diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 23c49f2742..a9a8b892c2 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -58,5 +57,3 @@ def add_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index e928c25546..432732af95 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,7 +5,6 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception: logger.exception("Annotation deleted index failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 957d8f7e45..6ff34c0e74 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -59,5 +58,3 @@ def update_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 8ee09d5738..f69f17b16d 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -48,6 +48,11 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" + # Initialize variables with default values + upload_file_key: str | None = None + dataset_config: dict | None = None + document_config: dict | None = None + with session_factory.create_session() as session: try: dataset = session.get(Dataset, dataset_id) @@ -69,86 +74,115 @@ def batch_create_segment_to_index_task( if not upload_file: raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + dataset_config = { + "id": dataset.id, + "indexing_technique": dataset.indexing_technique, + "tenant_id": dataset.tenant_id, + "embedding_model_provider": dataset.embedding_model_provider, + "embedding_model": dataset.embedding_model, + } - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): - if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + document_config = { + "id": dataset_document.id, + "doc_form": dataset_document.doc_form, + "word_count": dataset_document.word_count or 0, + } - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + upload_file_key = upload_file.key - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] - ) + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") + return + + # Ensure required variables are set before proceeding + if upload_file_key is None or dataset_config is None or document_config is None: + logger.error("Required configuration not set due to session error") + redis_client.setex(indexing_cache_key, 600, "error") + return + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file_key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file_key, file_path) + + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if document_config["doc_form"] == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} else: - tokens_list = [0] * len(content) + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - session.add(segment_document) - document_segments.append(segment_document) + document_segments = [] + embedding_model = None + if dataset_config["indexing_technique"] == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset_config["tenant_id"], + provider=dataset_config["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=dataset_config["embedding_model"], + ) + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content]) + else: + tokens_list = [0] * len(content) + + with session_factory.create_session() as session, session.begin(): + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document_config["id"]) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) + if document_config["doc_form"] == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + + with session_factory.create_session() as session, session.begin(): + dataset_document = session.get(Document, document_id) + if dataset_document: assert dataset_document.word_count is not None dataset_document.word_count += word_count_change session.add(dataset_document) - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") + with session_factory.create_session() as session: + dataset = session.get(Dataset, dataset_id) + if dataset: + VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"]) + + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 91ace6be02..a017e9114b 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i """ logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() + total_attachment_files = [] with session_factory.create_session() as session: try: @@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i SegmentAttachmentBinding.document_id == document_id, ) ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() + + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings]) + + index_node_ids = [segment.index_node_id for segment in segments] + segment_contents = [segment.content for segment in segments] + except Exception: + logger.exception("Cleaned document when document deleted failed") + return + + # check segment is exist + if index_node_ids: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.scalars( - select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - ).all() - for image_file in image_files: - if image_file is None: - continue - try: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) + total_image_files = [] + with session_factory.create_session() as session, session.begin(): + for segment_content in segment_contents: + image_upload_file_ids = get_image_upload_file_ids(segment_content) + image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all() + total_image_files.extend([image_file.key for image_file in image_files]) + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) - image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(image_file_delete_stmt) - session.delete(segment) + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) - session.commit() - if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - session.delete(file) - # delete segment attachments - if attachments_with_bindings: - attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] - binding_ids = [binding.id for binding, _ in attachments_with_bindings] - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) - session.execute(attachment_file_delete_stmt) - - binding_delete_stmt = delete(SegmentAttachmentBinding).where( - SegmentAttachmentBinding.id.in_(binding_ids) - ) - session.execute(binding_delete_stmt) - - # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", - ) - ) + for image_file_key in total_image_files: + try: + storage.delete(image_file_key) except Exception: - logger.exception("Cleaned document when document deleted failed") + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: + try: + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + + with session_factory.create_session() as session, session.begin(): + # delete segment attachments + if attachment_ids: + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) + + if binding_ids: + binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids)) + session.execute(binding_delete_stmt) + + for attachment_file_key in total_attachment_files: + try: + storage.delete(attachment_file_key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + attachment_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 34496e9c6f..11edcf151f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.commit() return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) + # Phase 1: Update status to parsing (short transaction) + with session_factory.create_session() as session, session.begin(): + documents = ( + session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + ) + for document in documents: if document: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) session.add(document) - session.commit() + # Transaction committed and closed - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions) + has_error = False + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + has_error = True + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) + has_error = True + if not has_error: + with session_factory.create_session() as session: # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) @@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # expire all session to get latest document's indexing status session.expire_all() # Check each document's indexing status and trigger summary generation if completed - for document_id in document_ids: - # Re-query document to get latest status (IndexingRunner may have updated it) - document = ( - session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() - ) + + documents = ( + session.query(Document) + .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + .all() + ) + + for document in documents: if document: logger.info( "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, @@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): and document.need_summary is True ): try: - generate_summary_index_task.delay(dataset.id, document_id, None) + generate_summary_index_task.delay(dataset.id, document.id, None) logger.info( "Queued summary index generation task for document %s in dataset %s " "after indexing completed", - document_id, + document.id, dataset.id, ) except Exception: logger.exception( "Failed to queue summary index generation task for document %s", - document_id, + document.id, ) # Don't fail the entire indexing process if summary task queuing fails else: logger.info( "Skipping summary generation for document %s: " "status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, ) else: - logger.warning("Document %s not found after indexing", document_id) - else: - logger.info( - "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", - dataset.id, - summary_index_setting.get("enable") if summary_index_setting else None, - ) + logger.warning("Document %s not found after indexing", document.id) else: logger.info( "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", dataset.id, dataset.indexing_technique, ) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/workflow_draft_var_tasks.py b/api/tasks/workflow_draft_var_tasks.py index fcb98ec39e..26f8f7c29e 100644 --- a/api/tasks/workflow_draft_var_tasks.py +++ b/api/tasks/workflow_draft_var_tasks.py @@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers. """ from celery import shared_task # type: ignore[import-untyped] -from sqlalchemy.orm import Session -from extensions.ext_database import db +from core.db.session_factory import session_factory from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService @@ -17,6 +16,6 @@ def save_workflow_execution_task( self, deletions: list[DraftVarFileDeletion], ): - with Session(bind=db.engine) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): srv = WorkflowDraftVariableService(session=session) srv.delete_workflow_draft_variable_file(deletions=deletions) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 1b844d6357..61f6b75b10 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask: mock_storage.download.side_effect = mock_download - # Execute the task + # Execute the task - should raise ValueError for empty CSV job_id = str(uuid.uuid4()) - batch_create_segment_to_index_task( - job_id=job_id, - upload_file_id=upload_file.id, - dataset_id=dataset.id, - document_id=document.id, - tenant_id=tenant.id, - user_id=account.id, - ) + with pytest.raises(ValueError, match="The CSV file is empty"): + batch_create_segment_to_index_task( + job_id=job_id, + upload_file_id=upload_file.id, + dataset_id=dataset.id, + document_id=document.id, + tenant_id=tenant.id, + user_id=account.id, + ) # Verify error handling - # Check Redis cache was set to error status - from extensions.ext_redis import redis_client - - cache_key = f"segment_batch_import_{job_id}" - cache_value = redis_client.get(cache_key) - assert cache_value == b"error" - - # Verify no segments were created + # Since exception was raised, no segments should be created from extensions.ext_database import db segments = db.session.query(DocumentSegment).all() diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index e24ef32a24..8d8e2b0db0 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id): def mock_db_session(): """Mock database session via session_factory.create_session().""" with patch("tasks.document_indexing_task.session_factory") as mock_sf: - session = MagicMock() - # Ensure tests that expect session.close() to be called can observe it via the context manager - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - # Link __exit__ to session.close so "close" expectations reflect context manager teardown + sessions = [] # Track all created sessions + # Shared mock data that all sessions will access + shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None} - def _exit_side_effect(*args, **kwargs): - session.close() + def create_session_side_effect(): + session = MagicMock() + session.close = MagicMock() - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm + # Track commit calls + commit_mock = MagicMock() + session.commit = commit_mock + cm = MagicMock() + cm.__enter__.return_value = session - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - yield session + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + + # Support session.begin() for transactions + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def begin_exit_side_effect(*args, **kwargs): + # Auto-commit on transaction exit (like SQLAlchemy) + session.commit() + # Also mark wrapper's commit as called + if sessions: + sessions[0].commit() + + begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect) + session.begin = MagicMock(return_value=begin_cm) + + sessions.append(session) + + # Setup query with side_effect to handle both Dataset and Document queries + def query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + # Support both .first() and .all() calls with chaining + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + # Create an iterator for .first() calls if not exists + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + session.query = MagicMock(side_effect=query_side_effect) + return cm + + mock_sf.create_session.side_effect = create_session_side_effect + + # Create a wrapper that behaves like the first session but has access to all sessions + class SessionWrapper: + def __init__(self): + self._sessions = sessions + self._shared_data = shared_mock_data + # Create a default session for setup phase + self._default_session = MagicMock() + self._default_session.close = MagicMock() + self._default_session.commit = MagicMock() + + # Support session.begin() for default session too + begin_cm = MagicMock() + begin_cm.__enter__.return_value = self._default_session + + def default_begin_exit_side_effect(*args, **kwargs): + self._default_session.commit() + + begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect) + self._default_session.begin = MagicMock(return_value=begin_cm) + + def default_query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + self._default_session.query = MagicMock(side_effect=default_query_side_effect) + + def __getattr__(self, name): + # Forward all attribute access to the first session, or default if none created yet + target_session = self._sessions[0] if self._sessions else self._default_session + return getattr(target_session, name) + + @property + def all_sessions(self): + """Access all created sessions for testing.""" + return self._sessions + + wrapper = SessionWrapper() + yield wrapper @pytest.fixture @@ -252,18 +356,9 @@ class TestTaskEnqueuing: use the deprecated function. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -304,21 +399,9 @@ class TestBatchProcessing: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create an iterator for documents - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -357,19 +440,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL @@ -407,19 +480,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX @@ -444,7 +507,10 @@ class TestBatchProcessing: """ # Arrange document_ids = [] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -482,19 +548,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -528,19 +584,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -635,19 +681,9 @@ class TestErrorHandling: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set up to trigger vector space limit error mock_feature_service.get_features.return_value.billing.enabled = True @@ -674,17 +710,9 @@ class TestErrorHandling: Errors during indexing should be caught and logged, but not crash the task. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Indexing failed") @@ -708,17 +736,9 @@ class TestErrorHandling: but not treated as a failure. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise DocumentIsPausedError mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") @@ -853,17 +873,9 @@ class TestTaskCancellation: Session cleanup should happen in finally block. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -883,17 +895,9 @@ class TestTaskCancellation: Session cleanup should happen even when errors occur. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Test error") @@ -962,6 +966,7 @@ class TestAdvancedScenarios: document_ids = [str(uuid.uuid4()) for _ in range(3)] # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents mock_documents = [] for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one doc = MagicMock(spec=Document) @@ -971,21 +976,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create iterator that returns None for missing document - doc_responses = [mock_documents[0], None, mock_documents[1]] - doc_iter = iter(doc_responses) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1075,19 +1068,9 @@ class TestAdvancedScenarios: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space exactly at limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1219,19 +1202,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Billing disabled - limits should not be checked mock_feature_service.get_features.return_value.billing.enabled = False @@ -1273,19 +1246,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1321,19 +1284,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1415,17 +1368,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1465,17 +1410,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1555,19 +1492,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space limit to 0 (unlimited) mock_feature_service.get_features.return_value.billing.enabled = True @@ -1612,19 +1539,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set negative vector space limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1675,19 +1592,9 @@ class TestPerformanceScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Configure billing with sufficient limits mock_feature_service.get_features.return_value.billing.enabled = True @@ -1826,19 +1733,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") @@ -1866,7 +1763,7 @@ class TestRobustness: - No exceptions occur Expected behavior: - - Database session is closed + - All database sessions are closed - No connection leaks """ # Arrange @@ -1879,19 +1776,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1899,10 +1786,11 @@ class TestRobustness: # Act _document_indexing(dataset_id, document_ids) - # Assert - assert mock_db_session.close.called - # Verify close is called exactly once - assert mock_db_session.close.call_count == 1 + # Assert - All created sessions should be closed + # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary) + assert len(mock_db_session.all_sessions) >= 1 + for session in mock_db_session.all_sessions: + assert session.close.called, "All sessions should be closed" def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): """