refactor: partition Celery task sessions into smaller, discrete execu… (#32085)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-02-08 21:01:54 +08:00
committed by NFish
parent b035b091fa
commit 55de893984
9 changed files with 436 additions and 516 deletions

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
) )
except Exception: except Exception:
logger.exception("Build index for annotation failed") logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -5,7 +5,6 @@ import click
from celery import shared_task from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception: except Exception:
logger.exception("Annotation deleted index failed") logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService from services.dataset_service import DatasetCollectionBindingService
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
) )
except Exception: except Exception:
logger.exception("Build index for annotation failed") logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}" indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
dataset = session.get(Dataset, dataset_id) dataset = session.get(Dataset, dataset_id)
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
if not upload_file: if not upload_file:
raise ValueError("UploadFile not found.") raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir: dataset_config = {
suffix = Path(upload_file.key).suffix "id": dataset.id,
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore "indexing_technique": dataset.indexing_technique,
storage.download(upload_file.key, file_path) "tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
df = pd.read_csv(file_path) document_config = {
content = [] "id": dataset_document.id,
for _, row in df.iterrows(): "doc_form": dataset_document.doc_form,
if dataset_document.doc_form == "qa_model": "word_count": dataset_document.word_count or 0,
data = {"content": row.iloc[0], "answer": row.iloc[1]} }
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_segments = [] upload_file_key = upload_file.key
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
word_count_change = 0 except Exception:
if embedding_model: logger.exception("Segments batch created index failed")
tokens_list = embedding_model.get_text_embedding_num_tokens( redis_client.setex(indexing_cache_key, 600, "error")
texts=[segment["content"] for segment in content] return
)
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else: else:
tokens_list = [0] * len(content) data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
for segment, tokens in zip(content, tokens_list): document_segments = []
content = segment["content"] embedding_model = None
doc_id = str(uuid.uuid4()) if dataset_config["indexing_technique"] == "high_quality":
segment_hash = helper.generate_text_hash(content) model_manager = ModelManager()
max_position = ( embedding_model = model_manager.get_model_instance(
session.query(func.max(DocumentSegment.position)) tenant_id=dataset_config["tenant_id"],
.where(DocumentSegment.document_id == dataset_document.id) provider=dataset_config["embedding_model_provider"],
.scalar() model_type=ModelType.TEXT_EMBEDDING,
) model=dataset_config["embedding_model"],
segment_document = DocumentSegment( )
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change dataset_document.word_count += word_count_change
session.add(dataset_document) session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) with session_factory.create_session() as session:
session.commit() dataset = session.get(Dataset, dataset_id)
redis_client.setex(indexing_cache_key, 600, "completed") if dataset:
end_at = time.perf_counter() VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
logger.info(
click.style( redis_client.setex(indexing_cache_key, 600, "completed")
f"Segment batch created job: {job_id} latency: {end_at - start_at}", end_at = time.perf_counter()
fg="green", logger.info(
) click.style(
) f"Segment batch created job: {job_id} latency: {end_at - start_at}",
except Exception: fg="green",
logger.exception("Segments batch created index failed") )
redis_client.setex(indexing_cache_key, 600, "error") )

View File

@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
""" """
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter() start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session: with session_factory.create_session() as session:
try: try:
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id, SegmentAttachmentBinding.document_id == document_id,
) )
).all() ).all()
# check segment is exist
if segments: attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
index_node_ids = [segment.index_node_id for segment in segments] binding_ids = [binding.id for binding, _ in attachments_with_bindings]
index_processor = IndexProcessorFactory(doc_form).init_index_processor() total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean( index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
) )
for segment in segments: total_image_files = []
image_upload_file_ids = get_image_upload_file_ids(segment.content) with session_factory.create_session() as session, session.begin():
image_files = session.scalars( for segment_content in segment_contents:
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) image_upload_file_ids = get_image_upload_file_ids(segment_content)
).all() image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
for image_file in image_files: total_image_files.extend([image_file.key for image_file in image_files])
if image_file is None: image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
continue session.execute(image_file_delete_stmt)
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) with session_factory.create_session() as session, session.begin():
session.execute(image_file_delete_stmt) segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.delete(segment) session.execute(segment_delete_stmt)
session.commit() for image_file_key in total_image_files:
if file_id: try:
file = session.query(UploadFile).where(UploadFile.id == file_id).first() storage.delete(image_file_key)
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception: except Exception:
logger.exception("Cleaned document when document deleted failed") logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit() session.commit()
return return
for document_id in document_ids: # Phase 1: Update status to parsing (short transaction)
logger.info(click.style(f"Start process document: {document_id}", fg="green")) with session_factory.create_session() as session, session.begin():
documents = (
document = ( session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() )
)
for document in documents:
if document: if document:
document.indexing_status = "parsing" document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now() document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document) session.add(document)
session.commit() # Transaction committed and closed
try: # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
indexing_runner = IndexingRunner() has_error = False
indexing_runner.run(documents) try:
end_at = time.perf_counter() indexing_runner = IndexingRunner()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled # Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated) # Re-query dataset to get latest summary_index_setting (in case it was updated)
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status # expire all session to get latest document's indexing status
session.expire_all() session.expire_all()
# Check each document's indexing status and trigger summary generation if completed # Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it) documents = (
document = ( session.query(Document)
session.query(Document) .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.where(Document.id == document_id, Document.dataset_id == dataset_id) .all()
.first() )
)
for document in documents:
if document: if document:
logger.info( logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id, document.id,
document.indexing_status, document.indexing_status,
document.doc_form, document.doc_form,
document.need_summary, document.need_summary,
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True and document.need_summary is True
): ):
try: try:
generate_summary_index_task.delay(dataset.id, document_id, None) generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info( logger.info(
"Queued summary index generation task for document %s in dataset %s " "Queued summary index generation task for document %s in dataset %s "
"after indexing completed", "after indexing completed",
document_id, document.id,
dataset.id, dataset.id,
) )
except Exception: except Exception:
logger.exception( logger.exception(
"Failed to queue summary index generation task for document %s", "Failed to queue summary index generation task for document %s",
document_id, document.id,
) )
# Don't fail the entire indexing process if summary task queuing fails # Don't fail the entire indexing process if summary task queuing fails
else: else:
logger.info( logger.info(
"Skipping summary generation for document %s: " "Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s", "status=%s, doc_form=%s, need_summary=%s",
document_id, document.id,
document.indexing_status, document.indexing_status,
document.doc_form, document.doc_form,
document.need_summary, document.need_summary,
) )
else: else:
logger.warning("Document %s not found after indexing", document_id) logger.warning("Document %s not found after indexing", document.id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
else: else:
logger.info( logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id, dataset.id,
dataset.indexing_technique, dataset.indexing_technique,
) )
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue( def _document_indexing_with_tenant_queue(

View File

@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
""" """
from celery import shared_task # type: ignore[import-untyped] from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db from core.db.session_factory import session_factory
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
self, self,
deletions: list[DraftVarFileDeletion], deletions: list[DraftVarFileDeletion],
): ):
with Session(bind=db.engine) as session, session.begin(): with session_factory.create_session() as session, session.begin():
srv = WorkflowDraftVariableService(session=session) srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions) srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download mock_storage.download.side_effect = mock_download
# Execute the task # Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4()) job_id = str(uuid.uuid4())
batch_create_segment_to_index_task( with pytest.raises(ValueError, match="The CSV file is empty"):
job_id=job_id, batch_create_segment_to_index_task(
upload_file_id=upload_file.id, job_id=job_id,
dataset_id=dataset.id, upload_file_id=upload_file.id,
document_id=document.id, dataset_id=dataset.id,
tenant_id=tenant.id, document_id=document.id,
user_id=account.id, tenant_id=tenant.id,
) user_id=account.id,
)
# Verify error handling # Verify error handling
# Check Redis cache was set to error status # Since exception was raised, no segments should be created
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
from extensions.ext_database import db from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all() segments = db.session.query(DocumentSegment).all()

View File

@@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
def mock_db_session(): def mock_db_session():
"""Mock database session via session_factory.create_session().""" """Mock database session via session_factory.create_session()."""
with patch("tasks.document_indexing_task.session_factory") as mock_sf: with patch("tasks.document_indexing_task.session_factory") as mock_sf:
session = MagicMock() sessions = [] # Track all created sessions
# Ensure tests that expect session.close() to be called can observe it via the context manager # Shared mock data that all sessions will access
session.close = MagicMock() shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
cm = MagicMock()
cm.__enter__.return_value = session
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
def _exit_side_effect(*args, **kwargs): def create_session_side_effect():
session.close() session = MagicMock()
session.close = MagicMock()
cm.__exit__.side_effect = _exit_side_effect # Track commit calls
mock_sf.create_session.return_value = cm commit_mock = MagicMock()
session.commit = commit_mock
cm = MagicMock()
cm.__enter__.return_value = session
query = MagicMock() def _exit_side_effect(*args, **kwargs):
session.query.return_value = query session.close()
query.where.return_value = query
yield session cm.__exit__.side_effect = _exit_side_effect
# Support session.begin() for transactions
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
def begin_exit_side_effect(*args, **kwargs):
# Auto-commit on transaction exit (like SQLAlchemy)
session.commit()
# Also mark wrapper's commit as called
if sessions:
sessions[0].commit()
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
session.begin = MagicMock(return_value=begin_cm)
sessions.append(session)
# Setup query with side_effect to handle both Dataset and Document queries
def query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
# Support both .first() and .all() calls with chaining
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
# Create an iterator for .first() calls if not exists
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
session.query = MagicMock(side_effect=query_side_effect)
return cm
mock_sf.create_session.side_effect = create_session_side_effect
# Create a wrapper that behaves like the first session but has access to all sessions
class SessionWrapper:
def __init__(self):
self._sessions = sessions
self._shared_data = shared_mock_data
# Create a default session for setup phase
self._default_session = MagicMock()
self._default_session.close = MagicMock()
self._default_session.commit = MagicMock()
# Support session.begin() for default session too
begin_cm = MagicMock()
begin_cm.__enter__.return_value = self._default_session
def default_begin_exit_side_effect(*args, **kwargs):
self._default_session.commit()
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
self._default_session.begin = MagicMock(return_value=begin_cm)
def default_query_side_effect(*args):
query = MagicMock()
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
where_result = MagicMock()
where_result.first.return_value = shared_mock_data["dataset"]
query.where = MagicMock(return_value=where_result)
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
where_result = MagicMock()
where_result.where = MagicMock(return_value=where_result)
if shared_mock_data["doc_iter"] is None:
docs = shared_mock_data["documents"] or [None]
shared_mock_data["doc_iter"] = iter(docs)
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
docs_or_empty = shared_mock_data["documents"] or []
where_result.all = MagicMock(return_value=docs_or_empty)
query.where = MagicMock(return_value=where_result)
else:
query.where = MagicMock(return_value=query)
return query
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
def __getattr__(self, name):
# Forward all attribute access to the first session, or default if none created yet
target_session = self._sessions[0] if self._sessions else self._default_session
return getattr(target_session, name)
@property
def all_sessions(self):
"""Access all created sessions for testing."""
return self._sessions
wrapper = SessionWrapper()
yield wrapper
@pytest.fixture @pytest.fixture
@@ -252,18 +356,9 @@ class TestTaskEnqueuing:
use the deprecated function. use the deprecated function.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -304,21 +399,9 @@ class TestBatchProcessing:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
# Create an iterator for documents mock_db_session._shared_data["documents"] = mock_documents
doc_iter = iter(mock_documents)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
# Return documents one by one for each call
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -357,19 +440,9 @@ class TestBatchProcessing:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
@@ -407,19 +480,9 @@ class TestBatchProcessing:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
@@ -444,7 +507,10 @@ class TestBatchProcessing:
""" """
# Arrange # Arrange
document_ids = [] document_ids = []
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
# Set shared mock data with empty documents list
mock_db_session._shared_data["dataset"] = mock_dataset
mock_db_session._shared_data["documents"] = []
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -482,19 +548,9 @@ class TestProgressTracking:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -528,19 +584,9 @@ class TestProgressTracking:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -635,19 +681,9 @@ class TestErrorHandling:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set up to trigger vector space limit error # Set up to trigger vector space limit error
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@@ -674,17 +710,9 @@ class TestErrorHandling:
Errors during indexing should be caught and logged, but not crash the task. Errors during indexing should be caught and logged, but not crash the task.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Indexing failed") mock_indexing_runner.run.side_effect = Exception("Indexing failed")
@@ -708,17 +736,9 @@ class TestErrorHandling:
but not treated as a failure. but not treated as a failure.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise DocumentIsPausedError # Make IndexingRunner raise DocumentIsPausedError
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
@@ -853,17 +873,9 @@ class TestTaskCancellation:
Session cleanup should happen in finally block. Session cleanup should happen in finally block.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -883,17 +895,9 @@ class TestTaskCancellation:
Session cleanup should happen even when errors occur. Session cleanup should happen even when errors occur.
""" """
# Arrange # Arrange
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = mock_documents
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first.side_effect = mock_documents
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = Exception("Test error") mock_indexing_runner.run.side_effect = Exception("Test error")
@@ -962,6 +966,7 @@ class TestAdvancedScenarios:
document_ids = [str(uuid.uuid4()) for _ in range(3)] document_ids = [str(uuid.uuid4()) for _ in range(3)]
# Create only 2 documents (simulate one missing) # Create only 2 documents (simulate one missing)
# The new code uses .all() which will only return existing documents
mock_documents = [] mock_documents = []
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
doc = MagicMock(spec=Document) doc = MagicMock(spec=Document)
@@ -971,21 +976,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data - .all() will only return existing documents
mock_db_session._shared_data["dataset"] = mock_dataset
# Create iterator that returns None for missing document mock_db_session._shared_data["documents"] = mock_documents
doc_responses = [mock_documents[0], None, mock_documents[1]]
doc_iter = iter(doc_responses)
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
doc.stopped_at = None doc.stopped_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space exactly at limit # Set vector space exactly at limit
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Billing disabled - limits should not be checked # Billing disabled - limits should not be checked
mock_feature_service.get_features.return_value.billing.enabled = False mock_feature_service.get_features.return_value.billing.enabled = False
@@ -1273,19 +1246,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks) # Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None] mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1321,19 +1284,9 @@ class TestIntegration:
# Set up rpop to return None for concurrency check (no more tasks) # Set up rpop to return None for concurrency check (no more tasks)
mock_redis.rpop.side_effect = [None] mock_redis.rpop.side_effect = [None]
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1415,17 +1368,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting" mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = [mock_document]
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1465,17 +1410,9 @@ class TestEdgeCases:
mock_document.indexing_status = "waiting" mock_document.indexing_status = "waiting"
mock_document.processing_started_at = None mock_document.processing_started_at = None
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
def mock_query_side_effect(*args): mock_db_session._shared_data["documents"] = [mock_document]
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: mock_document
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1555,19 +1492,9 @@ class TestEdgeCases:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set vector space limit to 0 (unlimited) # Set vector space limit to 0 (unlimited)
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1612,19 +1539,9 @@ class TestEdgeCases:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Set negative vector space limit # Set negative vector space limit
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Configure billing with sufficient limits # Configure billing with sufficient limits
mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.enabled = True
@@ -1826,19 +1733,9 @@ class TestRobustness:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
# Make IndexingRunner raise an exception # Make IndexingRunner raise an exception
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
@@ -1866,7 +1763,7 @@ class TestRobustness:
- No exceptions occur - No exceptions occur
Expected behavior: Expected behavior:
- Database session is closed - All database sessions are closed
- No connection leaks - No connection leaks
""" """
# Arrange # Arrange
@@ -1879,19 +1776,9 @@ class TestRobustness:
doc.processing_started_at = None doc.processing_started_at = None
mock_documents.append(doc) mock_documents.append(doc)
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset # Set shared mock data so all sessions can access it
mock_db_session._shared_data["dataset"] = mock_dataset
doc_iter = iter(mock_documents) mock_db_session._shared_data["documents"] = mock_documents
def mock_query_side_effect(*args):
mock_query = MagicMock()
if args[0] == Dataset:
mock_query.where.return_value.first.return_value = mock_dataset
elif args[0] == Document:
mock_query.where.return_value.first = lambda: next(doc_iter, None)
return mock_query
mock_db_session.query.side_effect = mock_query_side_effect
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
mock_features.return_value.billing.enabled = False mock_features.return_value.billing.enabled = False
@@ -1899,10 +1786,11 @@ class TestRobustness:
# Act # Act
_document_indexing(dataset_id, document_ids) _document_indexing(dataset_id, document_ids)
# Assert # Assert - All created sessions should be closed
assert mock_db_session.close.called # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
# Verify close is called exactly once assert len(mock_db_session.all_sessions) >= 1
assert mock_db_session.close.call_count == 1 for session in mock_db_session.all_sessions:
assert session.close.called, "All sessions should be closed"
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
""" """