Compare commits

..

2 Commits

Author SHA1 Message Date
Yanli 盐粒
ce1a1c062f chore(make): default PATH_TO_CHECK to api/ 2026-01-29 02:43:32 +08:00
Yanli 盐粒
5af4d7f9a1 chore(api): bootstrap ty type-checking
- Add api/ty.toml exclusions for incremental rollout
- Wire ty into make type-check (with basedpyright + mypy)
- Update style workflow to run make type-check
2026-01-29 02:41:26 +08:00
2817 changed files with 23645 additions and 199577 deletions

View File

@@ -480,4 +480,4 @@ const useButtonState = () => {
### Related Skills
- `frontend-testing` - For testing refactored components
- `web/docs/test.md` - Testing specification
- `web/testing/testing.md` - Testing specification

View File

@@ -7,7 +7,7 @@ description: Generate Vitest + React Testing Library tests for Dify frontend com
This skill enables Claude to generate high-quality, comprehensive frontend tests for the Dify project following established conventions and best practices.
> **⚠️ Authoritative Source**: This skill is derived from `web/docs/test.md`. Use Vitest mock/timer APIs (`vi.*`).
> **⚠️ Authoritative Source**: This skill is derived from `web/testing/testing.md`. Use Vitest mock/timer APIs (`vi.*`).
## When to Apply This Skill
@@ -309,7 +309,7 @@ For more detailed information, refer to:
### Primary Specification (MUST follow)
- **`web/docs/test.md`** - The canonical testing specification. This skill is derived from this document.
- **`web/testing/testing.md`** - The canonical testing specification. This skill is derived from this document.
### Reference Examples in Codebase

View File

@@ -4,7 +4,7 @@ This guide defines the workflow for generating tests, especially for complex com
## Scope Clarification
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/docs/test.md` § Coverage Goals.
This guide addresses **multi-file workflow** (how to process multiple test files). For coverage requirements within a single test file, see `web/testing/testing.md` § Coverage Goals.
| Scope | Rule |
|-------|------|

View File

@@ -0,0 +1 @@
../../.agents/skills/component-refactoring

View File

@@ -0,0 +1 @@
../../.agents/skills/frontend-code-review

View File

@@ -0,0 +1 @@
../../.agents/skills/frontend-testing

View File

@@ -0,0 +1 @@
../../.agents/skills/orpc-contract-first

View File

@@ -7,7 +7,7 @@ cd web && pnpm install
pipx install uv
echo "alias start-api=\"cd $WORKSPACE_ROOT/api && uv run python -m flask run --host 0.0.0.0 --port=5001 --debug\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution\"" >> ~/.bashrc
echo "alias start-worker=\"cd $WORKSPACE_ROOT/api && uv run python -m celery -A app.celery worker -P threads -c 1 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention\"" >> ~/.bashrc
echo "alias start-web=\"cd $WORKSPACE_ROOT/web && pnpm dev:inspect\"" >> ~/.bashrc
echo "alias start-web-prod=\"cd $WORKSPACE_ROOT/web && pnpm build && pnpm start\"" >> ~/.bashrc
echo "alias start-containers=\"cd $WORKSPACE_ROOT/docker && docker-compose -f docker-compose.middleware.yaml -p dify --env-file middleware.env up -d\"" >> ~/.bashrc

10
.github/CODEOWNERS vendored
View File

@@ -9,9 +9,6 @@
# CODEOWNERS file
/.github/CODEOWNERS @laipz8200 @crazywoola
# Agents
/.agents/skills/ @hyoban
# Docs
/docs/ @crazywoola
@@ -24,10 +21,6 @@
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
# Backend - Tests
/api/tests/ @laipz8200 @QuantumGhost
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
@@ -238,9 +231,6 @@
# Frontend - Base Components
/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Base Components Tests
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
# Frontend - Utils and Hooks
/web/utils/classnames.ts @iamjoel @zxhlyh
/web/utils/time.ts @iamjoel @zxhlyh

View File

@@ -72,7 +72,6 @@ jobs:
OPENDAL_FS_ROOT: /tmp/dify-storage
run: |
uv run --project api pytest \
-n auto \
--timeout "${PYTEST_TIMEOUT:-180}" \
api/tests/integration_tests/workflow \
api/tests/integration_tests/tools \

View File

@@ -79,6 +79,29 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install web dependencies
run: |
cd web
pnpm install --frozen-lockfile
- name: ESLint autofix
run: |
cd web
pnpm lint:fix || true
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |

View File

@@ -8,7 +8,6 @@ on:
- "build/**"
- "release/e-*"
- "hotfix/**"
- "feat/hitl-backend"
tags:
- "*"
@@ -76,9 +75,7 @@ jobs:
with:
context: "{{defaultContext}}:${{ matrix.context }}"
platforms: ${{ matrix.platform }}
build-args: |
COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
ENABLE_PROD_SOURCEMAP=${{ matrix.context == 'web' && github.ref_name == 'deploy/dev' }}
build-args: COMMIT_SHA=${{ fromJSON(steps.meta.outputs.json).labels['org.opencontainers.image.revision'] }}
labels: ${{ steps.meta.outputs.labels }}
outputs: type=image,name=${{ env[matrix.image_name_env] }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=gha,scope=${{ matrix.service_name }}

View File

@@ -4,7 +4,8 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "build/feat/hitl"
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
@@ -13,7 +14,10 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'build/feat/hitl'
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1

View File

@@ -39,7 +39,7 @@ jobs:
run: pnpm install --frozen-lockfile
- name: Run tests
run: pnpm test:ci
run: pnpm test:coverage
- name: Coverage Summary
if: always()

2
.gitignore vendored
View File

@@ -209,7 +209,6 @@ api/.vscode
.history
.idea/
web/migration/
# pnpm
/.pnpm-store
@@ -222,7 +221,6 @@ mise.toml
# AI Assistant
.sisyphus/
.roo/
api/.env.backup
/clickzetta

View File

@@ -37,7 +37,7 @@
"-c",
"1",
"-Q",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
"--loglevel",
"INFO"
],

View File

@@ -7,7 +7,7 @@ Dify is an open-source platform for developing LLM applications with an intuitiv
The codebase is split into:
- **Backend API** (`/api`): Python Flask application organized with Domain-Driven Design
- **Frontend Web** (`/web`): Next.js application using TypeScript and React
- **Frontend Web** (`/web`): Next.js 15 application using TypeScript and React 19
- **Docker deployment** (`/docker`): Containerized deployment configurations
## Backend Workflow
@@ -18,7 +18,36 @@ The codebase is split into:
## Frontend Workflow
- Read `web/AGENTS.md` for details
```bash
cd web
pnpm lint:fix
pnpm type-check:tsgo
pnpm test
```
### Frontend Linting
ESLint is used for frontend code quality. Available commands:
```bash
# Lint all files (report only)
pnpm lint
# Lint and auto-fix issues
pnpm lint:fix
# Lint specific files or directories
pnpm lint:fix app/components/base/button/
pnpm lint:fix app/components/base/button/index.tsx
# Lint quietly (errors only, no warnings)
pnpm lint:quiet
# Check code complexity
pnpm lint:complexity
```
**Important**: Always run `pnpm lint:fix` before committing. The pre-commit hook runs `lint-staged` which only lints staged files.
## Testing & Quality Practices

View File

@@ -77,7 +77,7 @@ How we prioritize:
For setting up the frontend service, please refer to our comprehensive [guide](https://github.com/langgenius/dify/blob/main/web/README.md) in the `web/README.md` file. This document provides detailed instructions to help you set up the frontend environment properly.
**Testing**: All React components must have comprehensive test coverage. See [web/docs/test.md](https://github.com/langgenius/dify/blob/main/web/docs/test.md) for the canonical frontend testing guidelines and follow every requirement described there.
**Testing**: All React components must have comprehensive test coverage. See [web/testing/testing.md](https://github.com/langgenius/dify/blob/main/web/testing/testing.md) for the canonical frontend testing guidelines and follow every requirement described there.
#### Backend

View File

@@ -3,6 +3,7 @@ DOCKER_REGISTRY=langgenius
WEB_IMAGE=$(DOCKER_REGISTRY)/dify-web
API_IMAGE=$(DOCKER_REGISTRY)/dify-api
VERSION=latest
PATH_TO_CHECK ?= api/
# Default target - show help
.DEFAULT_GOAL := help
@@ -80,7 +81,7 @@ test:
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
PYTEST_XDIST_ARGS="-n auto" uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi
@echo "✅ Tests complete"

View File

@@ -33,9 +33,6 @@ TRIGGER_URL=http://localhost:5001
# The time in seconds after the signature is rejected
FILES_ACCESS_TIMEOUT=300
# Collaboration mode toggle
ENABLE_COLLABORATION_MODE=false
# Access token expiration time in minutes
ACCESS_TOKEN_EXPIRE_MINUTES=60
@@ -620,18 +617,12 @@ PLUGIN_DAEMON_URL=http://127.0.0.1:5002
PLUGIN_REMOTE_INSTALL_PORT=5003
PLUGIN_REMOTE_INSTALL_HOST=localhost
PLUGIN_MAX_PACKAGE_SIZE=15728640
PLUGIN_MODEL_SCHEMA_CACHE_TTL=3600
INNER_API_KEY_FOR_PLUGIN=QaHbTe77CtuXmsfyhR7+vRjI/+XbV1AaFy691iy+kGDv2Jvy0/eAh8Y1
# Marketplace configuration
MARKETPLACE_ENABLED=true
MARKETPLACE_API_URL=https://marketplace.dify.ai
# Creators Platform configuration
CREATORS_PLATFORM_FEATURES_ENABLED=true
CREATORS_PLATFORM_API_URL=https://creators.dify.ai
CREATORS_PLATFORM_OAUTH_CLIENT_ID=
# Endpoint configuration
ENDPOINT_URL_TEMPLATE=http://localhost:5002/e/{hook_id}
@@ -723,58 +714,6 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
# Sandbox expired records clean configuration
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
# Sandbox Dify CLI configuration
# Directory containing dify CLI binaries (dify-cli-<os>-<arch>). Defaults to api/bin when unset.
SANDBOX_DIFY_CLI_ROOT=
# CLI API URL for sandbox (dify-sandbox or e2b) to call back to Dify API.
# This URL must be accessible from the sandbox environment.
# For local development: use http://localhost:5001 or http://127.0.0.1:5001
# For middleware docker stack (api on host): keep localhost/127.0.0.1 and use agentbox via 127.0.0.1:2222
# For Docker deployment: use http://api:5001 (internal Docker network)
# For external sandbox (e.g., e2b): use a publicly accessible URL
CLI_API_URL=http://localhost:5001
# Base URL for storage file ticket API endpoints (upload/download).
# Used by sandbox containers (internal or external like e2b) that need an absolute,
# routable address to reach the Dify API file endpoints.
# Falls back to FILES_URL if not specified.
# For local development: http://localhost:5001
# For Docker deployment: http://api:5001
FILES_API_URL=http://localhost:5001
# Optional defaults for SSH sandbox provider setup (for manual config/CLI usage).
# Middleware/local dev usually uses 127.0.0.1:2222; full docker deployment usually uses agentbox:22.
SSH_SANDBOX_HOST=127.0.0.1
SSH_SANDBOX_PORT=2222
SSH_SANDBOX_USERNAME=agentbox
SSH_SANDBOX_PASSWORD=agentbox
SSH_SANDBOX_BASE_WORKING_PATH=/workspace/sandboxes
# Redis URL used for PubSub between API and
# celery worker
# defaults to url constructed from `REDIS_*`
# configurations
PUBSUB_REDIS_URL=
# Pub/sub channel type for streaming events.
# valid options are:
#
# - pubsub: for normal Pub/Sub
# - sharded: for sharded Pub/Sub
#
# It's highly recommended to use sharded Pub/Sub AND redis cluster
# for large deployments.
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while running
# PubSub.
# It's highly recommended to enable this for large deployments.
PUBSUB_REDIS_USE_CLUSTERS=false
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
# Human input timeout check interval in minutes
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1

View File

@@ -36,8 +36,6 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
# TODO(QuantumGhost): fix the import violation later
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
@@ -52,14 +50,14 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@@ -124,6 +122,11 @@ ignore_imports =
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.file.models
@@ -133,6 +136,7 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.template_transform.template_transform_node -> configs
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
@@ -141,9 +145,9 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.agent.entities
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
@@ -159,6 +163,9 @@ ignore_imports =
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
@@ -207,6 +214,7 @@ ignore_imports =
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
core.workflow.nodes.llm.node -> core.model_manager
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
@@ -219,9 +227,7 @@ ignore_imports =
core.workflow.nodes.knowledge_index.entities -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.knowledge_index.knowledge_index_node -> models.dataset
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
core.workflow.nodes.llm.node -> models.dataset
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
core.workflow.nodes.llm.file_saver -> core.tools.signature
@@ -239,7 +245,6 @@ ignore_imports =
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
@@ -280,12 +285,12 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> extensions.ext_database
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
core.workflow.nodes.llm.file_saver -> extensions.ext_database
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
@@ -295,58 +300,6 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> services
core.workflow.nodes.tool.tool_node -> services
[importlinter:contract:model-runtime-no-internal-imports]
name = Model Runtime Internal Imports
type = forbidden
source_modules =
core.model_runtime
forbidden_modules =
configs
controllers
extensions
models
services
tasks
core.agent
core.app
core.base
core.callback_handler
core.datasource
core.db
core.entities
core.errors
core.extension
core.external_data_tool
core.file
core.helper
core.hosting_configuration
core.indexing_runner
core.llm_generator
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
core.prompt
core.provider_manager
core.rag
core.repositories
core.schemas
core.tools
core.trigger
core.variables
core.workflow
ignore_imports =
core.model_runtime.model_providers.__base.ai_model -> configs
core.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis
core.model_runtime.model_providers.__base.large_language_model -> configs
core.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type
core.model_runtime.model_providers.model_provider_factory -> configs
core.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis
core.model_runtime.model_providers.model_provider_factory -> models.provider_ids
[importlinter:contract:rsc]
name = RSC
type = layers

View File

@@ -53,7 +53,6 @@ select = [
"S301", # suspicious-pickle-usage, disallow use of `pickle` and its wrappers.
"S302", # suspicious-marshal-usage, disallow use of `marshal` module
"S311", # suspicious-non-cryptographic-random-usage,
"TID", # flake8-tidy-imports
]
@@ -89,7 +88,6 @@ ignore = [
"SIM113", # enumerate-for-loop
"SIM117", # multiple-with-statements
"SIM210", # if-expr-with-true-false
"TID252", # allow relative imports from parent modules
]
[lint.per-file-ignores]
@@ -111,20 +109,10 @@ ignore = [
"S110", # allow ignoring exceptions in tests code (currently)
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]
"controllers/web/human_input_form.py" = ["TID251"]
[lint.pyflakes]
allowed-unused-imports = [
"_pytest.monkeypatch",
"tests.integration_tests",
"tests.unit_tests",
]
[lint.flake8-tidy-imports]
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse"]
msg = "Use Pydantic payload/query models instead of reqparse."
[lint.flake8-tidy-imports.banned-api."flask_restx.reqparse.RequestParser"]
msg = "Use Pydantic payload/query models instead of reqparse."

View File

@@ -54,7 +54,7 @@
"--loglevel",
"DEBUG",
"-Q",
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,workflow_based_app_execution,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
]
}
]

View File

@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Application configuration definitions, including file access settings.
Invariants:
- File access settings drive signed URL expiration and base URLs.
Tests:
- Config parsing tests under tests/unit_tests/configs.

View File

@@ -1,9 +0,0 @@
Summary:
- Registers file-related API namespaces and routes for files service.
- Includes app-assets and sandbox archive proxy controllers.
Invariants:
- files_ns must include all file controller modules to register routes.
Tests:
- Coverage via controller unit tests and route registration smoke checks.

View File

@@ -1,14 +0,0 @@
Summary:
- App assets download proxy endpoint (signed URL verification, stream from storage).
Invariants:
- Validates AssetPath fields (UUIDs, asset_type allowlist).
- Verifies tenant-scoped signature and expiration before reading storage.
- URL uses expires_at/nonce/sign query params.
Edge Cases:
- Missing files return NotFound.
- Invalid signature or expired link returns Forbidden.
Tests:
- Verify signature validation and invalid/expired cases.

View File

@@ -1,13 +0,0 @@
Summary:
- App assets upload proxy endpoint (signed URL verification, upload to storage).
Invariants:
- Validates AssetPath fields (UUIDs, asset_type allowlist).
- Verifies tenant-scoped signature and expiration before writing storage.
- URL uses expires_at/nonce/sign query params.
Edge Cases:
- Invalid signature or expired link returns Forbidden.
Tests:
- Verify signature validation and invalid/expired cases.

View File

@@ -1,14 +0,0 @@
Summary:
- Sandbox archive upload/download proxy endpoints (signed URL verification, stream to storage).
Invariants:
- Validates tenant_id and sandbox_id UUIDs.
- Verifies tenant-scoped signature and expiration before storage access.
- URL uses expires_at/nonce/sign query params.
Edge Cases:
- Missing archive returns NotFound.
- Invalid signature or expired link returns Forbidden.
Tests:
- Add unit tests for signature validation if needed.

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Collects file assets and emits FileAsset entries with storage keys.
Invariants:
- Storage keys are derived via AppAssetStorage for draft files.
Tests:
- Covered by asset build pipeline tests.

View File

@@ -1,14 +0,0 @@
Summary:
Summary:
- Builds skill artifacts from markdown assets and uploads resolved outputs.
Invariants:
- Reads draft asset content via AppAssetStorage refs.
- Writes resolved artifacts via AppAssetStorage refs.
- FileAsset storage keys are derived via AppAssetStorage.
Edge Cases:
- Missing or invalid JSON content yields empty skill content/metadata.
Tests:
- Build pipeline unit tests covering compile/upload paths.

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Converts AppAssetFileTree to FileAsset items for packaging.
Invariants:
- Storage keys for assets are derived via AppAssetStorage.
Tests:
- Used in packaging/service tests for asset bundles.

View File

@@ -1,14 +0,0 @@
# Zip Packager Notes
## Purpose
- Builds a ZIP archive of asset contents stored via the configured storage backend.
## Key Decisions
- Packaging writes assets into an in-memory zip buffer returned as bytes.
- Asset fetch + zip writing are executed via a thread pool with a lock guarding `ZipFile` writes.
## Edge Cases
- ZIP writes are serialized by the lock; storage reads still run in parallel.
## Tests/Verification
- None yet.

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Builds AssetItem entries for asset trees using AssetPath-derived storage keys.
Invariants:
- Uses AssetPath to compute draft storage keys.
Tests:
- Covered by asset parsing and packaging tests.

View File

@@ -1,20 +0,0 @@
Summary:
- Defines AssetPath facade + typed asset path classes for app-asset storage access.
- Maps asset paths to storage keys and generates presigned or signed-proxy URLs.
- Signs proxy URLs using tenant private keys and enforces expiration.
- Exposes app_asset_storage singleton for reuse.
Invariants:
- AssetPathBase fields (tenant_id/app_id/resource_id/node_id) must be UUIDs.
- AssetPath.from_components enforces valid types and resolved node_id presence.
- Storage keys are derived internally via AssetPathBase.get_storage_key; callers never supply raw paths.
- AppAssetStorage.storage returns the cached presign wrapper (not the raw storage).
Edge Cases:
- Storage backends without presign support must fall back to signed proxy URLs.
- Signed proxy verification enforces expiration and tenant-scoped signing keys.
- Upload URLs also fall back to signed proxy endpoints when presign is unsupported.
- load_or_none treats SilentStorage "File Not Found" bytes as missing.
Tests:
- Unit tests for ref validation, storage key mapping, and signed URL verification.

View File

@@ -1,10 +0,0 @@
Summary:
Summary:
- Extracts asset files from a zip and persists them into app asset storage.
Invariants:
- Rejects path traversal/absolute/backslash paths.
- Saves extracted files via AppAssetStorage draft refs.
Tests:
- Zip security edge cases and tree construction tests.

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Downloads published app asset zip into sandbox and extracts it.
Invariants:
- Uses AppAssetStorage to generate download URLs for build zips (internal URL).
Tests:
- Sandbox initialization integration tests.

View File

@@ -1,12 +0,0 @@
Summary:
Summary:
- Downloads draft/resolved assets into sandbox for draft execution.
Invariants:
- Uses AppAssetStorage to generate download URLs for draft/resolved refs (internal URL).
Edge Cases:
- No nodes -> returns early.
Tests:
- Sandbox draft initialization tests.

View File

@@ -1,9 +0,0 @@
Summary:
- Sandbox lifecycle wrapper (ready/cancel/fail signals, mount/unmount, release).
Invariants:
- wait_ready raises with the original initialization error as the cause.
- release always attempts unmount and environment release, logging failures.
Tests:
- Covered by sandbox lifecycle/unit tests and workflow execution error handling.

View File

@@ -1,2 +0,0 @@
Summary:
- Sandbox security helper modules.

View File

@@ -1,13 +0,0 @@
Summary:
- Generates and verifies signed URLs for sandbox archive upload/download.
Invariants:
- tenant_id and sandbox_id must be UUIDs.
- Signatures are tenant-scoped and include operation, expiry, and nonce.
Edge Cases:
- Missing tenant private key raises ValueError.
- Expired or tampered signatures are rejected.
Tests:
- Add unit tests if sandbox archive signature behavior expands.

View File

@@ -1,12 +0,0 @@
Summary:
- Manages sandbox archive uploads/downloads for workspace persistence.
Invariants:
- Archive storage key is sandbox/<tenant_id>/<sandbox_id>.tar.gz.
- Signed URLs are tenant-scoped and use external files URL.
Edge Cases:
- Missing archive skips mount.
Tests:
- Covered indirectly via sandbox integration tests.

View File

@@ -1,9 +0,0 @@
Summary:
Summary:
- Loads/saves skill bundles to app asset storage.
Invariants:
- Skill bundles use AppAssetStorage refs and JSON serialization.
Tests:
- Covered by skill bundle build/load unit tests.

View File

@@ -1,16 +0,0 @@
# E2B Sandbox Provider Notes
## Purpose
- Implements the E2B-backed `VirtualEnvironment` provider and bootstraps sandbox metadata, file I/O, and command execution.
## Key Decisions
- Sandbox metadata is gathered during `_construct_environment` using the E2B SDK before returning `Metadata`.
- Architecture/OS detection uses a single `uname -m -s` call split by whitespace to reduce round-trips.
- Command execution streams stdout/stderr through `QueueTransportReadCloser`; stdin is unsupported.
## Edge Cases
- `release_environment` raises when sandbox termination fails.
- `execute_command` runs in a background thread; consumers must read stdout/stderr until EOF.
## Tests/Verification
- None yet. Add targeted service tests when behavior changes.

View File

@@ -1,14 +0,0 @@
Summary:
- App asset CRUD, publish/build pipeline, and presigned URL generation.
Invariants:
- Asset storage access goes through AppAssetStorage + AssetPath, using app_asset_storage singleton.
- Tree operations require tenant/app scoping and lock for mutation.
- Asset zips are packaged via raw storage with storage keys from AppAssetStorage.
Edge Cases:
- File nodes larger than preview limit are rejected.
- Deletion runs asynchronously; storage failures are logged.
Tests:
- Unit tests for storage URL generation and publish/build flows.

View File

@@ -1,10 +0,0 @@
Summary:
Summary:
- Imports app bundles, including asset extraction into app asset storage.
Invariants:
- Asset imports respect zip security checks and tenant/app scoping.
- Draft asset packaging uses AppAssetStorage for key mapping.
Tests:
- Bundle import unit tests and zip validation coverage.

View File

@@ -1,6 +0,0 @@
Summary:
Summary:
- Unit tests for AppAssetStorage ref validation, key mapping, and signing.
Tests:
- Covers valid/invalid refs, signature verify, expiration handling, and proxy URL generation.

View File

@@ -1,13 +1,4 @@
from __future__ import annotations
import os
import sys
from typing import TYPE_CHECKING, cast
if TYPE_CHECKING:
from celery import Celery
celery: Celery
def is_db_command() -> bool:
@@ -17,15 +8,10 @@ def is_db_command() -> bool:
# create app
flask_app = None
socketio_app = None
if is_db_command():
from app_factory import create_migrations_app
app = create_migrations_app()
socketio_app = app
flask_app = app
else:
# Gunicorn and Celery handle monkey patching automatically in production by
# specifying the `gevent` worker class. Manual monkey patching is not required here.
@@ -36,15 +22,8 @@ else:
from app_factory import create_app
socketio_app, flask_app = create_app()
app = flask_app
celery = cast("Celery", flask_app.extensions["celery"])
app = create_app()
celery = app.extensions["celery"]
if __name__ == "__main__":
from gevent import pywsgi
from geventwebsocket.handler import WebSocketHandler # type: ignore[reportMissingTypeStubs]
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", 5001))
server = pywsgi.WSGIServer((host, port), socketio_app, handler_class=WebSocketHandler)
server.serve_forever()
app.run(host="0.0.0.0", port=5001)

View File

@@ -1,7 +1,6 @@
import logging
import time
import socketio # type: ignore[reportMissingTypeStubs]
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
@@ -9,7 +8,6 @@ from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from core.logging.context import init_request_context
from dify_app import DifyApp
from extensions.ext_socketio import sio
logger = logging.getLogger(__name__)
@@ -62,18 +60,14 @@ def create_flask_app_with_configs() -> DifyApp:
return dify_app
def create_app() -> tuple[socketio.WSGIApp, DifyApp]:
def create_app() -> DifyApp:
start_time = time.perf_counter()
app = create_flask_app_with_configs()
initialize_extensions(app)
sio.app = app
socketio_app = socketio.WSGIApp(sio, app)
end_time = time.perf_counter()
if dify_config.DEBUG:
logger.info("Finished create_app (%s ms)", round((end_time - start_time) * 1000, 2))
return socketio_app, app
return app
def initialize_extensions(app: DifyApp):
@@ -155,7 +149,7 @@ def initialize_extensions(app: DifyApp):
logger.info("Loaded %s (%s ms)", short_name, round((end_time - start_time) * 1000, 2))
def create_migrations_app() -> DifyApp:
def create_migrations_app():
app = create_flask_app_with_configs()
from extensions import ext_database, ext_migrate

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -23,8 +23,7 @@ from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.index_processor.constant.built_in_field import BuiltInField
from core.rag.models.document import ChildDocument, Document
from core.sandbox import SandboxBuilder, SandboxType
from core.tools.utils.system_encryption import encrypt_system_params
from core.tools.utils.system_oauth_encryption import encrypt_system_oauth_params
from events.app_event import app_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -740,10 +739,8 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
except Exception:
logger.exception("Failed to execute database migration")
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
lock.release()
else:
@@ -1453,58 +1450,54 @@ def clear_orphaned_file_records(force: bool):
all_ids_in_tables = []
for ids_table in ids_tables:
query = ""
match ids_table["type"]:
case "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
if ids_table["type"] == "uuid":
click.echo(
click.style(
f"- Listing file ids in column {ids_table['column']} in table {ids_table['table']}", fg="white"
)
c = ids_table["column"]
query = f"SELECT {c} FROM {ids_table['table']} WHERE {c} IS NOT NULL"
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
case "text":
t = ids_table["table"]
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {t}",
fg="white",
)
)
query = (
f"SELECT {ids_table['column']} FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
all_ids_in_tables.append({"table": ids_table["table"], "id": str(i[0])})
elif ids_table["type"] == "text":
click.echo(
click.style(
f"- Listing file-id-like strings in column {ids_table['column']} in table {ids_table['table']}",
fg="white",
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
query = (
f"SELECT regexp_matches({ids_table['column']}, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
elif ids_table["type"] == "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case "json":
click.echo(
click.style(
(
f"- Listing file-id-like JSON string in column {ids_table['column']} "
f"in table {ids_table['table']}"
),
fg="white",
)
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
case _:
pass
)
query = (
f"SELECT regexp_matches({ids_table['column']}::text, '{guid_regexp}', 'g') AS extracted_id "
f"FROM {ids_table['table']}"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for i in rs:
for j in i[0]:
all_ids_in_tables.append({"table": ids_table["table"], "id": j})
click.echo(click.style(f"Found {len(all_ids_in_tables)} file ids in tables.", fg="white"))
except Exception as e:
@@ -1614,7 +1607,7 @@ def remove_orphaned_files_on_storage(force: bool):
click.echo(click.style(f"- Scanning files on storage path {storage_path}", fg="white"))
files = storage.scan(path=storage_path, files=True, directories=False)
all_files_on_storage.extend(files)
except FileNotFoundError:
except FileNotFoundError as e:
click.echo(click.style(f" -> Skipping path {storage_path} as it does not exist.", fg="yellow"))
continue
except Exception as e:
@@ -1744,18 +1737,59 @@ def file_usage(
if src_filter != src:
continue
match ids_table["type"]:
case "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ids_table["type"] == "uuid":
# Direct UUID match
query = (
f"SELECT {ids_table['pk_column']}, {ids_table['column']} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
ref_file_id = str(row[1])
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
elif ids_table["type"] in ("text", "json"):
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
@@ -1778,50 +1812,6 @@ def file_usage(
)
total_count += 1
case "text" | "json":
# Extract UUIDs from text/json content
column_cast = f"{ids_table['column']}::text" if ids_table["type"] == "json" else ids_table["column"]
query = (
f"SELECT {ids_table['pk_column']}, {column_cast} "
f"FROM {ids_table['table']} WHERE {ids_table['column']} IS NOT NULL"
)
with db.engine.begin() as conn:
rs = conn.execute(sa.text(query))
for row in rs:
record_id = str(row[0])
content = str(row[1])
# Find all UUIDs in the content
import re
uuid_pattern = re.compile(guid_regexp, re.IGNORECASE)
matches = uuid_pattern.findall(content)
for ref_file_id in matches:
if ref_file_id not in file_key_map:
continue
storage_key = file_key_map[ref_file_id]
# Apply filters
if file_id and ref_file_id != file_id:
continue
if key and not storage_key.endswith(key):
continue
# Only collect items within the requested page range
if offset <= total_count < offset + limit:
paginated_usages.append(
{
"src": f"{ids_table['table']}.{ids_table['column']}",
"record_id": record_id,
"file_id": ref_file_id,
"key": storage_key,
}
)
total_count += 1
case _:
pass
# Output results
if output_json:
result = {
@@ -1865,59 +1855,6 @@ def file_usage(
click.echo(click.style(f"Use --offset {offset + limit} to see next page", fg="white"))
@click.command("setup-sandbox-system-config", help="Setup system-level sandbox provider configuration.")
@click.option(
"--provider-type", prompt=True, type=click.Choice(["e2b", "docker", "local", "ssh"]), help="Sandbox provider type"
)
@click.option("--config", prompt=True, help='Configuration JSON (e.g., {"api_key": "xxx"} for e2b)')
def setup_sandbox_system_config(provider_type: str, config: str):
"""
Setup system-level sandbox provider configuration.
Examples:
flask setup-sandbox-system-config --provider-type e2b --config '{"api_key": "e2b_xxx"}'
flask setup-sandbox-system-config --provider-type docker --config '{"docker_sock": "unix:///var/run/docker.sock"}'
flask setup-sandbox-system-config --provider-type local --config '{}'
flask setup-sandbox-system-config --provider-type ssh --config \
'{"ssh_host": "agentbox", "ssh_port": "22", "ssh_username": "agentbox", "ssh_password": "agentbox"}'
"""
from models.sandbox import SandboxProviderSystemConfig
try:
click.echo(click.style(f"Validating config: {config}", fg="yellow"))
config_dict = TypeAdapter(dict[str, Any]).validate_json(config)
click.echo(click.style("Config validated successfully.", fg="green"))
click.echo(click.style(f"Validating config schema for provider type: {provider_type}", fg="yellow"))
SandboxBuilder.validate(SandboxType(provider_type), config_dict)
click.echo(click.style("Config schema validated successfully.", fg="green"))
click.echo(click.style("Encrypting config...", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
encrypted_config = encrypt_system_params(config_dict)
click.echo(click.style("Config encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error validating/encrypting config: {str(e)}", fg="red"))
return
deleted_count = db.session.query(SandboxProviderSystemConfig).filter_by(provider_type=provider_type).delete()
if deleted_count > 0:
click.echo(
click.style(
f"Deleted {deleted_count} existing system config for provider type: {provider_type}", fg="yellow"
)
)
system_config = SandboxProviderSystemConfig(
provider_type=provider_type,
encrypted_config=encrypted_config,
)
db.session.add(system_config)
db.session.commit()
click.echo(click.style(f"Sandbox system config setup successfully. id: {system_config.id}", fg="green"))
click.echo(click.style(f"Provider type: {provider_type}", fg="green"))
@click.command("setup-system-tool-oauth-client", help="Setup system tool oauth client.")
@click.option("--provider", prompt=True, help="Provider name")
@click.option("--client-params", prompt=True, help="Client Params")
@@ -1937,7 +1874,7 @@ def setup_system_tool_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_params(client_params_dict)
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))
@@ -1986,7 +1923,7 @@ def setup_system_trigger_oauth_client(provider, client_params):
click.echo(click.style(f"Encrypting client params: {client_params}", fg="yellow"))
click.echo(click.style(f"Using SECRET_KEY: `{dify_config.SECRET_KEY}`", fg="yellow"))
oauth_client_params = encrypt_system_params(client_params_dict)
oauth_client_params = encrypt_system_oauth_params(client_params_dict)
click.echo(click.style("Client params encrypted successfully.", fg="green"))
except Exception as e:
click.echo(click.style(f"Error parsing client params: {str(e)}", fg="red"))

View File

@@ -2,7 +2,6 @@ import logging
from pathlib import Path
from typing import Any
from pydantic import Field
from pydantic.fields import FieldInfo
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict, TomlConfigSettingsSource
@@ -83,17 +82,6 @@ class DifyConfig(
extra="ignore",
)
SANDBOX_DIFY_CLI_ROOT: str | None = Field(
default=None,
description=(
"Filesystem directory containing dify CLI binaries named dify-cli-<os>-<arch>. "
"Defaults to api/bin when unset."
),
)
DIFY_PORT: int = Field(
default=5001,
description="Port used by Dify to communicate with the host machine.",
)
# Before adding any config,
# please consider to arrange it in the proper config group of existed or added
# for better readability and maintainability.

View File

@@ -1,4 +1,3 @@
from datetime import timedelta
from enum import StrEnum
from typing import Literal
@@ -49,16 +48,6 @@ class SecurityConfig(BaseSettings):
default=5,
)
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
description="Maximum number of web form submissions allowed per IP within the rate limit window",
default=30,
)
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
description="Time window in seconds for web form submission rate limiting",
default=60,
)
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
@@ -93,12 +82,6 @@ class AppExecutionConfig(BaseSettings):
default=0,
)
HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
default=int(timedelta(days=7).total_seconds()),
ge=1,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
@@ -260,22 +243,6 @@ class PluginConfig(BaseSettings):
default=15728640 * 12,
)
PLUGIN_MODEL_SCHEMA_CACHE_TTL: PositiveInt = Field(
description="TTL in seconds for caching plugin model schemas in Redis",
default=60 * 60,
)
class CliApiConfig(BaseSettings):
"""
Configuration for CLI API (for dify-cli to call back from external sandbox environments)
"""
CLI_API_URL: str = Field(
description="CLI API URL for external sandbox (e.g., e2b) to call back.",
default="http://localhost:5001",
)
class MarketplaceConfig(BaseSettings):
"""
@@ -293,27 +260,6 @@ class MarketplaceConfig(BaseSettings):
)
class CreatorsPlatformConfig(BaseSettings):
"""
Configuration for creators platform
"""
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
description="Enable or disable creators platform features",
default=True,
)
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
description="Creators Platform API URL",
default=HttpUrl("https://creators.dify.ai"),
)
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
description="OAuth client_id for the Creators Platform app registered in Dify",
default="",
)
class EndpointConfig(BaseSettings):
"""
Configuration for various application endpoints and URLs
@@ -368,15 +314,6 @@ class FileAccessConfig(BaseSettings):
default="",
)
FILES_API_URL: str = Field(
description="Base URL for storage file ticket API endpoints."
" Used by sandbox containers (internal or external like e2b) that need"
" an absolute, routable address to upload/download files via the API."
" Falls back to FILES_URL if not specified."
" For Docker deployments, set to http://api:5001.",
default="",
)
FILES_ACCESS_TIMEOUT: int = Field(
description="Expiration time in seconds for file access URLs",
default=300,
@@ -1192,14 +1129,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable queue monitor task",
default=False,
)
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
description="Enable human input timeout check task",
default=True,
)
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
description="Human input timeout check interval in minutes",
default=1,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,
@@ -1221,16 +1150,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
default=0,
)
# API token last_used_at batch update
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
description="Enable periodic batch update of API token last_used_at timestamps",
default=True,
)
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
description="Interval in minutes for batch updating API token last_used_at (default 30)",
default=30,
)
# Trigger provider refresh (simple version)
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
description="Enable trigger provider refresh poller",
@@ -1310,13 +1229,6 @@ class PositionConfig(BaseSettings):
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
class CollaborationConfig(BaseSettings):
ENABLE_COLLABORATION_MODE: bool = Field(
description="Whether to enable collaboration mode features across the workspace",
default=False,
)
class LoginConfig(BaseSettings):
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
description="whether to enable email code login",
@@ -1392,10 +1304,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
description="Maximum number of records to process in each batch",
default=1000,
)
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
description="Maximum interval in milliseconds between batches",
default=200,
)
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
description="Retention days for sandbox expired workflow_run records and message records",
default=30,
@@ -1415,9 +1323,7 @@ class FeatureConfig(
TriggerConfig,
AsyncWorkflowConfig,
PluginConfig,
CliApiConfig,
MarketplaceConfig,
CreatorsPlatformConfig,
DataSetConfig,
EndpointConfig,
FileAccessConfig,
@@ -1441,7 +1347,6 @@ class FeatureConfig(
WorkflowConfig,
WorkflowNodeExecutionConfig,
WorkspaceConfig,
CollaborationConfig,
LoginConfig,
AccountConfig,
SwaggerUIConfig,

View File

@@ -6,7 +6,6 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
from .cache.redis_pubsub_config import RedisPubSubConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
@@ -259,20 +258,11 @@ class CeleryConfig(DatabaseConfig):
description="Password of the Redis Sentinel master.",
default=None,
)
CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field(
description="Timeout for Redis Sentinel socket operations in seconds.",
default=0.1,
)
CELERY_TASK_ANNOTATIONS: dict[str, Any] | None = Field(
description=(
"Annotations for Celery tasks as a JSON mapping of task name -> options "
"(for example, rate limits or other task-specific settings)."
),
default=None,
)
@computed_field
def CELERY_RESULT_BACKEND(self) -> str | None:
if self.CELERY_BACKEND in ("database", "rabbitmq"):
@@ -327,7 +317,6 @@ class MiddlewareConfig(
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
KeywordStoreConfig,
RedisConfig,
RedisPubSubConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,

View File

@@ -1,96 +0,0 @@
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import Field
from pydantic_settings import BaseSettings
class RedisConfigDefaults(Protocol):
REDIS_HOST: str
REDIS_PORT: int
REDIS_USERNAME: str | None
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for Redis pub/sub streaming.
"""
PUBSUB_REDIS_URL: str | None = Field(
alias="PUBSUB_REDIS_URL",
description=(
"Redis connection URL for pub/sub streaming events between API "
"and celery worker, defaults to url constructed from "
"`REDIS_*` configurations"
),
default=None,
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
description=(
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
"recommended to enable this for large deployments."
),
default=False,
)
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
description=(
"Pub/sub channel type for streaming events. "
"Valid options are:\n"
"\n"
" - pubsub: for normal Pub/Sub\n"
" - sharded: for sharded Pub/Sub\n"
"\n"
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
"for large deployments."
),
default="pubsub",
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
username = defaults.REDIS_USERNAME or None
password = defaults.REDIS_PASSWORD or None
userinfo = ""
if username:
userinfo = quote_plus(username)
if password:
password_part = quote_plus(password)
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property
def normalized_pubsub_redis_url(self) -> str:
pubsub_redis_url = self.PUBSUB_REDIS_URL
if pubsub_redis_url:
cleaned = pubsub_redis_url.strip()
pubsub_redis_url = cleaned or None
if pubsub_redis_url:
return pubsub_redis_url
return self._build_default_pubsub_url()

View File

@@ -21,7 +21,6 @@ language_timezone_mapping = {
"th-TH": "Asia/Bangkok",
"id-ID": "Asia/Jakarta",
"ar-TN": "Africa/Tunis",
"nl-NL": "Europe/Amsterdam",
}
languages = list(language_timezone_mapping.keys())

View File

@@ -6,6 +6,7 @@ from contexts.wrapper import RecyclableContextVar
if TYPE_CHECKING:
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
from core.model_runtime.entities.model_entities import AIModelEntity
from core.plugin.entities.plugin_daemon import PluginModelProviderEntity
from core.tools.plugin_tool.provider import PluginToolProviderController
from core.trigger.provider import PluginTriggerProviderController
@@ -28,6 +29,12 @@ plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(
ContextVar("plugin_model_providers_lock")
)
plugin_model_schema_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_model_schema_lock"))
plugin_model_schemas: RecyclableContextVar[dict[str, "AIModelEntity"]] = RecyclableContextVar(
ContextVar("plugin_model_schemas")
)
datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = (
RecyclableContextVar(ContextVar("datasource_plugin_providers"))
)

View File

@@ -1,27 +0,0 @@
from flask import Blueprint
from flask_restx import Namespace
from libs.external_api import ExternalApi
bp = Blueprint("cli_api", __name__, url_prefix="/cli/api")
api = ExternalApi(
bp,
version="1.0",
title="CLI API",
description="APIs for Dify CLI to call back from external sandbox environments (e.g., e2b)",
)
# Create namespace
cli_api_ns = Namespace("cli_api", description="CLI API operations", path="/")
from .dify_cli import cli_api as _plugin
api.add_namespace(cli_api_ns)
__all__ = [
"_plugin",
"api",
"bp",
"cli_api_ns",
]

View File

@@ -1,192 +0,0 @@
from flask import abort
from flask_restx import Resource
from pydantic import BaseModel
from controllers.cli_api import cli_api_ns
from controllers.cli_api.dify_cli.wraps import get_cli_user_tenant, plugin_data
from controllers.cli_api.wraps import cli_api_only
from controllers.console.wraps import setup_required
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.helpers import get_signed_file_url_for_plugin
from core.plugin.backwards_invocation.app import PluginAppBackwardsInvocation
from core.plugin.backwards_invocation.base import BaseBackwardsInvocationResponse
from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation
from core.plugin.backwards_invocation.tool import PluginToolBackwardsInvocation
from core.plugin.entities.request import (
RequestInvokeApp,
RequestInvokeLLM,
RequestInvokeTool,
RequestRequestUploadFile,
)
from core.sandbox.bash.dify_cli import DifyCliToolConfig
from core.session.cli_api import CliContext
from core.skill.entities import ToolInvocationRequest
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from libs.helper import length_prefixed_response
from models.account import Account
from models.model import EndUser, Tenant
class FetchToolItem(BaseModel):
tool_type: str
tool_provider: str
tool_name: str
credential_id: str | None = None
class FetchToolBatchRequest(BaseModel):
tools: list[FetchToolItem]
@cli_api_ns.route("/invoke/llm")
class CliInvokeLLMApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeLLM)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeLLM,
cli_context: CliContext,
):
def generator():
response = PluginModelBackwardsInvocation.invoke_llm(user_model.id, tenant_model, payload)
return PluginModelBackwardsInvocation.convert_to_event_stream(response)
return length_prefixed_response(0xF, generator())
@cli_api_ns.route("/invoke/tool")
class CliInvokeToolApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeTool)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeTool,
cli_context: CliContext,
):
tool_type = ToolProviderType.value_of(payload.tool_type)
request = ToolInvocationRequest(
tool_type=tool_type,
provider=payload.provider,
tool_name=payload.tool,
credential_id=payload.credential_id,
)
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
abort(403, description=f"Access denied for tool: {payload.provider}/{payload.tool}")
def generator():
return PluginToolBackwardsInvocation.convert_to_event_stream(
PluginToolBackwardsInvocation.invoke_tool(
tenant_id=tenant_model.id,
user_id=user_model.id,
tool_type=tool_type,
provider=payload.provider,
tool_name=payload.tool,
tool_parameters=payload.tool_parameters,
credential_id=payload.credential_id,
),
)
return length_prefixed_response(0xF, generator())
@cli_api_ns.route("/invoke/app")
class CliInvokeAppApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestInvokeApp)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestInvokeApp,
cli_context: CliContext,
):
response = PluginAppBackwardsInvocation.invoke_app(
app_id=payload.app_id,
user_id=user_model.id,
tenant_id=tenant_model.id,
conversation_id=payload.conversation_id,
query=payload.query,
stream=payload.response_mode == "streaming",
inputs=payload.inputs,
files=payload.files,
)
return length_prefixed_response(0xF, PluginAppBackwardsInvocation.convert_to_event_stream(response))
@cli_api_ns.route("/upload/file/request")
class CliUploadFileRequestApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=RequestRequestUploadFile)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: RequestRequestUploadFile,
cli_context: CliContext,
):
url = get_signed_file_url_for_plugin(
filename=payload.filename,
mimetype=payload.mimetype,
tenant_id=tenant_model.id,
user_id=user_model.id,
)
return BaseBackwardsInvocationResponse(data={"url": url}).model_dump()
@cli_api_ns.route("/fetch/tools/batch")
class CliFetchToolsBatchApi(Resource):
@cli_api_only
@get_cli_user_tenant
@setup_required
@plugin_data(payload_type=FetchToolBatchRequest)
def post(
self,
user_model: Account | EndUser,
tenant_model: Tenant,
payload: FetchToolBatchRequest,
cli_context: CliContext,
):
tools: list[dict] = []
for item in payload.tools:
provider_type = ToolProviderType.value_of(item.tool_type)
request = ToolInvocationRequest(
tool_type=provider_type,
provider=item.tool_provider,
tool_name=item.tool_name,
credential_id=item.credential_id,
)
if cli_context.tool_access and not cli_context.tool_access.is_allowed(request):
abort(403, description=f"Access denied for tool: {item.tool_provider}/{item.tool_name}")
try:
tool_runtime = ToolManager.get_tool_runtime(
tenant_id=tenant_model.id,
provider_type=provider_type,
provider_id=item.tool_provider,
tool_name=item.tool_name,
invoke_from=InvokeFrom.AGENT,
credential_id=item.credential_id,
)
tool_config = DifyCliToolConfig.create_from_tool(tool_runtime)
tools.append(tool_config.model_dump())
except Exception:
continue
return BaseBackwardsInvocationResponse(data={"tools": tools}).model_dump()

View File

@@ -1,137 +0,0 @@
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import current_app, g, request
from flask_login import user_logged_in
from pydantic import BaseModel
from sqlalchemy.orm import Session
from core.session.cli_api import CliApiSession, CliContext
from extensions.ext_database import db
from libs.login import current_user
from models.account import Tenant
from models.model import DefaultEndUserSessionID, EndUser
P = ParamSpec("P")
R = TypeVar("R")
class TenantUserPayload(BaseModel):
tenant_id: str
user_id: str
def get_user(tenant_id: str, user_id: str | None) -> EndUser:
"""
Get current user
NOTE: user_id is not trusted, it could be maliciously set to any value.
As a result, it could only be considered as an end user id.
"""
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
is_anonymous = user_id == DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
with Session(db.engine) as session:
user_model = None
if is_anonymous:
user_model = (
session.query(EndUser)
.where(
EndUser.session_id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
else:
user_model = (
session.query(EndUser)
.where(
EndUser.id == user_id,
EndUser.tenant_id == tenant_id,
)
.first()
)
if not user_model:
user_model = EndUser(
tenant_id=tenant_id,
type="service_api",
is_anonymous=is_anonymous,
session_id=user_id,
)
session.add(user_model)
session.commit()
session.refresh(user_model)
except Exception:
raise ValueError("user not found")
return user_model
def get_cli_user_tenant(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
session: CliApiSession | None = getattr(g, "cli_api_session", None)
if session is None:
raise ValueError("session not found")
user_id = session.user_id
tenant_id = session.tenant_id
cli_context = CliContext.model_validate(session.context)
if not user_id:
user_id = DefaultEndUserSessionID.DEFAULT_SESSION_ID
try:
tenant_model = (
db.session.query(Tenant)
.where(
Tenant.id == tenant_id,
)
.first()
)
except Exception:
raise ValueError("tenant not found")
if not tenant_model:
raise ValueError("tenant not found")
kwargs["tenant_model"] = tenant_model
kwargs["user_model"] = get_user(tenant_id, user_id)
kwargs["cli_context"] = cli_context
current_app.login_manager._update_request_context_with_user(kwargs["user_model"]) # type: ignore
user_logged_in.send(current_app._get_current_object(), user=current_user) # type: ignore
return view_func(*args, **kwargs)
return decorated_view
def plugin_data(view: Callable[P, R] | None = None, *, payload_type: type[BaseModel]):
def decorator(view_func: Callable[P, R]):
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
try:
data = request.get_json()
except Exception:
raise ValueError("invalid json")
try:
payload = payload_type.model_validate(data)
except Exception as e:
raise ValueError(f"invalid payload: {str(e)}")
kwargs["payload"] = payload
return view_func(*args, **kwargs)
return decorated_view
if view is None:
return decorator
else:
return decorator(view)

View File

@@ -1,56 +0,0 @@
import hashlib
import hmac
import time
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import abort, g, request
from core.session.cli_api import CliApiSessionManager
P = ParamSpec("P")
R = TypeVar("R")
SIGNATURE_TTL_SECONDS = 300
def _verify_signature(session_secret: str, timestamp: str, body: bytes, signature: str) -> bool:
expected = hmac.new(
session_secret.encode(),
f"{timestamp}.".encode() + body,
hashlib.sha256,
).hexdigest()
return hmac.compare_digest(f"sha256={expected}", signature)
def cli_api_only(view: Callable[P, R]):
@wraps(view)
def decorated(*args: P.args, **kwargs: P.kwargs):
session_id = request.headers.get("X-Cli-Api-Session-Id")
timestamp = request.headers.get("X-Cli-Api-Timestamp")
signature = request.headers.get("X-Cli-Api-Signature")
if not session_id or not timestamp or not signature:
abort(401)
try:
ts = int(timestamp)
if abs(time.time() - ts) > SIGNATURE_TTL_SECONDS:
abort(401)
except ValueError:
abort(401)
session = CliApiSessionManager().get(session_id)
if not session:
abort(401)
body = request.get_data()
if not _verify_signature(session.secret, timestamp, body, signature):
abort(401)
g.cli_api_session = session
return view(*args, **kwargs)
return decorated

View File

@@ -5,6 +5,8 @@ from enum import StrEnum
from flask_restx import Namespace
from pydantic import BaseModel, TypeAdapter
from controllers.console import console_ns
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
@@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
def get_or_create_model(model_name: str, field_def):
# Import lazily to avoid circular imports between console controllers and schema helpers.
from controllers.console import console_ns
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)

View File

@@ -32,16 +32,13 @@ for module_name in RESOURCE_MODULES:
# Ensure resource modules are imported so route decorators are evaluated.
# Import other controllers
# Sandbox file browser
from . import (
admin,
apikey,
extension,
feature,
human_input_form,
init_validate,
ping,
sandbox_files,
setup,
spec,
version,
@@ -53,7 +50,6 @@ from .app import (
agent,
annotation,
app,
app_asset,
audio,
completion,
conversation,
@@ -64,11 +60,9 @@ from .app import (
model_config,
ops_trace,
site,
skills,
statistic,
workflow,
workflow_app_log,
workflow_comment,
workflow_draft_variable,
workflow_run,
workflow_statistic,
@@ -120,7 +114,6 @@ from .explore import (
saved_message,
trial,
)
from .socketio import workflow as socketio_workflow # pyright: ignore[reportUnusedImport]
# Import tag controllers
from .tag import tags
@@ -135,7 +128,6 @@ from .workspace import (
model_providers,
models,
plugin,
sandbox_providers,
tool_providers,
trigger_providers,
workspace,
@@ -154,7 +146,6 @@ __all__ = [
"api",
"apikey",
"app",
"app_asset",
"audio",
"banner",
"billing",
@@ -180,7 +171,6 @@ __all__ = [
"forgot_password",
"generator",
"hit_testing",
"human_input_form",
"init_validate",
"installed_app",
"load_balancing_config",
@@ -204,12 +194,9 @@ __all__ = [
"rag_pipeline_import",
"rag_pipeline_workflow",
"recommended_app",
"sandbox_files",
"sandbox_providers",
"saved_message",
"setup",
"site",
"skills",
"spec",
"statistic",
"tags",
@@ -220,7 +207,6 @@ __all__ = [
"website",
"workflow",
"workflow_app_log",
"workflow_comment",
"workflow_draft_variable",
"workflow_run",
"workflow_statistic",

View File

@@ -243,13 +243,15 @@ class InsertExploreBannerApi(Resource):
def post(self):
payload = InsertExploreBannerPayload.model_validate(console_ns.payload)
content = {
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
}
banner = ExporleBanner(
content={
"category": payload.category,
"title": payload.title,
"description": payload.description,
"img-src": payload.img_src,
},
content=content,
link=payload.link,
sort=payload.sort,
language=payload.language,

View File

@@ -10,7 +10,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.dataset import Dataset
from models.model import ApiToken, App
from services.api_token_service import ApiTokenCache
from . import console_ns
from .wraps import account_initialization_required, edit_permission_required, setup_required
@@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
if key is None:
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -1,11 +1,10 @@
from typing import Any, Literal
from flask import abort, make_response, request
from flask_restx import Resource
from pydantic import BaseModel, Field, TypeAdapter, field_validator
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel, Field, field_validator
from controllers.common.errors import NoFileUploadedError, TooManyFilesError
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
@@ -17,11 +16,9 @@ from controllers.console.wraps import (
)
from extensions.ext_redis import redis_client
from fields.annotation_fields import (
Annotation,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
AnnotationList,
annotation_fields,
annotation_hit_history_fields,
build_annotation_model,
)
from libs.helper import uuid_value
from libs.login import login_required
@@ -92,14 +89,6 @@ reg(CreateAnnotationPayload)
reg(UpdateAnnotationPayload)
reg(AnnotationReplyStatusQuery)
reg(AnnotationFilePayload)
register_schema_models(
console_ns,
Annotation,
AnnotationList,
AnnotationExportList,
AnnotationHitHistory,
AnnotationHitHistoryList,
)
@console_ns.route("/apps/<uuid:app_id>/annotation-reply/<string:action>")
@@ -118,11 +107,10 @@ class AnnotationReplyActionApi(Resource):
def post(self, app_id, action: Literal["enable", "disable"]):
app_id = str(app_id)
args = AnnotationReplyPayload.model_validate(console_ns.payload)
match action:
case "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
case "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
if action == "enable":
result = AppAnnotationService.enable_app_annotation(args.model_dump(), app_id)
elif action == "disable":
result = AppAnnotationService.disable_app_annotation(app_id)
return result, 200
@@ -213,33 +201,33 @@ class AnnotationApi(Resource):
app_id = str(app_id)
annotation_list, total = AppAnnotationService.get_annotation_list_by_app_id(app_id, page, limit, keyword)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response = AnnotationList(
data=annotation_models,
has_more=len(annotation_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json"), 200
response = {
"data": marshal(annotation_list, annotation_fields),
"has_more": len(annotation_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response, 200
@console_ns.doc("create_annotation")
@console_ns.doc(description="Create a new annotation for an app")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[CreateAnnotationPayload.__name__])
@console_ns.response(201, "Annotation created successfully", console_ns.models[Annotation.__name__])
@console_ns.response(201, "Annotation created successfully", build_annotation_model(console_ns))
@console_ns.response(403, "Insufficient permissions")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@marshal_with(annotation_fields)
@edit_permission_required
def post(self, app_id):
app_id = str(app_id)
args = CreateAnnotationPayload.model_validate(console_ns.payload)
data = args.model_dump(exclude_none=True)
annotation = AppAnnotationService.up_insert_app_annotation_from_message(data, app_id)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
return annotation
@setup_required
@login_required
@@ -276,7 +264,7 @@ class AnnotationExportApi(Resource):
@console_ns.response(
200,
"Annotations exported successfully",
console_ns.models[AnnotationExportList.__name__],
console_ns.model("AnnotationList", {"data": fields.List(fields.Nested(build_annotation_model(console_ns)))}),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -286,8 +274,7 @@ class AnnotationExportApi(Resource):
def get(self, app_id):
app_id = str(app_id)
annotation_list = AppAnnotationService.export_annotation_list_by_app_id(app_id)
annotation_models = TypeAdapter(list[Annotation]).validate_python(annotation_list, from_attributes=True)
response_data = AnnotationExportList(data=annotation_models).model_dump(mode="json")
response_data = {"data": marshal(annotation_list, annotation_fields)}
# Create response with secure headers for CSV export
response = make_response(response_data, 200)
@@ -302,7 +289,7 @@ class AnnotationUpdateDeleteApi(Resource):
@console_ns.doc("update_delete_annotation")
@console_ns.doc(description="Update or delete an annotation")
@console_ns.doc(params={"app_id": "Application ID", "annotation_id": "Annotation ID"})
@console_ns.response(200, "Annotation updated successfully", console_ns.models[Annotation.__name__])
@console_ns.response(200, "Annotation updated successfully", build_annotation_model(console_ns))
@console_ns.response(204, "Annotation deleted successfully")
@console_ns.response(403, "Insufficient permissions")
@console_ns.expect(console_ns.models[UpdateAnnotationPayload.__name__])
@@ -311,6 +298,7 @@ class AnnotationUpdateDeleteApi(Resource):
@account_initialization_required
@cloud_edition_billing_resource_check("annotation")
@edit_permission_required
@marshal_with(annotation_fields)
def post(self, app_id, annotation_id):
app_id = str(app_id)
annotation_id = str(annotation_id)
@@ -318,7 +306,7 @@ class AnnotationUpdateDeleteApi(Resource):
annotation = AppAnnotationService.update_app_annotation_directly(
args.model_dump(exclude_none=True), app_id, annotation_id
)
return Annotation.model_validate(annotation, from_attributes=True).model_dump(mode="json")
return annotation
@setup_required
@login_required
@@ -426,7 +414,14 @@ class AnnotationHitHistoryListApi(Resource):
@console_ns.response(
200,
"Hit histories retrieved successfully",
console_ns.models[AnnotationHitHistoryList.__name__],
console_ns.model(
"AnnotationHitHistoryList",
{
"data": fields.List(
fields.Nested(console_ns.model("AnnotationHitHistoryItem", annotation_hit_history_fields))
)
},
),
)
@console_ns.response(403, "Insufficient permissions")
@setup_required
@@ -441,14 +436,11 @@ class AnnotationHitHistoryListApi(Resource):
annotation_hit_history_list, total = AppAnnotationService.get_annotation_hit_histories(
app_id, annotation_id, page, limit
)
history_models = TypeAdapter(list[AnnotationHitHistory]).validate_python(
annotation_hit_history_list, from_attributes=True
)
response = AnnotationHitHistoryList(
data=history_models,
has_more=len(annotation_hit_history_list) == limit,
limit=limit,
total=total,
page=page,
)
return response.model_dump(mode="json")
response = {
"data": marshal(annotation_hit_history_list, annotation_hit_history_fields),
"has_more": len(annotation_hit_history_list) == limit,
"limit": limit,
"total": total,
"page": page,
}
return response

View File

@@ -1,7 +1,5 @@
import logging
import uuid
from datetime import datetime
from enum import StrEnum
from typing import Any, Literal, TypeAlias
from flask import request
@@ -32,7 +30,6 @@ from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App, DatasetPermissionEnum, Workflow
from models.model import IconType
from models.workflow_features import WorkflowFeatures
from services.app_dsl_service import AppDslService, ImportMode
from services.app_service import AppService
from services.enterprise.enterprise_service import EnterpriseService
@@ -57,13 +54,6 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class RuntimeType(StrEnum):
CLASSIC = "classic"
SANDBOXED = "sandboxed"
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@@ -129,11 +119,6 @@ class AppExportQuery(BaseModel):
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppExportBundleQuery(BaseModel):
include_secret: bool = Field(default=False, description="Include secrets in export")
workflow_id: str | None = Field(default=None, description="Specific workflow ID to export")
class AppNamePayload(BaseModel):
name: str = Field(..., min_length=1, description="Name to check")
@@ -359,7 +344,6 @@ class AppPartial(ResponseModel):
create_user_name: str | None = None
author_name: str | None = None
has_draft_trigger: bool | None = None
runtime_type: RuntimeType = RuntimeType.CLASSIC
@computed_field(return_type=str | None) # type: ignore
@property
@@ -509,14 +493,12 @@ class AppListApi(Resource):
str(app.id) for app in app_pagination.items if app.mode in {"workflow", "advanced-chat"}
]
draft_trigger_app_ids: set[str] = set()
sandbox_app_ids: set[str] = set()
if workflow_capable_app_ids:
draft_workflows = (
db.session.execute(
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
)
)
.scalars()
@@ -528,23 +510,16 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
# Check sandbox feature
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
sandbox_app_ids.add(str(workflow.app_id))
node_id = None
try:
for node_id, node_data in workflow.walk_nodes():
for _, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue
for app in app_pagination.items:
app.has_draft_trigger = str(app.id) in draft_trigger_app_ids
app.runtime_type = RuntimeType.SANDBOXED if str(app.id) in sandbox_app_ids else RuntimeType.CLASSIC
pagination_model = AppPagination.model_validate(app_pagination, from_attributes=True)
return pagination_model.model_dump(mode="json"), 200
@@ -713,69 +688,6 @@ class AppExportApi(Resource):
return payload.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/export-bundle")
class AppExportBundleApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def get(self, app_model):
from services.app_bundle_service import AppBundleService
args = AppExportBundleQuery.model_validate(request.args.to_dict(flat=True))
current_user, _ = current_account_with_tenant()
result = AppBundleService.export_bundle(
app_model=app_model,
account_id=str(current_user.id),
include_secret=args.include_secret,
workflow_id=args.workflow_id,
)
return result.model_dump(mode="json")
@console_ns.route("/apps/<uuid:app_id>/publish-to-creators-platform")
class AppPublishToCreatorsPlatformApi(Resource):
@get_app_model
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self, app_model):
"""Bundle the app DSL and upload to Creators Platform, returning a redirect URL."""
import httpx
from configs import dify_config
from core.helper.creators import get_redirect_url, upload_dsl
from services.app_bundle_service import AppBundleService
if not dify_config.CREATORS_PLATFORM_FEATURES_ENABLED:
return {"message": "Creators Platform is not enabled"}, 403
current_user, _ = current_account_with_tenant()
# 1. Export the app bundle (DSL + assets) to a temporary zip
bundle_result = AppBundleService.export_bundle(
app_model=app_model,
account_id=str(current_user.id),
include_secret=False,
)
# 2. Download the zip from the temporary URL
download_response = httpx.get(bundle_result.download_url, timeout=60, follow_redirects=True)
download_response.raise_for_status()
# 3. Upload to Creators Platform
claim_code = upload_dsl(download_response.content, filename=bundle_result.filename)
# 4. Generate redirect URL (with optional OAuth code)
redirect_url = get_redirect_url(str(current_user.id), claim_code)
return {"redirect_url": redirect_url}
@console_ns.route("/apps/<uuid:app_id>/name")
class AppNameApi(Resource):
@console_ns.doc("check_app_name")

View File

@@ -1,321 +0,0 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from controllers.console import console_ns
from controllers.console.app.error import (
AppAssetNodeNotFoundError,
AppAssetPathConflictError,
)
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.entities.app_asset_entities import BatchUploadNode
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from services.app_asset_service import AppAssetService
from services.errors.app_asset import (
AppAssetNodeNotFoundError as ServiceNodeNotFoundError,
)
from services.errors.app_asset import (
AppAssetParentNotFoundError,
)
from services.errors.app_asset import (
AppAssetPathConflictError as ServicePathConflictError,
)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class CreateFolderPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
parent_id: str | None = None
class CreateFilePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
parent_id: str | None = None
@field_validator("name", mode="before")
@classmethod
def strip_name(cls, v: str) -> str:
return v.strip() if isinstance(v, str) else v
@field_validator("parent_id", mode="before")
@classmethod
def empty_to_none(cls, v: str | None) -> str | None:
return v or None
class GetUploadUrlPayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
size: int = Field(..., ge=0)
parent_id: str | None = None
@field_validator("name", mode="before")
@classmethod
def strip_name(cls, v: str) -> str:
return v.strip() if isinstance(v, str) else v
@field_validator("parent_id", mode="before")
@classmethod
def empty_to_none(cls, v: str | None) -> str | None:
return v or None
class BatchUploadPayload(BaseModel):
children: list[BatchUploadNode] = Field(..., min_length=1)
class UpdateFileContentPayload(BaseModel):
content: str
class RenameNodePayload(BaseModel):
name: str = Field(..., min_length=1, max_length=255)
class MoveNodePayload(BaseModel):
parent_id: str | None = None
class ReorderNodePayload(BaseModel):
after_node_id: str | None = Field(default=None, description="Place after this node, None for first position")
def reg(cls: type[BaseModel]) -> None:
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
reg(CreateFolderPayload)
reg(CreateFilePayload)
reg(GetUploadUrlPayload)
reg(BatchUploadNode)
reg(BatchUploadPayload)
reg(UpdateFileContentPayload)
reg(RenameNodePayload)
reg(MoveNodePayload)
reg(ReorderNodePayload)
@console_ns.route("/apps/<string:app_id>/assets/tree")
class AppAssetTreeResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
current_user, _ = current_account_with_tenant()
tree = AppAssetService.get_asset_tree(app_model, current_user.id)
return {"children": [view.model_dump() for view in tree.transform()]}
@console_ns.route("/apps/<string:app_id>/assets/folders")
class AppAssetFolderResource(Resource):
@console_ns.expect(console_ns.models[CreateFolderPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
payload = CreateFolderPayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.create_folder(app_model, current_user.id, payload.name, payload.parent_id)
return node.model_dump(), 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>")
class AppAssetFileDetailResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
content = AppAssetService.get_file_content(app_model, current_user.id, node_id)
return {"content": content.decode("utf-8", errors="replace")}
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.expect(console_ns.models[UpdateFileContentPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def put(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
file = request.files.get("file")
if file:
content = file.read()
else:
payload = UpdateFileContentPayload.model_validate(console_ns.payload or {})
content = payload.content.encode("utf-8")
try:
node = AppAssetService.update_file_content(app_model, current_user.id, node_id, content)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>")
class AppAssetNodeResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def delete(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
AppAssetService.delete_node(app_model, current_user.id, node_id)
return {"result": "success"}, 200
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/rename")
class AppAssetNodeRenameResource(Resource):
@console_ns.expect(console_ns.models[RenameNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = RenameNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.rename_node(app_model, current_user.id, node_id, payload.name)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/move")
class AppAssetNodeMoveResource(Resource):
@console_ns.expect(console_ns.models[MoveNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = MoveNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.move_node(app_model, current_user.id, node_id, payload.parent_id)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/nodes/<string:node_id>/reorder")
class AppAssetNodeReorderResource(Resource):
@console_ns.expect(console_ns.models[ReorderNodePayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
payload = ReorderNodePayload.model_validate(console_ns.payload or {})
try:
node = AppAssetService.reorder_node(app_model, current_user.id, node_id, payload.after_node_id)
return node.model_dump()
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/files/<string:node_id>/download-url")
class AppAssetFileDownloadUrlResource(Resource):
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, node_id: str):
current_user, _ = current_account_with_tenant()
try:
download_url = AppAssetService.get_file_download_url(app_model, current_user.id, node_id)
return {"download_url": download_url}
except ServiceNodeNotFoundError:
raise AppAssetNodeNotFoundError()
@console_ns.route("/apps/<string:app_id>/assets/files/upload")
class AppAssetFileUploadUrlResource(Resource):
@console_ns.expect(console_ns.models[GetUploadUrlPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
payload = GetUploadUrlPayload.model_validate(console_ns.payload or {})
try:
node, upload_url = AppAssetService.get_file_upload_url(
app_model, current_user.id, payload.name, payload.size, payload.parent_id
)
return {"node": node.model_dump(), "upload_url": upload_url}, 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()
@console_ns.route("/apps/<string:app_id>/assets/batch-upload")
class AppAssetBatchUploadResource(Resource):
@console_ns.expect(console_ns.models[BatchUploadPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
"""
Create nodes from tree structure and return upload URLs.
Input:
{
"children": [
{"name": "folder1", "node_type": "folder", "children": [
{"name": "file1.txt", "node_type": "file", "size": 1024}
]},
{"name": "root.txt", "node_type": "file", "size": 512}
]
}
Output:
{
"children": [
{"id": "xxx", "name": "folder1", "node_type": "folder", "children": [
{"id": "yyy", "name": "file1.txt", "node_type": "file", "size": 1024, "upload_url": "..."}
]},
{"id": "zzz", "name": "root.txt", "node_type": "file", "size": 512, "upload_url": "..."}
]
}
"""
current_user, _ = current_account_with_tenant()
payload = BatchUploadPayload.model_validate(console_ns.payload or {})
try:
result_children = AppAssetService.batch_create_from_tree(app_model, current_user.id, payload.children)
return {"children": [child.model_dump() for child in result_children]}, 201
except AppAssetParentNotFoundError:
raise AppAssetNodeNotFoundError()
except ServicePathConflictError:
raise AppAssetPathConflictError()

View File

@@ -51,14 +51,6 @@ class AppImportPayload(BaseModel):
app_id: str | None = Field(None)
class AppImportBundleConfirmPayload(BaseModel):
name: str | None = None
description: str | None = None
icon_type: str | None = None
icon: str | None = None
icon_background: str | None = None
console_ns.schema_model(
AppImportPayload.__name__, AppImportPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
@@ -147,68 +139,3 @@ class AppImportCheckDependenciesApi(Resource):
result = import_service.check_dependencies(app_model=app_model)
return result.model_dump(mode="json"), 200
@console_ns.route("/apps/imports-bundle/prepare")
class AppImportBundlePrepareApi(Resource):
"""Step 1: Get upload URL for bundle import."""
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
from services.app_bundle_service import AppBundleService
current_user, current_tenant_id = current_account_with_tenant()
result = AppBundleService.prepare_import(
tenant_id=current_tenant_id,
account_id=current_user.id,
)
return {"import_id": result.import_id, "upload_url": result.upload_url}, 200
@console_ns.route("/apps/imports-bundle/<string:import_id>/confirm")
class AppImportBundleConfirmApi(Resource):
"""Step 2: Confirm bundle import after upload."""
@setup_required
@login_required
@account_initialization_required
@marshal_with(app_import_model)
@cloud_edition_billing_resource_check("apps")
@edit_permission_required
def post(self, import_id: str):
from flask import request
from core.app.entities.app_bundle_entities import BundleFormatError
from services.app_bundle_service import AppBundleService
current_user, _ = current_account_with_tenant()
args = AppImportBundleConfirmPayload.model_validate(request.get_json() or {})
try:
result = AppBundleService.confirm_import(
import_id=import_id,
account=current_user,
name=args.name,
description=args.description,
icon_type=args.icon_type,
icon=args.icon,
icon_background=args.icon_background,
)
except BundleFormatError as e:
return {"error": str(e)}, 400
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, "private")
status = result.status
if status == ImportStatus.FAILED:
return result.model_dump(mode="json"), 400
elif status == ImportStatus.PENDING:
return result.model_dump(mode="json"), 202
return result.model_dump(mode="json"), 200

View File

@@ -6,7 +6,6 @@ from pydantic import BaseModel, Field
from werkzeug.exceptions import InternalServerError
import services
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
AppUnavailableError,
@@ -34,6 +33,7 @@ from services.errors.audio import (
)
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class TextToSpeechPayload(BaseModel):
@@ -47,11 +47,13 @@ class TextToSpeechVoiceQuery(BaseModel):
language: str = Field(..., description="Language code")
class AudioTranscriptResponse(BaseModel):
text: str = Field(description="Transcribed text from audio")
register_schema_models(console_ns, AudioTranscriptResponse, TextToSpeechPayload, TextToSpeechVoiceQuery)
console_ns.schema_model(
TextToSpeechPayload.__name__, TextToSpeechPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechVoiceQuery.__name__,
TextToSpeechVoiceQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/apps/<uuid:app_id>/audio-to-text")
@@ -62,7 +64,7 @@ class ChatMessageAudioApi(Resource):
@console_ns.response(
200,
"Audio transcription successful",
console_ns.models[AudioTranscriptResponse.__name__],
console_ns.model("AudioTranscriptResponse", {"text": fields.String(description="Transcribed text from audio")}),
)
@console_ns.response(400, "Bad request - No audio uploaded or unsupported type")
@console_ns.response(413, "Audio file too large")

View File

@@ -89,7 +89,6 @@ status_count_model = console_ns.model(
"success": fields.Integer,
"failed": fields.Integer,
"partial_success": fields.Integer,
"paused": fields.Integer,
},
)
@@ -509,19 +508,16 @@ class ChatConversationApi(Resource):
case "created_at" | "-created_at" | _:
query = query.where(Conversation.created_at <= end_datetime_utc)
match args.annotation_status:
case "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
case "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
case "all":
pass
if args.annotation_status == "annotated":
query = query.options(joinedload(Conversation.message_annotations)).join( # type: ignore
MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id
)
elif args.annotation_status == "not_annotated":
query = (
query.outerjoin(MessageAnnotation, MessageAnnotation.conversation_id == Conversation.id)
.group_by(Conversation.id)
.having(func.count(MessageAnnotation.id) == 0)
)
if app_model.mode == AppMode.ADVANCED_CHAT:
query = query.where(Conversation.invoke_from != InvokeFrom.DEBUGGER)
@@ -599,12 +595,7 @@ def _get_conversation(app_model, conversation_id):
db.session.execute(
sa.update(Conversation)
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
# Keep updated_at unchanged when only marking a conversation as read.
.values(
read_at=naive_utc_now(),
read_account_id=current_user.id,
updated_at=Conversation.updated_at,
)
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
)
db.session.commit()
db.session.refresh(conversation)

View File

@@ -110,6 +110,8 @@ class TracingConfigCheckError(BaseHTTPException):
class InvokeRateLimitError(BaseHTTPException):
"""Raised when the Invoke returns rate limit error."""
error_code = "rate_limit_error"
description = "Rate Limit Error"
code = 429
@@ -119,21 +121,3 @@ class NeedAddIdsError(BaseHTTPException):
error_code = "need_add_ids"
description = "Need to add ids."
code = 400
class AppAssetNodeNotFoundError(BaseHTTPException):
error_code = "app_asset_node_not_found"
description = "App asset node not found."
code = 404
class AppAssetFileRequiredError(BaseHTTPException):
error_code = "app_asset_file_required"
description = "File is required."
code = 400
class AppAssetPathConflictError(BaseHTTPException):
error_code = "app_asset_path_conflict"
description = "Path already exists."
code = 409

View File

@@ -12,17 +12,10 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.context_models import (
AvailableVarPayload,
CodeContextPayload,
ParameterInfoPayload,
)
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@@ -33,13 +26,28 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
@@ -47,34 +55,6 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class ContextGeneratePayload(BaseModel):
"""Payload for generating extractor code node."""
language: str = Field(default="python3", description="Code language (python3/javascript)")
prompt_messages: list[dict[str, Any]] = Field(
..., description="Multi-turn conversation history, last message is the current instruction"
)
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
code_context: CodeContextPayload = Field(description="Existing code node context for incremental generation")
class SuggestedQuestionsPayload(BaseModel):
"""Payload for generating suggested questions."""
language: str = Field(
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
)
model_config_data: dict[str, Any] = Field(
default_factory=dict,
alias="model_config",
description="Model configuration (optional, uses system default if not provided)",
)
available_vars: list[AvailableVarPayload] = Field(..., description="Available variables from upstream nodes")
parameter_info: ParameterInfoPayload = Field(..., description="Target parameter metadata from the frontend")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -84,9 +64,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ContextGeneratePayload)
reg(SuggestedQuestionsPayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")
@@ -105,7 +82,12 @@ class RuleGenerateApi(Resource):
_, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -136,7 +118,9 @@ class RuleCodeGenerateApi(Resource):
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -168,7 +152,8 @@ class RuleStructuredOutputGenerateApi(Resource):
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -219,29 +204,23 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
)
case _:
return {"error": f"invalid node type: {node_type}"}
@@ -299,70 +278,3 @@ class InstructionGenerationTemplateApi(Resource):
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
case _:
raise ValueError(f"Invalid type: {args.type}")
@console_ns.route("/context-generate")
class ContextGenerateApi(Resource):
@console_ns.doc("generate_with_context")
@console_ns.doc(description="Generate with multi-turn conversation context")
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
@console_ns.response(200, "Content generated successfully")
@console_ns.response(400, "Invalid request parameters or workflow not found")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
from core.llm_generator.utils import deserialize_prompt_messages
args = ContextGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
return LLMGenerator.generate_with_context(
tenant_id=current_tenant_id,
language=args.language,
prompt_messages=deserialize_prompt_messages(args.prompt_messages),
model_config=args.model_config_data,
available_vars=args.available_vars,
parameter_info=args.parameter_info,
code_context=args.code_context,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
@console_ns.route("/context-generate/suggested-questions")
class SuggestedQuestionsApi(Resource):
@console_ns.doc("generate_suggested_questions")
@console_ns.doc(description="Generate suggested questions for context generation")
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
@console_ns.response(200, "Questions generated successfully")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
return LLMGenerator.generate_suggested_questions(
tenant_id=current_tenant_id,
language=args.language,
available_vars=args.available_vars,
parameter_info=args.parameter_info,
model_config=args.model_config_data,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)

View File

@@ -7,7 +7,6 @@ from pydantic import BaseModel, Field, field_validator
from sqlalchemy import exists, select
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@@ -33,9 +32,10 @@ from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService, attach_message_extra_contents
from services.message_service import MessageService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class ChatMessagesQuery(BaseModel):
@@ -90,22 +90,13 @@ class FeedbackExportQuery(BaseModel):
raise ValueError("has_comment must be a boolean value")
class AnnotationCountResponse(BaseModel):
count: int = Field(description="Number of annotations")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
class SuggestedQuestionsResponse(BaseModel):
data: list[str] = Field(description="Suggested question")
register_schema_models(
console_ns,
ChatMessagesQuery,
MessageFeedbackPayload,
FeedbackExportQuery,
AnnotationCountResponse,
SuggestedQuestionsResponse,
)
reg(ChatMessagesQuery)
reg(MessageFeedbackPayload)
reg(FeedbackExportQuery)
# Register models for flask_restx to avoid dict type issues in Swagger
# Register in dependency order: base models first, then dependent models
@@ -207,12 +198,10 @@ message_detail_model = console_ns.model(
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
"generation_detail": fields.Raw,
},
)
@@ -242,7 +231,7 @@ class ChatMessageListApi(Resource):
@marshal_with(message_infinite_scroll_pagination_model)
@edit_permission_required
def get(self, app_model):
args = ChatMessagesQuery.model_validate(request.args.to_dict())
args = ChatMessagesQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
conversation = (
db.session.query(Conversation)
@@ -301,7 +290,6 @@ class ChatMessageListApi(Resource):
has_more = False
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@@ -368,7 +356,7 @@ class MessageAnnotationCountApi(Resource):
@console_ns.response(
200,
"Annotation count retrieved successfully",
console_ns.models[AnnotationCountResponse.__name__],
console_ns.model("AnnotationCountResponse", {"count": fields.Integer(description="Number of annotations")}),
)
@get_app_model
@setup_required
@@ -388,7 +376,9 @@ class MessageSuggestedQuestionApi(Resource):
@console_ns.response(
200,
"Suggested questions retrieved successfully",
console_ns.models[SuggestedQuestionsResponse.__name__],
console_ns.model(
"SuggestedQuestionsResponse", {"data": fields.List(fields.String(description="Suggested question"))}
),
)
@console_ns.response(404, "Message or conversation not found")
@setup_required
@@ -438,7 +428,7 @@ class MessageFeedbackExportApi(Resource):
@login_required
@account_initialization_required
def get(self, app_model):
args = FeedbackExportQuery.model_validate(request.args.to_dict())
args = FeedbackExportQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
# Import the service function
from services.feedback_service import FeedbackService
@@ -484,5 +474,4 @@ class MessageApi(Resource):
if not message:
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message

View File

@@ -1,83 +0,0 @@
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, current_account_with_tenant, setup_required
from libs.login import login_required
from models import App
from models.model import AppMode
from services.skill_service import SkillService
from services.workflow_service import WorkflowService
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nodes/<string:node_id>/skills")
class NodeSkillsApi(Resource):
"""API for retrieving skill references for a specific workflow node."""
@console_ns.doc("get_node_skills")
@console_ns.doc(description="Get skill references for a specific node in the draft workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.response(200, "Node skills retrieved successfully")
@console_ns.response(404, "Workflow or node not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App, node_id: str):
"""
Get skill information for a specific node in the draft workflow.
Returns information about skill references in the node, including:
- skill_references: List of prompt messages marked as skills
- tool_references: Aggregated tool references from all skill prompts
- file_references: Aggregated file references from all skill prompts
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise DraftWorkflowNotExist()
skill_info = SkillService.get_node_skill_info(
app=app_model,
workflow=workflow,
node_id=node_id,
user_id=current_user.id,
)
return skill_info.model_dump()
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/skills")
class WorkflowSkillsApi(Resource):
"""API for retrieving all skill references in a workflow."""
@console_ns.doc("get_workflow_skills")
@console_ns.doc(description="Get all skill references in the draft workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Workflow skills retrieved successfully")
@console_ns.response(404, "Workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def get(self, app_model: App):
"""
Get skill information for all nodes in the draft workflow that have skill references.
Returns a list of nodes with their skill information.
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
workflow = workflow_service.get_draft_workflow(app_model=app_model)
if not workflow:
raise DraftWorkflowNotExist()
skills_info = SkillService.get_workflow_skills(
app=app_model,
workflow=workflow,
user_id=current_user.id,
)
return {"nodes": [info.model_dump() for info in skills_info]}

View File

@@ -33,10 +33,8 @@ from core.trigger.debug.event_selectors import (
from core.workflow.enums import NodeType
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from factories import file_factory, variable_factory
from fields.member_fields import simple_account_fields
from fields.online_user_fields import online_user_list_fields
from fields.workflow_fields import workflow_fields, workflow_pagination_fields
from libs import helper
from libs.datetime_utils import naive_utc_now
@@ -45,12 +43,9 @@ from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import AppMode
from models.workflow import Workflow
from repositories.workflow_collaboration_repository import WORKFLOW_ONLINE_USERS_PREFIX
from services.app_generate_service import AppGenerateService
from services.errors.app import WorkflowHashNotEqualError
from services.errors.llm import InvokeRateLimitError
from services.workflow.entities import NestedNodeGraphRequest, NestedNodeParameterSchema
from services.workflow.nested_node_graph_service import NestedNodeGraphService
from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService
logger = logging.getLogger(__name__)
@@ -165,14 +160,6 @@ class WorkflowUpdatePayload(BaseModel):
marked_comment: str | None = Field(default=None, max_length=100)
class WorkflowFeaturesPayload(BaseModel):
features: dict[str, Any] = Field(..., description="Workflow feature configuration")
class WorkflowOnlineUsersQuery(BaseModel):
workflow_ids: str = Field(..., description="Comma-separated workflow IDs")
class DraftWorkflowTriggerRunPayload(BaseModel):
node_id: str
@@ -181,15 +168,6 @@ class DraftWorkflowTriggerRunAllPayload(BaseModel):
node_ids: list[str]
class NestedNodeGraphPayload(BaseModel):
"""Request payload for generating nested node graph."""
parent_node_id: str = Field(description="ID of the parent node that uses the extracted value")
parameter_key: str = Field(description="Key of the parameter being extracted")
context_source: list[str] = Field(description="Variable selector for the context source")
parameter_schema: dict[str, Any] = Field(description="Schema of the parameter to extract")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -205,11 +183,8 @@ reg(DefaultBlockConfigQuery)
reg(ConvertToWorkflowPayload)
reg(WorkflowListQuery)
reg(WorkflowUpdatePayload)
reg(WorkflowFeaturesPayload)
reg(WorkflowOnlineUsersQuery)
reg(DraftWorkflowTriggerRunPayload)
reg(DraftWorkflowTriggerRunAllPayload)
reg(NestedNodeGraphPayload)
# TODO(QuantumGhost): Refactor existing node run API to handle file parameter parsing
@@ -532,179 +507,6 @@ class WorkflowDraftRunLoopNodeApi(Resource):
raise InternalServerError()
class HumanInputFormPreviewPayload(BaseModel):
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
class HumanInputFormSubmitPayload(BaseModel):
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
inputs: dict[str, Any] = Field(
...,
description="Values used to fill missing upstream variables referenced in form_content",
)
action: str = Field(..., description="Selected action ID")
class HumanInputDeliveryTestPayload(BaseModel):
delivery_method_id: str = Field(..., description="Delivery method ID")
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
reg(HumanInputFormPreviewPayload)
reg(HumanInputFormSubmitPayload)
reg(HumanInputDeliveryTestPayload)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class AdvancedChatDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class WorkflowDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_workflow_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class WorkflowDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_workflow_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
@console_ns.doc("test_workflow_draft_human_input_delivery")
@console_ns.doc(description="Test human input delivery for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Test human input delivery
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
workflow_service.test_human_input_delivery(
app_model=app_model,
account=current_user,
node_id=node_id,
delivery_method_id=args.delivery_method_id,
inputs=args.inputs,
)
return jsonable_encoder({})
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")
@@ -852,14 +654,13 @@ class PublishedWorkflowApi(Resource):
"""
Publish workflow
"""
from services.app_bundle_service import AppBundleService
current_user, _ = current_account_with_tenant()
args = PublishWorkflowPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
with Session(db.engine) as session:
workflow = AppBundleService.publish(
workflow = workflow_service.publish_workflow(
session=session,
app_model=app_model,
account=current_user,
@@ -970,31 +771,6 @@ class ConvertToWorkflowApi(Resource):
}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/features")
class WorkflowFeaturesApi(Resource):
"""Update draft workflow features."""
@console_ns.expect(console_ns.models[WorkflowFeaturesPayload.__name__])
@console_ns.doc("update_workflow_features")
@console_ns.doc(description="Update draft workflow features")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Workflow features updated successfully")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
current_user, _ = current_account_with_tenant()
args = WorkflowFeaturesPayload.model_validate(console_ns.payload or {})
features = args.features
workflow_service = WorkflowService()
workflow_service.update_draft_workflow_features(app_model=app_model, features=features, account=current_user)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows")
class PublishedAllWorkflowApi(Resource):
@console_ns.expect(console_ns.models[WorkflowListQuery.__name__])
@@ -1372,83 +1148,3 @@ class DraftWorkflowTriggerRunAllApi(Resource):
"status": "error",
}
), 400
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/nested-node-graph")
class NestedNodeGraphApi(Resource):
"""
API for generating Nested Node LLM graph structures.
This endpoint creates a complete graph structure containing an LLM node
configured to extract values from list[PromptMessage] variables.
"""
@console_ns.doc("generate_nested_node_graph")
@console_ns.doc(description="Generate a Nested Node LLM graph structure")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[NestedNodeGraphPayload.__name__])
@console_ns.response(200, "Nested node graph generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(403, "Permission denied")
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App):
"""
Generate a Nested Node LLM graph structure.
Returns a complete graph structure containing a single LLM node
configured for extracting values from list[PromptMessage] context.
"""
payload = NestedNodeGraphPayload.model_validate(console_ns.payload or {})
parameter_schema = NestedNodeParameterSchema(
name=payload.parameter_schema.get("name", payload.parameter_key),
type=payload.parameter_schema.get("type", "string"),
description=payload.parameter_schema.get("description", ""),
)
request = NestedNodeGraphRequest(
parent_node_id=payload.parent_node_id,
parameter_key=payload.parameter_key,
context_source=payload.context_source,
parameter_schema=parameter_schema,
)
with Session(db.engine) as session:
service = NestedNodeGraphService(session)
response = service.generate_nested_node_graph(tenant_id=app_model.tenant_id, request=request)
return response.model_dump()
@console_ns.route("/apps/workflows/online-users")
class WorkflowOnlineUsersApi(Resource):
@console_ns.expect(console_ns.models[WorkflowOnlineUsersQuery.__name__])
@console_ns.doc("get_workflow_online_users")
@console_ns.doc(description="Get workflow online users")
@setup_required
@login_required
@account_initialization_required
@marshal_with(online_user_list_fields)
def get(self):
args = WorkflowOnlineUsersQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workflow_ids = [workflow_id.strip() for workflow_id in args.workflow_ids.split(",") if workflow_id.strip()]
results = []
for workflow_id in workflow_ids:
users_json = redis_client.hgetall(f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}")
users = []
for _, user_info_json in users_json.items():
try:
users.append(json.loads(user_info_json))
except Exception:
continue
results.append({"workflow_id": workflow_id, "users": users})
return {"data": results}

View File

@@ -1,322 +0,0 @@
import logging
from flask_restx import Resource, marshal_with
from pydantic import BaseModel, Field, TypeAdapter
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from fields.member_fields import AccountWithRole
from fields.workflow_comment_fields import (
workflow_comment_basic_fields,
workflow_comment_create_fields,
workflow_comment_detail_fields,
workflow_comment_reply_create_fields,
workflow_comment_reply_update_fields,
workflow_comment_resolve_fields,
workflow_comment_update_fields,
)
from libs.login import current_user, login_required
from models import App
from services.account_service import TenantService
from services.workflow_comment_service import WorkflowCommentService
logger = logging.getLogger(__name__)
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowCommentCreatePayload(BaseModel):
position_x: float = Field(..., description="Comment X position")
position_y: float = Field(..., description="Comment Y position")
content: str = Field(..., description="Comment content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentUpdatePayload(BaseModel):
content: str = Field(..., description="Comment content")
position_x: float | None = Field(default=None, description="Comment X position")
position_y: float | None = Field(default=None, description="Comment Y position")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyCreatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentReplyUpdatePayload(BaseModel):
content: str = Field(..., description="Reply content")
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
class WorkflowCommentMentionUsersResponse(BaseModel):
users: list[AccountWithRole] = Field(description="Mentionable users")
for model in (
WorkflowCommentCreatePayload,
WorkflowCommentUpdatePayload,
WorkflowCommentReplyCreatePayload,
WorkflowCommentReplyUpdatePayload,
):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
workflow_comment_reply_create_model = console_ns.model(
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
)
workflow_comment_reply_update_model = console_ns.model(
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
)
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
class WorkflowCommentListApi(Resource):
"""API for listing and creating workflow comments."""
@console_ns.doc("list_workflow_comments")
@console_ns.doc(description="Get all comments for a workflow")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_basic_model, envelope="data")
def get(self, app_model: App):
"""Get all comments for a workflow."""
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
return comments
@console_ns.doc("create_workflow_comment")
@console_ns.doc(description="Create a new workflow comment")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_create_model)
def post(self, app_model: App):
"""Create a new workflow comment."""
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
created_by=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
class WorkflowCommentDetailApi(Resource):
"""API for managing individual workflow comments."""
@console_ns.doc("get_workflow_comment")
@console_ns.doc(description="Get a specific workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_detail_model)
def get(self, app_model: App, comment_id: str):
"""Get a specific workflow comment."""
comment = WorkflowCommentService.get_comment(
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
)
return comment
@console_ns.doc("update_workflow_comment")
@console_ns.doc(description="Update a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_update_model)
def put(self, app_model: App, comment_id: str):
"""Update a workflow comment."""
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.update_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
content=payload.content,
position_x=payload.position_x,
position_y=payload.position_y,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result
@console_ns.doc("delete_workflow_comment")
@console_ns.doc(description="Delete a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(204, "Comment deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str):
"""Delete a workflow comment."""
WorkflowCommentService.delete_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
class WorkflowCommentResolveApi(Resource):
"""API for resolving and reopening workflow comments."""
@console_ns.doc("resolve_workflow_comment")
@console_ns.doc(description="Resolve a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_resolve_model)
def post(self, app_model: App, comment_id: str):
"""Resolve a workflow comment."""
comment = WorkflowCommentService.resolve_comment(
tenant_id=current_user.current_tenant_id,
app_id=app_model.id,
comment_id=comment_id,
user_id=current_user.id,
)
return comment
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
class WorkflowCommentReplyApi(Resource):
"""API for managing comment replies."""
@console_ns.doc("create_workflow_comment_reply")
@console_ns.doc(description="Add a reply to a workflow comment")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_create_model)
def post(self, app_model: App, comment_id: str):
"""Add a reply to a workflow comment."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
result = WorkflowCommentService.create_reply(
comment_id=comment_id,
content=payload.content,
created_by=current_user.id,
mentioned_user_ids=payload.mentioned_user_ids,
)
return result, 201
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
class WorkflowCommentReplyDetailApi(Resource):
"""API for managing individual comment replies."""
@console_ns.doc("update_workflow_comment_reply")
@console_ns.doc(description="Update a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
@marshal_with(workflow_comment_reply_update_model)
def put(self, app_model: App, comment_id: str, reply_id: str):
"""Update a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
reply = WorkflowCommentService.update_reply(
reply_id=reply_id,
user_id=current_user.id,
content=payload.content,
mentioned_user_ids=payload.mentioned_user_ids,
)
return reply
@console_ns.doc("delete_workflow_comment_reply")
@console_ns.doc(description="Delete a comment reply")
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
@console_ns.response(204, "Reply deleted successfully")
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def delete(self, app_model: App, comment_id: str, reply_id: str):
"""Delete a comment reply."""
# Validate comment access first
WorkflowCommentService.validate_comment_access(
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
)
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
return {"result": "success"}, 204
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
class WorkflowCommentMentionUsersApi(Resource):
"""API for getting mentionable users for workflow comments."""
@console_ns.doc("workflow_comment_mention_users")
@console_ns.doc(description="Get all users in current tenant for mentions")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
@login_required
@setup_required
@account_initialization_required
@get_app_model()
def get(self, app_model: App):
"""Get all users in current tenant for mentions."""
members = TenantService.get_tenant_members(current_user.current_tenant)
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
response = WorkflowCommentMentionUsersResponse(users=member_models)
return response.model_dump(mode="json"), 200

View File

@@ -17,16 +17,15 @@ from controllers.console.wraps import account_initialization_required, edit_perm
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.file import helpers as file_helpers
from core.variables.segment_group import SegmentGroup
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.variables.types import SegmentType
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from extensions.ext_database import db
from factories import variable_factory
from factories.file_factory import build_from_mapping, build_from_mappings
from libs.login import current_account_with_tenant, login_required
from factories.variable_factory import build_segment_with_type
from libs.login import login_required
from models import App, AppMode
from models.workflow import WorkflowDraftVariable
from services.sandbox.sandbox_service import SandboxService
from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService
from services.workflow_service import WorkflowService
@@ -44,16 +43,6 @@ class WorkflowDraftVariableUpdatePayload(BaseModel):
value: Any | None = Field(default=None, description="Variable value")
class ConversationVariableUpdatePayload(BaseModel):
conversation_variables: list[dict[str, Any]] = Field(
..., description="Conversation variables for the draft workflow"
)
class EnvironmentVariableUpdatePayload(BaseModel):
environment_variables: list[dict[str, Any]] = Field(..., description="Environment variables for the draft workflow")
console_ns.schema_model(
WorkflowDraftVariableListQuery.__name__,
WorkflowDraftVariableListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
@@ -62,14 +51,6 @@ console_ns.schema_model(
WorkflowDraftVariableUpdatePayload.__name__,
WorkflowDraftVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
ConversationVariableUpdatePayload.__name__,
ConversationVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
EnvironmentVariableUpdatePayload.__name__,
EnvironmentVariableUpdatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def _convert_values_to_json_serializable_object(value: Segment):
@@ -77,8 +58,6 @@ def _convert_values_to_json_serializable_object(value: Segment):
return value.value.model_dump()
elif isinstance(value, ArrayFileSegment):
return [i.model_dump() for i in value.value]
elif isinstance(value, ArrayPromptMessageSegment):
return value.to_object()
elif isinstance(value, SegmentGroup):
return [_convert_values_to_json_serializable_object(i) for i in value.value]
else:
@@ -268,8 +247,6 @@ class WorkflowVariableCollectionApi(Resource):
@console_ns.response(204, "Workflow variables deleted successfully")
@_api_prerequisite
def delete(self, app_model: App):
current_user, _ = current_account_with_tenant()
SandboxService.delete_draft_storage(app_model.tenant_id, app_model.id, current_user.id)
draft_var_srv = WorkflowDraftVariableService(
session=db.session(),
)
@@ -406,7 +383,7 @@ class VariableApi(Resource):
if len(raw_value) > 0 and not isinstance(raw_value[0], dict):
raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}")
raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id)
new_value = variable_factory.build_segment_with_type(variable.value_type, raw_value)
new_value = build_segment_with_type(variable.value_type, raw_value)
draft_var_srv.update_variable(variable, name=new_name, value=new_value)
db.session.commit()
return variable
@@ -499,35 +476,6 @@ class ConversationVariableCollectionApi(Resource):
db.session.commit()
return _get_variable_list(app_model, CONVERSATION_VARIABLE_NODE_ID)
@console_ns.expect(console_ns.models[ConversationVariableUpdatePayload.__name__])
@console_ns.doc("update_conversation_variables")
@console_ns.doc(description="Update conversation variables for workflow draft")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Conversation variables updated successfully")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=AppMode.ADVANCED_CHAT)
def post(self, app_model: App):
payload = ConversationVariableUpdatePayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
conversation_variables_list = payload.conversation_variables
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
current_user, _ = current_account_with_tenant()
workflow_service.update_draft_workflow_conversation_variables(
app_model=app_model,
account=current_user,
conversation_variables=conversation_variables,
)
return {"result": "success"}
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/system-variables")
class SystemVariableCollectionApi(Resource):
@@ -579,32 +527,3 @@ class EnvironmentVariableCollectionApi(Resource):
)
return {"items": env_vars_list}
@console_ns.expect(console_ns.models[EnvironmentVariableUpdatePayload.__name__])
@console_ns.doc("update_environment_variables")
@console_ns.doc(description="Update environment variables for workflow draft")
@console_ns.doc(params={"app_id": "Application ID"})
@console_ns.response(200, "Environment variables updated successfully")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
def post(self, app_model: App):
payload = EnvironmentVariableUpdatePayload.model_validate(console_ns.payload or {})
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
environment_variables_list = payload.environment_variables
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
workflow_service.update_draft_workflow_environment_variables(
app_model=app_model,
account=current_user,
environment_variables=environment_variables,
)
return {"result": "success"}

View File

@@ -5,15 +5,10 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -32,21 +27,9 @@ from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
if not form_token:
return None
base_url = dify_config.APP_WEB_URL
if not base_url:
return None
return f"{base_url.rstrip('/')}/form/{form_token}"
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
@@ -457,68 +440,3 @@ class WorkflowRunNodeExecutionListApi(Resource):
)
return {"data": node_executions}
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@setup_required
@login_required
@account_initialization_required
def get(self, workflow_run_id: str):
"""
Get workflow pause details.
GET /console/api/workflow/<workflow_run_id>/pause-details
Returns information about why and where the workflow is paused.
"""
# Query WorkflowRun to determine if workflow is suspended
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
raise NotFoundError("Workflow run not found")
if workflow_run.tenant_id != current_user.current_tenant_id:
raise NotFoundError("Workflow run not found")
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
"paused_at": None,
"paused_nodes": [],
}, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
paused_nodes = []
response = {
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
"paused_nodes": paused_nodes,
}
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
paused_nodes.append(
{
"node_id": reason.node_id,
"node_title": reason.node_title,
"pause_type": {
"type": "human_input",
"form_id": reason.form_id,
"backstage_input_url": _build_backstage_input_url(reason.form_token),
},
}
)
else:
raise AssertionError("unimplemented.")
return response, 200

View File

@@ -2,11 +2,9 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from pydantic import BaseModel, Field
from flask_restx import Resource, fields
from configs import dify_config
from controllers.common.schema import register_schema_models
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
@@ -16,26 +14,6 @@ from ..wraps import account_initialization_required, is_admin_or_owner_required,
logger = logging.getLogger(__name__)
class OAuthDataSourceResponse(BaseModel):
data: str = Field(description="Authorization URL or 'internal' for internal setup")
class OAuthDataSourceBindingResponse(BaseModel):
result: str = Field(description="Operation result")
class OAuthDataSourceSyncResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
OAuthDataSourceResponse,
OAuthDataSourceBindingResponse,
OAuthDataSourceSyncResponse,
)
def get_oauth_providers():
with current_app.app_context():
notion_oauth = NotionOAuth(
@@ -56,7 +34,10 @@ class OAuthDataSource(Resource):
@console_ns.response(
200,
"Authorization URL or internal setup success",
console_ns.models[OAuthDataSourceResponse.__name__],
console_ns.model(
"OAuthDataSourceResponse",
{"data": fields.Raw(description="Authorization URL or 'internal' for internal setup")},
),
)
@console_ns.response(400, "Invalid provider")
@console_ns.response(403, "Admin privileges required")
@@ -120,7 +101,7 @@ class OAuthDataSourceBinding(Resource):
@console_ns.response(
200,
"Data source binding success",
console_ns.models[OAuthDataSourceBindingResponse.__name__],
console_ns.model("OAuthDataSourceBindingResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid provider or code")
def get(self, provider: str):
@@ -152,7 +133,7 @@ class OAuthDataSourceSync(Resource):
@console_ns.response(
200,
"Data source sync success",
console_ns.models[OAuthDataSourceSyncResponse.__name__],
console_ns.model("OAuthDataSourceSyncResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid provider or sync failed")
@setup_required

View File

@@ -2,11 +2,10 @@ import base64
import secrets
from flask import request
from flask_restx import Resource
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.auth.error import (
EmailCodeError,
@@ -49,31 +48,8 @@ class ForgotPasswordResetPayload(BaseModel):
return valid_password(value)
class ForgotPasswordEmailResponse(BaseModel):
result: str = Field(description="Operation result")
data: str | None = Field(default=None, description="Reset token")
code: str | None = Field(default=None, description="Error code if account not found")
class ForgotPasswordCheckResponse(BaseModel):
is_valid: bool = Field(description="Whether code is valid")
email: EmailStr = Field(description="Email address")
token: str = Field(description="New reset token")
class ForgotPasswordResetResponse(BaseModel):
result: str = Field(description="Operation result")
register_schema_models(
console_ns,
ForgotPasswordSendPayload,
ForgotPasswordCheckPayload,
ForgotPasswordResetPayload,
ForgotPasswordEmailResponse,
ForgotPasswordCheckResponse,
ForgotPasswordResetResponse,
)
for model in (ForgotPasswordSendPayload, ForgotPasswordCheckPayload, ForgotPasswordResetPayload):
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@console_ns.route("/forgot-password")
@@ -84,7 +60,14 @@ class ForgotPasswordSendEmailApi(Resource):
@console_ns.response(
200,
"Email sent successfully",
console_ns.models[ForgotPasswordEmailResponse.__name__],
console_ns.model(
"ForgotPasswordEmailResponse",
{
"result": fields.String(description="Operation result"),
"data": fields.String(description="Reset token"),
"code": fields.String(description="Error code if account not found"),
},
),
)
@console_ns.response(400, "Invalid email or rate limit exceeded")
@setup_required
@@ -123,7 +106,14 @@ class ForgotPasswordCheckApi(Resource):
@console_ns.response(
200,
"Code verified successfully",
console_ns.models[ForgotPasswordCheckResponse.__name__],
console_ns.model(
"ForgotPasswordCheckResponse",
{
"is_valid": fields.Boolean(description="Whether code is valid"),
"email": fields.String(description="Email address"),
"token": fields.String(description="New reset token"),
},
),
)
@console_ns.response(400, "Invalid code or token")
@setup_required
@@ -173,7 +163,7 @@ class ForgotPasswordResetApi(Resource):
@console_ns.response(
200,
"Password reset successfully",
console_ns.models[ForgotPasswordResetResponse.__name__],
console_ns.model("ForgotPasswordResetResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Invalid token or password mismatch")
@setup_required

View File

@@ -155,43 +155,43 @@ class OAuthServerUserTokenApi(Resource):
grant_type = OAuthGrantType(payload.grant_type)
except ValueError:
raise BadRequest("invalid grant_type")
match grant_type:
case OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
if grant_type == OAuthGrantType.AUTHORIZATION_CODE:
if not payload.code:
raise BadRequest("code is required")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
if payload.client_secret != oauth_provider_app.client_secret:
raise BadRequest("client_secret is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
case OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
if payload.redirect_uri not in oauth_provider_app.redirect_uris:
raise BadRequest("redirect_uri is invalid")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, code=payload.code, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
elif grant_type == OAuthGrantType.REFRESH_TOKEN:
if not payload.refresh_token:
raise BadRequest("refresh_token is required")
access_token, refresh_token = OAuthServerService.sign_oauth_access_token(
grant_type, refresh_token=payload.refresh_token, client_id=oauth_provider_app.client_id
)
return jsonable_encoder(
{
"access_token": access_token,
"token_type": "Bearer",
"expires_in": OAUTH_ACCESS_TOKEN_EXPIRES_IN,
"refresh_token": refresh_token,
}
)
@console_ns.route("/oauth/provider/account")

View File

@@ -1,6 +1,6 @@
import json
from collections.abc import Generator
from typing import Any, Literal, cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal_with
@@ -157,8 +157,9 @@ class DataSourceApi(Resource):
@setup_required
@login_required
@account_initialization_required
def patch(self, binding_id, action: Literal["enable", "disable"]):
def patch(self, binding_id, action):
binding_id = str(binding_id)
action = str(action)
with Session(db.engine) as session:
data_source_binding = session.execute(
select(DataSourceOauthBinding).filter_by(id=binding_id)
@@ -166,24 +167,23 @@ class DataSourceApi(Resource):
if data_source_binding is None:
raise NotFound("Data source binding not found.")
# enable binding
match action:
case "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
case "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
if action == "enable":
if data_source_binding.disabled:
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is not disabled.")
# disable binding
if action == "disable":
if not data_source_binding.disabled:
data_source_binding.disabled = True
data_source_binding.updated_at = naive_utc_now()
db.session.add(data_source_binding)
db.session.commit()
else:
raise ValueError("Data source is disabled.")
return {"result": "success"}, 200

View File

@@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
from models.dataset import DatasetPermissionEnum
from models.provider_ids import ModelProviderID
from services.api_token_service import ApiTokenCache
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
# Register models for flask_restx to avoid dict type issues in Swagger
@@ -149,7 +148,6 @@ class DatasetUpdatePayload(BaseModel):
embedding_model: str | None = None
embedding_model_provider: str | None = None
retrieval_model: dict[str, Any] | None = None
summary_index_setting: dict[str, Any] | None = None
partial_member_list: list[dict[str, str]] | None = None
external_retrieval_model: dict[str, Any] | None = None
external_knowledge_id: str | None = None
@@ -290,14 +288,7 @@ class DatasetListApi(Resource):
@enterprise_license_required
def get(self):
current_user, current_tenant_id = current_account_with_tenant()
# Convert query parameters to dict, handling list parameters correctly
query_params: dict[str, str | list[str]] = dict(request.args.to_dict())
# Handle ids and tag_ids as lists (Flask request.args.getlist returns list even for single value)
if "ids" in request.args:
query_params["ids"] = request.args.getlist("ids")
if "tag_ids" in request.args:
query_params["tag_ids"] = request.args.getlist("tag_ids")
query = ConsoleDatasetListQuery.model_validate(query_params)
query = ConsoleDatasetListQuery.model_validate(request.args.to_dict())
# provider = request.args.get("provider", default="vendor")
if query.ids:
datasets, total = DatasetService.get_datasets_by_ids(query.ids, current_tenant_id)
@@ -821,11 +812,6 @@ class DatasetApiDeleteApi(Resource):
if key is None:
console_ns.abort(404, message="API key not found")
# Invalidate cache before deleting from database
# Type assertion: key is guaranteed to be non-None here because abort() raises
assert key is not None # nosec - for type checker only
ApiTokenCache.delete(key.token, key.type)
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
db.session.commit()

View File

@@ -45,7 +45,6 @@ from models.dataset import DocumentPipelineExecutionLog
from services.dataset_service import DatasetService, DocumentService
from services.entities.knowledge_entities.knowledge_entities import KnowledgeConfig, ProcessRule, RetrievalModel
from services.file_service import FileService
from tasks.generate_summary_index_task import generate_summary_index_task
from ..app.error import (
ProviderModelCurrentlyNotSupportError,
@@ -104,10 +103,6 @@ class DocumentRenamePayload(BaseModel):
name: str
class GenerateSummaryPayload(BaseModel):
document_list: list[str]
class DocumentBatchDownloadZipPayload(BaseModel):
"""Request payload for bulk downloading documents as a zip archive."""
@@ -130,7 +125,6 @@ register_schema_models(
RetrievalModel,
DocumentRetryPayload,
DocumentRenamePayload,
GenerateSummaryPayload,
DocumentBatchDownloadZipPayload,
)
@@ -318,13 +312,6 @@ class DatasetDocumentListApi(Resource):
paginated_documents = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
documents = paginated_documents.items
DocumentService.enrich_documents_with_summary_index_status(
documents=documents,
dataset=dataset,
tenant_id=current_tenant_id,
)
if fetch:
for document in documents:
completed_segments = (
@@ -576,62 +563,63 @@ class DocumentBatchIndexingEstimateApi(DocumentResource):
if document.indexing_status in {"completed", "error"}:
raise DocumentAlreadyFinishedError()
data_source_info = document.data_source_info_dict
match document.data_source_type:
case "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
if file_detail is None:
raise NotFound("File not found.")
if document.data_source_type == "upload_file":
if not data_source_info:
continue
file_id = data_source_info["upload_file_id"]
file_detail = (
db.session.query(UploadFile)
.where(UploadFile.tenant_id == current_tenant_id, UploadFile.id == file_id)
.first()
)
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
case "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
case "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
if file_detail is None:
raise NotFound("File not found.")
case _:
raise ValueError("Data source type not support")
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=file_detail, document_model=document.doc_form
)
extract_settings.append(extract_setting)
elif document.data_source_type == "notion_import":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.NOTION,
notion_info=NotionInfo.model_validate(
{
"credential_id": data_source_info.get("credential_id"),
"notion_workspace_id": data_source_info["notion_workspace_id"],
"notion_obj_id": data_source_info["notion_page_id"],
"notion_page_type": data_source_info["type"],
"tenant_id": current_tenant_id,
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
elif document.data_source_type == "website_crawl":
if not data_source_info:
continue
extract_setting = ExtractSetting(
datasource_type=DatasourceType.WEBSITE,
website_info=WebsiteInfo.model_validate(
{
"provider": data_source_info["provider"],
"job_id": data_source_info["job_id"],
"url": data_source_info["url"],
"tenant_id": current_tenant_id,
"mode": data_source_info["mode"],
"only_main_content": data_source_info["only_main_content"],
}
),
document_model=document.doc_form,
)
extract_settings.append(extract_setting)
else:
raise ValueError("Data source type not support")
indexing_runner = IndexingRunner()
try:
response = indexing_runner.indexing_estimate(
@@ -809,7 +797,6 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
else:
dataset_process_rules = DatasetService.get_process_rules(dataset_id)
@@ -845,7 +832,6 @@ class DocumentApi(DocumentResource):
"display_status": document.display_status,
"doc_form": document.doc_form,
"doc_language": document.doc_language,
"need_summary": document.need_summary if document.need_summary is not None else False,
}
return response, 200
@@ -953,24 +939,23 @@ class DocumentProcessingApi(DocumentResource):
if not current_user.is_dataset_editor:
raise Forbidden()
match action:
case "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
if action == "pause":
if document.indexing_status != "indexing":
raise InvalidActionError("Document not in indexing state.")
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
document.paused_by = current_user.id
document.paused_at = naive_utc_now()
document.is_paused = True
db.session.commit()
case "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
elif action == "resume":
if document.indexing_status not in {"paused", "error"}:
raise InvalidActionError("Document not in paused or error state.")
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
document.paused_by = None
document.paused_at = None
document.is_paused = False
db.session.commit()
return {"result": "success"}, 200
@@ -1270,149 +1255,3 @@ class DocumentPipelineExecutionLogApi(DocumentResource):
"input_data": log.input_data,
"datasource_node_id": log.datasource_node_id,
}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/generate-summary")
class DocumentGenerateSummaryApi(Resource):
@console_ns.doc("generate_summary_for_documents")
@console_ns.doc(description="Generate summary index for documents")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[GenerateSummaryPayload.__name__])
@console_ns.response(200, "Summary generation started successfully")
@console_ns.response(400, "Invalid request or dataset configuration")
@console_ns.response(403, "Permission denied")
@console_ns.response(404, "Dataset not found")
@setup_required
@login_required
@account_initialization_required
@cloud_edition_billing_rate_limit_check("knowledge")
def post(self, dataset_id):
"""
Generate summary index for specified documents.
This endpoint checks if the dataset configuration supports summary generation
(indexing_technique must be 'high_quality' and summary_index_setting.enable must be true),
then asynchronously generates summary indexes for the provided documents.
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
if not current_user.is_dataset_editor:
raise Forbidden()
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Validate request payload
payload = GenerateSummaryPayload.model_validate(console_ns.payload or {})
document_list = payload.document_list
if not document_list:
from werkzeug.exceptions import BadRequest
raise BadRequest("document_list cannot be empty.")
# Check if dataset configuration supports summary generation
if dataset.indexing_technique != "high_quality":
raise ValueError(
f"Summary generation is only available for 'high_quality' indexing technique. "
f"Current indexing technique: {dataset.indexing_technique}"
)
summary_index_setting = dataset.summary_index_setting
if not summary_index_setting or not summary_index_setting.get("enable"):
raise ValueError("Summary index is not enabled for this dataset. Please enable it in the dataset settings.")
# Verify all documents exist and belong to the dataset
documents = DocumentService.get_documents_by_ids(dataset_id, document_list)
if len(documents) != len(document_list):
found_ids = {doc.id for doc in documents}
missing_ids = set(document_list) - found_ids
raise NotFound(f"Some documents not found: {list(missing_ids)}")
# Update need_summary to True for documents that don't have it set
# This handles the case where documents were created when summary_index_setting was disabled
documents_to_update = [doc for doc in documents if not doc.need_summary and doc.doc_form != "qa_model"]
if documents_to_update:
document_ids_to_update = [str(doc.id) for doc in documents_to_update]
DocumentService.update_documents_need_summary(
dataset_id=dataset_id,
document_ids=document_ids_to_update,
need_summary=True,
)
# Dispatch async tasks for each document
for document in documents:
# Skip qa_model documents as they don't generate summaries
if document.doc_form == "qa_model":
logger.info("Skipping summary generation for qa_model document %s", document.id)
continue
# Dispatch async task
generate_summary_index_task.delay(dataset_id, document.id)
logger.info(
"Dispatched summary generation task for document %s in dataset %s",
document.id,
dataset_id,
)
return {"result": "success"}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/summary-status")
class DocumentSummaryStatusApi(DocumentResource):
@console_ns.doc("get_document_summary_status")
@console_ns.doc(description="Get summary index generation status for a document")
@console_ns.doc(params={"dataset_id": "Dataset ID", "document_id": "Document ID"})
@console_ns.response(200, "Summary status retrieved successfully")
@console_ns.response(404, "Document not found")
@setup_required
@login_required
@account_initialization_required
def get(self, dataset_id, document_id):
"""
Get summary index generation status for a document.
Returns:
- total_segments: Total number of segments in the document
- summary_status: Dictionary with status counts
- completed: Number of summaries completed
- generating: Number of summaries being generated
- error: Number of summaries with errors
- not_started: Number of segments without summary records
- summaries: List of summary records with status and content preview
"""
current_user, _ = current_account_with_tenant()
dataset_id = str(dataset_id)
document_id = str(document_id)
# Get dataset
dataset = DatasetService.get_dataset(dataset_id)
if not dataset:
raise NotFound("Dataset not found.")
# Check permissions
try:
DatasetService.check_dataset_permission(dataset, current_user)
except services.errors.account.NoPermissionError as e:
raise Forbidden(str(e))
# Get summary status detail from service
from services.summary_index_service import SummaryIndexService
result = SummaryIndexService.get_document_summary_status_detail(
document_id=document_id,
dataset_id=dataset_id,
)
return result, 200

View File

@@ -41,17 +41,6 @@ from services.errors.chunk import ChildChunkIndexingError as ChildChunkIndexingS
from tasks.batch_create_segment_to_index_task import batch_create_segment_to_index_task
def _get_segment_with_summary(segment, dataset_id):
"""Helper function to marshal segment and add summary information."""
from services.summary_index_service import SummaryIndexService
segment_dict = dict(marshal(segment, segment_fields))
# Query summary for this segment (only enabled summaries)
summary = SummaryIndexService.get_segment_summary(segment_id=segment.id, dataset_id=dataset_id)
segment_dict["summary"] = summary.summary_content if summary else None
return segment_dict
class SegmentListQuery(BaseModel):
limit: int = Field(default=20, ge=1, le=100)
status: list[str] = Field(default_factory=list)
@@ -74,7 +63,6 @@ class SegmentUpdatePayload(BaseModel):
keywords: list[str] | None = None
regenerate_child_chunks: bool = False
attachment_ids: list[str] | None = None
summary: str | None = None # Summary content for summary index
class BatchImportPayload(BaseModel):
@@ -193,25 +181,8 @@ class DatasetDocumentSegmentListApi(Resource):
segments = db.paginate(select=query, page=page, per_page=limit, max_per_page=100, error_out=False)
# Query summaries for all segments in this page (batch query for efficiency)
segment_ids = [segment.id for segment in segments.items]
summaries = {}
if segment_ids:
from services.summary_index_service import SummaryIndexService
summary_records = SummaryIndexService.get_segments_summaries(segment_ids=segment_ids, dataset_id=dataset_id)
# Only include enabled summaries (already filtered by service)
summaries = {chunk_id: summary.summary_content for chunk_id, summary in summary_records.items()}
# Add summary to each segment
segments_with_summary = []
for segment in segments.items:
segment_dict = dict(marshal(segment, segment_fields))
segment_dict["summary"] = summaries.get(segment.id)
segments_with_summary.append(segment_dict)
response = {
"data": segments_with_summary,
"data": marshal(segments.items, segment_fields),
"limit": limit,
"total": segments.total,
"total_pages": segments.pages,
@@ -357,7 +328,7 @@ class DatasetDocumentSegmentAddApi(Resource):
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
segment = SegmentService.create_segment(payload_dict, document, dataset)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@console_ns.route("/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>")
@@ -419,12 +390,10 @@ class DatasetDocumentSegmentUpdateApi(Resource):
payload = SegmentUpdatePayload.model_validate(console_ns.payload or {})
payload_dict = payload.model_dump(exclude_none=True)
SegmentService.segment_create_args_validate(payload_dict, document)
# Update segment (summary update with change detection is handled in SegmentService.update_segment)
segment = SegmentService.update_segment(
SegmentUpdateArgs.model_validate(payload.model_dump(exclude_none=True)), segment, document, dataset
)
return {"data": _get_segment_with_summary(segment, dataset_id), "doc_form": document.doc_form}, 200
return {"data": marshal(segment, segment_fields), "doc_form": document.doc_form}, 200
@setup_required
@login_required

View File

@@ -1,13 +1,6 @@
from flask_restx import Resource, fields
from flask_restx import Resource
from controllers.common.schema import register_schema_model
from fields.hit_testing_fields import (
child_chunk_fields,
document_fields,
files_fields,
hit_testing_record_fields,
segment_fields,
)
from libs.login import login_required
from .. import console_ns
@@ -21,45 +14,13 @@ from ..wraps import (
register_schema_model(console_ns, HitTestingPayload)
def _get_or_create_model(model_name: str, field_def):
"""Get or create a flask_restx model to avoid dict type issues in Swagger."""
existing = console_ns.models.get(model_name)
if existing is None:
existing = console_ns.model(model_name, field_def)
return existing
# Register models for flask_restx to avoid dict type issues in Swagger
document_model = _get_or_create_model("HitTestingDocument", document_fields)
segment_fields_copy = segment_fields.copy()
segment_fields_copy["document"] = fields.Nested(document_model)
segment_model = _get_or_create_model("HitTestingSegment", segment_fields_copy)
child_chunk_model = _get_or_create_model("HitTestingChildChunk", child_chunk_fields)
files_model = _get_or_create_model("HitTestingFile", files_fields)
hit_testing_record_fields_copy = hit_testing_record_fields.copy()
hit_testing_record_fields_copy["segment"] = fields.Nested(segment_model)
hit_testing_record_fields_copy["child_chunks"] = fields.List(fields.Nested(child_chunk_model))
hit_testing_record_fields_copy["files"] = fields.List(fields.Nested(files_model))
hit_testing_record_model = _get_or_create_model("HitTestingRecord", hit_testing_record_fields_copy)
# Response model for hit testing API
hit_testing_response_fields = {
"query": fields.String,
"records": fields.List(fields.Nested(hit_testing_record_model)),
}
hit_testing_response_model = _get_or_create_model("HitTestingResponse", hit_testing_response_fields)
@console_ns.route("/datasets/<uuid:dataset_id>/hit-testing")
class HitTestingApi(Resource, DatasetsHitTestingBase):
@console_ns.doc("test_dataset_retrieval")
@console_ns.doc(description="Test dataset knowledge retrieval")
@console_ns.doc(params={"dataset_id": "Dataset ID"})
@console_ns.expect(console_ns.models[HitTestingPayload.__name__])
@console_ns.response(200, "Hit testing completed successfully", model=hit_testing_response_model)
@console_ns.response(200, "Hit testing completed successfully")
@console_ns.response(404, "Dataset not found")
@console_ns.response(400, "Invalid parameters")
@setup_required

View File

@@ -126,11 +126,10 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
raise NotFound("Dataset not found.")
DatasetService.check_dataset_permission(dataset, current_user)
match action:
case "enable":
MetadataService.enable_built_in_field(dataset)
case "disable":
MetadataService.disable_built_in_field(dataset)
if action == "enable":
MetadataService.enable_built_in_field(dataset)
elif action == "disable":
MetadataService.disable_built_in_field(dataset)
return {"result": "success"}, 200

View File

@@ -1,9 +1,10 @@
import json
import logging
from typing import Any, Literal, cast
from uuid import UUID
from flask import abort, request
from flask_restx import Resource, marshal_with # type: ignore
from flask_restx import Resource, marshal_with, reqparse # type: ignore
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
@@ -37,7 +38,7 @@ from core.model_runtime.utils.encoders import jsonable_encoder
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField, UUIDStrOrEmpty
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, current_user, login_required
from models import Account
from models.dataset import Pipeline
@@ -109,7 +110,7 @@ class NodeIdQuery(BaseModel):
class WorkflowRunQuery(BaseModel):
last_id: UUIDStrOrEmpty | None = None
last_id: UUID | None = None
limit: int = Field(default=20, ge=1, le=100)
@@ -120,10 +121,6 @@ class DatasourceVariablesPayload(BaseModel):
start_node_title: str
class RagPipelineRecommendedPluginQuery(BaseModel):
type: str = "all"
register_schema_models(
console_ns,
DraftWorkflowSyncPayload,
@@ -138,7 +135,6 @@ register_schema_models(
NodeIdQuery,
WorkflowRunQuery,
DatasourceVariablesPayload,
RagPipelineRecommendedPluginQuery,
)
@@ -979,8 +975,11 @@ class RagPipelineRecommendedPluginApi(Resource):
@login_required
@account_initialization_required
def get(self):
query = RagPipelineRecommendedPluginQuery.model_validate(request.args.to_dict())
parser = reqparse.RequestParser()
parser.add_argument("type", type=str, location="args", required=False, default="all")
args = parser.parse_args()
type = args["type"]
rag_pipeline_service = RagPipelineService()
recommended_plugins = rag_pipeline_service.get_recommended_plugins(query.type)
recommended_plugins = rag_pipeline_service.get_recommended_plugins(type)
return recommended_plugins

View File

@@ -1,9 +1,8 @@
import logging
from typing import Any, Literal, cast
from typing import Any, cast
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
import services
@@ -52,7 +51,7 @@ from fields.app_fields import (
tag_fields,
)
from fields.dataset_fields import dataset_fields
from fields.member_fields import simple_account_fields
from fields.member_fields import build_simple_account_model
from fields.workflow_fields import (
conversation_variable_fields,
pipeline_variable_fields,
@@ -104,7 +103,7 @@ app_detail_fields_with_site_copy["tags"] = fields.List(fields.Nested(tag_model))
app_detail_fields_with_site_copy["site"] = fields.Nested(site_model)
app_detail_with_site_model = get_or_create_model("TrialAppDetailWithSite", app_detail_fields_with_site_copy)
simple_account_model = get_or_create_model("SimpleAccount", simple_account_fields)
simple_account_model = build_simple_account_model(console_ns)
conversation_variable_model = get_or_create_model("TrialConversationVariable", conversation_variable_fields)
pipeline_variable_model = get_or_create_model("TrialPipelineVariable", pipeline_variable_fields)
@@ -118,56 +117,7 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
# Pydantic models for request validation
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class WorkflowRunRequest(BaseModel):
inputs: dict
files: list | None = None
class ChatRequest(BaseModel):
inputs: dict
query: str
files: list | None = None
conversation_id: str | None = None
parent_message_id: str | None = None
retriever_from: str = "explore_app"
class TextToSpeechRequest(BaseModel):
message_id: str | None = None
voice: str | None = None
text: str | None = None
streaming: bool | None = None
class CompletionRequest(BaseModel):
inputs: dict
query: str = ""
files: list | None = None
response_mode: Literal["blocking", "streaming"] | None = None
retriever_from: str = "explore_app"
# Register schemas for Swagger documentation
console_ns.schema_model(
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
console_ns.schema_model(
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
)
class TrialAppWorkflowRunApi(TrialAppResource):
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
def post(self, trial_app):
"""
Run workflow
@@ -179,8 +129,10 @@ class TrialAppWorkflowRunApi(TrialAppResource):
if app_mode != AppMode.WORKFLOW:
raise NotWorkflowAppError()
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
parser.add_argument("files", type=list, required=False, location="json")
args = parser.parse_args()
assert current_user is not None
try:
app_id = app_model.id
@@ -231,7 +183,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
class TrialChatApi(TrialAppResource):
@console_ns.expect(console_ns.models[ChatRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
@@ -239,14 +190,14 @@ class TrialChatApi(TrialAppResource):
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
raise NotChatAppError()
request_data = ChatRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
# Validate UUID values if provided
if args.get("conversation_id"):
args["conversation_id"] = uuid_value(args["conversation_id"])
if args.get("parent_message_id"):
args["parent_message_id"] = uuid_value(args["parent_message_id"])
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, required=True, location="json")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("conversation_id", type=uuid_value, location="json")
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
args["auto_generate_name"] = False
@@ -369,16 +320,20 @@ class TrialChatAudioApi(TrialAppResource):
class TrialChatTextApi(TrialAppResource):
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
try:
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
parser = reqparse.RequestParser()
parser.add_argument("message_id", type=str, required=False, location="json")
parser.add_argument("voice", type=str, location="json")
parser.add_argument("text", type=str, location="json")
parser.add_argument("streaming", type=bool, location="json")
args = parser.parse_args()
message_id = request_data.message_id
text = request_data.text
voice = request_data.voice
message_id = args.get("message_id", None)
text = args.get("text", None)
voice = args.get("voice", None)
if not isinstance(current_user, Account):
raise ValueError("current_user must be an Account instance")
@@ -416,15 +371,19 @@ class TrialChatTextApi(TrialAppResource):
class TrialCompletionApi(TrialAppResource):
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
@trial_feature_enable
def post(self, trial_app):
app_model = trial_app
if app_model.mode != "completion":
raise NotCompletionAppError()
request_data = CompletionRequest.model_validate(console_ns.payload)
args = request_data.model_dump()
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("query", type=str, location="json", default="")
parser.add_argument("files", type=list, required=False, location="json")
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
args = parser.parse_args()
streaming = args["response_mode"] == "streaming"
args["auto_generate_name"] = False

View File

@@ -1,217 +0,0 @@
"""
Console/Studio Human Input Form APIs.
"""
import json
import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_service import Form, HumanInputService
from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@console_ns.route("/form/human_input/<string:form_token>")
class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
"""
Get human input form definition by form token.
GET /console/api/form/human_input/<form_token>
"""
service = HumanInputService(db.engine)
form = service.get_form_definition_by_token_for_console(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
def post(self, form_token: str):
"""
Submit human input form by form token.
POST /console/api/form/human_input/<form_token>
Request body:
{
"inputs": {
"content": "User input content"
},
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_user_id=current_user.id,
)
return jsonify({})
@console_ns.route("/workflow/<string:workflow_run_id>/events")
class ConsoleWorkflowEventsApi(Resource):
"""Console API for getting workflow execution events after resume."""
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
GET /console/api/workflow/<workflow_run_id>/events
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
if workflow_run.created_by != user.id:
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
with Session(expire_on_commit=False, bind=db.engine) as session:
app = _retrieve_app_for_workflow_run(session, workflow_run)
if workflow_run.finished_at is not None:
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=AppMode(app.mode),
workflow_run=workflow_run,
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,
)
app = session.scalars(query).first()
if app is None:
raise AssertionError(
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
)
return app

View File

@@ -1,74 +1,87 @@
import os
from typing import Literal
from flask import session
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from controllers.fastopenapi import console_router
from extensions.ext_database import db
from models.model import DifySetup
from services.account_service import TenantService
from . import console_ns
from .error import AlreadySetupError, InitValidateFailedError
from .wraps import only_edition_self_hosted
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class InitValidatePayload(BaseModel):
password: str = Field(..., max_length=30, description="Initialization password")
password: str = Field(..., max_length=30)
class InitStatusResponse(BaseModel):
status: Literal["finished", "not_started"] = Field(..., description="Initialization status")
class InitValidateResponse(BaseModel):
result: str = Field(description="Operation result", examples=["success"])
@console_router.get(
"/init",
response_model=InitStatusResponse,
tags=["console"],
console_ns.schema_model(
InitValidatePayload.__name__,
InitValidatePayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
def get_init_status() -> InitStatusResponse:
"""Get initialization validation status."""
init_status = get_init_validate_status()
if init_status:
return InitStatusResponse(status="finished")
return InitStatusResponse(status="not_started")
@console_router.post(
"/init",
response_model=InitValidateResponse,
tags=["console"],
status_code=201,
)
@only_edition_self_hosted
def validate_init_password(payload: InitValidatePayload) -> InitValidateResponse:
"""Validate initialization password."""
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
@console_ns.route("/init")
class InitValidateAPI(Resource):
@console_ns.doc("get_init_status")
@console_ns.doc(description="Get initialization validation status")
@console_ns.response(
200,
"Success",
model=console_ns.model(
"InitStatusResponse",
{"status": fields.String(description="Initialization status", enum=["finished", "not_started"])},
),
)
def get(self):
"""Get initialization validation status"""
init_status = get_init_validate_status()
if init_status:
return {"status": "finished"}
return {"status": "not_started"}
if payload.password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
@console_ns.doc("validate_init_password")
@console_ns.doc(description="Validate initialization password for self-hosted edition")
@console_ns.expect(console_ns.models[InitValidatePayload.__name__])
@console_ns.response(
201,
"Success",
model=console_ns.model("InitValidateResponse", {"result": fields.String(description="Operation result")}),
)
@console_ns.response(400, "Already setup or validation failed")
@only_edition_self_hosted
def post(self):
"""Validate initialization password"""
# is tenant created
tenant_count = TenantService.get_tenant_count()
if tenant_count > 0:
raise AlreadySetupError()
session["is_init_validated"] = True
return InitValidateResponse(result="success")
payload = InitValidatePayload.model_validate(console_ns.payload)
input_password = payload.password
if input_password != os.environ.get("INIT_PASSWORD"):
session["is_init_validated"] = False
raise InitValidateFailedError()
session["is_init_validated"] = True
return {"result": "success"}, 201
def get_init_validate_status() -> bool:
def get_init_validate_status():
if dify_config.EDITION == "SELF_HOSTED":
if os.environ.get("INIT_PASSWORD"):
if session.get("is_init_validated"):
return True
with Session(db.engine) as db_session:
return db_session.execute(select(DifySetup)).scalar_one_or_none() is not None
return db_session.execute(select(DifySetup)).scalar_one_or_none()
return True

View File

@@ -11,59 +11,68 @@ from controllers.common.errors import (
RemoteFileUploadError,
UnsupportedFileTypeError,
)
from controllers.console import console_ns
from controllers.common.schema import register_schema_models
from core.file import helpers as file_helpers
from core.helper import ssrf_proxy
from extensions.ext_database import db
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
from libs.login import current_account_with_tenant, login_required
from libs.login import current_account_with_tenant
from services.file_service import FileService
from . import console_ns
register_schema_models(console_ns, RemoteFileInfo, FileWithSignedUrl)
@console_ns.route("/remote-files/<path:url>")
class RemoteFileInfoApi(Resource):
@console_ns.response(200, "Remote file info", console_ns.models[RemoteFileInfo.__name__])
def get(self, url):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
# failed back to get method
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
info = RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
)
return info.model_dump(mode="json")
class RemoteFileUploadPayload(BaseModel):
url: str = Field(..., description="URL to fetch")
@console_ns.route("/remote-files/<path:url>")
class GetRemoteFileInfo(Resource):
@login_required
def get(self, url: str):
decoded_url = urllib.parse.unquote(url)
resp = ssrf_proxy.head(decoded_url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(decoded_url, timeout=3)
resp.raise_for_status()
return RemoteFileInfo(
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
file_length=int(resp.headers.get("Content-Length", 0)),
).model_dump(mode="json")
console_ns.schema_model(
RemoteFileUploadPayload.__name__,
RemoteFileUploadPayload.model_json_schema(ref_template="#/definitions/{model}"),
)
@console_ns.route("/remote-files/upload")
class RemoteFileUpload(Resource):
@login_required
class RemoteFileUploadApi(Resource):
@console_ns.expect(console_ns.models[RemoteFileUploadPayload.__name__])
@console_ns.response(201, "Remote file uploaded", console_ns.models[FileWithSignedUrl.__name__])
def post(self):
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = payload.url
args = RemoteFileUploadPayload.model_validate(console_ns.payload)
url = args.url
# Try to fetch remote file metadata/content first
try:
resp = ssrf_proxy.head(url=url)
if resp.status_code != httpx.codes.OK:
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
if resp.status_code != httpx.codes.OK:
# Normalize into a user-friendly error message expected by tests
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
except httpx.RequestError as e:
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
file_info = helpers.guess_file_info_from_response(resp)
# Enforce file size limit with 400 (Bad Request) per tests' expectation
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
raise FileTooLargeError()
raise FileTooLargeError
# Load content if needed
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
try:
@@ -80,17 +89,14 @@ class RemoteFileUpload(Resource):
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
# Success: return created resource with 201 status
return (
FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
).model_dump(mode="json"),
201,
payload = FileWithSignedUrl(
id=upload_file.id,
name=upload_file.name,
size=upload_file.size,
extension=upload_file.extension,
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
mime_type=upload_file.mime_type,
created_by=upload_file.created_by,
created_at=int(upload_file.created_at.timestamp()),
)
return payload.model_dump(mode="json"), 201

View File

@@ -1,103 +0,0 @@
from __future__ import annotations
from fastapi.encoders import jsonable_encoder
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from libs.login import current_account_with_tenant, login_required
from services.sandbox.sandbox_file_service import SandboxFileService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class SandboxFileListQuery(BaseModel):
path: str | None = Field(default=None, description="Workspace relative path")
recursive: bool = Field(default=False, description="List recursively")
class SandboxFileDownloadRequest(BaseModel):
path: str = Field(..., description="Workspace relative file path")
console_ns.schema_model(
SandboxFileListQuery.__name__,
SandboxFileListQuery.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
SandboxFileDownloadRequest.__name__,
SandboxFileDownloadRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
SANDBOX_FILE_NODE_FIELDS = {
"path": fields.String,
"is_dir": fields.Boolean,
"size": fields.Raw,
"mtime": fields.Raw,
"extension": fields.String,
}
SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS = {
"download_url": fields.String,
"expires_in": fields.Integer,
"export_id": fields.String,
}
sandbox_file_node_model = console_ns.model("SandboxFileNode", SANDBOX_FILE_NODE_FIELDS)
sandbox_file_download_ticket_model = console_ns.model("SandboxFileDownloadTicket", SANDBOX_FILE_DOWNLOAD_TICKET_FIELDS)
@console_ns.route("/apps/<string:app_id>/sandbox/files")
class SandboxFilesApi(Resource):
"""List sandbox files for the current user.
The sandbox_id is derived from the current user's ID, as each user has
their own sandbox workspace per app.
"""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[SandboxFileListQuery.__name__])
@console_ns.marshal_list_with(sandbox_file_node_model)
def get(self, app_id: str):
args = SandboxFileListQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore[arg-type]
account, tenant_id = current_account_with_tenant()
sandbox_id = account.id
return jsonable_encoder(
SandboxFileService.list_files(
tenant_id=tenant_id,
app_id=app_id,
sandbox_id=sandbox_id,
path=args.path,
recursive=args.recursive,
)
)
@console_ns.route("/apps/<string:app_id>/sandbox/files/download")
class SandboxFileDownloadApi(Resource):
"""Download a sandbox file for the current user.
The sandbox_id is derived from the current user's ID, as each user has
their own sandbox workspace per app.
"""
@setup_required
@login_required
@account_initialization_required
@console_ns.expect(console_ns.models[SandboxFileDownloadRequest.__name__])
@console_ns.marshal_with(sandbox_file_download_ticket_model)
def post(self, app_id: str):
payload = SandboxFileDownloadRequest.model_validate(console_ns.payload or {})
account, tenant_id = current_account_with_tenant()
sandbox_id = account.id
res = SandboxFileService.download_file(
tenant_id=tenant_id, app_id=app_id, sandbox_id=sandbox_id, path=payload.path
)
return jsonable_encoder(res)

View File

@@ -42,15 +42,7 @@ class SetupResponse(BaseModel):
tags=["console"],
)
def get_setup_status_api() -> SetupStatusResponse:
"""Get system setup status.
NOTE: This endpoint is unauthenticated by design.
During first-time bootstrap there is no admin account yet, so frontend initialization must be
able to query setup progress before any login flow exists.
Only bootstrap-safe status information should be returned by this endpoint.
"""
"""Get system setup status."""
if dify_config.EDITION == "SELF_HOSTED":
setup_status = get_setup_status()
if setup_status and not isinstance(setup_status, bool):
@@ -69,12 +61,7 @@ def get_setup_status_api() -> SetupStatusResponse:
)
@only_edition_self_hosted
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
"""Initialize system setup with admin account.
NOTE: This endpoint is unauthenticated by design for first-time bootstrap.
Access is restricted by deployment mode (`SELF_HOSTED`), one-time setup guards,
and init-password validation rather than user session authentication.
"""
"""Initialize system setup with admin account."""
if get_setup_status():
raise AlreadySetupError()

Some files were not shown because too many files have changed in this diff Show More