resolve: conflict

This commit is contained in:
crazywoola
2026-02-09 15:17:25 +08:00
parent f4d6383019
commit 481c707fab
78 changed files with 3470 additions and 971 deletions

View File

@@ -104,6 +104,8 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability

View File

@@ -1,9 +1,13 @@
import logging
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,

View File

@@ -32,6 +32,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@@ -285,6 +286,35 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
):
return WorkflowGenerator.generate_workflow_flowchart(
tenant_id=tenant_id,
instruction=instruction,
model_config=model_config,
available_nodes=available_nodes,
existing_nodes=existing_nodes,
available_tools=available_tools,
selected_node_ids=selected_node_ids,
previous_workflow=previous_workflow,
regenerate_mode=regenerate_mode,
preferred_language=preferred_language,
available_models=available_models,
)
@classmethod
def generate_code(
cls,

View File

@@ -143,6 +143,50 @@ Based on task description, please create a well-structured prompt template that
Please generate the full prompt template with at least 300 words and output only the prompt template.
""" # noqa: E501
WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """
You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request.
Constraints:
- Detect the language of the user's request. Generate all node titles in the same language as the user's input.
- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language.
- Use only node types listed in <available_nodes>.
- Use only tools listed in <available_tools>. When using a tool node, set type=tool and tool=<tool_key>.
- Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key.
- Prefer reusing node titles from <existing_nodes> when possible.
- Output must be valid Mermaid flowchart syntax, no markdown, no extra text.
- First line must be: flowchart LR
- Every node must be declared on its own line using:
<id>["type=<type>|title=<title>|tool=<tool_key>"]
- type is required and must match a type in <available_nodes>.
- title is required for non-tool nodes.
- tool is required only when type=tool, otherwise omit tool.
- Declare all node lines before any edges.
- Edges must use:
<id> --> <id>
<id> -->|true| <id>
<id> -->|false| <id>
- Keep node ids unique and simple (N1, N2, ...).
- For complex orchestration:
- Break the request into stages (ingest, transform, decision, action, output).
- Use IfElse for branching and label edges true/false only.
- Fan-in branches by connecting multiple nodes into a shared downstream node.
- Avoid cycles unless explicitly requested.
- Keep each branch complete with a clear downstream target.
<user_request>
{{TASK_DESCRIPTION}}
</user_request>
<available_nodes>
{{AVAILABLE_NODES}}
</available_nodes>
<existing_nodes>
{{EXISTING_NODES}}
</existing_nodes>
<available_tools>
{{AVAILABLE_TOOLS}}
</available_tools>
"""
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
Here is a task description for which I would like you to create a high-quality prompt template for:
<task_description>

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from packaging.version import Version
from pydantic import ValidationError
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -49,6 +50,12 @@ from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@@ -212,6 +212,14 @@ class Node(Generic[NodeDataT]):
return None
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
"""
Get the default configuration schema for the node.
Used for LLM generation.
"""
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {}

View File

@@ -1,3 +1,5 @@
from typing import Any
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -9,6 +11,24 @@ class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow exit point - defines output variables",
"required": ["outputs"],
"parameters": {
"outputs": {
"type": "array",
"description": "Output variables to return",
"item_schema": {
"variable": "string - output variable name",
"type": "enum: string, number, object, array",
"value_selector": "array - path to source value, e.g. ['node_id', 'field']",
},
},
},
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -14,6 +14,27 @@ class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow entry point - defines input variables",
"required": [],
"parameters": {
"variables": {
"type": "array",
"description": "Input variables for the workflow",
"item_schema": {
"variable": "string - variable name",
"label": "string - display label",
"type": "enum: text-input, paragraph, number, select, file, file-list",
"required": "boolean",
"max_length": "number (optional)",
},
},
},
"outputs": ["All defined variables are available as {{#start.variable_name#}}"],
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -50,6 +50,19 @@ class ToolNode(Node[ToolNodeData]):
def version(cls) -> str:
return "1"
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Execute an external tool",
"required": ["provider_id", "tool_id", "tool_parameters"],
"parameters": {
"provider_id": {"type": "string"},
"provider_type": {"type": "string"},
"tool_id": {"type": "string"},
"tool_parameters": {"type": "object"},
},
}
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node

View File

@@ -0,0 +1,400 @@
"""
Unit tests for GraphBuilder.
Tests the automatic graph construction from node lists with dependency declarations.
"""
import pytest
from core.workflow.generator.utils.graph_builder import (
CyclicDependencyError,
GraphBuilder,
)
class TestGraphBuilderBasic:
"""Basic functionality tests."""
def test_empty_nodes_creates_minimal_workflow(self):
"""Empty node list creates start -> end workflow."""
result_nodes, result_edges = GraphBuilder.build_graph([])
assert len(result_nodes) == 2
assert result_nodes[0]["type"] == "start"
assert result_nodes[1]["type"] == "end"
assert len(result_edges) == 1
assert result_edges[0]["source"] == "start"
assert result_edges[0]["target"] == "end"
def test_simple_linear_workflow(self):
"""Simple linear workflow: start -> fetch -> process -> end."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{"id": "process", "type": "llm", "depends_on": ["fetch"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have: start + 2 user nodes + end = 4
assert len(result_nodes) == 4
assert result_nodes[0]["type"] == "start"
assert result_nodes[-1]["type"] == "end"
# Should have: start->fetch, fetch->process, process->end = 3
assert len(result_edges) == 3
# Verify edge connections
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("start", "fetch") in edge_pairs
assert ("fetch", "process") in edge_pairs
assert ("process", "end") in edge_pairs
class TestParallelWorkflow:
"""Tests for parallel node handling."""
def test_parallel_workflow(self):
"""Parallel workflow: multiple nodes from start, merging to one."""
nodes = [
{"id": "api1", "type": "http-request", "depends_on": []},
{"id": "api2", "type": "http-request", "depends_on": []},
{"id": "merge", "type": "llm", "depends_on": ["api1", "api2"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# start should connect to both api1 and api2
start_edges = [e for e in result_edges if e["source"] == "start"]
assert len(start_edges) == 2
start_targets = {e["target"] for e in start_edges}
assert start_targets == {"api1", "api2"}
# Both api1 and api2 should connect to merge
merge_incoming = [e for e in result_edges if e["target"] == "merge"]
assert len(merge_incoming) == 2
def test_multiple_terminal_nodes(self):
"""Multiple terminal nodes all connect to end."""
nodes = [
{"id": "branch1", "type": "llm", "depends_on": []},
{"id": "branch2", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Both branches should connect to end
end_incoming = [e for e in result_edges if e["target"] == "end"]
assert len(end_incoming) == 2
class TestIfElseWorkflow:
"""Tests for if-else branching."""
def test_if_else_workflow(self):
"""Conditional branching workflow."""
nodes = [
{
"id": "check",
"type": "if-else",
"config": {"true_branch": "success", "false_branch": "fallback"},
"depends_on": [],
},
{"id": "success", "type": "llm", "depends_on": []},
{"id": "fallback", "type": "code", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have true and false branch edges
branch_edges = [e for e in result_edges if e["source"] == "check"]
assert len(branch_edges) == 2
assert any(e.get("sourceHandle") == "true" for e in branch_edges)
assert any(e.get("sourceHandle") == "false" for e in branch_edges)
# Verify targets
true_edge = next(e for e in branch_edges if e.get("sourceHandle") == "true")
false_edge = next(e for e in branch_edges if e.get("sourceHandle") == "false")
assert true_edge["target"] == "success"
assert false_edge["target"] == "fallback"
def test_if_else_missing_branch_no_error(self):
"""if-else with only true branch doesn't error (warning only)."""
nodes = [
{
"id": "check",
"type": "if-else",
"config": {"true_branch": "success"},
"depends_on": [],
},
{"id": "success", "type": "llm", "depends_on": []},
]
# Should not raise
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have one branch edge
branch_edges = [e for e in result_edges if e["source"] == "check"]
assert len(branch_edges) == 1
assert branch_edges[0].get("sourceHandle") == "true"
class TestQuestionClassifierWorkflow:
"""Tests for question-classifier branching."""
def test_question_classifier_workflow(self):
"""Question classifier with multiple classes."""
nodes = [
{
"id": "classifier",
"type": "question-classifier",
"config": {
"query": ["start", "user_input"],
"classes": [
{"id": "tech", "name": "技术问题", "target": "tech_handler"},
{"id": "sales", "name": "销售咨询", "target": "sales_handler"},
{"id": "other", "name": "其他问题", "target": "other_handler"},
],
},
"depends_on": [],
},
{"id": "tech_handler", "type": "llm", "depends_on": []},
{"id": "sales_handler", "type": "llm", "depends_on": []},
{"id": "other_handler", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have 3 branch edges from classifier
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
assert len(classifier_edges) == 3
# Each should use class id as sourceHandle
assert any(e.get("sourceHandle") == "tech" and e["target"] == "tech_handler" for e in classifier_edges)
assert any(e.get("sourceHandle") == "sales" and e["target"] == "sales_handler" for e in classifier_edges)
assert any(e.get("sourceHandle") == "other" and e["target"] == "other_handler" for e in classifier_edges)
def test_question_classifier_missing_target(self):
"""Classes without target connect to end."""
nodes = [
{
"id": "classifier",
"type": "question-classifier",
"config": {
"classes": [
{"id": "known", "name": "已知问题", "target": "handler"},
{"id": "unknown", "name": "未知问题"}, # Missing target
],
},
"depends_on": [],
},
{"id": "handler", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Missing target should connect to end
classifier_edges = [e for e in result_edges if e["source"] == "classifier"]
assert any(e.get("sourceHandle") == "unknown" and e["target"] == "end" for e in classifier_edges)
class TestVariableDependencyInference:
"""Tests for automatic dependency inference from variables."""
def test_variable_dependency_inference(self):
"""Dependencies inferred from variable references."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{
"id": "process",
"type": "llm",
"config": {"prompt_template": [{"text": "{{#fetch.body#}}"}]},
# No explicit depends_on, but references fetch
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should automatically infer process depends on fetch
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
def test_system_variable_not_inferred(self):
"""System variables (sys, start) not inferred as dependencies."""
nodes = [
{
"id": "process",
"type": "llm",
"config": {"prompt_template": [{"text": "{{#sys.query#}} {{#start.input#}}"}]},
"depends_on": [],
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should connect to start, not create dependency on sys or start
edge_sources = {e["source"] for e in result_edges}
assert "sys" not in edge_sources
assert "start" in edge_sources
class TestCycleDetection:
"""Tests for cyclic dependency detection."""
def test_cyclic_dependency_detected(self):
"""Cyclic dependencies raise error."""
nodes = [
{"id": "a", "type": "llm", "depends_on": ["c"]},
{"id": "b", "type": "llm", "depends_on": ["a"]},
{"id": "c", "type": "llm", "depends_on": ["b"]},
]
with pytest.raises(CyclicDependencyError):
GraphBuilder.build_graph(nodes)
def test_self_dependency_detected(self):
"""Self-dependency raises error."""
nodes = [
{"id": "a", "type": "llm", "depends_on": ["a"]},
]
with pytest.raises(CyclicDependencyError):
GraphBuilder.build_graph(nodes)
class TestErrorRecovery:
"""Tests for silent error recovery."""
def test_invalid_dependency_removed(self):
"""Invalid dependencies (non-existent nodes) are silently removed."""
nodes = [
{"id": "process", "type": "llm", "depends_on": ["nonexistent"]},
]
# Should not raise, invalid dependency silently removed
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Process should connect from start (since invalid dep was removed)
assert any(e["source"] == "start" and e["target"] == "process" for e in result_edges)
def test_depends_on_as_string(self):
"""depends_on as string is converted to list."""
nodes = [
{"id": "fetch", "type": "http-request", "depends_on": []},
{"id": "process", "type": "llm", "depends_on": "fetch"}, # String instead of list
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should work correctly
assert any(e["source"] == "fetch" and e["target"] == "process" for e in result_edges)
class TestContainerNodes:
"""Tests for container nodes (iteration, loop)."""
def test_iteration_node_as_regular_node(self):
"""Iteration nodes behave as regular single-in-single-out nodes."""
nodes = [
{"id": "prepare", "type": "code", "depends_on": []},
{
"id": "loop",
"type": "iteration",
"config": {"iterator_selector": ["prepare", "items"]},
"depends_on": ["prepare"],
},
{"id": "process_result", "type": "llm", "depends_on": ["loop"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should have standard edges: start->prepare, prepare->loop, loop->process_result, process_result->end
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("start", "prepare") in edge_pairs
assert ("prepare", "loop") in edge_pairs
assert ("loop", "process_result") in edge_pairs
assert ("process_result", "end") in edge_pairs
def test_loop_node_as_regular_node(self):
"""Loop nodes behave as regular single-in-single-out nodes."""
nodes = [
{"id": "init", "type": "code", "depends_on": []},
{
"id": "repeat",
"type": "loop",
"config": {"loop_count": 5},
"depends_on": ["init"],
},
{"id": "finish", "type": "llm", "depends_on": ["repeat"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Standard edge flow
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("init", "repeat") in edge_pairs
assert ("repeat", "finish") in edge_pairs
def test_iteration_with_variable_inference(self):
"""Iteration node dependencies can be inferred from iterator_selector."""
nodes = [
{"id": "data_source", "type": "http-request", "depends_on": []},
{
"id": "process_each",
"type": "iteration",
"config": {
"iterator_selector": ["data_source", "items"],
},
# No explicit depends_on, but references data_source
},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Should infer dependency from iterator_selector reference
# Note: iterator_selector format is different from {{#...#}}, so this tests
# that explicit depends_on is properly handled when not provided
# In this case, process_each has no depends_on, so it connects to start
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
# Without explicit depends_on, connects to start
assert ("start", "process_each") in edge_pairs or ("data_source", "process_each") in edge_pairs
def test_loop_node_self_reference_not_cycle(self):
"""Loop nodes referencing their own outputs should not create cycle."""
nodes = [
{"id": "init", "type": "code", "depends_on": []},
{
"id": "my_loop",
"type": "loop",
"config": {
"loop_count": 5,
# Loop node referencing its own output (common pattern)
"prompt": "Previous: {{#my_loop.output#}}, continue...",
},
"depends_on": ["init"],
},
{"id": "finish", "type": "llm", "depends_on": ["my_loop"]},
]
# Should NOT raise CyclicDependencyError
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
# Verify the graph is built correctly
assert len(result_nodes) == 5 # start + 3 + end
edge_pairs = [(e["source"], e["target"]) for e in result_edges]
assert ("init", "my_loop") in edge_pairs
assert ("my_loop", "finish") in edge_pairs
class TestEdgeStructure:
"""Tests for edge structure correctness."""
def test_edge_has_required_fields(self):
"""Edges have all required fields."""
nodes = [
{"id": "node1", "type": "llm", "depends_on": []},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
for edge in result_edges:
assert "id" in edge
assert "source" in edge
assert "target" in edge
assert "sourceHandle" in edge
assert "targetHandle" in edge
def test_edge_id_unique(self):
"""Each edge has a unique ID."""
nodes = [
{"id": "a", "type": "llm", "depends_on": []},
{"id": "b", "type": "llm", "depends_on": []},
{"id": "c", "type": "llm", "depends_on": ["a", "b"]},
]
result_nodes, result_edges = GraphBuilder.build_graph(nodes)
edge_ids = [e["id"] for e in result_edges]
assert len(edge_ids) == len(set(edge_ids)) # All unique

View File

@@ -0,0 +1,287 @@
"""
Unit tests for the Mermaid Generator.
Tests cover:
- Basic workflow rendering
- Reserved word handling ('end''end_node')
- Question classifier multi-branch edges
- If-else branch labels
- Edge validation and skipping
- Tool node formatting
"""
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
class TestBasicWorkflow:
"""Tests for basic workflow Mermaid generation."""
def test_simple_start_end_workflow(self):
"""Test simple Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
assert 'start["type=start|title=Start"]' in result
assert 'end_node["type=end|title=End"]' in result
assert "start --> end_node" in result
def test_start_llm_end_workflow(self):
"""Test Start → LLM → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "llm", "type": "llm", "title": "Generate"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
result = generate_mermaid(workflow_data)
assert 'llm["type=llm|title=Generate"]' in result
assert "start --> llm" in result
assert "llm --> end_node" in result
def test_empty_workflow(self):
"""Test empty workflow returns minimal output."""
workflow_data = {"nodes": [], "edges": []}
result = generate_mermaid(workflow_data)
assert result == "flowchart TD"
def test_missing_keys_handled(self):
"""Test workflow with missing keys doesn't crash."""
workflow_data = {}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
class TestReservedWords:
"""Tests for reserved word handling in node IDs."""
def test_end_node_id_is_replaced(self):
"""Test 'end' node ID is replaced with 'end_node'."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should use end_node instead of end
assert "end_node[" in result
assert '"type=end|title=End"' in result
def test_subgraph_node_id_is_replaced(self):
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
workflow_data = {
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "subgraph_node[" in result
def test_edge_uses_safe_ids(self):
"""Test edges correctly reference safe IDs after replacement."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Edge should use end_node, not end
assert "start --> end_node" in result
assert "start --> end\n" not in result
class TestBranchEdges:
"""Tests for branching node edge labels."""
def test_question_classifier_source_handles(self):
"""Test question-classifier edges with sourceHandle labels."""
workflow_data = {
"nodes": [
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
{"id": "refund", "type": "llm", "title": "Handle Refund"},
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
],
"edges": [
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
],
}
result = generate_mermaid(workflow_data)
assert "classifier -->|refund| refund" in result
assert "classifier -->|inquiry| inquiry" in result
def test_if_else_true_false_handles(self):
"""Test if-else edges with true/false labels."""
workflow_data = {
"nodes": [
{"id": "ifelse", "type": "if-else", "title": "Check"},
{"id": "yes_branch", "type": "llm", "title": "Yes"},
{"id": "no_branch", "type": "llm", "title": "No"},
],
"edges": [
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
],
}
result = generate_mermaid(workflow_data)
assert "ifelse -->|true| yes_branch" in result
assert "ifelse -->|false| no_branch" in result
def test_source_handle_source_is_ignored(self):
"""Test sourceHandle='source' doesn't add label."""
workflow_data = {
"nodes": [
{"id": "llm1", "type": "llm", "title": "LLM 1"},
{"id": "llm2", "type": "llm", "title": "LLM 2"},
],
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
}
result = generate_mermaid(workflow_data)
# Should be plain arrow without label
assert "llm1 --> llm2" in result
assert "llm1 -->|source|" not in result
class TestEdgeValidation:
"""Tests for edge validation and error handling."""
def test_edge_with_missing_source_is_skipped(self):
"""Test edge with non-existent source node is skipped."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Should not contain the invalid edge
assert "nonexistent" not in result
assert "-->" not in result or "nonexistent" not in result
def test_edge_with_missing_target_is_skipped(self):
"""Test edge with non-existent target node is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
result = generate_mermaid(workflow_data)
# Edge should be skipped
assert "start --> nonexistent" not in result
def test_edge_without_source_or_target_is_skipped(self):
"""Test edge missing source or target is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start"}, {"target": "start"}, {}],
}
result = generate_mermaid(workflow_data)
# No edges should be rendered
assert result.count("-->") == 0
class TestToolNodes:
"""Tests for tool node formatting."""
def test_tool_node_includes_tool_key(self):
"""Test tool node includes tool_key in label."""
workflow_data = {
"nodes": [
{
"id": "search",
"type": "tool",
"title": "Search",
"config": {"tool_key": "google/search"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert 'search["type=tool|title=Search|tool=google/search"]' in result
def test_tool_node_with_tool_name_fallback(self):
"""Test tool node uses tool_name as fallback."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"title": "My Tool",
"config": {"tool_name": "my_tool"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=my_tool" in result
def test_tool_node_missing_tool_key_shows_unknown(self):
"""Test tool node without tool_key shows 'unknown'."""
workflow_data = {
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=unknown" in result
class TestNodeFormatting:
"""Tests for node label formatting."""
def test_quotes_in_title_are_escaped(self):
"""Test double quotes in title are replaced with single quotes."""
workflow_data = {
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Double quotes should be replaced
assert "Say 'Hello'" in result
assert 'Say "Hello"' not in result
def test_node_without_id_is_skipped(self):
"""Test node without id is skipped."""
workflow_data = {
"nodes": [{"type": "llm", "title": "No ID"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should only have flowchart header
lines = [line for line in result.split("\n") if line.strip()]
assert len(lines) == 1
def test_node_default_values(self):
"""Test node with missing type/title uses defaults."""
workflow_data = {
"nodes": [{"id": "node1"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "type=unknown" in result
assert "title=Untitled" in result

View File

@@ -0,0 +1,81 @@
from core.workflow.generator.utils.node_repair import NodeRepair
class TestNodeRepair:
"""Tests for NodeRepair utility."""
def test_repair_if_else_valid_operators(self):
"""Test that valid operators remain unchanged."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": "", "value": "1"},
{"comparison_operator": "=", "value": "2"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes == nodes
def test_repair_if_else_invalid_operators(self):
"""Test that invalid operators are normalized."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": ">=", "value": "1"},
{"comparison_operator": "<=", "value": "2"},
{"comparison_operator": "!=", "value": "3"},
{"comparison_operator": "==", "value": "4"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is True
assert len(result.repairs_made) == 4
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
assert conditions[0]["comparison_operator"] == ""
assert conditions[1]["comparison_operator"] == ""
assert conditions[2]["comparison_operator"] == ""
assert conditions[3]["comparison_operator"] == "="
def test_repair_ignores_other_nodes(self):
"""Test that other node types are ignored."""
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes[0]["config"]["some_field"] == ">="
def test_repair_handles_missing_config(self):
"""Test robustness against missing fields."""
nodes = [
{
"id": "node1",
"type": "if-else",
# Missing config
},
{
"id": "node2",
"type": "if-else",
"config": {}, # Missing cases
},
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False

View File

@@ -0,0 +1,99 @@
"""
Tests for node schemas validation.
Ensures that the node configuration stays in sync with registered node types.
"""
from core.workflow.generator.config.node_schemas import (
get_builtin_node_schemas,
validate_node_schemas,
)
class TestNodeSchemasValidation:
"""Tests for node schema validation utilities."""
def test_validate_node_schemas_returns_no_warnings(self):
"""Ensure all registered node types have corresponding schemas."""
warnings = validate_node_schemas()
# If this test fails, it means a new node type was added but
# no schema was defined for it in node_schemas.py
assert len(warnings) == 0, (
f"Missing schemas for node types: {warnings}. "
"Please add schemas for these node types in node_schemas.py "
"or add them to _INTERNAL_NODE_TYPES if they don't need schemas."
)
def test_builtin_node_schemas_not_empty(self):
"""Ensure BUILTIN_NODE_SCHEMAS contains expected node types."""
# get_builtin_node_schemas() includes dynamic schemas
all_schemas = get_builtin_node_schemas()
assert len(all_schemas) > 0
# Core node types should always be present
expected_types = ["llm", "code", "http-request", "if-else"]
for node_type in expected_types:
assert node_type in all_schemas, f"Missing schema for core node type: {node_type}"
def test_schema_structure(self):
"""Ensure each schema has required fields."""
all_schemas = get_builtin_node_schemas()
for node_type, schema in all_schemas.items():
assert "description" in schema, f"Missing 'description' in schema for {node_type}"
# 'parameters' is optional but if present should be a dict
if "parameters" in schema:
assert isinstance(schema["parameters"], dict), (
f"'parameters' in schema for {node_type} should be a dict"
)
class TestNodeSchemasMerged:
"""Tests to verify the merged configuration works correctly."""
def test_fallback_rules_available(self):
"""Ensure FALLBACK_RULES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import FALLBACK_RULES
assert len(FALLBACK_RULES) > 0
assert "http-request" in FALLBACK_RULES
assert "code" in FALLBACK_RULES
assert "llm" in FALLBACK_RULES
def test_node_type_aliases_available(self):
"""Ensure NODE_TYPE_ALIASES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES
assert len(NODE_TYPE_ALIASES) > 0
assert NODE_TYPE_ALIASES.get("gpt") == "llm"
assert NODE_TYPE_ALIASES.get("api") == "http-request"
def test_field_name_corrections_available(self):
"""Ensure FIELD_NAME_CORRECTIONS is available from node_schemas."""
from core.workflow.generator.config.node_schemas import (
FIELD_NAME_CORRECTIONS,
get_corrected_field_name,
)
assert len(FIELD_NAME_CORRECTIONS) > 0
# Test the helper function
assert get_corrected_field_name("http-request", "text") == "body"
assert get_corrected_field_name("llm", "response") == "text"
assert get_corrected_field_name("code", "unknown") == "unknown"
def test_config_init_exports(self):
"""Ensure config __init__.py exports all needed symbols."""
from core.workflow.generator.config import (
BUILTIN_NODE_SCHEMAS,
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_corrected_field_name,
validate_node_schemas,
)
# Just verify imports work
assert BUILTIN_NODE_SCHEMAS is not None
assert FALLBACK_RULES is not None
assert FIELD_NAME_CORRECTIONS is not None
assert NODE_TYPE_ALIASES is not None
assert callable(get_corrected_field_name)
assert callable(validate_node_schemas)

View File

@@ -0,0 +1,172 @@
"""
Unit tests for the Planner Prompts.
Tests cover:
- Tool formatting for planner context
- Edge cases with missing fields
- Empty tool lists
"""
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
class TestFormatToolsForPlanner:
"""Tests for format_tools_for_planner function."""
def test_empty_tools_returns_default_message(self):
"""Test empty tools list returns default message."""
result = format_tools_for_planner([])
assert result == "No external tools available."
def test_none_tools_returns_default_message(self):
"""Test None tools list returns default message."""
result = format_tools_for_planner(None)
assert result == "No external tools available."
def test_single_tool_formatting(self):
"""Test single tool is formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Google Search",
"tool_description": "Search the web using Google",
}
]
result = format_tools_for_planner(tools)
assert "[google/search]" in result
assert "Google Search" in result
assert "Search the web using Google" in result
def test_multiple_tools_formatting(self):
"""Test multiple tools are formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Search",
"tool_description": "Web search",
},
{
"provider_id": "slack",
"tool_key": "send_message",
"tool_label": "Send Message",
"tool_description": "Send a Slack message",
},
]
result = format_tools_for_planner(tools)
lines = result.strip().split("\n")
assert len(lines) == 2
assert "[google/search]" in result
assert "[slack/send_message]" in result
def test_tool_without_provider_uses_key_only(self):
"""Test tool without provider_id uses tool_key only."""
tools = [
{
"tool_key": "my_tool",
"tool_label": "My Tool",
"tool_description": "A custom tool",
}
]
result = format_tools_for_planner(tools)
# Should format as [my_tool] without provider prefix
assert "[my_tool]" in result
assert "My Tool" in result
def test_tool_with_tool_name_fallback(self):
"""Test tool uses tool_name when tool_key is missing."""
tools = [
{
"tool_name": "fallback_tool",
"description": "Fallback description",
}
]
result = format_tools_for_planner(tools)
assert "fallback_tool" in result
assert "Fallback description" in result
def test_tool_with_missing_description(self):
"""Test tool with missing description doesn't crash."""
tools = [
{
"provider_id": "test",
"tool_key": "tool1",
"tool_label": "Tool 1",
}
]
result = format_tools_for_planner(tools)
assert "[test/tool1]" in result
assert "Tool 1" in result
def test_tool_with_all_missing_fields(self):
"""Test tool with all fields missing uses defaults."""
tools = [{}]
result = format_tools_for_planner(tools)
# Should not crash, may produce minimal output
assert isinstance(result, str)
def test_tool_uses_provider_fallback(self):
"""Test tool uses 'provider' when 'provider_id' is missing."""
tools = [
{
"provider": "openai",
"tool_key": "dalle",
"tool_label": "DALL-E",
"tool_description": "Generate images",
}
]
result = format_tools_for_planner(tools)
assert "[openai/dalle]" in result
def test_tool_label_fallback_to_key(self):
"""Test tool_label falls back to tool_key when missing."""
tools = [
{
"provider_id": "test",
"tool_key": "my_key",
"tool_description": "Description here",
}
]
result = format_tools_for_planner(tools)
# Label should fallback to key
assert "my_key" in result
assert "Description here" in result
class TestPlannerPromptConstants:
"""Tests for planner prompt constant availability."""
def test_planner_system_prompt_exists(self):
"""Test PLANNER_SYSTEM_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert PLANNER_SYSTEM_PROMPT is not None
assert len(PLANNER_SYSTEM_PROMPT) > 0
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
def test_planner_user_prompt_exists(self):
"""Test PLANNER_USER_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
assert PLANNER_USER_PROMPT is not None
assert "{instruction}" in PLANNER_USER_PROMPT
def test_planner_system_prompt_has_required_sections(self):
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert "<role>" in PLANNER_SYSTEM_PROMPT
assert "<task>" in PLANNER_SYSTEM_PROMPT
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
assert "<response_format>" in PLANNER_SYSTEM_PROMPT

View File

@@ -0,0 +1,510 @@
"""
Unit tests for the Validation Rule Engine.
Tests cover:
- Structure rules (required fields, types, formats)
- Semantic rules (variable references, edge connections)
- Reference rules (model exists, tool configured, dataset valid)
- ValidationEngine integration
"""
from core.workflow.generator.validation import (
ValidationContext,
ValidationEngine,
)
from core.workflow.generator.validation.rules import (
extract_variable_refs,
is_placeholder,
)
class TestPlaceholderDetection:
"""Tests for placeholder detection utility."""
def test_detects_please_select(self):
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
def test_detects_your_prefix(self):
assert is_placeholder("YOUR_API_KEY") is True
def test_detects_todo(self):
assert is_placeholder("TODO: fill this in") is True
def test_detects_placeholder(self):
assert is_placeholder("PLACEHOLDER_VALUE") is True
def test_detects_example_prefix(self):
assert is_placeholder("EXAMPLE_URL") is True
def test_detects_replace_prefix(self):
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
def test_case_insensitive(self):
assert is_placeholder("please_select") is True
assert is_placeholder("Please_Select") is True
def test_valid_values_not_detected(self):
assert is_placeholder("https://api.example.com") is False
assert is_placeholder("gpt-4") is False
assert is_placeholder("my_variable") is False
def test_non_string_returns_false(self):
assert is_placeholder(123) is False
assert is_placeholder(None) is False
assert is_placeholder(["list"]) is False
class TestVariableRefExtraction:
"""Tests for variable reference extraction."""
def test_extracts_simple_ref(self):
refs = extract_variable_refs("Hello {{#start.query#}}")
assert refs == [("start", "query")]
def test_extracts_multiple_refs(self):
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
assert refs == [("node1", "output"), ("node2", "text")]
def test_extracts_nested_field(self):
refs = extract_variable_refs("{{#http_request.body#}}")
assert refs == [("http_request", "body")]
def test_no_refs_returns_empty(self):
refs = extract_variable_refs("No references here")
assert refs == []
def test_handles_malformed_refs(self):
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
assert refs == []
class TestValidationContext:
"""Tests for ValidationContext."""
def test_node_map_lookup(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm_1", "type": "llm"},
]
)
assert ctx.get_node("start") == {"id": "start", "type": "start"}
assert ctx.get_node("nonexistent") is None
def test_model_set(self):
ctx = ValidationContext(
available_models=[
{"provider": "openai", "model": "gpt-4"},
{"provider": "anthropic", "model": "claude-3"},
]
)
assert ctx.has_model("openai", "gpt-4") is True
assert ctx.has_model("anthropic", "claude-3") is True
assert ctx.has_model("openai", "gpt-3.5") is False
def test_tool_set(self):
ctx = ValidationContext(
available_tools=[
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
]
)
assert ctx.has_tool("google/search") is True
assert ctx.has_tool("search") is True
assert ctx.is_tool_configured("google/search") is True
assert ctx.is_tool_configured("slack/send_message") is False
def test_upstream_downstream_nodes(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm", "type": "llm"},
{"id": "end", "type": "end"},
],
edges=[
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
)
assert ctx.get_upstream_nodes("llm") == ["start"]
assert ctx.get_downstream_nodes("llm") == ["end"]
class TestStructureRules:
"""Tests for structure validation rules."""
def test_llm_missing_prompt_template(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_with_prompt_template_passes(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "system", "text": "You are helpful"},
{"role": "user", "text": "Hello"},
]
},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No prompt_template errors
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
assert len(errors) == 0
def test_http_request_missing_url(self):
ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_http_request_placeholder_url(self):
ctx = ValidationContext(
nodes=[
{
"id": "http_1",
"type": "http-request",
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
assert len(errors) == 1
def test_code_node_missing_fields(self):
ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
error_rules = {e.rule_id for e in result.all_errors}
assert "code.code.required" in error_rules
assert "code.language.required" in error_rules
def test_knowledge_retrieval_missing_dataset(self):
ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is False # User must configure
class TestSemanticRules:
"""Tests for semantic validation rules."""
def test_valid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No variable reference errors
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 0
def test_invalid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
def test_edge_validation(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
edges=[
{"source": "start", "target": "end"},
{"source": "nonexistent", "target": "end"},
],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "edge" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
class TestReferenceRules:
"""Tests for reference validation rules (models, tools)."""
def test_llm_missing_model_with_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_missing_model_no_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[], # No models available
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
assert len(errors) == 1
assert errors[0].is_fixable is False
def test_llm_with_valid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-4"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "model" in e.rule_id]
assert len(errors) == 0
def test_llm_with_invalid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-99"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_tool_node_not_found(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "nonexistent/tool"},
}
],
available_tools=[],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
assert len(errors) == 1
def test_tool_node_not_configured(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "google/search"},
}
],
available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
assert len(errors) == 1
assert errors[0].is_fixable is False
class TestValidationResult:
"""Tests for ValidationResult classification."""
def test_has_errors(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors is True
assert result.is_valid is False
def test_has_fixable_errors(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_fixable_errors is True
assert len(result.fixable_errors) > 0
def test_get_fixable_by_node(self):
ctx = ValidationContext(
nodes=[
{"id": "llm_1", "type": "llm", "config": {}},
{"id": "http_1", "type": "http-request", "config": {}},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
by_node = result.get_fixable_by_node()
assert "llm_1" in by_node
assert "http_1" in by_node
def test_to_dict(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
d = result.to_dict()
assert "fixable" in d
assert "user_required" in d
assert "warnings" in d
assert "all_warnings" in d
assert "stats" in d
class TestIntegration:
"""Integration tests for the full validation pipeline."""
def test_complete_workflow_validation(self):
"""Test validation of a complete workflow."""
ctx = ValidationContext(
nodes=[
{
"id": "start",
"type": "start",
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
},
{
"id": "llm_1",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
},
},
{
"id": "end",
"type": "end",
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
},
],
edges=[
{"source": "start", "target": "llm_1"},
{"source": "llm_1", "target": "end"},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have no errors
assert result.is_valid is True
assert len(result.fixable_errors) == 0
assert len(result.user_required_errors) == 0
def test_workflow_with_multiple_errors(self):
"""Test workflow with multiple types of errors."""
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {}, # Missing prompt_template and model
},
{
"id": "kb_1",
"type": "knowledge-retrieval",
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
},
{"id": "end", "type": "end", "config": {}},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have multiple errors
assert result.has_errors is True
assert len(result.fixable_errors) >= 2 # model, prompt_template
assert len(result.user_required_errors) >= 1 # dataset placeholder
# Check stats
assert result.stats["total_nodes"] == 4
assert result.stats["total_errors"] >= 3

View File

@@ -0,0 +1,197 @@
from unittest.mock import MagicMock, patch
import pytest
from core.tools.entities.tool_entities import ToolProviderType
from core.workflow.nodes.agent.agent_node import AgentNode
class TestInferToolProviderType:
"""Test cases for AgentNode._infer_tool_provider_type method."""
def test_infer_type_from_config_workflow(self):
"""Test inferring workflow provider type from config."""
tool_config = {
"type": "workflow",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
def test_infer_type_from_config_builtin(self):
"""Test inferring builtin provider type from config."""
tool_config = {
"type": "builtin",
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_from_config_api(self):
"""Test inferring API provider type from config."""
tool_config = {
"type": "api",
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
def test_infer_type_from_config_mcp(self):
"""Test inferring MCP provider type from config."""
tool_config = {
"type": "mcp",
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
def test_infer_type_invalid_config_value_raises_error(self):
"""Test that invalid type value in config raises ValueError."""
tool_config = {
"type": "invalid-type",
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with pytest.raises(ValueError):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_workflow_type_from_database(self):
"""Test inferring workflow provider type from database."""
tool_config = {
"provider_name": "workflow-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns a result
mock_session.scalar.return_value = True
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.WORKFLOW
# Should only query once (after finding WorkflowToolProvider)
assert mock_session.scalar.call_count == 1
def test_infer_mcp_type_from_database(self):
"""Test inferring MCP provider type from database."""
tool_config = {
"provider_name": "mcp-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns a result
mock_session.scalar.side_effect = [None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.MCP
assert mock_session.scalar.call_count == 2
def test_infer_api_type_from_database(self):
"""Test inferring API provider type from database."""
tool_config = {
"provider_name": "api-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First query (WorkflowToolProvider) returns None
# Second query (MCPToolProvider) returns None
# Third query (ApiToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.API
assert mock_session.scalar.call_count == 3
def test_infer_builtin_type_from_database(self):
"""Test inferring builtin provider type from database."""
tool_config = {
"provider_name": "builtin-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# First three queries return None
# Fourth query (BuiltinToolProvider) returns a result
mock_session.scalar.side_effect = [None, None, None, True]
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
assert mock_session.scalar.call_count == 4
def test_infer_type_default_when_not_found(self):
"""Test raising AgentNodeError when provider is not found in database."""
tool_config = {
"provider_name": "unknown-provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# All queries return None
mock_session.scalar.return_value = None
# Current implementation raises AgentNodeError when provider not found
from core.workflow.nodes.agent.exc import AgentNodeError
with pytest.raises(AgentNodeError, match="Tool provider with ID 'unknown-provider-id' not found"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)
def test_infer_type_default_when_no_provider_name(self):
"""Test defaulting to BUILT_IN when provider_name is missing."""
tool_config = {}
tenant_id = "test-tenant"
result = AgentNode._infer_tool_provider_type(tool_config, tenant_id)
assert result == ToolProviderType.BUILT_IN
def test_infer_type_database_exception_propagates(self):
"""Test that database exception propagates (current implementation doesn't catch it)."""
tool_config = {
"provider_name": "provider-id",
}
tenant_id = "test-tenant"
with patch("core.db.session_factory.session_factory.create_session") as mock_create_session:
mock_session = MagicMock()
mock_create_session.return_value.__enter__.return_value = mock_session
# Database query raises exception
mock_session.scalar.side_effect = Exception("Database error")
# Current implementation doesn't catch exceptions, so it propagates
with pytest.raises(Exception, match="Database error"):
AgentNode._infer_tool_provider_type(tool_config, tenant_id)

View File

@@ -10,9 +10,15 @@ type VersionSelectorProps = {
versionLen: number
value: number
onChange: (index: number) => void
contentClassName?: string
}
const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, onChange }) => {
const VersionSelector: React.FC<VersionSelectorProps> = ({
versionLen,
value,
onChange,
contentClassName,
}) => {
const { t } = useTranslation()
const [isOpen, {
setFalse: handleOpenFalse,
@@ -64,6 +70,7 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, on
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={cn(
'z-[99]',
contentClassName,
)}
>
<div

View File

@@ -1,9 +1,10 @@
import type { ActionItem, AppSearchResult } from './types'
import type { AppSearchResult, ScopeDescriptor } from './types'
import type { App } from '@/types/app'
import { fetchAppList } from '@/service/apps'
import { searchApps } from '@/service/use-goto-anything'
import { getRedirectionPath } from '@/utils/app-redirection'
import { AppTypeIcon } from '../../app/type-selector'
import AppIcon from '../../base/app-icon'
import { ACTION_KEYS } from '../constants'
const parser = (apps: App[]): AppSearchResult[] => {
return apps.map(app => ({
@@ -35,21 +36,14 @@ const parser = (apps: App[]): AppSearchResult[] => {
}))
}
export const appAction: ActionItem = {
key: '@app',
shortcut: '@app',
export const appScope: ScopeDescriptor = {
id: 'app',
shortcut: ACTION_KEYS.APP,
title: 'Search Applications',
description: 'Search and navigate to your applications',
// action,
search: async (_, searchTerm = '', _locale) => {
try {
const response = await fetchAppList({
url: 'apps',
params: {
page: 1,
name: searchTerm,
},
})
const response = await searchApps(searchTerm)
const apps = response?.data || []
return parser(apps)
}

View File

@@ -9,7 +9,7 @@ export {
export { slashCommandRegistry, SlashCommandRegistry } from './registry'
// Command system exports
export { slashAction } from './slash'
export { slashScope } from './slash'
export { registerSlashCommands, SlashCommandProvider, unregisterSlashCommands } from './slash'
export type { SlashCommandHandler } from './types'

View File

@@ -1,12 +1,13 @@
import type { CommandSearchResult } from '../types'
import type { SlashCommandHandler } from './types'
import type { Locale } from '@/i18n-config/language'
import { getI18n } from 'react-i18next'
import { languages } from '@/i18n-config/language'
import { registerCommands, unregisterCommands } from './command-bus'
// Language dependency types
type LanguageDeps = {
setLocale?: (locale: string) => Promise<void>
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}
const buildLanguageCommands = (query: string): CommandSearchResult[] => {

View File

@@ -6,20 +6,21 @@ import type { SlashCommandHandler } from './types'
* Responsible for managing registration, lookup, and search of all slash commands
*/
export class SlashCommandRegistry {
private commands = new Map<string, SlashCommandHandler>()
private commandDeps = new Map<string, any>()
private commands = new Map<string, SlashCommandHandler<unknown>>()
private commandDeps = new Map<string, unknown>()
/**
* Register command handler
*/
register<TDeps = any>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
register<TDeps = unknown>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
// Register main command name
this.commands.set(handler.name, handler)
// Cast to unknown first, then to SlashCommandHandler<unknown> to handle generic type variance
this.commands.set(handler.name, handler as SlashCommandHandler<unknown>)
// Register aliases
if (handler.aliases) {
handler.aliases.forEach((alias) => {
this.commands.set(alias, handler)
this.commands.set(alias, handler as SlashCommandHandler<unknown>)
})
}
@@ -57,7 +58,7 @@ export class SlashCommandRegistry {
/**
* Find command handler
*/
findCommand(commandName: string): SlashCommandHandler | undefined {
findCommand(commandName: string): SlashCommandHandler<unknown> | undefined {
return this.commands.get(commandName)
}
@@ -65,7 +66,7 @@ export class SlashCommandRegistry {
* Smart partial command matching
* Prioritize alias matching, then match command name prefix
*/
private findBestPartialMatch(partialName: string): SlashCommandHandler | undefined {
private findBestPartialMatch(partialName: string): SlashCommandHandler<unknown> | undefined {
const lowerPartial = partialName.toLowerCase()
// First check if any alias starts with this
@@ -81,7 +82,7 @@ export class SlashCommandRegistry {
/**
* Find handler by alias prefix
*/
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler | undefined {
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
for (const handler of this.getAllCommands()) {
if (handler.aliases?.some(alias => alias.toLowerCase().startsWith(prefix)))
return handler
@@ -92,7 +93,7 @@ export class SlashCommandRegistry {
/**
* Find handler by name prefix
*/
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler | undefined {
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
return this.getAllCommands().find(handler =>
handler.name.toLowerCase().startsWith(prefix),
)
@@ -101,8 +102,8 @@ export class SlashCommandRegistry {
/**
* Get all registered commands (deduplicated)
*/
getAllCommands(): SlashCommandHandler[] {
const uniqueCommands = new Map<string, SlashCommandHandler>()
getAllCommands(): SlashCommandHandler<unknown>[] {
const uniqueCommands = new Map<string, SlashCommandHandler<unknown>>()
this.commands.forEach((handler) => {
uniqueCommands.set(handler.name, handler)
})
@@ -113,7 +114,7 @@ export class SlashCommandRegistry {
* Get all available commands in current context (deduplicated and filtered)
* Commands without isAvailable method are considered always available
*/
getAvailableCommands(): SlashCommandHandler[] {
getAvailableCommands(): SlashCommandHandler<unknown>[] {
return this.getAllCommands().filter(handler => this.isCommandAvailable(handler))
}
@@ -228,7 +229,7 @@ export class SlashCommandRegistry {
/**
* Get command dependencies
*/
getCommandDependencies(commandName: string): any {
getCommandDependencies(commandName: string): unknown {
return this.commandDeps.get(commandName)
}
@@ -236,7 +237,7 @@ export class SlashCommandRegistry {
* Determine if a command is available in the current context.
* Defaults to true when a handler does not implement the guard.
*/
private isCommandAvailable(handler: SlashCommandHandler) {
private isCommandAvailable(handler: SlashCommandHandler<unknown>) {
return handler.isAvailable?.() ?? true
}
}

View File

@@ -1,12 +1,13 @@
'use client'
import type { ActionItem } from '../types'
import type { ScopeDescriptor } from '../types'
import type { SlashCommandDependencies } from './types'
import { useTheme } from 'next-themes'
import { useEffect } from 'react'
import { getI18n } from 'react-i18next'
import { setLocaleOnClient } from '@/i18n-config'
import { ACTION_KEYS } from '../../constants'
import { accountCommand } from './account'
import { bananaCommand } from './banana'
import { executeCommand } from './command-bus'
import { communityCommand } from './community'
import { docsCommand } from './docs'
import { forumCommand } from './forum'
@@ -17,17 +18,11 @@ import { zenCommand } from './zen'
const i18n = getI18n()
export const slashAction: ActionItem = {
key: '/',
shortcut: '/',
export const slashScope: ScopeDescriptor = {
id: 'slash',
shortcut: ACTION_KEYS.SLASH,
title: i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }),
description: i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }),
action: (result) => {
if (result.type !== 'command')
return
const { command, args } = result.data
executeCommand(command, args)
},
search: async (query, _searchTerm = '') => {
// Delegate all search logic to the command registry system
return slashCommandRegistry.search(query, i18n.language)
@@ -35,7 +30,7 @@ export const slashAction: ActionItem = {
}
// Register/unregister default handlers for slash commands with external dependencies.
export const registerSlashCommands = (deps: Record<string, any>) => {
export const registerSlashCommands = (deps: SlashCommandDependencies) => {
// Register command handlers to the registry system with their respective dependencies
slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })

View File

@@ -1,10 +1,11 @@
import type { CommandSearchResult } from '../types'
import type { Locale } from '@/i18n-config/language'
/**
* Slash command handler interface
* Each slash command should implement this interface
*/
export type SlashCommandHandler<TDeps = any> = {
export type SlashCommandHandler<TDeps = unknown> = {
/** Command name (e.g., 'theme', 'language') */
name: string
@@ -51,3 +52,31 @@ export type SlashCommandHandler<TDeps = any> = {
*/
unregister?: () => void
}
/**
* Theme command dependencies
*/
export type ThemeCommandDeps = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
}
/**
* Language command dependencies
*/
export type LanguageCommandDeps = {
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}
/**
* Commands without external dependencies
*/
export type NoDepsCommandDeps = Record<string, never>
/**
* Union type of all slash command dependencies
* Used for type-safe dependency injection in registerSlashCommands
*/
export type SlashCommandDependencies = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}

View File

@@ -3,228 +3,66 @@
*
* This file defines the action registry for the goto-anything search system.
* Actions handle different types of searches: apps, knowledge bases, plugins, workflow nodes, and commands.
*
* ## How to Add a New Slash Command
*
* 1. **Create Command Handler File** (in `./commands/` directory):
* ```typescript
* // commands/my-command.ts
* import type { SlashCommandHandler } from './types'
* import type { CommandSearchResult } from '../types'
* import { registerCommands, unregisterCommands } from './command-bus'
*
* interface MyCommandDeps {
* myService?: (data: any) => Promise<void>
* }
*
* export const myCommand: SlashCommandHandler<MyCommandDeps> = {
* name: 'mycommand',
* aliases: ['mc'], // Optional aliases
* description: 'My custom command description',
*
* async search(args: string, locale: string = 'en') {
* // Return search results based on args
* return [{
* id: 'my-result',
* title: 'My Command Result',
* description: 'Description of the result',
* type: 'command' as const,
* data: { command: 'my.action', args: { value: args } }
* }]
* },
*
* register(deps: MyCommandDeps) {
* registerCommands({
* 'my.action': async (args) => {
* await deps.myService?.(args?.value)
* }
* })
* },
*
* unregister() {
* unregisterCommands(['my.action'])
* }
* }
* ```
*
* **Example for Self-Contained Command (no external dependencies):**
* ```typescript
* // commands/calculator-command.ts
* export const calculatorCommand: SlashCommandHandler = {
* name: 'calc',
* aliases: ['calculator'],
* description: 'Simple calculator',
*
* async search(args: string) {
* if (!args.trim()) return []
* try {
* // Safe math evaluation (implement proper parser in real use)
* const result = Function('"use strict"; return (' + args + ')')()
* return [{
* id: 'calc-result',
* title: `${args} = ${result}`,
* description: 'Calculator result',
* type: 'command' as const,
* data: { command: 'calc.copy', args: { result: result.toString() } }
* }]
* } catch {
* return [{
* id: 'calc-error',
* title: 'Invalid expression',
* description: 'Please enter a valid math expression',
* type: 'command' as const,
* data: { command: 'calc.noop', args: {} }
* }]
* }
* },
*
* register() {
* registerCommands({
* 'calc.copy': (args) => navigator.clipboard.writeText(args.result),
* 'calc.noop': () => {} // No operation
* })
* },
*
* unregister() {
* unregisterCommands(['calc.copy', 'calc.noop'])
* }
* }
* ```
*
* 2. **Register Command** (in `./commands/slash.tsx`):
* ```typescript
* import { myCommand } from './my-command'
* import { calculatorCommand } from './calculator-command' // For self-contained commands
*
* export const registerSlashCommands = (deps: Record<string, any>) => {
* slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
* slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })
* slashCommandRegistry.register(myCommand, { myService: deps.myService }) // With dependencies
* slashCommandRegistry.register(calculatorCommand) // Self-contained, no dependencies
* }
*
* export const unregisterSlashCommands = () => {
* slashCommandRegistry.unregister('theme')
* slashCommandRegistry.unregister('language')
* slashCommandRegistry.unregister('mycommand')
* slashCommandRegistry.unregister('calc') // Add this line
* }
* ```
*
*
* 3. **Update SlashCommandProvider** (in `./commands/slash.tsx`):
* ```typescript
* export const SlashCommandProvider = () => {
* const theme = useTheme()
* const myService = useMyService() // Add external dependency if needed
*
* useEffect(() => {
* registerSlashCommands({
* setTheme: theme.setTheme, // Required for theme command
* setLocale: setLocaleOnClient, // Required for language command
* myService: myService, // Required for your custom command
* // Note: calculatorCommand doesn't need dependencies, so not listed here
* })
* return () => unregisterSlashCommands()
* }, [theme.setTheme, myService]) // Update dependency array for all dynamic deps
*
* return null
* }
* ```
*
* **Note:** Self-contained commands (like calculator) don't require dependencies but are
* still registered through the same system for consistent lifecycle management.
*
* 4. **Usage**: Users can now type `/mycommand` or `/mc` to use your command
*
* ## Command System Architecture
* - Commands are registered via `SlashCommandRegistry`
* - Each command is self-contained with its own dependencies
* - Commands support aliases for easier access
* - Command execution is handled by the command bus system
* - All commands should be registered through `SlashCommandProvider` for consistent lifecycle management
*
* ## Command Types
* **Commands with External Dependencies:**
* - Require external services, APIs, or React hooks
* - Must provide dependencies in `SlashCommandProvider`
* - Example: theme commands (needs useTheme), API commands (needs service)
*
* **Self-Contained Commands:**
* - Pure logic operations, no external dependencies
* - Still recommended to register through `SlashCommandProvider` for consistency
* - Example: calculator, text manipulation commands
*
* ## Available Actions
* - `@app` - Search applications
* - `@knowledge` / `@kb` - Search knowledge bases
* - `@plugin` - Search plugins
* - `@node` - Search workflow nodes (workflow pages only)
* - `/` - Execute slash commands (theme, language, banana, etc.)
*/
import type { ActionItem, SearchResult } from './types'
import { appAction } from './app'
import { slashAction } from './commands'
import type { ScopeContext, ScopeDescriptor, SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { appScope } from './app'
import { slashScope } from './commands'
import { slashCommandRegistry } from './commands/registry'
import { knowledgeAction } from './knowledge'
import { pluginAction } from './plugin'
import { ragPipelineNodesAction } from './rag-pipeline-nodes'
import { workflowNodesAction } from './workflow-nodes'
import { knowledgeScope } from './knowledge'
import { pluginScope } from './plugin'
import { registerRagPipelineNodeScope } from './rag-pipeline-nodes'
import { scopeRegistry, useScopeRegistry } from './scope-registry'
import { registerWorkflowNodeScope } from './workflow-nodes'
// Create dynamic Actions based on context
export const createActions = (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
const baseActions = {
slash: slashAction,
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
}
let scopesInitialized = false
// Add appropriate node search based on context
if (isRagPipelinePage) {
return {
...baseActions,
node: ragPipelineNodesAction,
}
}
else if (isWorkflowPage) {
return {
...baseActions,
node: workflowNodesAction,
}
}
export const initGotoAnythingScopes = () => {
if (scopesInitialized)
return
// Default actions without node search
return baseActions
scopesInitialized = true
scopeRegistry.register(slashScope)
scopeRegistry.register(appScope)
scopeRegistry.register(knowledgeScope)
scopeRegistry.register(pluginScope)
registerWorkflowNodeScope()
registerRagPipelineNodeScope()
}
// Legacy export for backward compatibility
export const Actions = {
slash: slashAction,
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
node: workflowNodesAction,
export const useGotoAnythingScopes = (context: ScopeContext) => {
initGotoAnythingScopes()
return useScopeRegistry(context)
}
const isSlashScope = (scope: ScopeDescriptor) => {
if (scope.shortcut === ACTION_KEYS.SLASH)
return true
return scope.aliases?.includes(ACTION_KEYS.SLASH) ?? false
}
const getScopeShortcuts = (scope: ScopeDescriptor) => [scope.shortcut, ...(scope.aliases ?? [])]
export const searchAnything = async (
locale: string,
query: string,
actionItem?: ActionItem,
dynamicActions?: Record<string, ActionItem>,
scope: ScopeDescriptor | undefined,
scopes: ScopeDescriptor[],
): Promise<SearchResult[]> => {
const trimmedQuery = query.trim()
if (actionItem) {
if (scope) {
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
const prefixPattern = new RegExp(`^(${escapeRegExp(actionItem.key)}|${escapeRegExp(actionItem.shortcut)})\\s*`)
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
const prefixPattern = new RegExp(`^(${shortcuts.join('|')})\\s*`)
const searchTerm = trimmedQuery.replace(prefixPattern, '').trim()
try {
return await actionItem.search(query, searchTerm, locale)
return await scope.search(query, searchTerm, locale)
}
catch (error) {
console.warn(`Search failed for ${actionItem.key}:`, error)
console.warn(`Search failed for ${scope.id}:`, error)
return []
}
}
@@ -232,19 +70,19 @@ export const searchAnything = async (
if (trimmedQuery.startsWith('@') || trimmedQuery.startsWith('/'))
return []
const globalSearchActions = Object.values(dynamicActions || Actions)
// Exclude slash commands from general search results
.filter(action => action.key !== '/')
// Filter out slash commands from general search
const searchScopes = scopes.filter(scope => !isSlashScope(scope))
// Use Promise.allSettled to handle partial failures gracefully
const searchPromises = globalSearchActions.map(async (action) => {
const searchPromises = searchScopes.map(async (action) => {
const actionId = action.id
try {
const results = await action.search(query, query, locale)
return { success: true, data: results, actionType: action.key }
return { success: true, data: results, actionType: actionId }
}
catch (error) {
console.warn(`Search failed for ${action.key}:`, error)
return { success: false, data: [], actionType: action.key, error }
console.warn(`Search failed for ${actionId}:`, error)
return { success: false, data: [], actionType: actionId, error }
}
})
@@ -258,7 +96,7 @@ export const searchAnything = async (
allResults.push(...result.value.data)
}
else {
const actionKey = globalSearchActions[index]?.key || 'unknown'
const actionKey = searchScopes[index]?.id || 'unknown'
failedActions.push(actionKey)
}
})
@@ -269,31 +107,31 @@ export const searchAnything = async (
return allResults
}
export const matchAction = (query: string, actions: Record<string, ActionItem>) => {
return Object.values(actions).find((action) => {
// Special handling for slash commands
if (action.key === '/') {
// Get all registered commands from the registry
const allCommands = slashCommandRegistry.getAllCommands()
// ...
// Check if query matches any registered command
export const matchAction = (query: string, scopes: ScopeDescriptor[]) => {
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
return scopes.find((scope) => {
// Special handling for slash commands
if (isSlashScope(scope)) {
const allCommands = slashCommandRegistry.getAllCommands()
return allCommands.some((cmd) => {
const cmdPattern = `/${cmd.name}`
// For direct mode commands, don't match (keep in command selector)
if (cmd.mode === 'direct')
return false
// For submenu mode commands, match when complete command is entered
return query === cmdPattern || query.startsWith(`${cmdPattern} `)
})
}
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
// Check if query matches shortcut (exact or prefix)
// Only match if it's the full shortcut followed by space
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
const reg = new RegExp(`^(${shortcuts.join('|')})(?:\\s|$)`)
return reg.test(query)
})
}
export * from './commands'
export * from './scope-registry'
export * from './types'
export { appAction, knowledgeAction, pluginAction, workflowNodesAction }
export { appScope, knowledgeScope, pluginScope }

View File

@@ -1,8 +1,9 @@
import type { ActionItem, KnowledgeSearchResult } from './types'
import type { KnowledgeSearchResult, ScopeDescriptor } from './types'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import { searchDatasets } from '@/service/use-goto-anything'
import { cn } from '@/utils/classnames'
import { Folder } from '../../base/icons/src/vender/solid/files'
import { ACTION_KEYS } from '../constants'
const EXTERNAL_PROVIDER = 'external' as const
const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER
@@ -30,22 +31,15 @@ const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => {
})
}
export const knowledgeAction: ActionItem = {
key: '@knowledge',
shortcut: '@kb',
export const knowledgeScope: ScopeDescriptor = {
id: 'knowledge',
shortcut: ACTION_KEYS.KNOWLEDGE,
aliases: ['@kb'],
title: 'Search Knowledge Bases',
description: 'Search and navigate to your knowledge bases',
// action,
search: async (_, searchTerm = '', _locale) => {
try {
const response = await fetchDatasets({
url: '/datasets',
params: {
page: 1,
limit: 10,
keyword: searchTerm,
},
})
const response = await searchDatasets(searchTerm)
const datasets = response?.data || []
return parser(datasets)
}

View File

@@ -1,9 +1,10 @@
import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types'
import type { ActionItem, PluginSearchResult } from './types'
import type { Plugin } from '../../plugins/types'
import type { PluginSearchResult, ScopeDescriptor } from './types'
import { renderI18nObject } from '@/i18n-config'
import { postMarketplace } from '@/service/base'
import { searchPlugins } from '@/service/use-goto-anything'
import Icon from '../../plugins/card/base/card-icon'
import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils'
import { ACTION_KEYS } from '../constants'
const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
return plugins.map((plugin) => {
@@ -18,21 +19,14 @@ const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
})
}
export const pluginAction: ActionItem = {
key: '@plugin',
shortcut: '@plugin',
export const pluginScope: ScopeDescriptor = {
id: 'plugin',
shortcut: ACTION_KEYS.PLUGIN,
title: 'Search Plugins',
description: 'Search and navigate to your plugins',
search: async (_, searchTerm = '', locale) => {
try {
const response = await postMarketplace<{ data: PluginsFromMarketplaceResponse }>('/plugins/search/advanced', {
body: {
page: 1,
page_size: 10,
query: searchTerm,
type: 'plugin',
},
})
const response = await searchPlugins(searchTerm)
if (!response?.data?.plugins) {
console.warn('Plugin search: Unexpected response structure', response)

View File

@@ -1,24 +1,41 @@
import type { ActionItem } from './types'
import type { ScopeSearchHandler } from './scope-registry'
import type { SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { scopeRegistry } from './scope-registry'
// Create the RAG pipeline nodes action
export const ragPipelineNodesAction: ActionItem = {
key: '@node',
shortcut: '@node',
title: 'Search RAG Pipeline Nodes',
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
searchFn: undefined, // Will be set by useRagPipelineSearch hook
search: async (_, searchTerm = '', _locale) => {
const scopeId = 'rag-pipeline-node'
let scopeRegistered = false
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
return async (_, searchTerm = '', _locale) => {
try {
// Use the searchFn if available (set by useRagPipelineSearch hook)
if (ragPipelineNodesAction.searchFn)
return ragPipelineNodesAction.searchFn(searchTerm)
// If not in RAG pipeline context, return empty array
if (searchFn)
return searchFn(searchTerm)
return []
}
catch (error) {
console.warn('RAG pipeline nodes search failed:', error)
return []
}
},
}
}
export const registerRagPipelineNodeScope = () => {
if (scopeRegistered)
return
scopeRegistered = true
scopeRegistry.register({
id: scopeId,
shortcut: ACTION_KEYS.NODE,
title: 'Search RAG Pipeline Nodes',
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
isAvailable: context => context.isRagPipelinePage,
search: buildSearchHandler(),
})
}
export const setRagPipelineNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
registerRagPipelineNodeScope()
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
}

View File

@@ -0,0 +1,123 @@
import type { SearchResult } from './types'
import { useCallback, useMemo, useSyncExternalStore } from 'react'
export type ScopeContext = {
isWorkflowPage: boolean
isRagPipelinePage: boolean
isAdmin?: boolean
}
export type ScopeSearchHandler = (
query: string,
searchTerm: string,
locale?: string,
) => Promise<SearchResult[]> | SearchResult[]
export type ScopeDescriptor = {
/**
* Unique identifier for the scope (e.g. 'app', 'plugin')
*/
id: string
/**
* Shortcut to trigger this scope (e.g. '@app')
*/
shortcut: string
/**
* Additional shortcuts that map to this scope (e.g. ['@kb'])
*/
aliases?: string[]
/**
* I18n key or string for the scope title
*/
title: string
/**
* Description for help text
*/
description: string
/**
* Search handler function
*/
search: ScopeSearchHandler
/**
* Predicate to check if this scope is available in current context
*/
isAvailable?: (context: ScopeContext) => boolean
}
type Listener = () => void
class ScopeRegistry {
private scopes: Map<string, ScopeDescriptor> = new Map()
private listeners: Set<Listener> = new Set()
private version = 0
register(scope: ScopeDescriptor) {
this.scopes.set(scope.id, scope)
this.notify()
}
unregister(id: string) {
if (this.scopes.delete(id))
this.notify()
}
getScope(id: string) {
return this.scopes.get(id)
}
getScopes(context: ScopeContext): ScopeDescriptor[] {
return Array.from(this.scopes.values())
.filter(scope => !scope.isAvailable || scope.isAvailable(context))
.sort((a, b) => a.shortcut.localeCompare(b.shortcut))
}
updateSearchHandler(id: string, search: ScopeSearchHandler) {
const scope = this.scopes.get(id)
if (!scope)
return
this.scopes.set(id, { ...scope, search })
this.notify()
}
getVersion() {
return this.version
}
subscribe(listener: Listener) {
this.listeners.add(listener)
return () => {
this.listeners.delete(listener)
}
}
private notify() {
this.version += 1
this.listeners.forEach(listener => listener())
}
}
export const scopeRegistry = new ScopeRegistry()
export const useScopeRegistry = (context: ScopeContext) => {
const subscribe = useCallback(
(listener: Listener) => scopeRegistry.subscribe(listener),
[],
)
const getSnapshot = useCallback(
() => scopeRegistry.getVersion(),
[],
)
const version = useSyncExternalStore(
subscribe,
getSnapshot,
getSnapshot,
)
return useMemo(
() => scopeRegistry.getScopes(context),
[version, context.isWorkflowPage, context.isRagPipelinePage, context.isAdmin],
)
}

View File

@@ -1,5 +1,4 @@
import type { ReactNode } from 'react'
import type { TypeWithI18N } from '../../base/form/types'
import type { Plugin } from '../../plugins/types'
import type { CommonNodeType } from '../../workflow/types'
import type { DataSet } from '@/models/datasets'
@@ -7,7 +6,7 @@ import type { App } from '@/types/app'
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
export type BaseSearchResult<T = any> = {
export type BaseSearchResult<T = unknown> = {
id: string
title: string
description?: string
@@ -39,20 +38,8 @@ export type WorkflowNodeSearchResult = {
export type CommandSearchResult = {
type: 'command'
} & BaseSearchResult<{ command: string, args?: Record<string, any> }>
} & BaseSearchResult<{ command: string, args?: Record<string, unknown> }>
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
export type ActionItem = {
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/'
shortcut: string
title: string | TypeWithI18N
description: string
action?: (data: SearchResult) => void
searchFn?: (searchTerm: string) => SearchResult[]
search: (
query: string,
searchTerm: string,
locale?: string,
) => (Promise<SearchResult[]> | SearchResult[])
}
export type { ScopeContext, ScopeDescriptor } from './scope-registry'

View File

@@ -1,24 +1,41 @@
import type { ActionItem } from './types'
import type { ScopeSearchHandler } from './scope-registry'
import type { SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { scopeRegistry } from './scope-registry'
// Create the workflow nodes action
export const workflowNodesAction: ActionItem = {
key: '@node',
shortcut: '@node',
title: 'Search Workflow Nodes',
description: 'Find and jump to nodes in the current workflow by name or type',
searchFn: undefined, // Will be set by useWorkflowSearch hook
search: async (_, searchTerm = '', _locale) => {
const scopeId = 'workflow-node'
let scopeRegistered = false
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
return async (_, searchTerm = '', _locale) => {
try {
// Use the searchFn if available (set by useWorkflowSearch hook)
if (workflowNodesAction.searchFn)
return workflowNodesAction.searchFn(searchTerm)
// If not in workflow context, return empty array
if (searchFn)
return searchFn(searchTerm)
return []
}
catch (error) {
console.warn('Workflow nodes search failed:', error)
return []
}
},
}
}
export const registerWorkflowNodeScope = () => {
if (scopeRegistered)
return
scopeRegistered = true
scopeRegistry.register({
id: scopeId,
shortcut: ACTION_KEYS.NODE,
title: 'Search Workflow Nodes',
description: 'Find and jump to nodes in the current workflow by name or type',
isAvailable: context => context.isWorkflowPage,
search: buildSearchHandler(),
})
}
export const setWorkflowNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
registerWorkflowNodeScope()
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
}

View File

@@ -1,5 +1,5 @@
import type { ActionItem } from './actions/types'
import { render, screen } from '@testing-library/react'
import type { ScopeDescriptor } from './actions/scope-registry'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { Command } from 'cmdk'
import * as React from 'react'
@@ -22,263 +22,315 @@ vi.mock('./actions/commands/registry', () => ({
},
}))
const createActions = (): Record<string, ActionItem> => ({
app: {
key: '@app',
type CommandSelectorProps = React.ComponentProps<typeof CommandSelector>
const mockScopes: ScopeDescriptor[] = [
{
id: 'app',
shortcut: '@app',
title: 'Apps',
title: 'Search Applications',
description: 'Search apps',
search: vi.fn(),
description: '',
} as ActionItem,
plugin: {
key: '@plugin',
},
{
id: 'knowledge',
shortcut: '@knowledge',
title: 'Search Knowledge Bases',
description: 'Search knowledge bases',
search: vi.fn(),
},
{
id: 'plugin',
shortcut: '@plugin',
title: 'Plugins',
title: 'Search Plugins',
description: 'Search plugins',
search: vi.fn(),
description: '',
} as ActionItem,
})
},
{
id: 'workflow-node',
shortcut: '@node',
title: 'Search Nodes',
description: 'Search workflow nodes',
search: vi.fn(),
},
]
const mockOnCommandSelect = vi.fn()
const mockOnCommandValueChange = vi.fn()
const buildCommandSelector = (props: Partial<CommandSelectorProps> = {}) => (
<Command>
<Command.List>
<CommandSelector
scopes={mockScopes}
onCommandSelect={mockOnCommandSelect}
{...props}
/>
</Command.List>
</Command>
)
const renderCommandSelector = (props: Partial<CommandSelectorProps> = {}) => {
return render(buildCommandSelector(props))
}
describe('CommandSelector', () => {
it('should list contextual search actions and notify selection', async () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="app"
originalQuery="@app"
/>
</Command>,
)
const actionButton = screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')
await userEvent.click(actionButton)
expect(onSelect).toHaveBeenCalledWith('@app')
beforeEach(() => {
vi.clearAllMocks()
})
it('should render slash commands when query starts with slash', async () => {
const actions = createActions()
const onSelect = vi.fn()
describe('Basic Rendering', () => {
it('should render all scopes when no filter is provided', () => {
renderCommandSelector()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="zen"
originalQuery="/zen"
/>
</Command>,
)
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
await userEvent.click(slashItem)
it('should render empty filter as showing all scopes', () => {
renderCommandSelector({ searchFilter: '' })
expect(onSelect).toHaveBeenCalledWith('/zen')
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
})
describe('Filtering Functionality', () => {
it('should filter scopes based on searchFilter - single match', () => {
renderCommandSelector({ searchFilter: 'k' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
})
it('should filter scopes with multiple matches', () => {
renderCommandSelector({ searchFilter: 'p' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
})
it('should be case-insensitive when filtering', () => {
renderCommandSelector({ searchFilter: 'APP' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
})
it('should match partial strings', () => {
renderCommandSelector({ searchFilter: 'od' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
})
describe('Empty State', () => {
it('should show empty state when no matches found', () => {
renderCommandSelector({ searchFilter: 'xyz' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should not show empty state when filter is empty', () => {
renderCommandSelector({ searchFilter: '' })
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
})
})
describe('Selection and Highlight Management', () => {
it('should call onCommandValueChange when filter changes and first item differs', async () => {
const { rerender } = renderCommandSelector({
searchFilter: '',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
})
rerender(buildCommandSelector({
searchFilter: 'k',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
}))
await waitFor(() => {
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
})
})
it('should not call onCommandValueChange if current value still exists', async () => {
const { rerender } = renderCommandSelector({
searchFilter: '',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
})
rerender(buildCommandSelector({
searchFilter: 'a',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
}))
await waitFor(() => {
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
})
})
it('should handle onCommandSelect callback correctly', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'k' })
await user.click(screen.getByText('@knowledge'))
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
})
})
describe('Edge Cases', () => {
it('should handle empty scopes array', () => {
renderCommandSelector({ scopes: [] })
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should handle special characters in filter', () => {
renderCommandSelector({ searchFilter: '@' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
it('should handle undefined onCommandValueChange gracefully', () => {
const { rerender } = renderCommandSelector({ searchFilter: '' })
expect(() => {
rerender(buildCommandSelector({ searchFilter: 'k' }))
}).not.toThrow()
})
})
describe('User Interactions', () => {
it('should list contextual scopes and notify selection', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'app', originalQuery: '@app' })
await user.click(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc'))
expect(mockOnCommandSelect).toHaveBeenCalledWith('@app')
})
it('should render slash commands when query starts with slash', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'zen', originalQuery: '/zen' })
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
await user.click(slashItem)
expect(mockOnCommandSelect).toHaveBeenCalledWith('/zen')
})
})
it('should show all slash commands when no filter provided', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="/"
/>
</Command>,
)
renderCommandSelector({ searchFilter: '', originalQuery: '/' })
// Should show the zen command from mock
expect(screen.getByText('/zen')).toBeInTheDocument()
})
it('should exclude slash action when in @ mode', () => {
const actions = {
...createActions(),
slash: {
key: '/',
it('should exclude slash scope when in @ mode', () => {
const scopesWithSlash: ScopeDescriptor[] = [
...mockScopes,
{
id: 'slash',
shortcut: '/',
title: 'Slash',
search: vi.fn(),
description: '',
} as ActionItem,
}
const onSelect = vi.fn()
search: vi.fn(),
},
]
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
renderCommandSelector({ scopes: scopesWithSlash, searchFilter: '', originalQuery: '@' })
// Should show @ commands but not /
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('/')).not.toBeInTheDocument()
})
it('should show all actions when no filter in @ mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
it('should show all scopes when no filter in @ mode', () => {
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
})
it('should set default command value when items exist but value does not', () => {
const actions = createActions()
const onSelect = vi.fn()
const onCommandValueChange = vi.fn()
renderCommandSelector({
searchFilter: '',
originalQuery: '@',
commandValue: 'non-existent',
onCommandValueChange: mockOnCommandValueChange,
})
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
commandValue="non-existent"
onCommandValueChange={onCommandValueChange}
/>
</Command>,
)
expect(onCommandValueChange).toHaveBeenCalledWith('@app')
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@app')
})
it('should NOT set command value when value already exists in items', () => {
const actions = createActions()
const onSelect = vi.fn()
const onCommandValueChange = vi.fn()
renderCommandSelector({
searchFilter: '',
originalQuery: '@',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
})
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
commandValue="@app"
onCommandValueChange={onCommandValueChange}
/>
</Command>,
)
expect(onCommandValueChange).not.toHaveBeenCalled()
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
})
it('should show no matching commands message when filter has no results', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="nonexistent"
originalQuery="@nonexistent"
/>
</Command>,
)
renderCommandSelector({ searchFilter: 'nonexistent', originalQuery: '@nonexistent' })
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should show no matching commands for slash mode with no results', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="nonexistentcommand"
originalQuery="/nonexistentcommand"
/>
</Command>,
)
renderCommandSelector({ searchFilter: 'nonexistentcommand', originalQuery: '/nonexistentcommand' })
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should render description for @ commands', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
expect(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.actions.searchPluginsDesc')).toBeInTheDocument()
})
it('should render group header for @ mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="@"
/>
</Command>,
)
renderCommandSelector({ searchFilter: '', originalQuery: '@' })
expect(screen.getByText('app.gotoAnything.selectSearchType')).toBeInTheDocument()
})
it('should render group header for slash mode', () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter=""
originalQuery="/"
/>
</Command>,
)
renderCommandSelector({ searchFilter: '', originalQuery: '/' })
expect(screen.getByText('app.gotoAnything.groups.commands')).toBeInTheDocument()
})

View File

@@ -1,13 +1,14 @@
import type { FC } from 'react'
import type { ActionItem } from './actions/types'
import type { ScopeDescriptor } from './actions/scope-registry'
import { Command } from 'cmdk'
import { usePathname } from 'next/navigation'
import { useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { slashCommandRegistry } from './actions/commands/registry'
import { ACTION_KEYS } from './constants'
type Props = {
actions: Record<string, ActionItem>
scopes: ScopeDescriptor[]
onCommandSelect: (commandKey: string) => void
searchFilter?: string
commandValue?: string
@@ -15,7 +16,7 @@ type Props = {
originalQuery?: string
}
const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
const { t } = useTranslation()
const pathname = usePathname()
@@ -43,22 +44,31 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
}))
}, [isSlashMode, searchFilter, pathname])
const filteredActions = useMemo(() => {
const filteredScopes = useMemo(() => {
if (isSlashMode)
return []
return Object.values(actions).filter((action) => {
return scopes.filter((scope) => {
// Exclude slash action when in @ mode
if (action.key === '/')
if (scope.id === 'slash' || scope.shortcut === ACTION_KEYS.SLASH)
return false
if (!searchFilter)
return true
const filterLower = searchFilter.toLowerCase()
return action.shortcut.toLowerCase().includes(filterLower)
})
}, [actions, searchFilter, isSlashMode])
const allItems = isSlashMode ? slashCommands : filteredActions
// Match against shortcut/aliases or title
const filterLower = searchFilter.toLowerCase()
const shortcuts = [scope.shortcut, ...(scope.aliases || [])]
return shortcuts.some(shortcut => shortcut.toLowerCase().includes(filterLower))
|| scope.title.toLowerCase().includes(filterLower)
}).map(scope => ({
key: scope.shortcut, // Map to shortcut for UI display consistency
shortcut: scope.shortcut,
title: scope.title,
description: scope.description,
}))
}, [scopes, searchFilter, isSlashMode])
const allItems = isSlashMode ? slashCommands : filteredScopes
useEffect(() => {
if (allItems.length > 0 && onCommandValueChange) {

View File

@@ -83,10 +83,10 @@ describe('EmptyState', () => {
})
it('should show specific search hint with shortcuts', () => {
const Actions = {
app: { key: '@app', shortcut: '@app' },
plugin: { key: '@plugin', shortcut: '@plugin' },
} as unknown as Record<string, import('../actions/types').ActionItem>
const Actions = [
{ id: 'app', shortcut: '@app', title: 'App', description: '', search: vi.fn() },
{ id: 'plugin', shortcut: '@plugin', title: 'Plugin', description: '', search: vi.fn() },
] as import('../actions/types').ScopeDescriptor[]
render(<EmptyState variant="no-results" searchMode="general" Actions={Actions} />)
expect(screen.getByText('gotoAnything.emptyState.trySpecificSearch:@app, @plugin')).toBeInTheDocument()

View File

@@ -1,7 +1,7 @@
'use client'
import type { FC } from 'react'
import type { ActionItem } from '../actions/types'
import type { ScopeDescriptor } from '../actions/types'
import { useTranslation } from 'react-i18next'
export type EmptyStateVariant = 'no-results' | 'error' | 'default' | 'loading'
@@ -10,14 +10,14 @@ export type EmptyStateProps = {
variant: EmptyStateVariant
searchMode?: string
error?: Error | null
Actions?: Record<string, ActionItem>
Actions?: ScopeDescriptor[]
}
const EmptyState: FC<EmptyStateProps> = ({
variant,
searchMode = 'general',
error,
Actions = {},
Actions = [],
}) => {
const { t } = useTranslation()
@@ -88,7 +88,7 @@ const EmptyState: FC<EmptyStateProps> = ({
return t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
}
const shortcuts = Object.values(Actions).map(action => action.shortcut).join(', ')
const shortcuts = Actions.map(scope => scope.shortcut).join(', ')
return t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts })
}

View File

@@ -0,0 +1,20 @@
/**
* Goto Anything Constants
* Centralized constants for action keys
*/
/**
* Action keys for scope-based searches
*/
export const ACTION_KEYS = {
APP: '@app',
KNOWLEDGE: '@knowledge',
PLUGIN: '@plugin',
NODE: '@node',
SLASH: '/',
} as const
/**
* Type-safe action key union type
*/
export type ActionKey = typeof ACTION_KEYS[keyof typeof ACTION_KEYS]

View File

@@ -32,23 +32,17 @@ vi.mock('../actions/commands/registry', () => ({
},
}))
const createMockActionItem = (
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/',
extra: Record<string, unknown> = {},
) => ({
key,
shortcut: key,
title: `${key} title`,
description: `${key} description`,
search: vi.fn().mockResolvedValue([]),
...extra,
})
const mockExecuteCommand = vi.fn()
vi.mock('../actions/commands', () => ({
executeCommand: (...args: unknown[]) => mockExecuteCommand(...args),
}))
vi.mock('@/app/components/workflow/constants', () => ({
VIBE_COMMAND_EVENT: 'vibe-command',
}))
const createMockOptions = (overrides = {}) => ({
Actions: {
slash: createMockActionItem('/', { action: vi.fn() }),
app: createMockActionItem('@app'),
},
setSearchQuery: vi.fn(),
clearSelection: vi.fn(),
inputRef: { current: { focus: vi.fn() } } as unknown as React.RefObject<HTMLInputElement>,
@@ -60,6 +54,7 @@ describe('useGotoAnythingNavigation', () => {
beforeEach(() => {
vi.clearAllMocks()
mockFindCommandResult = null
mockExecuteCommand.mockReset()
vi.useFakeTimers()
})
@@ -221,13 +216,8 @@ describe('useGotoAnythingNavigation', () => {
expect(mockRouterPush).not.toHaveBeenCalled()
})
it('should execute slash command action for command type', () => {
const actionMock = vi.fn()
const options = createMockOptions({
Actions: {
slash: { key: '/', shortcut: '/', action: actionMock },
},
})
it('should execute command via executeCommand for command type', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
@@ -242,7 +232,7 @@ describe('useGotoAnythingNavigation', () => {
result.current.handleNavigate(commandResult)
})
expect(actionMock).toHaveBeenCalledWith(commandResult)
expect(mockExecuteCommand).toHaveBeenCalledWith('theme.set', { theme: 'dark' })
})
it('should set activePlugin for plugin type', () => {
@@ -368,10 +358,8 @@ describe('useGotoAnythingNavigation', () => {
// No error should occur
})
it('should handle missing slash action', () => {
const options = createMockOptions({
Actions: {},
})
it('should handle command execution without error', () => {
const options = createMockOptions()
const { result } = renderHook(() => useGotoAnythingNavigation(options))
@@ -385,7 +373,7 @@ describe('useGotoAnythingNavigation', () => {
})
})
// No error should occur
expect(mockExecuteCommand).toHaveBeenCalledWith('test-command', undefined)
})
})
})

View File

@@ -2,10 +2,12 @@
import type { RefObject } from 'react'
import type { Plugin } from '../../plugins/types'
import type { ActionItem, SearchResult } from '../actions/types'
import type { SearchResult } from '../actions/types'
import { useRouter } from 'next/navigation'
import { useCallback, useState } from 'react'
import { VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
import { executeCommand } from '../actions/commands'
import { slashCommandRegistry } from '../actions/commands/registry'
export type UseGotoAnythingNavigationReturn = {
@@ -16,7 +18,6 @@ export type UseGotoAnythingNavigationReturn = {
}
export type UseGotoAnythingNavigationOptions = {
Actions: Record<string, ActionItem>
setSearchQuery: (query: string) => void
clearSelection: () => void
inputRef: RefObject<HTMLInputElement | null>
@@ -27,7 +28,6 @@ export const useGotoAnythingNavigation = (
options: UseGotoAnythingNavigationOptions,
): UseGotoAnythingNavigationReturn => {
const {
Actions,
setSearchQuery,
clearSelection,
inputRef,
@@ -67,9 +67,16 @@ export const useGotoAnythingNavigation = (
switch (result.type) {
case 'command': {
// Execute slash commands
const action = Actions.slash
action?.action?.(result)
if (result.data.command === 'workflow.vibe') {
if (typeof document !== 'undefined') {
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: result.data.args?.dsl } }))
}
break
}
// Execute slash commands using the command bus
const { command, args } = result.data
executeCommand(command, args)
break
}
case 'plugin':
@@ -79,13 +86,12 @@ export const useGotoAnythingNavigation = (
// Handle workflow node selection and navigation
if (result.metadata?.nodeId)
selectWorkflowNode(result.metadata.nodeId, true)
break
default:
if (result.path)
router.push(result.path)
}
}, [router, Actions, onClose, setSearchQuery])
}, [router, onClose, setSearchQuery])
return {
handleCommandSelect,

View File

@@ -35,11 +35,11 @@ vi.mock('../actions', () => ({
searchAnything: (...args: unknown[]) => mockSearchAnything(...args),
}))
const createMockActionItem = (key: '@app' | '@knowledge' | '@plugin' | '@node' | '/') => ({
key,
shortcut: key,
title: `${key} title`,
description: `${key} description`,
const createMockScopeDescriptor = (id: string, shortcut: string) => ({
id,
shortcut,
title: `${shortcut} title`,
description: `${shortcut} description`,
search: vi.fn().mockResolvedValue([]),
})
@@ -47,7 +47,7 @@ const createMockOptions = (overrides = {}) => ({
searchQueryDebouncedValue: '',
searchMode: 'general',
isCommandsMode: false,
Actions: { app: createMockActionItem('@app') },
scopes: [createMockScopeDescriptor('app', '@app')],
isWorkflowPage: false,
isRagPipelinePage: false,
cmdVal: '_',
@@ -300,36 +300,36 @@ describe('useGotoAnythingResults', () => {
describe('queryFn execution', () => {
it('should call matchAction with lowercased query', async () => {
const mockActions = { app: createMockActionItem('@app') }
mockMatchAction.mockReturnValue({ key: '@app' })
const mockScopes = [createMockScopeDescriptor('app', '@app')]
mockMatchAction.mockReturnValue(mockScopes[0])
mockSearchAnything.mockResolvedValue([])
renderHook(() => useGotoAnythingResults(createMockOptions({
searchQueryDebouncedValue: 'TEST QUERY',
Actions: mockActions,
scopes: mockScopes,
})))
expect(capturedQueryFn).toBeDefined()
await capturedQueryFn!()
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockActions)
expect(mockMatchAction).toHaveBeenCalledWith('test query', mockScopes)
})
it('should call searchAnything with correct parameters', async () => {
const mockActions = { app: createMockActionItem('@app') }
const mockAction = { key: '@app' }
const mockScopes = [createMockScopeDescriptor('app', '@app')]
const mockAction = mockScopes[0]
mockMatchAction.mockReturnValue(mockAction)
mockSearchAnything.mockResolvedValue([{ id: '1', type: 'app', title: 'Result' }])
renderHook(() => useGotoAnythingResults(createMockOptions({
searchQueryDebouncedValue: 'My Query',
Actions: mockActions,
scopes: mockScopes,
})))
expect(capturedQueryFn).toBeDefined()
const result = await capturedQueryFn!()
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockActions)
expect(mockSearchAnything).toHaveBeenCalledWith('en_US', 'my query', mockAction, mockScopes)
expect(result).toEqual([{ id: '1', type: 'app', title: 'Result' }])
})

View File

@@ -1,6 +1,6 @@
'use client'
import type { ActionItem, SearchResult } from '../actions/types'
import type { ScopeDescriptor, SearchResult } from '../actions/types'
import { useQuery } from '@tanstack/react-query'
import { useEffect, useMemo } from 'react'
import { useGetLanguage } from '@/context/i18n'
@@ -19,7 +19,7 @@ export type UseGotoAnythingResultsOptions = {
searchQueryDebouncedValue: string
searchMode: string
isCommandsMode: boolean
Actions: Record<string, ActionItem>
scopes: ScopeDescriptor[]
isWorkflowPage: boolean
isRagPipelinePage: boolean
cmdVal: string
@@ -33,7 +33,7 @@ export const useGotoAnythingResults = (
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
Actions,
scopes,
isWorkflowPage,
isRagPipelinePage,
cmdVal,
@@ -42,13 +42,9 @@ export const useGotoAnythingResults = (
const defaultLocale = useGetLanguage()
// Use action keys as stable cache key instead of the full Actions object
// (Actions contains functions which are not serializable)
const actionKeys = useMemo(() => Object.keys(Actions).sort(), [Actions])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- Actions intentionally excluded: contains non-serializable functions; actionKeys provides stable representation
// eslint-disable-next-line @tanstack/query/exhaustive-deps -- scopes intentionally excluded: contains non-serializable functions; scope IDs provide stable representation
queryKey: [
'goto-anything',
'search-result',
@@ -57,12 +53,12 @@ export const useGotoAnythingResults = (
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
actionKeys,
scopes.map(s => s.id).sort().join(','),
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return await searchAnything(defaultLocale, query, action, Actions)
const scope = matchAction(query, scopes)
return await searchAnything(defaultLocale, query, scope, scopes)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,

View File

@@ -1,9 +1,25 @@
import type { ActionItem } from '../actions/types'
import type { ScopeDescriptor } from '../actions/types'
import { act, renderHook } from '@testing-library/react'
import { useGotoAnythingSearch } from './use-goto-anything-search'
let mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
let mockMatchActionResult: Partial<ActionItem> | undefined
let mockMatchActionResult: ScopeDescriptor | undefined
const baseScopesMock: ScopeDescriptor[] = [
{ id: 'slash', shortcut: '/', title: 'Slash', description: 'Slash commands', search: vi.fn() },
{ id: 'app', shortcut: '@app', title: 'App', description: 'Search apps', search: vi.fn() },
{ id: 'knowledge', shortcut: '@knowledge', title: 'Knowledge', description: 'Search KB', search: vi.fn() },
]
const workflowScopesMock: ScopeDescriptor[] = [
...baseScopesMock,
{ id: 'node', shortcut: '@node', title: 'Node', description: 'Search nodes', search: vi.fn() },
]
const ragScopesMock: ScopeDescriptor[] = [
...baseScopesMock,
{ id: 'ragNode', shortcut: '@node', title: 'RAG Node', description: 'Search RAG nodes', search: vi.fn() },
]
vi.mock('ahooks', () => ({
useDebounce: <T>(value: T) => value,
@@ -14,19 +30,12 @@ vi.mock('../context', () => ({
}))
vi.mock('../actions', () => ({
createActions: (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
const base = {
slash: { key: '/', shortcut: '/' },
app: { key: '@app', shortcut: '@app' },
knowledge: { key: '@knowledge', shortcut: '@kb' },
}
if (isWorkflowPage) {
return { ...base, node: { key: '@node', shortcut: '@node' } }
}
if (isRagPipelinePage) {
return { ...base, ragNode: { key: '@node', shortcut: '@node' } }
}
return base
useGotoAnythingScopes: (context: { isWorkflowPage: boolean, isRagPipelinePage: boolean }) => {
if (context.isWorkflowPage)
return workflowScopesMock
if (context.isRagPipelinePage)
return ragScopesMock
return baseScopesMock
},
matchAction: () => mockMatchActionResult,
}))
@@ -74,30 +83,30 @@ describe('useGotoAnythingSearch', () => {
})
})
describe('Actions', () => {
it('should provide Actions based on context', () => {
describe('scopes', () => {
it('should provide scopes based on context', () => {
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions).toBeDefined()
expect(typeof result.current.Actions).toBe('object')
expect(result.current.scopes).toBeDefined()
expect(Array.isArray(result.current.scopes)).toBe(true)
})
it('should include node action when on workflow page', () => {
it('should include node scope when on workflow page', () => {
mockContextValue = { isWorkflowPage: true, isRagPipelinePage: false }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.node).toBeDefined()
expect(result.current.scopes.find(s => s.id === 'node')).toBeDefined()
})
it('should include ragNode action when on RAG pipeline page', () => {
it('should include ragNode scope when on RAG pipeline page', () => {
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: true }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.ragNode).toBeDefined()
expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeDefined()
})
it('should not include node actions when on regular page', () => {
it('should not include node scopes when on regular page', () => {
mockContextValue = { isWorkflowPage: false, isRagPipelinePage: false }
const { result } = renderHook(() => useGotoAnythingSearch())
expect(result.current.Actions.node).toBeUndefined()
expect(result.current.Actions.ragNode).toBeUndefined()
expect(result.current.scopes.find(s => s.id === 'node')).toBeUndefined()
expect(result.current.scopes.find(s => s.id === 'ragNode')).toBeUndefined()
})
})
@@ -145,7 +154,7 @@ describe('useGotoAnythingSearch', () => {
})
it('should return false when query starts with "@" and action matches', () => {
mockMatchActionResult = { key: '@app', shortcut: '@app' }
mockMatchActionResult = baseScopesMock.find(s => s.id === 'app')
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
@@ -206,8 +215,8 @@ describe('useGotoAnythingSearch', () => {
expect(result.current.searchMode).toBe('general')
})
it('should return action key when action matches', () => {
mockMatchActionResult = { key: '@app', shortcut: '@app' }
it('should return action shortcut when action matches', () => {
mockMatchActionResult = baseScopesMock.find(s => s.id === 'app')
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {
@@ -217,8 +226,8 @@ describe('useGotoAnythingSearch', () => {
expect(result.current.searchMode).toBe('@app')
})
it('should return "@command" when action key is "/"', () => {
mockMatchActionResult = { key: '/', shortcut: '/' }
it('should return "@command" when action is slash', () => {
mockMatchActionResult = baseScopesMock.find(s => s.id === 'slash')
const { result } = renderHook(() => useGotoAnythingSearch())
act(() => {

View File

@@ -1,9 +1,10 @@
'use client'
import type { ActionItem } from '../actions/types'
import type { ScopeDescriptor } from '../actions/types'
import { useDebounce } from 'ahooks'
import { useCallback, useMemo, useState } from 'react'
import { createActions, matchAction } from '../actions'
import { matchAction, useGotoAnythingScopes } from '../actions'
import { ACTION_KEYS } from '../constants'
import { useGotoAnythingContext } from '../context'
export type UseGotoAnythingSearchReturn = {
@@ -15,7 +16,7 @@ export type UseGotoAnythingSearchReturn = {
cmdVal: string
setCmdVal: (val: string) => void
clearSelection: () => void
Actions: Record<string, ActionItem>
scopes: ScopeDescriptor[]
}
export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
@@ -23,10 +24,8 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('_')
// Filter actions based on context
const Actions = useMemo(() => {
return createActions(isWorkflowPage, isRagPipelinePage)
}, [isWorkflowPage, isRagPipelinePage])
// Fetch scopes from registry based on context
const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage })
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
@@ -35,28 +34,30 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
const isCommandsMode = useMemo(() => {
const trimmed = searchQuery.trim()
return trimmed === '@' || trimmed === '/'
|| (trimmed.startsWith('@') && !matchAction(trimmed, Actions))
|| (trimmed.startsWith('/') && !matchAction(trimmed, Actions))
}, [searchQuery, Actions])
|| (trimmed.startsWith('@') && !matchAction(trimmed, scopes))
|| (trimmed.startsWith('/') && !matchAction(trimmed, scopes))
}, [searchQuery, scopes])
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
return 'commands'
}
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
const action = matchAction(query, scopes)
if (!action)
return 'general'
return action.key === '/' ? '@command' : action.key
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH)
return '@command'
return action.shortcut
}, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery])
// Prevent automatic selection of the first option when cmdVal is not set
const clearSelection = useCallback(() => {
@@ -72,6 +73,6 @@ export const useGotoAnythingSearch = (): UseGotoAnythingSearchReturn => {
cmdVal,
setCmdVal,
clearSelection,
Actions,
scopes,
}
}

View File

@@ -0,0 +1,93 @@
import { keepPreviousData, useQuery } from '@tanstack/react-query'
import { useDebounce } from 'ahooks'
import { useMemo } from 'react'
import { useGetLanguage } from '@/context/i18n'
import { matchAction, searchAnything, useGotoAnythingScopes } from '../actions'
import { ACTION_KEYS } from '../constants'
import { useGotoAnythingContext } from '../context'
export const useSearch = (searchQuery: string) => {
const defaultLocale = useGetLanguage()
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
// Fetch scopes from registry based on context
const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage })
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
})
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), scopes))
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), scopes))
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
}
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, scopes)
if (!action)
return 'general'
if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH)
return '@command'
return action.shortcut
}, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
scopes.map(s => s.id).sort().join(','),
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const scope = matchAction(query, scopes)
return await searchAnything(defaultLocale, query, scope, scopes)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
placeholderData: keepPreviousData,
},
)
const dedupedResults = useMemo(() => {
if (!searchQuery.trim())
return []
const seen = new Set<string>()
return searchResults.filter((result) => {
const key = `${result.type}-${result.id}`
if (seen.has(key))
return false
seen.add(key)
return true
})
}, [searchResults, searchQuery])
return {
scopes,
searchResults: dedupedResults,
isLoading,
isError,
error,
searchMode,
isCommandsMode,
}
}

View File

@@ -1,5 +1,6 @@
import type { ReactNode } from 'react'
import type { ActionItem, SearchResult } from './actions/types'
import type { ScopeDescriptor } from './actions/scope-registry'
import type { SearchResult } from './actions/types'
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
@@ -58,6 +59,7 @@ const triggerKeyPress = (combo: string) => {
let mockQueryResult = { data: [] as TestSearchResult[], isLoading: false, isError: false, error: null as Error | null }
vi.mock('@tanstack/react-query', () => ({
useQuery: () => mockQueryResult,
keepPreviousData: (data: unknown) => data,
}))
vi.mock('@/context/i18n', () => ({
@@ -70,37 +72,30 @@ vi.mock('./context', () => ({
GotoAnythingProvider: ({ children }: { children: React.ReactNode }) => <>{children}</>,
}))
vi.mock('@/app/components/workflow/utils', () => ({
getKeyboardKeyNameBySystem: (key: string) => key,
}))
type MatchAction = typeof import('./actions').matchAction
type SearchAnything = typeof import('./actions').searchAnything
const createActionItem = (key: ActionItem['key'], shortcut: string): ActionItem => ({
key,
shortcut,
title: `${key} title`,
description: `${key} desc`,
action: vi.fn(),
search: vi.fn(),
const mockState = vi.hoisted(() => {
const state = {
scopes: [] as ScopeDescriptor[],
useGotoAnythingScopesMock: vi.fn(() => state.scopes),
matchActionMock: vi.fn<MatchAction>(() => undefined),
searchAnythingMock: vi.fn<SearchAnything>(async () => []),
}
return state
})
const actionsMock = {
slash: createActionItem('/', '/'),
app: createActionItem('@app', '@app'),
plugin: createActionItem('@plugin', '@plugin'),
}
const createActionsMock = vi.fn(() => actionsMock)
const matchActionMock = vi.fn(() => undefined)
const searchAnythingMock = vi.fn(async () => mockQueryResult.data)
vi.mock('./actions', () => ({
createActions: () => createActionsMock(),
matchAction: () => matchActionMock(),
searchAnything: () => searchAnythingMock(),
__esModule: true,
matchAction: (...args: Parameters<MatchAction>) => mockState.matchActionMock(...args),
searchAnything: (...args: Parameters<SearchAnything>) => mockState.searchAnythingMock(...args),
useGotoAnythingScopes: () => mockState.useGotoAnythingScopesMock(),
}))
vi.mock('./actions/commands', () => ({
SlashCommandProvider: () => null,
executeCommand: vi.fn(),
}))
type MockSlashCommand = {
@@ -118,6 +113,20 @@ vi.mock('./actions/commands/registry', () => ({
},
}))
const createScope = (id: ScopeDescriptor['id'], shortcut: string): ScopeDescriptor => ({
id,
shortcut,
title: `${id} title`,
description: `${id} desc`,
search: vi.fn(),
})
const scopesMock = [
createScope('slash', '/'),
createScope('app', '@app'),
createScope('plugin', '@plugin'),
]
vi.mock('@/app/components/workflow/utils/common', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',
getKeyboardKeyNameBySystem: (key: string) => key,
@@ -144,8 +153,10 @@ describe('GotoAnything', () => {
routerPush.mockClear()
Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key])
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
matchActionMock.mockReset()
searchAnythingMock.mockClear()
mockState.scopes = scopesMock
mockState.matchActionMock.mockReset()
mockState.searchAnythingMock.mockClear()
mockState.searchAnythingMock.mockImplementation(async () => mockQueryResult.data as SearchResult[])
mockFindCommand = null
})

View File

@@ -39,7 +39,7 @@ const GotoAnything: FC<Props> = ({
cmdVal,
setCmdVal,
clearSelection,
Actions,
scopes,
} = useGotoAnythingSearch()
// Modal state management
@@ -76,7 +76,7 @@ const GotoAnything: FC<Props> = ({
searchQueryDebouncedValue,
searchMode,
isCommandsMode,
Actions,
scopes,
isWorkflowPage,
isRagPipelinePage,
cmdVal,
@@ -90,7 +90,6 @@ const GotoAnything: FC<Props> = ({
activePlugin,
setActivePlugin,
} = useGotoAnythingNavigation({
Actions,
setSearchQuery,
clearSelection,
inputRef,
@@ -179,7 +178,7 @@ const GotoAnything: FC<Props> = ({
{isCommandsMode
? (
<CommandSelector
actions={Actions}
scopes={scopes}
onCommandSelect={handleCommandSelect}
searchFilter={searchQuery.trim().substring(1)}
commandValue={cmdVal}
@@ -198,7 +197,7 @@ const GotoAnything: FC<Props> = ({
<EmptyState
variant="no-results"
searchMode={searchMode}
Actions={Actions}
Actions={scopes}
/>
)}

View File

@@ -2052,9 +2052,6 @@ describe('CommonCreateModal', () => {
expect(mockCreateBuilder).toHaveBeenCalled()
})
// Flush pending state updates from createBuilder promise resolution
await act(async () => {})
const input = screen.getByTestId('form-field-webhook_url')
fireEvent.change(input, { target: { value: 'test' } })

View File

@@ -145,6 +145,22 @@ vi.mock('@/app/components/workflow/constants', () => ({
WORKFLOW_DATA_UPDATE: 'WORKFLOW_DATA_UPDATE',
}))
// Mock FileReader
class MockFileReader {
result: string | null = null
onload: ((e: { target: { result: string | null } }) => void) | null = null
readAsText(_file: File) {
// Simulate async file reading using queueMicrotask for more reliable async behavior
queueMicrotask(() => {
this.result = 'test file content'
if (this.onload) {
this.onload({ target: { result: this.result } })
}
})
}
}
afterEach(() => {
cleanup()
vi.clearAllMocks()
@@ -154,6 +170,7 @@ describe('UpdateDSLModal', () => {
const mockOnCancel = vi.fn()
const mockOnBackup = vi.fn()
const mockOnImport = vi.fn()
let originalFileReader: typeof FileReader
const defaultProps = {
onCancel: mockOnCancel,
@@ -169,6 +186,14 @@ describe('UpdateDSLModal', () => {
pipeline_id: 'test-pipeline-id',
})
mockHandleCheckPluginDependencies.mockResolvedValue(undefined)
// Mock FileReader
originalFileReader = globalThis.FileReader
globalThis.FileReader = MockFileReader as unknown as typeof FileReader
})
afterEach(() => {
globalThis.FileReader = originalFileReader
})
describe('rendering', () => {
@@ -538,7 +563,6 @@ describe('UpdateDSLModal', () => {
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
fireEvent.change(fileInput, { target: { files: [file] } })
// Wait for FileReader to process and button to be enabled
await waitFor(() => {
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
@@ -563,12 +587,15 @@ describe('UpdateDSLModal', () => {
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
fireEvent.change(fileInput, { target: { files: [file] } })
// Wait for FileReader to complete and button to be enabled
// Wait for FileReader to complete (setTimeout 0) and button to be enabled
await waitFor(() => {
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
})
// Give extra time for the FileReader's setTimeout to complete
await new Promise(resolve => setTimeout(resolve, 10))
const importButton = screen.getByText('common.overwriteAndImport')
fireEvent.click(importButton)
@@ -597,11 +624,6 @@ describe('UpdateDSLModal', () => {
expect(importButton).not.toBeDisabled()
})
// Flush the FileReader microtask to ensure fileContent is set
await act(async () => {
await new Promise<void>(resolve => queueMicrotask(resolve))
})
const importButton = screen.getByText('common.overwriteAndImport')
fireEvent.click(importButton)
@@ -703,7 +725,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('1.0.0')).toBeInTheDocument()
expect(screen.getByText('2.0.0')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
})
it('should close error modal when cancel button is clicked', async () => {
@@ -732,7 +754,7 @@ describe('UpdateDSLModal', () => {
// Wait for error modal
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
// Find and click cancel button in error modal - it should be the one with secondary variant
const cancelButtons = screen.getAllByText('newApp.Cancel')
@@ -750,8 +772,6 @@ describe('UpdateDSLModal', () => {
})
it('should call importDSLConfirm when confirm button is clicked in error modal', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockImportDSL.mockResolvedValue({
id: 'import-id',
status: DSLImportStatus.PENDING,
@@ -769,27 +789,20 @@ describe('UpdateDSLModal', () => {
const fileInput = screen.getByTestId('file-input')
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
fireEvent.change(fileInput, { target: { files: [file] } })
await act(async () => {
fireEvent.change(fileInput, { target: { files: [file] } })
// Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask)
await new Promise<void>(resolve => queueMicrotask(resolve))
await waitFor(() => {
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
})
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
await act(async () => {
fireEvent.click(importButton)
// Flush the promise resolution from mockImportDSL
await Promise.resolve()
// Advance past the 300ms setTimeout in the component
await vi.advanceTimersByTimeAsync(350)
})
fireEvent.click(importButton)
// Wait for error modal
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
// Click confirm button
const confirmButton = screen.getByText('newApp.Confirm')
@@ -798,8 +811,6 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(mockImportDSLConfirm).toHaveBeenCalledWith('import-id')
})
vi.useRealTimers()
})
it('should show success notification after confirm completes', async () => {
@@ -832,7 +843,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -874,7 +885,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -913,7 +924,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -955,7 +966,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -997,7 +1008,7 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -1008,8 +1019,6 @@ describe('UpdateDSLModal', () => {
})
it('should call handleCheckPluginDependencies after confirm', async () => {
vi.useFakeTimers({ shouldAdvanceTime: true })
mockImportDSL.mockResolvedValue({
id: 'import-id',
status: DSLImportStatus.PENDING,
@@ -1027,27 +1036,19 @@ describe('UpdateDSLModal', () => {
const fileInput = screen.getByTestId('file-input')
const file = new File(['test content'], 'test.pipeline', { type: 'text/yaml' })
fireEvent.change(fileInput, { target: { files: [file] } })
await act(async () => {
fireEvent.change(fileInput, { target: { files: [file] } })
// Flush microtasks scheduled by the FileReader mock (which uses queueMicrotask)
await new Promise<void>(resolve => queueMicrotask(resolve))
await waitFor(() => {
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
})
const importButton = screen.getByText('common.overwriteAndImport')
expect(importButton).not.toBeDisabled()
await act(async () => {
fireEvent.click(importButton)
// Flush the promise resolution from mockImportDSL
await Promise.resolve()
// Advance past the 300ms setTimeout in the component
await vi.advanceTimersByTimeAsync(350)
})
fireEvent.click(importButton)
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
const confirmButton = screen.getByText('newApp.Confirm')
fireEvent.click(confirmButton)
@@ -1055,8 +1056,6 @@ describe('UpdateDSLModal', () => {
await waitFor(() => {
expect(mockHandleCheckPluginDependencies).toHaveBeenCalledWith('test-pipeline-id', true)
})
vi.useRealTimers()
})
it('should handle undefined imported_dsl_version and current_dsl_version', async () => {
@@ -1085,7 +1084,7 @@ describe('UpdateDSLModal', () => {
// Should show error modal even with undefined versions
await waitFor(() => {
expect(screen.getByText('newApp.appCreateDSLErrorTitle')).toBeInTheDocument()
}, { timeout: 1000 })
}, { timeout: 500 })
})
it('should not call importDSLConfirm when importId is not set', async () => {

View File

@@ -1,49 +1,79 @@
import { act, renderHook, waitFor } from '@testing-library/react'
import { renderHook } from '@testing-library/react'
import { act } from 'react'
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
// ============================================================================
// Import after mocks
// ============================================================================
import { useDSL } from './use-DSL'
// Mock dependencies
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
useToastContext: () => ({ notify: mockNotify }),
}))
const mockEventEmitter = { emit: vi.fn() }
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: () => ({ eventEmitter: mockEventEmitter }),
}))
const mockDoSyncWorkflowDraft = vi.fn()
vi.mock('./use-nodes-sync-draft', () => ({
useNodesSyncDraft: () => ({ doSyncWorkflowDraft: mockDoSyncWorkflowDraft }),
}))
const mockGetState = vi.fn()
vi.mock('@/app/components/workflow/store', () => ({
useWorkflowStore: () => ({ getState: mockGetState }),
}))
const mockExportPipelineConfig = vi.fn()
vi.mock('@/service/use-pipeline', () => ({
useExportPipelineDSL: () => ({ mutateAsync: mockExportPipelineConfig }),
}))
const mockFetchWorkflowDraft = vi.fn()
vi.mock('@/service/workflow', () => ({
fetchWorkflowDraft: (...args: unknown[]) => mockFetchWorkflowDraft(...args),
}))
const mockDownloadBlob = vi.fn()
vi.mock('@/utils/download', () => ({
downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args),
}))
// ============================================================================
// Mocks
// ============================================================================
// Mock react-i18next
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key,
}),
}))
// Mock toast context
const mockNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
useToastContext: () => ({
notify: mockNotify,
}),
}))
// Mock event emitter context
const mockEmit = vi.fn()
vi.mock('@/context/event-emitter', () => ({
useEventEmitterContextContext: () => ({
eventEmitter: {
emit: mockEmit,
},
}),
}))
// Mock workflow store
const mockWorkflowStoreGetState = vi.fn()
vi.mock('@/app/components/workflow/store', () => ({
useWorkflowStore: () => ({
getState: mockWorkflowStoreGetState,
}),
}))
// Mock useNodesSyncDraft
const mockDoSyncWorkflowDraft = vi.fn()
vi.mock('./use-nodes-sync-draft', () => ({
useNodesSyncDraft: () => ({
doSyncWorkflowDraft: mockDoSyncWorkflowDraft,
}),
}))
// Mock pipeline service
const mockExportPipelineConfig = vi.fn()
vi.mock('@/service/use-pipeline', () => ({
useExportPipelineDSL: () => ({
mutateAsync: mockExportPipelineConfig,
}),
}))
// Mock download utility
const mockDownloadBlob = vi.fn()
vi.mock('@/utils/download', () => ({
downloadBlob: (...args: unknown[]) => mockDownloadBlob(...args),
}))
// Mock workflow service
const mockFetchWorkflowDraft = vi.fn()
vi.mock('@/service/workflow', () => ({
fetchWorkflowDraft: (url: string) => mockFetchWorkflowDraft(url),
}))
// Mock workflow constants
vi.mock('@/app/components/workflow/constants', () => ({
DSL_EXPORT_CHECK: 'DSL_EXPORT_CHECK',
}))
@@ -53,63 +83,48 @@ vi.mock('@/app/components/workflow/constants', () => ({
// ============================================================================
describe('useDSL', () => {
let mockLink: { href: string, download: string, click: ReturnType<typeof vi.fn>, style: { display: string }, remove: ReturnType<typeof vi.fn> }
let originalCreateElement: typeof document.createElement
let originalAppendChild: typeof document.body.appendChild
let mockCreateObjectURL: ReturnType<typeof vi.spyOn>
let mockRevokeObjectURL: ReturnType<typeof vi.spyOn>
beforeEach(() => {
vi.clearAllMocks()
// Create a proper mock link element with all required properties for downloadBlob
mockLink = {
href: '',
download: '',
click: vi.fn(),
style: { display: '' },
remove: vi.fn(),
}
// Save original and mock selectively - only intercept 'a' elements
originalCreateElement = document.createElement.bind(document)
document.createElement = vi.fn((tagName: string) => {
if (tagName === 'a') {
return mockLink as unknown as HTMLElement
}
return originalCreateElement(tagName)
}) as typeof document.createElement
// Mock document.body.appendChild for downloadBlob
originalAppendChild = document.body.appendChild.bind(document.body)
document.body.appendChild = vi.fn(<T extends Node>(node: T): T => node) as typeof document.body.appendChild
// downloadBlob uses window.URL, not URL
mockCreateObjectURL = vi.spyOn(window.URL, 'createObjectURL').mockReturnValue('blob:test-url')
mockRevokeObjectURL = vi.spyOn(window.URL, 'revokeObjectURL').mockImplementation(() => {})
// Default store state
mockGetState.mockReturnValue({
mockWorkflowStoreGetState.mockReturnValue({
pipelineId: 'test-pipeline-id',
knowledgeName: 'Test Knowledge Base',
})
mockDoSyncWorkflowDraft.mockResolvedValue(undefined)
mockExportPipelineConfig.mockResolvedValue({ data: 'yaml-content' })
mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] })
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: [],
})
})
afterEach(() => {
document.createElement = originalCreateElement
document.body.appendChild = originalAppendChild
mockCreateObjectURL.mockRestore()
mockRevokeObjectURL.mockRestore()
vi.clearAllMocks()
})
describe('hook initialization', () => {
it('should return exportCheck function', () => {
const { result } = renderHook(() => useDSL())
expect(result.current.exportCheck).toBeDefined()
expect(typeof result.current.exportCheck).toBe('function')
})
it('should return handleExportDSL function', () => {
const { result } = renderHook(() => useDSL())
expect(result.current.handleExportDSL).toBeDefined()
expect(typeof result.current.handleExportDSL).toBe('function')
})
})
describe('handleExportDSL', () => {
it('should return early when pipelineId is not set', async () => {
mockGetState.mockReturnValue({ pipelineId: null, knowledgeName: 'test' })
it('should not export when pipelineId is missing', async () => {
mockWorkflowStoreGetState.mockReturnValue({
pipelineId: undefined,
knowledgeName: 'Test',
})
const { result } = renderHook(() => useDSL())
@@ -118,6 +133,30 @@ describe('useDSL', () => {
})
expect(mockDoSyncWorkflowDraft).not.toHaveBeenCalled()
expect(mockExportPipelineConfig).not.toHaveBeenCalled()
})
it('should sync workflow draft before export', async () => {
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.handleExportDSL()
})
expect(mockDoSyncWorkflowDraft).toHaveBeenCalled()
})
it('should call exportPipelineConfig with correct params', async () => {
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.handleExportDSL(true)
})
expect(mockExportPipelineConfig).toHaveBeenCalledWith({
pipelineId: 'test-pipeline-id',
include: true,
})
})
it('should create and download file', async () => {
@@ -130,7 +169,7 @@ describe('useDSL', () => {
expect(mockDownloadBlob).toHaveBeenCalled()
})
it('should set correct download filename', async () => {
it('should use correct file extension for download', async () => {
const { result } = renderHook(() => useDSL())
await act(async () => {
@@ -158,7 +197,7 @@ describe('useDSL', () => {
)
})
it('should handle export error', async () => {
it('should show error notification on export failure', async () => {
mockExportPipelineConfig.mockRejectedValue(new Error('Export failed'))
const { result } = renderHook(() => useDSL())
@@ -167,33 +206,19 @@ describe('useDSL', () => {
await result.current.handleExportDSL()
})
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'exportFailed',
})
})
})
it('should pass include parameter', async () => {
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.handleExportDSL(true)
})
await waitFor(() => {
expect(mockExportPipelineConfig).toHaveBeenCalledWith({
pipelineId: 'test-pipeline-id',
include: true,
})
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'exportFailed',
})
})
})
describe('exportCheck', () => {
it('should return early when pipelineId is not set', async () => {
mockGetState.mockReturnValue({ pipelineId: null })
it('should not check when pipelineId is missing', async () => {
mockWorkflowStoreGetState.mockReturnValue({
pipelineId: undefined,
knowledgeName: 'Test',
})
const { result } = renderHook(() => useDSL())
@@ -204,8 +229,22 @@ describe('useDSL', () => {
expect(mockFetchWorkflowDraft).not.toHaveBeenCalled()
})
it('should call handleExportDSL directly when no secret variables', async () => {
mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: [] })
it('should fetch workflow draft', async () => {
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.exportCheck()
})
expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft')
})
it('should directly export when no secret environment variables', async () => {
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: [
{ id: '1', value_type: 'string', value: 'test' },
],
})
const { result } = renderHook(() => useDSL())
@@ -213,15 +252,16 @@ describe('useDSL', () => {
await result.current.exportCheck()
})
await waitFor(() => {
expect(mockFetchWorkflowDraft).toHaveBeenCalledWith('/rag/pipelines/test-pipeline-id/workflows/draft')
expect(mockDoSyncWorkflowDraft).toHaveBeenCalled()
})
// Should call doSyncWorkflowDraft (which means handleExportDSL was called)
expect(mockDoSyncWorkflowDraft).toHaveBeenCalled()
})
it('should emit event when secret variables exist', async () => {
const secretVars = [{ value_type: 'secret', name: 'API_KEY' }]
mockFetchWorkflowDraft.mockResolvedValue({ environment_variables: secretVars })
it('should emit DSL_EXPORT_CHECK event when secret variables exist', async () => {
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: [
{ id: '1', value_type: 'secret', value: 'secret-value' },
],
})
const { result } = renderHook(() => useDSL())
@@ -229,17 +269,15 @@ describe('useDSL', () => {
await result.current.exportCheck()
})
await waitFor(() => {
expect(mockEventEmitter.emit).toHaveBeenCalledWith({
type: expect.any(String),
payload: {
data: secretVars,
},
})
expect(mockEmit).toHaveBeenCalledWith({
type: 'DSL_EXPORT_CHECK',
payload: {
data: [{ id: '1', value_type: 'secret', value: 'secret-value' }],
},
})
})
it('should handle export check error', async () => {
it('should show error notification on check failure', async () => {
mockFetchWorkflowDraft.mockRejectedValue(new Error('Fetch failed'))
const { result } = renderHook(() => useDSL())
@@ -248,12 +286,68 @@ describe('useDSL', () => {
await result.current.exportCheck()
})
await waitFor(() => {
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'exportFailed',
})
expect(mockNotify).toHaveBeenCalledWith({
type: 'error',
message: 'exportFailed',
})
})
it('should filter only secret environment variables', async () => {
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: [
{ id: '1', value_type: 'string', value: 'plain' },
{ id: '2', value_type: 'secret', value: 'secret1' },
{ id: '3', value_type: 'number', value: '123' },
{ id: '4', value_type: 'secret', value: 'secret2' },
],
})
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.exportCheck()
})
expect(mockEmit).toHaveBeenCalledWith({
type: 'DSL_EXPORT_CHECK',
payload: {
data: [
{ id: '2', value_type: 'secret', value: 'secret1' },
{ id: '4', value_type: 'secret', value: 'secret2' },
],
},
})
})
it('should handle empty environment variables', async () => {
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: [],
})
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.exportCheck()
})
// Should directly call handleExportDSL since no secrets
expect(mockEmit).not.toHaveBeenCalled()
expect(mockDoSyncWorkflowDraft).toHaveBeenCalled()
})
it('should handle undefined environment variables', async () => {
mockFetchWorkflowDraft.mockResolvedValue({
environment_variables: undefined,
})
const { result } = renderHook(() => useDSL())
await act(async () => {
await result.current.exportCheck()
})
// Should directly call handleExportDSL since no secrets
expect(mockEmit).not.toHaveBeenCalled()
})
})
})

View File

@@ -5,7 +5,7 @@ import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types'
import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types'
import type { CommonNodeType } from '@/app/components/workflow/types'
import { useCallback, useEffect, useMemo } from 'react'
import { ragPipelineNodesAction } from '@/app/components/goto-anything/actions/rag-pipeline-nodes'
import { setRagPipelineNodesSearchFn } from '@/app/components/goto-anything/actions/rag-pipeline-nodes'
import BlockIcon from '@/app/components/workflow/block-icon'
import { useNodesInteractions } from '@/app/components/workflow/hooks/use-nodes-interactions'
import { useGetToolIcon } from '@/app/components/workflow/hooks/use-tool-icon'
@@ -153,16 +153,15 @@ export const useRagPipelineSearch = () => {
return results
}, [searchableNodes, calculateScore])
// Directly set the search function on the action object
// Directly set the search function using the setter
useEffect(() => {
if (searchableNodes.length > 0) {
// Set the search function directly on the action
ragPipelineNodesAction.searchFn = searchRagPipelineNodes
setRagPipelineNodesSearchFn(searchRagPipelineNodes)
}
return () => {
// Clean up when component unmounts
ragPipelineNodesAction.searchFn = undefined
setRagPipelineNodesSearchFn(() => [])
}
}, [searchableNodes, searchRagPipelineNodes])

View File

@@ -168,7 +168,6 @@ describe('EditCustomCollectionModal', () => {
const schemaInput = screen.getByPlaceholderText('tools.createTool.schemaPlaceHolder')
fireEvent.change(schemaInput, { target: { value: '{}' } })
// Wait for parseParamsSchema to be called and state to be updated
await waitFor(() => {
expect(parseParamsSchemaMock).toHaveBeenCalledWith('{}')
})
@@ -185,13 +184,13 @@ describe('EditCustomCollectionModal', () => {
provider: 'provider',
schema: '{}',
schema_type: 'openapi',
credentials: {
auth_type: 'none',
},
icon: {
content: '🕵️',
background: '#FEF7C3',
},
credentials: {
auth_type: 'none',
},
labels: [],
}))
expect(toastNotifySpy).not.toHaveBeenCalled()

View File

@@ -11,12 +11,7 @@ vi.mock('@/app/components/base/modal', () => ({
onClose,
children,
closable,
}: {
isShow: boolean
onClose?: () => void
children?: React.ReactNode
closable?: boolean
}) {
}: any) {
if (!isShow)
return null
@@ -44,10 +39,7 @@ vi.mock('./start-node-selection-panel', () => ({
default: function MockStartNodeSelectionPanel({
onSelectUserInput,
onSelectTrigger,
}: {
onSelectUserInput?: () => void
onSelectTrigger?: (type: BlockEnum, config?: Record<string, unknown>) => void
}) {
}: any) {
return (
<div data-testid="start-node-selection-panel">
<button data-testid="select-user-input" onClick={onSelectUserInput}>
@@ -55,13 +47,13 @@ vi.mock('./start-node-selection-panel', () => ({
</button>
<button
data-testid="select-trigger-schedule"
onClick={() => onSelectTrigger?.(BlockEnum.TriggerSchedule)}
onClick={() => onSelectTrigger(BlockEnum.TriggerSchedule)}
>
Select Trigger Schedule
</button>
<button
data-testid="select-trigger-webhook"
onClick={() => onSelectTrigger?.(BlockEnum.TriggerWebhook, { config: 'test' })}
onClick={() => onSelectTrigger(BlockEnum.TriggerWebhook, { config: 'test' })}
>
Select Trigger Webhook
</button>
@@ -557,7 +549,7 @@ describe('WorkflowOnboardingModal', () => {
// Arrange & Act
renderComponent({ isShow: true })
// Assert - ShortcutsName component renders keys in div elements with system-kbd class
// Assert
const escKey = screen.getByText('workflow.onboarding.escTip.key')
// ShortcutsName renders a <div> with class system-kbd, not a <kbd> element
expect(escKey.closest('.system-kbd')).toBeInTheDocument()

View File

@@ -10,8 +10,7 @@ export const X_OFFSET = 60
export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET
export const Y_OFFSET = 39
export const VIBE_COMMAND_EVENT = 'workflow-vibe-command'
export const VIBE_REGENERATE_EVENT = 'workflow-vibe-regenerate'
export const VIBE_ACCEPT_EVENT = 'workflow-vibe-accept'
export const VIBE_APPLY_EVENT = 'workflow-vibe-apply'
export const START_INITIAL_POSITION = { x: 80, y: 282 }
export const AUTO_LAYOUT_OFFSET = {
x: -42,

View File

@@ -160,7 +160,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
}
else {
usedVars = getNodeUsedVars(node).filter(v => v.length > 0)
usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0)
}
if (node.type === CUSTOM_NODE) {
@@ -359,7 +359,7 @@ export const useChecklistBeforePublish = () => {
}
}
else {
usedVars = getNodeUsedVars(node).filter(v => v.length > 0)
usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0)
}
const checkData = getCheckData(node.data, datasets)
const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)

View File

@@ -5,7 +5,7 @@ import type { CommonNodeType } from '../types'
import type { Emoji } from '@/app/components/tools/types'
import { useCallback, useEffect, useMemo } from 'react'
import { useNodes } from 'reactflow'
import { workflowNodesAction } from '@/app/components/goto-anything/actions/workflow-nodes'
import { setWorkflowNodesSearchFn } from '@/app/components/goto-anything/actions/workflow-nodes'
import { CollectionType } from '@/app/components/tools/types'
import BlockIcon from '@/app/components/workflow/block-icon'
import {
@@ -183,16 +183,15 @@ export const useWorkflowSearch = () => {
return results
}, [searchableNodes, calculateScore])
// Directly set the search function on the action object
// Directly set the search function using the setter
useEffect(() => {
if (searchableNodes.length > 0) {
// Set the search function directly on the action
workflowNodesAction.searchFn = searchWorkflowNodes
setWorkflowNodesSearchFn(searchWorkflowNodes)
}
return () => {
// Clean up when component unmounts
workflowNodesAction.searchFn = undefined
setWorkflowNodesSearchFn(() => [])
}
}, [searchableNodes, searchWorkflowNodes])

View File

@@ -471,12 +471,14 @@ export const useNodesReadOnly = () => {
const workflowRunningData = useStore(s => s.workflowRunningData)
const historyWorkflowData = useStore(s => s.historyWorkflowData)
const isRestoring = useStore(s => s.isRestoring)
// const showVibePanel = useStore(s => s.showVibePanel)
const getNodesReadOnly = useCallback((): boolean => {
const {
workflowRunningData,
historyWorkflowData,
isRestoring,
// showVibePanel,
} = workflowStore.getState()
return !!(

View File

@@ -68,6 +68,7 @@ import {
useWorkflow,
useWorkflowReadOnly,
useWorkflowRefreshDraft,
useWorkflowVibe,
} from './hooks'
import { HooksStoreContextProvider, useHooksStore } from './hooks-store'
import { useWorkflowSearch } from './hooks/use-workflow-search'
@@ -329,6 +330,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
useShortcuts()
// Initialize workflow node search functionality
useWorkflowSearch()
useWorkflowVibe()
// Set up scroll to node event listener using the utility function
useEffect(() => {

View File

@@ -33,9 +33,9 @@ const FileUploadSetting: FC<Props> = ({
const { t } = useTranslation()
const {
allowed_file_upload_methods,
allowed_file_upload_methods = [],
max_length,
allowed_file_types,
allowed_file_types = [],
allowed_file_extensions,
} = payload
const { data: fileUploadConfigResponse } = useFileUploadConfig()

View File

@@ -1404,9 +1404,9 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
payload.url,
payload.headers,
payload.params,
typeof payload.body.data === 'string'
typeof payload.body?.data === 'string'
? payload.body.data
: payload.body.data.map(d => d.value).join(''),
: (payload.body?.data?.map(d => d.value).join('') ?? ''),
])
break
}

View File

@@ -5,6 +5,9 @@ import { useCallback, useEffect, useState } from 'react'
const UNIQUE_ID_PREFIX = 'key-value-'
const strToKeyValueList = (value: string) => {
if (typeof value !== 'string' || !value)
return []
return value.split('\n').map((item) => {
const [key, ...others] = item.split(':')
return {
@@ -16,7 +19,7 @@ const strToKeyValueList = (value: string) => {
}
const useKeyValueList = (value: string, onChange: (value: string) => void, noFilter?: boolean) => {
const [list, doSetList] = useState<KeyValue[]>(() => value ? strToKeyValueList(value) : [])
const [list, doSetList] = useState<KeyValue[]>(() => typeof value === 'string' && value ? strToKeyValueList(value) : [])
const setList = (l: KeyValue[]) => {
doSetList(l.map((item) => {
return {

View File

@@ -49,7 +49,7 @@ const ConditionValue = ({
if (value === true || value === false)
return value ? 'True' : 'False'
return value.replace(/\{\{#([^#]*)#\}\}/g, (a, b) => {
return String(value).replace(/\{\{#([^#]*)#\}\}/g, (a, b) => {
const arr: string[] = b.split('.')
if (isSystemVar(arr))
return `{{${b}}}`

View File

@@ -18,7 +18,6 @@ import {
Group,
} from '@/app/components/workflow/nodes/_base/components/layout'
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
import { IS_CE_EDITION } from '@/config'
import Split from '../_base/components/split'
import ChunkStructure from './components/chunk-structure'
import EmbeddingModel from './components/embedding-model'
@@ -173,7 +172,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
{
data.indexing_technique === IndexMethodEnum.QUALIFIED
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
&& IS_CE_EDITION && (
&& (
<>
<SummaryIndexSetting
summaryIndexSetting={data.summary_index_setting}

View File

@@ -1,7 +1,6 @@
import type { ToolNodeType, ToolVarInputs } from './types'
import type { InputVar } from '@/app/components/workflow/types'
import { useBoolean } from 'ahooks'
import { capitalize } from 'es-toolkit/string'
import { produce } from 'immer'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -26,12 +25,6 @@ import {
} from '@/service/use-tools'
import { canFindTool } from '@/utils'
import { useWorkflowStore } from '../../store'
import { normalizeJsonSchemaType } from './output-schema-utils'
const formatDisplayType = (output: Record<string, unknown>): string => {
const normalizedType = normalizeJsonSchemaType(output) || 'Unknown'
return capitalize(normalizedType)
}
const useConfig = (id: string, payload: ToolNodeType) => {
const workflowStore = useWorkflowStore()
@@ -254,13 +247,20 @@ const useConfig = (id: string, payload: ToolNodeType) => {
})
}
else {
const normalizedType = normalizeJsonSchemaType(output)
res.push({
name: outputKey,
type:
normalizedType === 'array'
? `Array[${output.items ? formatDisplayType(output.items) : 'Unknown'}]`
: formatDisplayType(output),
output.type === 'array'
? `Array[${output.items?.type
? output.items.type.slice(0, 1).toLocaleUpperCase()
+ output.items.type.slice(1)
: 'Unknown'
}]`
: `${output.type
? output.type.slice(0, 1).toLocaleUpperCase()
+ output.type.slice(1)
: 'Unknown'
}`,
description: output.description,
})
}

View File

@@ -127,23 +127,30 @@ const NodeGroupItem = ({
!!item.variables.length && (
<div className="space-y-0.5">
{
item.variables.map((variable = [], index) => {
const isSystem = isSystemVar(variable)
item.variables
.map((variable = [], index) => {
// Ensure variable is an array
const safeVariable = Array.isArray(variable) ? variable : []
if (!safeVariable.length)
return null
const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === variable[0])
const varName = isSystem ? `sys.${variable[variable.length - 1]}` : variable.slice(1).join('.')
const isException = isExceptionVariable(varName, node?.data.type)
const isSystem = isSystemVar(safeVariable)
return (
<VariableLabelInNode
key={index}
variables={variable}
nodeType={node?.data.type}
nodeTitle={node?.data.title}
isExceptionVariable={isException}
/>
)
})
const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === safeVariable[0])
const varName = isSystem ? `sys.${safeVariable[safeVariable.length - 1]}` : safeVariable.slice(1).join('.')
const isException = isExceptionVariable(varName, node?.data.type)
return (
<VariableLabelInNode
key={index}
variables={safeVariable}
nodeType={node?.data.type}
nodeTitle={node?.data.title}
isExceptionVariable={isException}
/>
)
})
.filter(Boolean)
}
</div>
)

View File

@@ -8,6 +8,7 @@ import { cn } from '@/utils/classnames'
import { Panel as NodePanel } from '../nodes'
import { useStore } from '../store'
import EnvPanel from './env-panel'
import VibePanel from './vibe-panel'
const VersionHistoryPanel = dynamic(() => import('@/app/components/workflow/panel/version-history-panel'), {
ssr: false,
@@ -85,6 +86,7 @@ const Panel: FC<PanelProps> = ({
const showEnvPanel = useStore(s => s.showEnvPanel)
const isRestoring = useStore(s => s.isRestoring)
const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel)
const showVibePanel = useStore(s => s.showVibePanel)
// widths used for adaptive layout
const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth)
@@ -124,33 +126,36 @@ const Panel: FC<PanelProps> = ({
)
return (
<div
ref={rightPanelRef}
tabIndex={-1}
className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')}
key={`${isRestoring}`}
>
{components?.left}
{!!selectedNode && <NodePanel {...selectedNode} />}
<>
<div
className="relative"
ref={otherPanelRef}
ref={rightPanelRef}
tabIndex={-1}
className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')}
key={`${isRestoring}`}
>
{
components?.right
}
{
showWorkflowVersionHistoryPanel && (
<VersionHistoryPanel {...versionHistoryPanelProps} />
)
}
{
showEnvPanel && (
<EnvPanel />
)
}
{components?.left}
{!!selectedNode && <NodePanel {...selectedNode} />}
<div
className="relative"
ref={otherPanelRef}
>
{
components?.right
}
{
showWorkflowVersionHistoryPanel && (
<VersionHistoryPanel {...versionHistoryPanelProps} />
)
}
{
showEnvPanel && (
<EnvPanel />
)
}
</div>
</div>
</div>
{showVibePanel && <VibePanel />}
</>
)
}

View File

@@ -11,8 +11,8 @@ import type { LayoutSliceShape } from './layout-slice'
import type { NodeSliceShape } from './node-slice'
import type { PanelSliceShape } from './panel-slice'
import type { ToolSliceShape } from './tool-slice'
import type { VibeWorkflowSliceShape } from './vibe-workflow-slice'
import type { VersionSliceShape } from './version-slice'
import type { VibeWorkflowSliceShape } from './vibe-workflow-slice'
import type { WorkflowDraftSliceShape } from './workflow-draft-slice'
import type { WorkflowSliceShape } from './workflow-slice'
import type { RagPipelineSliceShape } from '@/app/components/rag-pipeline/store'
@@ -34,8 +34,8 @@ import { createNodeSlice } from './node-slice'
import { createPanelSlice } from './panel-slice'
import { createToolSlice } from './tool-slice'
import { createVibeWorkflowSlice } from './vibe-workflow-slice'
import { createVersionSlice } from './version-slice'
import { createVibeWorkflowSlice } from './vibe-workflow-slice'
import { createWorkflowDraftSlice } from './workflow-draft-slice'
import { createWorkflowSlice } from './workflow-slice'
@@ -57,8 +57,8 @@ export type Shape
& WorkflowSliceShape
& InspectVarsSliceShape
& LayoutSliceShape
& VibeWorkflowSliceShape
& SliceFromInjection
& VibeWorkflowSliceShape
export type InjectWorkflowStoreSliceFn = StateCreator<SliceFromInjection>

View File

@@ -1,4 +1,7 @@
import type { StateCreator } from 'zustand'
import type { BackendEdgeSpec, BackendNodeSpec } from '@/service/debug'
export type VibeIntent = 'generate' | 'off_topic' | 'error' | ''
export type PanelSliceShape = {
panelWidth: number
@@ -24,6 +27,26 @@ export type PanelSliceShape = {
setShowVariableInspectPanel: (showVariableInspectPanel: boolean) => void
initShowLastRunTab: boolean
setInitShowLastRunTab: (initShowLastRunTab: boolean) => void
showVibePanel: boolean
setShowVibePanel: (showVibePanel: boolean) => void
vibePanelMermaidCode: string
setVibePanelMermaidCode: (vibePanelMermaidCode: string) => void
vibePanelBackendNodes?: BackendNodeSpec[]
setVibePanelBackendNodes: (nodes?: BackendNodeSpec[]) => void
vibePanelBackendEdges?: BackendEdgeSpec[]
setVibePanelBackendEdges: (edges?: BackendEdgeSpec[]) => void
isVibeGenerating: boolean
setIsVibeGenerating: (isVibeGenerating: boolean) => void
vibePanelInstruction: string
setVibePanelInstruction: (vibePanelInstruction: string) => void
vibePanelIntent: VibeIntent
setVibePanelIntent: (vibePanelIntent: VibeIntent) => void
vibePanelMessage: string
setVibePanelMessage: (vibePanelMessage: string) => void
vibePanelSuggestions: string[]
setVibePanelSuggestions: (vibePanelSuggestions: string[]) => void
vibePanelLastWarnings: string[]
setVibePanelLastWarnings: (vibePanelLastWarnings: string[]) => void
}
export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({
@@ -44,4 +67,24 @@ export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({
setShowVariableInspectPanel: showVariableInspectPanel => set(() => ({ showVariableInspectPanel })),
initShowLastRunTab: false,
setInitShowLastRunTab: initShowLastRunTab => set(() => ({ initShowLastRunTab })),
showVibePanel: false,
setShowVibePanel: showVibePanel => set(() => ({ showVibePanel })),
vibePanelMermaidCode: '',
setVibePanelMermaidCode: vibePanelMermaidCode => set(() => ({ vibePanelMermaidCode })),
vibePanelBackendNodes: undefined,
setVibePanelBackendNodes: vibePanelBackendNodes => set(() => ({ vibePanelBackendNodes })),
vibePanelBackendEdges: undefined,
setVibePanelBackendEdges: vibePanelBackendEdges => set(() => ({ vibePanelBackendEdges })),
isVibeGenerating: false,
setIsVibeGenerating: isVibeGenerating => set(() => ({ isVibeGenerating })),
vibePanelInstruction: '',
setVibePanelInstruction: vibePanelInstruction => set(() => ({ vibePanelInstruction })),
vibePanelIntent: '',
setVibePanelIntent: vibePanelIntent => set(() => ({ vibePanelIntent })),
vibePanelMessage: '',
setVibePanelMessage: vibePanelMessage => set(() => ({ vibePanelMessage })),
vibePanelSuggestions: [],
setVibePanelSuggestions: vibePanelSuggestions => set(() => ({ vibePanelSuggestions })),
vibePanelLastWarnings: [],
setVibePanelLastWarnings: vibePanelLastWarnings => set(() => ({ vibePanelLastWarnings })),
})

View File

@@ -111,8 +111,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType>
if (currentNode.data.type === BlockEnum.Iteration) {
if (currentNode.data.start_node_id) {
if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE)
if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) {
if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_ITERATION_START_NODE)
iterationNodesWithStartNode.push(currentNode)
}
else {
@@ -121,8 +121,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
}
if (currentNode.data.type === BlockEnum.Loop) {
if (currentNode.data.start_node_id) {
if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE)
if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) {
if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_LOOP_START_NODE)
loopNodesWithStartNode.push(currentNode)
}
else {

View File

@@ -65,7 +65,7 @@ const IfElseNode: FC<NodeProps<IfElseNodeType>> = (props) => {
</div>
<div className="space-y-0.5">
{caseItem.conditions.map((condition, i) => (
<div key={condition.id} className="relative">
<div key={condition.id || i} className="relative">
{
checkIsConditionSet(condition)
? (

View File

@@ -2,6 +2,7 @@
import type {
EdgeChange,
FitViewOptions,
NodeChange,
Viewport,
} from 'reactflow'
@@ -59,8 +60,10 @@ const edgeTypes = {
type WorkflowPreviewProps = {
nodes: Node[]
edges: Edge[]
viewport: Viewport
viewport?: Viewport
className?: string
fitView?: boolean
fitViewOptions?: FitViewOptions
miniMapToRight?: boolean
}
const WorkflowPreview = ({
@@ -68,6 +71,8 @@ const WorkflowPreview = ({
edges,
viewport,
className,
fitView,
fitViewOptions,
miniMapToRight,
}: WorkflowPreviewProps) => {
const [nodesData, setNodesData] = useState(() => initialNodes(nodes, edges))
@@ -125,6 +130,8 @@ const WorkflowPreview = ({
selectionKeyCode={null}
selectionMode={SelectionMode.Partial}
minZoom={0.25}
fitView={fitView}
fitViewOptions={fitViewOptions}
>
<Background
gap={[14, 14]}

View File

@@ -0,0 +1,87 @@
import type { AppListResponse } from '@/models/app'
import type { DataSetListResponse } from '@/models/datasets'
import type { BackendEdgeSpec, BackendNodeSpec, FlowchartGenRes } from '@/service/debug'
import { type } from '@orpc/contract'
import { base } from '../base'
// Search APIs
export const searchAppsContract = base
.route({
path: '/apps',
method: 'GET',
})
.input(type<{
query?: {
page?: number
limit?: number
name?: string
}
}>())
.output(type<AppListResponse>())
export const searchDatasetsContract = base
.route({
path: '/datasets',
method: 'GET',
})
.input(type<{
query?: {
page?: number
limit?: number
keyword?: string
}
}>())
.output(type<DataSetListResponse>())
// Vibe Workflow API
export type GenerateFlowchartInput = {
instruction: string
model_config: {
provider: string
name: string
mode: string
completion_params: Record<string, unknown>
} | null
available_nodes: Array<{
type: string
title?: string
description?: string
}>
existing_nodes?: Array<{
id: string
type: string
title?: string
}>
existing_edges?: BackendEdgeSpec[]
available_tools: Array<{
provider_id: string
provider_name?: string
provider_type?: string
tool_name: string
tool_label?: string
tool_key?: string
tool_description?: string
}>
selected_node_ids?: string[]
previous_workflow?: {
nodes: BackendNodeSpec[]
edges: BackendEdgeSpec[]
warnings?: string[]
}
regenerate_mode?: boolean
language: string
available_models?: Array<{
provider: string
model: string
}>
}
export const generateFlowchartContract = base
.route({
path: '/flowchart-generate',
method: 'POST',
})
.input(type<{
body: GenerateFlowchartInput
}>())
.output(type<FlowchartGenRes>())

View File

@@ -1,5 +1,6 @@
import type { InferContractRouterInputs } from '@orpc/contract'
import { bindPartnerStackContract, invoicesContract } from './console/billing'
import { generateFlowchartContract, searchAppsContract, searchDatasetsContract } from './console/goto-anything'
import { systemFeaturesContract } from './console/system'
import {
triggerOAuthConfigContract,
@@ -58,6 +59,11 @@ export const consoleRouterContract = {
oauthDelete: triggerOAuthDeleteContract,
oauthInitiate: triggerOAuthInitiateContract,
},
gotoAnything: {
searchApps: searchAppsContract,
searchDatasets: searchDatasetsContract,
generateFlowchart: generateFlowchartContract,
},
}
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>

View File

@@ -269,6 +269,9 @@
}
},
"app/components/app/app-publisher/index.tsx": {
"tailwindcss/no-unnecessary-whitespace": {
"count": 1
},
"tailwindcss/no-unnecessary-whitespace": {
"count": 1
},
@@ -3204,6 +3207,11 @@
"count": 1
}
},
"app/components/share/text-generation/result/header.tsx": {
"tailwindcss/no-unnecessary-whitespace": {
"count": 3
}
},
"app/components/share/text-generation/result/index.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 3
@@ -3643,6 +3651,11 @@
"count": 1
}
},
"app/components/workflow/nodes/_base/components/before-run-form/panel-wrap.tsx": {
"tailwindcss/no-unnecessary-whitespace": {
"count": 1
}
},
"app/components/workflow/nodes/_base/components/editor/code-editor/editor-support-vars.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
@@ -5447,4 +5460,4 @@
"count": 2
}
}
}
}

View File

@@ -75,6 +75,9 @@
"gotoAnything.actions.themeLightDesc": "Use light appearance",
"gotoAnything.actions.themeSystem": "System Theme",
"gotoAnything.actions.themeSystemDesc": "Follow your OS appearance",
"gotoAnything.actions.vibeDesc": "Generate workflow from natural language",
"gotoAnything.actions.vibeHint": "Try: {{prompt}}",
"gotoAnything.actions.vibeTitle": "Vibe",
"gotoAnything.actions.zenDesc": "Toggle canvas focus mode",
"gotoAnything.actions.zenTitle": "Zen Mode",
"gotoAnything.clearToSearchAll": "Clear @ to search all",

View File

@@ -1150,5 +1150,26 @@
"versionHistory.nameThisVersion": "Name this version",
"versionHistory.releaseNotesPlaceholder": "Describe what changed",
"versionHistory.restorationTip": "After version restoration, the current draft will be overwritten.",
"versionHistory.title": "Versions"
"versionHistory.title": "Versions",
"vibe.apply": "Apply",
"vibe.generateError": "Failed to generate workflow. Please try again.",
"vibe.generatingFlowchart": "Generating flowchart preview...",
"vibe.invalidFlowchart": "The generated flowchart could not be parsed.",
"vibe.missingFlowchart": "No flowchart was generated.",
"vibe.missingInstruction": "Describe the workflow you want to build.",
"vibe.modelUnavailable": "No model available for flowchart generation.",
"vibe.noFlowchart": "No flowchart provided",
"vibe.noFlowchartYet": "No flowchart preview available",
"vibe.nodeTypeUnavailable": "Node type \"{{type}}\" is not available in this workflow.",
"vibe.nodesUnavailable": "Workflow nodes are not available yet.",
"vibe.offTopicDefault": "I'm the Dify workflow design assistant. I can help you create AI automation workflows, but I can't answer general questions. Would you like to create a workflow instead?",
"vibe.offTopicTitle": "Off-Topic Request",
"vibe.panelTitle": "Workflow Preview",
"vibe.readOnly": "This workflow is read-only.",
"vibe.regenerate": "Regenerate",
"vibe.regenerateReminder": "Please verify your input and re-generate.",
"vibe.toolUnavailable": "Tool \"{{tool}}\" is not available in this workspace.",
"vibe.trySuggestion": "Try one of these suggestions:",
"vibe.unknownNodeId": "Node \"{{id}}\" is used before it is defined.",
"vibe.unsupportedEdgeLabel": "Unsupported edge label \"{{label}}\". Only true/false are allowed for if/else."
}

View File

@@ -1150,5 +1150,6 @@
"versionHistory.nameThisVersion": "命名",
"versionHistory.releaseNotesPlaceholder": "请描述变更",
"versionHistory.restorationTip": "版本回滚后,当前草稿将被覆盖。",
"versionHistory.title": "版本"
"versionHistory.title": "版本",
"vibe.regenerateReminder": "请检查输入并重新生成。"
}

View File

@@ -236,7 +236,8 @@
"vite": "7.3.1",
"vite-tsconfig-paths": "6.0.4",
"vitest": "4.0.17",
"vitest-canvas-mock": "1.1.3"
"vitest-canvas-mock": "1.1.3",
"vitest-tiny-reporter": "1.3.1"
},
"pnpm": {
"overrides": {

15
web/pnpm-lock.yaml generated
View File

@@ -585,6 +585,9 @@ importers:
vitest-canvas-mock:
specifier: 1.1.3
version: 1.1.3(vitest@4.0.17)
vitest-tiny-reporter:
specifier: 1.3.1
version: 1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17)
packages:
@@ -7291,6 +7294,12 @@ packages:
peerDependencies:
vitest: ^3.0.0 || ^4.0.0
vitest-tiny-reporter@1.3.1:
resolution: {integrity: sha512-9WfLruQBbxm4EqMIS0jDZmQjvMgsWgHUso9mHQWgjA6hM3tEVhjdG8wYo7ePFh1XbwEFzEo3XUQqkGoKZ/Td2Q==}
peerDependencies:
'@vitest/runner': ^2.0.0 || ^3.0.2 || ^4.0.0
vitest: ^2.0.0 || ^3.0.2 || ^4.0.0
vitest@4.0.17:
resolution: {integrity: sha512-FQMeF0DJdWY0iOnbv466n/0BudNdKj1l5jYgl5JVTwjSsZSlqyXFt/9+1sEyhR6CLowbZpV7O1sCHrzBhucKKg==}
engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0}
@@ -15342,6 +15351,12 @@ snapshots:
moo-color: 1.0.3
vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
vitest-tiny-reporter@1.3.1(@vitest/runner@4.0.17)(vitest@4.0.17):
dependencies:
'@vitest/runner': 4.0.17
tinyrainbow: 3.0.3
vitest: 4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2)
vitest@4.0.17(@types/node@18.15.0)(@vitest/browser-playwright@4.0.17)(jiti@1.21.7)(jsdom@27.3.0(canvas@3.2.1))(sass@1.93.2)(terser@5.46.0)(tsx@4.21.0)(yaml@2.8.2):
dependencies:
'@vitest/expect': 4.0.17

View File

@@ -19,6 +19,48 @@ export type GenRes = {
error?: string
}
export type ToolRecommendation = {
requested_capability: string
unconfigured_tools: Array<{
provider_id: string
tool_name: string
description: string
}>
configured_alternatives: Array<{
provider_id: string
tool_name: string
description: string
}>
recommendation: string
}
export type BackendNodeSpec = {
id: string
type: string
title?: string
config?: Record<string, any>
position?: { x: number, y: number }
}
export type BackendEdgeSpec = {
source: string
target: string
sourceHandle?: string
targetHandle?: string
}
export type FlowchartGenRes = {
intent?: 'generate' | 'off_topic' | 'error'
flowchart: string
nodes?: BackendNodeSpec[]
edges?: BackendEdgeSpec[]
message?: string
warnings?: string[]
suggestions?: string[]
tool_recommendations?: ToolRecommendation[]
error?: string
}
export type CodeGenRes = {
code: string
language: string[]
@@ -75,6 +117,12 @@ export const generateRule = (body: Record<string, any>) => {
})
}
export const generateFlowchart = (body: Record<string, any>) => {
return post<FlowchartGenRes>('/flowchart-generate', {
body,
})
}
export const fetchModelParams = (providerName: string, modelId: string) => {
return get(`workspaces/current/model-providers/${providerName}/models/parameter-rules`, {
params: {

View File

@@ -0,0 +1,50 @@
import type { GenerateFlowchartInput } from '@/contract/console/goto-anything'
import { consoleClient, consoleQuery, marketplaceClient, marketplaceQuery } from '@/service/client'
// Search APIs
export const searchAppsQueryKey = consoleQuery.gotoAnything.searchApps.queryKey
export const searchApps = async (name?: string) => {
return consoleClient.gotoAnything.searchApps({
query: {
page: 1,
name,
},
})
}
export const searchDatasetsQueryKey = consoleQuery.gotoAnything.searchDatasets.queryKey
export const searchDatasets = async (keyword?: string) => {
return consoleClient.gotoAnything.searchDatasets({
query: {
page: 1,
limit: 10,
keyword,
},
})
}
export const searchPluginsQueryKey = marketplaceQuery.searchAdvanced.queryKey
export const searchPlugins = async (query?: string) => {
return marketplaceClient.searchAdvanced({
params: {
kind: 'plugins',
},
body: {
query: query || '',
page: 1,
page_size: 10,
},
})
}
// Vibe Workflow API
export const generateFlowchartMutationKey = consoleQuery.gotoAnything.generateFlowchart.mutationKey
export const generateFlowchart = async (input: GenerateFlowchartInput) => {
return consoleClient.gotoAnything.generateFlowchart({
body: input,
})
}