mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 15:10:13 -05:00
refactor: decouple database operations from knowledge retrieval nodes (#31981)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -50,14 +50,12 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
@@ -122,11 +120,6 @@ ignore_imports =
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
@@ -146,7 +139,6 @@ ignore_imports =
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
@@ -162,9 +154,6 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@@ -213,7 +202,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
@@ -229,7 +217,6 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
@@ -287,8 +274,6 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
|
||||
@@ -8,6 +8,7 @@ from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
@@ -16,6 +17,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
@@ -75,6 +77,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -140,6 +143,15 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@@ -18,6 +20,7 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@@ -58,12 +61,30 @@ from core.rag.retrieval.template_prompts import (
|
||||
)
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetQuery,
|
||||
DocumentSegment,
|
||||
RateLimitLog,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@@ -73,6 +94,8 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
@@ -91,6 +114,233 @@ class DatasetRetrieval:
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
self._check_knowledge_rate_limit(request.tenant_id)
|
||||
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
|
||||
available_datasets_ids = [i.id for i in available_datasets]
|
||||
if not available_datasets_ids:
|
||||
return []
|
||||
|
||||
if not request.query:
|
||||
return []
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
if request.metadata_filtering_mode != "disabled":
|
||||
# Convert workflow layer types to app_config layer types
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
|
||||
app_metadata_filtering_conditions = None
|
||||
if request.metadata_filtering_conditions is not None:
|
||||
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
|
||||
request.metadata_filtering_conditions.model_dump()
|
||||
)
|
||||
|
||||
query = request.query if request.query is not None else ""
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
dataset_ids=available_datasets_ids,
|
||||
query=query,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
metadata_filtering_mode=request.metadata_filtering_mode,
|
||||
metadata_model_config=app_metadata_model_config,
|
||||
metadata_filtering_conditions=app_metadata_filtering_conditions,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
# Ensure required fields are not None for single retrieval mode
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=request.model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise exc.ModelCredentialsNotInitializedError(
|
||||
f"Model {request.model_name} credentials is not initialized."
|
||||
)
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
|
||||
|
||||
stop = []
|
||||
completion_params = (request.completion_params or {}).copy()
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
model_schema=model_schema,
|
||||
mode=request.model_mode or "chat",
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
all_documents = self.single_retrieve(
|
||||
request.app_id,
|
||||
request.tenant_id,
|
||||
request.user_id,
|
||||
request.user_from,
|
||||
request.query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
None, # message_id
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
else:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id=request.app_id,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
user_from=request.user_from,
|
||||
available_datasets=available_datasets,
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
score_threshold=request.score_threshold,
|
||||
reranking_mode=request.reranking_mode,
|
||||
reranking_model=request.reranking_model,
|
||||
weights=request.weights,
|
||||
reranking_enable=request.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=request.attachment_ids,
|
||||
)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from="workflow",
|
||||
score=item.metadata.get("score"),
|
||||
doc_metadata=item.metadata,
|
||||
),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
dataset_ids = [i.segment.dataset_id for i in records]
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = dataset_map.get(segment.dataset_id)
|
||||
document = document_map.get(segment.document_id)
|
||||
|
||||
if dataset and document:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from="workflow",
|
||||
score=record.score or 0.0,
|
||||
segment_hit_count=segment.hit_count,
|
||||
segment_word_count=segment.word_count,
|
||||
segment_position=segment.position,
|
||||
segment_index_node_hash=segment.index_node_hash,
|
||||
doc_metadata=document.doc_metadata,
|
||||
child_chunks=[
|
||||
SourceChildChunk(
|
||||
id=str(getattr(chunk, "id", "")),
|
||||
content=str(getattr(chunk, "content", "")),
|
||||
position=int(getattr(chunk, "position", 0)),
|
||||
score=float(getattr(chunk, "score", 0.0)),
|
||||
)
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
position=None,
|
||||
),
|
||||
title=document.name,
|
||||
files=list(record.files) if record.files else None,
|
||||
content=segment.get_sign_content(),
|
||||
)
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
|
||||
if record.summary:
|
||||
source.summary = record.summary
|
||||
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if retrieval_resource_list:
|
||||
|
||||
def _score(item: Source) -> float:
|
||||
meta = item.metadata
|
||||
score = meta.score
|
||||
if isinstance(score, (int, float)):
|
||||
return float(score)
|
||||
return 0.0
|
||||
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=_score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item.metadata.position = position # type: ignore[index]
|
||||
return retrieval_resource_list
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
@@ -150,14 +400,7 @@ class DatasetRetrieval:
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = []
|
||||
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
@@ -1161,7 +1404,6 @@ class DatasetRetrieval:
|
||||
query=query or "",
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
invoke_result = cast(
|
||||
@@ -1192,7 +1434,8 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
@@ -1406,7 +1649,12 @@ class DatasetRetrieval:
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
full_text += text
|
||||
if isinstance(text, str):
|
||||
full_text += text
|
||||
elif isinstance(text, list):
|
||||
for i in text:
|
||||
if i.data:
|
||||
full_text += i.data
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -1524,3 +1772,53 @@ class DatasetRetrieval:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
DocumentModel.archived == False,
|
||||
DocumentModel.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(DocumentModel.dataset_id)
|
||||
.having(func.count(DocumentModel.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
return available_datasets
|
||||
|
||||
def _check_knowledge_rate_limit(self, tenant_id: str):
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with session_factory.create_session() as session:
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
raise exc.RateLimitExceededError(
|
||||
"you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
|
||||
@@ -20,3 +20,7 @@ class ModelQuotaExceededError(KnowledgeRetrievalNodeError):
|
||||
|
||||
class InvalidModelTypeError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the model is not a Large Language Model."""
|
||||
|
||||
|
||||
class RateLimitExceededError(KnowledgeRetrievalNodeError):
|
||||
"""Raised when the rate limit is exceeded."""
|
||||
|
||||
@@ -1,29 +1,10 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from sqlalchemy import and_, func, or_, select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from core.app.app_config.entities import DatasetRetrieveConfigEntity
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.simple_prompt_transform import ModelMode
|
||||
from core.rag.datasource.retrieval_service import RetrievalService
|
||||
from core.rag.entities.metadata_entities import Condition, MetadataCondition
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.rag.retrieval.retrieval_methods import RetrievalMethod
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
FileSegment,
|
||||
@@ -36,35 +17,16 @@ from core.workflow.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from core.workflow.node_events import ModelInvokeCompletedEvent, NodeRunResult
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base import LLMUsageTrackingMixin
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.knowledge_retrieval.template_prompts import (
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_1,
|
||||
METADATA_FILTER_ASSISTANT_PROMPT_2,
|
||||
METADATA_FILTER_COMPLETION_PROMPT,
|
||||
METADATA_FILTER_SYSTEM_PROMPT,
|
||||
METADATA_FILTER_USER_PROMPT_1,
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.workflow.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, ModelConfig
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models.dataset import Dataset, DatasetMetadata, Document, RateLimitLog
|
||||
from services.feature_service import FeatureService
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest, RAGRetrievalProtocol, Source
|
||||
|
||||
from .entities import KnowledgeRetrievalNodeData
|
||||
from .exc import (
|
||||
InvalidModelTypeError,
|
||||
KnowledgeRetrievalNodeError,
|
||||
ModelCredentialsNotInitializedError,
|
||||
ModelNotExistError,
|
||||
ModelNotSupportedError,
|
||||
ModelQuotaExceededError,
|
||||
RateLimitExceededError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -73,14 +35,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
default_retrieval_model = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
"reranking_enable": False,
|
||||
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
|
||||
class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeData]):
|
||||
node_type = NodeType.KNOWLEDGE_RETRIEVAL
|
||||
@@ -97,6 +51,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
rag_retrieval: RAGRetrievalProtocol,
|
||||
*,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -108,6 +63,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
self._rag_retrieval = rag_retrieval
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
@@ -121,6 +77,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
usage = LLMUsage.empty_usage()
|
||||
if not self._node_data.query_variable_selector and not self._node_data.query_attachment_selector:
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -128,7 +85,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
process_data={},
|
||||
outputs={},
|
||||
metadata={},
|
||||
llm_usage=LLMUsage.empty_usage(),
|
||||
llm_usage=usage,
|
||||
)
|
||||
variables: dict[str, Any] = {}
|
||||
# extract variables
|
||||
@@ -156,36 +113,9 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
else:
|
||||
variables["attachments"] = [variable.value]
|
||||
|
||||
# TODO(-LAN-): Move this check outside.
|
||||
# check rate limit
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(self.tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{self.tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with sessionmaker(db.engine).begin() as session:
|
||||
# add ratelimit record
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=self.tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error="Sorry, you have reached the knowledge base request rate limit of your subscription.",
|
||||
error_type="RateLimitExceeded",
|
||||
)
|
||||
|
||||
# retrieve knowledge
|
||||
usage = LLMUsage.empty_usage()
|
||||
try:
|
||||
results, usage = self._fetch_dataset_retriever(node_data=self._node_data, variables=variables)
|
||||
outputs = {"result": ArrayObjectSegment(value=results)}
|
||||
outputs = {"result": ArrayObjectSegment(value=[item.model_dump() for item in results])}
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@@ -198,9 +128,17 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
llm_usage=usage,
|
||||
)
|
||||
|
||||
except RateLimitExceededError as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
error=str(e),
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
except KnowledgeRetrievalNodeError as e:
|
||||
logger.warning("Error when running knowledge retrieval node")
|
||||
logger.warning("Error when running knowledge retrieval node", exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@@ -210,6 +148,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
)
|
||||
# Temporary handle all exceptions from DatasetRetrieval class here.
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=variables,
|
||||
@@ -217,92 +156,47 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
error_type=type(e).__name__,
|
||||
llm_usage=usage,
|
||||
)
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
def _fetch_dataset_retriever(
|
||||
self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any]
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
available_datasets = []
|
||||
) -> tuple[list[Source], LLMUsage]:
|
||||
dataset_ids = node_data.dataset_ids
|
||||
query = variables.get("query")
|
||||
attachments = variables.get("attachments")
|
||||
metadata_filter_document_ids = None
|
||||
metadata_condition = None
|
||||
metadata_usage = LLMUsage.empty_usage()
|
||||
# Subquery: Count the number of available documents for each dataset
|
||||
subquery = (
|
||||
db.session.query(Document.dataset_id, func.count(Document.id).label("available_document_count"))
|
||||
.where(
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(Document.dataset_id)
|
||||
.having(func.count(Document.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
retrieval_resource_list = []
|
||||
|
||||
results = (
|
||||
db.session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == self.tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = "disabled"
|
||||
if node_data.metadata_filtering_mode is not None:
|
||||
metadata_filtering_mode = node_data.metadata_filtering_mode
|
||||
|
||||
# avoid blocking at retrieval
|
||||
db.session.close()
|
||||
|
||||
for dataset in results:
|
||||
# pass if dataset is not available
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
if query:
|
||||
metadata_filter_document_ids, metadata_condition, metadata_usage = self._get_metadata_filter_condition(
|
||||
[dataset.id for dataset in available_datasets], query, node_data
|
||||
)
|
||||
usage = self._merge_usage(usage, metadata_usage)
|
||||
all_documents = []
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
if str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE and query:
|
||||
# fetch model config
|
||||
if node_data.single_retrieval_config is None:
|
||||
raise ValueError("single_retrieval_config is required")
|
||||
model_instance, model_config = self.get_model_config(node_data.single_retrieval_config.model)
|
||||
# check model is support tool calling
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
# get model schema
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
model=model_config.model, credentials=model_config.credentials
|
||||
)
|
||||
|
||||
if model_schema:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
features = model_schema.features
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
all_documents = dataset_retrieval.single_retrieve(
|
||||
available_datasets=available_datasets,
|
||||
raise ValueError("single_retrieval_config is required for single retrieval mode")
|
||||
model = node_data.single_retrieval_config.model
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
app_id=self.app_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE.value,
|
||||
completion_params=model.completion_params,
|
||||
model_provider=model.provider,
|
||||
model_mode=model.mode,
|
||||
model_name=model.name,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
query=query,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
planning_strategy=planning_strategy,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
)
|
||||
)
|
||||
elif str(node_data.retrieval_mode) == DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE:
|
||||
if node_data.multiple_retrieval_config is None:
|
||||
raise ValueError("multiple_retrieval_config is required")
|
||||
reranking_model = None
|
||||
weights = None
|
||||
match node_data.multiple_retrieval_config.reranking_mode:
|
||||
case "reranking_model":
|
||||
if node_data.multiple_retrieval_config.reranking_model:
|
||||
@@ -329,284 +223,36 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
},
|
||||
}
|
||||
case _:
|
||||
# Handle any other reranking_mode values
|
||||
reranking_model = None
|
||||
weights = None
|
||||
all_documents = dataset_retrieval.multiple_retrieve(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
available_datasets=available_datasets,
|
||||
query=query,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
usage = self._merge_usage(usage, dataset_retrieval.llm_usage)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source: dict[str, dict[str, str | Any | dict[Any, Any] | None] | Any | str | None] = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": item.metadata.get("dataset_id"),
|
||||
"dataset_name": item.metadata.get("dataset_name"),
|
||||
"document_id": item.metadata.get("document_id") or item.metadata.get("title"),
|
||||
"document_name": item.metadata.get("title"),
|
||||
"data_source_type": "external",
|
||||
"retriever_from": "workflow",
|
||||
"score": item.metadata.get("score"),
|
||||
"doc_metadata": item.metadata,
|
||||
},
|
||||
"title": item.metadata.get("title"),
|
||||
"content": item.page_content,
|
||||
}
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() # type: ignore
|
||||
stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
document = db.session.scalar(stmt)
|
||||
if dataset and document:
|
||||
source = {
|
||||
"metadata": {
|
||||
"_source": "knowledge",
|
||||
"dataset_id": dataset.id,
|
||||
"dataset_name": dataset.name,
|
||||
"document_id": document.id,
|
||||
"document_name": document.name,
|
||||
"data_source_type": document.data_source_type,
|
||||
"segment_id": segment.id,
|
||||
"retriever_from": "workflow",
|
||||
"score": record.score or 0.0,
|
||||
"child_chunks": [
|
||||
{
|
||||
"id": str(getattr(chunk, "id", "")),
|
||||
"content": str(getattr(chunk, "content", "")),
|
||||
"position": int(getattr(chunk, "position", 0)),
|
||||
"score": float(getattr(chunk, "score", 0.0)),
|
||||
}
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
"segment_hit_count": segment.hit_count,
|
||||
"segment_word_count": segment.word_count,
|
||||
"segment_position": segment.position,
|
||||
"segment_index_node_hash": segment.index_node_hash,
|
||||
"doc_metadata": document.doc_metadata,
|
||||
},
|
||||
"title": document.name,
|
||||
"files": list(record.files) if record.files else None,
|
||||
}
|
||||
if segment.answer:
|
||||
source["content"] = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
else:
|
||||
source["content"] = segment.get_sign_content()
|
||||
# Add summary if available
|
||||
if record.summary:
|
||||
source["summary"] = record.summary
|
||||
retrieval_resource_list.append(source)
|
||||
if retrieval_resource_list:
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=self._score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item["metadata"]["position"] = position # type: ignore[index]
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
def _score(self, item: dict[str, Any]) -> float:
|
||||
meta = item.get("metadata")
|
||||
if isinstance(meta, dict):
|
||||
s = meta.get("score")
|
||||
if isinstance(s, (int, float)):
|
||||
return float(s)
|
||||
return 0.0
|
||||
|
||||
def _get_metadata_filter_condition(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[dict[str, list[str]] | None, MetadataCondition | None, LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
document_query = db.session.query(Document).where(
|
||||
Document.dataset_id.in_(dataset_ids),
|
||||
Document.indexing_status == "completed",
|
||||
Document.enabled == True,
|
||||
Document.archived == False,
|
||||
)
|
||||
filters: list[Any] = []
|
||||
metadata_condition = None
|
||||
match node_data.metadata_filtering_mode:
|
||||
case "disabled":
|
||||
return None, None, usage
|
||||
case "automatic":
|
||||
automatic_metadata_filters, automatic_usage = self._automatic_metadata_filter_func(
|
||||
dataset_ids, query, node_data
|
||||
retrieval_resource_list = self._rag_retrieval.knowledge_retrieval(
|
||||
request=KnowledgeRetrievalRequest(
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
user_id=self.user_id,
|
||||
user_from=self.user_from.value,
|
||||
dataset_ids=dataset_ids,
|
||||
query=query,
|
||||
retrieval_mode=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE.value,
|
||||
top_k=node_data.multiple_retrieval_config.top_k,
|
||||
score_threshold=node_data.multiple_retrieval_config.score_threshold
|
||||
if node_data.multiple_retrieval_config.score_threshold is not None
|
||||
else 0.0,
|
||||
reranking_mode=node_data.multiple_retrieval_config.reranking_mode,
|
||||
reranking_model=reranking_model,
|
||||
weights=weights,
|
||||
reranking_enable=node_data.multiple_retrieval_config.reranking_enable,
|
||||
metadata_model_config=node_data.metadata_model_config,
|
||||
metadata_filtering_conditions=node_data.metadata_filtering_conditions,
|
||||
metadata_filtering_mode=metadata_filtering_mode,
|
||||
attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None,
|
||||
)
|
||||
usage = self._merge_usage(usage, automatic_usage)
|
||||
if automatic_metadata_filters:
|
||||
conditions = []
|
||||
for sequence, filter in enumerate(automatic_metadata_filters):
|
||||
DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
filter.get("condition", ""),
|
||||
filter.get("metadata_name", ""),
|
||||
filter.get("value"),
|
||||
filters,
|
||||
)
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=filter.get("metadata_name"), # type: ignore
|
||||
comparison_operator=filter.get("condition"), # type: ignore
|
||||
value=filter.get("value"),
|
||||
)
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator
|
||||
if node_data.metadata_filtering_conditions
|
||||
else "or",
|
||||
conditions=conditions,
|
||||
)
|
||||
case "manual":
|
||||
if node_data.metadata_filtering_conditions:
|
||||
conditions = []
|
||||
for sequence, condition in enumerate(node_data.metadata_filtering_conditions.conditions): # type: ignore
|
||||
metadata_name = condition.name
|
||||
expected_value = condition.value
|
||||
if expected_value is not None and condition.comparison_operator not in ("empty", "not empty"):
|
||||
if isinstance(expected_value, str):
|
||||
expected_value = self.graph_runtime_state.variable_pool.convert_template(
|
||||
expected_value
|
||||
).value[0]
|
||||
if expected_value.value_type in {"number", "integer", "float"}:
|
||||
expected_value = expected_value.value
|
||||
elif expected_value.value_type == "string":
|
||||
expected_value = re.sub(r"[\r\n\t]+", " ", expected_value.text).strip()
|
||||
else:
|
||||
raise ValueError("Invalid expected metadata value type")
|
||||
conditions.append(
|
||||
Condition(
|
||||
name=metadata_name,
|
||||
comparison_operator=condition.comparison_operator,
|
||||
value=expected_value,
|
||||
)
|
||||
)
|
||||
filters = DatasetRetrieval.process_metadata_filter_func(
|
||||
sequence,
|
||||
condition.comparison_operator,
|
||||
metadata_name,
|
||||
expected_value,
|
||||
filters,
|
||||
)
|
||||
metadata_condition = MetadataCondition(
|
||||
logical_operator=node_data.metadata_filtering_conditions.logical_operator,
|
||||
conditions=conditions,
|
||||
)
|
||||
case _:
|
||||
raise ValueError("Invalid metadata filtering mode")
|
||||
if filters:
|
||||
if (
|
||||
node_data.metadata_filtering_conditions
|
||||
and node_data.metadata_filtering_conditions.logical_operator == "and"
|
||||
):
|
||||
document_query = document_query.where(and_(*filters))
|
||||
else:
|
||||
document_query = document_query.where(or_(*filters))
|
||||
documents = document_query.all()
|
||||
# group by dataset_id
|
||||
metadata_filter_document_ids = defaultdict(list) if documents else None # type: ignore
|
||||
for document in documents:
|
||||
metadata_filter_document_ids[document.dataset_id].append(document.id) # type: ignore
|
||||
return metadata_filter_document_ids, metadata_condition, usage
|
||||
|
||||
def _automatic_metadata_filter_func(
|
||||
self, dataset_ids: list, query: str, node_data: KnowledgeRetrievalNodeData
|
||||
) -> tuple[list[dict[str, Any]], LLMUsage]:
|
||||
usage = LLMUsage.empty_usage()
|
||||
# get all metadata field
|
||||
stmt = select(DatasetMetadata).where(DatasetMetadata.dataset_id.in_(dataset_ids))
|
||||
metadata_fields = db.session.scalars(stmt).all()
|
||||
all_metadata_fields = [metadata_field.name for metadata_field in metadata_fields]
|
||||
if node_data.metadata_model_config is None:
|
||||
raise ValueError("metadata_model_config is required")
|
||||
# get metadata model instance and fetch model config
|
||||
model_instance, model_config = self.get_model_config(node_data.metadata_model_config)
|
||||
# fetch prompt messages
|
||||
prompt_template = self._get_prompt_template(
|
||||
node_data=node_data,
|
||||
metadata_fields=all_metadata_fields,
|
||||
query=query or "",
|
||||
)
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query=query,
|
||||
memory=None,
|
||||
model_config=model_config,
|
||||
sys_files=[],
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
node_data_model=node_data.metadata_model_config,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
stop=stop,
|
||||
user_id=self.user_id,
|
||||
structured_output_enabled=self.node_data.structured_output_enabled,
|
||||
structured_output=None,
|
||||
file_saver=self._llm_file_saver,
|
||||
file_outputs=self._file_outputs,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
)
|
||||
|
||||
for event in generator:
|
||||
if isinstance(event, ModelInvokeCompletedEvent):
|
||||
result_text = event.text
|
||||
usage = self._merge_usage(usage, event.usage)
|
||||
break
|
||||
|
||||
result_text_json = parse_and_check_json_markdown(result_text, [])
|
||||
automatic_metadata_filters = []
|
||||
if "metadata_map" in result_text_json:
|
||||
metadata_map = result_text_json["metadata_map"]
|
||||
for item in metadata_map:
|
||||
if item.get("metadata_field_name") in all_metadata_fields:
|
||||
automatic_metadata_filters.append(
|
||||
{
|
||||
"metadata_name": item.get("metadata_field_name"),
|
||||
"value": item.get("metadata_field_value"),
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception:
|
||||
return [], usage
|
||||
return automatic_metadata_filters, usage
|
||||
usage = self._rag_retrieval.llm_usage
|
||||
return retrieval_resource_list, usage
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -626,107 +272,3 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD
|
||||
if typed_node_data.query_attachment_selector:
|
||||
variable_mapping[node_id + ".queryAttachment"] = typed_node_data.query_attachment_selector
|
||||
return variable_mapping
|
||||
|
||||
def get_model_config(self, model: ModelConfig) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model_name = model.name
|
||||
provider_name = model.provider
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise ModelCredentialsNotInitializedError(f"Model {model_name} credentials is not initialized.")
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise ModelNotSupportedError(f"Dify Hosted OpenAI {model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise ModelQuotaExceededError(f"Model provider {provider_name} quota exceeded.")
|
||||
|
||||
# model config
|
||||
completion_params = model.completion_params
|
||||
stop = []
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
# get model mode
|
||||
model_mode = model.mode
|
||||
if not model_mode:
|
||||
raise ModelNotExistError("LLM mode is required.")
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=provider_name,
|
||||
model=model_name,
|
||||
model_schema=model_schema,
|
||||
mode=model_mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
|
||||
def _get_prompt_template(self, node_data: KnowledgeRetrievalNodeData, metadata_fields: list, query: str):
|
||||
model_mode = ModelMode(node_data.metadata_model_config.mode) # type: ignore
|
||||
input_text = query
|
||||
|
||||
prompt_messages: list[LLMNodeChatModelMessage] = []
|
||||
if model_mode == ModelMode.CHAT:
|
||||
system_prompt_messages = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.SYSTEM, text=METADATA_FILTER_SYSTEM_PROMPT
|
||||
)
|
||||
prompt_messages.append(system_prompt_messages)
|
||||
user_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_1)
|
||||
assistant_prompt_message_1 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_1
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_1)
|
||||
user_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER, text=METADATA_FILTER_USER_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_2)
|
||||
assistant_prompt_message_2 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.ASSISTANT, text=METADATA_FILTER_ASSISTANT_PROMPT_2
|
||||
)
|
||||
prompt_messages.append(assistant_prompt_message_2)
|
||||
user_prompt_message_3 = LLMNodeChatModelMessage(
|
||||
role=PromptMessageRole.USER,
|
||||
text=METADATA_FILTER_USER_PROMPT_3.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
),
|
||||
)
|
||||
prompt_messages.append(user_prompt_message_3)
|
||||
return prompt_messages
|
||||
elif model_mode == ModelMode.COMPLETION:
|
||||
return LLMNodeCompletionModelPromptTemplate(
|
||||
text=METADATA_FILTER_COMPLETION_PROMPT.format(
|
||||
input_text=input_text,
|
||||
metadata_fields=json.dumps(metadata_fields, ensure_ascii=False),
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
raise InvalidModelTypeError(f"Model mode {model_mode} not support.")
|
||||
|
||||
108
api/core/workflow/repositories/rag_retrieval_protocol.py
Normal file
108
api/core/workflow/repositories/rag_retrieval_protocol.py
Normal file
@@ -0,0 +1,108 @@
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.model_runtime.entities import LLMUsage
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import MetadataFilteringCondition
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
|
||||
class SourceChildChunk(BaseModel):
|
||||
id: str = Field(default="", description="Child chunk ID")
|
||||
content: str = Field(default="", description="Child chunk content")
|
||||
position: int = Field(default=0, description="Child chunk position")
|
||||
score: float = Field(default=0.0, description="Child chunk relevance score")
|
||||
|
||||
|
||||
class SourceMetadata(BaseModel):
|
||||
source: str = Field(
|
||||
default="knowledge",
|
||||
serialization_alias="_source",
|
||||
description="Data source identifier",
|
||||
)
|
||||
dataset_id: str = Field(description="Dataset unique identifier")
|
||||
dataset_name: str = Field(description="Dataset display name")
|
||||
document_id: str = Field(description="Document unique identifier")
|
||||
document_name: str = Field(description="Document display name")
|
||||
data_source_type: str = Field(description="Type of data source")
|
||||
segment_id: str | None = Field(default=None, description="Segment unique identifier")
|
||||
retriever_from: str = Field(default="workflow", description="Retriever source context")
|
||||
score: float = Field(default=0.0, description="Retrieval relevance score")
|
||||
child_chunks: list[SourceChildChunk] = Field(default=[], description="List of child chunks")
|
||||
segment_hit_count: int | None = Field(default=0, description="Number of times segment was retrieved")
|
||||
segment_word_count: int | None = Field(default=0, description="Word count of the segment")
|
||||
segment_position: int | None = Field(default=0, description="Position of segment in document")
|
||||
segment_index_node_hash: str | None = Field(default=None, description="Hash of index node for the segment")
|
||||
doc_metadata: dict[str, Any] | None = Field(default=None, description="Additional document metadata")
|
||||
position: int | None = Field(default=0, description="Position of the document in the dataset")
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
|
||||
class Source(BaseModel):
|
||||
metadata: SourceMetadata = Field(description="Source metadata information")
|
||||
title: str = Field(description="Document title")
|
||||
files: list[Any] | None = Field(default=None, description="Associated file references")
|
||||
content: str | None = Field(description="Segment content text")
|
||||
summary: str | None = Field(default=None, description="Content summary if available")
|
||||
|
||||
|
||||
class KnowledgeRetrievalRequest(BaseModel):
|
||||
tenant_id: str = Field(description="Tenant unique identifier")
|
||||
user_id: str = Field(description="User unique identifier")
|
||||
app_id: str = Field(description="Application unique identifier")
|
||||
user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')")
|
||||
dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from")
|
||||
query: str | None = Field(default=None, description="Query text for knowledge retrieval")
|
||||
retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'")
|
||||
model_provider: str | None = Field(default=None, description="Model provider name (e.g., 'openai', 'anthropic')")
|
||||
completion_params: dict[str, Any] | None = Field(
|
||||
default=None, description="Model completion parameters (e.g., temperature, max_tokens)"
|
||||
)
|
||||
model_mode: str | None = Field(default=None, description="Model mode (e.g., 'chat', 'completion')")
|
||||
model_name: str | None = Field(default=None, description="Model name (e.g., 'gpt-4', 'claude-3-opus')")
|
||||
metadata_model_config: ModelConfig | None = Field(
|
||||
default=None, description="Model config for metadata-based filtering"
|
||||
)
|
||||
metadata_filtering_conditions: MetadataFilteringCondition | None = Field(
|
||||
default=None, description="Conditions for filtering by metadata"
|
||||
)
|
||||
metadata_filtering_mode: Literal["disabled", "automatic", "manual"] = Field(
|
||||
default="disabled", description="Metadata filtering mode: 'disabled', 'automatic', or 'manual'"
|
||||
)
|
||||
top_k: int = Field(default=0, description="Number of top results to return")
|
||||
score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold")
|
||||
reranking_mode: str = Field(default="reranking_model", description="Reranking strategy")
|
||||
reranking_model: dict | None = Field(default=None, description="Reranking model configuration")
|
||||
weights: dict[str, Any] | None = Field(default=None, description="Weights for weighted score reranking")
|
||||
reranking_enable: bool = Field(default=True, description="Whether reranking is enabled")
|
||||
attachment_ids: list[str] | None = Field(default=None, description="List of attachment file IDs for retrieval")
|
||||
|
||||
|
||||
class RAGRetrievalProtocol(Protocol):
|
||||
"""Protocol for RAG-based knowledge retrieval implementations.
|
||||
|
||||
Implementations of this protocol handle knowledge retrieval from datasets
|
||||
including rate limiting, dataset filtering, and document retrieval.
|
||||
"""
|
||||
|
||||
@property
|
||||
def llm_usage(self) -> LLMUsage:
|
||||
"""Return accumulated LLM usage for retrieval operations."""
|
||||
...
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
"""Retrieve knowledge from datasets based on the provided request.
|
||||
|
||||
Args:
|
||||
request: Knowledge retrieval request with search parameters
|
||||
|
||||
Returns:
|
||||
List of sources matching the search criteria
|
||||
|
||||
Raises:
|
||||
RateLimitExceededError: If rate limit is exceeded
|
||||
ModelNotExistError: If specified model doesn't exist
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Integration tests for KnowledgeRetrievalNode.
|
||||
|
||||
This module provides integration tests for KnowledgeRetrievalNode with real database interactions.
|
||||
|
||||
Note: These tests require database setup and are more complex than unit tests.
|
||||
For now, we focus on unit tests which provide better coverage for the node logic.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalNodeIntegration:
|
||||
"""
|
||||
Integration test suite for KnowledgeRetrievalNode.
|
||||
|
||||
Note: Full integration tests require:
|
||||
- Database setup with datasets and documents
|
||||
- Vector store for embeddings
|
||||
- Model providers for retrieval
|
||||
|
||||
For now, unit tests provide comprehensive coverage of the node logic.
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="Integration tests require full database and vector store setup")
|
||||
def test_end_to_end_knowledge_retrieval(self):
|
||||
"""Test end-to-end knowledge retrieval workflow."""
|
||||
# TODO: Implement with real database
|
||||
pass
|
||||
@@ -0,0 +1,614 @@
|
||||
import uuid
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset, Document
|
||||
from services.account_service import AccountService, TenantService
|
||||
|
||||
|
||||
class TestGetAvailableDatasetsIntegration:
|
||||
def test_returns_datasets_with_available_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
# Create account and tenant
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
description=fake.text(max_nb_chars=100),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.flush()
|
||||
|
||||
# Create documents with completed status, enabled, not archived
|
||||
for i in range(3):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
name=f"Document {i}",
|
||||
created_from="web",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
doc_language="en",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset.id
|
||||
assert result[0].tenant_id == tenant.id
|
||||
assert result[0].name == dataset.name
|
||||
|
||||
def test_filters_out_datasets_with_only_archived_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create only archived documents
|
||||
for i in range(2):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Archived Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=True, # Archived
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_only_disabled_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create only disabled documents
|
||||
for i in range(2):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Disabled Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=False, # Disabled
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_filters_out_datasets_with_non_completed_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
# Create documents with non-completed status
|
||||
for i, status in enumerate(["indexing", "parsing", "splitting"]):
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=i,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document {status}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status=status, # Not completed
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 0
|
||||
|
||||
def test_includes_external_datasets_without_documents(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
"""
|
||||
Test that external datasets are returned even with no available documents.
|
||||
|
||||
External datasets (e.g., from external knowledge bases) don't have
|
||||
documents stored in Dify's database, so they should always be available.
|
||||
|
||||
Verifies:
|
||||
- External datasets are included in results
|
||||
- No document count check for external datasets
|
||||
"""
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="external", # External provider
|
||||
data_source_type="external",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [dataset.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset.id
|
||||
assert result[0].provider == "external"
|
||||
|
||||
def test_filters_by_tenant_id(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
# Create two accounts/tenants
|
||||
account1 = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account1, name=fake.company())
|
||||
tenant1 = account1.current_tenant
|
||||
|
||||
account2 = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account2, name=fake.company())
|
||||
tenant2 = account2.current_tenant
|
||||
|
||||
# Create dataset for tenant1
|
||||
dataset1 = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant1.id,
|
||||
name="Tenant 1 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account1.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset1)
|
||||
|
||||
# Create dataset for tenant2
|
||||
dataset2 = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant2.id,
|
||||
name="Tenant 2 Dataset",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account2.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset2)
|
||||
|
||||
# Add documents to both datasets
|
||||
for dataset, account in [(dataset1, account1), (dataset2, account2)]:
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=dataset.tenant_id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document for {dataset.name}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act - request from tenant1, should only get tenant1's dataset
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant1.id, [dataset1.id, dataset2.id])
|
||||
|
||||
# Assert
|
||||
assert len(result) == 1
|
||||
assert result[0].id == dataset1.id
|
||||
assert result[0].tenant_id == tenant1.id
|
||||
|
||||
def test_returns_empty_list_when_no_datasets_found(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Don't create any datasets
|
||||
|
||||
# Act
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, [str(uuid.uuid4())])
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_returns_only_requested_dataset_ids(self, db_session_with_containers, mock_external_service_dependencies):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create multiple datasets
|
||||
datasets = []
|
||||
for i in range(3):
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=f"Dataset {i}",
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
datasets.append(dataset)
|
||||
|
||||
# Add document
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=f"Document {i}",
|
||||
created_by=account.id,
|
||||
doc_form="text_model",
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Act - request only dataset 0 and 2, not dataset 1
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
requested_ids = [datasets[0].id, datasets[2].id]
|
||||
result = dataset_retrieval._get_available_datasets(tenant.id, requested_ids)
|
||||
|
||||
# Assert
|
||||
assert len(result) == 2
|
||||
returned_ids = {d.id for d in result}
|
||||
assert returned_ids == {datasets[0].id, datasets[2].id}
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalIntegration:
|
||||
def test_knowledge_retrieval_with_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
indexing_technique="high_quality",
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
|
||||
document = Document(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
position=0,
|
||||
data_source_type="upload_file",
|
||||
batch=str(uuid.uuid4()), # Required field
|
||||
created_from="web",
|
||||
name=fake.sentence(),
|
||||
created_by=account.id,
|
||||
indexing_status="completed",
|
||||
enabled=True,
|
||||
archived=False,
|
||||
doc_form="text_model",
|
||||
)
|
||||
db_session_with_containers.add(document)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# Create request
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check and retrieval
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_knowledge_retrieval_no_available_datasets(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
# Create dataset but no documents
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_rate_limit_exceeded(
|
||||
self, db_session_with_containers, mock_external_service_dependencies
|
||||
):
|
||||
# Arrange
|
||||
fake = Faker()
|
||||
|
||||
account = AccountService.create_account(
|
||||
email=fake.email(),
|
||||
name=fake.name(),
|
||||
interface_language="en-US",
|
||||
password=fake.password(length=12),
|
||||
)
|
||||
TenantService.create_owner_tenant_if_not_exist(account, name=fake.company())
|
||||
tenant = account.current_tenant
|
||||
|
||||
dataset = Dataset(
|
||||
id=str(uuid.uuid4()),
|
||||
tenant_id=tenant.id,
|
||||
name=fake.company(),
|
||||
provider="dify",
|
||||
data_source_type="upload_file",
|
||||
created_by=account.id,
|
||||
)
|
||||
db_session_with_containers.add(dataset)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant.id,
|
||||
user_id=account.id,
|
||||
app_id=str(uuid.uuid4()),
|
||||
user_from="web",
|
||||
dataset_ids=[dataset.id],
|
||||
query="test query",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit check to raise exception
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"_check_knowledge_rate_limit",
|
||||
side_effect=Exception("Rate limit exceeded"),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(Exception, match="Rate limit exceeded"):
|
||||
dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_service_dependencies():
|
||||
with (
|
||||
patch("services.account_service.FeatureService") as mock_account_feature_service,
|
||||
):
|
||||
# Setup default mock returns for account service
|
||||
mock_account_feature_service.get_system_features.return_value.is_allow_register = True
|
||||
|
||||
yield {
|
||||
"account_feature_service": mock_account_feature_service,
|
||||
}
|
||||
@@ -0,0 +1,715 @@
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.rag.models.document import Document
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import KnowledgeRetrievalRequest
|
||||
from models.dataset import Dataset
|
||||
|
||||
# ==================== Helper Functions ====================
|
||||
|
||||
|
||||
def create_mock_dataset(
|
||||
dataset_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
provider: str = "dify",
|
||||
indexing_technique: str = "high_quality",
|
||||
available_document_count: int = 10,
|
||||
) -> Mock:
|
||||
"""
|
||||
Create a mock Dataset object for testing.
|
||||
|
||||
Args:
|
||||
dataset_id: Unique identifier for the dataset
|
||||
tenant_id: Tenant ID for the dataset
|
||||
provider: Provider type ("dify" or "external")
|
||||
indexing_technique: Indexing technique ("high_quality" or "economy")
|
||||
available_document_count: Number of available documents
|
||||
|
||||
Returns:
|
||||
Mock: A properly configured Dataset mock
|
||||
"""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id or str(uuid4())
|
||||
dataset.tenant_id = tenant_id or str(uuid4())
|
||||
dataset.name = "test_dataset"
|
||||
dataset.provider = provider
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.available_document_count = available_document_count
|
||||
dataset.embedding_model = "text-embedding-ada-002"
|
||||
dataset.embedding_model_provider = "openai"
|
||||
dataset.retrieval_model = {
|
||||
"search_method": "semantic_search",
|
||||
"reranking_enable": False,
|
||||
"top_k": 4,
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
return dataset
|
||||
|
||||
|
||||
def create_mock_document(
|
||||
content: str,
|
||||
doc_id: str,
|
||||
score: float = 0.8,
|
||||
provider: str = "dify",
|
||||
additional_metadata: dict | None = None,
|
||||
) -> Document:
|
||||
"""
|
||||
Create a mock Document object for testing.
|
||||
|
||||
Args:
|
||||
content: The text content of the document
|
||||
doc_id: Unique identifier for the document chunk
|
||||
score: Relevance score (0.0 to 1.0)
|
||||
provider: Document provider ("dify" or "external")
|
||||
additional_metadata: Optional extra metadata fields
|
||||
|
||||
Returns:
|
||||
Document: A properly structured Document object
|
||||
"""
|
||||
metadata = {
|
||||
"doc_id": doc_id,
|
||||
"document_id": str(uuid4()),
|
||||
"dataset_id": str(uuid4()),
|
||||
"score": score,
|
||||
}
|
||||
|
||||
if additional_metadata:
|
||||
metadata.update(additional_metadata)
|
||||
|
||||
return Document(
|
||||
page_content=content,
|
||||
metadata=metadata,
|
||||
provider=provider,
|
||||
)
|
||||
|
||||
|
||||
# ==================== Test _check_knowledge_rate_limit ====================
|
||||
|
||||
|
||||
class TestCheckKnowledgeRateLimit:
|
||||
"""
|
||||
Test suite for _check_knowledge_rate_limit method.
|
||||
|
||||
The _check_knowledge_rate_limit method validates whether a tenant has
|
||||
exceeded their knowledge retrieval rate limit. This is important for:
|
||||
- Preventing abuse of the knowledge retrieval system
|
||||
- Enforcing subscription plan limits
|
||||
- Tracking usage for billing purposes
|
||||
|
||||
Test Cases:
|
||||
============
|
||||
1. Rate limit disabled - no exception raised
|
||||
2. Rate limit enabled but not exceeded - no exception raised
|
||||
3. Rate limit enabled and exceeded - RateLimitExceededError raised
|
||||
4. Redis operations are performed correctly
|
||||
5. RateLimitLog is created when limit is exceeded
|
||||
"""
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
def test_rate_limit_disabled_no_exception(self, mock_redis, mock_feature_service):
|
||||
"""
|
||||
Test that when rate limit is disabled, no exception is raised.
|
||||
|
||||
This test verifies the behavior when the tenant's subscription
|
||||
does not have rate limiting enabled.
|
||||
|
||||
Verifies:
|
||||
- FeatureService.get_knowledge_rate_limit is called
|
||||
- No Redis operations are performed
|
||||
- No exception is raised
|
||||
- Retrieval proceeds normally
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit disabled
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = False
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Act & Assert - should not raise any exception
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify FeatureService was called
|
||||
mock_feature_service.get_knowledge_rate_limit.assert_called_once_with(tenant_id)
|
||||
|
||||
# Verify no Redis operations were performed
|
||||
assert not mock_redis.zadd.called
|
||||
assert not mock_redis.zremrangebyscore.called
|
||||
assert not mock_redis.zcard.called
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.time")
|
||||
def test_rate_limit_enabled_not_exceeded(self, mock_time, mock_redis, mock_feature_service, mock_session_factory):
|
||||
"""
|
||||
Test that when rate limit is enabled but not exceeded, no exception is raised.
|
||||
|
||||
This test simulates a tenant making requests within their rate limit.
|
||||
The Redis sorted set stores timestamps of recent requests, and old
|
||||
requests (older than 60 seconds) are removed.
|
||||
|
||||
Verifies:
|
||||
- Redis zadd is called to track the request
|
||||
- Redis zremrangebyscore removes old entries
|
||||
- Redis zcard returns count within limit
|
||||
- No exception is raised
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit enabled with limit of 100 requests per minute
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = True
|
||||
mock_limit.limit = 100
|
||||
mock_limit.subscription_plan = "professional"
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Mock time
|
||||
current_time = 1234567890000 # Current time in milliseconds
|
||||
mock_time.time.return_value = current_time / 1000 # Return seconds
|
||||
mock_time.time.__mul__ = lambda self, x: int(self * x) # Multiply to get milliseconds
|
||||
|
||||
# Mock Redis operations
|
||||
# zcard returns 50 (within limit of 100)
|
||||
mock_redis.zcard.return_value = 50
|
||||
|
||||
# Mock session_factory.create_session
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
# Act & Assert - should not raise any exception
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify Redis operations
|
||||
expected_key = f"rate_limit_{tenant_id}"
|
||||
mock_redis.zadd.assert_called_once_with(expected_key, {current_time: current_time})
|
||||
mock_redis.zremrangebyscore.assert_called_once_with(expected_key, 0, current_time - 60000)
|
||||
mock_redis.zcard.assert_called_once_with(expected_key)
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.FeatureService")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.redis_client")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.time")
|
||||
def test_rate_limit_enabled_exceeded_raises_exception(
|
||||
self, mock_time, mock_redis, mock_feature_service, mock_session_factory
|
||||
):
|
||||
"""
|
||||
Test that when rate limit is enabled and exceeded, RateLimitExceededError is raised.
|
||||
|
||||
This test simulates a tenant exceeding their rate limit. When the count
|
||||
of recent requests exceeds the limit, an exception should be raised and
|
||||
a RateLimitLog should be created.
|
||||
|
||||
Verifies:
|
||||
- Redis zcard returns count exceeding limit
|
||||
- RateLimitExceededError is raised with correct message
|
||||
- RateLimitLog is created in database
|
||||
- Session operations are performed correctly
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock rate limit enabled with limit of 100 requests per minute
|
||||
mock_limit = Mock()
|
||||
mock_limit.enabled = True
|
||||
mock_limit.limit = 100
|
||||
mock_limit.subscription_plan = "professional"
|
||||
mock_feature_service.get_knowledge_rate_limit.return_value = mock_limit
|
||||
|
||||
# Mock time
|
||||
current_time = 1234567890000
|
||||
mock_time.time.return_value = current_time / 1000
|
||||
|
||||
# Mock Redis operations - return count exceeding limit
|
||||
mock_redis.zcard.return_value = 150 # Exceeds limit of 100
|
||||
|
||||
# Mock session_factory.create_session
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(exc.RateLimitExceededError) as exc_info:
|
||||
dataset_retrieval._check_knowledge_rate_limit(tenant_id)
|
||||
|
||||
# Verify exception message
|
||||
assert "knowledge base request rate limit" in str(exc_info.value)
|
||||
|
||||
# Verify RateLimitLog was created
|
||||
mock_session.add.assert_called_once()
|
||||
added_log = mock_session.add.call_args[0][0]
|
||||
assert added_log.tenant_id == tenant_id
|
||||
assert added_log.subscription_plan == "professional"
|
||||
assert added_log.operation == "knowledge"
|
||||
|
||||
|
||||
# ==================== Test _get_available_datasets ====================
|
||||
|
||||
|
||||
class TestGetAvailableDatasets:
|
||||
"""
|
||||
Test suite for _get_available_datasets method.
|
||||
|
||||
The _get_available_datasets method retrieves datasets that are available
|
||||
for retrieval. A dataset is considered available if:
|
||||
- It belongs to the specified tenant
|
||||
- It's in the list of requested dataset_ids
|
||||
- It has at least one completed, enabled, non-archived document OR
|
||||
- It's an external provider dataset
|
||||
|
||||
Note: Due to SQLAlchemy subquery complexity, full testing is done in
|
||||
integration tests. Unit tests here verify basic behavior.
|
||||
"""
|
||||
|
||||
def test_method_exists_and_has_correct_signature(self):
|
||||
"""
|
||||
Test that the method exists and has the correct signature.
|
||||
|
||||
Verifies:
|
||||
- Method exists on DatasetRetrieval class
|
||||
- Accepts tenant_id and dataset_ids parameters
|
||||
"""
|
||||
# Arrange
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Assert - method exists
|
||||
assert hasattr(dataset_retrieval, "_get_available_datasets")
|
||||
# Assert - method is callable
|
||||
assert callable(dataset_retrieval._get_available_datasets)
|
||||
|
||||
|
||||
# ==================== Test knowledge_retrieval ====================
|
||||
|
||||
|
||||
class TestDatasetRetrievalKnowledgeRetrieval:
|
||||
"""
|
||||
Test suite for knowledge_retrieval method.
|
||||
|
||||
The knowledge_retrieval method is the main entry point for retrieving
|
||||
knowledge from datasets. It orchestrates the entire retrieval process:
|
||||
1. Checks rate limits
|
||||
2. Gets available datasets
|
||||
3. Applies metadata filtering if enabled
|
||||
4. Performs retrieval (single or multiple mode)
|
||||
5. Formats and returns results
|
||||
|
||||
Test Cases:
|
||||
============
|
||||
1. Single mode retrieval
|
||||
2. Multiple mode retrieval
|
||||
3. Metadata filtering disabled
|
||||
4. Metadata filtering automatic
|
||||
5. Metadata filtering manual
|
||||
6. External documents handling
|
||||
7. Dify documents handling
|
||||
8. Empty results handling
|
||||
9. Rate limit exceeded
|
||||
10. No available datasets
|
||||
"""
|
||||
|
||||
def test_knowledge_retrieval_single_mode_basic(self):
|
||||
"""
|
||||
Test knowledge_retrieval in single retrieval mode - basic check.
|
||||
|
||||
Note: Full single mode testing requires complex model mocking and
|
||||
is better suited for integration tests. This test verifies the
|
||||
method accepts single mode requests.
|
||||
|
||||
Verifies:
|
||||
- Method can accept single mode request
|
||||
- Request parameters are correctly structured
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="single",
|
||||
model_provider="openai",
|
||||
model_name="gpt-4",
|
||||
model_mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
|
||||
# Assert - request is properly structured
|
||||
assert request.retrieval_mode == "single"
|
||||
assert request.model_provider == "openai"
|
||||
assert request.model_name == "gpt-4"
|
||||
assert request.model_mode == "chat"
|
||||
|
||||
@patch("core.rag.retrieval.dataset_retrieval.DataPostProcessor")
|
||||
@patch("core.rag.retrieval.dataset_retrieval.session_factory")
|
||||
def test_knowledge_retrieval_multiple_mode(self, mock_session_factory, mock_data_processor):
|
||||
"""
|
||||
Test knowledge_retrieval in multiple retrieval mode.
|
||||
|
||||
In multiple mode, retrieval is performed across all datasets and
|
||||
results are combined and reranked.
|
||||
|
||||
Verifies:
|
||||
- Rate limit is checked
|
||||
- Available datasets are retrieved
|
||||
- Multiple retrieval is performed
|
||||
- Results are combined and reranked
|
||||
- Results are formatted correctly
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id1 = str(uuid4())
|
||||
dataset_id2 = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id1, dataset_id2],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_enable=True,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-v2"},
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock _check_knowledge_rate_limit
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Mock _get_available_datasets
|
||||
mock_dataset1 = create_mock_dataset(dataset_id=dataset_id1, tenant_id=tenant_id)
|
||||
mock_dataset2 = create_mock_dataset(dataset_id=dataset_id2, tenant_id=tenant_id)
|
||||
with patch.object(
|
||||
dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset1, mock_dataset2]
|
||||
):
|
||||
# Mock get_metadata_filter_condition
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Mock multiple_retrieve to return documents
|
||||
doc1 = create_mock_document("Python is great", "doc1", score=0.9)
|
||||
doc2 = create_mock_document("Python is awesome", "doc2", score=0.8)
|
||||
with patch.object(
|
||||
dataset_retrieval, "multiple_retrieve", return_value=[doc1, doc2]
|
||||
) as mock_multiple_retrieve:
|
||||
# Mock format_retrieval_documents
|
||||
mock_record = Mock()
|
||||
mock_record.segment = Mock()
|
||||
mock_record.segment.dataset_id = dataset_id1
|
||||
mock_record.segment.document_id = str(uuid4())
|
||||
mock_record.segment.index_node_hash = "hash123"
|
||||
mock_record.segment.hit_count = 5
|
||||
mock_record.segment.word_count = 100
|
||||
mock_record.segment.position = 1
|
||||
mock_record.segment.get_sign_content.return_value = "Python is great"
|
||||
mock_record.segment.answer = None
|
||||
mock_record.score = 0.9
|
||||
mock_record.child_chunks = []
|
||||
mock_record.summary = None
|
||||
mock_record.files = None
|
||||
|
||||
mock_retrieval_service = Mock()
|
||||
mock_retrieval_service.format_retrieval_documents.return_value = [mock_record]
|
||||
|
||||
with patch(
|
||||
"core.rag.retrieval.dataset_retrieval.RetrievalService",
|
||||
return_value=mock_retrieval_service,
|
||||
):
|
||||
# Mock database queries
|
||||
mock_session = MagicMock()
|
||||
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
|
||||
mock_session_factory.create_session.return_value.__exit__.return_value = None
|
||||
|
||||
mock_dataset_from_db = Mock()
|
||||
mock_dataset_from_db.id = dataset_id1
|
||||
mock_dataset_from_db.name = "test_dataset"
|
||||
|
||||
mock_document = Mock()
|
||||
mock_document.id = str(uuid4())
|
||||
mock_document.name = "test_doc"
|
||||
mock_document.data_source_type = "upload_file"
|
||||
mock_document.doc_metadata = {}
|
||||
|
||||
mock_session.query.return_value.filter.return_value.all.return_value = [
|
||||
mock_dataset_from_db
|
||||
]
|
||||
mock_session.query.return_value.filter.return_value.all.__iter__ = lambda self: iter(
|
||||
[mock_dataset_from_db, mock_document]
|
||||
)
|
||||
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
mock_multiple_retrieve.assert_called_once()
|
||||
|
||||
def test_knowledge_retrieval_metadata_filtering_disabled(self):
|
||||
"""
|
||||
Test knowledge_retrieval with metadata filtering disabled.
|
||||
|
||||
When metadata filtering is disabled, get_metadata_filter_condition is
|
||||
NOT called (the method checks metadata_filtering_mode != "disabled").
|
||||
|
||||
Verifies:
|
||||
- get_metadata_filter_condition is NOT called when mode is "disabled"
|
||||
- Retrieval proceeds without metadata filters
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
metadata_filtering_mode="disabled",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
# Mock get_metadata_filter_condition - should NOT be called when disabled
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"get_metadata_filter_condition",
|
||||
return_value=(None, None),
|
||||
) as mock_get_metadata:
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
# get_metadata_filter_condition should NOT be called when mode is "disabled"
|
||||
mock_get_metadata.assert_not_called()
|
||||
|
||||
def test_knowledge_retrieval_with_external_documents(self):
|
||||
"""
|
||||
Test knowledge_retrieval with external documents.
|
||||
|
||||
External documents come from external knowledge bases and should
|
||||
be formatted differently than Dify documents.
|
||||
|
||||
Verifies:
|
||||
- External documents are handled correctly
|
||||
- Provider is set to "external"
|
||||
- Metadata includes external-specific fields
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id, provider="external")
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Create external document
|
||||
external_doc = create_mock_document(
|
||||
"External knowledge",
|
||||
"doc1",
|
||||
score=0.9,
|
||||
provider="external",
|
||||
additional_metadata={
|
||||
"dataset_id": dataset_id,
|
||||
"dataset_name": "external_kb",
|
||||
"document_id": "ext_doc1",
|
||||
"title": "External Document",
|
||||
},
|
||||
)
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[external_doc]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, list)
|
||||
if result:
|
||||
assert result[0].metadata.data_source_type == "external"
|
||||
|
||||
def test_knowledge_retrieval_empty_results(self):
|
||||
"""
|
||||
Test knowledge_retrieval when no documents are found.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned
|
||||
- No errors are raised
|
||||
- All dependencies are still called
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
mock_dataset = create_mock_dataset(dataset_id=dataset_id, tenant_id=tenant_id)
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[mock_dataset]):
|
||||
with patch.object(dataset_retrieval, "get_metadata_filter_condition", return_value=(None, None)):
|
||||
# Mock multiple_retrieve to return empty list
|
||||
with patch.object(dataset_retrieval, "multiple_retrieve", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_rate_limit_exceeded(self):
|
||||
"""
|
||||
Test knowledge_retrieval when rate limit is exceeded.
|
||||
|
||||
Verifies:
|
||||
- RateLimitExceededError is raised
|
||||
- No further processing occurs
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock _check_knowledge_rate_limit to raise exception
|
||||
with patch.object(
|
||||
dataset_retrieval,
|
||||
"_check_knowledge_rate_limit",
|
||||
side_effect=exc.RateLimitExceededError("Rate limit exceeded"),
|
||||
):
|
||||
# Act & Assert
|
||||
with pytest.raises(exc.RateLimitExceededError):
|
||||
dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
def test_knowledge_retrieval_no_available_datasets(self):
|
||||
"""
|
||||
Test knowledge_retrieval when no datasets are available.
|
||||
|
||||
Verifies:
|
||||
- Empty list is returned
|
||||
- No retrieval is attempted
|
||||
"""
|
||||
# Arrange
|
||||
tenant_id = str(uuid4())
|
||||
user_id = str(uuid4())
|
||||
app_id = str(uuid4())
|
||||
dataset_id = str(uuid4())
|
||||
|
||||
request = KnowledgeRetrievalRequest(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
user_from="web",
|
||||
dataset_ids=[dataset_id],
|
||||
query="What is Python?",
|
||||
retrieval_mode="multiple",
|
||||
top_k=5,
|
||||
)
|
||||
|
||||
dataset_retrieval = DatasetRetrieval()
|
||||
|
||||
# Mock dependencies
|
||||
with patch.object(dataset_retrieval, "_check_knowledge_rate_limit"):
|
||||
# Mock _get_available_datasets to return empty list
|
||||
with patch.object(dataset_retrieval, "_get_available_datasets", return_value=[]):
|
||||
# Act
|
||||
result = dataset_retrieval.knowledge_retrieval(request)
|
||||
|
||||
# Assert
|
||||
assert result == []
|
||||
|
||||
def test_knowledge_retrieval_handles_multiple_documents_with_different_scores(self):
|
||||
"""
|
||||
Test that knowledge_retrieval processes multiple documents with different scores.
|
||||
|
||||
Note: Full sorting and position testing requires complex SQLAlchemy mocking
|
||||
which is better suited for integration tests. This test verifies documents
|
||||
with different scores can be created and have their metadata.
|
||||
|
||||
Verifies:
|
||||
- Documents can be created with different scores
|
||||
- Score metadata is properly set
|
||||
"""
|
||||
# Create documents with different scores
|
||||
doc1 = create_mock_document("Low score", "doc1", score=0.6)
|
||||
doc2 = create_mock_document("High score", "doc2", score=0.95)
|
||||
doc3 = create_mock_document("Medium score", "doc3", score=0.8)
|
||||
|
||||
# Assert - each document has the correct score
|
||||
assert doc1.metadata["score"] == 0.6
|
||||
assert doc2.metadata["score"] == 0.95
|
||||
assert doc3.metadata["score"] == 0.8
|
||||
|
||||
# Assert - documents are correctly sorted (not the retrieval result, just the list)
|
||||
unsorted = [doc1, doc2, doc3]
|
||||
sorted_docs = sorted(unsorted, key=lambda d: d.metadata["score"], reverse=True)
|
||||
assert [d.metadata["score"] for d in sorted_docs] == [0.95, 0.8, 0.6]
|
||||
@@ -0,0 +1,595 @@
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import StringSegment
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.knowledge_retrieval.entities import (
|
||||
KnowledgeRetrievalNodeData,
|
||||
MultipleRetrievalConfig,
|
||||
RerankingModelConfig,
|
||||
SingleRetrievalConfig,
|
||||
)
|
||||
from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.repositories.rag_retrieval_protocol import RAGRetrievalProtocol, Source
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_init_params():
|
||||
"""Create mock GraphInitParams."""
|
||||
return GraphInitParams(
|
||||
tenant_id=str(uuid.uuid4()),
|
||||
app_id=str(uuid.uuid4()),
|
||||
workflow_id=str(uuid.uuid4()),
|
||||
graph_config={},
|
||||
user_id=str(uuid.uuid4()),
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_graph_runtime_state():
|
||||
"""Create mock GraphRuntimeState."""
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]),
|
||||
user_inputs={},
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_rag_retrieval():
|
||||
"""Create mock RAGRetrievalProtocol."""
|
||||
mock_retrieval = Mock(spec=RAGRetrievalProtocol)
|
||||
mock_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
return mock_retrieval
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_node_data():
|
||||
"""Create sample KnowledgeRetrievalNodeData."""
|
||||
return KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_mode="reranking_model",
|
||||
reranking_enable=True,
|
||||
reranking_model=RerankingModelConfig(
|
||||
provider="cohere",
|
||||
model="rerank-v2",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class TestKnowledgeRetrievalNode:
|
||||
"""
|
||||
Test suite for KnowledgeRetrievalNode.
|
||||
"""
|
||||
|
||||
def test_node_initialization(self, mock_graph_init_params, mock_graph_runtime_state, mock_rag_retrieval):
|
||||
"""Test KnowledgeRetrievalNode initialization."""
|
||||
# Arrange
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": {
|
||||
"title": "Knowledge Retrieval",
|
||||
"type": "knowledge-retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
},
|
||||
}
|
||||
|
||||
# Act
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert node.id == node_id
|
||||
assert node._rag_retrieval == mock_rag_retrieval
|
||||
assert node._llm_file_saver is not None
|
||||
|
||||
def test_run_with_no_query_or_attachment(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run returns success when no query or attachment is provided."""
|
||||
# Arrange
|
||||
sample_node_data.query_variable_selector = None
|
||||
sample_node_data.query_attachment_selector = None
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert result.outputs == {}
|
||||
assert mock_rag_retrieval.knowledge_retrieval.call_count == 0
|
||||
|
||||
def test_run_with_query_variable_single_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _run with query variable in single mode."""
|
||||
# Arrange
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add query to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="single",
|
||||
query_variable_selector=query_selector,
|
||||
single_retrieval_config=SingleRetrievalConfig(
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_run_with_query_variable_multiple_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run with query variable in multiple mode."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add query to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_source.model_dump.return_value = {"content": "Python is a programming language"}
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
assert "result" in result.outputs
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_run_with_invalid_query_variable_type(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run fails when query variable is not StringSegment."""
|
||||
# Arrange
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
# Add non-string variable to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, [1, 2, 3])
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Query variable is not string type" in result.error
|
||||
|
||||
def test_run_with_invalid_attachment_variable_type(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run fails when attachment variable is not FileSegment or ArrayFileSegment."""
|
||||
# Arrange
|
||||
attachment_selector = ["start", "attachments"]
|
||||
|
||||
# Add non-file variable to variable pool
|
||||
mock_graph_runtime_state.variable_pool.add(attachment_selector, "not a file")
|
||||
sample_node_data.query_attachment_selector = attachment_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Attachments variable is not array file or file type" in result.error
|
||||
|
||||
def test_run_with_rate_limit_exceeded(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run handles RateLimitExceededError properly."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval to raise RateLimitExceededError
|
||||
mock_rag_retrieval.knowledge_retrieval.side_effect = RateLimitExceededError(
|
||||
"knowledge base request rate limit exceeded"
|
||||
)
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "rate limit" in result.error.lower()
|
||||
|
||||
def test_run_with_generic_exception(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _run handles generic exceptions properly."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
query_selector = ["start", "query"]
|
||||
|
||||
mock_graph_runtime_state.variable_pool.add(query_selector, StringSegment(value=query))
|
||||
sample_node_data.query_variable_selector = query_selector
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
# Mock retrieval to raise generic exception
|
||||
mock_rag_retrieval.knowledge_retrieval.side_effect = Exception("Unexpected error")
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = node._run()
|
||||
|
||||
# Assert
|
||||
assert result.status == WorkflowNodeExecutionStatus.FAILED
|
||||
assert "Unexpected error" in result.error
|
||||
|
||||
def test_extract_variable_selector_to_variable_mapping(self):
|
||||
"""Test _extract_variable_selector_to_variable_mapping class method."""
|
||||
# Arrange
|
||||
node_id = "knowledge_node_1"
|
||||
node_data = {
|
||||
"type": "knowledge-retrieval",
|
||||
"title": "Knowledge Retrieval",
|
||||
"dataset_ids": [str(uuid.uuid4())],
|
||||
"retrieval_mode": "multiple",
|
||||
"query_variable_selector": ["start", "query"],
|
||||
"query_attachment_selector": ["start", "attachments"],
|
||||
}
|
||||
graph_config = {}
|
||||
|
||||
# Act
|
||||
mapping = KnowledgeRetrievalNode._extract_variable_selector_to_variable_mapping(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_data=node_data,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert mapping[f"{node_id}.query"] == ["start", "query"]
|
||||
assert mapping[f"{node_id}.queryAttachment"] == ["start", "attachments"]
|
||||
|
||||
|
||||
class TestFetchDatasetRetriever:
|
||||
"""
|
||||
Test suite for _fetch_dataset_retriever method.
|
||||
"""
|
||||
|
||||
def test_fetch_dataset_retriever_single_mode(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in single mode."""
|
||||
# Arrange
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="single",
|
||||
single_retrieval_config=SingleRetrievalConfig(
|
||||
model=ModelConfig(
|
||||
provider="openai",
|
||||
name="gpt-4",
|
||||
mode="chat",
|
||||
completion_params={"temperature": 0.7},
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
# Mock retrieval response
|
||||
mock_source = Mock(spec=Source)
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = [mock_source]
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {"id": node_id, "data": node_data.model_dump()}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert len(results) == 1
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
def test_fetch_dataset_retriever_multiple_mode_with_reranking(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
sample_node_data,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in multiple mode with reranking."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
# Mock retrieval response
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": sample_node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=sample_node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert isinstance(results, list)
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
# Verify reranking parameters via request object
|
||||
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
|
||||
request = call_args[1]["request"]
|
||||
assert request.reranking_enable is True
|
||||
assert request.reranking_mode == "reranking_model"
|
||||
|
||||
def test_fetch_dataset_retriever_multiple_mode_without_reranking(
|
||||
self,
|
||||
mock_graph_init_params,
|
||||
mock_graph_runtime_state,
|
||||
mock_rag_retrieval,
|
||||
):
|
||||
"""Test _fetch_dataset_retriever in multiple mode without reranking."""
|
||||
# Arrange
|
||||
query = "What is Python?"
|
||||
variables = {"query": query}
|
||||
|
||||
node_data = KnowledgeRetrievalNodeData(
|
||||
title="Knowledge Retrieval",
|
||||
type="knowledge-retrieval",
|
||||
dataset_ids=[str(uuid.uuid4())],
|
||||
retrieval_mode="multiple",
|
||||
multiple_retrieval_config=MultipleRetrievalConfig(
|
||||
top_k=5,
|
||||
score_threshold=0.7,
|
||||
reranking_enable=False,
|
||||
reranking_mode="reranking_model",
|
||||
),
|
||||
)
|
||||
|
||||
# Mock retrieval response
|
||||
mock_rag_retrieval.knowledge_retrieval.return_value = []
|
||||
mock_rag_retrieval.llm_usage = LLMUsage.empty_usage()
|
||||
|
||||
node_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"id": node_id,
|
||||
"data": node_data.model_dump(),
|
||||
}
|
||||
|
||||
node = KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=mock_graph_init_params,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
rag_retrieval=mock_rag_retrieval,
|
||||
)
|
||||
|
||||
# Act
|
||||
results, usage = node._fetch_dataset_retriever(node_data=node_data, variables=variables)
|
||||
|
||||
# Assert
|
||||
assert isinstance(results, list)
|
||||
assert mock_rag_retrieval.knowledge_retrieval.called
|
||||
|
||||
# Verify reranking is disabled
|
||||
call_args = mock_rag_retrieval.knowledge_retrieval.call_args
|
||||
request = call_args[1]["request"]
|
||||
assert request.reranking_enable is False
|
||||
|
||||
def test_version_method(self):
|
||||
"""Test version class method."""
|
||||
# Act
|
||||
version = KnowledgeRetrievalNode.version()
|
||||
|
||||
# Assert
|
||||
assert version == "1"
|
||||
Reference in New Issue
Block a user