diff --git a/api/.importlinter b/api/.importlinter index 2a6bb66a95..0853c2ded8 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -50,14 +50,12 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database core.workflow.nodes.tool.tool_node -> extensions.ext_database core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis core.workflow.graph_engine.manager -> extensions.ext_redis - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis [importlinter:contract:workflow-external-imports] name = Workflow External Imports @@ -122,11 +120,6 @@ ignore_imports = core.workflow.nodes.http_request.node -> core.tools.tool_file_manager core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model core.workflow.nodes.llm.llm_utils -> configs core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities core.workflow.nodes.llm.llm_utils -> core.file.models @@ -146,7 +139,6 @@ ignore_imports = core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform @@ -162,9 +154,6 @@ ignore_imports = core.workflow.workflow_entry -> core.app.workflow.node_factory core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager @@ -213,7 +202,6 @@ ignore_imports = core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output core.workflow.nodes.llm.node -> core.model_manager core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities @@ -229,7 +217,6 @@ ignore_imports = core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods core.workflow.nodes.llm.node -> models.dataset core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer core.workflow.nodes.llm.file_saver -> core.tools.signature @@ -287,8 +274,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> extensions.ext_database core.workflow.nodes.datasource.datasource_node -> extensions.ext_database core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database - core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis core.workflow.nodes.llm.file_saver -> extensions.ext_database core.workflow.nodes.llm.llm_utils -> extensions.ext_database core.workflow.nodes.llm.node -> extensions.ext_database diff --git a/api/core/app/workflow/node_factory.py b/api/core/app/workflow/node_factory.py index 6717be3ae6..18db750d28 100644 --- a/api/core/app/workflow/node_factory.py +++ b/api/core/app/workflow/node_factory.py @@ -8,6 +8,7 @@ from core.file.file_manager import file_manager from core.helper.code_executor.code_executor import CodeExecutor from core.helper.code_executor.code_node_provider import CodeNodeProvider from core.helper.ssrf_proxy import ssrf_proxy +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities.graph_config import NodeConfigDict from core.workflow.enums import NodeType @@ -16,6 +17,7 @@ from core.workflow.nodes.base.node import Node from core.workflow.nodes.code.code_node import CodeNode from core.workflow.nodes.code.limits import CodeNodeLimits from core.workflow.nodes.http_request.node import HttpRequestNode +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol from core.workflow.nodes.template_transform.template_renderer import ( @@ -75,6 +77,7 @@ class DifyNodeFactory(NodeFactory): self._http_request_http_client = http_request_http_client or ssrf_proxy self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory self._http_request_file_manager = http_request_file_manager or file_manager + self._rag_retrieval = DatasetRetrieval() @override def create_node(self, node_config: NodeConfigDict) -> Node: @@ -140,6 +143,15 @@ class DifyNodeFactory(NodeFactory): file_manager=self._http_request_file_manager, ) + if node_type == NodeType.KNOWLEDGE_RETRIEVAL: + return KnowledgeRetrievalNode( + id=node_id, + config=node_config, + graph_init_params=self.graph_init_params, + graph_runtime_state=self.graph_runtime_state, + rag_retrieval=self._rag_retrieval, + ) + return node_class( id=node_id, config=node_config, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 541c241ae5..a8133aa556 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -1,13 +1,15 @@ import json +import logging import math import re import threading +import time from collections import Counter, defaultdict from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app -from sqlalchemy import and_, literal, or_, select +from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session from core.app.app_config.entities import ( @@ -18,6 +20,7 @@ from core.app.app_config.entities import ( ) from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler +from core.db.session_factory import session_factory from core.entities.agent_entities import PlanningStrategy from core.entities.model_entities import ModelStatus from core.file import File, FileTransferMethod, FileType @@ -58,12 +61,30 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.repositories.rag_retrieval_protocol import ( + KnowledgeRetrievalRequest, + Source, + SourceChildChunk, + SourceMetadata, +) from extensions.ext_database import db +from extensions.ext_redis import redis_client from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile -from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding +from models.dataset import ( + ChildChunk, + Dataset, + DatasetMetadata, + DatasetQuery, + DocumentSegment, + RateLimitLog, + SegmentAttachmentBinding, +) from models.dataset import Document as DatasetDocument +from models.dataset import Document as DocumentModel from services.external_knowledge_service import ExternalDatasetService +from services.feature_service import FeatureService default_retrieval_model: dict[str, Any] = { "search_method": RetrievalMethod.SEMANTIC_SEARCH, @@ -73,6 +94,8 @@ default_retrieval_model: dict[str, Any] = { "score_threshold_enabled": False, } +logger = logging.getLogger(__name__) + class DatasetRetrieval: def __init__(self, application_generate_entity=None): @@ -91,6 +114,233 @@ class DatasetRetrieval: else: self._llm_usage = self._llm_usage.plus(usage) + def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: + self._check_knowledge_rate_limit(request.tenant_id) + available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids) + available_datasets_ids = [i.id for i in available_datasets] + if not available_datasets_ids: + return [] + + if not request.query: + return [] + + metadata_filter_document_ids, metadata_condition = None, None + + if request.metadata_filtering_mode != "disabled": + # Convert workflow layer types to app_config layer types + if not request.metadata_model_config: + raise ValueError("metadata_model_config is required for this method") + + app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump()) + + app_metadata_filtering_conditions = None + if request.metadata_filtering_conditions is not None: + app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate( + request.metadata_filtering_conditions.model_dump() + ) + + query = request.query if request.query is not None else "" + + metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition( + dataset_ids=available_datasets_ids, + query=query, + tenant_id=request.tenant_id, + user_id=request.user_id, + metadata_filtering_mode=request.metadata_filtering_mode, + metadata_model_config=app_metadata_model_config, + metadata_filtering_conditions=app_metadata_filtering_conditions, + inputs={}, + ) + + if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE: + planning_strategy = PlanningStrategy.REACT_ROUTER + # Ensure required fields are not None for single retrieval mode + if request.model_provider is None or request.model_name is None or request.query is None: + raise ValueError("model_provider, model_name, and query are required for single retrieval mode") + + model_manager = ModelManager() + model_instance = model_manager.get_model_instance( + tenant_id=request.tenant_id, + model_type=ModelType.LLM, + provider=request.model_provider, + model=request.model_name, + ) + + provider_model_bundle = model_instance.provider_model_bundle + model_type_instance = model_instance.model_type_instance + model_type_instance = cast(LargeLanguageModel, model_type_instance) + + model_credentials = model_instance.credentials + + # check model + provider_model = provider_model_bundle.configuration.get_provider_model( + model=request.model_name, model_type=ModelType.LLM + ) + + if provider_model is None: + raise exc.ModelNotExistError(f"Model {request.model_name} not exist.") + + if provider_model.status == ModelStatus.NO_CONFIGURE: + raise exc.ModelCredentialsNotInitializedError( + f"Model {request.model_name} credentials is not initialized." + ) + elif provider_model.status == ModelStatus.NO_PERMISSION: + raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.") + elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: + raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.") + + stop = [] + completion_params = (request.completion_params or {}).copy() + if "stop" in completion_params: + stop = completion_params["stop"] + del completion_params["stop"] + + model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials) + + if not model_schema: + raise exc.ModelNotExistError(f"Model {request.model_name} not exist.") + + model_config = ModelConfigWithCredentialsEntity( + provider=request.model_provider, + model=request.model_name, + model_schema=model_schema, + mode=request.model_mode or "chat", + provider_model_bundle=provider_model_bundle, + credentials=model_credentials, + parameters=completion_params, + stop=stop, + ) + all_documents = self.single_retrieve( + request.app_id, + request.tenant_id, + request.user_id, + request.user_from, + request.query, + available_datasets, + model_instance, + model_config, + planning_strategy, + None, # message_id + metadata_filter_document_ids, + metadata_condition, + ) + else: + all_documents = self.multiple_retrieve( + app_id=request.app_id, + tenant_id=request.tenant_id, + user_id=request.user_id, + user_from=request.user_from, + available_datasets=available_datasets, + query=request.query, + top_k=request.top_k, + score_threshold=request.score_threshold, + reranking_mode=request.reranking_mode, + reranking_model=request.reranking_model, + weights=request.weights, + reranking_enable=request.reranking_enable, + metadata_filter_document_ids=metadata_filter_document_ids, + metadata_condition=metadata_condition, + attachment_ids=request.attachment_ids, + ) + + dify_documents = [item for item in all_documents if item.provider == "dify"] + external_documents = [item for item in all_documents if item.provider == "external"] + retrieval_resource_list = [] + # deal with external documents + for item in external_documents: + source = Source( + metadata=SourceMetadata( + source="knowledge", + dataset_id=item.metadata.get("dataset_id"), + dataset_name=item.metadata.get("dataset_name"), + document_id=item.metadata.get("document_id"), + document_name=item.metadata.get("title"), + data_source_type="external", + retriever_from="workflow", + score=item.metadata.get("score"), + doc_metadata=item.metadata, + ), + title=item.metadata.get("title"), + content=item.page_content, + ) + retrieval_resource_list.append(source) + # deal with dify documents + if dify_documents: + records = RetrievalService.format_retrieval_documents(dify_documents) + dataset_ids = [i.segment.dataset_id for i in records] + document_ids = [i.segment.document_id for i in records] + + with session_factory.create_session() as session: + datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all() + documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all() + + dataset_map = {i.id: i for i in datasets} + document_map = {i.id: i for i in documents} + + if records: + for record in records: + segment = record.segment + dataset = dataset_map.get(segment.dataset_id) + document = document_map.get(segment.document_id) + + if dataset and document: + source = Source( + metadata=SourceMetadata( + source="knowledge", + dataset_id=dataset.id, + dataset_name=dataset.name, + document_id=document.id, + document_name=document.name, + data_source_type=document.data_source_type, + segment_id=segment.id, + retriever_from="workflow", + score=record.score or 0.0, + segment_hit_count=segment.hit_count, + segment_word_count=segment.word_count, + segment_position=segment.position, + segment_index_node_hash=segment.index_node_hash, + doc_metadata=document.doc_metadata, + child_chunks=[ + SourceChildChunk( + id=str(getattr(chunk, "id", "")), + content=str(getattr(chunk, "content", "")), + position=int(getattr(chunk, "position", 0)), + score=float(getattr(chunk, "score", 0.0)), + ) + for chunk in (record.child_chunks or []) + ], + position=None, + ), + title=document.name, + files=list(record.files) if record.files else None, + content=segment.get_sign_content(), + ) + if segment.answer: + source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" + + if record.summary: + source.summary = record.summary + + retrieval_resource_list.append(source) + + if retrieval_resource_list: + + def _score(item: Source) -> float: + meta = item.metadata + score = meta.score + if isinstance(score, (int, float)): + return float(score) + return 0.0 + + retrieval_resource_list = sorted( + retrieval_resource_list, + key=_score, # type: ignore[arg-type, return-value] + reverse=True, + ) + for position, item in enumerate(retrieval_resource_list, start=1): + item.metadata.position = position # type: ignore[index] + return retrieval_resource_list + def retrieve( self, app_id: str, @@ -150,14 +400,7 @@ class DatasetRetrieval: if features: if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: planning_strategy = PlanningStrategy.ROUTER - available_datasets = [] - - dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) - datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore - for dataset in datasets: - if dataset.available_document_count == 0 and dataset.provider != "external": - continue - available_datasets.append(dataset) + available_datasets = self._get_available_datasets(tenant_id, dataset_ids) if inputs: inputs = {key: str(value) for key, value in inputs.items()} @@ -1161,7 +1404,6 @@ class DatasetRetrieval: query=query or "", ) - result_text = "" try: # handle invoke result invoke_result = cast( @@ -1192,7 +1434,8 @@ class DatasetRetrieval: "condition": item.get("comparison_operator"), } ) - except Exception: + except Exception as e: + logger.warning(e, exc_info=True) return None return automatic_metadata_filters @@ -1406,7 +1649,12 @@ class DatasetRetrieval: usage = None for result in invoke_result: text = result.delta.message.content - full_text += text + if isinstance(text, str): + full_text += text + elif isinstance(text, list): + for i in text: + if i.data: + full_text += i.data if not model: model = result.model @@ -1524,3 +1772,53 @@ class DatasetRetrieval: cancel_event.set() if thread_exceptions is not None: thread_exceptions.append(e) + + def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]: + with session_factory.create_session() as session: + subquery = ( + session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count")) + .where( + DocumentModel.indexing_status == "completed", + DocumentModel.enabled == True, + DocumentModel.archived == False, + DocumentModel.dataset_id.in_(dataset_ids), + ) + .group_by(DocumentModel.dataset_id) + .having(func.count(DocumentModel.id) > 0) + .subquery() + ) + + results = ( + session.query(Dataset) + .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) + .where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids)) + .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) + .all() + ) + + available_datasets = [] + for dataset in results: + if not dataset: + continue + available_datasets.append(dataset) + return available_datasets + + def _check_knowledge_rate_limit(self, tenant_id: str): + knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id) + if knowledge_rate_limit.enabled: + current_time = int(time.time() * 1000) + key = f"rate_limit_{tenant_id}" + redis_client.zadd(key, {current_time: current_time}) + redis_client.zremrangebyscore(key, 0, current_time - 60000) + request_count = redis_client.zcard(key) + if request_count > knowledge_rate_limit.limit: + with session_factory.create_session() as session: + rate_limit_log = RateLimitLog( + tenant_id=tenant_id, + subscription_plan=knowledge_rate_limit.subscription_plan, + operation="knowledge", + ) + session.add(rate_limit_log) + raise exc.RateLimitExceededError( + "you have reached the knowledge base request rate limit of your subscription." + ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/exc.py b/api/core/workflow/nodes/knowledge_retrieval/exc.py index 6bcdc32790..13d36e8cc1 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/exc.py +++ b/api/core/workflow/nodes/knowledge_retrieval/exc.py @@ -20,3 +20,7 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError): class InvalidModelTypeError(KnowledgeRetrievalNodeError): """Raised when the model is not a Large Language Model.""" + + +class RateLimitExceededError(KnowledgeRetrievalNodeError): + """Raised when the rate limit is exceeded.""" diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 0827494a48..65c2792355 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -1,29 +1,10 @@ -import json import logging -import re -import time -from collections import defaultdict from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from sqlalchemy import and_, func, or_, select -from sqlalchemy.orm import sessionmaker +from typing import TYPE_CHECKING, Any, Literal from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.entities.agent_entities import PlanningStrategy -from core.entities.model_entities import ModelStatus -from core.model_manager import ModelInstance, ModelManager from core.model_runtime.entities.llm_entities import LLMUsage -from core.model_runtime.entities.message_entities import PromptMessageRole -from core.model_runtime.entities.model_entities import ModelFeature, ModelType -from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from core.model_runtime.utils.encoders import jsonable_encoder -from core.prompt.simple_prompt_transform import ModelMode -from core.rag.datasource.retrieval_service import RetrievalService -from core.rag.entities.metadata_entities import Condition, MetadataCondition -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.variables import ( ArrayFileSegment, FileSegment, @@ -36,35 +17,16 @@ from core.workflow.enums import ( WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult +from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import LLMUsageTrackingMixin from core.workflow.nodes.base.node import Node -from core.workflow.nodes.knowledge_retrieval.template_prompts import ( - METADATA_FILTER_ASSISTANT_PROMPT_1, - METADATA_FILTER_ASSISTANT_PROMPT_2, - METADATA_FILTER_COMPLETION_PROMPT, - METADATA_FILTER_SYSTEM_PROMPT, - METADATA_FILTER_USER_PROMPT_1, - METADATA_FILTER_USER_PROMPT_2, - METADATA_FILTER_USER_PROMPT_3, -) -from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from core.workflow.nodes.llm.node import LLMNode -from extensions.ext_database import db -from extensions.ext_redis import redis_client -from libs.json_in_md_parser import parse_and_check_json_markdown -from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog -from services.feature_service import FeatureService +from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source from .entities import KnowledgeRetrievalNodeData from .exc import ( - InvalidModelTypeError, KnowledgeRetrievalNodeError, - ModelCredentialsNotInitializedError, - ModelNotExistError, - ModelNotSupportedError, - ModelQuotaExceededError, + RateLimitExceededError, ) if TYPE_CHECKING: @@ -73,14 +35,6 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -default_retrieval_model = { - "search_method": RetrievalMethod.SEMANTIC_SEARCH, - "reranking_enable": False, - "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, - "top_k": 4, - "score_threshold_enabled": False, -} - class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]): node_type = NodeType.KNOWLEDGE_RETRIEVAL @@ -97,6 +51,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD config: Mapping[str, Any], graph_init_params: "GraphInitParams", graph_runtime_state: "GraphRuntimeState", + rag_retrieval: RAGRetrievalProtocol, *, llm_file_saver: LLMFileSaver | None = None, ): @@ -108,6 +63,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) # LLM file outputs, used for MultiModal outputs. self._file_outputs = [] + self._rag_retrieval = rag_retrieval if llm_file_saver is None: llm_file_saver = FileSaverImpl( @@ -121,6 +77,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD return "1" def _run(self) -> NodeRunResult: + usage = LLMUsage.empty_usage() if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector: return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, @@ -128,7 +85,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD process_data={}, outputs={}, metadata={}, - llm_usage=LLMUsage.empty_usage(), + llm_usage=usage, ) variables: dict[str, Any] = {} # extract variables @@ -156,36 +113,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD else: variables["attachments"] = [variable.value] - # TODO(-LAN-): Move this check outside. - # check rate limit - knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id) - if knowledge_rate_limit.enabled: - current_time = int(time.time() * 1000) - key = f"rate_limit_{self.tenant_id}" - redis_client.zadd(key, {current_time: current_time}) - redis_client.zremrangebyscore(key, 0, current_time - 60000) - request_count = redis_client.zcard(key) - if request_count > knowledge_rate_limit.limit: - with sessionmaker(db.engine).begin() as session: - # add ratelimit record - rate_limit_log = RateLimitLog( - tenant_id=self.tenant_id, - subscription_plan=knowledge_rate_limit.subscription_plan, - operation="knowledge", - ) - session.add(rate_limit_log) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error="Sorry, you have reached the knowledge base request rate limit of your subscription.", - error_type="RateLimitExceeded", - ) - - # retrieve knowledge - usage = LLMUsage.empty_usage() try: results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables) - outputs = {"result": ArrayObjectSegment(value=results)} + outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])} return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, @@ -198,9 +128,17 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD }, llm_usage=usage, ) - + except RateLimitExceededError as e: + logger.warning(e, exc_info=True) + return NodeRunResult( + status=WorkflowNodeExecutionStatus.FAILED, + inputs=variables, + error=str(e), + error_type=type(e).__name__, + llm_usage=usage, + ) except KnowledgeRetrievalNodeError as e: - logger.warning("Error when running knowledge retrieval node") + logger.warning("Error when running knowledge retrieval node", exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -210,6 +148,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD ) # Temporary handle all exceptions from DatasetRetrieval class here. except Exception as e: + logger.warning(e, exc_info=True) return NodeRunResult( status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, @@ -217,92 +156,47 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD error_type=type(e).__name__, llm_usage=usage, ) - finally: - db.session.close() def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] - ) -> tuple[list[dict[str, Any]], LLMUsage]: - usage = LLMUsage.empty_usage() - available_datasets = [] + ) -> tuple[list[Source], LLMUsage]: dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") - metadata_filter_document_ids = None - metadata_condition = None - metadata_usage = LLMUsage.empty_usage() - # Subquery: Count the number of available documents for each dataset - subquery = ( - db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count")) - .where( - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - Document.dataset_id.in_(dataset_ids), - ) - .group_by(Document.dataset_id) - .having(func.count(Document.id) > 0) - .subquery() - ) + retrieval_resource_list = [] - results = ( - db.session.query(Dataset) - .outerjoin(subquery, Dataset.id == subquery.c.dataset_id) - .where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids)) - .where((subquery.c.available_document_count > 0) | (Dataset.provider == "external")) - .all() - ) + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled" + if node_data.metadata_filtering_mode is not None: + metadata_filtering_mode = node_data.metadata_filtering_mode - # avoid blocking at retrieval - db.session.close() - - for dataset in results: - # pass if dataset is not available - if not dataset: - continue - available_datasets.append(dataset) - if query: - metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition( - [dataset.id for dataset in available_datasets], query, node_data - ) - usage = self._merge_usage(usage, metadata_usage) - all_documents = [] - dataset_retrieval = DatasetRetrieval() if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query: # fetch model config if node_data.single_retrieval_config is None: - raise ValueError("single_retrieval_config is required") - model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model) - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - # get model schema - model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials - ) - - if model_schema: - planning_strategy = PlanningStrategy.REACT_ROUTER - features = model_schema.features - if features: - if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features: - planning_strategy = PlanningStrategy.ROUTER - all_documents = dataset_retrieval.single_retrieve( - available_datasets=available_datasets, + raise ValueError("single_retrieval_config is required for single retrieval mode") + model = node_data.single_retrieval_config.model + retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( + request=KnowledgeRetrievalRequest( tenant_id=self.tenant_id, user_id=self.user_id, app_id=self.app_id, user_from=self.user_from.value, + dataset_ids=dataset_ids, + retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value, + completion_params=model.completion_params, + model_provider=model.provider, + model_mode=model.mode, + model_name=model.name, + metadata_model_config=node_data.metadata_model_config, + metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_mode=metadata_filtering_mode, query=query, - model_config=model_config, - model_instance=model_instance, - planning_strategy=planning_strategy, - metadata_filter_document_ids=metadata_filter_document_ids, - metadata_condition=metadata_condition, ) + ) elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE: if node_data.multiple_retrieval_config is None: raise ValueError("multiple_retrieval_config is required") + reranking_model = None + weights = None match node_data.multiple_retrieval_config.reranking_mode: case "reranking_model": if node_data.multiple_retrieval_config.reranking_model: @@ -329,284 +223,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD }, } case _: + # Handle any other reranking_mode values reranking_model = None weights = None - all_documents = dataset_retrieval.multiple_retrieve( - app_id=self.app_id, - tenant_id=self.tenant_id, - user_id=self.user_id, - user_from=self.user_from.value, - available_datasets=available_datasets, - query=query, - top_k=node_data.multiple_retrieval_config.top_k, - score_threshold=node_data.multiple_retrieval_config.score_threshold - if node_data.multiple_retrieval_config.score_threshold is not None - else 0.0, - reranking_mode=node_data.multiple_retrieval_config.reranking_mode, - reranking_model=reranking_model, - weights=weights, - reranking_enable=node_data.multiple_retrieval_config.reranking_enable, - metadata_filter_document_ids=metadata_filter_document_ids, - metadata_condition=metadata_condition, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, - ) - usage = self._merge_usage(usage, dataset_retrieval.llm_usage) - dify_documents = [item for item in all_documents if item.provider == "dify"] - external_documents = [item for item in all_documents if item.provider == "external"] - retrieval_resource_list = [] - # deal with external documents - for item in external_documents: - source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = { - "metadata": { - "_source": "knowledge", - "dataset_id": item.metadata.get("dataset_id"), - "dataset_name": item.metadata.get("dataset_name"), - "document_id": item.metadata.get("document_id") or item.metadata.get("title"), - "document_name": item.metadata.get("title"), - "data_source_type": "external", - "retriever_from": "workflow", - "score": item.metadata.get("score"), - "doc_metadata": item.metadata, - }, - "title": item.metadata.get("title"), - "content": item.page_content, - } - retrieval_resource_list.append(source) - # deal with dify documents - if dify_documents: - records = RetrievalService.format_retrieval_documents(dify_documents) - if records: - for record in records: - segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore - stmt = select(Document).where( - Document.id == segment.document_id, - Document.enabled == True, - Document.archived == False, - ) - document = db.session.scalar(stmt) - if dataset and document: - source = { - "metadata": { - "_source": "knowledge", - "dataset_id": dataset.id, - "dataset_name": dataset.name, - "document_id": document.id, - "document_name": document.name, - "data_source_type": document.data_source_type, - "segment_id": segment.id, - "retriever_from": "workflow", - "score": record.score or 0.0, - "child_chunks": [ - { - "id": str(getattr(chunk, "id", "")), - "content": str(getattr(chunk, "content", "")), - "position": int(getattr(chunk, "position", 0)), - "score": float(getattr(chunk, "score", 0.0)), - } - for chunk in (record.child_chunks or []) - ], - "segment_hit_count": segment.hit_count, - "segment_word_count": segment.word_count, - "segment_position": segment.position, - "segment_index_node_hash": segment.index_node_hash, - "doc_metadata": document.doc_metadata, - }, - "title": document.name, - "files": list(record.files) if record.files else None, - } - if segment.answer: - source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}" - else: - source["content"] = segment.get_sign_content() - # Add summary if available - if record.summary: - source["summary"] = record.summary - retrieval_resource_list.append(source) - if retrieval_resource_list: - retrieval_resource_list = sorted( - retrieval_resource_list, - key=self._score, # type: ignore[arg-type, return-value] - reverse=True, - ) - for position, item in enumerate(retrieval_resource_list, start=1): - item["metadata"]["position"] = position # type: ignore[index] - return retrieval_resource_list, usage - - def _score(self, item: dict[str, Any]) -> float: - meta = item.get("metadata") - if isinstance(meta, dict): - s = meta.get("score") - if isinstance(s, (int, float)): - return float(s) - return 0.0 - - def _get_metadata_filter_condition( - self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]: - usage = LLMUsage.empty_usage() - document_query = db.session.query(Document).where( - Document.dataset_id.in_(dataset_ids), - Document.indexing_status == "completed", - Document.enabled == True, - Document.archived == False, - ) - filters: list[Any] = [] - metadata_condition = None - match node_data.metadata_filtering_mode: - case "disabled": - return None, None, usage - case "automatic": - automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func( - dataset_ids, query, node_data + retrieval_resource_list = self._rag_retrieval.knowledge_retrieval( + request=KnowledgeRetrievalRequest( + app_id=self.app_id, + tenant_id=self.tenant_id, + user_id=self.user_id, + user_from=self.user_from.value, + dataset_ids=dataset_ids, + query=query, + retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value, + top_k=node_data.multiple_retrieval_config.top_k, + score_threshold=node_data.multiple_retrieval_config.score_threshold + if node_data.multiple_retrieval_config.score_threshold is not None + else 0.0, + reranking_mode=node_data.multiple_retrieval_config.reranking_mode, + reranking_model=reranking_model, + weights=weights, + reranking_enable=node_data.multiple_retrieval_config.reranking_enable, + metadata_model_config=node_data.metadata_model_config, + metadata_filtering_conditions=node_data.metadata_filtering_conditions, + metadata_filtering_mode=metadata_filtering_mode, + attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, ) - usage = self._merge_usage(usage, automatic_usage) - if automatic_metadata_filters: - conditions = [] - for sequence, filter in enumerate(automatic_metadata_filters): - DatasetRetrieval.process_metadata_filter_func( - sequence, - filter.get("condition", ""), - filter.get("metadata_name", ""), - filter.get("value"), - filters, - ) - conditions.append( - Condition( - name=filter.get("metadata_name"), # type: ignore - comparison_operator=filter.get("condition"), # type: ignore - value=filter.get("value"), - ) - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator - if node_data.metadata_filtering_conditions - else "or", - conditions=conditions, - ) - case "manual": - if node_data.metadata_filtering_conditions: - conditions = [] - for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore - metadata_name = condition.name - expected_value = condition.value - if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"): - if isinstance(expected_value, str): - expected_value = self.graph_runtime_state.variable_pool.convert_template( - expected_value - ).value[0] - if expected_value.value_type in {"number", "integer", "float"}: - expected_value = expected_value.value - elif expected_value.value_type == "string": - expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip() - else: - raise ValueError("Invalid expected metadata value type") - conditions.append( - Condition( - name=metadata_name, - comparison_operator=condition.comparison_operator, - value=expected_value, - ) - ) - filters = DatasetRetrieval.process_metadata_filter_func( - sequence, - condition.comparison_operator, - metadata_name, - expected_value, - filters, - ) - metadata_condition = MetadataCondition( - logical_operator=node_data.metadata_filtering_conditions.logical_operator, - conditions=conditions, - ) - case _: - raise ValueError("Invalid metadata filtering mode") - if filters: - if ( - node_data.metadata_filtering_conditions - and node_data.metadata_filtering_conditions.logical_operator == "and" - ): - document_query = document_query.where(and_(*filters)) - else: - document_query = document_query.where(or_(*filters)) - documents = document_query.all() - # group by dataset_id - metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore - for document in documents: - metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore - return metadata_filter_document_ids, metadata_condition, usage - - def _automatic_metadata_filter_func( - self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData - ) -> tuple[list[dict[str, Any]], LLMUsage]: - usage = LLMUsage.empty_usage() - # get all metadata field - stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids)) - metadata_fields = db.session.scalars(stmt).all() - all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields] - if node_data.metadata_model_config is None: - raise ValueError("metadata_model_config is required") - # get metadata model instance and fetch model config - model_instance, model_config = self.get_model_config(node_data.metadata_model_config) - # fetch prompt messages - prompt_template = self._get_prompt_template( - node_data=node_data, - metadata_fields=all_metadata_fields, - query=query or "", - ) - prompt_messages, stop = LLMNode.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query=query, - memory=None, - model_config=model_config, - sys_files=[], - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - tenant_id=self.tenant_id, - ) - - result_text = "" - try: - # handle invoke result - generator = LLMNode.invoke_llm( - node_data_model=node_data.metadata_model_config, - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, ) - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = self._merge_usage(usage, event.usage) - break - - result_text_json = parse_and_check_json_markdown(result_text, []) - automatic_metadata_filters = [] - if "metadata_map" in result_text_json: - metadata_map = result_text_json["metadata_map"] - for item in metadata_map: - if item.get("metadata_field_name") in all_metadata_fields: - automatic_metadata_filters.append( - { - "metadata_name": item.get("metadata_field_name"), - "value": item.get("metadata_field_value"), - "condition": item.get("comparison_operator"), - } - ) - except Exception: - return [], usage - return automatic_metadata_filters, usage + usage = self._rag_retrieval.llm_usage + return retrieval_resource_list, usage @classmethod def _extract_variable_selector_to_variable_mapping( @@ -626,107 +272,3 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD if typed_node_data.query_attachment_selector: variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector return variable_mapping - - def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: - model_name = model.name - provider_name = model.provider - - model_manager = ModelManager() - model_instance = model_manager.get_model_instance( - tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name - ) - - provider_model_bundle = model_instance.provider_model_bundle - model_type_instance = model_instance.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_credentials = model_instance.credentials - - # check model - provider_model = provider_model_bundle.configuration.get_provider_model( - model=model_name, model_type=ModelType.LLM - ) - - if provider_model is None: - raise ModelNotExistError(f"Model {model_name} not exist.") - - if provider_model.status == ModelStatus.NO_CONFIGURE: - raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.") - elif provider_model.status == ModelStatus.NO_PERMISSION: - raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.") - elif provider_model.status == ModelStatus.QUOTA_EXCEEDED: - raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.") - - # model config - completion_params = model.completion_params - stop = [] - if "stop" in completion_params: - stop = completion_params["stop"] - del completion_params["stop"] - - # get model mode - model_mode = model.mode - if not model_mode: - raise ModelNotExistError("LLM mode is required.") - - model_schema = model_type_instance.get_model_schema(model_name, model_credentials) - - if not model_schema: - raise ModelNotExistError(f"Model {model_name} not exist.") - - return model_instance, ModelConfigWithCredentialsEntity( - provider=provider_name, - model=model_name, - model_schema=model_schema, - mode=model_mode, - provider_model_bundle=provider_model_bundle, - credentials=model_credentials, - parameters=completion_params, - stop=stop, - ) - - def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str): - model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore - input_text = query - - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=METADATA_FILTER_USER_PROMPT_3.format( - input_text=input_text, - metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == ModelMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=METADATA_FILTER_COMPLETION_PROMPT.format( - input_text=input_text, - metadata_fields=json.dumps(metadata_fields, ensure_ascii=False), - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/core/workflow/repositories/rag_retrieval_protocol.py b/api/core/workflow/repositories/rag_retrieval_protocol.py new file mode 100644 index 0000000000..f91cecb694 --- /dev/null +++ b/api/core/workflow/repositories/rag_retrieval_protocol.py @@ -0,0 +1,108 @@ +from typing import Any, Literal, Protocol + +from pydantic import BaseModel, Field + +from core.model_runtime.entities import LLMUsage +from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition +from core.workflow.nodes.llm.entities import ModelConfig + + +class SourceChildChunk(BaseModel): + id: str = Field(default="", description="Child chunk ID") + content: str = Field(default="", description="Child chunk content") + position: int = Field(default=0, description="Child chunk position") + score: float = Field(default=0.0, description="Child chunk relevance score") + + +class SourceMetadata(BaseModel): + source: str = Field( + default="knowledge", + serialization_alias="_source", + description="Data source identifier", + ) + dataset_id: str = Field(description="Dataset unique identifier") + dataset_name: str = Field(description="Dataset display name") + document_id: str = Field(description="Document unique identifier") + document_name: str = Field(description="Document display name") + data_source_type: str = Field(description="Type of data source") + segment_id: str | None = Field(default=None, description="Segment unique identifier") + retriever_from: str = Field(default="workflow", description="Retriever source context") + score: float = Field(default=0.0, description="Retrieval relevance score") + child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks") + segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved") + segment_word_count: int | None = Field(default=0, description="Word count of the segment") + segment_position: int | None = Field(default=0, description="Position of segment in document") + segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment") + doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata") + position: int | None = Field(default=0, description="Position of the document in the dataset") + + class Config: + populate_by_name = True + + +class Source(BaseModel): + metadata: SourceMetadata = Field(description="Source metadata information") + title: str = Field(description="Document title") + files: list[Any] | None = Field(default=None, description="Associated file references") + content: str | None = Field(description="Segment content text") + summary: str | None = Field(default=None, description="Content summary if available") + + +class KnowledgeRetrievalRequest(BaseModel): + tenant_id: str = Field(description="Tenant unique identifier") + user_id: str = Field(description="User unique identifier") + app_id: str = Field(description="Application unique identifier") + user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") + query: str | None = Field(default=None, description="Query text for knowledge retrieval") + retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") + model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')") + completion_params: dict[str, Any] | None = Field( + default=None, description="Model completion parameters (e.g., temperature, max_tokens)" + ) + model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')") + model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')") + metadata_model_config: ModelConfig | None = Field( + default=None, description="Model config for metadata-based filtering" + ) + metadata_filtering_conditions: MetadataFilteringCondition | None = Field( + default=None, description="Conditions for filtering by metadata" + ) + metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field( + default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'" + ) + top_k: int = Field(default=0, description="Number of top results to return") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + reranking_mode: str = Field(default="reranking_model", description="Reranking strategy") + reranking_model: dict | None = Field(default=None, description="Reranking model configuration") + weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking") + reranking_enable: bool = Field(default=True, description="Whether reranking is enabled") + attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval") + + +class RAGRetrievalProtocol(Protocol): + """Protocol for RAG-based knowledge retrieval implementations. + + Implementations of this protocol handle knowledge retrieval from datasets + including rate limiting, dataset filtering, and document retrieval. + """ + + @property + def llm_usage(self) -> LLMUsage: + """Return accumulated LLM usage for retrieval operations.""" + ... + + def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]: + """Retrieve knowledge from datasets based on the provided request. + + Args: + request: Knowledge retrieval request with search parameters + + Returns: + List of sources matching the search criteria + + Raises: + RateLimitExceededError: If rate limit is exceeded + ModelNotExistError: If specified model doesn't exist + """ + ... diff --git a/api/tests/integration_tests/workflow/nodes/knowledge_retrieval/__init__.py b/api/tests/integration_tests/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/integration_tests/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node_integration.py b/api/tests/integration_tests/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node_integration.py new file mode 100644 index 0000000000..d029115df1 --- /dev/null +++ b/api/tests/integration_tests/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node_integration.py @@ -0,0 +1,29 @@ +""" +Integration tests for KnowledgeRetrievalNode. + +This module provides integration tests for KnowledgeRetrievalNode with real database interactions. + +Note: These tests require database setup and are more complex than unit tests. +For now, we focus on unit tests which provide better coverage for the node logic. +""" + +import pytest + + +class TestKnowledgeRetrievalNodeIntegration: + """ + Integration test suite for KnowledgeRetrievalNode. + + Note: Full integration tests require: + - Database setup with datasets and documents + - Vector store for embeddings + - Model providers for retrieval + + For now, unit tests provide comprehensive coverage of the node logic. + """ + + @pytest.mark.skip(reason="Integration tests require full database and vector store setup") + def test_end_to_end_knowledge_retrieval(self): + """Test end-to-end knowledge retrieval workflow.""" + # TODO: Implement with real database + pass diff --git a/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py new file mode 100644 index 0000000000..4e6cc620ac --- /dev/null +++ b/api/tests/test_containers_integration_tests/core/rag/retrieval/test_dataset_retrieval_integration.py @@ -0,0 +1,614 @@ +import uuid +from unittest.mock import patch + +import pytest +from faker import Faker + +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from models.dataset import Dataset, Document +from services.account_service import AccountService, TenantService + + +class TestGetAvailableDatasetsIntegration: + def test_returns_datasets_with_available_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + # Create account and tenant + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + description=fake.text(max_nb_chars=100), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + indexing_technique="high_quality", + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + # Create documents with completed status, enabled, not archived + for i in range(3): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + name=f"Document {i}", + created_from="web", + created_by=account.id, + doc_form="text_model", + doc_language="en", + indexing_status="completed", + enabled=True, + archived=False, + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id]) + + # Assert + assert len(result) == 1 + assert result[0].id == dataset.id + assert result[0].tenant_id == tenant.id + assert result[0].name == dataset.name + + def test_filters_out_datasets_with_only_archived_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + + # Create only archived documents + for i in range(2): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=f"Archived Document {i}", + created_by=account.id, + doc_form="text_model", + indexing_status="completed", + enabled=True, + archived=True, # Archived + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id]) + + # Assert + assert len(result) == 0 + + def test_filters_out_datasets_with_only_disabled_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + + # Create only disabled documents + for i in range(2): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=f"Disabled Document {i}", + created_by=account.id, + doc_form="text_model", + indexing_status="completed", + enabled=False, # Disabled + archived=False, + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id]) + + # Assert + assert len(result) == 0 + + def test_filters_out_datasets_with_non_completed_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + + # Create documents with non-completed status + for i, status in enumerate(["indexing", "parsing", "splitting"]): + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=i, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=f"Document {status}", + created_by=account.id, + doc_form="text_model", + indexing_status=status, # Not completed + enabled=True, + archived=False, + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id]) + + # Assert + assert len(result) == 0 + + def test_includes_external_datasets_without_documents( + self, db_session_with_containers, mock_external_service_dependencies + ): + """ + Test that external datasets are returned even with no available documents. + + External datasets (e.g., from external knowledge bases) don't have + documents stored in Dify's database, so they should always be available. + + Verifies: + - External datasets are included in results + - No document count check for external datasets + """ + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="external", # External provider + data_source_type="external", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id]) + + # Assert + assert len(result) == 1 + assert result[0].id == dataset.id + assert result[0].provider == "external" + + def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies): + # Arrange + fake = Faker() + + # Create two accounts/tenants + account1 = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company()) + tenant1 = account1.current_tenant + + account2 = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company()) + tenant2 = account2.current_tenant + + # Create dataset for tenant1 + dataset1 = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant1.id, + name="Tenant 1 Dataset", + provider="dify", + data_source_type="upload_file", + created_by=account1.id, + ) + db_session_with_containers.add(dataset1) + + # Create dataset for tenant2 + dataset2 = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant2.id, + name="Tenant 2 Dataset", + provider="dify", + data_source_type="upload_file", + created_by=account2.id, + ) + db_session_with_containers.add(dataset2) + + # Add documents to both datasets + for dataset, account in [(dataset1, account1), (dataset2, account2)]: + document = Document( + id=str(uuid.uuid4()), + tenant_id=dataset.tenant_id, + dataset_id=dataset.id, + position=0, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=f"Document for {dataset.name}", + created_by=account.id, + doc_form="text_model", + indexing_status="completed", + enabled=True, + archived=False, + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act - request from tenant1, should only get tenant1's dataset + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant1.id, [dataset1.id, dataset2.id]) + + # Assert + assert len(result) == 1 + assert result[0].id == dataset1.id + assert result[0].tenant_id == tenant1.id + + def test_returns_empty_list_when_no_datasets_found( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Don't create any datasets + + # Act + dataset_retrieval = DatasetRetrieval() + result = dataset_retrieval._get_available_datasets(tenant.id, [str(uuid.uuid4())]) + + # Assert + assert result == [] + + def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create multiple datasets + datasets = [] + for i in range(3): + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=f"Dataset {i}", + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + datasets.append(dataset) + + # Add document + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=f"Document {i}", + created_by=account.id, + doc_form="text_model", + indexing_status="completed", + enabled=True, + archived=False, + ) + db_session_with_containers.add(document) + + db_session_with_containers.commit() + + # Act - request only dataset 0 and 2, not dataset 1 + dataset_retrieval = DatasetRetrieval() + requested_ids = [datasets[0].id, datasets[2].id] + result = dataset_retrieval._get_available_datasets(tenant.id, requested_ids) + + # Assert + assert len(result) == 2 + returned_ids = {d.id for d in result} + assert returned_ids == {datasets[0].id, datasets[2].id} + + +class TestKnowledgeRetrievalIntegration: + def test_knowledge_retrieval_with_available_datasets( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + indexing_technique="high_quality", + ) + db_session_with_containers.add(dataset) + + document = Document( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + dataset_id=dataset.id, + position=0, + data_source_type="upload_file", + batch=str(uuid.uuid4()), # Required field + created_from="web", + name=fake.sentence(), + created_by=account.id, + indexing_status="completed", + enabled=True, + archived=False, + doc_form="text_model", + ) + db_session_with_containers.add(document) + db_session_with_containers.commit() + + # Create request + request = KnowledgeRetrievalRequest( + tenant_id=tenant.id, + user_id=account.id, + app_id=str(uuid.uuid4()), + user_from="web", + dataset_ids=[dataset.id], + query="test query", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit check and retrieval + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + + def test_knowledge_retrieval_no_available_datasets( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + # Create dataset but no documents + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + request = KnowledgeRetrievalRequest( + tenant_id=tenant.id, + user_id=account.id, + app_id=str(uuid.uuid4()), + user_from="web", + dataset_ids=[dataset.id], + query="test query", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit check + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded( + self, db_session_with_containers, mock_external_service_dependencies + ): + # Arrange + fake = Faker() + + account = AccountService.create_account( + email=fake.email(), + name=fake.name(), + interface_language="en-US", + password=fake.password(length=12), + ) + TenantService.create_owner_tenant_if_not_exist(account, name=fake.company()) + tenant = account.current_tenant + + dataset = Dataset( + id=str(uuid.uuid4()), + tenant_id=tenant.id, + name=fake.company(), + provider="dify", + data_source_type="upload_file", + created_by=account.id, + ) + db_session_with_containers.add(dataset) + db_session_with_containers.commit() + + request = KnowledgeRetrievalRequest( + tenant_id=tenant.id, + user_id=account.id, + app_id=str(uuid.uuid4()), + user_from="web", + dataset_ids=[dataset.id], + query="test query", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit check to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=Exception("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(Exception, match="Rate limit exceeded"): + dataset_retrieval.knowledge_retrieval(request) + + +@pytest.fixture +def mock_external_service_dependencies(): + with ( + patch("services.account_service.FeatureService") as mock_account_feature_service, + ): + # Setup default mock returns for account service + mock_account_feature_service.get_system_features.return_value.is_allow_register = True + + yield { + "account_feature_service": mock_account_feature_service, + } diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py new file mode 100644 index 0000000000..4bc802dc23 --- /dev/null +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval_methods.py @@ -0,0 +1,715 @@ +from unittest.mock import MagicMock, Mock, patch +from uuid import uuid4 + +import pytest + +from core.rag.models.document import Document +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.nodes.knowledge_retrieval import exc +from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest +from models.dataset import Dataset + +# ==================== Helper Functions ==================== + + +def create_mock_dataset( + dataset_id: str | None = None, + tenant_id: str | None = None, + provider: str = "dify", + indexing_technique: str = "high_quality", + available_document_count: int = 10, +) -> Mock: + """ + Create a mock Dataset object for testing. + + Args: + dataset_id: Unique identifier for the dataset + tenant_id: Tenant ID for the dataset + provider: Provider type ("dify" or "external") + indexing_technique: Indexing technique ("high_quality" or "economy") + available_document_count: Number of available documents + + Returns: + Mock: A properly configured Dataset mock + """ + dataset = Mock(spec=Dataset) + dataset.id = dataset_id or str(uuid4()) + dataset.tenant_id = tenant_id or str(uuid4()) + dataset.name = "test_dataset" + dataset.provider = provider + dataset.indexing_technique = indexing_technique + dataset.available_document_count = available_document_count + dataset.embedding_model = "text-embedding-ada-002" + dataset.embedding_model_provider = "openai" + dataset.retrieval_model = { + "search_method": "semantic_search", + "reranking_enable": False, + "top_k": 4, + "score_threshold_enabled": False, + } + return dataset + + +def create_mock_document( + content: str, + doc_id: str, + score: float = 0.8, + provider: str = "dify", + additional_metadata: dict | None = None, +) -> Document: + """ + Create a mock Document object for testing. + + Args: + content: The text content of the document + doc_id: Unique identifier for the document chunk + score: Relevance score (0.0 to 1.0) + provider: Document provider ("dify" or "external") + additional_metadata: Optional extra metadata fields + + Returns: + Document: A properly structured Document object + """ + metadata = { + "doc_id": doc_id, + "document_id": str(uuid4()), + "dataset_id": str(uuid4()), + "score": score, + } + + if additional_metadata: + metadata.update(additional_metadata) + + return Document( + page_content=content, + metadata=metadata, + provider=provider, + ) + + +# ==================== Test _check_knowledge_rate_limit ==================== + + +class TestCheckKnowledgeRateLimit: + """ + Test suite for _check_knowledge_rate_limit method. + + The _check_knowledge_rate_limit method validates whether a tenant has + exceeded their knowledge retrieval rate limit. This is important for: + - Preventing abuse of the knowledge retrieval system + - Enforcing subscription plan limits + - Tracking usage for billing purposes + + Test Cases: + ============ + 1. Rate limit disabled - no exception raised + 2. Rate limit enabled but not exceeded - no exception raised + 3. Rate limit enabled and exceeded - RateLimitExceededError raised + 4. Redis operations are performed correctly + 5. RateLimitLog is created when limit is exceeded + """ + + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service): + """ + Test that when rate limit is disabled, no exception is raised. + + This test verifies the behavior when the tenant's subscription + does not have rate limiting enabled. + + Verifies: + - FeatureService.get_knowledge_rate_limit is called + - No Redis operations are performed + - No exception is raised + - Retrieval proceeds normally + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit disabled + mock_limit = Mock() + mock_limit.enabled = False + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify FeatureService was called + mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id) + + # Verify no Redis operations were performed + assert not mock_redis.zadd.called + assert not mock_redis.zremrangebyscore.called + assert not mock_redis.zcard.called + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory): + """ + Test that when rate limit is enabled but not exceeded, no exception is raised. + + This test simulates a tenant making requests within their rate limit. + The Redis sorted set stores timestamps of recent requests, and old + requests (older than 60 seconds) are removed. + + Verifies: + - Redis zadd is called to track the request + - Redis zremrangebyscore removes old entries + - Redis zcard returns count within limit + - No exception is raised + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 # Current time in milliseconds + mock_time.time.return_value = current_time / 1000 # Return seconds + mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds + + # Mock Redis operations + # zcard returns 50 (within limit of 100) + mock_redis.zcard.return_value = 50 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert - should not raise any exception + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify Redis operations + expected_key = f"rate_limit_{tenant_id}" + mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time}) + mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000) + mock_redis.zcard.assert_called_once_with(expected_key) + + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + @patch("core.rag.retrieval.dataset_retrieval.FeatureService") + @patch("core.rag.retrieval.dataset_retrieval.redis_client") + @patch("core.rag.retrieval.dataset_retrieval.time") + def test_rate_limit_enabled_exceeded_raises_exception( + self, mock_time, mock_redis, mock_feature_service, mock_session_factory + ): + """ + Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised. + + This test simulates a tenant exceeding their rate limit. When the count + of recent requests exceeds the limit, an exception should be raised and + a RateLimitLog should be created. + + Verifies: + - Redis zcard returns count exceeding limit + - RateLimitExceededError is raised with correct message + - RateLimitLog is created in database + - Session operations are performed correctly + """ + # Arrange + tenant_id = str(uuid4()) + dataset_retrieval = DatasetRetrieval() + + # Mock rate limit enabled with limit of 100 requests per minute + mock_limit = Mock() + mock_limit.enabled = True + mock_limit.limit = 100 + mock_limit.subscription_plan = "professional" + mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit + + # Mock time + current_time = 1234567890000 + mock_time.time.return_value = current_time / 1000 + + # Mock Redis operations - return count exceeding limit + mock_redis.zcard.return_value = 150 # Exceeds limit of 100 + + # Mock session_factory.create_session + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + # Act & Assert + with pytest.raises(exc.RateLimitExceededError) as exc_info: + dataset_retrieval._check_knowledge_rate_limit(tenant_id) + + # Verify exception message + assert "knowledge base request rate limit" in str(exc_info.value) + + # Verify RateLimitLog was created + mock_session.add.assert_called_once() + added_log = mock_session.add.call_args[0][0] + assert added_log.tenant_id == tenant_id + assert added_log.subscription_plan == "professional" + assert added_log.operation == "knowledge" + + +# ==================== Test _get_available_datasets ==================== + + +class TestGetAvailableDatasets: + """ + Test suite for _get_available_datasets method. + + The _get_available_datasets method retrieves datasets that are available + for retrieval. A dataset is considered available if: + - It belongs to the specified tenant + - It's in the list of requested dataset_ids + - It has at least one completed, enabled, non-archived document OR + - It's an external provider dataset + + Note: Due to SQLAlchemy subquery complexity, full testing is done in + integration tests. Unit tests here verify basic behavior. + """ + + def test_method_exists_and_has_correct_signature(self): + """ + Test that the method exists and has the correct signature. + + Verifies: + - Method exists on DatasetRetrieval class + - Accepts tenant_id and dataset_ids parameters + """ + # Arrange + dataset_retrieval = DatasetRetrieval() + + # Assert - method exists + assert hasattr(dataset_retrieval, "_get_available_datasets") + # Assert - method is callable + assert callable(dataset_retrieval._get_available_datasets) + + +# ==================== Test knowledge_retrieval ==================== + + +class TestDatasetRetrievalKnowledgeRetrieval: + """ + Test suite for knowledge_retrieval method. + + The knowledge_retrieval method is the main entry point for retrieving + knowledge from datasets. It orchestrates the entire retrieval process: + 1. Checks rate limits + 2. Gets available datasets + 3. Applies metadata filtering if enabled + 4. Performs retrieval (single or multiple mode) + 5. Formats and returns results + + Test Cases: + ============ + 1. Single mode retrieval + 2. Multiple mode retrieval + 3. Metadata filtering disabled + 4. Metadata filtering automatic + 5. Metadata filtering manual + 6. External documents handling + 7. Dify documents handling + 8. Empty results handling + 9. Rate limit exceeded + 10. No available datasets + """ + + def test_knowledge_retrieval_single_mode_basic(self): + """ + Test knowledge_retrieval in single retrieval mode - basic check. + + Note: Full single mode testing requires complex model mocking and + is better suited for integration tests. This test verifies the + method accepts single mode requests. + + Verifies: + - Method can accept single mode request + - Request parameters are correctly structured + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="single", + model_provider="openai", + model_name="gpt-4", + model_mode="chat", + completion_params={"temperature": 0.7}, + ) + + # Assert - request is properly structured + assert request.retrieval_mode == "single" + assert request.model_provider == "openai" + assert request.model_name == "gpt-4" + assert request.model_mode == "chat" + + @patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor") + @patch("core.rag.retrieval.dataset_retrieval.session_factory") + def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor): + """ + Test knowledge_retrieval in multiple retrieval mode. + + In multiple mode, retrieval is performed across all datasets and + results are combined and reranked. + + Verifies: + - Rate limit is checked + - Available datasets are retrieved + - Multiple retrieval is performed + - Results are combined and reranked + - Results are formatted correctly + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id1 = str(uuid4()) + dataset_id2 = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id1, dataset_id2], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + score_threshold=0.7, + reranking_enable=True, + reranking_mode="reranking_model", + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"}, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets + mock_dataset1 = create_mock_dataset(dataset_id=dataset_id1, tenant_id=tenant_id) + mock_dataset2 = create_mock_dataset(dataset_id=dataset_id2, tenant_id=tenant_id) + with patch.object( + dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2] + ): + # Mock get_metadata_filter_condition + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return documents + doc1 = create_mock_document("Python is great", "doc1", score=0.9) + doc2 = create_mock_document("Python is awesome", "doc2", score=0.8) + with patch.object( + dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2] + ) as mock_multiple_retrieve: + # Mock format_retrieval_documents + mock_record = Mock() + mock_record.segment = Mock() + mock_record.segment.dataset_id = dataset_id1 + mock_record.segment.document_id = str(uuid4()) + mock_record.segment.index_node_hash = "hash123" + mock_record.segment.hit_count = 5 + mock_record.segment.word_count = 100 + mock_record.segment.position = 1 + mock_record.segment.get_sign_content.return_value = "Python is great" + mock_record.segment.answer = None + mock_record.score = 0.9 + mock_record.child_chunks = [] + mock_record.summary = None + mock_record.files = None + + mock_retrieval_service = Mock() + mock_retrieval_service.format_retrieval_documents.return_value = [mock_record] + + with patch( + "core.rag.retrieval.dataset_retrieval.RetrievalService", + return_value=mock_retrieval_service, + ): + # Mock database queries + mock_session = MagicMock() + mock_session_factory.create_session.return_value.__enter__.return_value = mock_session + mock_session_factory.create_session.return_value.__exit__.return_value = None + + mock_dataset_from_db = Mock() + mock_dataset_from_db.id = dataset_id1 + mock_dataset_from_db.name = "test_dataset" + + mock_document = Mock() + mock_document.id = str(uuid4()) + mock_document.name = "test_doc" + mock_document.data_source_type = "upload_file" + mock_document.doc_metadata = {} + + mock_session.query.return_value.filter.return_value.all.return_value = [ + mock_dataset_from_db + ] + mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter( + [mock_dataset_from_db, mock_document] + ) + + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + mock_multiple_retrieve.assert_called_once() + + def test_knowledge_retrieval_metadata_filtering_disabled(self): + """ + Test knowledge_retrieval with metadata filtering disabled. + + When metadata filtering is disabled, get_metadata_filter_condition is + NOT called (the method checks metadata_filtering_mode != "disabled"). + + Verifies: + - get_metadata_filter_condition is NOT called when mode is "disabled" + - Retrieval proceeds without metadata filters + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + metadata_filtering_mode="disabled", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + # Mock get_metadata_filter_condition - should NOT be called when disabled + with patch.object( + dataset_retrieval, + "get_metadata_filter_condition", + return_value=(None, None), + ) as mock_get_metadata: + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + # get_metadata_filter_condition should NOT be called when mode is "disabled" + mock_get_metadata.assert_not_called() + + def test_knowledge_retrieval_with_external_documents(self): + """ + Test knowledge_retrieval with external documents. + + External documents come from external knowledge bases and should + be formatted differently than Dify documents. + + Verifies: + - External documents are handled correctly + - Provider is set to "external" + - Metadata includes external-specific fields + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id, provider="external") + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Create external document + external_doc = create_mock_document( + "External knowledge", + "doc1", + score=0.9, + provider="external", + additional_metadata={ + "dataset_id": dataset_id, + "dataset_name": "external_kb", + "document_id": "ext_doc1", + "title": "External Document", + }, + ) + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert isinstance(result, list) + if result: + assert result[0].metadata.data_source_type == "external" + + def test_knowledge_retrieval_empty_results(self): + """ + Test knowledge_retrieval when no documents are found. + + Verifies: + - Empty list is returned + - No errors are raised + - All dependencies are still called + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id) + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]): + with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)): + # Mock multiple_retrieve to return empty list + with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_rate_limit_exceeded(self): + """ + Test knowledge_retrieval when rate limit is exceeded. + + Verifies: + - RateLimitExceededError is raised + - No further processing occurs + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock _check_knowledge_rate_limit to raise exception + with patch.object( + dataset_retrieval, + "_check_knowledge_rate_limit", + side_effect=exc.RateLimitExceededError("Rate limit exceeded"), + ): + # Act & Assert + with pytest.raises(exc.RateLimitExceededError): + dataset_retrieval.knowledge_retrieval(request) + + def test_knowledge_retrieval_no_available_datasets(self): + """ + Test knowledge_retrieval when no datasets are available. + + Verifies: + - Empty list is returned + - No retrieval is attempted + """ + # Arrange + tenant_id = str(uuid4()) + user_id = str(uuid4()) + app_id = str(uuid4()) + dataset_id = str(uuid4()) + + request = KnowledgeRetrievalRequest( + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + user_from="web", + dataset_ids=[dataset_id], + query="What is Python?", + retrieval_mode="multiple", + top_k=5, + ) + + dataset_retrieval = DatasetRetrieval() + + # Mock dependencies + with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"): + # Mock _get_available_datasets to return empty list + with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]): + # Act + result = dataset_retrieval.knowledge_retrieval(request) + + # Assert + assert result == [] + + def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self): + """ + Test that knowledge_retrieval processes multiple documents with different scores. + + Note: Full sorting and position testing requires complex SQLAlchemy mocking + which is better suited for integration tests. This test verifies documents + with different scores can be created and have their metadata. + + Verifies: + - Documents can be created with different scores + - Score metadata is properly set + """ + # Create documents with different scores + doc1 = create_mock_document("Low score", "doc1", score=0.6) + doc2 = create_mock_document("High score", "doc2", score=0.95) + doc3 = create_mock_document("Medium score", "doc3", score=0.8) + + # Assert - each document has the correct score + assert doc1.metadata["score"] == 0.6 + assert doc2.metadata["score"] == 0.95 + assert doc3.metadata["score"] == 0.8 + + # Assert - documents are correctly sorted (not the retrieval result, just the list) + unsorted = [doc1, doc2, doc3] + sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True) + assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6] diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/__init__.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py new file mode 100644 index 0000000000..5733b2cf5b --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -0,0 +1,595 @@ +import time +import uuid +from unittest.mock import Mock + +import pytest + +from core.app.entities.app_invoke_entities import InvokeFrom +from core.model_runtime.entities.llm_entities import LLMUsage +from core.variables import StringSegment +from core.workflow.entities import GraphInitParams +from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.nodes.knowledge_retrieval.entities import ( + KnowledgeRetrievalNodeData, + MultipleRetrievalConfig, + RerankingModelConfig, + SingleRetrievalConfig, +) +from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError +from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode +from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source +from core.workflow.runtime import GraphRuntimeState, VariablePool +from core.workflow.system_variable import SystemVariable +from models.enums import UserFrom + + +@pytest.fixture +def mock_graph_init_params(): + """Create mock GraphInitParams.""" + return GraphInitParams( + tenant_id=str(uuid.uuid4()), + app_id=str(uuid.uuid4()), + workflow_id=str(uuid.uuid4()), + graph_config={}, + user_id=str(uuid.uuid4()), + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + """Create mock GraphRuntimeState.""" + variable_pool = VariablePool( + system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + user_inputs={}, + environment_variables=[], + conversation_variables=[], + ) + return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + + +@pytest.fixture +def mock_rag_retrieval(): + """Create mock RAGRetrievalProtocol.""" + mock_retrieval = Mock(spec=RAGRetrievalProtocol) + mock_retrieval.knowledge_retrieval.return_value = [] + mock_retrieval.llm_usage = LLMUsage.empty_usage() + return mock_retrieval + + +@pytest.fixture +def sample_node_data(): + """Create sample KnowledgeRetrievalNodeData.""" + return KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=5, + score_threshold=0.7, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RerankingModelConfig( + provider="cohere", + model="rerank-v2", + ), + ), + ) + + +class TestKnowledgeRetrievalNode: + """ + Test suite for KnowledgeRetrievalNode. + """ + + def test_node_initialization(self, mock_graph_init_params, mock_graph_runtime_state, mock_rag_retrieval): + """Test KnowledgeRetrievalNode initialization.""" + # Arrange + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": { + "title": "Knowledge Retrieval", + "type": "knowledge-retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + }, + } + + # Act + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Assert + assert node.id == node_id + assert node._rag_retrieval == mock_rag_retrieval + assert node._llm_file_saver is not None + + def test_run_with_no_query_or_attachment( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run returns success when no query or attachment is provided.""" + # Arrange + sample_node_data.query_variable_selector = None + sample_node_data.query_attachment_selector = None + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs == {} + assert mock_rag_retrieval.knowledge_retrieval.call_count == 0 + + def test_run_with_query_variable_single_mode( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """Test _run with query variable in single mode.""" + # Arrange + from core.workflow.nodes.llm.entities import ModelConfig + + query = "What is Python?" + query_selector = ["start", "query"] + + # Add query to variable pool + mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query)) + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="single", + query_variable_selector=query_selector, + single_retrieval_config=SingleRetrievalConfig( + model=ModelConfig( + provider="openai", + name="gpt-4", + mode="chat", + completion_params={"temperature": 0.7}, + ) + ), + ) + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": node_data.model_dump(), + } + + # Mock retrieval response + mock_source = Mock(spec=Source) + mock_source.model_dump.return_value = {"content": "Python is a programming language"} + mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert "result" in result.outputs + assert mock_rag_retrieval.knowledge_retrieval.called + + def test_run_with_query_variable_multiple_mode( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run with query variable in multiple mode.""" + # Arrange + query = "What is Python?" + query_selector = ["start", "query"] + + # Add query to variable pool + mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query)) + sample_node_data.query_variable_selector = query_selector + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + # Mock retrieval response + mock_source = Mock(spec=Source) + mock_source.model_dump.return_value = {"content": "Python is a programming language"} + mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert "result" in result.outputs + assert mock_rag_retrieval.knowledge_retrieval.called + + def test_run_with_invalid_query_variable_type( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run fails when query variable is not StringSegment.""" + # Arrange + query_selector = ["start", "query"] + + # Add non-string variable to variable pool + mock_graph_runtime_state.variable_pool.add(query_selector, [1, 2, 3]) + sample_node_data.query_variable_selector = query_selector + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Query variable is not string type" in result.error + + def test_run_with_invalid_attachment_variable_type( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run fails when attachment variable is not FileSegment or ArrayFileSegment.""" + # Arrange + attachment_selector = ["start", "attachments"] + + # Add non-file variable to variable pool + mock_graph_runtime_state.variable_pool.add(attachment_selector, "not a file") + sample_node_data.query_attachment_selector = attachment_selector + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Attachments variable is not array file or file type" in result.error + + def test_run_with_rate_limit_exceeded( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run handles RateLimitExceededError properly.""" + # Arrange + query = "What is Python?" + query_selector = ["start", "query"] + + mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query)) + sample_node_data.query_variable_selector = query_selector + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + # Mock retrieval to raise RateLimitExceededError + mock_rag_retrieval.knowledge_retrieval.side_effect = RateLimitExceededError( + "knowledge base request rate limit exceeded" + ) + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "rate limit" in result.error.lower() + + def test_run_with_generic_exception( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _run handles generic exceptions properly.""" + # Arrange + query = "What is Python?" + query_selector = ["start", "query"] + + mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query)) + sample_node_data.query_variable_selector = query_selector + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + # Mock retrieval to raise generic exception + mock_rag_retrieval.knowledge_retrieval.side_effect = Exception("Unexpected error") + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + result = node._run() + + # Assert + assert result.status == WorkflowNodeExecutionStatus.FAILED + assert "Unexpected error" in result.error + + def test_extract_variable_selector_to_variable_mapping(self): + """Test _extract_variable_selector_to_variable_mapping class method.""" + # Arrange + node_id = "knowledge_node_1" + node_data = { + "type": "knowledge-retrieval", + "title": "Knowledge Retrieval", + "dataset_ids": [str(uuid.uuid4())], + "retrieval_mode": "multiple", + "query_variable_selector": ["start", "query"], + "query_attachment_selector": ["start", "attachments"], + } + graph_config = {} + + # Act + mapping = KnowledgeRetrievalNode._extract_variable_selector_to_variable_mapping( + graph_config=graph_config, + node_id=node_id, + node_data=node_data, + ) + + # Assert + assert mapping[f"{node_id}.query"] == ["start", "query"] + assert mapping[f"{node_id}.queryAttachment"] == ["start", "attachments"] + + +class TestFetchDatasetRetriever: + """ + Test suite for _fetch_dataset_retriever method. + """ + + def test_fetch_dataset_retriever_single_mode( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """Test _fetch_dataset_retriever in single mode.""" + # Arrange + from core.workflow.nodes.llm.entities import ModelConfig + + query = "What is Python?" + variables = {"query": query} + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="single", + single_retrieval_config=SingleRetrievalConfig( + model=ModelConfig( + provider="openai", + name="gpt-4", + mode="chat", + completion_params={"temperature": 0.7}, + ) + ), + ) + + # Mock retrieval response + mock_source = Mock(spec=Source) + mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node_id = str(uuid.uuid4()) + config = {"id": node_id, "data": node_data.model_dump()} + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert + assert len(results) == 1 + assert isinstance(usage, LLMUsage) + assert mock_rag_retrieval.knowledge_retrieval.called + + def test_fetch_dataset_retriever_multiple_mode_with_reranking( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + sample_node_data, + ): + """Test _fetch_dataset_retriever in multiple mode with reranking.""" + # Arrange + query = "What is Python?" + variables = {"query": query} + + # Mock retrieval response + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": sample_node_data.model_dump(), + } + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + results, usage = node._fetch_dataset_retriever(node_data=sample_node_data, variables=variables) + + # Assert + assert isinstance(results, list) + assert isinstance(usage, LLMUsage) + assert mock_rag_retrieval.knowledge_retrieval.called + + # Verify reranking parameters via request object + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.reranking_enable is True + assert request.reranking_mode == "reranking_model" + + def test_fetch_dataset_retriever_multiple_mode_without_reranking( + self, + mock_graph_init_params, + mock_graph_runtime_state, + mock_rag_retrieval, + ): + """Test _fetch_dataset_retriever in multiple mode without reranking.""" + # Arrange + query = "What is Python?" + variables = {"query": query} + + node_data = KnowledgeRetrievalNodeData( + title="Knowledge Retrieval", + type="knowledge-retrieval", + dataset_ids=[str(uuid.uuid4())], + retrieval_mode="multiple", + multiple_retrieval_config=MultipleRetrievalConfig( + top_k=5, + score_threshold=0.7, + reranking_enable=False, + reranking_mode="reranking_model", + ), + ) + + # Mock retrieval response + mock_rag_retrieval.knowledge_retrieval.return_value = [] + mock_rag_retrieval.llm_usage = LLMUsage.empty_usage() + + node_id = str(uuid.uuid4()) + config = { + "id": node_id, + "data": node_data.model_dump(), + } + + node = KnowledgeRetrievalNode( + id=node_id, + config=config, + graph_init_params=mock_graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + rag_retrieval=mock_rag_retrieval, + ) + + # Act + results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables) + + # Assert + assert isinstance(results, list) + assert mock_rag_retrieval.knowledge_retrieval.called + + # Verify reranking is disabled + call_args = mock_rag_retrieval.knowledge_retrieval.call_args + request = call_args[1]["request"] + assert request.reranking_enable is False + + def test_version_method(self): + """Test version class method.""" + # Act + version = KnowledgeRetrievalNode.version() + + # Assert + assert version == "1"