From 243c3f7dc0166a43ec2ff3a75f58a3369130f4c9 Mon Sep 17 00:00:00 2001 From: NFish Date: Mon, 9 Feb 2026 15:52:22 +0800 Subject: [PATCH 1/8] fix: include app id in automatic generation requests (#32138) --- .../config/automatic/get-automatic-res.tsx | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx index f5ebaac3ca..44ce5cde52 100644 --- a/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx +++ b/web/app/components/app/configuration/config/automatic/get-automatic-res.tsx @@ -19,19 +19,21 @@ import { useBoolean, useSessionStorageState } from 'ahooks' import * as React from 'react' import { useCallback, useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' +import { useShallow } from 'zustand/react/shallow' import Button from '@/app/components/base/button' import Confirm from '@/app/components/base/confirm' import { Generator } from '@/app/components/base/icons/src/vender/other' -import Loading from '@/app/components/base/loading' +import Loading from '@/app/components/base/loading' import Modal from '@/app/components/base/modal' import Toast from '@/app/components/base/toast' -import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' +import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations' import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks' import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal' import { generateBasicAppFirstTimeRule, generateRule } from '@/service/debug' import { useGenerateRuleTemplate } from '@/service/use-apps' +import { useStore } from '../../../store' import IdeaOutput from './idea-output' import InstructionEditorInBasic from './instruction-editor' import InstructionEditorInWorkflow from './instruction-editor-in-workflow' @@ -83,6 +85,9 @@ const GetAutomaticRes: FC = ({ onFinished, }) => { const { t } = useTranslation() + const { appDetail } = useStore(useShallow(state => ({ + appDetail: state.appDetail, + }))) const localModel = localStorage.getItem('auto-gen-model') ? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model : null @@ -235,6 +240,7 @@ const GetAutomaticRes: FC = ({ instruction, model_config: model, no_variable: false, + app_id: appDetail?.id, }) apiRes = { ...res, @@ -256,6 +262,7 @@ const GetAutomaticRes: FC = ({ instruction, ideal_output: ideaOutput, model_config: model, + app_id: appDetail?.id, }) apiRes = res if (error) { From 51946a734aa96aa5011cb1aa8d32c86b0e34c1ae Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Thu, 5 Feb 2026 14:42:34 +0800 Subject: [PATCH 2/8] fix: fix miss use db.session (#31971) --- api/tasks/document_indexing_update_task.py | 6 +- .../test_document_indexing_update_task.py | 182 ++++++++++++++++++ 2 files changed, 184 insertions(+), 4 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 67a23be952..45d58c92ec 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -8,7 +8,6 @@ from sqlalchemy import delete, select from core.db.session_factory import session_factory from core.indexing_runner import DocumentIsPausedError, IndexingRunner from core.rag.index_processor.index_processor_factory import IndexProcessorFactory -from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, Document, DocumentSegment @@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start update document: {document_id}", fg="green")) start_at = time.perf_counter() - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: @@ -36,7 +35,6 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - session.commit() # delete all document segment and index try: @@ -56,7 +54,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str): segment_ids = [segment.id for segment in segments] segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) session.execute(segment_delete_stmt) - db.session.commit() + end_at = time.perf_counter() logger.info( click.style( diff --git a/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py new file mode 100644 index 0000000000..7f37f84113 --- /dev/null +++ b/api/tests/test_containers_integration_tests/tasks/test_document_indexing_update_task.py @@ -0,0 +1,182 @@ +from unittest.mock import MagicMock, patch + +import pytest +from faker import Faker + +from models import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, Document, DocumentSegment +from tasks.document_indexing_update_task import document_indexing_update_task + + +class TestDocumentIndexingUpdateTask: + @pytest.fixture + def mock_external_dependencies(self): + """Patch external collaborators used by the update task. + - IndexProcessorFactory.init_index_processor().clean(...) + - IndexingRunner.run([...]) + """ + with ( + patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory, + patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner, + ): + processor_instance = MagicMock() + mock_factory.return_value.init_index_processor.return_value = processor_instance + + runner_instance = MagicMock() + mock_runner.return_value = runner_instance + + yield { + "factory": mock_factory, + "processor": processor_instance, + "runner": mock_runner, + "runner_instance": runner_instance, + } + + def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2): + fake = Faker() + + # Account and tenant + account = Account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + status="active", + ) + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=TenantAccountRole.OWNER, + current=True, + ) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + # Dataset and document + dataset = Dataset( + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=64), + data_source_type="upload_file", + indexing_technique="high_quality", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + document = Document( + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="upload_file", + batch="test_batch", + name=fake.file_name(), + created_from="upload_file", + created_by=account.id, + indexing_status="waiting", + enabled=True, + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Segments + node_ids = [] + for i in range(segment_count): + node_id = f"node-{i + 1}" + seg = DocumentSegment( + tenant_id=tenant.id, + dataset_id=dataset.id, + document_id=document.id, + position=i, + content=fake.text(max_nb_chars=32), + answer=None, + word_count=10, + tokens=5, + index_node_id=node_id, + status="completed", + created_by=account.id, + ) + db_session_with_containers.add(seg) + node_ids.append(node_id) + db_session_with_containers.commit() + + # Refresh to ensure ORM state + db_session_with_containers.refresh(dataset) + db_session_with_containers.refresh(document) + + return dataset, document, node_ids + + def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies): + dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) + + # Act + document_indexing_update_task(dataset.id, document.id) + + # Ensure we see committed changes from another session + db_session_with_containers.expire_all() + + # Assert document status updated before reindex + updated = db_session_with_containers.query(Document).where(Document.id == document.id).first() + assert updated.indexing_status == "parsing" + assert updated.processing_started_at is not None + + # Segments should be deleted + remaining = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + ) + assert remaining == 0 + + # Assert index processor clean was called with expected args + clean_call = mock_external_dependencies["processor"].clean.call_args + assert clean_call is not None + args, kwargs = clean_call + # args[0] is a Dataset instance (from another session) — validate by id + assert getattr(args[0], "id", None) == dataset.id + # args[1] should contain our node_ids + assert set(args[1]) == set(node_ids) + assert kwargs.get("with_keywords") is True + assert kwargs.get("delete_child_chunks") is True + + # Assert indexing runner invoked with the updated document + run_call = mock_external_dependencies["runner_instance"].run.call_args + assert run_call is not None + run_docs = run_call[0][0] + assert len(run_docs) == 1 + first = run_docs[0] + assert getattr(first, "id", None) == document.id + + def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies): + dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers) + + # Force clean to raise; task should continue to indexing + mock_external_dependencies["processor"].clean.side_effect = Exception("boom") + + document_indexing_update_task(dataset.id, document.id) + + # Ensure we see committed changes from another session + db_session_with_containers.expire_all() + + # Indexing should still be triggered + mock_external_dependencies["runner_instance"].run.assert_called_once() + + # Segments should remain (since clean failed before DB delete) + remaining = ( + db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count() + ) + assert remaining > 0 + + def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies): + fake = Faker() + # Act with non-existent document id + document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4()) + + # Neither processor nor runner should be called + mock_external_dependencies["processor"].clean.assert_not_called() + mock_external_dependencies["runner_instance"].run.assert_not_called() From 9742185e6b1ed608d89c1f74ae1f7e8a8e1bfff9 Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Thu, 5 Feb 2026 19:05:09 +0800 Subject: [PATCH 3/8] perf(api): Optimize the response time of AppListApi endpoint (#31999) --- api/controllers/console/app/app.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 8c371da596..91034f2d87 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -1,3 +1,4 @@ +import logging import uuid from datetime import datetime from typing import Any, Literal, TypeAlias @@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co register_enum_models(console_ns, IconType) +_logger = logging.getLogger(__name__) + class AppListQuery(BaseModel): page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)") @@ -499,6 +502,7 @@ class AppListApi(Resource): select(Workflow).where( Workflow.version == Workflow.VERSION_DRAFT, Workflow.app_id.in_(workflow_capable_app_ids), + Workflow.tenant_id == current_tenant_id, ) ) .scalars() @@ -510,12 +514,14 @@ class AppListApi(Resource): NodeType.TRIGGER_PLUGIN, } for workflow in draft_workflows: + node_id = None try: - for _, node_data in workflow.walk_nodes(): + for node_id, node_data in workflow.walk_nodes(): if node_data.get("type") in trigger_node_types: draft_trigger_app_ids.add(str(workflow.app_id)) break except Exception: + _logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id) continue for app in app_pagination.items: From 075e90a253e50d4a12da9548b628484266e16813 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 6 Feb 2026 11:24:39 +0800 Subject: [PATCH 4/8] fix: fix agent node tool type is not right (#32008) Infer real tool type via querying relevant database tables. The root cause for incorrect `type` field is still not clear. --- api/.importlinter | 2 + api/core/workflow/nodes/agent/agent_node.py | 42 +++- .../core/workflow/nodes/agent/__init__.py | 0 .../workflow/nodes/agent/test_agent_node.py | 197 ++++++++++++++++++ 4 files changed, 239 insertions(+), 2 deletions(-) create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/__init__.py create mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py diff --git a/api/.importlinter b/api/.importlinter index 9dad254560..cbbb9e6618 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -102,6 +102,8 @@ forbidden_modules = core.trigger core.variables ignore_imports = + core.workflow.nodes.agent.agent_node -> core.db.session_factory + core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e195aebe6d..e64a83034c 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Union, cast from packaging.version import Version from pydantic import ValidationError @@ -11,6 +11,7 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter +from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -49,6 +50,12 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation +from models.tools import ( + ApiToolProvider, + BuiltinToolProvider, + MCPToolProvider, + WorkflowToolProvider, +) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) + provider_type = self._infer_tool_provider_type(tool, self.tenant_id) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) + + @staticmethod + def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: + provider_type_str = tool_config.get("type") + if provider_type_str: + return ToolProviderType(provider_type_str) + + provider_id = tool_config.get("provider_name") + if not provider_id: + return ToolProviderType.BUILT_IN + + with session_factory.create_session() as session: + provider_map: dict[ + type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], + ToolProviderType, + ] = { + WorkflowToolProvider: ToolProviderType.WORKFLOW, + MCPToolProvider: ToolProviderType.MCP, + ApiToolProvider: ToolProviderType.API, + BuiltinToolProvider: ToolProviderType.BUILT_IN, + } + + for provider_model, provider_type in provider_map.items(): + stmt = select(provider_model).where( + provider_model.id == provider_id, + provider_model.tenant_id == tenant_id, + ) + if session.scalar(stmt): + return provider_type + + raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py b/api/tests/unit_tests/core/workflow/nodes/agent/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py new file mode 100644 index 0000000000..a95892d0b6 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py @@ -0,0 +1,197 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from core.tools.entities.tool_entities import ToolProviderType +from core.workflow.nodes.agent.agent_node import AgentNode + + +class TestInferToolProviderType: + """Test cases for AgentNode._infer_tool_provider_type method.""" + + def test_infer_type_from_config_workflow(self): + """Test inferring workflow provider type from config.""" + tool_config = { + "type": "workflow", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + + def test_infer_type_from_config_builtin(self): + """Test inferring builtin provider type from config.""" + tool_config = { + "type": "builtin", + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_from_config_api(self): + """Test inferring API provider type from config.""" + tool_config = { + "type": "api", + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + + def test_infer_type_from_config_mcp(self): + """Test inferring MCP provider type from config.""" + tool_config = { + "type": "mcp", + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + + def test_infer_type_invalid_config_value_raises_error(self): + """Test that invalid type value in config raises ValueError.""" + tool_config = { + "type": "invalid-type", + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with pytest.raises(ValueError): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_workflow_type_from_database(self): + """Test inferring workflow provider type from database.""" + tool_config = { + "provider_name": "workflow-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns a result + mock_session.scalar.return_value = True + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.WORKFLOW + # Should only query once (after finding WorkflowToolProvider) + assert mock_session.scalar.call_count == 1 + + def test_infer_mcp_type_from_database(self): + """Test inferring MCP provider type from database.""" + tool_config = { + "provider_name": "mcp-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns a result + mock_session.scalar.side_effect = [None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.MCP + assert mock_session.scalar.call_count == 2 + + def test_infer_api_type_from_database(self): + """Test inferring API provider type from database.""" + tool_config = { + "provider_name": "api-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First query (WorkflowToolProvider) returns None + # Second query (MCPToolProvider) returns None + # Third query (ApiToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.API + assert mock_session.scalar.call_count == 3 + + def test_infer_builtin_type_from_database(self): + """Test inferring builtin provider type from database.""" + tool_config = { + "provider_name": "builtin-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # First three queries return None + # Fourth query (BuiltinToolProvider) returns a result + mock_session.scalar.side_effect = [None, None, None, True] + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + assert mock_session.scalar.call_count == 4 + + def test_infer_type_default_when_not_found(self): + """Test raising AgentNodeError when provider is not found in database.""" + tool_config = { + "provider_name": "unknown-provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # All queries return None + mock_session.scalar.return_value = None + + # Current implementation raises AgentNodeError when provider not found + from core.workflow.nodes.agent.exc import AgentNodeError + + with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + def test_infer_type_default_when_no_provider_name(self): + """Test defaulting to BUILT_IN when provider_name is missing.""" + tool_config = {} + tenant_id = "test-tenant" + + result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) + + assert result == ToolProviderType.BUILT_IN + + def test_infer_type_database_exception_propagates(self): + """Test that database exception propagates (current implementation doesn't catch it).""" + tool_config = { + "provider_name": "provider-id", + } + tenant_id = "test-tenant" + + with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: + mock_session = MagicMock() + mock_create_session.return_value.__enter__.return_value = mock_session + + # Database query raises exception + mock_session.scalar.side_effect = Exception("Database error") + + # Current implementation doesn't catch exceptions, so it propagates + with pytest.raises(Exception, match="Database error"): + AgentNode._infer_tool_provider_type(tool_config, tenant_id) From 9898df5ed59822b1ad87bc7d55113c1c38a46e45 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Fri, 6 Feb 2026 14:38:15 +0800 Subject: [PATCH 5/8] fix: fix tool type is miss (#32042) --- api/.importlinter | 2 - api/core/workflow/nodes/agent/agent_node.py | 42 +--- .../workflow/nodes/agent/test_agent_node.py | 197 ------------------ .../config/agent/agent-tools/index.tsx | 1 + .../hooks/use-tool-selector-state.ts | 1 + .../workflow/block-selector/types.ts | 1 + 6 files changed, 5 insertions(+), 239 deletions(-) delete mode 100644 api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py diff --git a/api/.importlinter b/api/.importlinter index cbbb9e6618..9dad254560 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -102,8 +102,6 @@ forbidden_modules = core.trigger core.variables ignore_imports = - core.workflow.nodes.agent.agent_node -> core.db.session_factory - core.workflow.nodes.agent.agent_node -> models.tools core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.workflow_entry -> core.app.workflow.layers.observability diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index e64a83034c..e195aebe6d 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -2,7 +2,7 @@ from __future__ import annotations import json from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Union, cast +from typing import TYPE_CHECKING, Any, cast from packaging.version import Version from pydantic import ValidationError @@ -11,7 +11,6 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter -from core.db.session_factory import session_factory from core.file import File, FileTransferMethod from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance, ModelManager @@ -50,12 +49,6 @@ from factories import file_factory from factories.agent_factory import get_plugin_agent_strategy from models import ToolFile from models.model import Conversation -from models.tools import ( - ApiToolProvider, - BuiltinToolProvider, - MCPToolProvider, - WorkflowToolProvider, -) from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exc import ( @@ -266,7 +259,7 @@ class AgentNode(Node[AgentNodeData]): value = cast(list[dict[str, Any]], value) tool_value = [] for tool in value: - provider_type = self._infer_tool_provider_type(tool, self.tenant_id) + provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN)) setting_params = tool.get("settings", {}) parameters = tool.get("parameters", {}) manual_input_params = [key for key, value in parameters.items() if value is not None] @@ -755,34 +748,3 @@ class AgentNode(Node[AgentNodeData]): llm_usage=llm_usage, ) ) - - @staticmethod - def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType: - provider_type_str = tool_config.get("type") - if provider_type_str: - return ToolProviderType(provider_type_str) - - provider_id = tool_config.get("provider_name") - if not provider_id: - return ToolProviderType.BUILT_IN - - with session_factory.create_session() as session: - provider_map: dict[ - type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]], - ToolProviderType, - ] = { - WorkflowToolProvider: ToolProviderType.WORKFLOW, - MCPToolProvider: ToolProviderType.MCP, - ApiToolProvider: ToolProviderType.API, - BuiltinToolProvider: ToolProviderType.BUILT_IN, - } - - for provider_model, provider_type in provider_map.items(): - stmt = select(provider_model).where( - provider_model.id == provider_id, - provider_model.tenant_id == tenant_id, - ) - if session.scalar(stmt): - return provider_type - - raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.") diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py deleted file mode 100644 index a95892d0b6..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/agent/test_agent_node.py +++ /dev/null @@ -1,197 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest - -from core.tools.entities.tool_entities import ToolProviderType -from core.workflow.nodes.agent.agent_node import AgentNode - - -class TestInferToolProviderType: - """Test cases for AgentNode._infer_tool_provider_type method.""" - - def test_infer_type_from_config_workflow(self): - """Test inferring workflow provider type from config.""" - tool_config = { - "type": "workflow", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - - def test_infer_type_from_config_builtin(self): - """Test inferring builtin provider type from config.""" - tool_config = { - "type": "builtin", - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_from_config_api(self): - """Test inferring API provider type from config.""" - tool_config = { - "type": "api", - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - - def test_infer_type_from_config_mcp(self): - """Test inferring MCP provider type from config.""" - tool_config = { - "type": "mcp", - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - - def test_infer_type_invalid_config_value_raises_error(self): - """Test that invalid type value in config raises ValueError.""" - tool_config = { - "type": "invalid-type", - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with pytest.raises(ValueError): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_workflow_type_from_database(self): - """Test inferring workflow provider type from database.""" - tool_config = { - "provider_name": "workflow-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns a result - mock_session.scalar.return_value = True - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.WORKFLOW - # Should only query once (after finding WorkflowToolProvider) - assert mock_session.scalar.call_count == 1 - - def test_infer_mcp_type_from_database(self): - """Test inferring MCP provider type from database.""" - tool_config = { - "provider_name": "mcp-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns a result - mock_session.scalar.side_effect = [None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.MCP - assert mock_session.scalar.call_count == 2 - - def test_infer_api_type_from_database(self): - """Test inferring API provider type from database.""" - tool_config = { - "provider_name": "api-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First query (WorkflowToolProvider) returns None - # Second query (MCPToolProvider) returns None - # Third query (ApiToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.API - assert mock_session.scalar.call_count == 3 - - def test_infer_builtin_type_from_database(self): - """Test inferring builtin provider type from database.""" - tool_config = { - "provider_name": "builtin-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # First three queries return None - # Fourth query (BuiltinToolProvider) returns a result - mock_session.scalar.side_effect = [None, None, None, True] - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - assert mock_session.scalar.call_count == 4 - - def test_infer_type_default_when_not_found(self): - """Test raising AgentNodeError when provider is not found in database.""" - tool_config = { - "provider_name": "unknown-provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # All queries return None - mock_session.scalar.return_value = None - - # Current implementation raises AgentNodeError when provider not found - from core.workflow.nodes.agent.exc import AgentNodeError - - with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - def test_infer_type_default_when_no_provider_name(self): - """Test defaulting to BUILT_IN when provider_name is missing.""" - tool_config = {} - tenant_id = "test-tenant" - - result = AgentNode._infer_tool_provider_type(tool_config, tenant_id) - - assert result == ToolProviderType.BUILT_IN - - def test_infer_type_database_exception_propagates(self): - """Test that database exception propagates (current implementation doesn't catch it).""" - tool_config = { - "provider_name": "provider-id", - } - tenant_id = "test-tenant" - - with patch("core.db.session_factory.session_factory.create_session") as mock_create_session: - mock_session = MagicMock() - mock_create_session.return_value.__enter__.return_value = mock_session - - # Database query raises exception - mock_session.scalar.side_effect = Exception("Database error") - - # Current implementation doesn't catch exceptions, so it propagates - with pytest.raises(Exception, match="Database error"): - AgentNode._infer_tool_provider_type(tool_config, tenant_id) diff --git a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx index 486c0a8ac9..b97aa6e775 100644 --- a/web/app/components/app/configuration/config/agent/agent-tools/index.tsx +++ b/web/app/components/app/configuration/config/agent/agent-tools/index.tsx @@ -109,6 +109,7 @@ const AgentTools: FC = () => { tool_parameters: paramsWithDefaultValue, notAuthor: !tool.is_team_authorization, enabled: true, + type: tool.provider_type as CollectionType, } } const handleSelectTool = (tool: ToolDefaultValue) => { diff --git a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts index 44d0ff864e..e5edea5679 100644 --- a/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts +++ b/web/app/components/plugins/plugin-detail-panel/tool-selector/hooks/use-tool-selector-state.ts @@ -129,6 +129,7 @@ export const useToolSelectorState = ({ extra: { description: tool.tool_description, }, + type: tool.provider_type, } }, []) diff --git a/web/app/components/workflow/block-selector/types.ts b/web/app/components/workflow/block-selector/types.ts index 07efb0d02f..39e7b033bd 100644 --- a/web/app/components/workflow/block-selector/types.ts +++ b/web/app/components/workflow/block-selector/types.ts @@ -87,6 +87,7 @@ export type ToolValue = { enabled?: boolean extra?: { description?: string } & Record credential_id?: string + type?: string } export type DataSourceItem = { From b035b091fa48610186511d72c11386eb52894d6d Mon Sep 17 00:00:00 2001 From: QuantumGhost Date: Fri, 6 Feb 2026 15:12:32 +0800 Subject: [PATCH 6/8] perf: use batch delete method instead of single delete (#32036) Co-authored-by: fatelei Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: FFXN --- api/tasks/batch_clean_document_task.py | 211 ++++++++++++++---- api/tasks/delete_segment_from_index_task.py | 11 +- api/tasks/document_indexing_sync_task.py | 4 +- .../tasks/test_document_indexing_sync_task.py | 15 ++ 4 files changed, 190 insertions(+), 51 deletions(-) diff --git a/api/tasks/batch_clean_document_task.py b/api/tasks/batch_clean_document_task.py index d388284980..747106d373 100644 --- a/api/tasks/batch_clean_document_task.py +++ b/api/tasks/batch_clean_document_task.py @@ -14,6 +14,9 @@ from models.model import UploadFile logger = logging.getLogger(__name__) +# Batch size for database operations to keep transactions short +BATCH_SIZE = 1000 + @shared_task(queue="dataset") def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]): @@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form if not doc_form: raise ValueError("doc_form is required") - with session_factory.create_session() as session: - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - - if not dataset: - raise Exception("Document has no dataset") - - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id.in_(document_ids), - ).delete(synchronize_session=False) + storage_keys_to_delete: list[str] = [] + index_node_ids: list[str] = [] + segment_ids: list[str] = [] + total_image_upload_file_ids: list[str] = [] + try: + # ============ Step 1: Query segment and file data (short read-only transaction) ============ + with session_factory.create_session() as session: + # Get segments info segments = session.scalars( select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)) ).all() - # check segment is exist + if segments: index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() - index_processor.clean( - dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True - ) + segment_ids = [segment.id for segment in segments] + # Collect image file IDs from segment content for segment in segments: image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all() - for image_file in image_files: - try: - if image_file and image_file.key: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) - stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(stmt) - session.delete(segment) + total_image_upload_file_ids.extend(image_upload_file_ids) + + # Query storage keys for image files + if total_image_upload_file_ids: + image_files = session.scalars( + select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids)) + ).all() + storage_keys_to_delete.extend([f.key for f in image_files if f and f.key]) + + # Query storage keys for document files if file_ids: files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all() - for file in files: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file.id) - stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) - session.execute(stmt) + storage_keys_to_delete.extend([f.key for f in files if f and f.key]) - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned documents when documents deleted latency: {end_at - start_at}", - fg="green", + # ============ Step 2: Clean vector index (external service, fresh session for dataset) ============ + if index_node_ids: + try: + # Fetch dataset in a fresh session to avoid DetachedInstanceError + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id) + else: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + index_processor.clean( + dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True + ) + except Exception: + logger.exception( + "Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d", + dataset_id, + document_ids, + len(index_node_ids), ) - ) + + # ============ Step 3: Delete metadata binding (separate short transaction) ============ + try: + with session_factory.create_session() as session: + deleted_count = ( + session.query(DatasetMetadataBinding) + .where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id.in_(document_ids), + ) + .delete(synchronize_session=False) + ) + session.commit() + logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id) except Exception: - logger.exception("Cleaned documents when documents deleted failed") + logger.exception( + "Failed to delete metadata bindings for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) + + # ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============ + if total_image_upload_file_ids: + failed_batches = 0 + total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE): + batch = total_image_upload_file_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(batch)) + session.execute(stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete image UploadFile batch %d-%d for dataset_id: %s", + i, + i + len(batch), + dataset_id, + ) + if failed_batches > 0: + logger.warning( + "Image UploadFile deletion: %d/%d batches failed for dataset_id: %s", + failed_batches, + total_batches, + dataset_id, + ) + + # ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============ + if segment_ids: + failed_batches = 0 + total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE + for i in range(0, len(segment_ids), BATCH_SIZE): + batch = segment_ids[i : i + BATCH_SIZE] + try: + with session_factory.create_session() as session: + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch)) + session.execute(segment_delete_stmt) + session.commit() + except Exception: + failed_batches += 1 + logger.exception( + "Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s", + i, + i + len(batch), + dataset_id, + document_ids, + ) + if failed_batches > 0: + logger.warning( + "DocumentSegment deletion: %d/%d batches failed, document_ids: %s", + failed_batches, + total_batches, + document_ids, + ) + + # ============ Step 6: Delete document-associated files (separate short transaction) ============ + if file_ids: + try: + with session_factory.create_session() as session: + stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids)) + session.execute(stmt) + session.commit() + except Exception: + logger.exception( + "Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s", + dataset_id, + file_ids, + ) + + # ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============ + storage_delete_failures = 0 + for storage_key in storage_keys_to_delete: + try: + storage.delete(storage_key) + except Exception: + storage_delete_failures += 1 + logger.exception("Failed to delete file from storage, key: %s", storage_key) + if storage_delete_failures > 0: + logger.warning( + "Storage file deletion completed with %d failures out of %d total files for dataset_id: %s", + storage_delete_failures, + len(storage_keys_to_delete), + dataset_id, + ) + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, " + f"dataset_id: {dataset_id}, document_ids: {document_ids}, " + f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, " + f"storage_files: {len(storage_keys_to_delete)}", + fg="green", + ) + ) + except Exception: + logger.exception( + "Batch clean documents failed for dataset_id: %s, document_ids: %s", + dataset_id, + document_ids, + ) diff --git a/api/tasks/delete_segment_from_index_task.py b/api/tasks/delete_segment_from_index_task.py index 764c635d83..a6a2dcebc8 100644 --- a/api/tasks/delete_segment_from_index_task.py +++ b/api/tasks/delete_segment_from_index_task.py @@ -3,6 +3,7 @@ import time import click from celery import shared_task +from sqlalchemy import delete from core.db.session_factory import session_factory from core.rag.index_processor.index_processor_factory import IndexProcessorFactory @@ -67,8 +68,14 @@ def delete_segment_from_index_task( if segment_attachment_bindings: attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings] index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False) - for binding in segment_attachment_bindings: - session.delete(binding) + segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings] + + for i in range(0, len(segment_attachment_bind_ids), 1000): + segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where( + SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000]) + ) + session.execute(segment_attachment_bind_delete_stmt) + # delete upload file session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False) session.commit() diff --git a/api/tasks/document_indexing_sync_task.py b/api/tasks/document_indexing_sync_task.py index 149185f6e2..8fa5faa796 100644 --- a/api/tasks/document_indexing_sync_task.py +++ b/api/tasks/document_indexing_sync_task.py @@ -28,7 +28,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): logger.info(click.style(f"Start sync document: {document_id}", fg="green")) start_at = time.perf_counter() - with session_factory.create_session() as session: + with session_factory.create_session() as session, session.begin(): document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() if not document: @@ -68,7 +68,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): document.indexing_status = "error" document.error = "Datasource credential not found. Please reconnect your Notion workspace." document.stopped_at = naive_utc_now() - session.commit() return loader = NotionExtractor( @@ -85,7 +84,6 @@ def document_indexing_sync_task(dataset_id: str, document_id: str): if last_edited_time != page_edited_time: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - session.commit() # delete all document segment and index try: diff --git a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py index fa33034f40..24e0bc76cf 100644 --- a/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py +++ b/api/tests/unit_tests/tasks/test_document_indexing_sync_task.py @@ -114,6 +114,21 @@ def mock_db_session(): session = MagicMock() # Ensure tests can observe session.close() via context manager teardown session.close = MagicMock() + session.commit = MagicMock() + + # Mock session.begin() context manager to auto-commit on exit + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def _begin_exit_side_effect(*args, **kwargs): + # session.begin().__exit__() should commit if no exception + if args[0] is None: # No exception + session.commit() + + begin_cm.__exit__.side_effect = _begin_exit_side_effect + session.begin.return_value = begin_cm + + # Mock create_session() context manager cm = MagicMock() cm.__enter__.return_value = session From 55de893984f72b5ba94871635b76b32ad19b18fc Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Sun, 8 Feb 2026 21:01:54 +0800 Subject: [PATCH 7/8] =?UTF-8?q?refactor:=20partition=20Celery=20task=20ses?= =?UTF-8?q?sions=20into=20smaller,=20discrete=20execu=E2=80=A6=20(#32085)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- .../add_annotation_to_index_task.py | 3 - .../delete_annotation_index_task.py | 3 - .../update_annotation_to_index_task.py | 3 - .../batch_create_segment_to_index_task.py | 180 ++++--- api/tasks/clean_document_task.py | 150 +++--- api/tasks/document_indexing_task.py | 72 +-- api/tasks/workflow_draft_var_tasks.py | 5 +- ...test_batch_create_segment_to_index_task.py | 28 +- .../tasks/test_dataset_indexing_task.py | 508 +++++++----------- 9 files changed, 436 insertions(+), 516 deletions(-) diff --git a/api/tasks/annotation/add_annotation_to_index_task.py b/api/tasks/annotation/add_annotation_to_index_task.py index 23c49f2742..a9a8b892c2 100644 --- a/api/tasks/annotation/add_annotation_to_index_task.py +++ b/api/tasks/annotation/add_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -58,5 +57,3 @@ def add_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/delete_annotation_index_task.py b/api/tasks/annotation/delete_annotation_index_task.py index e928c25546..432732af95 100644 --- a/api/tasks/annotation/delete_annotation_index_task.py +++ b/api/tasks/annotation/delete_annotation_index_task.py @@ -5,7 +5,6 @@ import click from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green")) except Exception: logger.exception("Annotation deleted index failed") - finally: - db.session.close() diff --git a/api/tasks/annotation/update_annotation_to_index_task.py b/api/tasks/annotation/update_annotation_to_index_task.py index 957d8f7e45..6ff34c0e74 100644 --- a/api/tasks/annotation/update_annotation_to_index_task.py +++ b/api/tasks/annotation/update_annotation_to_index_task.py @@ -6,7 +6,6 @@ from celery import shared_task from core.rag.datasource.vdb.vector_factory import Vector from core.rag.models.document import Document -from extensions.ext_database import db from models.dataset import Dataset from services.dataset_service import DatasetCollectionBindingService @@ -59,5 +58,3 @@ def update_annotation_to_index_task( ) except Exception: logger.exception("Build index for annotation failed") - finally: - db.session.close() diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index 8ee09d5738..f69f17b16d 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -48,6 +48,11 @@ def batch_create_segment_to_index_task( indexing_cache_key = f"segment_batch_import_{job_id}" + # Initialize variables with default values + upload_file_key: str | None = None + dataset_config: dict | None = None + document_config: dict | None = None + with session_factory.create_session() as session: try: dataset = session.get(Dataset, dataset_id) @@ -69,86 +74,115 @@ def batch_create_segment_to_index_task( if not upload_file: raise ValueError("UploadFile not found.") - with tempfile.TemporaryDirectory() as temp_dir: - suffix = Path(upload_file.key).suffix - file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore - storage.download(upload_file.key, file_path) + dataset_config = { + "id": dataset.id, + "indexing_technique": dataset.indexing_technique, + "tenant_id": dataset.tenant_id, + "embedding_model_provider": dataset.embedding_model_provider, + "embedding_model": dataset.embedding_model, + } - df = pd.read_csv(file_path) - content = [] - for _, row in df.iterrows(): - if dataset_document.doc_form == "qa_model": - data = {"content": row.iloc[0], "answer": row.iloc[1]} - else: - data = {"content": row.iloc[0]} - content.append(data) - if len(content) == 0: - raise ValueError("The CSV file is empty.") + document_config = { + "id": dataset_document.id, + "doc_form": dataset_document.doc_form, + "word_count": dataset_document.word_count or 0, + } - document_segments = [] - embedding_model = None - if dataset.indexing_technique == "high_quality": - model_manager = ModelManager() - embedding_model = model_manager.get_model_instance( - tenant_id=dataset.tenant_id, - provider=dataset.embedding_model_provider, - model_type=ModelType.TEXT_EMBEDDING, - model=dataset.embedding_model, - ) + upload_file_key = upload_file.key - word_count_change = 0 - if embedding_model: - tokens_list = embedding_model.get_text_embedding_num_tokens( - texts=[segment["content"] for segment in content] - ) + except Exception: + logger.exception("Segments batch created index failed") + redis_client.setex(indexing_cache_key, 600, "error") + return + + # Ensure required variables are set before proceeding + if upload_file_key is None or dataset_config is None or document_config is None: + logger.error("Required configuration not set due to session error") + redis_client.setex(indexing_cache_key, 600, "error") + return + + with tempfile.TemporaryDirectory() as temp_dir: + suffix = Path(upload_file_key).suffix + file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore + storage.download(upload_file_key, file_path) + + df = pd.read_csv(file_path) + content = [] + for _, row in df.iterrows(): + if document_config["doc_form"] == "qa_model": + data = {"content": row.iloc[0], "answer": row.iloc[1]} else: - tokens_list = [0] * len(content) + data = {"content": row.iloc[0]} + content.append(data) + if len(content) == 0: + raise ValueError("The CSV file is empty.") - for segment, tokens in zip(content, tokens_list): - content = segment["content"] - doc_id = str(uuid.uuid4()) - segment_hash = helper.generate_text_hash(content) - max_position = ( - session.query(func.max(DocumentSegment.position)) - .where(DocumentSegment.document_id == dataset_document.id) - .scalar() - ) - segment_document = DocumentSegment( - tenant_id=tenant_id, - dataset_id=dataset_id, - document_id=document_id, - index_node_id=doc_id, - index_node_hash=segment_hash, - position=max_position + 1 if max_position else 1, - content=content, - word_count=len(content), - tokens=tokens, - created_by=user_id, - indexing_at=naive_utc_now(), - status="completed", - completed_at=naive_utc_now(), - ) - if dataset_document.doc_form == "qa_model": - segment_document.answer = segment["answer"] - segment_document.word_count += len(segment["answer"]) - word_count_change += segment_document.word_count - session.add(segment_document) - document_segments.append(segment_document) + document_segments = [] + embedding_model = None + if dataset_config["indexing_technique"] == "high_quality": + model_manager = ModelManager() + embedding_model = model_manager.get_model_instance( + tenant_id=dataset_config["tenant_id"], + provider=dataset_config["embedding_model_provider"], + model_type=ModelType.TEXT_EMBEDDING, + model=dataset_config["embedding_model"], + ) + word_count_change = 0 + if embedding_model: + tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content]) + else: + tokens_list = [0] * len(content) + + with session_factory.create_session() as session, session.begin(): + for segment, tokens in zip(content, tokens_list): + content = segment["content"] + doc_id = str(uuid.uuid4()) + segment_hash = helper.generate_text_hash(content) + max_position = ( + session.query(func.max(DocumentSegment.position)) + .where(DocumentSegment.document_id == document_config["id"]) + .scalar() + ) + segment_document = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset_id, + document_id=document_id, + index_node_id=doc_id, + index_node_hash=segment_hash, + position=max_position + 1 if max_position else 1, + content=content, + word_count=len(content), + tokens=tokens, + created_by=user_id, + indexing_at=naive_utc_now(), + status="completed", + completed_at=naive_utc_now(), + ) + if document_config["doc_form"] == "qa_model": + segment_document.answer = segment["answer"] + segment_document.word_count += len(segment["answer"]) + word_count_change += segment_document.word_count + session.add(segment_document) + document_segments.append(segment_document) + + with session_factory.create_session() as session, session.begin(): + dataset_document = session.get(Document, document_id) + if dataset_document: assert dataset_document.word_count is not None dataset_document.word_count += word_count_change session.add(dataset_document) - VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form) - session.commit() - redis_client.setex(indexing_cache_key, 600, "completed") - end_at = time.perf_counter() - logger.info( - click.style( - f"Segment batch created job: {job_id} latency: {end_at - start_at}", - fg="green", - ) - ) - except Exception: - logger.exception("Segments batch created index failed") - redis_client.setex(indexing_cache_key, 600, "error") + with session_factory.create_session() as session: + dataset = session.get(Dataset, dataset_id) + if dataset: + VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"]) + + redis_client.setex(indexing_cache_key, 600, "completed") + end_at = time.perf_counter() + logger.info( + click.style( + f"Segment batch created job: {job_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/clean_document_task.py b/api/tasks/clean_document_task.py index 91ace6be02..a017e9114b 100644 --- a/api/tasks/clean_document_task.py +++ b/api/tasks/clean_document_task.py @@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i """ logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green")) start_at = time.perf_counter() + total_attachment_files = [] with session_factory.create_session() as session: try: @@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i SegmentAttachmentBinding.document_id == document_id, ) ).all() - # check segment is exist - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - index_processor = IndexProcessorFactory(doc_form).init_index_processor() + + attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] + binding_ids = [binding.id for binding, _ in attachments_with_bindings] + total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings]) + + index_node_ids = [segment.index_node_id for segment in segments] + segment_contents = [segment.content for segment in segments] + except Exception: + logger.exception("Cleaned document when document deleted failed") + return + + # check segment is exist + if index_node_ids: + index_processor = IndexProcessorFactory(doc_form).init_index_processor() + with session_factory.create_session() as session: + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if dataset: index_processor.clean( dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True ) - for segment in segments: - image_upload_file_ids = get_image_upload_file_ids(segment.content) - image_files = session.scalars( - select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - ).all() - for image_file in image_files: - if image_file is None: - continue - try: - storage.delete(image_file.key) - except Exception: - logger.exception( - "Delete image_files failed when storage deleted, \ - image_upload_file_is: %s", - image_file.id, - ) + total_image_files = [] + with session_factory.create_session() as session, session.begin(): + for segment_content in segment_contents: + image_upload_file_ids = get_image_upload_file_ids(segment_content) + image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all() + total_image_files.extend([image_file.key for image_file in image_files]) + image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) + session.execute(image_file_delete_stmt) - image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)) - session.execute(image_file_delete_stmt) - session.delete(segment) + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) - session.commit() - if file_id: - file = session.query(UploadFile).where(UploadFile.id == file_id).first() - if file: - try: - storage.delete(file.key) - except Exception: - logger.exception("Delete file failed when document deleted, file_id: %s", file_id) - session.delete(file) - # delete segment attachments - if attachments_with_bindings: - attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings] - binding_ids = [binding.id for binding, _ in attachments_with_bindings] - for binding, attachment_file in attachments_with_bindings: - try: - storage.delete(attachment_file.key) - except Exception: - logger.exception( - "Delete attachment_file failed when storage deleted, \ - attachment_file_id: %s", - binding.attachment_id, - ) - attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) - session.execute(attachment_file_delete_stmt) - - binding_delete_stmt = delete(SegmentAttachmentBinding).where( - SegmentAttachmentBinding.id.in_(binding_ids) - ) - session.execute(binding_delete_stmt) - - # delete dataset metadata binding - session.query(DatasetMetadataBinding).where( - DatasetMetadataBinding.dataset_id == dataset_id, - DatasetMetadataBinding.document_id == document_id, - ).delete() - session.commit() - - end_at = time.perf_counter() - logger.info( - click.style( - f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", - fg="green", - ) - ) + for image_file_key in total_image_files: + try: + storage.delete(image_file_key) except Exception: - logger.exception("Cleaned document when document deleted failed") + logger.exception( + "Delete image_files failed when storage deleted, \ + image_upload_file_is: %s", + image_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + if file_id: + file = session.query(UploadFile).where(UploadFile.id == file_id).first() + if file: + try: + storage.delete(file.key) + except Exception: + logger.exception("Delete file failed when document deleted, file_id: %s", file_id) + session.delete(file) + + with session_factory.create_session() as session, session.begin(): + # delete segment attachments + if attachment_ids: + attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids)) + session.execute(attachment_file_delete_stmt) + + if binding_ids: + binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids)) + session.execute(binding_delete_stmt) + + for attachment_file_key in total_attachment_files: + try: + storage.delete(attachment_file_key) + except Exception: + logger.exception( + "Delete attachment_file failed when storage deleted, \ + attachment_file_id: %s", + attachment_file_key, + ) + + with session_factory.create_session() as session, session.begin(): + # delete dataset metadata binding + session.query(DatasetMetadataBinding).where( + DatasetMetadataBinding.dataset_id == dataset_id, + DatasetMetadataBinding.document_id == document_id, + ).delete() + + end_at = time.perf_counter() + logger.info( + click.style( + f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}", + fg="green", + ) + ) diff --git a/api/tasks/document_indexing_task.py b/api/tasks/document_indexing_task.py index 34496e9c6f..11edcf151f 100644 --- a/api/tasks/document_indexing_task.py +++ b/api/tasks/document_indexing_task.py @@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): session.commit() return - for document_id in document_ids: - logger.info(click.style(f"Start process document: {document_id}", fg="green")) - - document = ( - session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first() - ) + # Phase 1: Update status to parsing (short transaction) + with session_factory.create_session() as session, session.begin(): + documents = ( + session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all() + ) + for document in documents: if document: document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - documents.append(document) session.add(document) - session.commit() + # Transaction committed and closed - try: - indexing_runner = IndexingRunner() - indexing_runner.run(documents) - end_at = time.perf_counter() - logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + # Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions) + has_error = False + try: + indexing_runner = IndexingRunner() + indexing_runner.run(documents) + end_at = time.perf_counter() + logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + has_error = True + except Exception: + logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) + has_error = True + if not has_error: + with session_factory.create_session() as session: # Trigger summary index generation for completed documents if enabled # Only generate for high_quality indexing technique and when summary_index_setting is enabled # Re-query dataset to get latest summary_index_setting (in case it was updated) @@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): # expire all session to get latest document's indexing status session.expire_all() # Check each document's indexing status and trigger summary generation if completed - for document_id in document_ids: - # Re-query document to get latest status (IndexingRunner may have updated it) - document = ( - session.query(Document) - .where(Document.id == document_id, Document.dataset_id == dataset_id) - .first() - ) + + documents = ( + session.query(Document) + .where(Document.id.in_(document_ids), Document.dataset_id == dataset_id) + .all() + ) + + for document in documents: if document: logger.info( "Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, @@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]): and document.need_summary is True ): try: - generate_summary_index_task.delay(dataset.id, document_id, None) + generate_summary_index_task.delay(dataset.id, document.id, None) logger.info( "Queued summary index generation task for document %s in dataset %s " "after indexing completed", - document_id, + document.id, dataset.id, ) except Exception: logger.exception( "Failed to queue summary index generation task for document %s", - document_id, + document.id, ) # Don't fail the entire indexing process if summary task queuing fails else: logger.info( "Skipping summary generation for document %s: " "status=%s, doc_form=%s, need_summary=%s", - document_id, + document.id, document.indexing_status, document.doc_form, document.need_summary, ) else: - logger.warning("Document %s not found after indexing", document_id) - else: - logger.info( - "Summary index generation skipped for dataset %s: summary_index_setting.enable=%s", - dataset.id, - summary_index_setting.get("enable") if summary_index_setting else None, - ) + logger.warning("Document %s not found after indexing", document.id) else: logger.info( "Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')", dataset.id, dataset.indexing_technique, ) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("Document indexing task failed, dataset_id: %s", dataset_id) def _document_indexing_with_tenant_queue( diff --git a/api/tasks/workflow_draft_var_tasks.py b/api/tasks/workflow_draft_var_tasks.py index fcb98ec39e..26f8f7c29e 100644 --- a/api/tasks/workflow_draft_var_tasks.py +++ b/api/tasks/workflow_draft_var_tasks.py @@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers. """ from celery import shared_task # type: ignore[import-untyped] -from sqlalchemy.orm import Session -from extensions.ext_database import db +from core.db.session_factory import session_factory from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService @@ -17,6 +16,6 @@ def save_workflow_execution_task( self, deletions: list[DraftVarFileDeletion], ): - with Session(bind=db.engine) as session, session.begin(): + with session_factory.create_session() as session, session.begin(): srv = WorkflowDraftVariableService(session=session) srv.delete_workflow_draft_variable_file(deletions=deletions) diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index 1b844d6357..61f6b75b10 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask: mock_storage.download.side_effect = mock_download - # Execute the task + # Execute the task - should raise ValueError for empty CSV job_id = str(uuid.uuid4()) - batch_create_segment_to_index_task( - job_id=job_id, - upload_file_id=upload_file.id, - dataset_id=dataset.id, - document_id=document.id, - tenant_id=tenant.id, - user_id=account.id, - ) + with pytest.raises(ValueError, match="The CSV file is empty"): + batch_create_segment_to_index_task( + job_id=job_id, + upload_file_id=upload_file.id, + dataset_id=dataset.id, + document_id=document.id, + tenant_id=tenant.id, + user_id=account.id, + ) # Verify error handling - # Check Redis cache was set to error status - from extensions.ext_redis import redis_client - - cache_key = f"segment_batch_import_{job_id}" - cache_value = redis_client.get(cache_key) - assert cache_value == b"error" - - # Verify no segments were created + # Since exception was raised, no segments should be created from extensions.ext_database import db segments = db.session.query(DocumentSegment).all() diff --git a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py index e24ef32a24..8d8e2b0db0 100644 --- a/api/tests/unit_tests/tasks/test_dataset_indexing_task.py +++ b/api/tests/unit_tests/tasks/test_dataset_indexing_task.py @@ -83,23 +83,127 @@ def mock_documents(document_ids, dataset_id): def mock_db_session(): """Mock database session via session_factory.create_session().""" with patch("tasks.document_indexing_task.session_factory") as mock_sf: - session = MagicMock() - # Ensure tests that expect session.close() to be called can observe it via the context manager - session.close = MagicMock() - cm = MagicMock() - cm.__enter__.return_value = session - # Link __exit__ to session.close so "close" expectations reflect context manager teardown + sessions = [] # Track all created sessions + # Shared mock data that all sessions will access + shared_mock_data = {"dataset": None, "documents": None, "doc_iter": None} - def _exit_side_effect(*args, **kwargs): - session.close() + def create_session_side_effect(): + session = MagicMock() + session.close = MagicMock() - cm.__exit__.side_effect = _exit_side_effect - mock_sf.create_session.return_value = cm + # Track commit calls + commit_mock = MagicMock() + session.commit = commit_mock + cm = MagicMock() + cm.__enter__.return_value = session - query = MagicMock() - session.query.return_value = query - query.where.return_value = query - yield session + def _exit_side_effect(*args, **kwargs): + session.close() + + cm.__exit__.side_effect = _exit_side_effect + + # Support session.begin() for transactions + begin_cm = MagicMock() + begin_cm.__enter__.return_value = session + + def begin_exit_side_effect(*args, **kwargs): + # Auto-commit on transaction exit (like SQLAlchemy) + session.commit() + # Also mark wrapper's commit as called + if sessions: + sessions[0].commit() + + begin_cm.__exit__ = MagicMock(side_effect=begin_exit_side_effect) + session.begin = MagicMock(return_value=begin_cm) + + sessions.append(session) + + # Setup query with side_effect to handle both Dataset and Document queries + def query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + # Support both .first() and .all() calls with chaining + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + # Create an iterator for .first() calls if not exists + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + session.query = MagicMock(side_effect=query_side_effect) + return cm + + mock_sf.create_session.side_effect = create_session_side_effect + + # Create a wrapper that behaves like the first session but has access to all sessions + class SessionWrapper: + def __init__(self): + self._sessions = sessions + self._shared_data = shared_mock_data + # Create a default session for setup phase + self._default_session = MagicMock() + self._default_session.close = MagicMock() + self._default_session.commit = MagicMock() + + # Support session.begin() for default session too + begin_cm = MagicMock() + begin_cm.__enter__.return_value = self._default_session + + def default_begin_exit_side_effect(*args, **kwargs): + self._default_session.commit() + + begin_cm.__exit__ = MagicMock(side_effect=default_begin_exit_side_effect) + self._default_session.begin = MagicMock(return_value=begin_cm) + + def default_query_side_effect(*args): + query = MagicMock() + if args and args[0] == Dataset and shared_mock_data["dataset"] is not None: + where_result = MagicMock() + where_result.first.return_value = shared_mock_data["dataset"] + query.where = MagicMock(return_value=where_result) + elif args and args[0] == Document and shared_mock_data["documents"] is not None: + where_result = MagicMock() + where_result.where = MagicMock(return_value=where_result) + + if shared_mock_data["doc_iter"] is None: + docs = shared_mock_data["documents"] or [None] + shared_mock_data["doc_iter"] = iter(docs) + + where_result.first = lambda: next(shared_mock_data["doc_iter"], None) + docs_or_empty = shared_mock_data["documents"] or [] + where_result.all = MagicMock(return_value=docs_or_empty) + query.where = MagicMock(return_value=where_result) + else: + query.where = MagicMock(return_value=query) + return query + + self._default_session.query = MagicMock(side_effect=default_query_side_effect) + + def __getattr__(self, name): + # Forward all attribute access to the first session, or default if none created yet + target_session = self._sessions[0] if self._sessions else self._default_session + return getattr(target_session, name) + + @property + def all_sessions(self): + """Access all created sessions for testing.""" + return self._sessions + + wrapper = SessionWrapper() + yield wrapper @pytest.fixture @@ -252,18 +356,9 @@ class TestTaskEnqueuing: use the deprecated function. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -304,21 +399,9 @@ class TestBatchProcessing: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create an iterator for documents - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - # Return documents one by one for each call - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -357,19 +440,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.PROFESSIONAL @@ -407,19 +480,9 @@ class TestBatchProcessing: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents mock_feature_service.get_features.return_value.billing.enabled = True mock_feature_service.get_features.return_value.billing.subscription.plan = CloudPlan.SANDBOX @@ -444,7 +507,10 @@ class TestBatchProcessing: """ # Arrange document_ids = [] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset + + # Set shared mock data with empty documents list + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -482,19 +548,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -528,19 +584,9 @@ class TestProgressTracking: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -635,19 +681,9 @@ class TestErrorHandling: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set up to trigger vector space limit error mock_feature_service.get_features.return_value.billing.enabled = True @@ -674,17 +710,9 @@ class TestErrorHandling: Errors during indexing should be caught and logged, but not crash the task. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Indexing failed") @@ -708,17 +736,9 @@ class TestErrorHandling: but not treated as a failure. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise DocumentIsPausedError mock_indexing_runner.run.side_effect = DocumentIsPausedError("Document is paused") @@ -853,17 +873,9 @@ class TestTaskCancellation: Session cleanup should happen in finally block. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -883,17 +895,9 @@ class TestTaskCancellation: Session cleanup should happen even when errors occur. """ # Arrange - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first.side_effect = mock_documents - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = Exception("Test error") @@ -962,6 +966,7 @@ class TestAdvancedScenarios: document_ids = [str(uuid.uuid4()) for _ in range(3)] # Create only 2 documents (simulate one missing) + # The new code uses .all() which will only return existing documents mock_documents = [] for i, doc_id in enumerate([document_ids[0], document_ids[2]]): # Skip middle one doc = MagicMock(spec=Document) @@ -971,21 +976,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - # Create iterator that returns None for missing document - doc_responses = [mock_documents[0], None, mock_documents[1]] - doc_iter = iter(doc_responses) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data - .all() will only return existing documents + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1075,19 +1068,9 @@ class TestAdvancedScenarios: doc.stopped_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space exactly at limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1219,19 +1202,9 @@ class TestAdvancedScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Billing disabled - limits should not be checked mock_feature_service.get_features.return_value.billing.enabled = False @@ -1273,19 +1246,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1321,19 +1284,9 @@ class TestIntegration: # Set up rpop to return None for concurrency check (no more tasks) mock_redis.rpop.side_effect = [None] - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1415,17 +1368,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1465,17 +1410,9 @@ class TestEdgeCases: mock_document.indexing_status = "waiting" mock_document.processing_started_at = None - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: mock_document - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = [mock_document] with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1555,19 +1492,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set vector space limit to 0 (unlimited) mock_feature_service.get_features.return_value.billing.enabled = True @@ -1612,19 +1539,9 @@ class TestEdgeCases: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Set negative vector space limit mock_feature_service.get_features.return_value.billing.enabled = True @@ -1675,19 +1592,9 @@ class TestPerformanceScenarios: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Configure billing with sufficient limits mock_feature_service.get_features.return_value.billing.enabled = True @@ -1826,19 +1733,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents # Make IndexingRunner raise an exception mock_indexing_runner.run.side_effect = RuntimeError("Unexpected indexing error") @@ -1866,7 +1763,7 @@ class TestRobustness: - No exceptions occur Expected behavior: - - Database session is closed + - All database sessions are closed - No connection leaks """ # Arrange @@ -1879,19 +1776,9 @@ class TestRobustness: doc.processing_started_at = None mock_documents.append(doc) - mock_db_session.query.return_value.where.return_value.first.return_value = mock_dataset - - doc_iter = iter(mock_documents) - - def mock_query_side_effect(*args): - mock_query = MagicMock() - if args[0] == Dataset: - mock_query.where.return_value.first.return_value = mock_dataset - elif args[0] == Document: - mock_query.where.return_value.first = lambda: next(doc_iter, None) - return mock_query - - mock_db_session.query.side_effect = mock_query_side_effect + # Set shared mock data so all sessions can access it + mock_db_session._shared_data["dataset"] = mock_dataset + mock_db_session._shared_data["documents"] = mock_documents with patch("tasks.document_indexing_task.FeatureService.get_features") as mock_features: mock_features.return_value.billing.enabled = False @@ -1899,10 +1786,11 @@ class TestRobustness: # Act _document_indexing(dataset_id, document_ids) - # Assert - assert mock_db_session.close.called - # Verify close is called exactly once - assert mock_db_session.close.call_count == 1 + # Assert - All created sessions should be closed + # The code creates multiple sessions: validation, Phase 1 (parsing), Phase 3 (summary) + assert len(mock_db_session.all_sessions) >= 1 + for session in mock_db_session.all_sessions: + assert session.close.called, "All sessions should be closed" def test_task_proxy_handles_feature_service_failure(self, tenant_id, dataset_id, document_ids, mock_redis): """ From 284c5f40f10162c494e2b66c470bc33e5f28f2c1 Mon Sep 17 00:00:00 2001 From: wangxiaolei Date: Mon, 9 Feb 2026 10:49:23 +0800 Subject: [PATCH 8/8] refactor: document_indexing_update_task split database session (#32105) Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- api/pyproject.toml | 2 +- api/tasks/document_indexing_update_task.py | 56 +++++++++++----------- api/uv.lock | 11 +++-- 3 files changed, 35 insertions(+), 34 deletions(-) diff --git a/api/pyproject.toml b/api/pyproject.toml index 4be7afff26..2a7c946e6e 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -81,7 +81,7 @@ dependencies = [ "starlette==0.49.1", "tiktoken~=0.9.0", "transformers~=4.56.1", - "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", + "unstructured[docx,epub,md,ppt,pptx]~=0.18.18", "yarl~=1.18.3", "webvtt-py~=0.5.1", "sseclient-py~=1.8.0", diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 45d58c92ec..c7508c6d05 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -36,25 +36,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - # delete all document segment and index - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + return - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) + index_type = document.doc_form + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + clean_success = False + try: + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if index_node_ids: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() logger.info( click.style( @@ -64,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str): fg="green", ) ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + clean_success = True + except Exception: + logger.exception("Failed to clean document index during update, document_id: %s", document_id) - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) + if clean_success: + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/uv.lock b/api/uv.lock index 700011d7b1..a3b5433952 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1653,7 +1653,7 @@ requires-dist = [ { name = "starlette", specifier = "==0.49.1" }, { name = "tiktoken", specifier = "~=0.9.0" }, { name = "transformers", specifier = "~=4.56.1" }, - { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, + { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.17.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -6814,12 +6814,12 @@ wheels = [ [[package]] name = "unstructured" -version = "0.16.25" +version = "0.18.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, { name = "beautifulsoup4" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "dataclasses-json" }, { name = "emoji" }, { name = "filetype" }, @@ -6827,6 +6827,7 @@ dependencies = [ { name = "langdetect" }, { name = "lxml" }, { name = "nltk" }, + { name = "numba" }, { name = "numpy" }, { name = "psutil" }, { name = "python-iso639" }, @@ -6839,9 +6840,9 @@ dependencies = [ { name = "unstructured-client" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" }, + { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" }, ] [package.optional-dependencies]