mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 23:20:12 -05:00
refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -6,7 +6,6 @@ from celery import shared_task
|
|||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Build index for annotation failed")
|
logger.exception("Build index for annotation failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ import click
|
|||||||
from celery import shared_task
|
from celery import shared_task
|
||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
|||||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Annotation deleted index failed")
|
logger.exception("Annotation deleted index failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from celery import shared_task
|
|||||||
|
|
||||||
from core.rag.datasource.vdb.vector_factory import Vector
|
from core.rag.datasource.vdb.vector_factory import Vector
|
||||||
from core.rag.models.document import Document
|
from core.rag.models.document import Document
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.dataset import Dataset
|
from models.dataset import Dataset
|
||||||
from services.dataset_service import DatasetCollectionBindingService
|
from services.dataset_service import DatasetCollectionBindingService
|
||||||
|
|
||||||
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Build index for annotation failed")
|
logger.exception("Build index for annotation failed")
|
||||||
finally:
|
|
||||||
db.session.close()
|
|
||||||
|
|||||||
@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
|
|||||||
|
|
||||||
indexing_cache_key = f"segment_batch_import_{job_id}"
|
indexing_cache_key = f"segment_batch_import_{job_id}"
|
||||||
|
|
||||||
|
# Initialize variables with default values
|
||||||
|
upload_file_key: str | None = None
|
||||||
|
dataset_config: dict | None = None
|
||||||
|
document_config: dict | None = None
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
try:
|
try:
|
||||||
dataset = session.get(Dataset, dataset_id)
|
dataset = session.get(Dataset, dataset_id)
|
||||||
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
|
|||||||
if not upload_file:
|
if not upload_file:
|
||||||
raise ValueError("UploadFile not found.")
|
raise ValueError("UploadFile not found.")
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as temp_dir:
|
dataset_config = {
|
||||||
suffix = Path(upload_file.key).suffix
|
"id": dataset.id,
|
||||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
"indexing_technique": dataset.indexing_technique,
|
||||||
storage.download(upload_file.key, file_path)
|
"tenant_id": dataset.tenant_id,
|
||||||
|
"embedding_model_provider": dataset.embedding_model_provider,
|
||||||
|
"embedding_model": dataset.embedding_model,
|
||||||
|
}
|
||||||
|
|
||||||
df = pd.read_csv(file_path)
|
document_config = {
|
||||||
content = []
|
"id": dataset_document.id,
|
||||||
for _, row in df.iterrows():
|
"doc_form": dataset_document.doc_form,
|
||||||
if dataset_document.doc_form == "qa_model":
|
"word_count": dataset_document.word_count or 0,
|
||||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
}
|
||||||
else:
|
|
||||||
data = {"content": row.iloc[0]}
|
|
||||||
content.append(data)
|
|
||||||
if len(content) == 0:
|
|
||||||
raise ValueError("The CSV file is empty.")
|
|
||||||
|
|
||||||
document_segments = []
|
upload_file_key = upload_file.key
|
||||||
embedding_model = None
|
|
||||||
if dataset.indexing_technique == "high_quality":
|
|
||||||
model_manager = ModelManager()
|
|
||||||
embedding_model = model_manager.get_model_instance(
|
|
||||||
tenant_id=dataset.tenant_id,
|
|
||||||
provider=dataset.embedding_model_provider,
|
|
||||||
model_type=ModelType.TEXT_EMBEDDING,
|
|
||||||
model=dataset.embedding_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
word_count_change = 0
|
except Exception:
|
||||||
if embedding_model:
|
logger.exception("Segments batch created index failed")
|
||||||
tokens_list = embedding_model.get_text_embedding_num_tokens(
|
redis_client.setex(indexing_cache_key, 600, "error")
|
||||||
texts=[segment["content"] for segment in content]
|
return
|
||||||
)
|
|
||||||
|
# Ensure required variables are set before proceeding
|
||||||
|
if upload_file_key is None or dataset_config is None or document_config is None:
|
||||||
|
logger.error("Required configuration not set due to session error")
|
||||||
|
redis_client.setex(indexing_cache_key, 600, "error")
|
||||||
|
return
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir:
|
||||||
|
suffix = Path(upload_file_key).suffix
|
||||||
|
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||||
|
storage.download(upload_file_key, file_path)
|
||||||
|
|
||||||
|
df = pd.read_csv(file_path)
|
||||||
|
content = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
if document_config["doc_form"] == "qa_model":
|
||||||
|
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||||
else:
|
else:
|
||||||
tokens_list = [0] * len(content)
|
data = {"content": row.iloc[0]}
|
||||||
|
content.append(data)
|
||||||
|
if len(content) == 0:
|
||||||
|
raise ValueError("The CSV file is empty.")
|
||||||
|
|
||||||
for segment, tokens in zip(content, tokens_list):
|
document_segments = []
|
||||||
content = segment["content"]
|
embedding_model = None
|
||||||
doc_id = str(uuid.uuid4())
|
if dataset_config["indexing_technique"] == "high_quality":
|
||||||
segment_hash = helper.generate_text_hash(content)
|
model_manager = ModelManager()
|
||||||
max_position = (
|
embedding_model = model_manager.get_model_instance(
|
||||||
session.query(func.max(DocumentSegment.position))
|
tenant_id=dataset_config["tenant_id"],
|
||||||
.where(DocumentSegment.document_id == dataset_document.id)
|
provider=dataset_config["embedding_model_provider"],
|
||||||
.scalar()
|
model_type=ModelType.TEXT_EMBEDDING,
|
||||||
)
|
model=dataset_config["embedding_model"],
|
||||||
segment_document = DocumentSegment(
|
)
|
||||||
tenant_id=tenant_id,
|
|
||||||
dataset_id=dataset_id,
|
|
||||||
document_id=document_id,
|
|
||||||
index_node_id=doc_id,
|
|
||||||
index_node_hash=segment_hash,
|
|
||||||
position=max_position + 1 if max_position else 1,
|
|
||||||
content=content,
|
|
||||||
word_count=len(content),
|
|
||||||
tokens=tokens,
|
|
||||||
created_by=user_id,
|
|
||||||
indexing_at=naive_utc_now(),
|
|
||||||
status="completed",
|
|
||||||
completed_at=naive_utc_now(),
|
|
||||||
)
|
|
||||||
if dataset_document.doc_form == "qa_model":
|
|
||||||
segment_document.answer = segment["answer"]
|
|
||||||
segment_document.word_count += len(segment["answer"])
|
|
||||||
word_count_change += segment_document.word_count
|
|
||||||
session.add(segment_document)
|
|
||||||
document_segments.append(segment_document)
|
|
||||||
|
|
||||||
|
word_count_change = 0
|
||||||
|
if embedding_model:
|
||||||
|
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
|
||||||
|
else:
|
||||||
|
tokens_list = [0] * len(content)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
for segment, tokens in zip(content, tokens_list):
|
||||||
|
content = segment["content"]
|
||||||
|
doc_id = str(uuid.uuid4())
|
||||||
|
segment_hash = helper.generate_text_hash(content)
|
||||||
|
max_position = (
|
||||||
|
session.query(func.max(DocumentSegment.position))
|
||||||
|
.where(DocumentSegment.document_id == document_config["id"])
|
||||||
|
.scalar()
|
||||||
|
)
|
||||||
|
segment_document = DocumentSegment(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
document_id=document_id,
|
||||||
|
index_node_id=doc_id,
|
||||||
|
index_node_hash=segment_hash,
|
||||||
|
position=max_position + 1 if max_position else 1,
|
||||||
|
content=content,
|
||||||
|
word_count=len(content),
|
||||||
|
tokens=tokens,
|
||||||
|
created_by=user_id,
|
||||||
|
indexing_at=naive_utc_now(),
|
||||||
|
status="completed",
|
||||||
|
completed_at=naive_utc_now(),
|
||||||
|
)
|
||||||
|
if document_config["doc_form"] == "qa_model":
|
||||||
|
segment_document.answer = segment["answer"]
|
||||||
|
segment_document.word_count += len(segment["answer"])
|
||||||
|
word_count_change += segment_document.word_count
|
||||||
|
session.add(segment_document)
|
||||||
|
document_segments.append(segment_document)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
dataset_document = session.get(Document, document_id)
|
||||||
|
if dataset_document:
|
||||||
assert dataset_document.word_count is not None
|
assert dataset_document.word_count is not None
|
||||||
dataset_document.word_count += word_count_change
|
dataset_document.word_count += word_count_change
|
||||||
session.add(dataset_document)
|
session.add(dataset_document)
|
||||||
|
|
||||||
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
|
with session_factory.create_session() as session:
|
||||||
session.commit()
|
dataset = session.get(Dataset, dataset_id)
|
||||||
redis_client.setex(indexing_cache_key, 600, "completed")
|
if dataset:
|
||||||
end_at = time.perf_counter()
|
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
|
||||||
logger.info(
|
|
||||||
click.style(
|
redis_client.setex(indexing_cache_key, 600, "completed")
|
||||||
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
end_at = time.perf_counter()
|
||||||
fg="green",
|
logger.info(
|
||||||
)
|
click.style(
|
||||||
)
|
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
|
||||||
except Exception:
|
fg="green",
|
||||||
logger.exception("Segments batch created index failed")
|
)
|
||||||
redis_client.setex(indexing_cache_key, 600, "error")
|
)
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
"""
|
"""
|
||||||
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
|
||||||
start_at = time.perf_counter()
|
start_at = time.perf_counter()
|
||||||
|
total_attachment_files = []
|
||||||
|
|
||||||
with session_factory.create_session() as session:
|
with session_factory.create_session() as session:
|
||||||
try:
|
try:
|
||||||
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
|||||||
SegmentAttachmentBinding.document_id == document_id,
|
SegmentAttachmentBinding.document_id == document_id,
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
# check segment is exist
|
|
||||||
if segments:
|
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
||||||
index_node_ids = [segment.index_node_id for segment in segments]
|
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
||||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
|
||||||
|
|
||||||
|
index_node_ids = [segment.index_node_id for segment in segments]
|
||||||
|
segment_contents = [segment.content for segment in segments]
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Cleaned document when document deleted failed")
|
||||||
|
return
|
||||||
|
|
||||||
|
# check segment is exist
|
||||||
|
if index_node_ids:
|
||||||
|
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||||
|
with session_factory.create_session() as session:
|
||||||
|
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||||
|
if dataset:
|
||||||
index_processor.clean(
|
index_processor.clean(
|
||||||
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
|
||||||
)
|
)
|
||||||
|
|
||||||
for segment in segments:
|
total_image_files = []
|
||||||
image_upload_file_ids = get_image_upload_file_ids(segment.content)
|
with session_factory.create_session() as session, session.begin():
|
||||||
image_files = session.scalars(
|
for segment_content in segment_contents:
|
||||||
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
image_upload_file_ids = get_image_upload_file_ids(segment_content)
|
||||||
).all()
|
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
|
||||||
for image_file in image_files:
|
total_image_files.extend([image_file.key for image_file in image_files])
|
||||||
if image_file is None:
|
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
||||||
continue
|
session.execute(image_file_delete_stmt)
|
||||||
try:
|
|
||||||
storage.delete(image_file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Delete image_files failed when storage deleted, \
|
|
||||||
image_upload_file_is: %s",
|
|
||||||
image_file.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
|
with session_factory.create_session() as session, session.begin():
|
||||||
session.execute(image_file_delete_stmt)
|
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||||
session.delete(segment)
|
session.execute(segment_delete_stmt)
|
||||||
|
|
||||||
session.commit()
|
for image_file_key in total_image_files:
|
||||||
if file_id:
|
try:
|
||||||
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
storage.delete(image_file_key)
|
||||||
if file:
|
|
||||||
try:
|
|
||||||
storage.delete(file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
|
||||||
session.delete(file)
|
|
||||||
# delete segment attachments
|
|
||||||
if attachments_with_bindings:
|
|
||||||
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
|
|
||||||
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
|
|
||||||
for binding, attachment_file in attachments_with_bindings:
|
|
||||||
try:
|
|
||||||
storage.delete(attachment_file.key)
|
|
||||||
except Exception:
|
|
||||||
logger.exception(
|
|
||||||
"Delete attachment_file failed when storage deleted, \
|
|
||||||
attachment_file_id: %s",
|
|
||||||
binding.attachment_id,
|
|
||||||
)
|
|
||||||
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
|
||||||
session.execute(attachment_file_delete_stmt)
|
|
||||||
|
|
||||||
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
|
|
||||||
SegmentAttachmentBinding.id.in_(binding_ids)
|
|
||||||
)
|
|
||||||
session.execute(binding_delete_stmt)
|
|
||||||
|
|
||||||
# delete dataset metadata binding
|
|
||||||
session.query(DatasetMetadataBinding).where(
|
|
||||||
DatasetMetadataBinding.dataset_id == dataset_id,
|
|
||||||
DatasetMetadataBinding.document_id == document_id,
|
|
||||||
).delete()
|
|
||||||
session.commit()
|
|
||||||
|
|
||||||
end_at = time.perf_counter()
|
|
||||||
logger.info(
|
|
||||||
click.style(
|
|
||||||
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
|
||||||
fg="green",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Cleaned document when document deleted failed")
|
logger.exception(
|
||||||
|
"Delete image_files failed when storage deleted, \
|
||||||
|
image_upload_file_is: %s",
|
||||||
|
image_file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
if file_id:
|
||||||
|
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
|
||||||
|
if file:
|
||||||
|
try:
|
||||||
|
storage.delete(file.key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
|
||||||
|
session.delete(file)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
# delete segment attachments
|
||||||
|
if attachment_ids:
|
||||||
|
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
|
||||||
|
session.execute(attachment_file_delete_stmt)
|
||||||
|
|
||||||
|
if binding_ids:
|
||||||
|
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
|
||||||
|
session.execute(binding_delete_stmt)
|
||||||
|
|
||||||
|
for attachment_file_key in total_attachment_files:
|
||||||
|
try:
|
||||||
|
storage.delete(attachment_file_key)
|
||||||
|
except Exception:
|
||||||
|
logger.exception(
|
||||||
|
"Delete attachment_file failed when storage deleted, \
|
||||||
|
attachment_file_id: %s",
|
||||||
|
attachment_file_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
# delete dataset metadata binding
|
||||||
|
session.query(DatasetMetadataBinding).where(
|
||||||
|
DatasetMetadataBinding.dataset_id == dataset_id,
|
||||||
|
DatasetMetadataBinding.document_id == document_id,
|
||||||
|
).delete()
|
||||||
|
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(
|
||||||
|
click.style(
|
||||||
|
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
|
||||||
|
fg="green",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|||||||
@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
session.commit()
|
session.commit()
|
||||||
return
|
return
|
||||||
|
|
||||||
for document_id in document_ids:
|
# Phase 1: Update status to parsing (short transaction)
|
||||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
with session_factory.create_session() as session, session.begin():
|
||||||
|
documents = (
|
||||||
document = (
|
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
|
||||||
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
)
|
||||||
)
|
|
||||||
|
|
||||||
|
for document in documents:
|
||||||
if document:
|
if document:
|
||||||
document.indexing_status = "parsing"
|
document.indexing_status = "parsing"
|
||||||
document.processing_started_at = naive_utc_now()
|
document.processing_started_at = naive_utc_now()
|
||||||
documents.append(document)
|
|
||||||
session.add(document)
|
session.add(document)
|
||||||
session.commit()
|
# Transaction committed and closed
|
||||||
|
|
||||||
try:
|
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
|
||||||
indexing_runner = IndexingRunner()
|
has_error = False
|
||||||
indexing_runner.run(documents)
|
try:
|
||||||
end_at = time.perf_counter()
|
indexing_runner = IndexingRunner()
|
||||||
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
indexing_runner.run(documents)
|
||||||
|
end_at = time.perf_counter()
|
||||||
|
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
|
||||||
|
except DocumentIsPausedError as ex:
|
||||||
|
logger.info(click.style(str(ex), fg="yellow"))
|
||||||
|
has_error = True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
||||||
|
has_error = True
|
||||||
|
|
||||||
|
if not has_error:
|
||||||
|
with session_factory.create_session() as session:
|
||||||
# Trigger summary index generation for completed documents if enabled
|
# Trigger summary index generation for completed documents if enabled
|
||||||
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
|
||||||
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
# Re-query dataset to get latest summary_index_setting (in case it was updated)
|
||||||
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
# expire all session to get latest document's indexing status
|
# expire all session to get latest document's indexing status
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
# Check each document's indexing status and trigger summary generation if completed
|
# Check each document's indexing status and trigger summary generation if completed
|
||||||
for document_id in document_ids:
|
|
||||||
# Re-query document to get latest status (IndexingRunner may have updated it)
|
documents = (
|
||||||
document = (
|
session.query(Document)
|
||||||
session.query(Document)
|
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
|
||||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
.all()
|
||||||
.first()
|
)
|
||||||
)
|
|
||||||
|
for document in documents:
|
||||||
if document:
|
if document:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
|
||||||
document_id,
|
document.id,
|
||||||
document.indexing_status,
|
document.indexing_status,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
document.need_summary,
|
document.need_summary,
|
||||||
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
|
|||||||
and document.need_summary is True
|
and document.need_summary is True
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
generate_summary_index_task.delay(dataset.id, document_id, None)
|
generate_summary_index_task.delay(dataset.id, document.id, None)
|
||||||
logger.info(
|
logger.info(
|
||||||
"Queued summary index generation task for document %s in dataset %s "
|
"Queued summary index generation task for document %s in dataset %s "
|
||||||
"after indexing completed",
|
"after indexing completed",
|
||||||
document_id,
|
document.id,
|
||||||
dataset.id,
|
dataset.id,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception(
|
logger.exception(
|
||||||
"Failed to queue summary index generation task for document %s",
|
"Failed to queue summary index generation task for document %s",
|
||||||
document_id,
|
document.id,
|
||||||
)
|
)
|
||||||
# Don't fail the entire indexing process if summary task queuing fails
|
# Don't fail the entire indexing process if summary task queuing fails
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Skipping summary generation for document %s: "
|
"Skipping summary generation for document %s: "
|
||||||
"status=%s, doc_form=%s, need_summary=%s",
|
"status=%s, doc_form=%s, need_summary=%s",
|
||||||
document_id,
|
document.id,
|
||||||
document.indexing_status,
|
document.indexing_status,
|
||||||
document.doc_form,
|
document.doc_form,
|
||||||
document.need_summary,
|
document.need_summary,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("Document %s not found after indexing", document_id)
|
logger.warning("Document %s not found after indexing", document.id)
|
||||||
else:
|
|
||||||
logger.info(
|
|
||||||
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
|
|
||||||
dataset.id,
|
|
||||||
summary_index_setting.get("enable") if summary_index_setting else None,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
logger.info(
|
logger.info(
|
||||||
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
|
||||||
dataset.id,
|
dataset.id,
|
||||||
dataset.indexing_technique,
|
dataset.indexing_technique,
|
||||||
)
|
)
|
||||||
except DocumentIsPausedError as ex:
|
|
||||||
logger.info(click.style(str(ex), fg="yellow"))
|
|
||||||
except Exception:
|
|
||||||
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
|
|
||||||
|
|
||||||
|
|
||||||
def _document_indexing_with_tenant_queue(
|
def _document_indexing_with_tenant_queue(
|
||||||
|
|||||||
@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from celery import shared_task # type: ignore[import-untyped]
|
from celery import shared_task # type: ignore[import-untyped]
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from extensions.ext_database import db
|
from core.db.session_factory import session_factory
|
||||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||||
|
|
||||||
|
|
||||||
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
|
|||||||
self,
|
self,
|
||||||
deletions: list[DraftVarFileDeletion],
|
deletions: list[DraftVarFileDeletion],
|
||||||
):
|
):
|
||||||
with Session(bind=db.engine) as session, session.begin():
|
with session_factory.create_session() as session, session.begin():
|
||||||
srv = WorkflowDraftVariableService(session=session)
|
srv = WorkflowDraftVariableService(session=session)
|
||||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||||
|
|||||||
@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
|
|||||||
|
|
||||||
mock_storage.download.side_effect = mock_download
|
mock_storage.download.side_effect = mock_download
|
||||||
|
|
||||||
# Execute the task
|
# Execute the task - should raise ValueError for empty CSV
|
||||||
job_id = str(uuid.uuid4())
|
job_id = str(uuid.uuid4())
|
||||||
batch_create_segment_to_index_task(
|
with pytest.raises(ValueError, match="The CSV file is empty"):
|
||||||
job_id=job_id,
|
batch_create_segment_to_index_task(
|
||||||
upload_file_id=upload_file.id,
|
job_id=job_id,
|
||||||
dataset_id=dataset.id,
|
upload_file_id=upload_file.id,
|
||||||
document_id=document.id,
|
dataset_id=dataset.id,
|
||||||
tenant_id=tenant.id,
|
document_id=document.id,
|
||||||
user_id=account.id,
|
tenant_id=tenant.id,
|
||||||
)
|
user_id=account.id,
|
||||||
|
)
|
||||||
|
|
||||||
# Verify error handling
|
# Verify error handling
|
||||||
# Check Redis cache was set to error status
|
# Since exception was raised, no segments should be created
|
||||||
from extensions.ext_redis import redis_client
|
|
||||||
|
|
||||||
cache_key = f"segment_batch_import_{job_id}"
|
|
||||||
cache_value = redis_client.get(cache_key)
|
|
||||||
assert cache_value == b"error"
|
|
||||||
|
|
||||||
# Verify no segments were created
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
|
|
||||||
segments = db.session.query(DocumentSegment).all()
|
segments = db.session.query(DocumentSegment).all()
|
||||||
|
|||||||
@@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id):
|
|||||||
def mock_db_session():
|
def mock_db_session():
|
||||||
"""Mock database session via session_factory.create_session()."""
|
"""Mock database session via session_factory.create_session()."""
|
||||||
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
with patch("tasks.document_indexing_task.session_factory") as mock_sf:
|
||||||
session = MagicMock()
|
sessions = [] # Track all created sessions
|
||||||
# Ensure tests that expect session.close() to be called can observe it via the context manager
|
# Shared mock data that all sessions will access
|
||||||
session.close = MagicMock()
|
shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None}
|
||||||
cm = MagicMock()
|
|
||||||
cm.__enter__.return_value = session
|
|
||||||
# Link __exit__ to session.close so "close" expectations reflect context manager teardown
|
|
||||||
|
|
||||||
def _exit_side_effect(*args, **kwargs):
|
def create_session_side_effect():
|
||||||
session.close()
|
session = MagicMock()
|
||||||
|
session.close = MagicMock()
|
||||||
|
|
||||||
cm.__exit__.side_effect = _exit_side_effect
|
# Track commit calls
|
||||||
mock_sf.create_session.return_value = cm
|
commit_mock = MagicMock()
|
||||||
|
session.commit = commit_mock
|
||||||
|
cm = MagicMock()
|
||||||
|
cm.__enter__.return_value = session
|
||||||
|
|
||||||
query = MagicMock()
|
def _exit_side_effect(*args, **kwargs):
|
||||||
session.query.return_value = query
|
session.close()
|
||||||
query.where.return_value = query
|
|
||||||
yield session
|
cm.__exit__.side_effect = _exit_side_effect
|
||||||
|
|
||||||
|
# Support session.begin() for transactions
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = session
|
||||||
|
|
||||||
|
def begin_exit_side_effect(*args, **kwargs):
|
||||||
|
# Auto-commit on transaction exit (like SQLAlchemy)
|
||||||
|
session.commit()
|
||||||
|
# Also mark wrapper's commit as called
|
||||||
|
if sessions:
|
||||||
|
sessions[0].commit()
|
||||||
|
|
||||||
|
begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect)
|
||||||
|
session.begin = MagicMock(return_value=begin_cm)
|
||||||
|
|
||||||
|
sessions.append(session)
|
||||||
|
|
||||||
|
# Setup query with side_effect to handle both Dataset and Document queries
|
||||||
|
def query_side_effect(*args):
|
||||||
|
query = MagicMock()
|
||||||
|
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.first.return_value = shared_mock_data["dataset"]
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||||
|
# Support both .first() and .all() calls with chaining
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.where = MagicMock(return_value=where_result)
|
||||||
|
|
||||||
|
# Create an iterator for .first() calls if not exists
|
||||||
|
if shared_mock_data["doc_iter"] is None:
|
||||||
|
docs = shared_mock_data["documents"] or [None]
|
||||||
|
shared_mock_data["doc_iter"] = iter(docs)
|
||||||
|
|
||||||
|
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||||
|
docs_or_empty = shared_mock_data["documents"] or []
|
||||||
|
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
else:
|
||||||
|
query.where = MagicMock(return_value=query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
session.query = MagicMock(side_effect=query_side_effect)
|
||||||
|
return cm
|
||||||
|
|
||||||
|
mock_sf.create_session.side_effect = create_session_side_effect
|
||||||
|
|
||||||
|
# Create a wrapper that behaves like the first session but has access to all sessions
|
||||||
|
class SessionWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self._sessions = sessions
|
||||||
|
self._shared_data = shared_mock_data
|
||||||
|
# Create a default session for setup phase
|
||||||
|
self._default_session = MagicMock()
|
||||||
|
self._default_session.close = MagicMock()
|
||||||
|
self._default_session.commit = MagicMock()
|
||||||
|
|
||||||
|
# Support session.begin() for default session too
|
||||||
|
begin_cm = MagicMock()
|
||||||
|
begin_cm.__enter__.return_value = self._default_session
|
||||||
|
|
||||||
|
def default_begin_exit_side_effect(*args, **kwargs):
|
||||||
|
self._default_session.commit()
|
||||||
|
|
||||||
|
begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect)
|
||||||
|
self._default_session.begin = MagicMock(return_value=begin_cm)
|
||||||
|
|
||||||
|
def default_query_side_effect(*args):
|
||||||
|
query = MagicMock()
|
||||||
|
if args and args[0] == Dataset and shared_mock_data["dataset"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.first.return_value = shared_mock_data["dataset"]
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
elif args and args[0] == Document and shared_mock_data["documents"] is not None:
|
||||||
|
where_result = MagicMock()
|
||||||
|
where_result.where = MagicMock(return_value=where_result)
|
||||||
|
|
||||||
|
if shared_mock_data["doc_iter"] is None:
|
||||||
|
docs = shared_mock_data["documents"] or [None]
|
||||||
|
shared_mock_data["doc_iter"] = iter(docs)
|
||||||
|
|
||||||
|
where_result.first = lambda: next(shared_mock_data["doc_iter"], None)
|
||||||
|
docs_or_empty = shared_mock_data["documents"] or []
|
||||||
|
where_result.all = MagicMock(return_value=docs_or_empty)
|
||||||
|
query.where = MagicMock(return_value=where_result)
|
||||||
|
else:
|
||||||
|
query.where = MagicMock(return_value=query)
|
||||||
|
return query
|
||||||
|
|
||||||
|
self._default_session.query = MagicMock(side_effect=default_query_side_effect)
|
||||||
|
|
||||||
|
def __getattr__(self, name):
|
||||||
|
# Forward all attribute access to the first session, or default if none created yet
|
||||||
|
target_session = self._sessions[0] if self._sessions else self._default_session
|
||||||
|
return getattr(target_session, name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def all_sessions(self):
|
||||||
|
"""Access all created sessions for testing."""
|
||||||
|
return self._sessions
|
||||||
|
|
||||||
|
wrapper = SessionWrapper()
|
||||||
|
yield wrapper
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -252,18 +356,9 @@ class TestTaskEnqueuing:
|
|||||||
use the deprecated function.
|
use the deprecated function.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
# Return documents one by one for each call
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -304,21 +399,9 @@ class TestBatchProcessing:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
# Create an iterator for documents
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
doc_iter = iter(mock_documents)
|
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
# Return documents one by one for each call
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -357,19 +440,9 @@ class TestBatchProcessing:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL
|
||||||
@@ -407,19 +480,9 @@ class TestBatchProcessing:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX
|
||||||
@@ -444,7 +507,10 @@ class TestBatchProcessing:
|
|||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
document_ids = []
|
document_ids = []
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
|
||||||
|
# Set shared mock data with empty documents list
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
|
mock_db_session._shared_data["documents"] = []
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -482,19 +548,9 @@ class TestProgressTracking:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -528,19 +584,9 @@ class TestProgressTracking:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -635,19 +681,9 @@ class TestErrorHandling:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set up to trigger vector space limit error
|
# Set up to trigger vector space limit error
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@@ -674,17 +710,9 @@ class TestErrorHandling:
|
|||||||
Errors during indexing should be caught and logged, but not crash the task.
|
Errors during indexing should be caught and logged, but not crash the task.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
mock_indexing_runner.run.side_effect = Exception("Indexing failed")
|
||||||
@@ -708,17 +736,9 @@ class TestErrorHandling:
|
|||||||
but not treated as a failure.
|
but not treated as a failure.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise DocumentIsPausedError
|
# Make IndexingRunner raise DocumentIsPausedError
|
||||||
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused")
|
||||||
@@ -853,17 +873,9 @@ class TestTaskCancellation:
|
|||||||
Session cleanup should happen in finally block.
|
Session cleanup should happen in finally block.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -883,17 +895,9 @@ class TestTaskCancellation:
|
|||||||
Session cleanup should happen even when errors occur.
|
Session cleanup should happen even when errors occur.
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first.side_effect = mock_documents
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = Exception("Test error")
|
mock_indexing_runner.run.side_effect = Exception("Test error")
|
||||||
@@ -962,6 +966,7 @@ class TestAdvancedScenarios:
|
|||||||
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
document_ids = [str(uuid.uuid4()) for _ in range(3)]
|
||||||
|
|
||||||
# Create only 2 documents (simulate one missing)
|
# Create only 2 documents (simulate one missing)
|
||||||
|
# The new code uses .all() which will only return existing documents
|
||||||
mock_documents = []
|
mock_documents = []
|
||||||
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one
|
||||||
doc = MagicMock(spec=Document)
|
doc = MagicMock(spec=Document)
|
||||||
@@ -971,21 +976,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data - .all() will only return existing documents
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
# Create iterator that returns None for missing document
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
doc_responses = [mock_documents[0], None, mock_documents[1]]
|
|
||||||
doc_iter = iter(doc_responses)
|
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1075,19 +1068,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.stopped_at = None
|
doc.stopped_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set vector space exactly at limit
|
# Set vector space exactly at limit
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@@ -1219,19 +1202,9 @@ class TestAdvancedScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Billing disabled - limits should not be checked
|
# Billing disabled - limits should not be checked
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = False
|
mock_feature_service.get_features.return_value.billing.enabled = False
|
||||||
@@ -1273,19 +1246,9 @@ class TestIntegration:
|
|||||||
|
|
||||||
# Set up rpop to return None for concurrency check (no more tasks)
|
# Set up rpop to return None for concurrency check (no more tasks)
|
||||||
mock_redis.rpop.side_effect = [None]
|
mock_redis.rpop.side_effect = [None]
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1321,19 +1284,9 @@ class TestIntegration:
|
|||||||
|
|
||||||
# Set up rpop to return None for concurrency check (no more tasks)
|
# Set up rpop to return None for concurrency check (no more tasks)
|
||||||
mock_redis.rpop.side_effect = [None]
|
mock_redis.rpop.side_effect = [None]
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1415,17 +1368,9 @@ class TestEdgeCases:
|
|||||||
mock_document.indexing_status = "waiting"
|
mock_document.indexing_status = "waiting"
|
||||||
mock_document.processing_started_at = None
|
mock_document.processing_started_at = None
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = [mock_document]
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: mock_document
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1465,17 +1410,9 @@ class TestEdgeCases:
|
|||||||
mock_document.indexing_status = "waiting"
|
mock_document.indexing_status = "waiting"
|
||||||
mock_document.processing_started_at = None
|
mock_document.processing_started_at = None
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
def mock_query_side_effect(*args):
|
mock_db_session._shared_data["documents"] = [mock_document]
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: mock_document
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1555,19 +1492,9 @@ class TestEdgeCases:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set vector space limit to 0 (unlimited)
|
# Set vector space limit to 0 (unlimited)
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@@ -1612,19 +1539,9 @@ class TestEdgeCases:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Set negative vector space limit
|
# Set negative vector space limit
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@@ -1675,19 +1592,9 @@ class TestPerformanceScenarios:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Configure billing with sufficient limits
|
# Configure billing with sufficient limits
|
||||||
mock_feature_service.get_features.return_value.billing.enabled = True
|
mock_feature_service.get_features.return_value.billing.enabled = True
|
||||||
@@ -1826,19 +1733,9 @@ class TestRobustness:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
# Make IndexingRunner raise an exception
|
# Make IndexingRunner raise an exception
|
||||||
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error")
|
||||||
@@ -1866,7 +1763,7 @@ class TestRobustness:
|
|||||||
- No exceptions occur
|
- No exceptions occur
|
||||||
|
|
||||||
Expected behavior:
|
Expected behavior:
|
||||||
- Database session is closed
|
- All database sessions are closed
|
||||||
- No connection leaks
|
- No connection leaks
|
||||||
"""
|
"""
|
||||||
# Arrange
|
# Arrange
|
||||||
@@ -1879,19 +1776,9 @@ class TestRobustness:
|
|||||||
doc.processing_started_at = None
|
doc.processing_started_at = None
|
||||||
mock_documents.append(doc)
|
mock_documents.append(doc)
|
||||||
|
|
||||||
mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset
|
# Set shared mock data so all sessions can access it
|
||||||
|
mock_db_session._shared_data["dataset"] = mock_dataset
|
||||||
doc_iter = iter(mock_documents)
|
mock_db_session._shared_data["documents"] = mock_documents
|
||||||
|
|
||||||
def mock_query_side_effect(*args):
|
|
||||||
mock_query = MagicMock()
|
|
||||||
if args[0] == Dataset:
|
|
||||||
mock_query.where.return_value.first.return_value = mock_dataset
|
|
||||||
elif args[0] == Document:
|
|
||||||
mock_query.where.return_value.first = lambda: next(doc_iter, None)
|
|
||||||
return mock_query
|
|
||||||
|
|
||||||
mock_db_session.query.side_effect = mock_query_side_effect
|
|
||||||
|
|
||||||
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features:
|
||||||
mock_features.return_value.billing.enabled = False
|
mock_features.return_value.billing.enabled = False
|
||||||
@@ -1899,10 +1786,11 @@ class TestRobustness:
|
|||||||
# Act
|
# Act
|
||||||
_document_indexing(dataset_id, document_ids)
|
_document_indexing(dataset_id, document_ids)
|
||||||
|
|
||||||
# Assert
|
# Assert - All created sessions should be closed
|
||||||
assert mock_db_session.close.called
|
# The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary)
|
||||||
# Verify close is called exactly once
|
assert len(mock_db_session.all_sessions) >= 1
|
||||||
assert mock_db_session.close.call_count == 1
|
for session in mock_db_session.all_sessions:
|
||||||
|
assert session.close.called, "All sessions should be closed"
|
||||||
|
|
||||||
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user