refactor: decouple database operations from knowledge retrieval nodes (#31981)

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
wangxiaolei
2026-02-09 13:56:55 +08:00
committed by GitHub
parent 0428ac5f3a
commit 3348b89436
12 changed files with 2453 additions and 551 deletions

View File

@@ -50,14 +50,12 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@@ -122,11 +120,6 @@ ignore_imports =
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.file.models
@@ -146,7 +139,6 @@ ignore_imports =
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
@@ -162,9 +154,6 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
@@ -213,7 +202,6 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
@@ -229,7 +217,6 @@ ignore_imports =
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
@@ -287,8 +274,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database

View File

@@ -8,6 +8,7 @@ from core.file.file_manager import file_manager
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.tools.tool_file_manager import ToolFileManager
from core.workflow.entities.graph_config import NodeConfigDict
from core.workflow.enums import NodeType
@@ -16,6 +17,7 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.code.code_node import CodeNode
from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.http_request.node import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
from core.workflow.nodes.template_transform.template_renderer import (
@@ -75,6 +77,7 @@ class DifyNodeFactory(NodeFactory):
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
self._rag_retrieval = DatasetRetrieval()
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
@@ -140,6 +143,15 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
rag_retrieval=self._rag_retrieval,
)
return node_class(
id=node_id,
config=node_config,

View File

@@ -1,13 +1,15 @@
import json
import logging
import math
import re
import threading
import time
from collections import Counter, defaultdict
from collections.abc import Generator, Mapping
from typing import Any, Union, cast
from flask import Flask, current_app
from sqlalchemy import and_, literal, or_, select
from sqlalchemy import and_, func, literal, or_, select
from sqlalchemy.orm import Session
from core.app.app_config.entities import (
@@ -18,6 +20,7 @@ from core.app.app_config.entities import (
)
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.db.session_factory import session_factory
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.file import File, FileTransferMethod, FileType
@@ -58,12 +61,30 @@ from core.rag.retrieval.template_prompts import (
)
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import (
KnowledgeRetrievalRequest,
Source,
SourceChildChunk,
SourceMetadata,
)
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models import UploadFile
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
from models.dataset import (
ChildChunk,
Dataset,
DatasetMetadata,
DatasetQuery,
DocumentSegment,
RateLimitLog,
SegmentAttachmentBinding,
)
from models.dataset import Document as DatasetDocument
from models.dataset import Document as DocumentModel
from services.external_knowledge_service import ExternalDatasetService
from services.feature_service import FeatureService
default_retrieval_model: dict[str, Any] = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
@@ -73,6 +94,8 @@ default_retrieval_model: dict[str, Any] = {
"score_threshold_enabled": False,
}
logger = logging.getLogger(__name__)
class DatasetRetrieval:
def __init__(self, application_generate_entity=None):
@@ -91,6 +114,233 @@ class DatasetRetrieval:
else:
self._llm_usage = self._llm_usage.plus(usage)
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
self._check_knowledge_rate_limit(request.tenant_id)
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
available_datasets_ids = [i.id for i in available_datasets]
if not available_datasets_ids:
return []
if not request.query:
return []
metadata_filter_document_ids, metadata_condition = None, None
if request.metadata_filtering_mode != "disabled":
# Convert workflow layer types to app_config layer types
if not request.metadata_model_config:
raise ValueError("metadata_model_config is required for this method")
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
app_metadata_filtering_conditions = None
if request.metadata_filtering_conditions is not None:
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
request.metadata_filtering_conditions.model_dump()
)
query = request.query if request.query is not None else ""
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
dataset_ids=available_datasets_ids,
query=query,
tenant_id=request.tenant_id,
user_id=request.user_id,
metadata_filtering_mode=request.metadata_filtering_mode,
metadata_model_config=app_metadata_model_config,
metadata_filtering_conditions=app_metadata_filtering_conditions,
inputs={},
)
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
planning_strategy = PlanningStrategy.REACT_ROUTER
# Ensure required fields are not None for single retrieval mode
if request.model_provider is None or request.model_name is None or request.query is None:
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=request.tenant_id,
model_type=ModelType.LLM,
provider=request.model_provider,
model=request.model_name,
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=request.model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise exc.ModelCredentialsNotInitializedError(
f"Model {request.model_name} credentials is not initialized."
)
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
stop = []
completion_params = (request.completion_params or {}).copy()
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
if not model_schema:
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
model_config = ModelConfigWithCredentialsEntity(
provider=request.model_provider,
model=request.model_name,
model_schema=model_schema,
mode=request.model_mode or "chat",
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
all_documents = self.single_retrieve(
request.app_id,
request.tenant_id,
request.user_id,
request.user_from,
request.query,
available_datasets,
model_instance,
model_config,
planning_strategy,
None, # message_id
metadata_filter_document_ids,
metadata_condition,
)
else:
all_documents = self.multiple_retrieve(
app_id=request.app_id,
tenant_id=request.tenant_id,
user_id=request.user_id,
user_from=request.user_from,
available_datasets=available_datasets,
query=request.query,
top_k=request.top_k,
score_threshold=request.score_threshold,
reranking_mode=request.reranking_mode,
reranking_model=request.reranking_model,
weights=request.weights,
reranking_enable=request.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=request.attachment_ids,
)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=item.metadata.get("dataset_id"),
dataset_name=item.metadata.get("dataset_name"),
document_id=item.metadata.get("document_id"),
document_name=item.metadata.get("title"),
data_source_type="external",
retriever_from="workflow",
score=item.metadata.get("score"),
doc_metadata=item.metadata,
),
title=item.metadata.get("title"),
content=item.page_content,
)
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
dataset_ids = [i.segment.dataset_id for i in records]
document_ids = [i.segment.document_id for i in records]
with session_factory.create_session() as session:
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
dataset_map = {i.id: i for i in datasets}
document_map = {i.id: i for i in documents}
if records:
for record in records:
segment = record.segment
dataset = dataset_map.get(segment.dataset_id)
document = document_map.get(segment.document_id)
if dataset and document:
source = Source(
metadata=SourceMetadata(
source="knowledge",
dataset_id=dataset.id,
dataset_name=dataset.name,
document_id=document.id,
document_name=document.name,
data_source_type=document.data_source_type,
segment_id=segment.id,
retriever_from="workflow",
score=record.score or 0.0,
segment_hit_count=segment.hit_count,
segment_word_count=segment.word_count,
segment_position=segment.position,
segment_index_node_hash=segment.index_node_hash,
doc_metadata=document.doc_metadata,
child_chunks=[
SourceChildChunk(
id=str(getattr(chunk, "id", "")),
content=str(getattr(chunk, "content", "")),
position=int(getattr(chunk, "position", 0)),
score=float(getattr(chunk, "score", 0.0)),
)
for chunk in (record.child_chunks or [])
],
position=None,
),
title=document.name,
files=list(record.files) if record.files else None,
content=segment.get_sign_content(),
)
if segment.answer:
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
if record.summary:
source.summary = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
def _score(item: Source) -> float:
meta = item.metadata
score = meta.score
if isinstance(score, (int, float)):
return float(score)
return 0.0
retrieval_resource_list = sorted(
retrieval_resource_list,
key=_score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item.metadata.position = position # type: ignore[index]
return retrieval_resource_list
def retrieve(
self,
app_id: str,
@@ -150,14 +400,7 @@ class DatasetRetrieval:
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
available_datasets = []
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
for dataset in datasets:
if dataset.available_document_count == 0 and dataset.provider != "external":
continue
available_datasets.append(dataset)
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
if inputs:
inputs = {key: str(value) for key, value in inputs.items()}
@@ -1161,7 +1404,6 @@ class DatasetRetrieval:
query=query or "",
)
result_text = ""
try:
# handle invoke result
invoke_result = cast(
@@ -1192,7 +1434,8 @@ class DatasetRetrieval:
"condition": item.get("comparison_operator"),
}
)
except Exception:
except Exception as e:
logger.warning(e, exc_info=True)
return None
return automatic_metadata_filters
@@ -1406,7 +1649,12 @@ class DatasetRetrieval:
usage = None
for result in invoke_result:
text = result.delta.message.content
full_text += text
if isinstance(text, str):
full_text += text
elif isinstance(text, list):
for i in text:
if i.data:
full_text += i.data
if not model:
model = result.model
@@ -1524,3 +1772,53 @@ class DatasetRetrieval:
cancel_event.set()
if thread_exceptions is not None:
thread_exceptions.append(e)
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
with session_factory.create_session() as session:
subquery = (
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
.where(
DocumentModel.indexing_status == "completed",
DocumentModel.enabled == True,
DocumentModel.archived == False,
DocumentModel.dataset_id.in_(dataset_ids),
)
.group_by(DocumentModel.dataset_id)
.having(func.count(DocumentModel.id) > 0)
.subquery()
)
results = (
session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
available_datasets = []
for dataset in results:
if not dataset:
continue
available_datasets.append(dataset)
return available_datasets
def _check_knowledge_rate_limit(self, tenant_id: str):
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with session_factory.create_session() as session:
rate_limit_log = RateLimitLog(
tenant_id=tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
raise exc.RateLimitExceededError(
"you have reached the knowledge base request rate limit of your subscription."
)

View File

@@ -20,3 +20,7 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
"""Raised when the model is not a Large Language Model."""
class RateLimitExceededError(KnowledgeRetrievalNodeError):
"""Raised when the rate limit is exceeded."""

View File

@@ -1,29 +1,10 @@
import json
import logging
import re
import time
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import sessionmaker
from typing import TYPE_CHECKING, Any, Literal
from core.app.app_config.entities import DatasetRetrieveConfigEntity
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.agent_entities import PlanningStrategy
from core.entities.model_entities import ModelStatus
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.simple_prompt_transform import ModelMode
from core.rag.datasource.retrieval_service import RetrievalService
from core.rag.entities.metadata_entities import Condition, MetadataCondition
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.variables import (
ArrayFileSegment,
FileSegment,
@@ -36,35 +17,16 @@ from core.workflow.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base import LLMUsageTrackingMixin
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
METADATA_FILTER_ASSISTANT_PROMPT_1,
METADATA_FILTER_ASSISTANT_PROMPT_2,
METADATA_FILTER_COMPLETION_PROMPT,
METADATA_FILTER_SYSTEM_PROMPT,
METADATA_FILTER_USER_PROMPT_1,
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.json_in_md_parser import parse_and_check_json_markdown
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
from services.feature_service import FeatureService
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
from .entities import KnowledgeRetrievalNodeData
from .exc import (
InvalidModelTypeError,
KnowledgeRetrievalNodeError,
ModelCredentialsNotInitializedError,
ModelNotExistError,
ModelNotSupportedError,
ModelQuotaExceededError,
RateLimitExceededError,
)
if TYPE_CHECKING:
@@ -73,14 +35,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
default_retrieval_model = {
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
"reranking_enable": False,
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
"top_k": 4,
"score_threshold_enabled": False,
}
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
node_type = NodeType.KNOWLEDGE_RETRIEVAL
@@ -97,6 +51,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
rag_retrieval: RAGRetrievalProtocol,
*,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -108,6 +63,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._rag_retrieval = rag_retrieval
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
@@ -121,6 +77,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
return "1"
def _run(self) -> NodeRunResult:
usage = LLMUsage.empty_usage()
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -128,7 +85,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
process_data={},
outputs={},
metadata={},
llm_usage=LLMUsage.empty_usage(),
llm_usage=usage,
)
variables: dict[str, Any] = {}
# extract variables
@@ -156,36 +113,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
else:
variables["attachments"] = [variable.value]
# TODO(-LAN-): Move this check outside.
# check rate limit
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
if knowledge_rate_limit.enabled:
current_time = int(time.time() * 1000)
key = f"rate_limit_{self.tenant_id}"
redis_client.zadd(key, {current_time: current_time})
redis_client.zremrangebyscore(key, 0, current_time - 60000)
request_count = redis_client.zcard(key)
if request_count > knowledge_rate_limit.limit:
with sessionmaker(db.engine).begin() as session:
# add ratelimit record
rate_limit_log = RateLimitLog(
tenant_id=self.tenant_id,
subscription_plan=knowledge_rate_limit.subscription_plan,
operation="knowledge",
)
session.add(rate_limit_log)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
error_type="RateLimitExceeded",
)
# retrieve knowledge
usage = LLMUsage.empty_usage()
try:
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
outputs = {"result": ArrayObjectSegment(value=results)}
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,
@@ -198,9 +128,17 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
llm_usage=usage,
)
except RateLimitExceededError as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
error=str(e),
error_type=type(e).__name__,
llm_usage=usage,
)
except KnowledgeRetrievalNodeError as e:
logger.warning("Error when running knowledge retrieval node")
logger.warning("Error when running knowledge retrieval node", exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -210,6 +148,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
)
# Temporary handle all exceptions from DatasetRetrieval class here.
except Exception as e:
logger.warning(e, exc_info=True)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
inputs=variables,
@@ -217,92 +156,47 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
error_type=type(e).__name__,
llm_usage=usage,
)
finally:
db.session.close()
def _fetch_dataset_retriever(
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
available_datasets = []
) -> tuple[list[Source], LLMUsage]:
dataset_ids = node_data.dataset_ids
query = variables.get("query")
attachments = variables.get("attachments")
metadata_filter_document_ids = None
metadata_condition = None
metadata_usage = LLMUsage.empty_usage()
# Subquery: Count the number of available documents for each dataset
subquery = (
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
.where(
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
Document.dataset_id.in_(dataset_ids),
)
.group_by(Document.dataset_id)
.having(func.count(Document.id) > 0)
.subquery()
)
retrieval_resource_list = []
results = (
db.session.query(Dataset)
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
.all()
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
if node_data.metadata_filtering_mode is not None:
metadata_filtering_mode = node_data.metadata_filtering_mode
# avoid blocking at retrieval
db.session.close()
for dataset in results:
# pass if dataset is not available
if not dataset:
continue
available_datasets.append(dataset)
if query:
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
[dataset.id for dataset in available_datasets], query, node_data
)
usage = self._merge_usage(usage, metadata_usage)
all_documents = []
dataset_retrieval = DatasetRetrieval()
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
# fetch model config
if node_data.single_retrieval_config is None:
raise ValueError("single_retrieval_config is required")
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
# check model is support tool calling
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
# get model schema
model_schema = model_type_instance.get_model_schema(
model=model_config.model, credentials=model_config.credentials
)
if model_schema:
planning_strategy = PlanningStrategy.REACT_ROUTER
features = model_schema.features
if features:
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
planning_strategy = PlanningStrategy.ROUTER
all_documents = dataset_retrieval.single_retrieve(
available_datasets=available_datasets,
raise ValueError("single_retrieval_config is required for single retrieval mode")
model = node_data.single_retrieval_config.model
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
tenant_id=self.tenant_id,
user_id=self.user_id,
app_id=self.app_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
completion_params=model.completion_params,
model_provider=model.provider,
model_mode=model.mode,
model_name=model.name,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
query=query,
model_config=model_config,
model_instance=model_instance,
planning_strategy=planning_strategy,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
)
)
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
if node_data.multiple_retrieval_config is None:
raise ValueError("multiple_retrieval_config is required")
reranking_model = None
weights = None
match node_data.multiple_retrieval_config.reranking_mode:
case "reranking_model":
if node_data.multiple_retrieval_config.reranking_model:
@@ -329,284 +223,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
},
}
case _:
# Handle any other reranking_mode values
reranking_model = None
weights = None
all_documents = dataset_retrieval.multiple_retrieve(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
available_datasets=available_datasets,
query=query,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_filter_document_ids=metadata_filter_document_ids,
metadata_condition=metadata_condition,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
dify_documents = [item for item in all_documents if item.provider == "dify"]
external_documents = [item for item in all_documents if item.provider == "external"]
retrieval_resource_list = []
# deal with external documents
for item in external_documents:
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
"metadata": {
"_source": "knowledge",
"dataset_id": item.metadata.get("dataset_id"),
"dataset_name": item.metadata.get("dataset_name"),
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
"document_name": item.metadata.get("title"),
"data_source_type": "external",
"retriever_from": "workflow",
"score": item.metadata.get("score"),
"doc_metadata": item.metadata,
},
"title": item.metadata.get("title"),
"content": item.page_content,
}
retrieval_resource_list.append(source)
# deal with dify documents
if dify_documents:
records = RetrievalService.format_retrieval_documents(dify_documents)
if records:
for record in records:
segment = record.segment
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,
Document.archived == False,
)
document = db.session.scalar(stmt)
if dataset and document:
source = {
"metadata": {
"_source": "knowledge",
"dataset_id": dataset.id,
"dataset_name": dataset.name,
"document_id": document.id,
"document_name": document.name,
"data_source_type": document.data_source_type,
"segment_id": segment.id,
"retriever_from": "workflow",
"score": record.score or 0.0,
"child_chunks": [
{
"id": str(getattr(chunk, "id", "")),
"content": str(getattr(chunk, "content", "")),
"position": int(getattr(chunk, "position", 0)),
"score": float(getattr(chunk, "score", 0.0)),
}
for chunk in (record.child_chunks or [])
],
"segment_hit_count": segment.hit_count,
"segment_word_count": segment.word_count,
"segment_position": segment.position,
"segment_index_node_hash": segment.index_node_hash,
"doc_metadata": document.doc_metadata,
},
"title": document.name,
"files": list(record.files) if record.files else None,
}
if segment.answer:
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
else:
source["content"] = segment.get_sign_content()
# Add summary if available
if record.summary:
source["summary"] = record.summary
retrieval_resource_list.append(source)
if retrieval_resource_list:
retrieval_resource_list = sorted(
retrieval_resource_list,
key=self._score, # type: ignore[arg-type, return-value]
reverse=True,
)
for position, item in enumerate(retrieval_resource_list, start=1):
item["metadata"]["position"] = position # type: ignore[index]
return retrieval_resource_list, usage
def _score(self, item: dict[str, Any]) -> float:
meta = item.get("metadata")
if isinstance(meta, dict):
s = meta.get("score")
if isinstance(s, (int, float)):
return float(s)
return 0.0
def _get_metadata_filter_condition(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
usage = LLMUsage.empty_usage()
document_query = db.session.query(Document).where(
Document.dataset_id.in_(dataset_ids),
Document.indexing_status == "completed",
Document.enabled == True,
Document.archived == False,
)
filters: list[Any] = []
metadata_condition = None
match node_data.metadata_filtering_mode:
case "disabled":
return None, None, usage
case "automatic":
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
dataset_ids, query, node_data
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
request=KnowledgeRetrievalRequest(
app_id=self.app_id,
tenant_id=self.tenant_id,
user_id=self.user_id,
user_from=self.user_from.value,
dataset_ids=dataset_ids,
query=query,
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
top_k=node_data.multiple_retrieval_config.top_k,
score_threshold=node_data.multiple_retrieval_config.score_threshold
if node_data.multiple_retrieval_config.score_threshold is not None
else 0.0,
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
reranking_model=reranking_model,
weights=weights,
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
metadata_model_config=node_data.metadata_model_config,
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
metadata_filtering_mode=metadata_filtering_mode,
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
)
usage = self._merge_usage(usage, automatic_usage)
if automatic_metadata_filters:
conditions = []
for sequence, filter in enumerate(automatic_metadata_filters):
DatasetRetrieval.process_metadata_filter_func(
sequence,
filter.get("condition", ""),
filter.get("metadata_name", ""),
filter.get("value"),
filters,
)
conditions.append(
Condition(
name=filter.get("metadata_name"), # type: ignore
comparison_operator=filter.get("condition"), # type: ignore
value=filter.get("value"),
)
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator
if node_data.metadata_filtering_conditions
else "or",
conditions=conditions,
)
case "manual":
if node_data.metadata_filtering_conditions:
conditions = []
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
metadata_name = condition.name
expected_value = condition.value
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
if isinstance(expected_value, str):
expected_value = self.graph_runtime_state.variable_pool.convert_template(
expected_value
).value[0]
if expected_value.value_type in {"number", "integer", "float"}:
expected_value = expected_value.value
elif expected_value.value_type == "string":
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
else:
raise ValueError("Invalid expected metadata value type")
conditions.append(
Condition(
name=metadata_name,
comparison_operator=condition.comparison_operator,
value=expected_value,
)
)
filters = DatasetRetrieval.process_metadata_filter_func(
sequence,
condition.comparison_operator,
metadata_name,
expected_value,
filters,
)
metadata_condition = MetadataCondition(
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
conditions=conditions,
)
case _:
raise ValueError("Invalid metadata filtering mode")
if filters:
if (
node_data.metadata_filtering_conditions
and node_data.metadata_filtering_conditions.logical_operator == "and"
):
document_query = document_query.where(and_(*filters))
else:
document_query = document_query.where(or_(*filters))
documents = document_query.all()
# group by dataset_id
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
for document in documents:
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
return metadata_filter_document_ids, metadata_condition, usage
def _automatic_metadata_filter_func(
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
) -> tuple[list[dict[str, Any]], LLMUsage]:
usage = LLMUsage.empty_usage()
# get all metadata field
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
metadata_fields = db.session.scalars(stmt).all()
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
if node_data.metadata_model_config is None:
raise ValueError("metadata_model_config is required")
# get metadata model instance and fetch model config
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
# fetch prompt messages
prompt_template = self._get_prompt_template(
node_data=node_data,
metadata_fields=all_metadata_fields,
query=query or "",
)
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query=query,
memory=None,
model_config=model_config,
sys_files=[],
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""
try:
# handle invoke result
generator = LLMNode.invoke_llm(
node_data_model=node_data.metadata_model_config,
model_instance=model_instance,
prompt_messages=prompt_messages,
stop=stop,
user_id=self.user_id,
structured_output_enabled=self.node_data.structured_output_enabled,
structured_output=None,
file_saver=self._llm_file_saver,
file_outputs=self._file_outputs,
node_id=self._node_id,
node_type=self.node_type,
)
for event in generator:
if isinstance(event, ModelInvokeCompletedEvent):
result_text = event.text
usage = self._merge_usage(usage, event.usage)
break
result_text_json = parse_and_check_json_markdown(result_text, [])
automatic_metadata_filters = []
if "metadata_map" in result_text_json:
metadata_map = result_text_json["metadata_map"]
for item in metadata_map:
if item.get("metadata_field_name") in all_metadata_fields:
automatic_metadata_filters.append(
{
"metadata_name": item.get("metadata_field_name"),
"value": item.get("metadata_field_value"),
"condition": item.get("comparison_operator"),
}
)
except Exception:
return [], usage
return automatic_metadata_filters, usage
usage = self._rag_retrieval.llm_usage
return retrieval_resource_list, usage
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -626,107 +272,3 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
if typed_node_data.query_attachment_selector:
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
return variable_mapping
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model_name = model.name
provider_name = model.provider
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
)
provider_model_bundle = model_instance.provider_model_bundle
model_type_instance = model_instance.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)
model_credentials = model_instance.credentials
# check model
provider_model = provider_model_bundle.configuration.get_provider_model(
model=model_name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
if provider_model.status == ModelStatus.NO_CONFIGURE:
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
elif provider_model.status == ModelStatus.NO_PERMISSION:
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
# model config
completion_params = model.completion_params
stop = []
if "stop" in completion_params:
stop = completion_params["stop"]
del completion_params["stop"]
# get model mode
model_mode = model.mode
if not model_mode:
raise ModelNotExistError("LLM mode is required.")
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
if not model_schema:
raise ModelNotExistError(f"Model {model_name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=provider_name,
model=model_name,
model_schema=model_schema,
mode=model_mode,
provider_model_bundle=provider_model_bundle,
credentials=model_credentials,
parameters=completion_params,
stop=stop,
)
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
input_text = query
prompt_messages: list[LLMNodeChatModelMessage] = []
if model_mode == ModelMode.CHAT:
system_prompt_messages = LLMNodeChatModelMessage(
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
)
prompt_messages.append(system_prompt_messages)
user_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
)
prompt_messages.append(user_prompt_message_1)
assistant_prompt_message_1 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
)
prompt_messages.append(assistant_prompt_message_1)
user_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
)
prompt_messages.append(user_prompt_message_2)
assistant_prompt_message_2 = LLMNodeChatModelMessage(
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
)
prompt_messages.append(assistant_prompt_message_2)
user_prompt_message_3 = LLMNodeChatModelMessage(
role=PromptMessageRole.USER,
text=METADATA_FILTER_USER_PROMPT_3.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
),
)
prompt_messages.append(user_prompt_message_3)
return prompt_messages
elif model_mode == ModelMode.COMPLETION:
return LLMNodeCompletionModelPromptTemplate(
text=METADATA_FILTER_COMPLETION_PROMPT.format(
input_text=input_text,
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
)
)
else:
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")

View File

@@ -0,0 +1,108 @@
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from core.model_runtime.entities import LLMUsage
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
from core.workflow.nodes.llm.entities import ModelConfig
class SourceChildChunk(BaseModel):
id: str = Field(default="", description="Child chunk ID")
content: str = Field(default="", description="Child chunk content")
position: int = Field(default=0, description="Child chunk position")
score: float = Field(default=0.0, description="Child chunk relevance score")
class SourceMetadata(BaseModel):
source: str = Field(
default="knowledge",
serialization_alias="_source",
description="Data source identifier",
)
dataset_id: str = Field(description="Dataset unique identifier")
dataset_name: str = Field(description="Dataset display name")
document_id: str = Field(description="Document unique identifier")
document_name: str = Field(description="Document display name")
data_source_type: str = Field(description="Type of data source")
segment_id: str | None = Field(default=None, description="Segment unique identifier")
retriever_from: str = Field(default="workflow", description="Retriever source context")
score: float = Field(default=0.0, description="Retrieval relevance score")
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
segment_position: int | None = Field(default=0, description="Position of segment in document")
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
position: int | None = Field(default=0, description="Position of the document in the dataset")
class Config:
populate_by_name = True
class Source(BaseModel):
metadata: SourceMetadata = Field(description="Source metadata information")
title: str = Field(description="Document title")
files: list[Any] | None = Field(default=None, description="Associated file references")
content: str | None = Field(description="Segment content text")
summary: str | None = Field(default=None, description="Content summary if available")
class KnowledgeRetrievalRequest(BaseModel):
tenant_id: str = Field(description="Tenant unique identifier")
user_id: str = Field(description="User unique identifier")
app_id: str = Field(description="Application unique identifier")
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
completion_params: dict[str, Any] | None = Field(
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
)
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
metadata_model_config: ModelConfig | None = Field(
default=None, description="Model config for metadata-based filtering"
)
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
default=None, description="Conditions for filtering by metadata"
)
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
)
top_k: int = Field(default=0, description="Number of top results to return")
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
class RAGRetrievalProtocol(Protocol):
"""Protocol for RAG-based knowledge retrieval implementations.
Implementations of this protocol handle knowledge retrieval from datasets
including rate limiting, dataset filtering, and document retrieval.
"""
@property
def llm_usage(self) -> LLMUsage:
"""Return accumulated LLM usage for retrieval operations."""
...
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
"""Retrieve knowledge from datasets based on the provided request.
Args:
request: Knowledge retrieval request with search parameters
Returns:
List of sources matching the search criteria
Raises:
RateLimitExceededError: If rate limit is exceeded
ModelNotExistError: If specified model doesn't exist
"""
...

View File

@@ -0,0 +1,29 @@
"""
Integration tests for KnowledgeRetrievalNode.
This module provides integration tests for KnowledgeRetrievalNode with real database interactions.
Note: These tests require database setup and are more complex than unit tests.
For now, we focus on unit tests which provide better coverage for the node logic.
"""
import pytest
class TestKnowledgeRetrievalNodeIntegration:
"""
Integration test suite for KnowledgeRetrievalNode.
Note: Full integration tests require:
- Database setup with datasets and documents
- Vector store for embeddings
- Model providers for retrieval
For now, unit tests provide comprehensive coverage of the node logic.
"""
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
def test_end_to_end_knowledge_retrieval(self):
"""Test end-to-end knowledge retrieval workflow."""
# TODO: Implement with real database
pass

View File

@@ -0,0 +1,614 @@
import uuid
from unittest.mock import patch
import pytest
from faker import Faker
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
from models.dataset import Dataset, Document
from services.account_service import AccountService, TenantService
class TestGetAvailableDatasetsIntegration:
def test_returns_datasets_with_available_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
# Create account and tenant
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create dataset
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=100),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
indexing_technique="high_quality",
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
# Create documents with completed status, enabled, not archived
for i in range(3):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
name=f"Document {i}",
created_from="web",
created_by=account.id,
doc_form="text_model",
doc_language="en",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset.id
assert result[0].tenant_id == tenant.id
assert result[0].name == dataset.name
def test_filters_out_datasets_with_only_archived_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create only archived documents
for i in range(2):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Archived Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=True, # Archived
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_filters_out_datasets_with_only_disabled_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create only disabled documents
for i in range(2):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Disabled Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=False, # Disabled
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_filters_out_datasets_with_non_completed_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
# Create documents with non-completed status
for i, status in enumerate(["indexing", "parsing", "splitting"]):
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=i,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document {status}",
created_by=account.id,
doc_form="text_model",
indexing_status=status, # Not completed
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 0
def test_includes_external_datasets_without_documents(
self, db_session_with_containers, mock_external_service_dependencies
):
"""
Test that external datasets are returned even with no available documents.
External datasets (e.g., from external knowledge bases) don't have
documents stored in Dify's database, so they should always be available.
Verifies:
- External datasets are included in results
- No document count check for external datasets
"""
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="external", # External provider
data_source_type="external",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset.id
assert result[0].provider == "external"
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
# Arrange
fake = Faker()
# Create two accounts/tenants
account1 = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
tenant1 = account1.current_tenant
account2 = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
tenant2 = account2.current_tenant
# Create dataset for tenant1
dataset1 = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant1.id,
name="Tenant 1 Dataset",
provider="dify",
data_source_type="upload_file",
created_by=account1.id,
)
db_session_with_containers.add(dataset1)
# Create dataset for tenant2
dataset2 = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant2.id,
name="Tenant 2 Dataset",
provider="dify",
data_source_type="upload_file",
created_by=account2.id,
)
db_session_with_containers.add(dataset2)
# Add documents to both datasets
for dataset, account in [(dataset1, account1), (dataset2, account2)]:
document = Document(
id=str(uuid.uuid4()),
tenant_id=dataset.tenant_id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document for {dataset.name}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act - request from tenant1, should only get tenant1's dataset
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant1.id, [dataset1.id, dataset2.id])
# Assert
assert len(result) == 1
assert result[0].id == dataset1.id
assert result[0].tenant_id == tenant1.id
def test_returns_empty_list_when_no_datasets_found(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Don't create any datasets
# Act
dataset_retrieval = DatasetRetrieval()
result = dataset_retrieval._get_available_datasets(tenant.id, [str(uuid.uuid4())])
# Assert
assert result == []
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create multiple datasets
datasets = []
for i in range(3):
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=f"Dataset {i}",
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
datasets.append(dataset)
# Add document
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=f"Document {i}",
created_by=account.id,
doc_form="text_model",
indexing_status="completed",
enabled=True,
archived=False,
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Act - request only dataset 0 and 2, not dataset 1
dataset_retrieval = DatasetRetrieval()
requested_ids = [datasets[0].id, datasets[2].id]
result = dataset_retrieval._get_available_datasets(tenant.id, requested_ids)
# Assert
assert len(result) == 2
returned_ids = {d.id for d in result}
assert returned_ids == {datasets[0].id, datasets[2].id}
class TestKnowledgeRetrievalIntegration:
def test_knowledge_retrieval_with_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
indexing_technique="high_quality",
)
db_session_with_containers.add(dataset)
document = Document(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch=str(uuid.uuid4()), # Required field
created_from="web",
name=fake.sentence(),
created_by=account.id,
indexing_status="completed",
enabled=True,
archived=False,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Create request
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check and retrieval
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
def test_knowledge_retrieval_no_available_datasets(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
# Create dataset but no documents
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(
self, db_session_with_containers, mock_external_service_dependencies
):
# Arrange
fake = Faker()
account = AccountService.create_account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
password=fake.password(length=12),
)
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
tenant = account.current_tenant
dataset = Dataset(
id=str(uuid.uuid4()),
tenant_id=tenant.id,
name=fake.company(),
provider="dify",
data_source_type="upload_file",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
request = KnowledgeRetrievalRequest(
tenant_id=tenant.id,
user_id=account.id,
app_id=str(uuid.uuid4()),
user_from="web",
dataset_ids=[dataset.id],
query="test query",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock rate limit check to raise exception
with patch.object(
dataset_retrieval,
"_check_knowledge_rate_limit",
side_effect=Exception("Rate limit exceeded"),
):
# Act & Assert
with pytest.raises(Exception, match="Rate limit exceeded"):
dataset_retrieval.knowledge_retrieval(request)
@pytest.fixture
def mock_external_service_dependencies():
with (
patch("services.account_service.FeatureService") as mock_account_feature_service,
):
# Setup default mock returns for account service
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
yield {
"account_feature_service": mock_account_feature_service,
}

View File

@@ -0,0 +1,715 @@
from unittest.mock import MagicMock, Mock, patch
from uuid import uuid4
import pytest
from core.rag.models.document import Document
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
from core.workflow.nodes.knowledge_retrieval import exc
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
from models.dataset import Dataset
# ==================== Helper Functions ====================
def create_mock_dataset(
dataset_id: str | None = None,
tenant_id: str | None = None,
provider: str = "dify",
indexing_technique: str = "high_quality",
available_document_count: int = 10,
) -> Mock:
"""
Create a mock Dataset object for testing.
Args:
dataset_id: Unique identifier for the dataset
tenant_id: Tenant ID for the dataset
provider: Provider type ("dify" or "external")
indexing_technique: Indexing technique ("high_quality" or "economy")
available_document_count: Number of available documents
Returns:
Mock: A properly configured Dataset mock
"""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id or str(uuid4())
dataset.tenant_id = tenant_id or str(uuid4())
dataset.name = "test_dataset"
dataset.provider = provider
dataset.indexing_technique = indexing_technique
dataset.available_document_count = available_document_count
dataset.embedding_model = "text-embedding-ada-002"
dataset.embedding_model_provider = "openai"
dataset.retrieval_model = {
"search_method": "semantic_search",
"reranking_enable": False,
"top_k": 4,
"score_threshold_enabled": False,
}
return dataset
def create_mock_document(
content: str,
doc_id: str,
score: float = 0.8,
provider: str = "dify",
additional_metadata: dict | None = None,
) -> Document:
"""
Create a mock Document object for testing.
Args:
content: The text content of the document
doc_id: Unique identifier for the document chunk
score: Relevance score (0.0 to 1.0)
provider: Document provider ("dify" or "external")
additional_metadata: Optional extra metadata fields
Returns:
Document: A properly structured Document object
"""
metadata = {
"doc_id": doc_id,
"document_id": str(uuid4()),
"dataset_id": str(uuid4()),
"score": score,
}
if additional_metadata:
metadata.update(additional_metadata)
return Document(
page_content=content,
metadata=metadata,
provider=provider,
)
# ==================== Test _check_knowledge_rate_limit ====================
class TestCheckKnowledgeRateLimit:
"""
Test suite for _check_knowledge_rate_limit method.
The _check_knowledge_rate_limit method validates whether a tenant has
exceeded their knowledge retrieval rate limit. This is important for:
- Preventing abuse of the knowledge retrieval system
- Enforcing subscription plan limits
- Tracking usage for billing purposes
Test Cases:
============
1. Rate limit disabled - no exception raised
2. Rate limit enabled but not exceeded - no exception raised
3. Rate limit enabled and exceeded - RateLimitExceededError raised
4. Redis operations are performed correctly
5. RateLimitLog is created when limit is exceeded
"""
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service):
"""
Test that when rate limit is disabled, no exception is raised.
This test verifies the behavior when the tenant's subscription
does not have rate limiting enabled.
Verifies:
- FeatureService.get_knowledge_rate_limit is called
- No Redis operations are performed
- No exception is raised
- Retrieval proceeds normally
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit disabled
mock_limit = Mock()
mock_limit.enabled = False
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Act & Assert - should not raise any exception
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify FeatureService was called
mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id)
# Verify no Redis operations were performed
assert not mock_redis.zadd.called
assert not mock_redis.zremrangebyscore.called
assert not mock_redis.zcard.called
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
@patch("core.rag.retrieval.dataset_retrieval.time")
def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory):
"""
Test that when rate limit is enabled but not exceeded, no exception is raised.
This test simulates a tenant making requests within their rate limit.
The Redis sorted set stores timestamps of recent requests, and old
requests (older than 60 seconds) are removed.
Verifies:
- Redis zadd is called to track the request
- Redis zremrangebyscore removes old entries
- Redis zcard returns count within limit
- No exception is raised
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit enabled with limit of 100 requests per minute
mock_limit = Mock()
mock_limit.enabled = True
mock_limit.limit = 100
mock_limit.subscription_plan = "professional"
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Mock time
current_time = 1234567890000 # Current time in milliseconds
mock_time.time.return_value = current_time / 1000 # Return seconds
mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds
# Mock Redis operations
# zcard returns 50 (within limit of 100)
mock_redis.zcard.return_value = 50
# Mock session_factory.create_session
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
# Act & Assert - should not raise any exception
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify Redis operations
expected_key = f"rate_limit_{tenant_id}"
mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time})
mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000)
mock_redis.zcard.assert_called_once_with(expected_key)
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
@patch("core.rag.retrieval.dataset_retrieval.time")
def test_rate_limit_enabled_exceeded_raises_exception(
self, mock_time, mock_redis, mock_feature_service, mock_session_factory
):
"""
Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised.
This test simulates a tenant exceeding their rate limit. When the count
of recent requests exceeds the limit, an exception should be raised and
a RateLimitLog should be created.
Verifies:
- Redis zcard returns count exceeding limit
- RateLimitExceededError is raised with correct message
- RateLimitLog is created in database
- Session operations are performed correctly
"""
# Arrange
tenant_id = str(uuid4())
dataset_retrieval = DatasetRetrieval()
# Mock rate limit enabled with limit of 100 requests per minute
mock_limit = Mock()
mock_limit.enabled = True
mock_limit.limit = 100
mock_limit.subscription_plan = "professional"
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
# Mock time
current_time = 1234567890000
mock_time.time.return_value = current_time / 1000
# Mock Redis operations - return count exceeding limit
mock_redis.zcard.return_value = 150 # Exceeds limit of 100
# Mock session_factory.create_session
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
# Act & Assert
with pytest.raises(exc.RateLimitExceededError) as exc_info:
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
# Verify exception message
assert "knowledge base request rate limit" in str(exc_info.value)
# Verify RateLimitLog was created
mock_session.add.assert_called_once()
added_log = mock_session.add.call_args[0][0]
assert added_log.tenant_id == tenant_id
assert added_log.subscription_plan == "professional"
assert added_log.operation == "knowledge"
# ==================== Test _get_available_datasets ====================
class TestGetAvailableDatasets:
"""
Test suite for _get_available_datasets method.
The _get_available_datasets method retrieves datasets that are available
for retrieval. A dataset is considered available if:
- It belongs to the specified tenant
- It's in the list of requested dataset_ids
- It has at least one completed, enabled, non-archived document OR
- It's an external provider dataset
Note: Due to SQLAlchemy subquery complexity, full testing is done in
integration tests. Unit tests here verify basic behavior.
"""
def test_method_exists_and_has_correct_signature(self):
"""
Test that the method exists and has the correct signature.
Verifies:
- Method exists on DatasetRetrieval class
- Accepts tenant_id and dataset_ids parameters
"""
# Arrange
dataset_retrieval = DatasetRetrieval()
# Assert - method exists
assert hasattr(dataset_retrieval, "_get_available_datasets")
# Assert - method is callable
assert callable(dataset_retrieval._get_available_datasets)
# ==================== Test knowledge_retrieval ====================
class TestDatasetRetrievalKnowledgeRetrieval:
"""
Test suite for knowledge_retrieval method.
The knowledge_retrieval method is the main entry point for retrieving
knowledge from datasets. It orchestrates the entire retrieval process:
1. Checks rate limits
2. Gets available datasets
3. Applies metadata filtering if enabled
4. Performs retrieval (single or multiple mode)
5. Formats and returns results
Test Cases:
============
1. Single mode retrieval
2. Multiple mode retrieval
3. Metadata filtering disabled
4. Metadata filtering automatic
5. Metadata filtering manual
6. External documents handling
7. Dify documents handling
8. Empty results handling
9. Rate limit exceeded
10. No available datasets
"""
def test_knowledge_retrieval_single_mode_basic(self):
"""
Test knowledge_retrieval in single retrieval mode - basic check.
Note: Full single mode testing requires complex model mocking and
is better suited for integration tests. This test verifies the
method accepts single mode requests.
Verifies:
- Method can accept single mode request
- Request parameters are correctly structured
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="single",
model_provider="openai",
model_name="gpt-4",
model_mode="chat",
completion_params={"temperature": 0.7},
)
# Assert - request is properly structured
assert request.retrieval_mode == "single"
assert request.model_provider == "openai"
assert request.model_name == "gpt-4"
assert request.model_mode == "chat"
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor):
"""
Test knowledge_retrieval in multiple retrieval mode.
In multiple mode, retrieval is performed across all datasets and
results are combined and reranked.
Verifies:
- Rate limit is checked
- Available datasets are retrieved
- Multiple retrieval is performed
- Results are combined and reranked
- Results are formatted correctly
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id1 = str(uuid4())
dataset_id2 = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id1, dataset_id2],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
score_threshold=0.7,
reranking_enable=True,
reranking_mode="reranking_model",
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
)
dataset_retrieval = DatasetRetrieval()
# Mock _check_knowledge_rate_limit
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Mock _get_available_datasets
mock_dataset1 = create_mock_dataset(dataset_id=dataset_id1, tenant_id=tenant_id)
mock_dataset2 = create_mock_dataset(dataset_id=dataset_id2, tenant_id=tenant_id)
with patch.object(
dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2]
):
# Mock get_metadata_filter_condition
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Mock multiple_retrieve to return documents
doc1 = create_mock_document("Python is great", "doc1", score=0.9)
doc2 = create_mock_document("Python is awesome", "doc2", score=0.8)
with patch.object(
dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2]
) as mock_multiple_retrieve:
# Mock format_retrieval_documents
mock_record = Mock()
mock_record.segment = Mock()
mock_record.segment.dataset_id = dataset_id1
mock_record.segment.document_id = str(uuid4())
mock_record.segment.index_node_hash = "hash123"
mock_record.segment.hit_count = 5
mock_record.segment.word_count = 100
mock_record.segment.position = 1
mock_record.segment.get_sign_content.return_value = "Python is great"
mock_record.segment.answer = None
mock_record.score = 0.9
mock_record.child_chunks = []
mock_record.summary = None
mock_record.files = None
mock_retrieval_service = Mock()
mock_retrieval_service.format_retrieval_documents.return_value = [mock_record]
with patch(
"core.rag.retrieval.dataset_retrieval.RetrievalService",
return_value=mock_retrieval_service,
):
# Mock database queries
mock_session = MagicMock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_session_factory.create_session.return_value.__exit__.return_value = None
mock_dataset_from_db = Mock()
mock_dataset_from_db.id = dataset_id1
mock_dataset_from_db.name = "test_dataset"
mock_document = Mock()
mock_document.id = str(uuid4())
mock_document.name = "test_doc"
mock_document.data_source_type = "upload_file"
mock_document.doc_metadata = {}
mock_session.query.return_value.filter.return_value.all.return_value = [
mock_dataset_from_db
]
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
[mock_dataset_from_db, mock_document]
)
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
mock_multiple_retrieve.assert_called_once()
def test_knowledge_retrieval_metadata_filtering_disabled(self):
"""
Test knowledge_retrieval with metadata filtering disabled.
When metadata filtering is disabled, get_metadata_filter_condition is
NOT called (the method checks metadata_filtering_mode != "disabled").
Verifies:
- get_metadata_filter_condition is NOT called when mode is "disabled"
- Retrieval proceeds without metadata filters
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
metadata_filtering_mode="disabled",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
# Mock get_metadata_filter_condition - should NOT be called when disabled
with patch.object(
dataset_retrieval,
"get_metadata_filter_condition",
return_value=(None, None),
) as mock_get_metadata:
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
# get_metadata_filter_condition should NOT be called when mode is "disabled"
mock_get_metadata.assert_not_called()
def test_knowledge_retrieval_with_external_documents(self):
"""
Test knowledge_retrieval with external documents.
External documents come from external knowledge bases and should
be formatted differently than Dify documents.
Verifies:
- External documents are handled correctly
- Provider is set to "external"
- Metadata includes external-specific fields
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id, provider="external")
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Create external document
external_doc = create_mock_document(
"External knowledge",
"doc1",
score=0.9,
provider="external",
additional_metadata={
"dataset_id": dataset_id,
"dataset_name": "external_kb",
"document_id": "ext_doc1",
"title": "External Document",
},
)
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert isinstance(result, list)
if result:
assert result[0].metadata.data_source_type == "external"
def test_knowledge_retrieval_empty_results(self):
"""
Test knowledge_retrieval when no documents are found.
Verifies:
- Empty list is returned
- No errors are raised
- All dependencies are still called
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
# Mock multiple_retrieve to return empty list
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_rate_limit_exceeded(self):
"""
Test knowledge_retrieval when rate limit is exceeded.
Verifies:
- RateLimitExceededError is raised
- No further processing occurs
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock _check_knowledge_rate_limit to raise exception
with patch.object(
dataset_retrieval,
"_check_knowledge_rate_limit",
side_effect=exc.RateLimitExceededError("Rate limit exceeded"),
):
# Act & Assert
with pytest.raises(exc.RateLimitExceededError):
dataset_retrieval.knowledge_retrieval(request)
def test_knowledge_retrieval_no_available_datasets(self):
"""
Test knowledge_retrieval when no datasets are available.
Verifies:
- Empty list is returned
- No retrieval is attempted
"""
# Arrange
tenant_id = str(uuid4())
user_id = str(uuid4())
app_id = str(uuid4())
dataset_id = str(uuid4())
request = KnowledgeRetrievalRequest(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
user_from="web",
dataset_ids=[dataset_id],
query="What is Python?",
retrieval_mode="multiple",
top_k=5,
)
dataset_retrieval = DatasetRetrieval()
# Mock dependencies
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
# Mock _get_available_datasets to return empty list
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]):
# Act
result = dataset_retrieval.knowledge_retrieval(request)
# Assert
assert result == []
def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self):
"""
Test that knowledge_retrieval processes multiple documents with different scores.
Note: Full sorting and position testing requires complex SQLAlchemy mocking
which is better suited for integration tests. This test verifies documents
with different scores can be created and have their metadata.
Verifies:
- Documents can be created with different scores
- Score metadata is properly set
"""
# Create documents with different scores
doc1 = create_mock_document("Low score", "doc1", score=0.6)
doc2 = create_mock_document("High score", "doc2", score=0.95)
doc3 = create_mock_document("Medium score", "doc3", score=0.8)
# Assert - each document has the correct score
assert doc1.metadata["score"] == 0.6
assert doc2.metadata["score"] == 0.95
assert doc3.metadata["score"] == 0.8
# Assert - documents are correctly sorted (not the retrieval result, just the list)
unsorted = [doc1, doc2, doc3]
sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True)
assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6]

View File

@@ -0,0 +1,595 @@
import time
import uuid
from unittest.mock import Mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import StringSegment
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.nodes.knowledge_retrieval.entities import (
KnowledgeRetrievalNodeData,
MultipleRetrievalConfig,
RerankingModelConfig,
SingleRetrievalConfig,
)
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@pytest.fixture
def mock_graph_init_params():
"""Create mock GraphInitParams."""
return GraphInitParams(
tenant_id=str(uuid.uuid4()),
app_id=str(uuid.uuid4()),
workflow_id=str(uuid.uuid4()),
graph_config={},
user_id=str(uuid.uuid4()),
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@pytest.fixture
def mock_graph_runtime_state():
"""Create mock GraphRuntimeState."""
variable_pool = VariablePool(
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
user_inputs={},
environment_variables=[],
conversation_variables=[],
)
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
@pytest.fixture
def mock_rag_retrieval():
"""Create mock RAGRetrievalProtocol."""
mock_retrieval = Mock(spec=RAGRetrievalProtocol)
mock_retrieval.knowledge_retrieval.return_value = []
mock_retrieval.llm_usage = LLMUsage.empty_usage()
return mock_retrieval
@pytest.fixture
def sample_node_data():
"""Create sample KnowledgeRetrievalNodeData."""
return KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
multiple_retrieval_config=MultipleRetrievalConfig(
top_k=5,
score_threshold=0.7,
reranking_mode="reranking_model",
reranking_enable=True,
reranking_model=RerankingModelConfig(
provider="cohere",
model="rerank-v2",
),
),
)
class TestKnowledgeRetrievalNode:
"""
Test suite for KnowledgeRetrievalNode.
"""
def test_node_initialization(self, mock_graph_init_params, mock_graph_runtime_state, mock_rag_retrieval):
"""Test KnowledgeRetrievalNode initialization."""
# Arrange
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": {
"title": "Knowledge Retrieval",
"type": "knowledge-retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
},
}
# Act
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Assert
assert node.id == node_id
assert node._rag_retrieval == mock_rag_retrieval
assert node._llm_file_saver is not None
def test_run_with_no_query_or_attachment(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run returns success when no query or attachment is provided."""
# Arrange
sample_node_data.query_variable_selector = None
sample_node_data.query_attachment_selector = None
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert result.outputs == {}
assert mock_rag_retrieval.knowledge_retrieval.call_count == 0
def test_run_with_query_variable_single_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _run with query variable in single mode."""
# Arrange
from core.workflow.nodes.llm.entities import ModelConfig
query = "What is Python?"
query_selector = ["start", "query"]
# Add query to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="single",
query_variable_selector=query_selector,
single_retrieval_config=SingleRetrievalConfig(
model=ModelConfig(
provider="openai",
name="gpt-4",
mode="chat",
completion_params={"temperature": 0.7},
)
),
)
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": node_data.model_dump(),
}
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "result" in result.outputs
assert mock_rag_retrieval.knowledge_retrieval.called
def test_run_with_query_variable_multiple_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run with query variable in multiple mode."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
# Add query to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
assert "result" in result.outputs
assert mock_rag_retrieval.knowledge_retrieval.called
def test_run_with_invalid_query_variable_type(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run fails when query variable is not StringSegment."""
# Arrange
query_selector = ["start", "query"]
# Add non-string variable to variable pool
mock_graph_runtime_state.variable_pool.add(query_selector, [1, 2, 3])
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Query variable is not string type" in result.error
def test_run_with_invalid_attachment_variable_type(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run fails when attachment variable is not FileSegment or ArrayFileSegment."""
# Arrange
attachment_selector = ["start", "attachments"]
# Add non-file variable to variable pool
mock_graph_runtime_state.variable_pool.add(attachment_selector, "not a file")
sample_node_data.query_attachment_selector = attachment_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Attachments variable is not array file or file type" in result.error
def test_run_with_rate_limit_exceeded(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run handles RateLimitExceededError properly."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval to raise RateLimitExceededError
mock_rag_retrieval.knowledge_retrieval.side_effect = RateLimitExceededError(
"knowledge base request rate limit exceeded"
)
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "rate limit" in result.error.lower()
def test_run_with_generic_exception(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _run handles generic exceptions properly."""
# Arrange
query = "What is Python?"
query_selector = ["start", "query"]
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
sample_node_data.query_variable_selector = query_selector
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
# Mock retrieval to raise generic exception
mock_rag_retrieval.knowledge_retrieval.side_effect = Exception("Unexpected error")
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
result = node._run()
# Assert
assert result.status == WorkflowNodeExecutionStatus.FAILED
assert "Unexpected error" in result.error
def test_extract_variable_selector_to_variable_mapping(self):
"""Test _extract_variable_selector_to_variable_mapping class method."""
# Arrange
node_id = "knowledge_node_1"
node_data = {
"type": "knowledge-retrieval",
"title": "Knowledge Retrieval",
"dataset_ids": [str(uuid.uuid4())],
"retrieval_mode": "multiple",
"query_variable_selector": ["start", "query"],
"query_attachment_selector": ["start", "attachments"],
}
graph_config = {}
# Act
mapping = KnowledgeRetrievalNode._extract_variable_selector_to_variable_mapping(
graph_config=graph_config,
node_id=node_id,
node_data=node_data,
)
# Assert
assert mapping[f"{node_id}.query"] == ["start", "query"]
assert mapping[f"{node_id}.queryAttachment"] == ["start", "attachments"]
class TestFetchDatasetRetriever:
"""
Test suite for _fetch_dataset_retriever method.
"""
def test_fetch_dataset_retriever_single_mode(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _fetch_dataset_retriever in single mode."""
# Arrange
from core.workflow.nodes.llm.entities import ModelConfig
query = "What is Python?"
variables = {"query": query}
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="single",
single_retrieval_config=SingleRetrievalConfig(
model=ModelConfig(
provider="openai",
name="gpt-4",
mode="chat",
completion_params={"temperature": 0.7},
)
),
)
# Mock retrieval response
mock_source = Mock(spec=Source)
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {"id": node_id, "data": node_data.model_dump()}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
# Assert
assert len(results) == 1
assert isinstance(usage, LLMUsage)
assert mock_rag_retrieval.knowledge_retrieval.called
def test_fetch_dataset_retriever_multiple_mode_with_reranking(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
sample_node_data,
):
"""Test _fetch_dataset_retriever in multiple mode with reranking."""
# Arrange
query = "What is Python?"
variables = {"query": query}
# Mock retrieval response
mock_rag_retrieval.knowledge_retrieval.return_value = []
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": sample_node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=sample_node_data, variables=variables)
# Assert
assert isinstance(results, list)
assert isinstance(usage, LLMUsage)
assert mock_rag_retrieval.knowledge_retrieval.called
# Verify reranking parameters via request object
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
request = call_args[1]["request"]
assert request.reranking_enable is True
assert request.reranking_mode == "reranking_model"
def test_fetch_dataset_retriever_multiple_mode_without_reranking(
self,
mock_graph_init_params,
mock_graph_runtime_state,
mock_rag_retrieval,
):
"""Test _fetch_dataset_retriever in multiple mode without reranking."""
# Arrange
query = "What is Python?"
variables = {"query": query}
node_data = KnowledgeRetrievalNodeData(
title="Knowledge Retrieval",
type="knowledge-retrieval",
dataset_ids=[str(uuid.uuid4())],
retrieval_mode="multiple",
multiple_retrieval_config=MultipleRetrievalConfig(
top_k=5,
score_threshold=0.7,
reranking_enable=False,
reranking_mode="reranking_model",
),
)
# Mock retrieval response
mock_rag_retrieval.knowledge_retrieval.return_value = []
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
node_id = str(uuid.uuid4())
config = {
"id": node_id,
"data": node_data.model_dump(),
}
node = KnowledgeRetrievalNode(
id=node_id,
config=config,
graph_init_params=mock_graph_init_params,
graph_runtime_state=mock_graph_runtime_state,
rag_retrieval=mock_rag_retrieval,
)
# Act
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
# Assert
assert isinstance(results, list)
assert mock_rag_retrieval.knowledge_retrieval.called
# Verify reranking is disabled
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
request = call_args[1]["request"]
assert request.reranking_enable is False
def test_version_method(self):
"""Test version class method."""
# Act
version = KnowledgeRetrievalNode.version()
# Assert
assert version == "1"