mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 15:10:13 -05:00
fix(api): register knowledge pipeline service API routes (#32097)
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com> Co-authored-by: FFXN <31929997+FFXN@users.noreply.github.com>
This commit is contained in:
@@ -34,6 +34,7 @@ from .dataset import (
|
|||||||
metadata,
|
metadata,
|
||||||
segment,
|
segment,
|
||||||
)
|
)
|
||||||
|
from .dataset.rag_pipeline import rag_pipeline_workflow
|
||||||
from .end_user import end_user
|
from .end_user import end_user
|
||||||
from .workspace import models
|
from .workspace import models
|
||||||
|
|
||||||
@@ -53,6 +54,7 @@ __all__ = [
|
|||||||
"message",
|
"message",
|
||||||
"metadata",
|
"metadata",
|
||||||
"models",
|
"models",
|
||||||
|
"rag_pipeline_workflow",
|
||||||
"segment",
|
"segment",
|
||||||
"site",
|
"site",
|
||||||
"workflow",
|
"workflow",
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
import string
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -41,7 +39,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
|||||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
|
||||||
class DatasourcePluginsApi(DatasetApiResource):
|
class DatasourcePluginsApi(DatasetApiResource):
|
||||||
"""Resource for datasource plugins."""
|
"""Resource for datasource plugins."""
|
||||||
|
|
||||||
@@ -76,7 +74,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
|||||||
return datasource_plugins, 200
|
return datasource_plugins, 200
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
|
||||||
class DatasourceNodeRunApi(DatasetApiResource):
|
class DatasourceNodeRunApi(DatasetApiResource):
|
||||||
"""Resource for datasource node run."""
|
"""Resource for datasource node run."""
|
||||||
|
|
||||||
@@ -131,7 +129,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
|
||||||
class PipelineRunApi(DatasetApiResource):
|
class PipelineRunApi(DatasetApiResource):
|
||||||
"""Resource for datasource node run."""
|
"""Resource for datasource node run."""
|
||||||
|
|
||||||
|
|||||||
@@ -217,6 +217,8 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||||
@wraps(view)
|
@wraps(view)
|
||||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||||
|
api_token = validate_and_get_api_token("dataset")
|
||||||
|
|
||||||
# get url path dataset_id from positional args or kwargs
|
# get url path dataset_id from positional args or kwargs
|
||||||
# Flask passes URL path parameters as positional arguments
|
# Flask passes URL path parameters as positional arguments
|
||||||
dataset_id = None
|
dataset_id = None
|
||||||
@@ -253,12 +255,18 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
|||||||
# Validate dataset if dataset_id is provided
|
# Validate dataset if dataset_id is provided
|
||||||
if dataset_id:
|
if dataset_id:
|
||||||
dataset_id = str(dataset_id)
|
dataset_id = str(dataset_id)
|
||||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
dataset = (
|
||||||
|
db.session.query(Dataset)
|
||||||
|
.where(
|
||||||
|
Dataset.id == dataset_id,
|
||||||
|
Dataset.tenant_id == api_token.tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise NotFound("Dataset not found.")
|
raise NotFound("Dataset not found.")
|
||||||
if not dataset.enable_api:
|
if not dataset.enable_api:
|
||||||
raise Forbidden("Dataset api access is not enabled.")
|
raise Forbidden("Dataset api access is not enabled.")
|
||||||
api_token = validate_and_get_api_token("dataset")
|
|
||||||
tenant_account_join = (
|
tenant_account_join = (
|
||||||
db.session.query(Tenant, TenantAccountJoin)
|
db.session.query(Tenant, TenantAccountJoin)
|
||||||
.where(Tenant.id == api_token.tenant_id)
|
.where(Tenant.id == api_token.tenant_id)
|
||||||
|
|||||||
@@ -1329,10 +1329,24 @@ class RagPipelineService:
|
|||||||
"""
|
"""
|
||||||
Get datasource plugins
|
Get datasource plugins
|
||||||
"""
|
"""
|
||||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
dataset: Dataset | None = (
|
||||||
|
db.session.query(Dataset)
|
||||||
|
.where(
|
||||||
|
Dataset.id == dataset_id,
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
pipeline: Pipeline | None = (
|
||||||
|
db.session.query(Pipeline)
|
||||||
|
.where(
|
||||||
|
Pipeline.id == dataset.pipeline_id,
|
||||||
|
Pipeline.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not pipeline:
|
if not pipeline:
|
||||||
raise ValueError("Pipeline not found")
|
raise ValueError("Pipeline not found")
|
||||||
|
|
||||||
@@ -1413,10 +1427,24 @@ class RagPipelineService:
|
|||||||
"""
|
"""
|
||||||
Get pipeline
|
Get pipeline
|
||||||
"""
|
"""
|
||||||
dataset: Dataset | None = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
dataset: Dataset | None = (
|
||||||
|
db.session.query(Dataset)
|
||||||
|
.where(
|
||||||
|
Dataset.id == dataset_id,
|
||||||
|
Dataset.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise ValueError("Dataset not found")
|
raise ValueError("Dataset not found")
|
||||||
pipeline: Pipeline | None = db.session.query(Pipeline).where(Pipeline.id == dataset.pipeline_id).first()
|
pipeline: Pipeline | None = (
|
||||||
|
db.session.query(Pipeline)
|
||||||
|
.where(
|
||||||
|
Pipeline.id == dataset.pipeline_id,
|
||||||
|
Pipeline.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.first()
|
||||||
|
)
|
||||||
if not pipeline:
|
if not pipeline:
|
||||||
raise ValueError("Pipeline not found")
|
raise ValueError("Pipeline not found")
|
||||||
return pipeline
|
return pipeline
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for Service API knowledge pipeline route registration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import ast
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
def test_rag_pipeline_routes_registered():
|
||||||
|
api_dir = Path(__file__).resolve().parents[5]
|
||||||
|
|
||||||
|
service_api_init = api_dir / "controllers" / "service_api" / "__init__.py"
|
||||||
|
rag_pipeline_workflow = (
|
||||||
|
api_dir / "controllers" / "service_api" / "dataset" / "rag_pipeline" / "rag_pipeline_workflow.py"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert service_api_init.exists()
|
||||||
|
assert rag_pipeline_workflow.exists()
|
||||||
|
|
||||||
|
init_tree = ast.parse(service_api_init.read_text(encoding="utf-8"))
|
||||||
|
import_found = False
|
||||||
|
for node in ast.walk(init_tree):
|
||||||
|
if not isinstance(node, ast.ImportFrom):
|
||||||
|
continue
|
||||||
|
if node.module != "dataset.rag_pipeline" or node.level != 1:
|
||||||
|
continue
|
||||||
|
if any(alias.name == "rag_pipeline_workflow" for alias in node.names):
|
||||||
|
import_found = True
|
||||||
|
break
|
||||||
|
assert import_found, "from .dataset.rag_pipeline import rag_pipeline_workflow not found in service_api/__init__.py"
|
||||||
|
|
||||||
|
workflow_tree = ast.parse(rag_pipeline_workflow.read_text(encoding="utf-8"))
|
||||||
|
route_paths: set[str] = set()
|
||||||
|
|
||||||
|
for node in ast.walk(workflow_tree):
|
||||||
|
if not isinstance(node, ast.ClassDef):
|
||||||
|
continue
|
||||||
|
for decorator in node.decorator_list:
|
||||||
|
if not isinstance(decorator, ast.Call):
|
||||||
|
continue
|
||||||
|
if not isinstance(decorator.func, ast.Attribute):
|
||||||
|
continue
|
||||||
|
if decorator.func.attr != "route":
|
||||||
|
continue
|
||||||
|
if not decorator.args:
|
||||||
|
continue
|
||||||
|
first_arg = decorator.args[0]
|
||||||
|
if isinstance(first_arg, ast.Constant) and isinstance(first_arg.value, str):
|
||||||
|
route_paths.add(first_arg.value)
|
||||||
|
|
||||||
|
assert "/datasets/<uuid:dataset_id>/pipeline/datasource-plugins" in route_paths
|
||||||
|
assert "/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run" in route_paths
|
||||||
|
assert "/datasets/<uuid:dataset_id>/pipeline/run" in route_paths
|
||||||
|
assert "/datasets/pipeline/file-upload" in route_paths
|
||||||
Reference in New Issue
Block a user