mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 23:20:12 -05:00
revert: revert human input relevant code (#31766)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -8,7 +8,6 @@ on:
|
|||||||
- "build/**"
|
- "build/**"
|
||||||
- "release/e-*"
|
- "release/e-*"
|
||||||
- "hotfix/**"
|
- "hotfix/**"
|
||||||
- "feat/hitl-backend"
|
|
||||||
tags:
|
tags:
|
||||||
- "*"
|
- "*"
|
||||||
|
|
||||||
|
|||||||
@@ -717,28 +717,3 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
|||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||||
|
|
||||||
|
|
||||||
# Redis URL used for PubSub between API and
|
|
||||||
# celery worker
|
|
||||||
# defaults to url constructed from `REDIS_*`
|
|
||||||
# configurations
|
|
||||||
PUBSUB_REDIS_URL=
|
|
||||||
# Pub/sub channel type for streaming events.
|
|
||||||
# valid options are:
|
|
||||||
#
|
|
||||||
# - pubsub: for normal Pub/Sub
|
|
||||||
# - sharded: for sharded Pub/Sub
|
|
||||||
#
|
|
||||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
|
||||||
# for large deployments.
|
|
||||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
|
||||||
# Whether to use Redis cluster mode while running
|
|
||||||
# PubSub.
|
|
||||||
# It's highly recommended to enable this for large deployments.
|
|
||||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
|
||||||
|
|
||||||
# Whether to Enable human input timeout check task
|
|
||||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
|
||||||
# Human input timeout check interval in minutes
|
|
||||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1
|
|
||||||
|
|||||||
@@ -36,8 +36,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||||
# TODO(QuantumGhost): fix the import violation later
|
|
||||||
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
|
|
||||||
|
|
||||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||||
name = Workflow Infrastructure Dependencies
|
name = Workflow Infrastructure Dependencies
|
||||||
@@ -60,8 +58,6 @@ ignore_imports =
|
|||||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
|
||||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
|
||||||
|
|
||||||
[importlinter:contract:workflow-external-imports]
|
[importlinter:contract:workflow-external-imports]
|
||||||
name = Workflow External Imports
|
name = Workflow External Imports
|
||||||
@@ -149,7 +145,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
|
||||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||||
@@ -253,7 +248,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||||
@@ -300,8 +294,6 @@ ignore_imports =
|
|||||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
|
||||||
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
|
|
||||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||||
core.workflow.nodes.agent.agent_node -> models
|
core.workflow.nodes.agent.agent_node -> models
|
||||||
core.workflow.nodes.base.node -> models.enums
|
core.workflow.nodes.base.node -> models.enums
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from datetime import timedelta
|
|
||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
@@ -49,16 +48,6 @@ class SecurityConfig(BaseSettings):
|
|||||||
default=5,
|
default=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
|
|
||||||
description="Maximum number of web form submissions allowed per IP within the rate limit window",
|
|
||||||
default=30,
|
|
||||||
)
|
|
||||||
|
|
||||||
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
|
|
||||||
description="Time window in seconds for web form submission rate limiting",
|
|
||||||
default=60,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOGIN_DISABLED: bool = Field(
|
LOGIN_DISABLED: bool = Field(
|
||||||
description="Whether to disable login checks",
|
description="Whether to disable login checks",
|
||||||
default=False,
|
default=False,
|
||||||
@@ -93,12 +82,6 @@ class AppExecutionConfig(BaseSettings):
|
|||||||
default=0,
|
default=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
|
|
||||||
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
|
|
||||||
default=int(timedelta(days=7).total_seconds()),
|
|
||||||
ge=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CodeExecutionSandboxConfig(BaseSettings):
|
class CodeExecutionSandboxConfig(BaseSettings):
|
||||||
"""
|
"""
|
||||||
@@ -1151,14 +1134,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
|||||||
description="Enable queue monitor task",
|
description="Enable queue monitor task",
|
||||||
default=False,
|
default=False,
|
||||||
)
|
)
|
||||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
|
|
||||||
description="Enable human input timeout check task",
|
|
||||||
default=True,
|
|
||||||
)
|
|
||||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
|
|
||||||
description="Human input timeout check interval in minutes",
|
|
||||||
default=1,
|
|
||||||
)
|
|
||||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
|
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
|
||||||
description="Enable check upgradable plugin task",
|
description="Enable check upgradable plugin task",
|
||||||
default=True,
|
default=True,
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
|
|||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
from .cache.redis_config import RedisConfig
|
from .cache.redis_config import RedisConfig
|
||||||
from .cache.redis_pubsub_config import RedisPubSubConfig
|
|
||||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||||
@@ -318,7 +317,6 @@ class MiddlewareConfig(
|
|||||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||||
KeywordStoreConfig,
|
KeywordStoreConfig,
|
||||||
RedisConfig,
|
RedisConfig,
|
||||||
RedisPubSubConfig,
|
|
||||||
# configs of storage and storage providers
|
# configs of storage and storage providers
|
||||||
StorageConfig,
|
StorageConfig,
|
||||||
AliyunOSSStorageConfig,
|
AliyunOSSStorageConfig,
|
||||||
|
|||||||
@@ -1,96 +0,0 @@
|
|||||||
from typing import Literal, Protocol
|
|
||||||
from urllib.parse import quote_plus, urlunparse
|
|
||||||
|
|
||||||
from pydantic import Field
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class RedisConfigDefaults(Protocol):
|
|
||||||
REDIS_HOST: str
|
|
||||||
REDIS_PORT: int
|
|
||||||
REDIS_USERNAME: str | None
|
|
||||||
REDIS_PASSWORD: str | None
|
|
||||||
REDIS_DB: int
|
|
||||||
REDIS_USE_SSL: bool
|
|
||||||
REDIS_USE_SENTINEL: bool | None
|
|
||||||
REDIS_USE_CLUSTERS: bool
|
|
||||||
|
|
||||||
|
|
||||||
class RedisConfigDefaultsMixin:
|
|
||||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
|
||||||
"""
|
|
||||||
Configuration settings for Redis pub/sub streaming.
|
|
||||||
"""
|
|
||||||
|
|
||||||
PUBSUB_REDIS_URL: str | None = Field(
|
|
||||||
alias="PUBSUB_REDIS_URL",
|
|
||||||
description=(
|
|
||||||
"Redis connection URL for pub/sub streaming events between API "
|
|
||||||
"and celery worker, defaults to url constructed from "
|
|
||||||
"`REDIS_*` configurations"
|
|
||||||
),
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
|
||||||
description=(
|
|
||||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
|
||||||
"recommended to enable this for large deployments."
|
|
||||||
),
|
|
||||||
default=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
|
||||||
description=(
|
|
||||||
"Pub/sub channel type for streaming events. "
|
|
||||||
"Valid options are:\n"
|
|
||||||
"\n"
|
|
||||||
" - pubsub: for normal Pub/Sub\n"
|
|
||||||
" - sharded: for sharded Pub/Sub\n"
|
|
||||||
"\n"
|
|
||||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
|
||||||
"for large deployments."
|
|
||||||
),
|
|
||||||
default="pubsub",
|
|
||||||
)
|
|
||||||
|
|
||||||
def _build_default_pubsub_url(self) -> str:
|
|
||||||
defaults = self._redis_defaults()
|
|
||||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
|
||||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
|
||||||
|
|
||||||
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
|
|
||||||
username = defaults.REDIS_USERNAME or None
|
|
||||||
password = defaults.REDIS_PASSWORD or None
|
|
||||||
|
|
||||||
userinfo = ""
|
|
||||||
if username:
|
|
||||||
userinfo = quote_plus(username)
|
|
||||||
if password:
|
|
||||||
password_part = quote_plus(password)
|
|
||||||
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
|
|
||||||
if userinfo:
|
|
||||||
userinfo = f"{userinfo}@"
|
|
||||||
|
|
||||||
host = defaults.REDIS_HOST
|
|
||||||
port = defaults.REDIS_PORT
|
|
||||||
db = defaults.REDIS_DB
|
|
||||||
|
|
||||||
netloc = f"{userinfo}{host}:{port}"
|
|
||||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def normalized_pubsub_redis_url(self) -> str:
|
|
||||||
pubsub_redis_url = self.PUBSUB_REDIS_URL
|
|
||||||
if pubsub_redis_url:
|
|
||||||
cleaned = pubsub_redis_url.strip()
|
|
||||||
pubsub_redis_url = cleaned or None
|
|
||||||
|
|
||||||
if pubsub_redis_url:
|
|
||||||
return pubsub_redis_url
|
|
||||||
|
|
||||||
return self._build_default_pubsub_url()
|
|
||||||
@@ -37,7 +37,6 @@ from . import (
|
|||||||
apikey,
|
apikey,
|
||||||
extension,
|
extension,
|
||||||
feature,
|
feature,
|
||||||
human_input_form,
|
|
||||||
init_validate,
|
init_validate,
|
||||||
ping,
|
ping,
|
||||||
setup,
|
setup,
|
||||||
@@ -172,7 +171,6 @@ __all__ = [
|
|||||||
"forgot_password",
|
"forgot_password",
|
||||||
"generator",
|
"generator",
|
||||||
"hit_testing",
|
"hit_testing",
|
||||||
"human_input_form",
|
|
||||||
"init_validate",
|
"init_validate",
|
||||||
"installed_app",
|
"installed_app",
|
||||||
"load_balancing_config",
|
"load_balancing_config",
|
||||||
|
|||||||
@@ -89,7 +89,6 @@ status_count_model = console_ns.model(
|
|||||||
"success": fields.Integer,
|
"success": fields.Integer,
|
||||||
"failed": fields.Integer,
|
"failed": fields.Integer,
|
||||||
"partial_success": fields.Integer,
|
"partial_success": fields.Integer,
|
||||||
"paused": fields.Integer,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ from libs.login import current_account_with_tenant, login_required
|
|||||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||||
from services.errors.conversation import ConversationNotExistsError
|
from services.errors.conversation import ConversationNotExistsError
|
||||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||||
from services.message_service import MessageService, attach_message_extra_contents
|
from services.message_service import MessageService
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||||
@@ -198,7 +198,6 @@ message_detail_model = console_ns.model(
|
|||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||||
"extra_contents": fields.List(fields.Raw),
|
|
||||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||||
"status": fields.String,
|
"status": fields.String,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
@@ -291,7 +290,6 @@ class ChatMessageListApi(Resource):
|
|||||||
has_more = False
|
has_more = False
|
||||||
|
|
||||||
history_messages = list(reversed(history_messages))
|
history_messages = list(reversed(history_messages))
|
||||||
attach_message_extra_contents(history_messages)
|
|
||||||
|
|
||||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||||
|
|
||||||
@@ -476,5 +474,4 @@ class MessageApi(Resource):
|
|||||||
if not message:
|
if not message:
|
||||||
raise NotFound("Message Not Exists.")
|
raise NotFound("Message Not Exists.")
|
||||||
|
|
||||||
attach_message_extra_contents([message])
|
|
||||||
return message
|
return message
|
||||||
|
|||||||
@@ -507,179 +507,6 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
|||||||
raise InternalServerError()
|
raise InternalServerError()
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormPreviewPayload(BaseModel):
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Values used to fill missing upstream variables referenced in form_content",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormSubmitPayload(BaseModel):
|
|
||||||
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
...,
|
|
||||||
description="Values used to fill missing upstream variables referenced in form_content",
|
|
||||||
)
|
|
||||||
action: str = Field(..., description="Selected action ID")
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputDeliveryTestPayload(BaseModel):
|
|
||||||
delivery_method_id: str = Field(..., description="Delivery method ID")
|
|
||||||
inputs: dict[str, Any] = Field(
|
|
||||||
default_factory=dict,
|
|
||||||
description="Values used to fill missing upstream variables referenced in form_content",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
reg(HumanInputFormPreviewPayload)
|
|
||||||
reg(HumanInputFormSubmitPayload)
|
|
||||||
reg(HumanInputDeliveryTestPayload)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
|
||||||
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
|
||||||
@console_ns.doc("get_advanced_chat_draft_human_input_form")
|
|
||||||
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
|
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
|
||||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
|
||||||
@edit_permission_required
|
|
||||||
def post(self, app_model: App, node_id: str):
|
|
||||||
"""
|
|
||||||
Preview human input form content and placeholders
|
|
||||||
"""
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
|
||||||
inputs = args.inputs
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
|
||||||
preview = workflow_service.get_human_input_form_preview(
|
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
inputs=inputs,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(preview)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
|
||||||
class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
|
||||||
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
|
|
||||||
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
|
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
|
||||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
|
||||||
@edit_permission_required
|
|
||||||
def post(self, app_model: App, node_id: str):
|
|
||||||
"""
|
|
||||||
Submit human input form preview
|
|
||||||
"""
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
|
||||||
workflow_service = WorkflowService()
|
|
||||||
result = workflow_service.submit_human_input_form_preview(
|
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
form_inputs=args.form_inputs,
|
|
||||||
inputs=args.inputs,
|
|
||||||
action=args.action,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
|
||||||
class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
|
||||||
@console_ns.doc("get_workflow_draft_human_input_form")
|
|
||||||
@console_ns.doc(description="Get human input form preview for workflow")
|
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
|
||||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
|
||||||
@edit_permission_required
|
|
||||||
def post(self, app_model: App, node_id: str):
|
|
||||||
"""
|
|
||||||
Preview human input form content and placeholders
|
|
||||||
"""
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
|
||||||
inputs = args.inputs
|
|
||||||
|
|
||||||
workflow_service = WorkflowService()
|
|
||||||
preview = workflow_service.get_human_input_form_preview(
|
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
inputs=inputs,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(preview)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
|
||||||
class WorkflowDraftHumanInputFormRunApi(Resource):
|
|
||||||
@console_ns.doc("submit_workflow_draft_human_input_form")
|
|
||||||
@console_ns.doc(description="Submit human input form preview for workflow")
|
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
|
||||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
|
||||||
@edit_permission_required
|
|
||||||
def post(self, app_model: App, node_id: str):
|
|
||||||
"""
|
|
||||||
Submit human input form preview
|
|
||||||
"""
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
workflow_service = WorkflowService()
|
|
||||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
|
||||||
result = workflow_service.submit_human_input_form_preview(
|
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
form_inputs=args.form_inputs,
|
|
||||||
inputs=args.inputs,
|
|
||||||
action=args.action,
|
|
||||||
)
|
|
||||||
return jsonable_encoder(result)
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
|
|
||||||
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
|
||||||
@console_ns.doc("test_workflow_draft_human_input_delivery")
|
|
||||||
@console_ns.doc(description="Test human input delivery for workflow")
|
|
||||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
|
||||||
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
|
||||||
@edit_permission_required
|
|
||||||
def post(self, app_model: App, node_id: str):
|
|
||||||
"""
|
|
||||||
Test human input delivery
|
|
||||||
"""
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
workflow_service = WorkflowService()
|
|
||||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
|
||||||
workflow_service.test_human_input_delivery(
|
|
||||||
app_model=app_model,
|
|
||||||
account=current_user,
|
|
||||||
node_id=node_id,
|
|
||||||
delivery_method_id=args.delivery_method_id,
|
|
||||||
inputs=args.inputs,
|
|
||||||
)
|
|
||||||
return jsonable_encoder({})
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
||||||
class DraftWorkflowRunApi(Resource):
|
class DraftWorkflowRunApi(Resource):
|
||||||
@console_ns.doc("run_draft_workflow")
|
@console_ns.doc("run_draft_workflow")
|
||||||
|
|||||||
@@ -5,15 +5,10 @@ from flask import request
|
|||||||
from flask_restx import Resource, fields, marshal_with
|
from flask_restx import Resource, fields, marshal_with
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from controllers.console import console_ns
|
from controllers.console import console_ns
|
||||||
from controllers.console.app.wraps import get_app_model
|
from controllers.console.app.wraps import get_app_model
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
from controllers.console.wraps import account_initialization_required, setup_required
|
||||||
from controllers.web.error import NotFoundError
|
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
|
||||||
from core.workflow.enums import WorkflowExecutionStatus
|
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.end_user_fields import simple_end_user_fields
|
from fields.end_user_fields import simple_end_user_fields
|
||||||
from fields.member_fields import simple_account_fields
|
from fields.member_fields import simple_account_fields
|
||||||
@@ -32,21 +27,9 @@ from libs.custom_inputs import time_duration
|
|||||||
from libs.helper import uuid_value
|
from libs.helper import uuid_value
|
||||||
from libs.login import current_user, login_required
|
from libs.login import current_user, login_required
|
||||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||||
from models.workflow import WorkflowRun
|
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
|
||||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||||
from services.workflow_run_service import WorkflowRunService
|
from services.workflow_run_service import WorkflowRunService
|
||||||
|
|
||||||
|
|
||||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
|
||||||
if not form_token:
|
|
||||||
return None
|
|
||||||
base_url = dify_config.APP_WEB_URL
|
|
||||||
if not base_url:
|
|
||||||
return None
|
|
||||||
return f"{base_url.rstrip('/')}/form/{form_token}"
|
|
||||||
|
|
||||||
|
|
||||||
# Workflow run status choices for filtering
|
# Workflow run status choices for filtering
|
||||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||||
@@ -457,63 +440,3 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return {"data": node_executions}
|
return {"data": node_executions}
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
|
||||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
|
||||||
"""Console API for getting workflow pause details."""
|
|
||||||
|
|
||||||
@account_initialization_required
|
|
||||||
@login_required
|
|
||||||
def get(self, workflow_run_id: str):
|
|
||||||
"""
|
|
||||||
Get workflow pause details.
|
|
||||||
|
|
||||||
GET /console/api/workflow/<workflow_run_id>/pause-details
|
|
||||||
|
|
||||||
Returns information about why and where the workflow is paused.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# Query WorkflowRun to determine if workflow is suspended
|
|
||||||
session_maker = sessionmaker(bind=db.engine)
|
|
||||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
|
||||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
|
||||||
if not workflow_run:
|
|
||||||
raise NotFoundError("Workflow run not found")
|
|
||||||
|
|
||||||
# Check if workflow is suspended
|
|
||||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
|
||||||
if not is_paused:
|
|
||||||
return {
|
|
||||||
"paused_at": None,
|
|
||||||
"paused_nodes": [],
|
|
||||||
}, 200
|
|
||||||
|
|
||||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
|
||||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
|
||||||
|
|
||||||
# Build response
|
|
||||||
paused_at = pause_entity.paused_at if pause_entity else None
|
|
||||||
paused_nodes = []
|
|
||||||
response = {
|
|
||||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
|
||||||
"paused_nodes": paused_nodes,
|
|
||||||
}
|
|
||||||
|
|
||||||
for reason in pause_reasons:
|
|
||||||
if isinstance(reason, HumanInputRequired):
|
|
||||||
paused_nodes.append(
|
|
||||||
{
|
|
||||||
"node_id": reason.node_id,
|
|
||||||
"node_title": reason.node_title,
|
|
||||||
"pause_type": {
|
|
||||||
"type": "human_input",
|
|
||||||
"form_id": reason.form_id,
|
|
||||||
"backstage_input_url": _build_backstage_input_url(reason.form_token),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raise AssertionError("unimplemented.")
|
|
||||||
|
|
||||||
return response, 200
|
|
||||||
|
|||||||
@@ -1,217 +0,0 @@
|
|||||||
"""
|
|
||||||
Console/Studio Human Input Form APIs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
from flask import Response, jsonify, request
|
|
||||||
from flask_restx import Resource, reqparse
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
|
||||||
|
|
||||||
from controllers.console import console_ns
|
|
||||||
from controllers.console.wraps import account_initialization_required, setup_required
|
|
||||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
|
||||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
|
||||||
from core.app.apps.message_generator import MessageGenerator
|
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.login import current_account_with_tenant, login_required
|
|
||||||
from models import App
|
|
||||||
from models.enums import CreatorUserRole
|
|
||||||
from models.human_input import RecipientType
|
|
||||||
from models.model import AppMode
|
|
||||||
from models.workflow import WorkflowRun
|
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
|
||||||
from services.human_input_service import Form, HumanInputService
|
|
||||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _jsonify_form_definition(form: Form) -> Response:
|
|
||||||
payload = form.get_definition().model_dump()
|
|
||||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
|
||||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/form/human_input/<string:form_token>")
|
|
||||||
class ConsoleHumanInputFormApi(Resource):
|
|
||||||
"""Console API for getting human input form definition."""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _ensure_console_access(form: Form):
|
|
||||||
_, current_tenant_id = current_account_with_tenant()
|
|
||||||
|
|
||||||
if form.tenant_id != current_tenant_id:
|
|
||||||
raise NotFoundError("App not found")
|
|
||||||
|
|
||||||
@setup_required
|
|
||||||
@login_required
|
|
||||||
@account_initialization_required
|
|
||||||
def get(self, form_token: str):
|
|
||||||
"""
|
|
||||||
Get human input form definition by form token.
|
|
||||||
|
|
||||||
GET /console/api/form/human_input/<form_token>
|
|
||||||
"""
|
|
||||||
service = HumanInputService(db.engine)
|
|
||||||
form = service.get_form_definition_by_token_for_console(form_token)
|
|
||||||
if form is None:
|
|
||||||
raise NotFoundError(f"form not found, token={form_token}")
|
|
||||||
|
|
||||||
self._ensure_console_access(form)
|
|
||||||
|
|
||||||
return _jsonify_form_definition(form)
|
|
||||||
|
|
||||||
@account_initialization_required
|
|
||||||
@login_required
|
|
||||||
def post(self, form_token: str):
|
|
||||||
"""
|
|
||||||
Submit human input form by form token.
|
|
||||||
|
|
||||||
POST /console/api/form/human_input/<form_token>
|
|
||||||
|
|
||||||
Request body:
|
|
||||||
{
|
|
||||||
"inputs": {
|
|
||||||
"content": "User input content"
|
|
||||||
},
|
|
||||||
"action": "Approve"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
parser.add_argument("action", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
current_user, _ = current_account_with_tenant()
|
|
||||||
|
|
||||||
service = HumanInputService(db.engine)
|
|
||||||
form = service.get_form_by_token(form_token)
|
|
||||||
if form is None:
|
|
||||||
raise NotFoundError(f"form not found, token={form_token}")
|
|
||||||
|
|
||||||
self._ensure_console_access(form)
|
|
||||||
|
|
||||||
recipient_type = form.recipient_type
|
|
||||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
|
||||||
raise NotFoundError(f"form not found, token={form_token}")
|
|
||||||
# The type checker is not smart enought to validate the following invariant.
|
|
||||||
# So we need to assert it manually.
|
|
||||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
|
||||||
|
|
||||||
service.submit_form_by_token(
|
|
||||||
recipient_type=recipient_type,
|
|
||||||
form_token=form_token,
|
|
||||||
selected_action_id=args["action"],
|
|
||||||
form_data=args["inputs"],
|
|
||||||
submission_user_id=current_user.id,
|
|
||||||
)
|
|
||||||
|
|
||||||
return jsonify({})
|
|
||||||
|
|
||||||
|
|
||||||
@console_ns.route("/workflow/<string:workflow_run_id>/events")
|
|
||||||
class ConsoleWorkflowEventsApi(Resource):
|
|
||||||
"""Console API for getting workflow execution events after resume."""
|
|
||||||
|
|
||||||
@account_initialization_required
|
|
||||||
@login_required
|
|
||||||
def get(self, workflow_run_id: str):
|
|
||||||
"""
|
|
||||||
Get workflow execution events stream after resume.
|
|
||||||
|
|
||||||
GET /console/api/workflow/<workflow_run_id>/events
|
|
||||||
|
|
||||||
Returns Server-Sent Events stream.
|
|
||||||
"""
|
|
||||||
|
|
||||||
user, tenant_id = current_account_with_tenant()
|
|
||||||
session_maker = sessionmaker(db.engine)
|
|
||||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
|
||||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
|
||||||
tenant_id=tenant_id,
|
|
||||||
run_id=workflow_run_id,
|
|
||||||
)
|
|
||||||
if workflow_run is None:
|
|
||||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
|
|
||||||
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.created_by != user.id:
|
|
||||||
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
|
|
||||||
|
|
||||||
with Session(expire_on_commit=False, bind=db.engine) as session:
|
|
||||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
|
||||||
|
|
||||||
if workflow_run.finished_at is not None:
|
|
||||||
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
|
|
||||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
|
||||||
task_id=workflow_run.id,
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
creator_user=user,
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = response.model_dump(mode="json")
|
|
||||||
payload["event"] = response.event.value
|
|
||||||
|
|
||||||
def _generate_finished_events() -> Generator[str, None, None]:
|
|
||||||
yield f"data: {json.dumps(payload)}\n\n"
|
|
||||||
|
|
||||||
event_generator = _generate_finished_events
|
|
||||||
|
|
||||||
else:
|
|
||||||
msg_generator = MessageGenerator()
|
|
||||||
if app.mode == AppMode.ADVANCED_CHAT:
|
|
||||||
generator = AdvancedChatAppGenerator()
|
|
||||||
elif app.mode == AppMode.WORKFLOW:
|
|
||||||
generator = WorkflowAppGenerator()
|
|
||||||
else:
|
|
||||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
|
||||||
|
|
||||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
|
||||||
|
|
||||||
def _generate_stream_events():
|
|
||||||
if include_state_snapshot:
|
|
||||||
return generator.convert_to_event_stream(
|
|
||||||
build_workflow_event_stream(
|
|
||||||
app_mode=AppMode(app.mode),
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
tenant_id=workflow_run.tenant_id,
|
|
||||||
app_id=workflow_run.app_id,
|
|
||||||
session_maker=session_maker,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return generator.convert_to_event_stream(
|
|
||||||
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
event_generator = _generate_stream_events
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
event_generator(),
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
|
||||||
query = select(App).where(
|
|
||||||
App.id == workflow_run.app_id,
|
|
||||||
App.tenant_id == workflow_run.tenant_id,
|
|
||||||
)
|
|
||||||
app = session.scalars(query).first()
|
|
||||||
if app is None:
|
|
||||||
raise AssertionError(
|
|
||||||
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
|
|
||||||
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return app
|
|
||||||
@@ -33,9 +33,8 @@ from core.workflow.graph_engine.manager import GraphEngineManager
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||||
from libs import helper
|
from libs import helper
|
||||||
from libs.helper import OptionalTimestampField, TimestampField
|
from libs.helper import TimestampField
|
||||||
from models.model import App, AppMode, EndUser
|
from models.model import App, AppMode, EndUser
|
||||||
from models.workflow import WorkflowRun
|
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
from repositories.factory import DifyAPIRepositoryFactory
|
||||||
from services.app_generate_service import AppGenerateService
|
from services.app_generate_service import AppGenerateService
|
||||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||||
@@ -64,32 +63,17 @@ class WorkflowLogQuery(BaseModel):
|
|||||||
|
|
||||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunStatusField(fields.Raw):
|
|
||||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
|
||||||
return obj.status.value
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowRunOutputsField(fields.Raw):
|
|
||||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
|
||||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
outputs = obj.outputs_dict
|
|
||||||
return outputs or {}
|
|
||||||
|
|
||||||
|
|
||||||
workflow_run_fields = {
|
workflow_run_fields = {
|
||||||
"id": fields.String,
|
"id": fields.String,
|
||||||
"workflow_id": fields.String,
|
"workflow_id": fields.String,
|
||||||
"status": WorkflowRunStatusField,
|
"status": fields.String,
|
||||||
"inputs": fields.Raw,
|
"inputs": fields.Raw,
|
||||||
"outputs": WorkflowRunOutputsField,
|
"outputs": fields.Raw,
|
||||||
"error": fields.String,
|
"error": fields.String,
|
||||||
"total_steps": fields.Integer,
|
"total_steps": fields.Integer,
|
||||||
"total_tokens": fields.Integer,
|
"total_tokens": fields.Integer,
|
||||||
"created_at": TimestampField,
|
"created_at": TimestampField,
|
||||||
"finished_at": OptionalTimestampField,
|
"finished_at": TimestampField,
|
||||||
"elapsed_time": fields.Float,
|
"elapsed_time": fields.Float,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ from . import (
|
|||||||
feature,
|
feature,
|
||||||
files,
|
files,
|
||||||
forgot_password,
|
forgot_password,
|
||||||
human_input_form,
|
|
||||||
login,
|
login,
|
||||||
message,
|
message,
|
||||||
passport,
|
passport,
|
||||||
@@ -31,7 +30,6 @@ from . import (
|
|||||||
saved_message,
|
saved_message,
|
||||||
site,
|
site,
|
||||||
workflow,
|
workflow,
|
||||||
workflow_events,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
api.add_namespace(web_ns)
|
api.add_namespace(web_ns)
|
||||||
@@ -46,7 +44,6 @@ __all__ = [
|
|||||||
"feature",
|
"feature",
|
||||||
"files",
|
"files",
|
||||||
"forgot_password",
|
"forgot_password",
|
||||||
"human_input_form",
|
|
||||||
"login",
|
"login",
|
||||||
"message",
|
"message",
|
||||||
"passport",
|
"passport",
|
||||||
@@ -55,5 +52,4 @@ __all__ = [
|
|||||||
"site",
|
"site",
|
||||||
"web_ns",
|
"web_ns",
|
||||||
"workflow",
|
"workflow",
|
||||||
"workflow_events",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -117,12 +117,6 @@ class InvokeRateLimitError(BaseHTTPException):
|
|||||||
code = 429
|
code = 429
|
||||||
|
|
||||||
|
|
||||||
class WebFormRateLimitExceededError(BaseHTTPException):
|
|
||||||
error_code = "web_form_rate_limit_exceeded"
|
|
||||||
description = "Too many form requests. Please try again later."
|
|
||||||
code = 429
|
|
||||||
|
|
||||||
|
|
||||||
class NotFoundError(BaseHTTPException):
|
class NotFoundError(BaseHTTPException):
|
||||||
error_code = "not_found"
|
error_code = "not_found"
|
||||||
code = 404
|
code = 404
|
||||||
|
|||||||
@@ -1,164 +0,0 @@
|
|||||||
"""
|
|
||||||
Web App Human Input Form APIs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from flask import Response, request
|
|
||||||
from flask_restx import Resource, reqparse
|
|
||||||
from werkzeug.exceptions import Forbidden
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from controllers.web import web_ns
|
|
||||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
|
||||||
from controllers.web.site import serialize_app_site_payload
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.helper import RateLimiter, extract_remote_ip
|
|
||||||
from models.account import TenantStatus
|
|
||||||
from models.model import App, Site
|
|
||||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
|
||||||
prefix="web_form_submit_rate_limit",
|
|
||||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
|
||||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
|
||||||
)
|
|
||||||
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
|
||||||
prefix="web_form_access_rate_limit",
|
|
||||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
|
||||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
|
||||||
result: dict[str, str] = {}
|
|
||||||
for key, value in values.items():
|
|
||||||
if value is None:
|
|
||||||
result[key] = ""
|
|
||||||
elif isinstance(value, (dict, list)):
|
|
||||||
result[key] = json.dumps(value, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
result[key] = str(value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _to_timestamp(value: datetime) -> int:
|
|
||||||
return int(value.timestamp())
|
|
||||||
|
|
||||||
|
|
||||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
|
||||||
"""Return the form payload (optionally with site) as a JSON response."""
|
|
||||||
definition_payload = form.get_definition().model_dump()
|
|
||||||
payload = {
|
|
||||||
"form_content": definition_payload["rendered_content"],
|
|
||||||
"inputs": definition_payload["inputs"],
|
|
||||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
|
||||||
"user_actions": definition_payload["user_actions"],
|
|
||||||
"expiration_time": _to_timestamp(form.expiration_time),
|
|
||||||
}
|
|
||||||
if site_payload is not None:
|
|
||||||
payload["site"] = site_payload
|
|
||||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO(QuantumGhost): disable authorization for web app
|
|
||||||
# form api temporarily
|
|
||||||
|
|
||||||
|
|
||||||
@web_ns.route("/form/human_input/<string:form_token>")
|
|
||||||
# class HumanInputFormApi(WebApiResource):
|
|
||||||
class HumanInputFormApi(Resource):
|
|
||||||
"""API for getting and submitting human input forms via the web app."""
|
|
||||||
|
|
||||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
|
||||||
def get(self, form_token: str):
|
|
||||||
"""
|
|
||||||
Get human input form definition by token.
|
|
||||||
|
|
||||||
GET /api/form/human_input/<form_token>
|
|
||||||
"""
|
|
||||||
ip_address = extract_remote_ip(request)
|
|
||||||
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
|
|
||||||
raise WebFormRateLimitExceededError()
|
|
||||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
|
||||||
|
|
||||||
service = HumanInputService(db.engine)
|
|
||||||
# TODO(QuantumGhost): forbid submision for form tokens
|
|
||||||
# that are only for console.
|
|
||||||
form = service.get_form_by_token(form_token)
|
|
||||||
|
|
||||||
if form is None:
|
|
||||||
raise NotFoundError("Form not found")
|
|
||||||
|
|
||||||
service.ensure_form_active(form)
|
|
||||||
app_model, site = _get_app_site_from_form(form)
|
|
||||||
|
|
||||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
|
||||||
|
|
||||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
|
||||||
def post(self, form_token: str):
|
|
||||||
"""
|
|
||||||
Submit human input form by token.
|
|
||||||
|
|
||||||
POST /api/form/human_input/<form_token>
|
|
||||||
|
|
||||||
Request body:
|
|
||||||
{
|
|
||||||
"inputs": {
|
|
||||||
"content": "User input content"
|
|
||||||
},
|
|
||||||
"action": "Approve"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
parser = reqparse.RequestParser()
|
|
||||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
|
||||||
parser.add_argument("action", type=str, required=True, location="json")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
ip_address = extract_remote_ip(request)
|
|
||||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
|
||||||
raise WebFormRateLimitExceededError()
|
|
||||||
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
|
|
||||||
|
|
||||||
service = HumanInputService(db.engine)
|
|
||||||
form = service.get_form_by_token(form_token)
|
|
||||||
if form is None:
|
|
||||||
raise NotFoundError("Form not found")
|
|
||||||
|
|
||||||
if (recipient_type := form.recipient_type) is None:
|
|
||||||
logger.warning("Recipient type is None for form, form_id=%", form.id)
|
|
||||||
raise AssertionError("Recipient type is None")
|
|
||||||
|
|
||||||
try:
|
|
||||||
service.submit_form_by_token(
|
|
||||||
recipient_type=recipient_type,
|
|
||||||
form_token=form_token,
|
|
||||||
selected_action_id=args["action"],
|
|
||||||
form_data=args["inputs"],
|
|
||||||
submission_end_user_id=None,
|
|
||||||
# submission_end_user_id=_end_user.id,
|
|
||||||
)
|
|
||||||
except FormNotFoundError:
|
|
||||||
raise NotFoundError("Form not found")
|
|
||||||
|
|
||||||
return {}, 200
|
|
||||||
|
|
||||||
|
|
||||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
|
||||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
|
||||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
|
||||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
|
||||||
raise NotFoundError("Form not found")
|
|
||||||
|
|
||||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
|
||||||
if site is None:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
|
||||||
raise Forbidden()
|
|
||||||
|
|
||||||
return app_model, site
|
|
||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import cast
|
from flask_restx import fields, marshal_with
|
||||||
|
|
||||||
from flask_restx import fields, marshal, marshal_with
|
|
||||||
from werkzeug.exceptions import Forbidden
|
from werkzeug.exceptions import Forbidden
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
@@ -9,7 +7,7 @@ from controllers.web.wraps import WebApiResource
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.helper import AppIconUrlField
|
from libs.helper import AppIconUrlField
|
||||||
from models.account import TenantStatus
|
from models.account import TenantStatus
|
||||||
from models.model import App, Site
|
from models.model import Site
|
||||||
from services.feature_service import FeatureService
|
from services.feature_service import FeatureService
|
||||||
|
|
||||||
|
|
||||||
@@ -110,14 +108,3 @@ class AppSiteInfo:
|
|||||||
"remove_webapp_brand": remove_webapp_brand,
|
"remove_webapp_brand": remove_webapp_brand,
|
||||||
"replace_webapp_logo": replace_webapp_logo,
|
"replace_webapp_logo": replace_webapp_logo,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def serialize_site(site: Site) -> dict:
|
|
||||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
|
||||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
|
||||||
|
|
||||||
|
|
||||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
|
||||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
|
||||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
|
||||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
|
||||||
|
|||||||
@@ -1,112 +0,0 @@
|
|||||||
"""
|
|
||||||
Web App Workflow Resume APIs.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
from flask import Response, request
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
|
||||||
|
|
||||||
from controllers.web import api
|
|
||||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
|
||||||
from controllers.web.wraps import WebApiResource
|
|
||||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
|
||||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
|
||||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
|
||||||
from core.app.apps.message_generator import MessageGenerator
|
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from models.enums import CreatorUserRole
|
|
||||||
from models.model import App, AppMode, EndUser
|
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
|
||||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowEventsApi(WebApiResource):
|
|
||||||
"""API for getting workflow execution events after resume."""
|
|
||||||
|
|
||||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
|
||||||
"""
|
|
||||||
Get workflow execution events stream after resume.
|
|
||||||
|
|
||||||
GET /api/workflow/<task_id>/events
|
|
||||||
|
|
||||||
Returns Server-Sent Events stream.
|
|
||||||
"""
|
|
||||||
workflow_run_id = task_id
|
|
||||||
session_maker = sessionmaker(db.engine)
|
|
||||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
|
||||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
|
||||||
tenant_id=app_model.tenant_id,
|
|
||||||
run_id=workflow_run_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if workflow_run is None:
|
|
||||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.app_id != app_model.id:
|
|
||||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
|
||||||
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.created_by != end_user.id:
|
|
||||||
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
|
|
||||||
|
|
||||||
if workflow_run.finished_at is not None:
|
|
||||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
|
||||||
task_id=workflow_run.id,
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
creator_user=end_user,
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = response.model_dump(mode="json")
|
|
||||||
payload["event"] = response.event.value
|
|
||||||
|
|
||||||
def _generate_finished_events() -> Generator[str, None, None]:
|
|
||||||
yield f"data: {json.dumps(payload)}\n\n"
|
|
||||||
|
|
||||||
event_generator = _generate_finished_events
|
|
||||||
else:
|
|
||||||
app_mode = AppMode.value_of(app_model.mode)
|
|
||||||
msg_generator = MessageGenerator()
|
|
||||||
generator: BaseAppGenerator
|
|
||||||
if app_mode == AppMode.ADVANCED_CHAT:
|
|
||||||
generator = AdvancedChatAppGenerator()
|
|
||||||
elif app_mode == AppMode.WORKFLOW:
|
|
||||||
generator = WorkflowAppGenerator()
|
|
||||||
else:
|
|
||||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
|
||||||
|
|
||||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
|
||||||
|
|
||||||
def _generate_stream_events():
|
|
||||||
if include_state_snapshot:
|
|
||||||
return generator.convert_to_event_stream(
|
|
||||||
build_workflow_event_stream(
|
|
||||||
app_mode=app_mode,
|
|
||||||
workflow_run=workflow_run,
|
|
||||||
tenant_id=app_model.tenant_id,
|
|
||||||
app_id=app_model.id,
|
|
||||||
session_maker=session_maker,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return generator.convert_to_event_stream(
|
|
||||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
|
||||||
)
|
|
||||||
|
|
||||||
event_generator = _generate_stream_events
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
event_generator(),
|
|
||||||
mimetype="text/event-stream",
|
|
||||||
headers={
|
|
||||||
"Cache-Control": "no-cache",
|
|
||||||
"Connection": "keep-alive",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Register the APIs
|
|
||||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
|
||||||
@@ -4,8 +4,8 @@ import contextvars
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
from collections.abc import Generator, Mapping
|
||||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||||
|
|
||||||
from flask import Flask, current_app
|
from flask import Flask, current_app
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@@ -29,25 +29,21 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
|||||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
|
||||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||||
from core.repositories import DifyCoreRepositoryFactory
|
from core.repositories import DifyCoreRepositoryFactory
|
||||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
|
||||||
from core.workflow.repositories.draft_variable_repository import (
|
from core.workflow.repositories.draft_variable_repository import (
|
||||||
DraftVariableSaverFactory,
|
DraftVariableSaverFactory,
|
||||||
)
|
)
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.base import Base
|
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from services.conversation_service import ConversationService
|
from services.conversation_service import ConversationService
|
||||||
from services.workflow_draft_variable_service import (
|
from services.workflow_draft_variable_service import (
|
||||||
@@ -69,9 +65,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: Mapping[str, Any],
|
args: Mapping[str, Any],
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Mapping[str, Any]: ...
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -80,11 +74,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: Mapping[str, Any],
|
args: Mapping,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Generator[Mapping | str, None, None]: ...
|
) -> Generator[Mapping | str, None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -93,11 +85,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: Mapping[str, Any],
|
args: Mapping,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -105,11 +95,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
app_model: App,
|
app_model: App,
|
||||||
workflow: Workflow,
|
workflow: Workflow,
|
||||||
user: Union[Account, EndUser],
|
user: Union[Account, EndUser],
|
||||||
args: Mapping[str, Any],
|
args: Mapping,
|
||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
workflow_run_id: str,
|
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
@@ -173,6 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
# always enable retriever resource in debugger mode
|
# always enable retriever resource in debugger mode
|
||||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||||
|
|
||||||
|
workflow_run_id = str(uuid.uuid4())
|
||||||
# init application generate entity
|
# init application generate entity
|
||||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||||
task_id=str(uuid.uuid4()),
|
task_id=str(uuid.uuid4()),
|
||||||
@@ -190,7 +179,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
invoke_from=invoke_from,
|
invoke_from=invoke_from,
|
||||||
extras=extras,
|
extras=extras,
|
||||||
trace_manager=trace_manager,
|
trace_manager=trace_manager,
|
||||||
workflow_run_id=str(workflow_run_id),
|
workflow_run_id=workflow_run_id,
|
||||||
)
|
)
|
||||||
contexts.plugin_tool_providers.set({})
|
contexts.plugin_tool_providers.set({})
|
||||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||||
@@ -227,38 +216,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
stream=streaming,
|
stream=streaming,
|
||||||
pause_state_config=pause_state_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def resume(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_model: App,
|
|
||||||
workflow: Workflow,
|
|
||||||
user: Union[Account, EndUser],
|
|
||||||
conversation: Conversation,
|
|
||||||
message: Message,
|
|
||||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Resume a paused advanced chat execution.
|
|
||||||
"""
|
|
||||||
return self._generate(
|
|
||||||
workflow=workflow,
|
|
||||||
user=user,
|
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
|
||||||
conversation=conversation,
|
|
||||||
message=message,
|
|
||||||
stream=application_generate_entity.stream,
|
|
||||||
pause_state_config=pause_state_config,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def single_iteration_generate(
|
def single_iteration_generate(
|
||||||
@@ -439,12 +396,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
conversation: Conversation | None = None,
|
conversation: Conversation | None = None,
|
||||||
message: Message | None = None,
|
|
||||||
stream: bool = True,
|
stream: bool = True,
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
|
||||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
@@ -458,12 +411,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
:param conversation: conversation
|
:param conversation: conversation
|
||||||
:param stream: is stream
|
:param stream: is stream
|
||||||
"""
|
"""
|
||||||
is_first_conversation = conversation is None
|
is_first_conversation = False
|
||||||
|
if not conversation:
|
||||||
|
is_first_conversation = True
|
||||||
|
|
||||||
if conversation is not None and message is not None:
|
# init generate records
|
||||||
pass
|
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||||
else:
|
|
||||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
|
||||||
|
|
||||||
if is_first_conversation:
|
if is_first_conversation:
|
||||||
# update conversation features
|
# update conversation features
|
||||||
@@ -486,16 +439,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
message_id=message.id,
|
message_id=message.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
|
||||||
if pause_state_config is not None:
|
|
||||||
graph_layers.append(
|
|
||||||
PauseStatePersistenceLayer(
|
|
||||||
session_factory=pause_state_config.session_factory,
|
|
||||||
generate_entity=application_generate_entity,
|
|
||||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# new thread with request context and contextvars
|
# new thread with request context and contextvars
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
@@ -511,25 +454,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
"variable_loader": variable_loader,
|
"variable_loader": variable_loader,
|
||||||
"workflow_execution_repository": workflow_execution_repository,
|
"workflow_execution_repository": workflow_execution_repository,
|
||||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||||
"graph_engine_layers": tuple(graph_layers),
|
|
||||||
"graph_runtime_state": graph_runtime_state,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
worker_thread.start()
|
worker_thread.start()
|
||||||
|
|
||||||
# release database connection, because the following new thread operations may take a long time
|
# release database connection, because the following new thread operations may take a long time
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
db.session.refresh(workflow)
|
||||||
workflow = _refresh_model(session, workflow)
|
db.session.refresh(message)
|
||||||
message = _refresh_model(session, message)
|
|
||||||
# workflow_ = session.get(Workflow, workflow.id)
|
|
||||||
# assert workflow_ is not None
|
|
||||||
# workflow = workflow_
|
|
||||||
# message_ = session.get(Message, message.id)
|
|
||||||
# assert message_ is not None
|
|
||||||
# message = message_
|
|
||||||
# db.session.refresh(workflow)
|
|
||||||
# db.session.refresh(message)
|
|
||||||
# db.session.refresh(user)
|
# db.session.refresh(user)
|
||||||
db.session.close()
|
db.session.close()
|
||||||
|
|
||||||
@@ -558,8 +490,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
variable_loader: VariableLoader,
|
variable_loader: VariableLoader,
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
@@ -617,8 +547,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
app=app,
|
app=app,
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
workflow_execution_repository=workflow_execution_repository,
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
graph_engine_layers=graph_engine_layers,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -686,13 +614,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
|||||||
else:
|
else:
|
||||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
_T = TypeVar("_T", bound=Base)
|
|
||||||
|
|
||||||
|
|
||||||
def _refresh_model(session, model: _T) -> _T:
|
|
||||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
|
||||||
detach_model = session.get(type(model), model.id)
|
|
||||||
assert detach_model is not None
|
|
||||||
return detach_model
|
|
||||||
|
|||||||
@@ -66,7 +66,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@@ -83,7 +82,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
self._app = app
|
self._app = app
|
||||||
self._workflow_execution_repository = workflow_execution_repository
|
self._workflow_execution_repository = workflow_execution_repository
|
||||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||||
self._resume_graph_runtime_state = graph_runtime_state
|
|
||||||
|
|
||||||
@trace_span(WorkflowAppRunnerHandler)
|
@trace_span(WorkflowAppRunnerHandler)
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -112,21 +110,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
|||||||
invoke_from = InvokeFrom.DEBUGGER
|
invoke_from = InvokeFrom.DEBUGGER
|
||||||
user_from = self._resolve_user_from(invoke_from)
|
user_from = self._resolve_user_from(invoke_from)
|
||||||
|
|
||||||
resume_state = self._resume_graph_runtime_state
|
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||||
|
|
||||||
if resume_state is not None:
|
|
||||||
graph_runtime_state = resume_state
|
|
||||||
variable_pool = graph_runtime_state.variable_pool
|
|
||||||
graph = self._init_graph(
|
|
||||||
graph_config=self._workflow.graph_dict,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
workflow_id=self._workflow.id,
|
|
||||||
tenant_id=self._workflow.tenant_id,
|
|
||||||
user_id=self.application_generate_entity.user_id,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
user_from=user_from,
|
|
||||||
)
|
|
||||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
|
||||||
# Handle single iteration or single loop run
|
# Handle single iteration or single loop run
|
||||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
|
|||||||
@@ -24,8 +24,6 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueAnnotationReplyEvent,
|
QueueAnnotationReplyEvent,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
QueueHumanInputFormFilledEvent,
|
|
||||||
QueueHumanInputFormTimeoutEvent,
|
|
||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
@@ -44,7 +42,6 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
QueueWorkflowPartialSuccessEvent,
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowPausedEvent,
|
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
WorkflowQueueMessage,
|
WorkflowQueueMessage,
|
||||||
@@ -66,8 +63,6 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
|||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
|
||||||
from core.workflow.enums import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
@@ -76,8 +71,7 @@ from core.workflow.system_variable import SystemVariable
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||||
from models.enums import CreatorUserRole, MessageStatus
|
from models.enums import CreatorUserRole
|
||||||
from models.execution_extra_content import HumanInputContent
|
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -134,7 +128,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self._task_state = WorkflowTaskState()
|
self._task_state = WorkflowTaskState()
|
||||||
self._seed_task_state_from_message(message)
|
|
||||||
self._message_cycle_manager = MessageCycleManager(
|
self._message_cycle_manager = MessageCycleManager(
|
||||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||||
)
|
)
|
||||||
@@ -142,7 +135,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
self._application_generate_entity = application_generate_entity
|
self._application_generate_entity = application_generate_entity
|
||||||
self._workflow_id = workflow.id
|
self._workflow_id = workflow.id
|
||||||
self._workflow_features_dict = workflow.features_dict
|
self._workflow_features_dict = workflow.features_dict
|
||||||
self._workflow_tenant_id = workflow.tenant_id
|
|
||||||
self._conversation_id = conversation.id
|
self._conversation_id = conversation.id
|
||||||
self._conversation_mode = conversation.mode
|
self._conversation_mode = conversation.mode
|
||||||
self._message_id = message.id
|
self._message_id = message.id
|
||||||
@@ -152,13 +144,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
self._workflow_run_id: str = ""
|
self._workflow_run_id: str = ""
|
||||||
self._draft_var_saver_factory = draft_var_saver_factory
|
self._draft_var_saver_factory = draft_var_saver_factory
|
||||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||||
self._message_saved_on_pause = False
|
|
||||||
self._seed_graph_runtime_state_from_queue_manager()
|
self._seed_graph_runtime_state_from_queue_manager()
|
||||||
|
|
||||||
def _seed_task_state_from_message(self, message: Message) -> None:
|
|
||||||
if message.status == MessageStatus.PAUSED and message.answer:
|
|
||||||
self._task_state.answer = message.answer
|
|
||||||
|
|
||||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||||
"""
|
"""
|
||||||
Process generate task pipeline.
|
Process generate task pipeline.
|
||||||
@@ -321,7 +308,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run_id=run_id,
|
workflow_run_id=run_id,
|
||||||
workflow_id=self._workflow_id,
|
workflow_id=self._workflow_id,
|
||||||
reason=event.reason,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
yield workflow_start_resp
|
yield workflow_start_resp
|
||||||
@@ -539,35 +525,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
)
|
)
|
||||||
|
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
|
|
||||||
def _handle_workflow_paused_event(
|
|
||||||
self,
|
|
||||||
event: QueueWorkflowPausedEvent,
|
|
||||||
**kwargs,
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle workflow paused events."""
|
|
||||||
validated_state = self._ensure_graph_runtime_initialized()
|
|
||||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
graph_runtime_state=validated_state,
|
|
||||||
)
|
|
||||||
for reason in event.reasons:
|
|
||||||
if isinstance(reason, HumanInputRequired):
|
|
||||||
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
|
|
||||||
yield from responses
|
|
||||||
resolved_state: GraphRuntimeState | None = None
|
|
||||||
try:
|
|
||||||
resolved_state = self._ensure_graph_runtime_initialized()
|
|
||||||
except ValueError:
|
|
||||||
resolved_state = None
|
|
||||||
|
|
||||||
with self._database_session() as session:
|
|
||||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
|
||||||
message = self._get_message(session=session)
|
|
||||||
if message is not None:
|
|
||||||
message.status = MessageStatus.PAUSED
|
|
||||||
self._message_saved_on_pause = True
|
|
||||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||||
|
|
||||||
def _handle_workflow_failed_event(
|
def _handle_workflow_failed_event(
|
||||||
@@ -657,10 +614,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save message unless it has already been persisted on pause.
|
# Save message
|
||||||
if not self._message_saved_on_pause:
|
with self._database_session() as session:
|
||||||
with self._database_session() as session:
|
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
|
||||||
|
|
||||||
yield self._message_end_to_stream_response()
|
yield self._message_end_to_stream_response()
|
||||||
|
|
||||||
@@ -686,65 +642,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
"""Handle message replace events."""
|
"""Handle message replace events."""
|
||||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
|
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
|
||||||
|
|
||||||
def _handle_human_input_form_filled_event(
|
|
||||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle human input form filled events."""
|
|
||||||
self._persist_human_input_extra_content(node_id=event.node_id)
|
|
||||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
|
||||||
event=event, task_id=self._application_generate_entity.task_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_human_input_form_timeout_event(
|
|
||||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle human input form timeout events."""
|
|
||||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
|
||||||
event=event, task_id=self._application_generate_entity.task_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
|
|
||||||
if not self._workflow_run_id or not self._message_id:
|
|
||||||
return
|
|
||||||
|
|
||||||
if form_id is None:
|
|
||||||
if node_id is None:
|
|
||||||
return
|
|
||||||
form_id = self._load_human_input_form_id(node_id=node_id)
|
|
||||||
if form_id is None:
|
|
||||||
logger.warning(
|
|
||||||
"HumanInput form not found for workflow run %s node %s",
|
|
||||||
self._workflow_run_id,
|
|
||||||
node_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
with self._database_session() as session:
|
|
||||||
exists_stmt = select(HumanInputContent).where(
|
|
||||||
HumanInputContent.workflow_run_id == self._workflow_run_id,
|
|
||||||
HumanInputContent.message_id == self._message_id,
|
|
||||||
HumanInputContent.form_id == form_id,
|
|
||||||
)
|
|
||||||
if session.scalar(exists_stmt) is not None:
|
|
||||||
return
|
|
||||||
|
|
||||||
content = HumanInputContent(
|
|
||||||
workflow_run_id=self._workflow_run_id,
|
|
||||||
message_id=self._message_id,
|
|
||||||
form_id=form_id,
|
|
||||||
)
|
|
||||||
session.add(content)
|
|
||||||
|
|
||||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
|
||||||
form_repository = HumanInputFormRepositoryImpl(
|
|
||||||
session_factory=db.engine,
|
|
||||||
tenant_id=self._workflow_tenant_id,
|
|
||||||
)
|
|
||||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
|
||||||
if form is None:
|
|
||||||
return None
|
|
||||||
return form.id
|
|
||||||
|
|
||||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||||
"""Handle agent log events."""
|
"""Handle agent log events."""
|
||||||
yield self._workflow_response_converter.handle_agent_log(
|
yield self._workflow_response_converter.handle_agent_log(
|
||||||
@@ -762,7 +659,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
|
||||||
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
||||||
# Node events
|
# Node events
|
||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
@@ -784,8 +680,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
QueueMessageReplaceEvent: self._handle_message_replace_event,
|
QueueMessageReplaceEvent: self._handle_message_replace_event,
|
||||||
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
|
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
|
||||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
|
||||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _dispatch_event(
|
def _dispatch_event(
|
||||||
@@ -853,9 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
case QueueWorkflowFailedEvent():
|
case QueueWorkflowFailedEvent():
|
||||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||||
break
|
break
|
||||||
case QueueWorkflowPausedEvent():
|
|
||||||
yield from self._handle_workflow_paused_event(event)
|
|
||||||
break
|
|
||||||
|
|
||||||
case QueueStopEvent():
|
case QueueStopEvent():
|
||||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||||
@@ -881,11 +772,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
|
|
||||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||||
message = self._get_message(session=session)
|
message = self._get_message(session=session)
|
||||||
if message is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
if message.status == MessageStatus.PAUSED:
|
|
||||||
message.status = MessageStatus.NORMAL
|
|
||||||
|
|
||||||
# If there are assistant files, remove markdown image links from answer
|
# If there are assistant files, remove markdown image links from answer
|
||||||
answer_text = self._task_state.answer
|
answer_text = self._task_state.answer
|
||||||
|
|||||||
@@ -5,14 +5,9 @@ from dataclasses import dataclass
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, NewType, Union
|
from typing import Any, NewType, Union
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueHumanInputFormFilledEvent,
|
|
||||||
QueueHumanInputFormTimeoutEvent,
|
|
||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
@@ -24,13 +19,9 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueNodeRetryEvent,
|
QueueNodeRetryEvent,
|
||||||
QueueNodeStartedEvent,
|
QueueNodeStartedEvent,
|
||||||
QueueNodeSucceededEvent,
|
QueueNodeSucceededEvent,
|
||||||
QueueWorkflowPausedEvent,
|
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
AgentLogStreamResponse,
|
AgentLogStreamResponse,
|
||||||
HumanInputFormFilledResponse,
|
|
||||||
HumanInputFormTimeoutResponse,
|
|
||||||
HumanInputRequiredResponse,
|
|
||||||
IterationNodeCompletedStreamResponse,
|
IterationNodeCompletedStreamResponse,
|
||||||
IterationNodeNextStreamResponse,
|
IterationNodeNextStreamResponse,
|
||||||
IterationNodeStartStreamResponse,
|
IterationNodeStartStreamResponse,
|
||||||
@@ -40,9 +31,7 @@ from core.app.entities.task_entities import (
|
|||||||
NodeFinishStreamResponse,
|
NodeFinishStreamResponse,
|
||||||
NodeRetryStreamResponse,
|
NodeRetryStreamResponse,
|
||||||
NodeStartStreamResponse,
|
NodeStartStreamResponse,
|
||||||
StreamResponse,
|
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowPauseStreamResponse,
|
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
)
|
)
|
||||||
from core.file import FILE_MODEL_IDENTITY, File
|
from core.file import FILE_MODEL_IDENTITY, File
|
||||||
@@ -51,8 +40,6 @@ from core.tools.entities.tool_entities import ToolProviderType
|
|||||||
from core.tools.tool_manager import ToolManager
|
from core.tools.tool_manager import ToolManager
|
||||||
from core.trigger.trigger_manager import TriggerManager
|
from core.trigger.trigger_manager import TriggerManager
|
||||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.enums import (
|
from core.workflow.enums import (
|
||||||
NodeType,
|
NodeType,
|
||||||
SystemVariableKey,
|
SystemVariableKey,
|
||||||
@@ -64,11 +51,8 @@ from core.workflow.runtime import GraphRuntimeState
|
|||||||
from core.workflow.system_variable import SystemVariable
|
from core.workflow.system_variable import SystemVariable
|
||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import Account, EndUser
|
from models import Account, EndUser
|
||||||
from models.human_input import HumanInputForm
|
|
||||||
from models.workflow import WorkflowRun
|
|
||||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||||
|
|
||||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||||
@@ -207,7 +191,6 @@ class WorkflowResponseConverter:
|
|||||||
task_id: str,
|
task_id: str,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
workflow_id: str,
|
workflow_id: str,
|
||||||
reason: WorkflowStartReason,
|
|
||||||
) -> WorkflowStartStreamResponse:
|
) -> WorkflowStartStreamResponse:
|
||||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||||
started_at = naive_utc_now()
|
started_at = naive_utc_now()
|
||||||
@@ -221,7 +204,6 @@ class WorkflowResponseConverter:
|
|||||||
workflow_id=workflow_id,
|
workflow_id=workflow_id,
|
||||||
inputs=self._workflow_inputs,
|
inputs=self._workflow_inputs,
|
||||||
created_at=int(started_at.timestamp()),
|
created_at=int(started_at.timestamp()),
|
||||||
reason=reason,
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -282,160 +264,6 @@ class WorkflowResponseConverter:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
def workflow_pause_to_stream_response(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
event: QueueWorkflowPausedEvent,
|
|
||||||
task_id: str,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
) -> list[StreamResponse]:
|
|
||||||
run_id = self._ensure_workflow_run_id()
|
|
||||||
started_at = self._workflow_started_at
|
|
||||||
if started_at is None:
|
|
||||||
raise ValueError(
|
|
||||||
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
|
|
||||||
)
|
|
||||||
paused_at = naive_utc_now()
|
|
||||||
elapsed_time = (paused_at - started_at).total_seconds()
|
|
||||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
|
||||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
|
||||||
encoded_outputs = {}
|
|
||||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
|
||||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
|
||||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
|
||||||
if human_input_form_ids:
|
|
||||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
|
||||||
HumanInputForm.id.in_(human_input_form_ids)
|
|
||||||
)
|
|
||||||
with Session(bind=db.engine) as session:
|
|
||||||
for form_id, expiration_time in session.execute(stmt):
|
|
||||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
|
||||||
|
|
||||||
responses: list[StreamResponse] = []
|
|
||||||
|
|
||||||
for reason in event.reasons:
|
|
||||||
if isinstance(reason, HumanInputRequired):
|
|
||||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
|
||||||
if expiration_time is None:
|
|
||||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
|
||||||
responses.append(
|
|
||||||
HumanInputRequiredResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
data=HumanInputRequiredResponse.Data(
|
|
||||||
form_id=reason.form_id,
|
|
||||||
node_id=reason.node_id,
|
|
||||||
node_title=reason.node_title,
|
|
||||||
form_content=reason.form_content,
|
|
||||||
inputs=reason.inputs,
|
|
||||||
actions=reason.actions,
|
|
||||||
display_in_ui=reason.display_in_ui,
|
|
||||||
form_token=reason.form_token,
|
|
||||||
resolved_default_values=reason.resolved_default_values,
|
|
||||||
expiration_time=int(expiration_time.timestamp()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
responses.append(
|
|
||||||
WorkflowPauseStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
data=WorkflowPauseStreamResponse.Data(
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
paused_nodes=list(event.paused_nodes),
|
|
||||||
outputs=encoded_outputs,
|
|
||||||
reasons=pause_reasons,
|
|
||||||
status=WorkflowExecutionStatus.PAUSED.value,
|
|
||||||
created_at=int(started_at.timestamp()),
|
|
||||||
elapsed_time=elapsed_time,
|
|
||||||
total_tokens=graph_runtime_state.total_tokens,
|
|
||||||
total_steps=graph_runtime_state.node_run_steps,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return responses
|
|
||||||
|
|
||||||
def human_input_form_filled_to_stream_response(
|
|
||||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
|
||||||
) -> HumanInputFormFilledResponse:
|
|
||||||
run_id = self._ensure_workflow_run_id()
|
|
||||||
return HumanInputFormFilledResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
data=HumanInputFormFilledResponse.Data(
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_title=event.node_title,
|
|
||||||
rendered_content=event.rendered_content,
|
|
||||||
action_id=event.action_id,
|
|
||||||
action_text=event.action_text,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def human_input_form_timeout_to_stream_response(
|
|
||||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
|
||||||
) -> HumanInputFormTimeoutResponse:
|
|
||||||
run_id = self._ensure_workflow_run_id()
|
|
||||||
return HumanInputFormTimeoutResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
data=HumanInputFormTimeoutResponse.Data(
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_title=event.node_title,
|
|
||||||
expiration_time=int(event.expiration_time.timestamp()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def workflow_run_result_to_finish_response(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
task_id: str,
|
|
||||||
workflow_run: WorkflowRun,
|
|
||||||
creator_user: Account | EndUser,
|
|
||||||
) -> WorkflowFinishStreamResponse:
|
|
||||||
run_id = workflow_run.id
|
|
||||||
elapsed_time = workflow_run.elapsed_time
|
|
||||||
|
|
||||||
encoded_outputs = workflow_run.outputs_dict
|
|
||||||
finished_at = workflow_run.finished_at
|
|
||||||
assert finished_at is not None
|
|
||||||
|
|
||||||
created_by: Mapping[str, object]
|
|
||||||
user = creator_user
|
|
||||||
if isinstance(user, Account):
|
|
||||||
created_by = {
|
|
||||||
"id": user.id,
|
|
||||||
"name": user.name,
|
|
||||||
"email": user.email,
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
created_by = {
|
|
||||||
"id": user.id,
|
|
||||||
"user": user.session_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
return WorkflowFinishStreamResponse(
|
|
||||||
task_id=task_id,
|
|
||||||
workflow_run_id=run_id,
|
|
||||||
data=WorkflowFinishStreamResponse.Data(
|
|
||||||
id=run_id,
|
|
||||||
workflow_id=workflow_run.workflow_id,
|
|
||||||
status=workflow_run.status.value,
|
|
||||||
outputs=encoded_outputs,
|
|
||||||
error=workflow_run.error,
|
|
||||||
elapsed_time=elapsed_time,
|
|
||||||
total_tokens=workflow_run.total_tokens,
|
|
||||||
total_steps=workflow_run.total_steps,
|
|
||||||
created_by=created_by,
|
|
||||||
created_at=int(workflow_run.created_at.timestamp()),
|
|
||||||
finished_at=int(finished_at.timestamp()),
|
|
||||||
files=cls.fetch_files_from_node_outputs(encoded_outputs),
|
|
||||||
exceptions_count=workflow_run.exceptions_count,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
def workflow_node_start_to_stream_response(
|
def workflow_node_start_to_stream_response(
|
||||||
self,
|
self,
|
||||||
*,
|
*,
|
||||||
@@ -764,8 +592,7 @@ class WorkflowResponseConverter:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||||
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
|
||||||
"""
|
"""
|
||||||
Fetch files from node outputs
|
Fetch files from node outputs
|
||||||
:param outputs_dict: node outputs dict
|
:param outputs_dict: node outputs dict
|
||||||
@@ -774,7 +601,7 @@ class WorkflowResponseConverter:
|
|||||||
if not outputs_dict:
|
if not outputs_dict:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||||
# Remove None
|
# Remove None
|
||||||
files = [file for file in files if file]
|
files = [file for file in files if file]
|
||||||
# Flatten list
|
# Flatten list
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Generator
|
||||||
from typing import Union, cast
|
from typing import Union, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
@@ -10,14 +10,12 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
|||||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||||
from core.app.apps.exc import GenerateTaskStoppedError
|
from core.app.apps.exc import GenerateTaskStoppedError
|
||||||
from core.app.apps.streaming_utils import stream_topic_events
|
|
||||||
from core.app.entities.app_invoke_entities import (
|
from core.app.entities.app_invoke_entities import (
|
||||||
AdvancedChatAppGenerateEntity,
|
AdvancedChatAppGenerateEntity,
|
||||||
AgentChatAppGenerateEntity,
|
AgentChatAppGenerateEntity,
|
||||||
AppGenerateEntity,
|
AppGenerateEntity,
|
||||||
ChatAppGenerateEntity,
|
ChatAppGenerateEntity,
|
||||||
CompletionAppGenerateEntity,
|
CompletionAppGenerateEntity,
|
||||||
ConversationAppGenerateEntity,
|
|
||||||
InvokeFrom,
|
InvokeFrom,
|
||||||
)
|
)
|
||||||
from core.app.entities.task_entities import (
|
from core.app.entities.task_entities import (
|
||||||
@@ -29,8 +27,6 @@ from core.app.entities.task_entities import (
|
|||||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
|
||||||
from libs.broadcast_channel.channel import Topic
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models import Account
|
from models import Account
|
||||||
from models.enums import CreatorUserRole
|
from models.enums import CreatorUserRole
|
||||||
@@ -160,7 +156,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
query = application_generate_entity.query or "New conversation"
|
query = application_generate_entity.query or "New conversation"
|
||||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||||
|
|
||||||
created_new_conversation = conversation is None
|
|
||||||
try:
|
try:
|
||||||
if not conversation:
|
if not conversation:
|
||||||
conversation = Conversation(
|
conversation = Conversation(
|
||||||
@@ -237,10 +232,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
db.session.add_all(message_files)
|
db.session.add_all(message_files)
|
||||||
|
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
|
|
||||||
application_generate_entity.conversation_id = conversation.id
|
|
||||||
application_generate_entity.is_new_conversation = created_new_conversation
|
|
||||||
return conversation, message
|
return conversation, message
|
||||||
except Exception:
|
except Exception:
|
||||||
db.session.rollback()
|
db.session.rollback()
|
||||||
@@ -293,29 +284,3 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
|||||||
raise MessageNotExistsError("Message not exists")
|
raise MessageNotExistsError("Message not exists")
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
|
||||||
return f"channel:{app_mode}:{workflow_run_id}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
|
||||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
|
||||||
channel = get_pubsub_broadcast_channel()
|
|
||||||
topic = channel.topic(key)
|
|
||||||
return topic
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def retrieve_events(
|
|
||||||
cls,
|
|
||||||
app_mode: AppMode,
|
|
||||||
workflow_run_id: str,
|
|
||||||
idle_timeout=300,
|
|
||||||
on_subscribe: Callable[[], None] | None = None,
|
|
||||||
) -> Generator[Mapping | str, None, None]:
|
|
||||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
|
||||||
return stream_topic_events(
|
|
||||||
topic=topic,
|
|
||||||
idle_timeout=idle_timeout,
|
|
||||||
on_subscribe=on_subscribe,
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
from collections.abc import Callable, Generator, Mapping
|
|
||||||
|
|
||||||
from core.app.apps.streaming_utils import stream_topic_events
|
|
||||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
|
||||||
from libs.broadcast_channel.channel import Topic
|
|
||||||
from models.model import AppMode
|
|
||||||
|
|
||||||
|
|
||||||
class MessageGenerator:
|
|
||||||
@staticmethod
|
|
||||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
|
||||||
return f"channel:{app_mode}:{str(workflow_run_id)}"
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
|
||||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
|
||||||
channel = get_pubsub_broadcast_channel()
|
|
||||||
topic = channel.topic(key)
|
|
||||||
return topic
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def retrieve_events(
|
|
||||||
cls,
|
|
||||||
app_mode: AppMode,
|
|
||||||
workflow_run_id: str,
|
|
||||||
idle_timeout=300,
|
|
||||||
ping_interval: float = 10.0,
|
|
||||||
on_subscribe: Callable[[], None] | None = None,
|
|
||||||
) -> Generator[Mapping | str, None, None]:
|
|
||||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
|
||||||
return stream_topic_events(
|
|
||||||
topic=topic,
|
|
||||||
idle_timeout=idle_timeout,
|
|
||||||
ping_interval=ping_interval,
|
|
||||||
on_subscribe=on_subscribe,
|
|
||||||
)
|
|
||||||
@@ -1,70 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from core.app.entities.task_entities import StreamEvent
|
|
||||||
from libs.broadcast_channel.channel import Topic
|
|
||||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
|
||||||
|
|
||||||
|
|
||||||
def stream_topic_events(
|
|
||||||
*,
|
|
||||||
topic: Topic,
|
|
||||||
idle_timeout: float,
|
|
||||||
ping_interval: float | None = None,
|
|
||||||
on_subscribe: Callable[[], None] | None = None,
|
|
||||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
|
||||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
|
||||||
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
|
||||||
#
|
|
||||||
# This simplify the debugging process as the DevTools in Chrome does not
|
|
||||||
# provide complete curl command for pending connections.
|
|
||||||
yield StreamEvent.PING.value
|
|
||||||
|
|
||||||
terminal_values = _normalize_terminal_events(terminal_events)
|
|
||||||
last_msg_time = time.time()
|
|
||||||
last_ping_time = last_msg_time
|
|
||||||
with topic.subscribe() as sub:
|
|
||||||
# on_subscribe fires only after the Redis subscription is active.
|
|
||||||
# This is used to gate task start and reduce pub/sub race for the first event.
|
|
||||||
if on_subscribe is not None:
|
|
||||||
on_subscribe()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
msg = sub.receive(timeout=0.1)
|
|
||||||
except SubscriptionClosedError:
|
|
||||||
return
|
|
||||||
if msg is None:
|
|
||||||
current_time = time.time()
|
|
||||||
if current_time - last_msg_time > idle_timeout:
|
|
||||||
return
|
|
||||||
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
|
|
||||||
yield StreamEvent.PING.value
|
|
||||||
last_ping_time = current_time
|
|
||||||
continue
|
|
||||||
|
|
||||||
last_msg_time = time.time()
|
|
||||||
last_ping_time = last_msg_time
|
|
||||||
event = json.loads(msg)
|
|
||||||
yield event
|
|
||||||
if not isinstance(event, dict):
|
|
||||||
continue
|
|
||||||
|
|
||||||
event_type = event.get("event")
|
|
||||||
if event_type in terminal_values:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
|
||||||
if not terminal_events:
|
|
||||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
|
||||||
values: set[str] = set()
|
|
||||||
for item in terminal_events:
|
|
||||||
if isinstance(item, StreamEvent):
|
|
||||||
values.add(item.value)
|
|
||||||
else:
|
|
||||||
values.add(str(item))
|
|
||||||
return values
|
|
||||||
@@ -25,7 +25,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
|||||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
|
||||||
from core.db.session_factory import session_factory
|
from core.db.session_factory import session_factory
|
||||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||||
@@ -35,15 +34,12 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
|||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
|
||||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories import file_factory
|
from factories import file_factory
|
||||||
from libs.flask_utils import preserve_flask_contexts
|
from libs.flask_utils import preserve_flask_contexts
|
||||||
from models.account import Account
|
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.model import App, EndUser
|
|
||||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
|
||||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -70,11 +66,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[True],
|
streaming: Literal[True],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_run_id: str | uuid.UUID | None = None,
|
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -88,11 +82,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: Literal[False],
|
streaming: Literal[False],
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_run_id: str | uuid.UUID | None = None,
|
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Mapping[str, Any]: ...
|
) -> Mapping[str, Any]: ...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
@@ -106,11 +98,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool,
|
streaming: bool,
|
||||||
call_depth: int,
|
call_depth: int,
|
||||||
workflow_run_id: str | uuid.UUID | None = None,
|
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||||
|
|
||||||
def generate(
|
def generate(
|
||||||
@@ -123,11 +113,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
invoke_from: InvokeFrom,
|
invoke_from: InvokeFrom,
|
||||||
streaming: bool = True,
|
streaming: bool = True,
|
||||||
call_depth: int = 0,
|
call_depth: int = 0,
|
||||||
workflow_run_id: str | uuid.UUID | None = None,
|
|
||||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||||
|
|
||||||
@@ -162,7 +150,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
extras = {
|
extras = {
|
||||||
**extract_external_trace_id_from_args(args),
|
**extract_external_trace_id_from_args(args),
|
||||||
}
|
}
|
||||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
workflow_run_id = str(uuid.uuid4())
|
||||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||||
# trigger shouldn't prepare user inputs
|
# trigger shouldn't prepare user inputs
|
||||||
if self._should_prepare_user_inputs(args):
|
if self._should_prepare_user_inputs(args):
|
||||||
@@ -228,40 +216,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
root_node_id=root_node_id,
|
root_node_id=root_node_id,
|
||||||
graph_engine_layers=graph_engine_layers,
|
graph_engine_layers=graph_engine_layers,
|
||||||
pause_state_config=pause_state_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def resume(
|
def resume(self, *, workflow_run_id: str) -> None:
|
||||||
self,
|
|
||||||
*,
|
|
||||||
app_model: App,
|
|
||||||
workflow: Workflow,
|
|
||||||
user: Union[Account, EndUser],
|
|
||||||
application_generate_entity: WorkflowAppGenerateEntity,
|
|
||||||
graph_runtime_state: GraphRuntimeState,
|
|
||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
|
||||||
"""
|
"""
|
||||||
Resume a paused workflow execution using the persisted runtime state.
|
@TBD
|
||||||
"""
|
"""
|
||||||
return self._generate(
|
pass
|
||||||
app_model=app_model,
|
|
||||||
workflow=workflow,
|
|
||||||
user=user,
|
|
||||||
application_generate_entity=application_generate_entity,
|
|
||||||
invoke_from=application_generate_entity.invoke_from,
|
|
||||||
workflow_execution_repository=workflow_execution_repository,
|
|
||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
|
||||||
streaming=application_generate_entity.stream,
|
|
||||||
variable_loader=variable_loader,
|
|
||||||
graph_engine_layers=graph_engine_layers,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
pause_state_config=pause_state_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
@@ -277,8 +238,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
pause_state_config: PauseStateLayerConfig | None = None,
|
|
||||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||||
"""
|
"""
|
||||||
Generate App response.
|
Generate App response.
|
||||||
@@ -292,8 +251,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
:param workflow_node_execution_repository: repository for workflow node execution
|
:param workflow_node_execution_repository: repository for workflow node execution
|
||||||
:param streaming: is stream
|
:param streaming: is stream
|
||||||
"""
|
"""
|
||||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
|
||||||
|
|
||||||
# init queue manager
|
# init queue manager
|
||||||
queue_manager = WorkflowAppQueueManager(
|
queue_manager = WorkflowAppQueueManager(
|
||||||
task_id=application_generate_entity.task_id,
|
task_id=application_generate_entity.task_id,
|
||||||
@@ -302,15 +259,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
app_mode=app_model.mode,
|
app_mode=app_model.mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
if pause_state_config is not None:
|
|
||||||
graph_layers.append(
|
|
||||||
PauseStatePersistenceLayer(
|
|
||||||
session_factory=pause_state_config.session_factory,
|
|
||||||
generate_entity=application_generate_entity,
|
|
||||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# new thread with request context and contextvars
|
# new thread with request context and contextvars
|
||||||
context = contextvars.copy_context()
|
context = contextvars.copy_context()
|
||||||
|
|
||||||
@@ -328,8 +276,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
"root_node_id": root_node_id,
|
"root_node_id": root_node_id,
|
||||||
"workflow_execution_repository": workflow_execution_repository,
|
"workflow_execution_repository": workflow_execution_repository,
|
||||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||||
"graph_engine_layers": tuple(graph_layers),
|
"graph_engine_layers": graph_engine_layers,
|
||||||
"graph_runtime_state": graph_runtime_state,
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -431,7 +378,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
variable_loader=var_loader,
|
variable_loader=var_loader,
|
||||||
pause_state_config=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def single_loop_generate(
|
def single_loop_generate(
|
||||||
@@ -513,7 +459,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
streaming=streaming,
|
streaming=streaming,
|
||||||
variable_loader=var_loader,
|
variable_loader=var_loader,
|
||||||
pause_state_config=None,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _generate_worker(
|
def _generate_worker(
|
||||||
@@ -527,7 +472,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
root_node_id: str | None = None,
|
root_node_id: str | None = None,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Generate worker in a new thread.
|
Generate worker in a new thread.
|
||||||
@@ -573,7 +517,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
|||||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||||
root_node_id=root_node_id,
|
root_node_id=root_node_id,
|
||||||
graph_engine_layers=graph_engine_layers,
|
graph_engine_layers=graph_engine_layers,
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
workflow_execution_repository: WorkflowExecutionRepository,
|
workflow_execution_repository: WorkflowExecutionRepository,
|
||||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||||
graph_runtime_state: GraphRuntimeState | None = None,
|
|
||||||
):
|
):
|
||||||
super().__init__(
|
super().__init__(
|
||||||
queue_manager=queue_manager,
|
queue_manager=queue_manager,
|
||||||
@@ -56,7 +55,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
self._root_node_id = root_node_id
|
self._root_node_id = root_node_id
|
||||||
self._workflow_execution_repository = workflow_execution_repository
|
self._workflow_execution_repository = workflow_execution_repository
|
||||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||||
self._resume_graph_runtime_state = graph_runtime_state
|
|
||||||
|
|
||||||
@trace_span(WorkflowAppRunnerHandler)
|
@trace_span(WorkflowAppRunnerHandler)
|
||||||
def run(self):
|
def run(self):
|
||||||
@@ -65,28 +63,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
"""
|
"""
|
||||||
app_config = self.application_generate_entity.app_config
|
app_config = self.application_generate_entity.app_config
|
||||||
app_config = cast(WorkflowAppConfig, app_config)
|
app_config = cast(WorkflowAppConfig, app_config)
|
||||||
|
|
||||||
|
system_inputs = SystemVariable(
|
||||||
|
files=self.application_generate_entity.files,
|
||||||
|
user_id=self._sys_user_id,
|
||||||
|
app_id=app_config.app_id,
|
||||||
|
timestamp=int(naive_utc_now().timestamp()),
|
||||||
|
workflow_id=app_config.workflow_id,
|
||||||
|
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||||
|
)
|
||||||
|
|
||||||
invoke_from = self.application_generate_entity.invoke_from
|
invoke_from = self.application_generate_entity.invoke_from
|
||||||
# if only single iteration or single loop run is requested
|
# if only single iteration or single loop run is requested
|
||||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||||
invoke_from = InvokeFrom.DEBUGGER
|
invoke_from = InvokeFrom.DEBUGGER
|
||||||
user_from = self._resolve_user_from(invoke_from)
|
user_from = self._resolve_user_from(invoke_from)
|
||||||
|
|
||||||
resume_state = self._resume_graph_runtime_state
|
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||||
|
|
||||||
if resume_state is not None:
|
|
||||||
graph_runtime_state = resume_state
|
|
||||||
variable_pool = graph_runtime_state.variable_pool
|
|
||||||
graph = self._init_graph(
|
|
||||||
graph_config=self._workflow.graph_dict,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
workflow_id=self._workflow.id,
|
|
||||||
tenant_id=self._workflow.tenant_id,
|
|
||||||
user_id=self.application_generate_entity.user_id,
|
|
||||||
user_from=user_from,
|
|
||||||
invoke_from=invoke_from,
|
|
||||||
root_node_id=self._root_node_id,
|
|
||||||
)
|
|
||||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
|
||||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||||
workflow=self._workflow,
|
workflow=self._workflow,
|
||||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||||
@@ -96,14 +89,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
inputs = self.application_generate_entity.inputs
|
inputs = self.application_generate_entity.inputs
|
||||||
|
|
||||||
# Create a variable pool.
|
# Create a variable pool.
|
||||||
system_inputs = SystemVariable(
|
|
||||||
files=self.application_generate_entity.files,
|
|
||||||
user_id=self._sys_user_id,
|
|
||||||
app_id=app_config.app_id,
|
|
||||||
timestamp=int(naive_utc_now().timestamp()),
|
|
||||||
workflow_id=app_config.workflow_id,
|
|
||||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
|
||||||
)
|
|
||||||
variable_pool = VariablePool(
|
variable_pool = VariablePool(
|
||||||
system_variables=system_inputs,
|
system_variables=system_inputs,
|
||||||
user_inputs=inputs,
|
user_inputs=inputs,
|
||||||
@@ -112,6 +98,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
|||||||
)
|
)
|
||||||
|
|
||||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||||
|
|
||||||
|
# init graph
|
||||||
graph = self._init_graph(
|
graph = self._init_graph(
|
||||||
graph_config=self._workflow.graph_dict,
|
graph_config=self._workflow.graph_dict,
|
||||||
graph_runtime_state=graph_runtime_state,
|
graph_runtime_state=graph_runtime_state,
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
from libs.exception import BaseHTTPException
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowPausedInBlockingModeError(BaseHTTPException):
|
|
||||||
error_code = "workflow_paused_in_blocking_mode"
|
|
||||||
description = "Workflow execution paused for human input; blocking response mode is not supported."
|
|
||||||
code = 400
|
|
||||||
@@ -16,8 +16,6 @@ from core.app.entities.queue_entities import (
|
|||||||
MessageQueueMessage,
|
MessageQueueMessage,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueErrorEvent,
|
QueueErrorEvent,
|
||||||
QueueHumanInputFormFilledEvent,
|
|
||||||
QueueHumanInputFormTimeoutEvent,
|
|
||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
@@ -34,7 +32,6 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
QueueWorkflowPartialSuccessEvent,
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowPausedEvent,
|
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
WorkflowQueueMessage,
|
WorkflowQueueMessage,
|
||||||
@@ -49,13 +46,11 @@ from core.app.entities.task_entities import (
|
|||||||
WorkflowAppBlockingResponse,
|
WorkflowAppBlockingResponse,
|
||||||
WorkflowAppStreamResponse,
|
WorkflowAppStreamResponse,
|
||||||
WorkflowFinishStreamResponse,
|
WorkflowFinishStreamResponse,
|
||||||
WorkflowPauseStreamResponse,
|
|
||||||
WorkflowStartStreamResponse,
|
WorkflowStartStreamResponse,
|
||||||
)
|
)
|
||||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||||
from core.ops.ops_trace_manager import TraceQueueManager
|
from core.ops.ops_trace_manager import TraceQueueManager
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.enums import WorkflowExecutionStatus
|
from core.workflow.enums import WorkflowExecutionStatus
|
||||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||||
from core.workflow.runtime import GraphRuntimeState
|
from core.workflow.runtime import GraphRuntimeState
|
||||||
@@ -137,25 +132,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
for stream_response in generator:
|
for stream_response in generator:
|
||||||
if isinstance(stream_response, ErrorStreamResponse):
|
if isinstance(stream_response, ErrorStreamResponse):
|
||||||
raise stream_response.err
|
raise stream_response.err
|
||||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
|
||||||
response = WorkflowAppBlockingResponse(
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
workflow_run_id=stream_response.data.workflow_run_id,
|
|
||||||
data=WorkflowAppBlockingResponse.Data(
|
|
||||||
id=stream_response.data.workflow_run_id,
|
|
||||||
workflow_id=self._workflow.id,
|
|
||||||
status=stream_response.data.status,
|
|
||||||
outputs=stream_response.data.outputs or {},
|
|
||||||
error=None,
|
|
||||||
elapsed_time=stream_response.data.elapsed_time,
|
|
||||||
total_tokens=stream_response.data.total_tokens,
|
|
||||||
total_steps=stream_response.data.total_steps,
|
|
||||||
created_at=stream_response.data.created_at,
|
|
||||||
finished_at=None,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||||
response = WorkflowAppBlockingResponse(
|
response = WorkflowAppBlockingResponse(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
@@ -170,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
total_tokens=stream_response.data.total_tokens,
|
total_tokens=stream_response.data.total_tokens,
|
||||||
total_steps=stream_response.data.total_steps,
|
total_steps=stream_response.data.total_steps,
|
||||||
created_at=int(stream_response.data.created_at),
|
created_at=int(stream_response.data.created_at),
|
||||||
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
|
finished_at=int(stream_response.data.finished_at),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -283,15 +259,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
run_id = self._extract_workflow_run_id(runtime_state)
|
run_id = self._extract_workflow_run_id(runtime_state)
|
||||||
self._workflow_execution_id = run_id
|
self._workflow_execution_id = run_id
|
||||||
|
|
||||||
if event.reason == WorkflowStartReason.INITIAL:
|
with self._database_session() as session:
|
||||||
with self._database_session() as session:
|
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
|
||||||
|
|
||||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||||
task_id=self._application_generate_entity.task_id,
|
task_id=self._application_generate_entity.task_id,
|
||||||
workflow_run_id=run_id,
|
workflow_run_id=run_id,
|
||||||
workflow_id=self._workflow.id,
|
workflow_id=self._workflow.id,
|
||||||
reason=event.reason,
|
|
||||||
)
|
)
|
||||||
yield start_resp
|
yield start_resp
|
||||||
|
|
||||||
@@ -466,21 +440,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
)
|
)
|
||||||
yield workflow_finish_resp
|
yield workflow_finish_resp
|
||||||
|
|
||||||
def _handle_workflow_paused_event(
|
|
||||||
self,
|
|
||||||
event: QueueWorkflowPausedEvent,
|
|
||||||
**kwargs,
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle workflow paused events."""
|
|
||||||
self._ensure_workflow_initialized()
|
|
||||||
validated_state = self._ensure_graph_runtime_initialized()
|
|
||||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
|
||||||
event=event,
|
|
||||||
task_id=self._application_generate_entity.task_id,
|
|
||||||
graph_runtime_state=validated_state,
|
|
||||||
)
|
|
||||||
yield from responses
|
|
||||||
|
|
||||||
def _handle_workflow_failed_and_stop_events(
|
def _handle_workflow_failed_and_stop_events(
|
||||||
self,
|
self,
|
||||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||||
@@ -536,22 +495,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
task_id=self._application_generate_entity.task_id, event=event
|
task_id=self._application_generate_entity.task_id, event=event
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_human_input_form_filled_event(
|
|
||||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle human input form filled events."""
|
|
||||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
|
||||||
event=event, task_id=self._application_generate_entity.task_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _handle_human_input_form_timeout_event(
|
|
||||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
|
||||||
) -> Generator[StreamResponse, None, None]:
|
|
||||||
"""Handle human input form timeout events."""
|
|
||||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
|
||||||
event=event, task_id=self._application_generate_entity.task_id
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_event_handlers(self) -> dict[type, Callable]:
|
def _get_event_handlers(self) -> dict[type, Callable]:
|
||||||
"""Get mapping of event types to their handlers using fluent pattern."""
|
"""Get mapping of event types to their handlers using fluent pattern."""
|
||||||
return {
|
return {
|
||||||
@@ -563,7 +506,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
|
||||||
# Node events
|
# Node events
|
||||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||||
@@ -578,8 +520,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
||||||
# Agent events
|
# Agent events
|
||||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
|
||||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _dispatch_event(
|
def _dispatch_event(
|
||||||
@@ -662,9 +602,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
|||||||
case QueueWorkflowFailedEvent():
|
case QueueWorkflowFailedEvent():
|
||||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||||
break
|
break
|
||||||
case QueueWorkflowPausedEvent():
|
|
||||||
yield from self._handle_workflow_paused_event(event)
|
|
||||||
break
|
|
||||||
|
|
||||||
case QueueStopEvent():
|
case QueueStopEvent():
|
||||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import logging
|
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
@@ -8,8 +7,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
|||||||
from core.app.entities.queue_entities import (
|
from core.app.entities.queue_entities import (
|
||||||
AppQueueEvent,
|
AppQueueEvent,
|
||||||
QueueAgentLogEvent,
|
QueueAgentLogEvent,
|
||||||
QueueHumanInputFormFilledEvent,
|
|
||||||
QueueHumanInputFormTimeoutEvent,
|
|
||||||
QueueIterationCompletedEvent,
|
QueueIterationCompletedEvent,
|
||||||
QueueIterationNextEvent,
|
QueueIterationNextEvent,
|
||||||
QueueIterationStartEvent,
|
QueueIterationStartEvent,
|
||||||
@@ -25,27 +22,22 @@ from core.app.entities.queue_entities import (
|
|||||||
QueueTextChunkEvent,
|
QueueTextChunkEvent,
|
||||||
QueueWorkflowFailedEvent,
|
QueueWorkflowFailedEvent,
|
||||||
QueueWorkflowPartialSuccessEvent,
|
QueueWorkflowPartialSuccessEvent,
|
||||||
QueueWorkflowPausedEvent,
|
|
||||||
QueueWorkflowStartedEvent,
|
QueueWorkflowStartedEvent,
|
||||||
QueueWorkflowSucceededEvent,
|
QueueWorkflowSucceededEvent,
|
||||||
)
|
)
|
||||||
from core.app.workflow.node_factory import DifyNodeFactory
|
from core.app.workflow.node_factory import DifyNodeFactory
|
||||||
from core.workflow.entities import GraphInitParams
|
from core.workflow.entities import GraphInitParams
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
GraphEngineEvent,
|
GraphEngineEvent,
|
||||||
GraphRunFailedEvent,
|
GraphRunFailedEvent,
|
||||||
GraphRunPartialSucceededEvent,
|
GraphRunPartialSucceededEvent,
|
||||||
GraphRunPausedEvent,
|
|
||||||
GraphRunStartedEvent,
|
GraphRunStartedEvent,
|
||||||
GraphRunSucceededEvent,
|
GraphRunSucceededEvent,
|
||||||
NodeRunAgentLogEvent,
|
NodeRunAgentLogEvent,
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunHumanInputFormFilledEvent,
|
|
||||||
NodeRunHumanInputFormTimeoutEvent,
|
|
||||||
NodeRunIterationFailedEvent,
|
NodeRunIterationFailedEvent,
|
||||||
NodeRunIterationNextEvent,
|
NodeRunIterationNextEvent,
|
||||||
NodeRunIterationStartedEvent,
|
NodeRunIterationStartedEvent,
|
||||||
@@ -69,9 +61,6 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
|
|||||||
from core.workflow.workflow_entry import WorkflowEntry
|
from core.workflow.workflow_entry import WorkflowEntry
|
||||||
from models.enums import UserFrom
|
from models.enums import UserFrom
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowBasedAppRunner:
|
class WorkflowBasedAppRunner:
|
||||||
@@ -338,7 +327,7 @@ class WorkflowBasedAppRunner:
|
|||||||
:param event: event
|
:param event: event
|
||||||
"""
|
"""
|
||||||
if isinstance(event, GraphRunStartedEvent):
|
if isinstance(event, GraphRunStartedEvent):
|
||||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
self._publish_event(QueueWorkflowStartedEvent())
|
||||||
elif isinstance(event, GraphRunSucceededEvent):
|
elif isinstance(event, GraphRunSucceededEvent):
|
||||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||||
@@ -349,38 +338,6 @@ class WorkflowBasedAppRunner:
|
|||||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||||
elif isinstance(event, GraphRunAbortedEvent):
|
elif isinstance(event, GraphRunAbortedEvent):
|
||||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||||
elif isinstance(event, GraphRunPausedEvent):
|
|
||||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
|
||||||
paused_nodes = runtime_state.get_paused_nodes()
|
|
||||||
self._enqueue_human_input_notifications(event.reasons)
|
|
||||||
self._publish_event(
|
|
||||||
QueueWorkflowPausedEvent(
|
|
||||||
reasons=event.reasons,
|
|
||||||
outputs=event.outputs,
|
|
||||||
paused_nodes=paused_nodes,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueHumanInputFormFilledEvent(
|
|
||||||
node_execution_id=event.id,
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_title=event.node_title,
|
|
||||||
rendered_content=event.rendered_content,
|
|
||||||
action_id=event.action_id,
|
|
||||||
action_text=event.action_text,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
|
|
||||||
self._publish_event(
|
|
||||||
QueueHumanInputFormTimeoutEvent(
|
|
||||||
node_id=event.node_id,
|
|
||||||
node_type=event.node_type,
|
|
||||||
node_title=event.node_title,
|
|
||||||
expiration_time=event.expiration_time,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
elif isinstance(event, NodeRunRetryEvent):
|
elif isinstance(event, NodeRunRetryEvent):
|
||||||
node_run_result = event.node_run_result
|
node_run_result = event.node_run_result
|
||||||
inputs = node_run_result.inputs
|
inputs = node_run_result.inputs
|
||||||
@@ -587,19 +544,5 @@ class WorkflowBasedAppRunner:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
|
|
||||||
for reason in reasons:
|
|
||||||
if not isinstance(reason, HumanInputRequired):
|
|
||||||
continue
|
|
||||||
if not reason.form_id:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
dispatch_human_input_email_task.apply_async(
|
|
||||||
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
|
|
||||||
queue="mail",
|
|
||||||
)
|
|
||||||
except Exception: # pragma: no cover - defensive logging
|
|
||||||
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
|
|
||||||
|
|
||||||
def _publish_event(self, event: AppQueueEvent):
|
def _publish_event(self, event: AppQueueEvent):
|
||||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||||
|
|||||||
@@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
|
|||||||
extras: dict[str, Any] = Field(default_factory=dict)
|
extras: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
# tracing instance
|
# tracing instance
|
||||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
trace_manager: Optional["TraceQueueManager"] = None
|
||||||
|
|
||||||
|
|
||||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||||
@@ -156,7 +156,6 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
conversation_id: str | None = None
|
conversation_id: str | None = None
|
||||||
is_new_conversation: bool = False
|
|
||||||
parent_message_id: str | None = Field(
|
parent_message_id: str | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description=(
|
description=(
|
||||||
|
|||||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.pause_reason import PauseReason
|
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||||
from core.workflow.nodes import NodeType
|
from core.workflow.nodes import NodeType
|
||||||
|
|
||||||
@@ -48,9 +46,6 @@ class QueueEvent(StrEnum):
|
|||||||
PING = "ping"
|
PING = "ping"
|
||||||
STOP = "stop"
|
STOP = "stop"
|
||||||
RETRY = "retry"
|
RETRY = "retry"
|
||||||
PAUSE = "pause"
|
|
||||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
|
||||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
|
||||||
|
|
||||||
|
|
||||||
class AppQueueEvent(BaseModel):
|
class AppQueueEvent(BaseModel):
|
||||||
@@ -266,8 +261,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
|||||||
"""QueueWorkflowStartedEvent entity."""
|
"""QueueWorkflowStartedEvent entity."""
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||||
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
|
|
||||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
|
||||||
|
|
||||||
|
|
||||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||||
@@ -491,35 +484,6 @@ class QueueStopEvent(AppQueueEvent):
|
|||||||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||||
|
|
||||||
|
|
||||||
class QueueHumanInputFormFilledEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueHumanInputFormFilledEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
|
|
||||||
|
|
||||||
node_execution_id: str
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_title: str
|
|
||||||
rendered_content: str
|
|
||||||
action_id: str
|
|
||||||
action_text: str
|
|
||||||
|
|
||||||
|
|
||||||
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueHumanInputFormTimeoutEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
|
|
||||||
|
|
||||||
node_id: str
|
|
||||||
node_type: NodeType
|
|
||||||
node_title: str
|
|
||||||
expiration_time: datetime
|
|
||||||
|
|
||||||
|
|
||||||
class QueueMessage(BaseModel):
|
class QueueMessage(BaseModel):
|
||||||
"""
|
"""
|
||||||
QueueMessage abstract entity
|
QueueMessage abstract entity
|
||||||
@@ -545,14 +509,3 @@ class WorkflowQueueMessage(QueueMessage):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class QueueWorkflowPausedEvent(AppQueueEvent):
|
|
||||||
"""
|
|
||||||
QueueWorkflowPausedEvent entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
event: QueueEvent = QueueEvent.PAUSE
|
|
||||||
reasons: Sequence[PauseReason] = Field(default_factory=list)
|
|
||||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
|
||||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
|
||||||
|
|||||||
@@ -7,9 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||||
from core.workflow.entities import AgentNodeStrategyInit
|
from core.workflow.entities import AgentNodeStrategyInit
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
|
||||||
|
|
||||||
|
|
||||||
class AnnotationReplyAccount(BaseModel):
|
class AnnotationReplyAccount(BaseModel):
|
||||||
@@ -71,7 +69,6 @@ class StreamEvent(StrEnum):
|
|||||||
AGENT_THOUGHT = "agent_thought"
|
AGENT_THOUGHT = "agent_thought"
|
||||||
AGENT_MESSAGE = "agent_message"
|
AGENT_MESSAGE = "agent_message"
|
||||||
WORKFLOW_STARTED = "workflow_started"
|
WORKFLOW_STARTED = "workflow_started"
|
||||||
WORKFLOW_PAUSED = "workflow_paused"
|
|
||||||
WORKFLOW_FINISHED = "workflow_finished"
|
WORKFLOW_FINISHED = "workflow_finished"
|
||||||
NODE_STARTED = "node_started"
|
NODE_STARTED = "node_started"
|
||||||
NODE_FINISHED = "node_finished"
|
NODE_FINISHED = "node_finished"
|
||||||
@@ -85,9 +82,6 @@ class StreamEvent(StrEnum):
|
|||||||
TEXT_CHUNK = "text_chunk"
|
TEXT_CHUNK = "text_chunk"
|
||||||
TEXT_REPLACE = "text_replace"
|
TEXT_REPLACE = "text_replace"
|
||||||
AGENT_LOG = "agent_log"
|
AGENT_LOG = "agent_log"
|
||||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
|
||||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
|
||||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
|
||||||
|
|
||||||
|
|
||||||
class StreamResponse(BaseModel):
|
class StreamResponse(BaseModel):
|
||||||
@@ -211,8 +205,6 @@ class WorkflowStartStreamResponse(StreamResponse):
|
|||||||
workflow_id: str
|
workflow_id: str
|
||||||
inputs: Mapping[str, Any]
|
inputs: Mapping[str, Any]
|
||||||
created_at: int
|
created_at: int
|
||||||
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
|
|
||||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
@@ -239,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||||||
total_steps: int
|
total_steps: int
|
||||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int | None
|
finished_at: int
|
||||||
exceptions_count: int | None = 0
|
exceptions_count: int | None = 0
|
||||||
files: Sequence[Mapping[str, Any]] | None = []
|
files: Sequence[Mapping[str, Any]] | None = []
|
||||||
|
|
||||||
@@ -248,85 +240,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
|||||||
data: Data
|
data: Data
|
||||||
|
|
||||||
|
|
||||||
class WorkflowPauseStreamResponse(StreamResponse):
|
|
||||||
"""
|
|
||||||
WorkflowPauseStreamResponse entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
workflow_run_id: str
|
|
||||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
|
||||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
|
||||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
|
||||||
status: str
|
|
||||||
created_at: int
|
|
||||||
elapsed_time: float
|
|
||||||
total_tokens: int
|
|
||||||
total_steps: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputRequiredResponse(StreamResponse):
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
form_id: str
|
|
||||||
node_id: str
|
|
||||||
node_title: str
|
|
||||||
form_content: str
|
|
||||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
|
||||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
|
||||||
display_in_ui: bool = False
|
|
||||||
form_token: str | None = None
|
|
||||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
|
||||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormFilledResponse(StreamResponse):
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_id: str
|
|
||||||
node_title: str
|
|
||||||
rendered_content: str
|
|
||||||
action_id: str
|
|
||||||
action_text: str
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormTimeoutResponse(StreamResponse):
|
|
||||||
class Data(BaseModel):
|
|
||||||
"""
|
|
||||||
Data entity
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_id: str
|
|
||||||
node_title: str
|
|
||||||
expiration_time: int
|
|
||||||
|
|
||||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
|
|
||||||
workflow_run_id: str
|
|
||||||
data: Data
|
|
||||||
|
|
||||||
|
|
||||||
class NodeStartStreamResponse(StreamResponse):
|
class NodeStartStreamResponse(StreamResponse):
|
||||||
"""
|
"""
|
||||||
NodeStartStreamResponse entity
|
NodeStartStreamResponse entity
|
||||||
@@ -813,7 +726,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
|||||||
total_tokens: int
|
total_tokens: int
|
||||||
total_steps: int
|
total_steps: int
|
||||||
created_at: int
|
created_at: int
|
||||||
finished_at: int | None
|
finished_at: int
|
||||||
|
|
||||||
workflow_run_id: str
|
workflow_run_id: str
|
||||||
data: Data
|
data: Data
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
@@ -104,14 +103,6 @@ class RateLimit:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
|
|
||||||
request_id = rate_limit.enter(request_id)
|
|
||||||
yield
|
|
||||||
if request_id is not None:
|
|
||||||
rate_limit.exit(request_id)
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimitGenerator:
|
class RateLimitGenerator:
|
||||||
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
||||||
self.rate_limit = rate_limit
|
self.rate_limit = rate_limit
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from dataclasses import dataclass
|
|
||||||
from typing import Annotated, Literal, Self, TypeAlias
|
from typing import Annotated, Literal, Self, TypeAlias
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -53,14 +52,6 @@ class WorkflowResumptionContext(BaseModel):
|
|||||||
return self.generate_entity.entity
|
return self.generate_entity.entity
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PauseStateLayerConfig:
|
|
||||||
"""Configuration container for instantiating pause persistence layers."""
|
|
||||||
|
|
||||||
session_factory: Engine | sessionmaker[Session]
|
|
||||||
state_owner_user_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -82,11 +82,10 @@ class MessageCycleManager:
|
|||||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
is_first_message = self._application_generate_entity.is_new_conversation
|
is_first_message = self._application_generate_entity.conversation_id is None
|
||||||
extras = self._application_generate_entity.extras
|
extras = self._application_generate_entity.extras
|
||||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||||
|
|
||||||
thread: Thread | None = None
|
|
||||||
if auto_generate_conversation_name and is_first_message:
|
if auto_generate_conversation_name and is_first_message:
|
||||||
# start generate thread
|
# start generate thread
|
||||||
# time.sleep not block other logic
|
# time.sleep not block other logic
|
||||||
@@ -102,10 +101,9 @@ class MessageCycleManager:
|
|||||||
thread.daemon = True
|
thread.daemon = True
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
if is_first_message:
|
return thread
|
||||||
self._application_generate_entity.is_new_conversation = False
|
|
||||||
|
|
||||||
return thread
|
return None
|
||||||
|
|
||||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||||
with flask_app.app_context():
|
with flask_app.app_context():
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from typing import Any, TypeAlias
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
|
||||||
|
|
||||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
|
||||||
from models.execution_extra_content import ExecutionContentType
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormDefinition(BaseModel):
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
form_id: str
|
|
||||||
node_id: str
|
|
||||||
node_title: str
|
|
||||||
form_content: str
|
|
||||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
|
||||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
|
||||||
display_in_ui: bool = False
|
|
||||||
form_token: str | None = None
|
|
||||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
|
||||||
expiration_time: int
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormSubmissionData(BaseModel):
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
node_id: str
|
|
||||||
node_title: str
|
|
||||||
rendered_content: str
|
|
||||||
action_id: str
|
|
||||||
action_text: str
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputContent(BaseModel):
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
workflow_run_id: str
|
|
||||||
submitted: bool
|
|
||||||
form_definition: HumanInputFormDefinition | None = None
|
|
||||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
|
||||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
|
||||||
|
|
||||||
|
|
||||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ExecutionExtraContentDomainModel",
|
|
||||||
"HumanInputContent",
|
|
||||||
"HumanInputFormDefinition",
|
|
||||||
"HumanInputFormSubmissionData",
|
|
||||||
]
|
|
||||||
@@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
|
|||||||
)
|
)
|
||||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||||
|
from extensions.ext_database import db
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from models.engine import db
|
|
||||||
from models.provider import (
|
from models.provider import (
|
||||||
LoadBalancingModelConfig,
|
LoadBalancingModelConfig,
|
||||||
Provider,
|
Provider,
|
||||||
|
|||||||
@@ -15,7 +15,10 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
from core.ops.entities.config_entity import (
|
||||||
|
OPS_FILE_PATH,
|
||||||
|
TracingProviderEnum,
|
||||||
|
)
|
||||||
from core.ops.entities.trace_entity import (
|
from core.ops.entities.trace_entity import (
|
||||||
DatasetRetrievalTraceInfo,
|
DatasetRetrievalTraceInfo,
|
||||||
GenerateNameTraceInfo,
|
GenerateNameTraceInfo,
|
||||||
@@ -28,8 +31,8 @@ from core.ops.entities.trace_entity import (
|
|||||||
WorkflowTraceInfo,
|
WorkflowTraceInfo,
|
||||||
)
|
)
|
||||||
from core.ops.utils import get_message_data
|
from core.ops.utils import get_message_data
|
||||||
|
from extensions.ext_database import db
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from models.engine import db
|
|
||||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||||
from models.workflow import WorkflowAppLog
|
from models.workflow import WorkflowAppLog
|
||||||
from tasks.ops_trace_task import process_trace_tasks
|
from tasks.ops_trace_task import process_trace_tasks
|
||||||
@@ -466,8 +469,6 @@ class TraceTask:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_workflow_run_repo(cls):
|
def _get_workflow_run_repo(cls):
|
||||||
from repositories.factory import DifyAPIRepositoryFactory
|
|
||||||
|
|
||||||
if cls._workflow_run_repo is None:
|
if cls._workflow_run_repo is None:
|
||||||
with cls._repo_lock:
|
with cls._repo_lock:
|
||||||
if cls._workflow_run_repo is None:
|
if cls._workflow_run_repo is None:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
|||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from models.engine import db
|
from extensions.ext_database import db
|
||||||
from models.model import Message
|
from models.model import Message
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
import uuid
|
|
||||||
from collections.abc import Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
@@ -12,7 +11,6 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
|
|||||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
|
||||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from models import Account
|
from models import Account
|
||||||
@@ -103,11 +101,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("unexpected app type")
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
pause_config = PauseStateLayerConfig(
|
|
||||||
session_factory=db.engine,
|
|
||||||
state_owner_user_id=workflow.created_by,
|
|
||||||
)
|
|
||||||
|
|
||||||
return AdvancedChatAppGenerator().generate(
|
return AdvancedChatAppGenerator().generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
@@ -119,9 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
"conversation_id": conversation_id,
|
"conversation_id": conversation_id,
|
||||||
},
|
},
|
||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
workflow_run_id=str(uuid.uuid4()),
|
|
||||||
streaming=stream,
|
streaming=stream,
|
||||||
pause_state_config=pause_config,
|
|
||||||
)
|
)
|
||||||
elif app.mode == AppMode.AGENT_CHAT:
|
elif app.mode == AppMode.AGENT_CHAT:
|
||||||
return AgentChatAppGenerator().generate(
|
return AgentChatAppGenerator().generate(
|
||||||
@@ -168,11 +159,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
if not workflow:
|
if not workflow:
|
||||||
raise ValueError("unexpected app type")
|
raise ValueError("unexpected app type")
|
||||||
|
|
||||||
pause_config = PauseStateLayerConfig(
|
|
||||||
session_factory=db.engine,
|
|
||||||
state_owner_user_id=workflow.created_by,
|
|
||||||
)
|
|
||||||
|
|
||||||
return WorkflowAppGenerator().generate(
|
return WorkflowAppGenerator().generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
@@ -181,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
|||||||
invoke_from=InvokeFrom.SERVICE_API,
|
invoke_from=InvokeFrom.SERVICE_API,
|
||||||
streaming=stream,
|
streaming=stream,
|
||||||
call_depth=1,
|
call_depth=1,
|
||||||
pause_state_config=pause_config,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -1,18 +1,19 @@
|
|||||||
"""Repository implementations for data access."""
|
"""
|
||||||
|
Repository implementations for data access.
|
||||||
|
|
||||||
from __future__ import annotations
|
This package contains concrete implementations of the repository interfaces
|
||||||
|
defined in the core.workflow.repository package.
|
||||||
|
"""
|
||||||
|
|
||||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"CeleryWorkflowExecutionRepository",
|
"CeleryWorkflowExecutionRepository",
|
||||||
"CeleryWorkflowNodeExecutionRepository",
|
"CeleryWorkflowNodeExecutionRepository",
|
||||||
"DifyCoreRepositoryFactory",
|
"DifyCoreRepositoryFactory",
|
||||||
"RepositoryImportError",
|
"RepositoryImportError",
|
||||||
"SQLAlchemyWorkflowExecutionRepository",
|
|
||||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -1,553 +0,0 @@
|
|||||||
import dataclasses
|
|
||||||
import json
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import Engine, select
|
|
||||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
|
||||||
|
|
||||||
from core.workflow.nodes.human_input.entities import (
|
|
||||||
DeliveryChannelConfig,
|
|
||||||
EmailDeliveryMethod,
|
|
||||||
EmailRecipients,
|
|
||||||
ExternalRecipient,
|
|
||||||
FormDefinition,
|
|
||||||
HumanInputNodeData,
|
|
||||||
MemberRecipient,
|
|
||||||
WebAppDeliveryMethod,
|
|
||||||
)
|
|
||||||
from core.workflow.nodes.human_input.enums import (
|
|
||||||
DeliveryMethodType,
|
|
||||||
HumanInputFormKind,
|
|
||||||
HumanInputFormStatus,
|
|
||||||
)
|
|
||||||
from core.workflow.repositories.human_input_form_repository import (
|
|
||||||
FormCreateParams,
|
|
||||||
FormNotFoundError,
|
|
||||||
HumanInputFormEntity,
|
|
||||||
HumanInputFormRecipientEntity,
|
|
||||||
)
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
|
||||||
from libs.uuid_utils import uuidv7
|
|
||||||
from models.account import Account, TenantAccountJoin
|
|
||||||
from models.human_input import (
|
|
||||||
BackstageRecipientPayload,
|
|
||||||
ConsoleDeliveryPayload,
|
|
||||||
ConsoleRecipientPayload,
|
|
||||||
EmailExternalRecipientPayload,
|
|
||||||
EmailMemberRecipientPayload,
|
|
||||||
HumanInputDelivery,
|
|
||||||
HumanInputForm,
|
|
||||||
HumanInputFormRecipient,
|
|
||||||
RecipientType,
|
|
||||||
StandaloneWebAppRecipientPayload,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class _DeliveryAndRecipients:
|
|
||||||
delivery: HumanInputDelivery
|
|
||||||
recipients: Sequence[HumanInputFormRecipient]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class _WorkspaceMemberInfo:
|
|
||||||
user_id: str
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
|
||||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
|
||||||
self._recipient_model = recipient_model
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> str:
|
|
||||||
return self._recipient_model.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def token(self) -> str:
|
|
||||||
if self._recipient_model.access_token is None:
|
|
||||||
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
|
|
||||||
return self._recipient_model.access_token
|
|
||||||
|
|
||||||
|
|
||||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
|
||||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
|
||||||
self._form_model = form_model
|
|
||||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
|
||||||
self._web_app_recipient = next(
|
|
||||||
(
|
|
||||||
recipient
|
|
||||||
for recipient in recipient_models
|
|
||||||
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
self._console_recipient = next(
|
|
||||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
self._submitted_data: Mapping[str, Any] | None = (
|
|
||||||
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
|
|
||||||
)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def id(self) -> str:
|
|
||||||
return self._form_model.id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def web_app_token(self):
|
|
||||||
if self._console_recipient is not None:
|
|
||||||
return self._console_recipient.access_token
|
|
||||||
if self._web_app_recipient is None:
|
|
||||||
return None
|
|
||||||
return self._web_app_recipient.access_token
|
|
||||||
|
|
||||||
@property
|
|
||||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
|
||||||
return list(self._recipients)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rendered_content(self) -> str:
|
|
||||||
return self._form_model.rendered_content
|
|
||||||
|
|
||||||
@property
|
|
||||||
def selected_action_id(self) -> str | None:
|
|
||||||
return self._form_model.selected_action_id
|
|
||||||
|
|
||||||
@property
|
|
||||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
|
||||||
return self._submitted_data
|
|
||||||
|
|
||||||
@property
|
|
||||||
def submitted(self) -> bool:
|
|
||||||
return self._form_model.submitted_at is not None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def status(self) -> HumanInputFormStatus:
|
|
||||||
return self._form_model.status
|
|
||||||
|
|
||||||
@property
|
|
||||||
def expiration_time(self) -> datetime:
|
|
||||||
return self._form_model.expiration_time
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
|
||||||
class HumanInputFormRecord:
|
|
||||||
form_id: str
|
|
||||||
workflow_run_id: str | None
|
|
||||||
node_id: str
|
|
||||||
tenant_id: str
|
|
||||||
app_id: str
|
|
||||||
form_kind: HumanInputFormKind
|
|
||||||
definition: FormDefinition
|
|
||||||
rendered_content: str
|
|
||||||
created_at: datetime
|
|
||||||
expiration_time: datetime
|
|
||||||
status: HumanInputFormStatus
|
|
||||||
selected_action_id: str | None
|
|
||||||
submitted_data: Mapping[str, Any] | None
|
|
||||||
submitted_at: datetime | None
|
|
||||||
submission_user_id: str | None
|
|
||||||
submission_end_user_id: str | None
|
|
||||||
completed_by_recipient_id: str | None
|
|
||||||
recipient_id: str | None
|
|
||||||
recipient_type: RecipientType | None
|
|
||||||
access_token: str | None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def submitted(self) -> bool:
|
|
||||||
return self.submitted_at is not None
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_models(
|
|
||||||
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
|
|
||||||
) -> "HumanInputFormRecord":
|
|
||||||
definition_payload = json.loads(form_model.form_definition)
|
|
||||||
if "expiration_time" not in definition_payload:
|
|
||||||
definition_payload["expiration_time"] = form_model.expiration_time
|
|
||||||
return cls(
|
|
||||||
form_id=form_model.id,
|
|
||||||
workflow_run_id=form_model.workflow_run_id,
|
|
||||||
node_id=form_model.node_id,
|
|
||||||
tenant_id=form_model.tenant_id,
|
|
||||||
app_id=form_model.app_id,
|
|
||||||
form_kind=form_model.form_kind,
|
|
||||||
definition=FormDefinition.model_validate(definition_payload),
|
|
||||||
rendered_content=form_model.rendered_content,
|
|
||||||
created_at=form_model.created_at,
|
|
||||||
expiration_time=form_model.expiration_time,
|
|
||||||
status=form_model.status,
|
|
||||||
selected_action_id=form_model.selected_action_id,
|
|
||||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
|
||||||
submitted_at=form_model.submitted_at,
|
|
||||||
submission_user_id=form_model.submission_user_id,
|
|
||||||
submission_end_user_id=form_model.submission_end_user_id,
|
|
||||||
completed_by_recipient_id=form_model.completed_by_recipient_id,
|
|
||||||
recipient_id=recipient_model.id if recipient_model else None,
|
|
||||||
recipient_type=recipient_model.recipient_type if recipient_model else None,
|
|
||||||
access_token=recipient_model.access_token if recipient_model else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _InvalidTimeoutStatusError(ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormRepositoryImpl:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
session_factory: sessionmaker | Engine,
|
|
||||||
tenant_id: str,
|
|
||||||
):
|
|
||||||
if isinstance(session_factory, Engine):
|
|
||||||
session_factory = sessionmaker(bind=session_factory)
|
|
||||||
self._session_factory = session_factory
|
|
||||||
self._tenant_id = tenant_id
|
|
||||||
|
|
||||||
def _delivery_method_to_model(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
form_id: str,
|
|
||||||
delivery_method: DeliveryChannelConfig,
|
|
||||||
) -> _DeliveryAndRecipients:
|
|
||||||
delivery_id = str(uuidv7())
|
|
||||||
delivery_model = HumanInputDelivery(
|
|
||||||
id=delivery_id,
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_method_type=delivery_method.type,
|
|
||||||
delivery_config_id=delivery_method.id,
|
|
||||||
channel_payload=delivery_method.model_dump_json(),
|
|
||||||
)
|
|
||||||
recipients: list[HumanInputFormRecipient] = []
|
|
||||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
|
||||||
recipient_model = HumanInputFormRecipient(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
|
||||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
|
||||||
)
|
|
||||||
recipients.append(recipient_model)
|
|
||||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
|
||||||
email_recipients_config = delivery_method.config.recipients
|
|
||||||
recipients.extend(
|
|
||||||
self._build_email_recipients(
|
|
||||||
session=session,
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
recipients_config=email_recipients_config,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
|
||||||
|
|
||||||
def _build_email_recipients(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
form_id: str,
|
|
||||||
delivery_id: str,
|
|
||||||
recipients_config: EmailRecipients,
|
|
||||||
) -> list[HumanInputFormRecipient]:
|
|
||||||
member_user_ids = [
|
|
||||||
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
|
|
||||||
]
|
|
||||||
external_emails = [
|
|
||||||
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
|
|
||||||
]
|
|
||||||
if recipients_config.whole_workspace:
|
|
||||||
members = self._query_all_workspace_members(session=session)
|
|
||||||
else:
|
|
||||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
|
|
||||||
|
|
||||||
return self._create_email_recipients_from_resolved(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
members=members,
|
|
||||||
external_emails=external_emails,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _create_email_recipients_from_resolved(
|
|
||||||
*,
|
|
||||||
form_id: str,
|
|
||||||
delivery_id: str,
|
|
||||||
members: Sequence[_WorkspaceMemberInfo],
|
|
||||||
external_emails: Sequence[str],
|
|
||||||
) -> list[HumanInputFormRecipient]:
|
|
||||||
recipient_models: list[HumanInputFormRecipient] = []
|
|
||||||
seen_emails: set[str] = set()
|
|
||||||
|
|
||||||
for member in members:
|
|
||||||
if not member.email:
|
|
||||||
continue
|
|
||||||
if member.email in seen_emails:
|
|
||||||
continue
|
|
||||||
seen_emails.add(member.email)
|
|
||||||
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
|
|
||||||
recipient_models.append(
|
|
||||||
HumanInputFormRecipient.new(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
payload=payload,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for email in external_emails:
|
|
||||||
if not email:
|
|
||||||
continue
|
|
||||||
if email in seen_emails:
|
|
||||||
continue
|
|
||||||
seen_emails.add(email)
|
|
||||||
recipient_models.append(
|
|
||||||
HumanInputFormRecipient.new(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
payload=EmailExternalRecipientPayload(email=email),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return recipient_models
|
|
||||||
|
|
||||||
def _query_all_workspace_members(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
) -> list[_WorkspaceMemberInfo]:
|
|
||||||
stmt = (
|
|
||||||
select(Account.id, Account.email)
|
|
||||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
|
||||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
|
||||||
)
|
|
||||||
rows = session.execute(stmt).all()
|
|
||||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
|
||||||
|
|
||||||
def _query_workspace_members_by_ids(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
restrict_to_user_ids: Sequence[str],
|
|
||||||
) -> list[_WorkspaceMemberInfo]:
|
|
||||||
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
|
|
||||||
if not unique_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
stmt = (
|
|
||||||
select(Account.id, Account.email)
|
|
||||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
|
||||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
|
||||||
)
|
|
||||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
|
||||||
|
|
||||||
rows = session.execute(stmt).all()
|
|
||||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
|
||||||
|
|
||||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
|
||||||
form_config: HumanInputNodeData = params.form_config
|
|
||||||
|
|
||||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
|
||||||
# Generate unique form ID
|
|
||||||
form_id = str(uuidv7())
|
|
||||||
start_time = naive_utc_now()
|
|
||||||
node_expiration = form_config.expiration_time(start_time)
|
|
||||||
form_definition = FormDefinition(
|
|
||||||
form_content=form_config.form_content,
|
|
||||||
inputs=form_config.inputs,
|
|
||||||
user_actions=form_config.user_actions,
|
|
||||||
rendered_content=params.rendered_content,
|
|
||||||
expiration_time=node_expiration,
|
|
||||||
default_values=dict(params.resolved_default_values),
|
|
||||||
display_in_ui=params.display_in_ui,
|
|
||||||
node_title=form_config.title,
|
|
||||||
)
|
|
||||||
form_model = HumanInputForm(
|
|
||||||
id=form_id,
|
|
||||||
tenant_id=self._tenant_id,
|
|
||||||
app_id=params.app_id,
|
|
||||||
workflow_run_id=params.workflow_execution_id,
|
|
||||||
form_kind=params.form_kind,
|
|
||||||
node_id=params.node_id,
|
|
||||||
form_definition=form_definition.model_dump_json(),
|
|
||||||
rendered_content=params.rendered_content,
|
|
||||||
expiration_time=node_expiration,
|
|
||||||
created_at=start_time,
|
|
||||||
)
|
|
||||||
session.add(form_model)
|
|
||||||
recipient_models: list[HumanInputFormRecipient] = []
|
|
||||||
for delivery in params.delivery_methods:
|
|
||||||
delivery_and_recipients = self._delivery_method_to_model(
|
|
||||||
session=session,
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_method=delivery,
|
|
||||||
)
|
|
||||||
session.add(delivery_and_recipients.delivery)
|
|
||||||
session.add_all(delivery_and_recipients.recipients)
|
|
||||||
recipient_models.extend(delivery_and_recipients.recipients)
|
|
||||||
if params.console_recipient_required and not any(
|
|
||||||
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
|
|
||||||
):
|
|
||||||
console_delivery_id = str(uuidv7())
|
|
||||||
console_delivery = HumanInputDelivery(
|
|
||||||
id=console_delivery_id,
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
|
||||||
delivery_config_id=None,
|
|
||||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
|
||||||
)
|
|
||||||
console_recipient = HumanInputFormRecipient(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=console_delivery_id,
|
|
||||||
recipient_type=RecipientType.CONSOLE,
|
|
||||||
recipient_payload=ConsoleRecipientPayload(
|
|
||||||
account_id=params.console_creator_account_id,
|
|
||||||
).model_dump_json(),
|
|
||||||
)
|
|
||||||
session.add(console_delivery)
|
|
||||||
session.add(console_recipient)
|
|
||||||
recipient_models.append(console_recipient)
|
|
||||||
if params.backstage_recipient_required and not any(
|
|
||||||
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
|
|
||||||
):
|
|
||||||
backstage_delivery_id = str(uuidv7())
|
|
||||||
backstage_delivery = HumanInputDelivery(
|
|
||||||
id=backstage_delivery_id,
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
|
||||||
delivery_config_id=None,
|
|
||||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
|
||||||
)
|
|
||||||
backstage_recipient = HumanInputFormRecipient(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=backstage_delivery_id,
|
|
||||||
recipient_type=RecipientType.BACKSTAGE,
|
|
||||||
recipient_payload=BackstageRecipientPayload(
|
|
||||||
account_id=params.console_creator_account_id,
|
|
||||||
).model_dump_json(),
|
|
||||||
)
|
|
||||||
session.add(backstage_delivery)
|
|
||||||
session.add(backstage_recipient)
|
|
||||||
recipient_models.append(backstage_recipient)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
|
||||||
|
|
||||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
|
||||||
form_query = select(HumanInputForm).where(
|
|
||||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
|
||||||
HumanInputForm.node_id == node_id,
|
|
||||||
HumanInputForm.tenant_id == self._tenant_id,
|
|
||||||
)
|
|
||||||
with self._session_factory(expire_on_commit=False) as session:
|
|
||||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
|
||||||
if form_model is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
|
|
||||||
recipient_models = session.scalars(recipient_query).all()
|
|
||||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormSubmissionRepository:
|
|
||||||
"""Repository for fetching and submitting human input forms."""
|
|
||||||
|
|
||||||
def __init__(self, session_factory: sessionmaker | Engine):
|
|
||||||
if isinstance(session_factory, Engine):
|
|
||||||
session_factory = sessionmaker(bind=session_factory)
|
|
||||||
self._session_factory = session_factory
|
|
||||||
|
|
||||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
|
||||||
query = (
|
|
||||||
select(HumanInputFormRecipient)
|
|
||||||
.options(selectinload(HumanInputFormRecipient.form))
|
|
||||||
.where(HumanInputFormRecipient.access_token == form_token)
|
|
||||||
)
|
|
||||||
with self._session_factory(expire_on_commit=False) as session:
|
|
||||||
recipient_model = session.scalars(query).first()
|
|
||||||
if recipient_model is None or recipient_model.form is None:
|
|
||||||
return None
|
|
||||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
|
||||||
|
|
||||||
def get_by_form_id_and_recipient_type(
|
|
||||||
self,
|
|
||||||
form_id: str,
|
|
||||||
recipient_type: RecipientType,
|
|
||||||
) -> HumanInputFormRecord | None:
|
|
||||||
query = (
|
|
||||||
select(HumanInputFormRecipient)
|
|
||||||
.options(selectinload(HumanInputFormRecipient.form))
|
|
||||||
.where(
|
|
||||||
HumanInputFormRecipient.form_id == form_id,
|
|
||||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
with self._session_factory(expire_on_commit=False) as session:
|
|
||||||
recipient_model = session.scalars(query).first()
|
|
||||||
if recipient_model is None or recipient_model.form is None:
|
|
||||||
return None
|
|
||||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
|
||||||
|
|
||||||
def mark_submitted(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
form_id: str,
|
|
||||||
recipient_id: str | None,
|
|
||||||
selected_action_id: str,
|
|
||||||
form_data: Mapping[str, Any],
|
|
||||||
submission_user_id: str | None,
|
|
||||||
submission_end_user_id: str | None,
|
|
||||||
) -> HumanInputFormRecord:
|
|
||||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
|
||||||
form_model = session.get(HumanInputForm, form_id)
|
|
||||||
if form_model is None:
|
|
||||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
|
||||||
|
|
||||||
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
|
|
||||||
|
|
||||||
form_model.selected_action_id = selected_action_id
|
|
||||||
form_model.submitted_data = json.dumps(form_data)
|
|
||||||
form_model.submitted_at = naive_utc_now()
|
|
||||||
form_model.status = HumanInputFormStatus.SUBMITTED
|
|
||||||
form_model.submission_user_id = submission_user_id
|
|
||||||
form_model.submission_end_user_id = submission_end_user_id
|
|
||||||
form_model.completed_by_recipient_id = recipient_id
|
|
||||||
|
|
||||||
session.add(form_model)
|
|
||||||
session.flush()
|
|
||||||
session.refresh(form_model)
|
|
||||||
if recipient_model is not None:
|
|
||||||
session.refresh(recipient_model)
|
|
||||||
|
|
||||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
|
||||||
|
|
||||||
def mark_timeout(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
form_id: str,
|
|
||||||
timeout_status: HumanInputFormStatus,
|
|
||||||
reason: str | None = None,
|
|
||||||
) -> HumanInputFormRecord:
|
|
||||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
|
||||||
form_model = session.get(HumanInputForm, form_id)
|
|
||||||
if form_model is None:
|
|
||||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
|
||||||
|
|
||||||
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
|
||||||
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
|
|
||||||
|
|
||||||
# already handled or submitted
|
|
||||||
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
|
||||||
return HumanInputFormRecord.from_models(form_model, None)
|
|
||||||
|
|
||||||
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
|
|
||||||
raise FormNotFoundError(f"form already submitted, id={form_id}")
|
|
||||||
|
|
||||||
form_model.status = timeout_status
|
|
||||||
form_model.selected_action_id = None
|
|
||||||
form_model.submitted_data = None
|
|
||||||
form_model.submission_user_id = None
|
|
||||||
form_model.submission_end_user_id = None
|
|
||||||
form_model.completed_by_recipient_id = None
|
|
||||||
# Reason is recorded in status/error downstream; not stored on form.
|
|
||||||
session.add(form_model)
|
|
||||||
session.flush()
|
|
||||||
session.refresh(form_model)
|
|
||||||
|
|
||||||
return HumanInputFormRecord.from_models(form_model, None)
|
|
||||||
@@ -488,7 +488,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
|||||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._app_id:
|
if self._app_id:
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||||
from libs.exception import BaseHTTPException
|
|
||||||
|
|
||||||
|
|
||||||
class ToolProviderNotFoundError(ValueError):
|
class ToolProviderNotFoundError(ValueError):
|
||||||
@@ -38,12 +37,6 @@ class ToolCredentialPolicyViolationError(ValueError):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
|
|
||||||
error_code = "workflow_tool_human_input_not_supported"
|
|
||||||
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
|
|
||||||
code = 400
|
|
||||||
|
|
||||||
|
|
||||||
class ToolEngineInvokeError(Exception):
|
class ToolEngineInvokeError(Exception):
|
||||||
meta: ToolInvokeMeta
|
meta: ToolInvokeMeta
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,6 @@ from typing import Any
|
|||||||
|
|
||||||
from core.app.app_config.entities import VariableEntity
|
from core.app.app_config.entities import VariableEntity
|
||||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
|
||||||
from core.workflow.enums import NodeType
|
|
||||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||||
|
|
||||||
|
|
||||||
@@ -52,13 +50,6 @@ class WorkflowToolConfigurationUtils:
|
|||||||
|
|
||||||
return [outputs_by_variable[variable] for variable in variable_order]
|
return [outputs_by_variable[variable] for variable in variable_order]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
|
|
||||||
nodes = graph.get("nodes", [])
|
|
||||||
for node in nodes:
|
|
||||||
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
|
|
||||||
raise WorkflowToolHumanInputNotSupportedError()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_is_synced(
|
def check_is_synced(
|
||||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||||
|
|||||||
@@ -98,10 +98,6 @@ class WorkflowTool(Tool):
|
|||||||
invoke_from=self.runtime.invoke_from,
|
invoke_from=self.runtime.invoke_from,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
call_depth=self.workflow_call_depth + 1,
|
call_depth=self.workflow_call_depth + 1,
|
||||||
# NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None`
|
|
||||||
# because workflow pausing mechanisms (such as HumanInput) are not
|
|
||||||
# supported within WorkflowTool execution context.
|
|
||||||
pause_state_config=None,
|
|
||||||
)
|
)
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
data = result.get("data", {})
|
data = result.get("data", {})
|
||||||
|
|||||||
@@ -2,12 +2,10 @@ from .agent import AgentNodeStrategyInit
|
|||||||
from .graph_init_params import GraphInitParams
|
from .graph_init_params import GraphInitParams
|
||||||
from .workflow_execution import WorkflowExecution
|
from .workflow_execution import WorkflowExecution
|
||||||
from .workflow_node_execution import WorkflowNodeExecution
|
from .workflow_node_execution import WorkflowNodeExecution
|
||||||
from .workflow_start_reason import WorkflowStartReason
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentNodeStrategyInit",
|
"AgentNodeStrategyInit",
|
||||||
"GraphInitParams",
|
"GraphInitParams",
|
||||||
"WorkflowExecution",
|
"WorkflowExecution",
|
||||||
"WorkflowNodeExecution",
|
"WorkflowNodeExecution",
|
||||||
"WorkflowStartReason",
|
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -5,16 +5,6 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
|
|
||||||
class GraphInitParams(BaseModel):
|
class GraphInitParams(BaseModel):
|
||||||
"""GraphInitParams encapsulates the configurations and contextual information
|
|
||||||
that remain constant throughout a single execution of the graph engine.
|
|
||||||
|
|
||||||
A single execution is defined as follows: as long as the execution has not reached
|
|
||||||
its conclusion, it is considered one execution. For instance, if a workflow is suspended
|
|
||||||
and later resumed, it is still regarded as a single execution, not two.
|
|
||||||
|
|
||||||
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# init params
|
# init params
|
||||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||||
app_id: str = Field(..., description="app id")
|
app_id: str = Field(..., description="app id")
|
||||||
|
|||||||
@@ -1,11 +1,8 @@
|
|||||||
from collections.abc import Mapping
|
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
from typing import Annotated, Any, Literal, TypeAlias
|
from typing import Annotated, Literal, TypeAlias
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
|
||||||
|
|
||||||
|
|
||||||
class PauseReasonType(StrEnum):
|
class PauseReasonType(StrEnum):
|
||||||
HUMAN_INPUT_REQUIRED = auto()
|
HUMAN_INPUT_REQUIRED = auto()
|
||||||
@@ -14,31 +11,10 @@ class PauseReasonType(StrEnum):
|
|||||||
|
|
||||||
class HumanInputRequired(BaseModel):
|
class HumanInputRequired(BaseModel):
|
||||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||||
|
|
||||||
form_id: str
|
form_id: str
|
||||||
form_content: str
|
# The identifier of the human input node causing the pause.
|
||||||
inputs: list[FormInput] = Field(default_factory=list)
|
|
||||||
actions: list[UserAction] = Field(default_factory=list)
|
|
||||||
display_in_ui: bool = False
|
|
||||||
node_id: str
|
node_id: str
|
||||||
node_title: str
|
|
||||||
|
|
||||||
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
|
|
||||||
# `output_variable_name` to their resolved values.
|
|
||||||
#
|
|
||||||
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
|
|
||||||
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
|
|
||||||
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
|
|
||||||
# `resolved_default_values` is `{"name": "John"}`.
|
|
||||||
#
|
|
||||||
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
|
|
||||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
|
|
||||||
# `HumanInputFormRecipient.access_token`.
|
|
||||||
#
|
|
||||||
# This field is `None` if webapp delivery is not set and not
|
|
||||||
# in orchestrating mode.
|
|
||||||
form_token: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class SchedulingPause(BaseModel):
|
class SchedulingPause(BaseModel):
|
||||||
|
|||||||
@@ -1,8 +0,0 @@
|
|||||||
from enum import StrEnum
|
|
||||||
|
|
||||||
|
|
||||||
class WorkflowStartReason(StrEnum):
|
|
||||||
"""Reason for workflow start events across graph/queue/SSE layers."""
|
|
||||||
|
|
||||||
INITIAL = "initial" # First start of a workflow run.
|
|
||||||
RESUMPTION = "resumption" # Start triggered after resuming a paused run.
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def get_timestamp() -> float:
|
|
||||||
"""Retrieve a timestamp as a float point numer representing the number of seconds
|
|
||||||
since the Unix epoch.
|
|
||||||
|
|
||||||
This function is primarily used to measure the execution time of the workflow engine.
|
|
||||||
Since workflow execution may be paused and resumed on a different machine,
|
|
||||||
`time.perf_counter` cannot be used as it is inconsistent across machines.
|
|
||||||
|
|
||||||
To address this, the function uses the wall clock as the time source.
|
|
||||||
However, it assumes that the clocks of all servers are properly synchronized.
|
|
||||||
"""
|
|
||||||
return round(time.time())
|
|
||||||
@@ -2,14 +2,12 @@
|
|||||||
GraphEngine configuration models.
|
GraphEngine configuration models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
class GraphEngineConfig(BaseModel):
|
class GraphEngineConfig(BaseModel):
|
||||||
"""Configuration for GraphEngine worker pool scaling."""
|
"""Configuration for GraphEngine worker pool scaling."""
|
||||||
|
|
||||||
model_config = ConfigDict(frozen=True)
|
|
||||||
|
|
||||||
min_workers: int = 1
|
min_workers: int = 1
|
||||||
max_workers: int = 5
|
max_workers: int = 5
|
||||||
scale_up_threshold: int = 3
|
scale_up_threshold: int = 3
|
||||||
|
|||||||
@@ -192,13 +192,9 @@ class EventHandler:
|
|||||||
self._event_collector.collect(edge_event)
|
self._event_collector.collect(edge_event)
|
||||||
|
|
||||||
# Enqueue ready nodes
|
# Enqueue ready nodes
|
||||||
if self._graph_execution.is_paused:
|
for node_id in ready_nodes:
|
||||||
for node_id in ready_nodes:
|
self._state_manager.enqueue_node(node_id)
|
||||||
self._graph_runtime_state.register_deferred_node(node_id)
|
self._state_manager.start_execution(node_id)
|
||||||
else:
|
|
||||||
for node_id in ready_nodes:
|
|
||||||
self._state_manager.enqueue_node(node_id)
|
|
||||||
self._state_manager.start_execution(node_id)
|
|
||||||
|
|
||||||
# Update execution tracking
|
# Update execution tracking
|
||||||
self._state_manager.finish_execution(event.node_id)
|
self._state_manager.finish_execution(event.node_id)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ from collections.abc import Generator
|
|||||||
from typing import TYPE_CHECKING, cast, final
|
from typing import TYPE_CHECKING, cast, final
|
||||||
|
|
||||||
from core.workflow.context import capture_current_context
|
from core.workflow.context import capture_current_context
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.enums import NodeExecutionType
|
from core.workflow.enums import NodeExecutionType
|
||||||
from core.workflow.graph import Graph
|
from core.workflow.graph import Graph
|
||||||
from core.workflow.graph_events import (
|
from core.workflow.graph_events import (
|
||||||
@@ -57,9 +56,6 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_CONFIG = GraphEngineConfig()
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
@final
|
||||||
class GraphEngine:
|
class GraphEngine:
|
||||||
"""
|
"""
|
||||||
@@ -75,7 +71,7 @@ class GraphEngine:
|
|||||||
graph: Graph,
|
graph: Graph,
|
||||||
graph_runtime_state: GraphRuntimeState,
|
graph_runtime_state: GraphRuntimeState,
|
||||||
command_channel: CommandChannel,
|
command_channel: CommandChannel,
|
||||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
config: GraphEngineConfig,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||||
# stop event
|
# stop event
|
||||||
@@ -239,9 +235,7 @@ class GraphEngine:
|
|||||||
self._graph_execution.paused = False
|
self._graph_execution.paused = False
|
||||||
self._graph_execution.pause_reasons = []
|
self._graph_execution.pause_reasons = []
|
||||||
|
|
||||||
start_event = GraphRunStartedEvent(
|
start_event = GraphRunStartedEvent()
|
||||||
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
|
|
||||||
)
|
|
||||||
self._event_manager.notify_layers(start_event)
|
self._event_manager.notify_layers(start_event)
|
||||||
yield start_event
|
yield start_event
|
||||||
|
|
||||||
@@ -310,17 +304,15 @@ class GraphEngine:
|
|||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
try:
|
try:
|
||||||
layer.on_graph_start()
|
layer.on_graph_start()
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||||
|
|
||||||
def _start_execution(self, *, resume: bool = False) -> None:
|
def _start_execution(self, *, resume: bool = False) -> None:
|
||||||
"""Start execution subsystems."""
|
"""Start execution subsystems."""
|
||||||
self._stop_event.clear()
|
self._stop_event.clear()
|
||||||
paused_nodes: list[str] = []
|
paused_nodes: list[str] = []
|
||||||
deferred_nodes: list[str] = []
|
|
||||||
if resume:
|
if resume:
|
||||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||||
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
|
|
||||||
|
|
||||||
# Start worker pool (it calculates initial workers internally)
|
# Start worker pool (it calculates initial workers internally)
|
||||||
self._worker_pool.start()
|
self._worker_pool.start()
|
||||||
@@ -336,11 +328,7 @@ class GraphEngine:
|
|||||||
self._state_manager.enqueue_node(root_node.id)
|
self._state_manager.enqueue_node(root_node.id)
|
||||||
self._state_manager.start_execution(root_node.id)
|
self._state_manager.start_execution(root_node.id)
|
||||||
else:
|
else:
|
||||||
seen_nodes: set[str] = set()
|
for node_id in paused_nodes:
|
||||||
for node_id in paused_nodes + deferred_nodes:
|
|
||||||
if node_id in seen_nodes:
|
|
||||||
continue
|
|
||||||
seen_nodes.add(node_id)
|
|
||||||
self._state_manager.enqueue_node(node_id)
|
self._state_manager.enqueue_node(node_id)
|
||||||
self._state_manager.start_execution(node_id)
|
self._state_manager.start_execution(node_id)
|
||||||
|
|
||||||
@@ -358,8 +346,8 @@ class GraphEngine:
|
|||||||
for layer in self._layers:
|
for layer in self._layers:
|
||||||
try:
|
try:
|
||||||
layer.on_graph_end(self._graph_execution.error)
|
layer.on_graph_end(self._graph_execution.error)
|
||||||
except Exception:
|
except Exception as e:
|
||||||
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||||
|
|
||||||
# Public property accessors for attributes that need external access
|
# Public property accessors for attributes that need external access
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -224,8 +224,6 @@ class GraphStateManager:
|
|||||||
Returns:
|
Returns:
|
||||||
Number of executing nodes
|
Number of executing nodes
|
||||||
"""
|
"""
|
||||||
# This count is a best-effort snapshot and can change concurrently.
|
|
||||||
# Only use it for pause-drain checks where scheduling is already frozen.
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
return len(self._executing_nodes)
|
return len(self._executing_nodes)
|
||||||
|
|
||||||
|
|||||||
@@ -83,12 +83,12 @@ class Dispatcher:
|
|||||||
"""Main dispatcher loop."""
|
"""Main dispatcher loop."""
|
||||||
try:
|
try:
|
||||||
self._process_commands()
|
self._process_commands()
|
||||||
paused = False
|
|
||||||
while not self._stop_event.is_set():
|
while not self._stop_event.is_set():
|
||||||
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
|
if (
|
||||||
break
|
self._execution_coordinator.aborted
|
||||||
if self._execution_coordinator.paused:
|
or self._execution_coordinator.paused
|
||||||
paused = True
|
or self._execution_coordinator.execution_complete
|
||||||
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
self._execution_coordinator.check_scaling()
|
self._execution_coordinator.check_scaling()
|
||||||
@@ -101,10 +101,13 @@ class Dispatcher:
|
|||||||
time.sleep(0.1)
|
time.sleep(0.1)
|
||||||
|
|
||||||
self._process_commands()
|
self._process_commands()
|
||||||
if paused:
|
while True:
|
||||||
self._drain_events_until_idle()
|
try:
|
||||||
else:
|
event = self._event_queue.get(block=False)
|
||||||
self._drain_event_queue()
|
self._event_handler.dispatch(event)
|
||||||
|
self._event_queue.task_done()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Dispatcher error")
|
logger.exception("Dispatcher error")
|
||||||
@@ -119,24 +122,3 @@ class Dispatcher:
|
|||||||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||||
self._execution_coordinator.process_commands()
|
self._execution_coordinator.process_commands()
|
||||||
|
|
||||||
def _drain_event_queue(self) -> None:
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
event = self._event_queue.get(block=False)
|
|
||||||
self._event_handler.dispatch(event)
|
|
||||||
self._event_queue.task_done()
|
|
||||||
except queue.Empty:
|
|
||||||
break
|
|
||||||
|
|
||||||
def _drain_events_until_idle(self) -> None:
|
|
||||||
while not self._stop_event.is_set():
|
|
||||||
try:
|
|
||||||
event = self._event_queue.get(timeout=0.1)
|
|
||||||
self._event_handler.dispatch(event)
|
|
||||||
self._event_queue.task_done()
|
|
||||||
self._process_commands(event)
|
|
||||||
except queue.Empty:
|
|
||||||
if not self._execution_coordinator.has_executing_nodes():
|
|
||||||
break
|
|
||||||
self._drain_event_queue()
|
|
||||||
|
|||||||
@@ -94,11 +94,3 @@ class ExecutionCoordinator:
|
|||||||
|
|
||||||
self._worker_pool.stop()
|
self._worker_pool.stop()
|
||||||
self._state_manager.clear_executing()
|
self._state_manager.clear_executing()
|
||||||
|
|
||||||
def has_executing_nodes(self) -> bool:
|
|
||||||
"""Return True if any nodes are currently marked as executing."""
|
|
||||||
# This check is only safe once execution has already paused.
|
|
||||||
# Before pause, executing state can change concurrently, which makes the result unreliable.
|
|
||||||
if not self._graph_execution.is_paused:
|
|
||||||
raise AssertionError("has_executing_nodes should only be called after execution is paused")
|
|
||||||
return self._state_manager.get_executing_count() > 0
|
|
||||||
|
|||||||
@@ -38,8 +38,6 @@ from .loop import (
|
|||||||
from .node import (
|
from .node import (
|
||||||
NodeRunExceptionEvent,
|
NodeRunExceptionEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunHumanInputFormFilledEvent,
|
|
||||||
NodeRunHumanInputFormTimeoutEvent,
|
|
||||||
NodeRunPauseRequestedEvent,
|
NodeRunPauseRequestedEvent,
|
||||||
NodeRunRetrieverResourceEvent,
|
NodeRunRetrieverResourceEvent,
|
||||||
NodeRunRetryEvent,
|
NodeRunRetryEvent,
|
||||||
@@ -62,8 +60,6 @@ __all__ = [
|
|||||||
"NodeRunAgentLogEvent",
|
"NodeRunAgentLogEvent",
|
||||||
"NodeRunExceptionEvent",
|
"NodeRunExceptionEvent",
|
||||||
"NodeRunFailedEvent",
|
"NodeRunFailedEvent",
|
||||||
"NodeRunHumanInputFormFilledEvent",
|
|
||||||
"NodeRunHumanInputFormTimeoutEvent",
|
|
||||||
"NodeRunIterationFailedEvent",
|
"NodeRunIterationFailedEvent",
|
||||||
"NodeRunIterationNextEvent",
|
"NodeRunIterationNextEvent",
|
||||||
"NodeRunIterationStartedEvent",
|
"NodeRunIterationStartedEvent",
|
||||||
|
|||||||
@@ -1,16 +1,11 @@
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
|
||||||
from core.workflow.entities.pause_reason import PauseReason
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
|
||||||
from core.workflow.graph_events import BaseGraphEvent
|
from core.workflow.graph_events import BaseGraphEvent
|
||||||
|
|
||||||
|
|
||||||
class GraphRunStartedEvent(BaseGraphEvent):
|
class GraphRunStartedEvent(BaseGraphEvent):
|
||||||
# Reason is emitted for workflow start events and is always set.
|
pass
|
||||||
reason: WorkflowStartReason = Field(
|
|
||||||
default=WorkflowStartReason.INITIAL,
|
|
||||||
description="reason for workflow start",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||||
|
|||||||
@@ -54,22 +54,6 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
|||||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||||
|
|
||||||
|
|
||||||
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
|
|
||||||
"""Emitted when a HumanInput form is submitted and before the node finishes."""
|
|
||||||
|
|
||||||
node_title: str = Field(..., description="HumanInput node title")
|
|
||||||
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
|
|
||||||
action_id: str = Field(..., description="User action identifier chosen in the form.")
|
|
||||||
action_text: str = Field(..., description="Display text of the chosen action button.")
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
|
|
||||||
"""Emitted when a HumanInput form times out."""
|
|
||||||
|
|
||||||
node_title: str = Field(..., description="HumanInput node title")
|
|
||||||
expiration_time: datetime = Field(..., description="Form expiration time")
|
|
||||||
|
|
||||||
|
|
||||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||||
reason: PauseReason = Field(..., description="pause reason")
|
reason: PauseReason = Field(..., description="pause reason")
|
||||||
|
|
||||||
|
|||||||
@@ -13,8 +13,6 @@ from .loop import (
|
|||||||
LoopSucceededEvent,
|
LoopSucceededEvent,
|
||||||
)
|
)
|
||||||
from .node import (
|
from .node import (
|
||||||
HumanInputFormFilledEvent,
|
|
||||||
HumanInputFormTimeoutEvent,
|
|
||||||
ModelInvokeCompletedEvent,
|
ModelInvokeCompletedEvent,
|
||||||
PauseRequestedEvent,
|
PauseRequestedEvent,
|
||||||
RunRetrieverResourceEvent,
|
RunRetrieverResourceEvent,
|
||||||
@@ -25,8 +23,6 @@ from .node import (
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AgentLogEvent",
|
"AgentLogEvent",
|
||||||
"HumanInputFormFilledEvent",
|
|
||||||
"HumanInputFormTimeoutEvent",
|
|
||||||
"IterationFailedEvent",
|
"IterationFailedEvent",
|
||||||
"IterationNextEvent",
|
"IterationNextEvent",
|
||||||
"IterationStartedEvent",
|
"IterationStartedEvent",
|
||||||
|
|||||||
@@ -47,19 +47,3 @@ class StreamCompletedEvent(NodeEventBase):
|
|||||||
|
|
||||||
class PauseRequestedEvent(NodeEventBase):
|
class PauseRequestedEvent(NodeEventBase):
|
||||||
reason: PauseReason = Field(..., description="pause reason")
|
reason: PauseReason = Field(..., description="pause reason")
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormFilledEvent(NodeEventBase):
|
|
||||||
"""Event emitted when a human input form is submitted."""
|
|
||||||
|
|
||||||
node_title: str
|
|
||||||
rendered_content: str
|
|
||||||
action_id: str
|
|
||||||
action_text: str
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormTimeoutEvent(NodeEventBase):
|
|
||||||
"""Event emitted when a human input form times out."""
|
|
||||||
|
|
||||||
node_title: str
|
|
||||||
expiration_time: datetime
|
|
||||||
|
|||||||
@@ -18,8 +18,6 @@ from core.workflow.graph_events import (
|
|||||||
GraphNodeEventBase,
|
GraphNodeEventBase,
|
||||||
NodeRunAgentLogEvent,
|
NodeRunAgentLogEvent,
|
||||||
NodeRunFailedEvent,
|
NodeRunFailedEvent,
|
||||||
NodeRunHumanInputFormFilledEvent,
|
|
||||||
NodeRunHumanInputFormTimeoutEvent,
|
|
||||||
NodeRunIterationFailedEvent,
|
NodeRunIterationFailedEvent,
|
||||||
NodeRunIterationNextEvent,
|
NodeRunIterationNextEvent,
|
||||||
NodeRunIterationStartedEvent,
|
NodeRunIterationStartedEvent,
|
||||||
@@ -36,8 +34,6 @@ from core.workflow.graph_events import (
|
|||||||
)
|
)
|
||||||
from core.workflow.node_events import (
|
from core.workflow.node_events import (
|
||||||
AgentLogEvent,
|
AgentLogEvent,
|
||||||
HumanInputFormFilledEvent,
|
|
||||||
HumanInputFormTimeoutEvent,
|
|
||||||
IterationFailedEvent,
|
IterationFailedEvent,
|
||||||
IterationNextEvent,
|
IterationNextEvent,
|
||||||
IterationStartedEvent,
|
IterationStartedEvent,
|
||||||
@@ -65,15 +61,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Node(Generic[NodeDataT]):
|
class Node(Generic[NodeDataT]):
|
||||||
"""BaseNode serves as the foundational class for all node implementations.
|
|
||||||
|
|
||||||
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
|
|
||||||
attribute to track files generated by the LLM). However, these states are not persisted
|
|
||||||
when the workflow is suspended or resumed. If a node needs its state to be preserved
|
|
||||||
across workflow suspension and resumption, it should include the relevant state data
|
|
||||||
in its output.
|
|
||||||
"""
|
|
||||||
|
|
||||||
node_type: ClassVar[NodeType]
|
node_type: ClassVar[NodeType]
|
||||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||||
@@ -264,33 +251,10 @@ class Node(Generic[NodeDataT]):
|
|||||||
return self._node_execution_id
|
return self._node_execution_id
|
||||||
|
|
||||||
def ensure_execution_id(self) -> str:
|
def ensure_execution_id(self) -> str:
|
||||||
if self._node_execution_id:
|
if not self._node_execution_id:
|
||||||
return self._node_execution_id
|
self._node_execution_id = str(uuid4())
|
||||||
|
|
||||||
resumed_execution_id = self._restore_execution_id_from_runtime_state()
|
|
||||||
if resumed_execution_id:
|
|
||||||
self._node_execution_id = resumed_execution_id
|
|
||||||
return self._node_execution_id
|
|
||||||
|
|
||||||
self._node_execution_id = str(uuid4())
|
|
||||||
return self._node_execution_id
|
return self._node_execution_id
|
||||||
|
|
||||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
|
||||||
graph_execution = self.graph_runtime_state.graph_execution
|
|
||||||
try:
|
|
||||||
node_executions = graph_execution.node_executions
|
|
||||||
except AttributeError:
|
|
||||||
return None
|
|
||||||
if not isinstance(node_executions, dict):
|
|
||||||
return None
|
|
||||||
node_execution = node_executions.get(self._node_id)
|
|
||||||
if node_execution is None:
|
|
||||||
return None
|
|
||||||
execution_id = node_execution.execution_id
|
|
||||||
if not execution_id:
|
|
||||||
return None
|
|
||||||
return str(execution_id)
|
|
||||||
|
|
||||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||||
|
|
||||||
@@ -656,28 +620,6 @@ class Node(Generic[NodeDataT]):
|
|||||||
metadata=event.metadata,
|
metadata=event.metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
@_dispatch.register
|
|
||||||
def _(self, event: HumanInputFormFilledEvent):
|
|
||||||
return NodeRunHumanInputFormFilledEvent(
|
|
||||||
id=self.execution_id,
|
|
||||||
node_id=self._node_id,
|
|
||||||
node_type=self.node_type,
|
|
||||||
node_title=event.node_title,
|
|
||||||
rendered_content=event.rendered_content,
|
|
||||||
action_id=event.action_id,
|
|
||||||
action_text=event.action_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
@_dispatch.register
|
|
||||||
def _(self, event: HumanInputFormTimeoutEvent):
|
|
||||||
return NodeRunHumanInputFormTimeoutEvent(
|
|
||||||
id=self.execution_id,
|
|
||||||
node_id=self._node_id,
|
|
||||||
node_type=self.node_type,
|
|
||||||
node_title=event.node_title,
|
|
||||||
expiration_time=event.expiration_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
@_dispatch.register
|
@_dispatch.register
|
||||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||||
return NodeRunLoopStartedEvent(
|
return NodeRunLoopStartedEvent(
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
"""
|
from .human_input_node import HumanInputNode
|
||||||
Human Input node implementation.
|
|
||||||
"""
|
__all__ = ["HumanInputNode"]
|
||||||
|
|||||||
@@ -1,350 +1,10 @@
|
|||||||
"""
|
from pydantic import Field
|
||||||
Human Input node entities.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
import uuid
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from datetime import datetime, timedelta
|
|
||||||
from typing import Annotated, Any, ClassVar, Literal, Self
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
|
||||||
|
|
||||||
from core.variables.consts import SELECTORS_LENGTH
|
|
||||||
from core.workflow.nodes.base import BaseNodeData
|
from core.workflow.nodes.base import BaseNodeData
|
||||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
|
||||||
from core.workflow.runtime import VariablePool
|
|
||||||
|
|
||||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
|
||||||
|
|
||||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
class _WebAppDeliveryConfig(BaseModel):
|
|
||||||
"""Configuration for webapp delivery method."""
|
|
||||||
|
|
||||||
pass # Empty for webapp delivery
|
|
||||||
|
|
||||||
|
|
||||||
class MemberRecipient(BaseModel):
|
|
||||||
"""Member recipient for email delivery."""
|
|
||||||
|
|
||||||
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
|
|
||||||
class ExternalRecipient(BaseModel):
|
|
||||||
"""External recipient for email delivery."""
|
|
||||||
|
|
||||||
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
|
|
||||||
|
|
||||||
|
|
||||||
class EmailRecipients(BaseModel):
|
|
||||||
"""Email recipients configuration."""
|
|
||||||
|
|
||||||
# When true, recipients are the union of all workspace members and external items.
|
|
||||||
# Member items are ignored because they are already covered by the workspace scope.
|
|
||||||
# De-duplication is applied by email, with member recipients taking precedence.
|
|
||||||
whole_workspace: bool = False
|
|
||||||
items: list[EmailRecipient] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class EmailDeliveryConfig(BaseModel):
|
|
||||||
"""Configuration for email delivery method."""
|
|
||||||
|
|
||||||
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
|
|
||||||
|
|
||||||
recipients: EmailRecipients
|
|
||||||
|
|
||||||
# the subject of email
|
|
||||||
subject: str
|
|
||||||
|
|
||||||
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
|
|
||||||
# represent the url to submit the form.
|
|
||||||
#
|
|
||||||
# It may also reference the output variable of the previous node with the syntax
|
|
||||||
# `{{#<node_id>.<field_name>#}}`.
|
|
||||||
body: str
|
|
||||||
debug_mode: bool = False
|
|
||||||
|
|
||||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
|
||||||
if not user_id:
|
|
||||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
|
||||||
return self.model_copy(update={"recipients": debug_recipients})
|
|
||||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
|
||||||
return self.model_copy(update={"recipients": debug_recipients})
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
|
|
||||||
"""Replace the url placeholder with provided value."""
|
|
||||||
return body.replace(cls.URL_PLACEHOLDER, url or "")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def render_body_template(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
body: str,
|
|
||||||
url: str | None,
|
|
||||||
variable_pool: VariablePool | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Render email body by replacing placeholders with runtime values."""
|
|
||||||
templated_body = cls.replace_url_placeholder(body, url)
|
|
||||||
if variable_pool is None:
|
|
||||||
return templated_body
|
|
||||||
return variable_pool.convert_template(templated_body).text
|
|
||||||
|
|
||||||
|
|
||||||
class _DeliveryMethodBase(BaseModel):
|
|
||||||
"""Base delivery method configuration."""
|
|
||||||
|
|
||||||
enabled: bool = True
|
|
||||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
|
||||||
|
|
||||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
|
||||||
return ()
|
|
||||||
|
|
||||||
|
|
||||||
class WebAppDeliveryMethod(_DeliveryMethodBase):
|
|
||||||
"""Webapp delivery method configuration."""
|
|
||||||
|
|
||||||
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
|
|
||||||
# The config field is not used currently.
|
|
||||||
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
|
|
||||||
|
|
||||||
|
|
||||||
class EmailDeliveryMethod(_DeliveryMethodBase):
|
|
||||||
"""Email delivery method configuration."""
|
|
||||||
|
|
||||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
|
||||||
config: EmailDeliveryConfig
|
|
||||||
|
|
||||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
|
||||||
variable_template_parser = VariableTemplateParser(template=self.config.body)
|
|
||||||
selectors: list[Sequence[str]] = []
|
|
||||||
for variable_selector in variable_template_parser.extract_variable_selectors():
|
|
||||||
value_selector = list(variable_selector.value_selector)
|
|
||||||
if len(value_selector) < SELECTORS_LENGTH:
|
|
||||||
continue
|
|
||||||
selectors.append(value_selector[:SELECTORS_LENGTH])
|
|
||||||
return selectors
|
|
||||||
|
|
||||||
|
|
||||||
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
|
|
||||||
|
|
||||||
|
|
||||||
def apply_debug_email_recipient(
|
|
||||||
method: DeliveryChannelConfig,
|
|
||||||
*,
|
|
||||||
enabled: bool,
|
|
||||||
user_id: str,
|
|
||||||
) -> DeliveryChannelConfig:
|
|
||||||
if not enabled:
|
|
||||||
return method
|
|
||||||
if not isinstance(method, EmailDeliveryMethod):
|
|
||||||
return method
|
|
||||||
if not method.config.debug_mode:
|
|
||||||
return method
|
|
||||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
|
||||||
return method.model_copy(update={"config": debug_config})
|
|
||||||
|
|
||||||
|
|
||||||
class FormInputDefault(BaseModel):
|
|
||||||
"""Default configuration for form inputs."""
|
|
||||||
|
|
||||||
# NOTE: Ideally, a discriminated union would be used to model
|
|
||||||
# FormInputDefault. However, the UI requires preserving the previous
|
|
||||||
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
|
||||||
# necessitates retaining all fields, making a discriminated union unsuitable.
|
|
||||||
|
|
||||||
type: PlaceholderType
|
|
||||||
|
|
||||||
# The selector of default variable, used when `type` is `VARIABLE`.
|
|
||||||
selector: Sequence[str] = Field(default_factory=tuple) #
|
|
||||||
|
|
||||||
# The value of the default, used when `type` is `CONSTANT`.
|
|
||||||
# TODO: How should we express JSON values?
|
|
||||||
value: str = ""
|
|
||||||
|
|
||||||
@model_validator(mode="after")
|
|
||||||
def _validate_selector(self) -> Self:
|
|
||||||
if self.type == PlaceholderType.CONSTANT:
|
|
||||||
return self
|
|
||||||
if len(self.selector) < SELECTORS_LENGTH:
|
|
||||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class FormInput(BaseModel):
|
|
||||||
"""Form input definition."""
|
|
||||||
|
|
||||||
type: FormInputType
|
|
||||||
output_variable_name: str
|
|
||||||
default: FormInputDefault | None = None
|
|
||||||
|
|
||||||
|
|
||||||
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
|
||||||
|
|
||||||
|
|
||||||
class UserAction(BaseModel):
|
|
||||||
"""User action configuration."""
|
|
||||||
|
|
||||||
# id is the identifier for this action.
|
|
||||||
# It also serves as the identifiers of output handle.
|
|
||||||
#
|
|
||||||
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
|
||||||
id: str = Field(max_length=20)
|
|
||||||
title: str = Field(max_length=20)
|
|
||||||
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
|
||||||
|
|
||||||
@field_validator("id")
|
|
||||||
@classmethod
|
|
||||||
def _validate_id(cls, value: str) -> str:
|
|
||||||
if not _IDENTIFIER_PATTERN.match(value):
|
|
||||||
raise ValueError(
|
|
||||||
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
|
||||||
f"and contain only letters, numbers, or underscores."
|
|
||||||
)
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputNodeData(BaseNodeData):
|
class HumanInputNodeData(BaseNodeData):
|
||||||
"""Human Input node data."""
|
"""Configuration schema for the HumanInput node."""
|
||||||
|
|
||||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
required_variables: list[str] = Field(default_factory=list)
|
||||||
form_content: str = ""
|
pause_reason: str | None = Field(default=None)
|
||||||
inputs: list[FormInput] = Field(default_factory=list)
|
|
||||||
user_actions: list[UserAction] = Field(default_factory=list)
|
|
||||||
timeout: int = 36
|
|
||||||
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
|
||||||
|
|
||||||
@field_validator("inputs")
|
|
||||||
@classmethod
|
|
||||||
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
|
||||||
seen_names: set[str] = set()
|
|
||||||
for form_input in inputs:
|
|
||||||
name = form_input.output_variable_name
|
|
||||||
if name in seen_names:
|
|
||||||
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
|
||||||
seen_names.add(name)
|
|
||||||
return inputs
|
|
||||||
|
|
||||||
@field_validator("user_actions")
|
|
||||||
@classmethod
|
|
||||||
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
|
||||||
seen_ids: set[str] = set()
|
|
||||||
for action in user_actions:
|
|
||||||
action_id = action.id
|
|
||||||
if action_id in seen_ids:
|
|
||||||
raise ValueError(f"duplicated user action id '{action_id}'")
|
|
||||||
seen_ids.add(action_id)
|
|
||||||
return user_actions
|
|
||||||
|
|
||||||
def is_webapp_enabled(self) -> bool:
|
|
||||||
for dm in self.delivery_methods:
|
|
||||||
if not dm.enabled:
|
|
||||||
continue
|
|
||||||
if dm.type == DeliveryMethodType.WEBAPP:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def expiration_time(self, start_time: datetime) -> datetime:
|
|
||||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
|
||||||
return start_time + timedelta(hours=self.timeout)
|
|
||||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
|
||||||
return start_time + timedelta(days=self.timeout)
|
|
||||||
else:
|
|
||||||
raise AssertionError("unknown timeout unit.")
|
|
||||||
|
|
||||||
def outputs_field_names(self) -> Sequence[str]:
|
|
||||||
field_names = []
|
|
||||||
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
|
||||||
field_names.append(match.group("field_name"))
|
|
||||||
return field_names
|
|
||||||
|
|
||||||
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
|
||||||
variable_mappings: dict[str, Sequence[str]] = {}
|
|
||||||
|
|
||||||
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
|
|
||||||
for selector in selectors:
|
|
||||||
if len(selector) < SELECTORS_LENGTH:
|
|
||||||
continue
|
|
||||||
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
|
|
||||||
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
|
|
||||||
|
|
||||||
form_template_parser = VariableTemplateParser(template=self.form_content)
|
|
||||||
_add_variable_selectors(
|
|
||||||
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
|
|
||||||
)
|
|
||||||
for delivery_method in self.delivery_methods:
|
|
||||||
if not delivery_method.enabled:
|
|
||||||
continue
|
|
||||||
_add_variable_selectors(delivery_method.extract_variable_selectors())
|
|
||||||
|
|
||||||
for input in self.inputs:
|
|
||||||
default_value = input.default
|
|
||||||
if default_value is None:
|
|
||||||
continue
|
|
||||||
if default_value.type == PlaceholderType.CONSTANT:
|
|
||||||
continue
|
|
||||||
default_value_key = ".".join(default_value.selector)
|
|
||||||
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
|
|
||||||
variable_mappings[qualified_variable_mapping_key] = default_value.selector
|
|
||||||
|
|
||||||
return variable_mappings
|
|
||||||
|
|
||||||
def find_action_text(self, action_id: str) -> str:
|
|
||||||
"""
|
|
||||||
Resolve action display text by id.
|
|
||||||
"""
|
|
||||||
for action in self.user_actions:
|
|
||||||
if action.id == action_id:
|
|
||||||
return action.title
|
|
||||||
return action_id
|
|
||||||
|
|
||||||
|
|
||||||
class FormDefinition(BaseModel):
|
|
||||||
form_content: str
|
|
||||||
inputs: list[FormInput] = Field(default_factory=list)
|
|
||||||
user_actions: list[UserAction] = Field(default_factory=list)
|
|
||||||
rendered_content: str
|
|
||||||
expiration_time: datetime
|
|
||||||
|
|
||||||
# this is used to store the resolved default values
|
|
||||||
default_values: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
# node_title records the title of the HumanInput node.
|
|
||||||
node_title: str | None = None
|
|
||||||
|
|
||||||
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
|
||||||
display_in_ui: bool | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputSubmissionValidationError(ValueError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def validate_human_input_submission(
|
|
||||||
*,
|
|
||||||
inputs: Sequence[FormInput],
|
|
||||||
user_actions: Sequence[UserAction],
|
|
||||||
selected_action_id: str,
|
|
||||||
form_data: Mapping[str, Any],
|
|
||||||
) -> None:
|
|
||||||
available_actions = {action.id for action in user_actions}
|
|
||||||
if selected_action_id not in available_actions:
|
|
||||||
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
|
||||||
|
|
||||||
provided_inputs = set(form_data.keys())
|
|
||||||
missing_inputs = [
|
|
||||||
form_input.output_variable_name
|
|
||||||
for form_input in inputs
|
|
||||||
if form_input.output_variable_name not in provided_inputs
|
|
||||||
]
|
|
||||||
|
|
||||||
if missing_inputs:
|
|
||||||
missing_list = ", ".join(missing_inputs)
|
|
||||||
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
import enum
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormStatus(enum.StrEnum):
|
|
||||||
"""Status of a human input form."""
|
|
||||||
|
|
||||||
# Awaiting submission from any recipient. Forms stay in this state until
|
|
||||||
# submitted or a timeout rule applies.
|
|
||||||
WAITING = enum.auto()
|
|
||||||
# Global timeout reached. The workflow run is stopped and will not resume.
|
|
||||||
# This is distinct from node-level timeout.
|
|
||||||
EXPIRED = enum.auto()
|
|
||||||
# Submitted by a recipient; form data is available and execution resumes
|
|
||||||
# along the selected action edge.
|
|
||||||
SUBMITTED = enum.auto()
|
|
||||||
# Node-level timeout reached. The human input node should emit a timeout
|
|
||||||
# event and the workflow should resume along the timeout edge.
|
|
||||||
TIMEOUT = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormKind(enum.StrEnum):
|
|
||||||
"""Kind of a human input form."""
|
|
||||||
|
|
||||||
RUNTIME = enum.auto() # Form created during workflow execution.
|
|
||||||
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
|
|
||||||
|
|
||||||
|
|
||||||
class DeliveryMethodType(enum.StrEnum):
|
|
||||||
"""Delivery method types for human input forms."""
|
|
||||||
|
|
||||||
# WEBAPP controls whether the form is delivered to the web app. It not only controls
|
|
||||||
# the standalone web app, but also controls the installed apps in the console.
|
|
||||||
WEBAPP = enum.auto()
|
|
||||||
|
|
||||||
EMAIL = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class ButtonStyle(enum.StrEnum):
|
|
||||||
"""Button styles for user actions."""
|
|
||||||
|
|
||||||
PRIMARY = enum.auto()
|
|
||||||
DEFAULT = enum.auto()
|
|
||||||
ACCENT = enum.auto()
|
|
||||||
GHOST = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class TimeoutUnit(enum.StrEnum):
|
|
||||||
"""Timeout unit for form expiration."""
|
|
||||||
|
|
||||||
HOUR = enum.auto()
|
|
||||||
DAY = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class FormInputType(enum.StrEnum):
|
|
||||||
"""Form input types."""
|
|
||||||
|
|
||||||
TEXT_INPUT = enum.auto()
|
|
||||||
PARAGRAPH = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class PlaceholderType(enum.StrEnum):
|
|
||||||
"""Default value types for form inputs."""
|
|
||||||
|
|
||||||
VARIABLE = enum.auto()
|
|
||||||
CONSTANT = enum.auto()
|
|
||||||
|
|
||||||
|
|
||||||
class EmailRecipientType(enum.StrEnum):
|
|
||||||
"""Email recipient types."""
|
|
||||||
|
|
||||||
MEMBER = enum.auto()
|
|
||||||
EXTERNAL = enum.auto()
|
|
||||||
@@ -1,42 +1,12 @@
|
|||||||
import json
|
from collections.abc import Mapping
|
||||||
import logging
|
from typing import Any
|
||||||
from collections.abc import Generator, Mapping, Sequence
|
|
||||||
from typing import TYPE_CHECKING, Any
|
|
||||||
|
|
||||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
|
||||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||||
from core.workflow.node_events import (
|
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||||
HumanInputFormFilledEvent,
|
|
||||||
HumanInputFormTimeoutEvent,
|
|
||||||
NodeRunResult,
|
|
||||||
PauseRequestedEvent,
|
|
||||||
)
|
|
||||||
from core.workflow.node_events.base import NodeEventBase
|
|
||||||
from core.workflow.node_events.node import StreamCompletedEvent
|
|
||||||
from core.workflow.nodes.base.node import Node
|
from core.workflow.nodes.base.node import Node
|
||||||
from core.workflow.repositories.human_input_form_repository import (
|
|
||||||
FormCreateParams,
|
|
||||||
HumanInputFormEntity,
|
|
||||||
HumanInputFormRepository,
|
|
||||||
)
|
|
||||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
|
||||||
from extensions.ext_database import db
|
|
||||||
from libs.datetime_utils import naive_utc_now
|
|
||||||
|
|
||||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
from .entities import HumanInputNodeData
|
||||||
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
|
||||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
|
||||||
|
|
||||||
|
|
||||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputNode(Node[HumanInputNodeData]):
|
class HumanInputNode(Node[HumanInputNodeData]):
|
||||||
@@ -47,7 +17,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||||||
"edge_source_handle",
|
"edge_source_handle",
|
||||||
"edgeSourceHandle",
|
"edgeSourceHandle",
|
||||||
"source_handle",
|
"source_handle",
|
||||||
_SELECTED_BRANCH_KEY,
|
"selected_branch",
|
||||||
"selectedBranch",
|
"selectedBranch",
|
||||||
"branch",
|
"branch",
|
||||||
"branch_id",
|
"branch_id",
|
||||||
@@ -55,37 +25,43 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||||||
"handle",
|
"handle",
|
||||||
)
|
)
|
||||||
|
|
||||||
_node_data: HumanInputNodeData
|
|
||||||
_form_repository: HumanInputFormRepository
|
|
||||||
_OUTPUT_FIELD_ACTION_ID = "__action_id"
|
|
||||||
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
|
|
||||||
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
id: str,
|
|
||||||
config: Mapping[str, Any],
|
|
||||||
graph_init_params: "GraphInitParams",
|
|
||||||
graph_runtime_state: "GraphRuntimeState",
|
|
||||||
form_repository: HumanInputFormRepository | None = None,
|
|
||||||
) -> None:
|
|
||||||
super().__init__(
|
|
||||||
id=id,
|
|
||||||
config=config,
|
|
||||||
graph_init_params=graph_init_params,
|
|
||||||
graph_runtime_state=graph_runtime_state,
|
|
||||||
)
|
|
||||||
if form_repository is None:
|
|
||||||
form_repository = HumanInputFormRepositoryImpl(
|
|
||||||
session_factory=db.engine,
|
|
||||||
tenant_id=self.tenant_id,
|
|
||||||
)
|
|
||||||
self._form_repository = form_repository
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def version(cls) -> str:
|
def version(cls) -> str:
|
||||||
return "1"
|
return "1"
|
||||||
|
|
||||||
|
def _run(self): # type: ignore[override]
|
||||||
|
if self._is_completion_ready():
|
||||||
|
branch_handle = self._resolve_branch_selection()
|
||||||
|
return NodeRunResult(
|
||||||
|
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||||
|
outputs={},
|
||||||
|
edge_source_handle=branch_handle or "source",
|
||||||
|
)
|
||||||
|
|
||||||
|
return self._pause_generator()
|
||||||
|
|
||||||
|
def _pause_generator(self):
|
||||||
|
# TODO(QuantumGhost): yield a real form id.
|
||||||
|
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||||
|
|
||||||
|
def _is_completion_ready(self) -> bool:
|
||||||
|
"""Determine whether all required inputs are satisfied."""
|
||||||
|
|
||||||
|
if not self.node_data.required_variables:
|
||||||
|
return False
|
||||||
|
|
||||||
|
variable_pool = self.graph_runtime_state.variable_pool
|
||||||
|
|
||||||
|
for selector_str in self.node_data.required_variables:
|
||||||
|
parts = selector_str.split(".")
|
||||||
|
if len(parts) != 2:
|
||||||
|
return False
|
||||||
|
segment = variable_pool.get(parts)
|
||||||
|
if segment is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def _resolve_branch_selection(self) -> str | None:
|
def _resolve_branch_selection(self) -> str | None:
|
||||||
"""Determine the branch handle selected by human input if available."""
|
"""Determine the branch handle selected by human input if available."""
|
||||||
|
|
||||||
@@ -132,224 +108,3 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
|||||||
return candidate
|
return candidate
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@property
|
|
||||||
def _workflow_execution_id(self) -> str:
|
|
||||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
|
||||||
assert workflow_exec_id is not None
|
|
||||||
return workflow_exec_id
|
|
||||||
|
|
||||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
|
||||||
required_event = self._human_input_required_event(form_entity)
|
|
||||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
|
||||||
return pause_requested_event
|
|
||||||
|
|
||||||
def resolve_default_values(self) -> Mapping[str, Any]:
|
|
||||||
variable_pool = self.graph_runtime_state.variable_pool
|
|
||||||
resolved_defaults: dict[str, Any] = {}
|
|
||||||
for input in self._node_data.inputs:
|
|
||||||
if (default_value := input.default) is None:
|
|
||||||
continue
|
|
||||||
if default_value.type == PlaceholderType.CONSTANT:
|
|
||||||
continue
|
|
||||||
resolved_value = variable_pool.get(default_value.selector)
|
|
||||||
if resolved_value is None:
|
|
||||||
# TODO: How should we handle this?
|
|
||||||
continue
|
|
||||||
resolved_defaults[input.output_variable_name] = (
|
|
||||||
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
|
|
||||||
)
|
|
||||||
|
|
||||||
return resolved_defaults
|
|
||||||
|
|
||||||
def _should_require_console_recipient(self) -> bool:
|
|
||||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
|
||||||
return True
|
|
||||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
|
||||||
return self._node_data.is_webapp_enabled()
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _display_in_ui(self) -> bool:
|
|
||||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
|
||||||
return True
|
|
||||||
return self._node_data.is_webapp_enabled()
|
|
||||||
|
|
||||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
|
||||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
|
||||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
|
||||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
|
||||||
return [
|
|
||||||
apply_debug_email_recipient(
|
|
||||||
method,
|
|
||||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
|
||||||
user_id=self.user_id or "",
|
|
||||||
)
|
|
||||||
for method in enabled_methods
|
|
||||||
]
|
|
||||||
|
|
||||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
|
||||||
node_data = self._node_data
|
|
||||||
resolved_default_values = self.resolve_default_values()
|
|
||||||
display_in_ui = self._display_in_ui()
|
|
||||||
form_token = form_entity.web_app_token
|
|
||||||
if display_in_ui and form_token is None:
|
|
||||||
raise AssertionError("Form token should be available for UI execution.")
|
|
||||||
return HumanInputRequired(
|
|
||||||
form_id=form_entity.id,
|
|
||||||
form_content=form_entity.rendered_content,
|
|
||||||
inputs=node_data.inputs,
|
|
||||||
actions=node_data.user_actions,
|
|
||||||
display_in_ui=display_in_ui,
|
|
||||||
node_id=self.id,
|
|
||||||
node_title=node_data.title,
|
|
||||||
form_token=form_token,
|
|
||||||
resolved_default_values=resolved_default_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
|
||||||
"""
|
|
||||||
Execute the human input node.
|
|
||||||
|
|
||||||
This method will:
|
|
||||||
1. Generate a unique form ID
|
|
||||||
2. Create form content with variable substitution
|
|
||||||
3. Create form in database
|
|
||||||
4. Send form via configured delivery methods
|
|
||||||
5. Suspend workflow execution
|
|
||||||
6. Wait for form submission to resume
|
|
||||||
"""
|
|
||||||
repo = self._form_repository
|
|
||||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
|
||||||
if form is None:
|
|
||||||
display_in_ui = self._display_in_ui()
|
|
||||||
params = FormCreateParams(
|
|
||||||
app_id=self.app_id,
|
|
||||||
workflow_execution_id=self._workflow_execution_id,
|
|
||||||
node_id=self.id,
|
|
||||||
form_config=self._node_data,
|
|
||||||
rendered_content=self.render_form_content_before_submission(),
|
|
||||||
delivery_methods=self._effective_delivery_methods(),
|
|
||||||
display_in_ui=display_in_ui,
|
|
||||||
resolved_default_values=self.resolve_default_values(),
|
|
||||||
console_recipient_required=self._should_require_console_recipient(),
|
|
||||||
console_creator_account_id=(
|
|
||||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
|
||||||
),
|
|
||||||
backstage_recipient_required=True,
|
|
||||||
)
|
|
||||||
form_entity = self._form_repository.create_form(params)
|
|
||||||
# Create human input required event
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
|
||||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
|
||||||
self.id,
|
|
||||||
form_entity.id,
|
|
||||||
)
|
|
||||||
yield self._form_to_pause_event(form_entity)
|
|
||||||
return
|
|
||||||
|
|
||||||
if (
|
|
||||||
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
|
|
||||||
or form.expiration_time <= naive_utc_now()
|
|
||||||
):
|
|
||||||
yield HumanInputFormTimeoutEvent(
|
|
||||||
node_title=self._node_data.title,
|
|
||||||
expiration_time=form.expiration_time,
|
|
||||||
)
|
|
||||||
yield StreamCompletedEvent(
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
||||||
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
|
|
||||||
edge_source_handle=self._TIMEOUT_HANDLE,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if not form.submitted:
|
|
||||||
yield self._form_to_pause_event(form)
|
|
||||||
return
|
|
||||||
|
|
||||||
selected_action_id = form.selected_action_id
|
|
||||||
if selected_action_id is None:
|
|
||||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
|
||||||
submitted_data = form.submitted_data or {}
|
|
||||||
outputs: dict[str, Any] = dict(submitted_data)
|
|
||||||
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
|
|
||||||
rendered_content = self.render_form_content_with_outputs(
|
|
||||||
form.rendered_content,
|
|
||||||
outputs,
|
|
||||||
self._node_data.outputs_field_names(),
|
|
||||||
)
|
|
||||||
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
|
|
||||||
|
|
||||||
action_text = self._node_data.find_action_text(selected_action_id)
|
|
||||||
|
|
||||||
yield HumanInputFormFilledEvent(
|
|
||||||
node_title=self._node_data.title,
|
|
||||||
rendered_content=rendered_content,
|
|
||||||
action_id=selected_action_id,
|
|
||||||
action_text=action_text,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield StreamCompletedEvent(
|
|
||||||
node_run_result=NodeRunResult(
|
|
||||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
|
||||||
outputs=outputs,
|
|
||||||
edge_source_handle=selected_action_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def render_form_content_before_submission(self) -> str:
|
|
||||||
"""
|
|
||||||
Process form content by substituting variables.
|
|
||||||
|
|
||||||
This method should:
|
|
||||||
1. Parse the form_content markdown
|
|
||||||
2. Substitute {{#node_name.var_name#}} with actual values
|
|
||||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
|
||||||
"""
|
|
||||||
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
|
|
||||||
self._node_data.form_content,
|
|
||||||
)
|
|
||||||
return rendered_form_content.markdown
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def render_form_content_with_outputs(
|
|
||||||
form_content: str,
|
|
||||||
outputs: Mapping[str, Any],
|
|
||||||
field_names: Sequence[str],
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Replace {{#$output.xxx#}} placeholders with submitted values.
|
|
||||||
"""
|
|
||||||
rendered_content = form_content
|
|
||||||
for field_name in field_names:
|
|
||||||
placeholder = "{{#$output." + field_name + "#}}"
|
|
||||||
value = outputs.get(field_name)
|
|
||||||
if value is None:
|
|
||||||
replacement = ""
|
|
||||||
elif isinstance(value, (dict, list)):
|
|
||||||
replacement = json.dumps(value, ensure_ascii=False)
|
|
||||||
else:
|
|
||||||
replacement = str(value)
|
|
||||||
rendered_content = rendered_content.replace(placeholder, replacement)
|
|
||||||
return rendered_content
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _extract_variable_selector_to_variable_mapping(
|
|
||||||
cls,
|
|
||||||
*,
|
|
||||||
graph_config: Mapping[str, Any],
|
|
||||||
node_id: str,
|
|
||||||
node_data: Mapping[str, Any],
|
|
||||||
) -> Mapping[str, Sequence[str]]:
|
|
||||||
"""
|
|
||||||
Extract variable selectors referenced in form content and input default values.
|
|
||||||
|
|
||||||
This method should parse:
|
|
||||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
|
||||||
2. Variables referenced in input default values
|
|
||||||
"""
|
|
||||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
|
||||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
|
||||||
|
|||||||
@@ -1,152 +0,0 @@
|
|||||||
import abc
|
|
||||||
import dataclasses
|
|
||||||
from collections.abc import Mapping, Sequence
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any, Protocol
|
|
||||||
|
|
||||||
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
|
|
||||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputError(Exception):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class FormNotFoundError(HumanInputError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
|
||||||
class FormCreateParams:
|
|
||||||
# app_id is the identifier for the app that the form belongs to.
|
|
||||||
# It is a string with uuid format.
|
|
||||||
app_id: str
|
|
||||||
# None when creating a delivery test form; set for runtime forms.
|
|
||||||
workflow_execution_id: str | None
|
|
||||||
|
|
||||||
# node_id is the identifier for a specific
|
|
||||||
# node in the graph.
|
|
||||||
#
|
|
||||||
# TODO: for node inside loop / iteration, this would
|
|
||||||
# cause problems, as a single node may be executed multiple times.
|
|
||||||
node_id: str
|
|
||||||
|
|
||||||
form_config: HumanInputNodeData
|
|
||||||
rendered_content: str
|
|
||||||
# Delivery methods already filtered by runtime context (invoke_from).
|
|
||||||
delivery_methods: Sequence[DeliveryChannelConfig]
|
|
||||||
# UI display flag computed by runtime context.
|
|
||||||
display_in_ui: bool
|
|
||||||
|
|
||||||
# resolved_default_values saves the values for defaults with
|
|
||||||
# type = VARIABLE.
|
|
||||||
#
|
|
||||||
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
|
|
||||||
resolved_default_values: Mapping[str, Any]
|
|
||||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
|
||||||
|
|
||||||
# Force creating a console-only recipient for submission in Console.
|
|
||||||
console_recipient_required: bool = False
|
|
||||||
console_creator_account_id: str | None = None
|
|
||||||
# Force creating a backstage recipient for submission in Console.
|
|
||||||
backstage_recipient_required: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormEntity(abc.ABC):
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def id(self) -> str:
|
|
||||||
"""id returns the identifer of the form."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def web_app_token(self) -> str | None:
|
|
||||||
"""web_app_token returns the token for submission inside webapp.
|
|
||||||
|
|
||||||
For console/debug execution, this may point to the console submission token
|
|
||||||
if the form is configured to require console delivery.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# TODO: what if the users are allowed to add multiple
|
|
||||||
# webapp delivery?
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def rendered_content(self) -> str:
|
|
||||||
"""Rendered markdown content associated with the form."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def selected_action_id(self) -> str | None:
|
|
||||||
"""Identifier of the selected user action if the form has been submitted."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
|
||||||
"""Submitted form data if available."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def submitted(self) -> bool:
|
|
||||||
"""Whether the form has been submitted."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def status(self) -> HumanInputFormStatus:
|
|
||||||
"""Current status of the form."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def expiration_time(self) -> datetime:
|
|
||||||
"""When the form expires."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormRecipientEntity(abc.ABC):
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def id(self) -> str:
|
|
||||||
"""id returns the identifer of this recipient."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def token(self) -> str:
|
|
||||||
"""token returns a random string used to submit form"""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormRepository(Protocol):
|
|
||||||
"""
|
|
||||||
Repository interface for HumanInputForm.
|
|
||||||
|
|
||||||
This interface defines the contract for accessing and manipulating
|
|
||||||
HumanInputForm data, regardless of the underlying storage mechanism.
|
|
||||||
|
|
||||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
|
||||||
and other implementation details should be handled at the implementation level, not in
|
|
||||||
the core interface. This keeps the core domain model clean and independent of specific
|
|
||||||
application domains or deployment scenarios.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
|
||||||
"""Get the form created for a given human input node in a workflow execution. Returns
|
|
||||||
`None` if the form has not been created yet."""
|
|
||||||
...
|
|
||||||
|
|
||||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
|
||||||
"""
|
|
||||||
Create a human input form from form definition.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -6,18 +6,14 @@ import threading
|
|||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping, Sequence
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Protocol
|
from typing import Any, Protocol
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pydantic.json import pydantic_encoder
|
from pydantic.json import pydantic_encoder
|
||||||
|
|
||||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||||
from core.workflow.enums import NodeState
|
from core.workflow.entities.pause_reason import PauseReason
|
||||||
from core.workflow.runtime.variable_pool import VariablePool
|
from core.workflow.runtime.variable_pool import VariablePool
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from core.workflow.entities.pause_reason import PauseReason
|
|
||||||
|
|
||||||
|
|
||||||
class ReadyQueueProtocol(Protocol):
|
class ReadyQueueProtocol(Protocol):
|
||||||
"""Structural interface required from ready queue implementations."""
|
"""Structural interface required from ready queue implementations."""
|
||||||
@@ -64,7 +60,7 @@ class GraphExecutionProtocol(Protocol):
|
|||||||
aborted: bool
|
aborted: bool
|
||||||
error: Exception | None
|
error: Exception | None
|
||||||
exceptions_count: int
|
exceptions_count: int
|
||||||
pause_reasons: Sequence[PauseReason]
|
pause_reasons: list[PauseReason]
|
||||||
|
|
||||||
def start(self) -> None:
|
def start(self) -> None:
|
||||||
"""Transition execution into the running state."""
|
"""Transition execution into the running state."""
|
||||||
@@ -107,33 +103,14 @@ class ResponseStreamCoordinatorProtocol(Protocol):
|
|||||||
...
|
...
|
||||||
|
|
||||||
|
|
||||||
class NodeProtocol(Protocol):
|
|
||||||
"""Structural interface for graph nodes."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
state: NodeState
|
|
||||||
|
|
||||||
|
|
||||||
class EdgeProtocol(Protocol):
|
|
||||||
id: str
|
|
||||||
state: NodeState
|
|
||||||
|
|
||||||
|
|
||||||
class GraphProtocol(Protocol):
|
class GraphProtocol(Protocol):
|
||||||
"""Structural interface required from graph instances attached to the runtime state."""
|
"""Structural interface required from graph instances attached to the runtime state."""
|
||||||
|
|
||||||
nodes: Mapping[str, NodeProtocol]
|
nodes: Mapping[str, object]
|
||||||
edges: Mapping[str, EdgeProtocol]
|
edges: Mapping[str, object]
|
||||||
root_node: NodeProtocol
|
root_node: object
|
||||||
|
|
||||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
def get_outgoing_edges(self, node_id: str) -> Sequence[object]: ...
|
||||||
|
|
||||||
|
|
||||||
class _GraphStateSnapshot(BaseModel):
|
|
||||||
"""Serializable graph state snapshot for node/edge states."""
|
|
||||||
|
|
||||||
nodes: dict[str, NodeState] = Field(default_factory=dict)
|
|
||||||
edges: dict[str, NodeState] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(slots=True)
|
@dataclass(slots=True)
|
||||||
@@ -151,20 +128,10 @@ class _GraphRuntimeStateSnapshot:
|
|||||||
graph_execution_dump: str | None
|
graph_execution_dump: str | None
|
||||||
response_coordinator_dump: str | None
|
response_coordinator_dump: str | None
|
||||||
paused_nodes: tuple[str, ...]
|
paused_nodes: tuple[str, ...]
|
||||||
deferred_nodes: tuple[str, ...]
|
|
||||||
graph_node_states: dict[str, NodeState]
|
|
||||||
graph_edge_states: dict[str, NodeState]
|
|
||||||
|
|
||||||
|
|
||||||
class GraphRuntimeState:
|
class GraphRuntimeState:
|
||||||
"""Mutable runtime state shared across graph execution components.
|
"""Mutable runtime state shared across graph execution components."""
|
||||||
|
|
||||||
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
|
|
||||||
including scheduling details, variable values, and timing information.
|
|
||||||
|
|
||||||
Values that are initialized prior to workflow execution and remain constant
|
|
||||||
throughout the execution should be part of `GraphInitParams` instead.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -202,16 +169,6 @@ class GraphRuntimeState:
|
|||||||
self._pending_response_coordinator_dump: str | None = None
|
self._pending_response_coordinator_dump: str | None = None
|
||||||
self._pending_graph_execution_workflow_id: str | None = None
|
self._pending_graph_execution_workflow_id: str | None = None
|
||||||
self._paused_nodes: set[str] = set()
|
self._paused_nodes: set[str] = set()
|
||||||
self._deferred_nodes: set[str] = set()
|
|
||||||
|
|
||||||
# Node and edges states needed to be restored into
|
|
||||||
# graph object.
|
|
||||||
#
|
|
||||||
# These two fields are non-None only when resuming from a snapshot.
|
|
||||||
# Once the graph is attached, these two fields will be set to None.
|
|
||||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
|
||||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
|
||||||
|
|
||||||
self.stop_event: threading.Event = threading.Event()
|
self.stop_event: threading.Event = threading.Event()
|
||||||
|
|
||||||
if graph is not None:
|
if graph is not None:
|
||||||
@@ -233,7 +190,6 @@ class GraphRuntimeState:
|
|||||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||||
self._pending_response_coordinator_dump = None
|
self._pending_response_coordinator_dump = None
|
||||||
self._apply_pending_graph_state()
|
|
||||||
|
|
||||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||||
"""Ensure core collaborators are initialized with the provided context."""
|
"""Ensure core collaborators are initialized with the provided context."""
|
||||||
@@ -355,13 +311,8 @@ class GraphRuntimeState:
|
|||||||
"ready_queue": self.ready_queue.dumps(),
|
"ready_queue": self.ready_queue.dumps(),
|
||||||
"graph_execution": self.graph_execution.dumps(),
|
"graph_execution": self.graph_execution.dumps(),
|
||||||
"paused_nodes": list(self._paused_nodes),
|
"paused_nodes": list(self._paused_nodes),
|
||||||
"deferred_nodes": list(self._deferred_nodes),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
graph_state = self._snapshot_graph_state()
|
|
||||||
if graph_state is not None:
|
|
||||||
snapshot["graph_state"] = graph_state
|
|
||||||
|
|
||||||
if self._response_coordinator is not None and self._graph is not None:
|
if self._response_coordinator is not None and self._graph is not None:
|
||||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||||
|
|
||||||
@@ -395,11 +346,6 @@ class GraphRuntimeState:
|
|||||||
|
|
||||||
self._paused_nodes.add(node_id)
|
self._paused_nodes.add(node_id)
|
||||||
|
|
||||||
def get_paused_nodes(self) -> list[str]:
|
|
||||||
"""Retrieve the list of paused nodes without mutating internal state."""
|
|
||||||
|
|
||||||
return list(self._paused_nodes)
|
|
||||||
|
|
||||||
def consume_paused_nodes(self) -> list[str]:
|
def consume_paused_nodes(self) -> list[str]:
|
||||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||||
|
|
||||||
@@ -407,23 +353,6 @@ class GraphRuntimeState:
|
|||||||
self._paused_nodes.clear()
|
self._paused_nodes.clear()
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
def register_deferred_node(self, node_id: str) -> None:
|
|
||||||
"""Record a node that became ready during pause and should resume later."""
|
|
||||||
|
|
||||||
self._deferred_nodes.add(node_id)
|
|
||||||
|
|
||||||
def get_deferred_nodes(self) -> list[str]:
|
|
||||||
"""Retrieve deferred nodes without mutating internal state."""
|
|
||||||
|
|
||||||
return list(self._deferred_nodes)
|
|
||||||
|
|
||||||
def consume_deferred_nodes(self) -> list[str]:
|
|
||||||
"""Retrieve and clear deferred nodes awaiting resume."""
|
|
||||||
|
|
||||||
nodes = list(self._deferred_nodes)
|
|
||||||
self._deferred_nodes.clear()
|
|
||||||
return nodes
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Builders
|
# Builders
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
@@ -485,10 +414,6 @@ class GraphRuntimeState:
|
|||||||
graph_execution_payload = payload.get("graph_execution")
|
graph_execution_payload = payload.get("graph_execution")
|
||||||
response_payload = payload.get("response_coordinator")
|
response_payload = payload.get("response_coordinator")
|
||||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||||
deferred_nodes_payload = payload.get("deferred_nodes", [])
|
|
||||||
graph_state_payload = payload.get("graph_state", {}) or {}
|
|
||||||
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
|
|
||||||
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
|
|
||||||
|
|
||||||
return _GraphRuntimeStateSnapshot(
|
return _GraphRuntimeStateSnapshot(
|
||||||
start_at=start_at,
|
start_at=start_at,
|
||||||
@@ -502,9 +427,6 @@ class GraphRuntimeState:
|
|||||||
graph_execution_dump=graph_execution_payload,
|
graph_execution_dump=graph_execution_payload,
|
||||||
response_coordinator_dump=response_payload,
|
response_coordinator_dump=response_payload,
|
||||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||||
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
|
|
||||||
graph_node_states=graph_node_states,
|
|
||||||
graph_edge_states=graph_edge_states,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||||
@@ -520,10 +442,6 @@ class GraphRuntimeState:
|
|||||||
self._restore_graph_execution(snapshot.graph_execution_dump)
|
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||||
self._paused_nodes = set(snapshot.paused_nodes)
|
self._paused_nodes = set(snapshot.paused_nodes)
|
||||||
self._deferred_nodes = set(snapshot.deferred_nodes)
|
|
||||||
self._pending_graph_node_states = snapshot.graph_node_states or None
|
|
||||||
self._pending_graph_edge_states = snapshot.graph_edge_states or None
|
|
||||||
self._apply_pending_graph_state()
|
|
||||||
|
|
||||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||||
if payload is not None:
|
if payload is not None:
|
||||||
@@ -560,68 +478,3 @@ class GraphRuntimeState:
|
|||||||
|
|
||||||
self._pending_response_coordinator_dump = payload
|
self._pending_response_coordinator_dump = payload
|
||||||
self._response_coordinator = None
|
self._response_coordinator = None
|
||||||
|
|
||||||
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
|
|
||||||
graph = self._graph
|
|
||||||
if graph is None:
|
|
||||||
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
|
|
||||||
return _GraphStateSnapshot()
|
|
||||||
return _GraphStateSnapshot(
|
|
||||||
nodes=self._pending_graph_node_states or {},
|
|
||||||
edges=self._pending_graph_edge_states or {},
|
|
||||||
)
|
|
||||||
|
|
||||||
nodes = graph.nodes
|
|
||||||
edges = graph.edges
|
|
||||||
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
|
|
||||||
return _GraphStateSnapshot()
|
|
||||||
|
|
||||||
node_states = {}
|
|
||||||
for node_id, node in nodes.items():
|
|
||||||
if not isinstance(node_id, str):
|
|
||||||
continue
|
|
||||||
node_states[node_id] = node.state
|
|
||||||
|
|
||||||
edge_states = {}
|
|
||||||
for edge_id, edge in edges.items():
|
|
||||||
if not isinstance(edge_id, str):
|
|
||||||
continue
|
|
||||||
edge_states[edge_id] = edge.state
|
|
||||||
|
|
||||||
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
|
|
||||||
|
|
||||||
def _apply_pending_graph_state(self) -> None:
|
|
||||||
if self._graph is None:
|
|
||||||
return
|
|
||||||
if self._pending_graph_node_states:
|
|
||||||
for node_id, state in self._pending_graph_node_states.items():
|
|
||||||
node = self._graph.nodes.get(node_id)
|
|
||||||
if node is None:
|
|
||||||
continue
|
|
||||||
node.state = state
|
|
||||||
if self._pending_graph_edge_states:
|
|
||||||
for edge_id, state in self._pending_graph_edge_states.items():
|
|
||||||
edge = self._graph.edges.get(edge_id)
|
|
||||||
if edge is None:
|
|
||||||
continue
|
|
||||||
edge.state = state
|
|
||||||
|
|
||||||
self._pending_graph_node_states = None
|
|
||||||
self._pending_graph_edge_states = None
|
|
||||||
|
|
||||||
|
|
||||||
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
|
|
||||||
if not isinstance(payload, Mapping):
|
|
||||||
return {}
|
|
||||||
raw_map = payload.get(key, {})
|
|
||||||
if not isinstance(raw_map, Mapping):
|
|
||||||
return {}
|
|
||||||
result: dict[str, NodeState] = {}
|
|
||||||
for node_id, raw_state in raw_map.items():
|
|
||||||
if not isinstance(node_id, str):
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
result[node_id] = NodeState(str(raw_state))
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
return result
|
|
||||||
|
|||||||
@@ -15,14 +15,12 @@ class WorkflowRuntimeTypeConverter:
|
|||||||
def to_json_encodable(self, value: None) -> None: ...
|
def to_json_encodable(self, value: None) -> None: ...
|
||||||
|
|
||||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||||
"""Convert runtime values to JSON-serializable structures."""
|
result = self._to_json_encodable_recursive(value)
|
||||||
|
|
||||||
result = self.value_to_json_encodable_recursive(value)
|
|
||||||
if isinstance(result, Mapping) or result is None:
|
if isinstance(result, Mapping) or result is None:
|
||||||
return result
|
return result
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def value_to_json_encodable_recursive(self, value: Any):
|
def _to_json_encodable_recursive(self, value: Any):
|
||||||
if value is None:
|
if value is None:
|
||||||
return value
|
return value
|
||||||
if isinstance(value, (bool, int, str, float)):
|
if isinstance(value, (bool, int, str, float)):
|
||||||
@@ -31,7 +29,7 @@ class WorkflowRuntimeTypeConverter:
|
|||||||
# Convert Decimal to float for JSON serialization
|
# Convert Decimal to float for JSON serialization
|
||||||
return float(value)
|
return float(value)
|
||||||
if isinstance(value, Segment):
|
if isinstance(value, Segment):
|
||||||
return self.value_to_json_encodable_recursive(value.value)
|
return self._to_json_encodable_recursive(value.value)
|
||||||
if isinstance(value, File):
|
if isinstance(value, File):
|
||||||
return value.to_dict()
|
return value.to_dict()
|
||||||
if isinstance(value, BaseModel):
|
if isinstance(value, BaseModel):
|
||||||
@@ -39,11 +37,11 @@ class WorkflowRuntimeTypeConverter:
|
|||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
res = {}
|
res = {}
|
||||||
for k, v in value.items():
|
for k, v in value.items():
|
||||||
res[k] = self.value_to_json_encodable_recursive(v)
|
res[k] = self._to_json_encodable_recursive(v)
|
||||||
return res
|
return res
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
res_list = []
|
res_list = []
|
||||||
for item in value:
|
for item in value:
|
||||||
res_list.append(self.value_to_json_encodable_recursive(item))
|
res_list.append(self._to_json_encodable_recursive(item))
|
||||||
return res_list
|
return res_list
|
||||||
return value
|
return value
|
||||||
|
|||||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
|||||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||||
# Cloud edition: separate queues for dataset and trigger tasks
|
# Cloud edition: separate queues for dataset and trigger tasks
|
||||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||||
else
|
else
|
||||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||||
fi
|
fi
|
||||||
else
|
else
|
||||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||||
@@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then
|
|||||||
fi
|
fi
|
||||||
|
|
||||||
echo "Running Flask job command: flask $*"
|
echo "Running Flask job command: flask $*"
|
||||||
|
|
||||||
# Temporarily disable exit on error to capture exit code
|
# Temporarily disable exit on error to capture exit code
|
||||||
set +e
|
set +e
|
||||||
flask "$@"
|
flask "$@"
|
||||||
|
|||||||
@@ -151,12 +151,6 @@ def init_app(app: DifyApp) -> Celery:
|
|||||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||||
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||||
}
|
}
|
||||||
if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK:
|
|
||||||
imports.append("tasks.human_input_timeout_tasks")
|
|
||||||
beat_schedule["human_input_form_timeout"] = {
|
|
||||||
"task": "human_input_form_timeout.check_and_resume",
|
|
||||||
"schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL),
|
|
||||||
}
|
|
||||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||||
imports.append("schedule.check_upgradable_plugin_task")
|
imports.append("schedule.check_upgradable_plugin_task")
|
||||||
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
||||||
|
|||||||
@@ -8,16 +8,12 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
|||||||
import redis
|
import redis
|
||||||
from redis import RedisError
|
from redis import RedisError
|
||||||
from redis.cache import CacheConfig
|
from redis.cache import CacheConfig
|
||||||
from redis.client import PubSub
|
|
||||||
from redis.cluster import ClusterNode, RedisCluster
|
from redis.cluster import ClusterNode, RedisCluster
|
||||||
from redis.connection import Connection, SSLConnection
|
from redis.connection import Connection, SSLConnection
|
||||||
from redis.sentinel import Sentinel
|
from redis.sentinel import Sentinel
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from dify_app import DifyApp
|
from dify_app import DifyApp
|
||||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
|
||||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
|
||||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from redis.lock import Lock
|
from redis.lock import Lock
|
||||||
@@ -110,7 +106,6 @@ class RedisClientWrapper:
|
|||||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||||
def zcard(self, name: str | bytes) -> Any: ...
|
def zcard(self, name: str | bytes) -> Any: ...
|
||||||
def getdel(self, name: str | bytes) -> Any: ...
|
def getdel(self, name: str | bytes) -> Any: ...
|
||||||
def pubsub(self) -> PubSub: ...
|
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Any:
|
def __getattr__(self, item: str) -> Any:
|
||||||
if self._client is None:
|
if self._client is None:
|
||||||
@@ -119,7 +114,6 @@ class RedisClientWrapper:
|
|||||||
|
|
||||||
|
|
||||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
|
||||||
|
|
||||||
|
|
||||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||||
@@ -232,12 +226,6 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
|||||||
return client
|
return client
|
||||||
|
|
||||||
|
|
||||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
|
||||||
if use_clusters:
|
|
||||||
return RedisCluster.from_url(pubsub_url)
|
|
||||||
return redis.Redis.from_url(pubsub_url)
|
|
||||||
|
|
||||||
|
|
||||||
def init_app(app: DifyApp):
|
def init_app(app: DifyApp):
|
||||||
"""Initialize Redis client and attach it to the app."""
|
"""Initialize Redis client and attach it to the app."""
|
||||||
global redis_client
|
global redis_client
|
||||||
@@ -256,24 +244,6 @@ def init_app(app: DifyApp):
|
|||||||
redis_client.initialize(client)
|
redis_client.initialize(client)
|
||||||
app.extensions["redis"] = redis_client
|
app.extensions["redis"] = redis_client
|
||||||
|
|
||||||
pubsub_client = client
|
|
||||||
if dify_config.normalized_pubsub_redis_url:
|
|
||||||
pubsub_client = _create_pubsub_client(
|
|
||||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
|
||||||
)
|
|
||||||
pubsub_redis_client.initialize(pubsub_client)
|
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
|
||||||
return pubsub_redis_client
|
|
||||||
|
|
||||||
|
|
||||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
|
||||||
redis_conn = get_pubsub_redis_client()
|
|
||||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
|
||||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
|
||||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
|
||||||
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from typing import Any
|
|||||||
|
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
|
||||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
|
||||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||||
from extensions.logstore.repositories import safe_float, safe_int
|
from extensions.logstore.repositories import safe_float, safe_int
|
||||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||||
@@ -208,10 +207,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
for row in deduplicated_results:
|
if deduplicated_results:
|
||||||
model = _dict_to_workflow_node_execution_model(row)
|
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
|
||||||
if model.status != WorkflowNodeExecutionStatus.PAUSED:
|
|
||||||
return model
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@@ -312,8 +309,6 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
|||||||
if model and model.id: # Ensure model is valid
|
if model and model.id: # Ensure model is valid
|
||||||
models.append(model)
|
models.append(model)
|
||||||
|
|
||||||
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
|
|
||||||
|
|
||||||
# Sort by index DESC for trace visualization
|
# Sort by index DESC for trace visualization
|
||||||
models.sort(key=lambda x: x.index, reverse=True)
|
models.sort(key=lambda x: x.index, reverse=True)
|
||||||
|
|
||||||
|
|||||||
@@ -192,7 +192,6 @@ class StatusCount(ResponseModel):
|
|||||||
success: int
|
success: int
|
||||||
failed: int
|
failed: int
|
||||||
partial_success: int
|
partial_success: int
|
||||||
paused: int
|
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(ResponseModel):
|
class ModelConfig(ResponseModel):
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||||
|
|
||||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
|
||||||
from core.file import File
|
from core.file import File
|
||||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||||
|
|
||||||
@@ -62,7 +61,6 @@ class MessageListItem(ResponseModel):
|
|||||||
message_files: list[MessageFile]
|
message_files: list[MessageFile]
|
||||||
status: str
|
status: str
|
||||||
error: str | None = None
|
error: str | None = None
|
||||||
extra_contents: list[ExecutionExtraContentDomainModel]
|
|
||||||
|
|
||||||
@field_validator("inputs", mode="before")
|
@field_validator("inputs", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class RedisSubscriptionBase(Subscription):
|
|||||||
self._start_if_needed()
|
self._start_if_needed()
|
||||||
return iter(self._message_iterator())
|
return iter(self._message_iterator())
|
||||||
|
|
||||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||||
"""Receive the next message from the subscription."""
|
"""Receive the next message from the subscription."""
|
||||||
if self._closed.is_set():
|
if self._closed.is_set():
|
||||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||||
|
|||||||
@@ -61,14 +61,7 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
|||||||
|
|
||||||
def _get_message(self) -> dict | None:
|
def _get_message(self) -> dict | None:
|
||||||
assert self._pubsub is not None
|
assert self._pubsub is not None
|
||||||
# NOTE(QuantumGhost): this is an issue in
|
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||||
# upstream code. If Sharded PubSub is used with Cluster, the
|
|
||||||
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
|
|
||||||
# message['type'].
|
|
||||||
#
|
|
||||||
# Since we have already filtered at the caller's site, we can safely set
|
|
||||||
# `ignore_subscribe_messages=False`.
|
|
||||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=0.1) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
def _get_message_type(self) -> str:
|
def _get_message_type(self) -> str:
|
||||||
return "smessage"
|
return "smessage"
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
"""
|
|
||||||
Email template rendering helpers with configurable safety modes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
from collections.abc import Mapping
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from flask import render_template_string
|
|
||||||
from jinja2.runtime import Context
|
|
||||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
|
||||||
|
|
||||||
from configs import dify_config
|
|
||||||
from configs.feature import TemplateMode
|
|
||||||
|
|
||||||
|
|
||||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
|
||||||
"""Sandboxed environment with execution timeout."""
|
|
||||||
|
|
||||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
|
||||||
self._deadline = time.time() + timeout if timeout else None
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
|
||||||
if self._deadline is not None and time.time() > self._deadline:
|
|
||||||
raise TimeoutError("Template rendering timeout")
|
|
||||||
return super().call(context, obj, *args, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
|
|
||||||
"""
|
|
||||||
Render email template content according to the configured template mode.
|
|
||||||
|
|
||||||
In unsafe mode, Jinja expressions are evaluated directly.
|
|
||||||
In sandbox mode, a sandboxed environment with timeout is used.
|
|
||||||
In disabled mode, the template is returned without rendering.
|
|
||||||
"""
|
|
||||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
|
||||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
|
||||||
|
|
||||||
if mode == TemplateMode.UNSAFE:
|
|
||||||
return render_template_string(template, **substitutions)
|
|
||||||
if mode == TemplateMode.SANDBOX:
|
|
||||||
env = SandboxedEnvironment(timeout=timeout)
|
|
||||||
tmpl = env.from_string(template)
|
|
||||||
return tmpl.render(substitutions)
|
|
||||||
if mode == TemplateMode.DISABLED:
|
|
||||||
return template
|
|
||||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
|
||||||
@@ -1,15 +1,12 @@
|
|||||||
import contextvars
|
import contextvars
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from flask import Flask, g
|
from flask import Flask, g
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from models import Account, EndUser
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def preserve_flask_contexts(
|
def preserve_flask_contexts(
|
||||||
@@ -67,7 +64,3 @@ def preserve_flask_contexts(
|
|||||||
finally:
|
finally:
|
||||||
# Any cleanup can be added here if needed
|
# Any cleanup can be added here if needed
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def set_login_user(user: "Account | EndUser"):
|
|
||||||
g._login_user = user
|
|
||||||
|
|||||||
@@ -7,10 +7,10 @@ import struct
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Generator, Mapping
|
from collections.abc import Generator, Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
from zoneinfo import available_timezones
|
from zoneinfo import available_timezones
|
||||||
|
|
||||||
@@ -126,13 +126,6 @@ class TimestampField(fields.Raw):
|
|||||||
return int(value.timestamp())
|
return int(value.timestamp())
|
||||||
|
|
||||||
|
|
||||||
class OptionalTimestampField(fields.Raw):
|
|
||||||
def format(self, value) -> int | None:
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
return int(value.timestamp())
|
|
||||||
|
|
||||||
|
|
||||||
def email(email):
|
def email(email):
|
||||||
# Define a regex pattern for email addresses
|
# Define a regex pattern for email addresses
|
||||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||||
@@ -244,26 +237,6 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
|||||||
|
|
||||||
|
|
||||||
def generate_string(n):
|
def generate_string(n):
|
||||||
"""
|
|
||||||
Generates a cryptographically secure random string of the specified length.
|
|
||||||
|
|
||||||
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
|
|
||||||
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
|
|
||||||
|
|
||||||
Each character in the generated string provides approximately 5.95 bits of entropy
|
|
||||||
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
|
|
||||||
length of the string (`n`) should be at least 22 characters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n (int): The length of the random string to generate. For secure usage,
|
|
||||||
`n` should be 22 or greater.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: A random string of length `n` composed of ASCII letters and digits.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This function is suitable for generating credentials or other secure tokens.
|
|
||||||
"""
|
|
||||||
letters_digits = string.ascii_letters + string.digits
|
letters_digits = string.ascii_letters + string.digits
|
||||||
result = ""
|
result = ""
|
||||||
for _ in range(n):
|
for _ in range(n):
|
||||||
@@ -432,35 +405,11 @@ class TokenManager:
|
|||||||
return f"{token_type}:account:{account_id}"
|
return f"{token_type}:account:{account_id}"
|
||||||
|
|
||||||
|
|
||||||
class _RateLimiterRedisClient(Protocol):
|
|
||||||
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
|
|
||||||
|
|
||||||
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
|
|
||||||
|
|
||||||
def zcard(self, name: str | bytes) -> int: ...
|
|
||||||
|
|
||||||
def expire(self, name: str | bytes, time: int) -> bool: ...
|
|
||||||
|
|
||||||
|
|
||||||
def _default_rate_limit_member_factory() -> str:
|
|
||||||
current_time = int(time.time())
|
|
||||||
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
|
|
||||||
|
|
||||||
|
|
||||||
class RateLimiter:
|
class RateLimiter:
|
||||||
def __init__(
|
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||||
self,
|
|
||||||
prefix: str,
|
|
||||||
max_attempts: int,
|
|
||||||
time_window: int,
|
|
||||||
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
|
|
||||||
redis_client: _RateLimiterRedisClient = redis_client,
|
|
||||||
):
|
|
||||||
self.prefix = prefix
|
self.prefix = prefix
|
||||||
self.max_attempts = max_attempts
|
self.max_attempts = max_attempts
|
||||||
self.time_window = time_window
|
self.time_window = time_window
|
||||||
self._member_factory = member_factory
|
|
||||||
self._redis_client = redis_client
|
|
||||||
|
|
||||||
def _get_key(self, email: str) -> str:
|
def _get_key(self, email: str) -> str:
|
||||||
return f"{self.prefix}:{email}"
|
return f"{self.prefix}:{email}"
|
||||||
@@ -470,8 +419,8 @@ class RateLimiter:
|
|||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
window_start_time = current_time - self.time_window
|
window_start_time = current_time - self.time_window
|
||||||
|
|
||||||
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||||
attempts = self._redis_client.zcard(key)
|
attempts = redis_client.zcard(key)
|
||||||
|
|
||||||
if attempts and int(attempts) >= self.max_attempts:
|
if attempts and int(attempts) >= self.max_attempts:
|
||||||
return True
|
return True
|
||||||
@@ -479,8 +428,7 @@ class RateLimiter:
|
|||||||
|
|
||||||
def increment_rate_limit(self, email: str):
|
def increment_rate_limit(self, email: str):
|
||||||
key = self._get_key(email)
|
key = self._get_key(email)
|
||||||
member = self._member_factory()
|
|
||||||
current_time = int(time.time())
|
current_time = int(time.time())
|
||||||
|
|
||||||
self._redis_client.zadd(key, {member: current_time})
|
redis_client.zadd(key, {current_time: current_time})
|
||||||
self._redis_client.expire(key, self.time_window * 2)
|
redis_client.expire(key, self.time_window * 2)
|
||||||
|
|||||||
@@ -1,99 +0,0 @@
|
|||||||
"""Add human input related db models
|
|
||||||
|
|
||||||
Revision ID: e8c3b3c46151
|
|
||||||
Revises: 788d3099ae3a
|
|
||||||
Create Date: 2026-01-29 14:15:23.081903
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import models as models
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision = "e8c3b3c46151"
|
|
||||||
down_revision = "788d3099ae3a"
|
|
||||||
branch_labels = None
|
|
||||||
depends_on = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade():
|
|
||||||
op.create_table(
|
|
||||||
"execution_extra_contents",
|
|
||||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
|
|
||||||
sa.Column("type", sa.String(length=30), nullable=False),
|
|
||||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("message_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.Column("form_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("execution_extra_contents_pkey")),
|
|
||||||
)
|
|
||||||
with op.batch_alter_table("execution_extra_contents", schema=None) as batch_op:
|
|
||||||
batch_op.create_index(batch_op.f("execution_extra_contents_message_id_idx"), ["message_id"], unique=False)
|
|
||||||
batch_op.create_index(
|
|
||||||
batch_op.f("execution_extra_contents_workflow_run_id_idx"), ["workflow_run_id"], unique=False
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"human_input_form_deliveries",
|
|
||||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
|
|
||||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("delivery_method_type", sa.String(length=20), nullable=False),
|
|
||||||
sa.Column("delivery_config_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.Column("channel_payload", sa.Text(), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_deliveries_pkey")),
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"human_input_form_recipients",
|
|
||||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
|
|
||||||
sa.Column("form_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("delivery_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("recipient_type", sa.String(length=20), nullable=False),
|
|
||||||
sa.Column("recipient_payload", sa.Text(), nullable=False),
|
|
||||||
sa.Column("access_token", sa.VARCHAR(length=32), nullable=False),
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_form_recipients_pkey")),
|
|
||||||
)
|
|
||||||
with op.batch_alter_table('human_input_form_recipients', schema=None) as batch_op:
|
|
||||||
batch_op.create_unique_constraint(batch_op.f('human_input_form_recipients_access_token_key'), ['access_token'])
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"human_input_forms",
|
|
||||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
|
||||||
|
|
||||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
|
||||||
sa.Column("workflow_run_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.Column("form_kind", sa.String(length=20), nullable=False),
|
|
||||||
sa.Column("node_id", sa.String(length=60), nullable=False),
|
|
||||||
sa.Column("form_definition", sa.Text(), nullable=False),
|
|
||||||
sa.Column("rendered_content", sa.Text(), nullable=False),
|
|
||||||
sa.Column("status", sa.String(length=20), nullable=False),
|
|
||||||
sa.Column("expiration_time", sa.DateTime(), nullable=False),
|
|
||||||
sa.Column("selected_action_id", sa.String(length=200), nullable=True),
|
|
||||||
sa.Column("submitted_data", sa.Text(), nullable=True),
|
|
||||||
sa.Column("submitted_at", sa.DateTime(), nullable=True),
|
|
||||||
sa.Column("submission_user_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.Column("submission_end_user_id", models.types.StringUUID(), nullable=True),
|
|
||||||
sa.Column("completed_by_recipient_id", models.types.StringUUID(), nullable=True),
|
|
||||||
|
|
||||||
sa.PrimaryKeyConstraint("id", name=op.f("human_input_forms_pkey")),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade():
|
|
||||||
op.drop_table("human_input_forms")
|
|
||||||
op.drop_table("human_input_form_recipients")
|
|
||||||
op.drop_table("human_input_form_deliveries")
|
|
||||||
op.drop_table("execution_extra_contents")
|
|
||||||
@@ -34,8 +34,6 @@ from .enums import (
|
|||||||
WorkflowRunTriggeredFrom,
|
WorkflowRunTriggeredFrom,
|
||||||
WorkflowTriggerStatus,
|
WorkflowTriggerStatus,
|
||||||
)
|
)
|
||||||
from .execution_extra_content import ExecutionExtraContent, HumanInputContent
|
|
||||||
from .human_input import HumanInputForm
|
|
||||||
from .model import (
|
from .model import (
|
||||||
AccountTrialAppRecord,
|
AccountTrialAppRecord,
|
||||||
ApiRequest,
|
ApiRequest,
|
||||||
@@ -157,12 +155,9 @@ __all__ = [
|
|||||||
"DocumentSegment",
|
"DocumentSegment",
|
||||||
"Embedding",
|
"Embedding",
|
||||||
"EndUser",
|
"EndUser",
|
||||||
"ExecutionExtraContent",
|
|
||||||
"ExporleBanner",
|
"ExporleBanner",
|
||||||
"ExternalKnowledgeApis",
|
"ExternalKnowledgeApis",
|
||||||
"ExternalKnowledgeBindings",
|
"ExternalKnowledgeBindings",
|
||||||
"HumanInputContent",
|
|
||||||
"HumanInputForm",
|
|
||||||
"IconType",
|
"IconType",
|
||||||
"InstalledApp",
|
"InstalledApp",
|
||||||
"InvitationCode",
|
"InvitationCode",
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ class DefaultFieldsMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime,
|
__name_pos=DateTime,
|
||||||
nullable=False,
|
nullable=False,
|
||||||
default=naive_utc_now,
|
default=naive_utc_now,
|
||||||
server_default=func.current_timestamp(),
|
server_default=func.current_timestamp(),
|
||||||
|
|||||||
@@ -36,7 +36,6 @@ class MessageStatus(StrEnum):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
NORMAL = "normal"
|
NORMAL = "normal"
|
||||||
PAUSED = "paused"
|
|
||||||
ERROR = "error"
|
ERROR = "error"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,78 +0,0 @@
|
|||||||
from enum import StrEnum, auto
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from .base import Base, DefaultFieldsMixin
|
|
||||||
from .types import EnumText, StringUUID
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .human_input import HumanInputForm
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionContentType(StrEnum):
|
|
||||||
HUMAN_INPUT = auto()
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionExtraContent(DefaultFieldsMixin, Base):
|
|
||||||
"""ExecutionExtraContent stores extra contents produced during workflow / chatflow execution."""
|
|
||||||
|
|
||||||
# The `ExecutionExtraContent` uses single table inheritance to model different
|
|
||||||
# kinds of contents produced during message generation.
|
|
||||||
#
|
|
||||||
# See: https://docs.sqlalchemy.org/en/20/orm/inheritance.html#single-table-inheritance
|
|
||||||
|
|
||||||
__tablename__ = "execution_extra_contents"
|
|
||||||
__mapper_args__ = {
|
|
||||||
"polymorphic_abstract": True,
|
|
||||||
"polymorphic_on": "type",
|
|
||||||
"with_polymorphic": "*",
|
|
||||||
}
|
|
||||||
# type records the type of the content. It serves as the `discriminator` for the
|
|
||||||
# single table inheritance.
|
|
||||||
type: Mapped[ExecutionContentType] = mapped_column(
|
|
||||||
EnumText(ExecutionContentType, length=30),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# `workflow_run_id` records the workflow execution which generates this content, correspond to
|
|
||||||
# `WorkflowRun.id`.
|
|
||||||
workflow_run_id: Mapped[str] = mapped_column(StringUUID, nullable=False, index=True)
|
|
||||||
|
|
||||||
# `message_id` records the messages generated by the execution associated with this `ExecutionExtraContent`.
|
|
||||||
# It references to `Message.id`.
|
|
||||||
#
|
|
||||||
# For workflow execution, this field is `None`.
|
|
||||||
#
|
|
||||||
# For chatflow execution, `message_id`` is not None, and the following condition holds:
|
|
||||||
#
|
|
||||||
# The message referenced by `message_id` has `message.workflow_run_id == execution_extra_content.workflow_run_id`
|
|
||||||
#
|
|
||||||
message_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, index=True)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputContent(ExecutionExtraContent):
|
|
||||||
"""HumanInputContent is a concrete class that represents human input content.
|
|
||||||
It should only be initialized with the `new` class method."""
|
|
||||||
|
|
||||||
__mapper_args__ = {
|
|
||||||
"polymorphic_identity": ExecutionContentType.HUMAN_INPUT,
|
|
||||||
}
|
|
||||||
|
|
||||||
# A relation to HumanInputForm table.
|
|
||||||
#
|
|
||||||
# While the form_id column is nullable in database (due to the nature of single table inheritance),
|
|
||||||
# the form_id field should not be null for a given `HumanInputContent` instance.
|
|
||||||
form_id: Mapped[str] = mapped_column(StringUUID, nullable=True)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new(cls, form_id: str, message_id: str | None) -> "HumanInputContent":
|
|
||||||
return cls(form_id=form_id, message_id=message_id)
|
|
||||||
|
|
||||||
form: Mapped["HumanInputForm"] = relationship(
|
|
||||||
"HumanInputForm",
|
|
||||||
foreign_keys=[form_id],
|
|
||||||
uselist=False,
|
|
||||||
lazy="raise",
|
|
||||||
primaryjoin="foreign(HumanInputContent.form_id) == HumanInputForm.id",
|
|
||||||
)
|
|
||||||
@@ -1,237 +0,0 @@
|
|||||||
from datetime import datetime
|
|
||||||
from enum import StrEnum
|
|
||||||
from typing import Annotated, Literal, Self, final
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from core.workflow.nodes.human_input.enums import (
|
|
||||||
DeliveryMethodType,
|
|
||||||
HumanInputFormKind,
|
|
||||||
HumanInputFormStatus,
|
|
||||||
)
|
|
||||||
from libs.helper import generate_string
|
|
||||||
|
|
||||||
from .base import Base, DefaultFieldsMixin
|
|
||||||
from .types import EnumText, StringUUID
|
|
||||||
|
|
||||||
_token_length = 22
|
|
||||||
# A 32-character string can store a base64-encoded value with 192 bits of entropy
|
|
||||||
# or a base62-encoded value with over 180 bits of entropy, providing sufficient
|
|
||||||
# uniqueness for most use cases.
|
|
||||||
_token_field_length = 32
|
|
||||||
_email_field_length = 330
|
|
||||||
|
|
||||||
|
|
||||||
def _generate_token() -> str:
|
|
||||||
return generate_string(_token_length)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
|
||||||
__tablename__ = "human_input_forms"
|
|
||||||
|
|
||||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
|
||||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
|
||||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
|
||||||
form_kind: Mapped[HumanInputFormKind] = mapped_column(
|
|
||||||
EnumText(HumanInputFormKind),
|
|
||||||
nullable=False,
|
|
||||||
default=HumanInputFormKind.RUNTIME,
|
|
||||||
)
|
|
||||||
|
|
||||||
# The human input node the current form corresponds to.
|
|
||||||
node_id: Mapped[str] = mapped_column(sa.String(60), nullable=False)
|
|
||||||
form_definition: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
rendered_content: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
status: Mapped[HumanInputFormStatus] = mapped_column(
|
|
||||||
EnumText(HumanInputFormStatus),
|
|
||||||
nullable=False,
|
|
||||||
default=HumanInputFormStatus.WAITING,
|
|
||||||
)
|
|
||||||
|
|
||||||
expiration_time: Mapped[datetime] = mapped_column(
|
|
||||||
sa.DateTime,
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Submission-related fields (nullable until a submission happens).
|
|
||||||
selected_action_id: Mapped[str | None] = mapped_column(sa.String(200), nullable=True)
|
|
||||||
submitted_data: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
|
||||||
submitted_at: Mapped[datetime | None] = mapped_column(sa.DateTime, nullable=True)
|
|
||||||
submission_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
|
||||||
submission_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
|
||||||
|
|
||||||
completed_by_recipient_id: Mapped[str | None] = mapped_column(
|
|
||||||
StringUUID,
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
deliveries: Mapped[list["HumanInputDelivery"]] = relationship(
|
|
||||||
"HumanInputDelivery",
|
|
||||||
primaryjoin="HumanInputForm.id == foreign(HumanInputDelivery.form_id)",
|
|
||||||
uselist=True,
|
|
||||||
back_populates="form",
|
|
||||||
lazy="raise",
|
|
||||||
)
|
|
||||||
completed_by_recipient: Mapped["HumanInputFormRecipient | None"] = relationship(
|
|
||||||
"HumanInputFormRecipient",
|
|
||||||
primaryjoin="HumanInputForm.completed_by_recipient_id == foreign(HumanInputFormRecipient.id)",
|
|
||||||
lazy="raise",
|
|
||||||
viewonly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
|
||||||
__tablename__ = "human_input_form_deliveries"
|
|
||||||
|
|
||||||
form_id: Mapped[str] = mapped_column(
|
|
||||||
StringUUID,
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
delivery_method_type: Mapped[DeliveryMethodType] = mapped_column(
|
|
||||||
EnumText(DeliveryMethodType),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
delivery_config_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
|
||||||
channel_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
|
|
||||||
form: Mapped[HumanInputForm] = relationship(
|
|
||||||
"HumanInputForm",
|
|
||||||
uselist=False,
|
|
||||||
foreign_keys=[form_id],
|
|
||||||
primaryjoin="HumanInputDelivery.form_id == HumanInputForm.id",
|
|
||||||
back_populates="deliveries",
|
|
||||||
lazy="raise",
|
|
||||||
)
|
|
||||||
|
|
||||||
recipients: Mapped[list["HumanInputFormRecipient"]] = relationship(
|
|
||||||
"HumanInputFormRecipient",
|
|
||||||
primaryjoin="HumanInputDelivery.id == foreign(HumanInputFormRecipient.delivery_id)",
|
|
||||||
uselist=True,
|
|
||||||
back_populates="delivery",
|
|
||||||
# Require explicit preloading
|
|
||||||
lazy="raise",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RecipientType(StrEnum):
|
|
||||||
# EMAIL_MEMBER member means that the
|
|
||||||
EMAIL_MEMBER = "email_member"
|
|
||||||
EMAIL_EXTERNAL = "email_external"
|
|
||||||
# STANDALONE_WEB_APP is used by the standalone web app.
|
|
||||||
#
|
|
||||||
# It's not used while running workflows / chatflows containing HumanInput
|
|
||||||
# node inside console.
|
|
||||||
STANDALONE_WEB_APP = "standalone_web_app"
|
|
||||||
# CONSOLE is used while running workflows / chatflows containing HumanInput
|
|
||||||
# node inside console. (E.G. running installed apps or debugging workflows / chatflows)
|
|
||||||
CONSOLE = "console"
|
|
||||||
# BACKSTAGE is used for backstage input inside console.
|
|
||||||
BACKSTAGE = "backstage"
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class EmailMemberRecipientPayload(BaseModel):
|
|
||||||
TYPE: Literal[RecipientType.EMAIL_MEMBER] = RecipientType.EMAIL_MEMBER
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
# The `email` field here is only used for mail sending.
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class EmailExternalRecipientPayload(BaseModel):
|
|
||||||
TYPE: Literal[RecipientType.EMAIL_EXTERNAL] = RecipientType.EMAIL_EXTERNAL
|
|
||||||
email: str
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class StandaloneWebAppRecipientPayload(BaseModel):
|
|
||||||
TYPE: Literal[RecipientType.STANDALONE_WEB_APP] = RecipientType.STANDALONE_WEB_APP
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class ConsoleRecipientPayload(BaseModel):
|
|
||||||
TYPE: Literal[RecipientType.CONSOLE] = RecipientType.CONSOLE
|
|
||||||
account_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class BackstageRecipientPayload(BaseModel):
|
|
||||||
TYPE: Literal[RecipientType.BACKSTAGE] = RecipientType.BACKSTAGE
|
|
||||||
account_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@final
|
|
||||||
class ConsoleDeliveryPayload(BaseModel):
|
|
||||||
type: Literal["console"] = "console"
|
|
||||||
internal: bool = True
|
|
||||||
|
|
||||||
|
|
||||||
RecipientPayload = Annotated[
|
|
||||||
EmailMemberRecipientPayload
|
|
||||||
| EmailExternalRecipientPayload
|
|
||||||
| StandaloneWebAppRecipientPayload
|
|
||||||
| ConsoleRecipientPayload
|
|
||||||
| BackstageRecipientPayload,
|
|
||||||
Field(discriminator="TYPE"),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
|
||||||
__tablename__ = "human_input_form_recipients"
|
|
||||||
|
|
||||||
form_id: Mapped[str] = mapped_column(
|
|
||||||
StringUUID,
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
delivery_id: Mapped[str] = mapped_column(
|
|
||||||
StringUUID,
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
recipient_type: Mapped["RecipientType"] = mapped_column(EnumText(RecipientType), nullable=False)
|
|
||||||
recipient_payload: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
|
||||||
|
|
||||||
# Token primarily used for authenticated resume links (email, etc.).
|
|
||||||
access_token: Mapped[str | None] = mapped_column(
|
|
||||||
sa.VARCHAR(_token_field_length),
|
|
||||||
nullable=False,
|
|
||||||
default=_generate_token,
|
|
||||||
unique=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
delivery: Mapped[HumanInputDelivery] = relationship(
|
|
||||||
"HumanInputDelivery",
|
|
||||||
uselist=False,
|
|
||||||
foreign_keys=[delivery_id],
|
|
||||||
back_populates="recipients",
|
|
||||||
primaryjoin="HumanInputFormRecipient.delivery_id == HumanInputDelivery.id",
|
|
||||||
# Require explicit preloading
|
|
||||||
lazy="raise",
|
|
||||||
)
|
|
||||||
|
|
||||||
form: Mapped[HumanInputForm] = relationship(
|
|
||||||
"HumanInputForm",
|
|
||||||
uselist=False,
|
|
||||||
foreign_keys=[form_id],
|
|
||||||
primaryjoin="HumanInputFormRecipient.form_id == HumanInputForm.id",
|
|
||||||
# Require explicit preloading
|
|
||||||
lazy="raise",
|
|
||||||
)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def new(
|
|
||||||
cls,
|
|
||||||
form_id: str,
|
|
||||||
delivery_id: str,
|
|
||||||
payload: RecipientPayload,
|
|
||||||
) -> Self:
|
|
||||||
recipient_model = cls(
|
|
||||||
form_id=form_id,
|
|
||||||
delivery_id=delivery_id,
|
|
||||||
recipient_type=payload.TYPE,
|
|
||||||
recipient_payload=payload.model_dump_json(),
|
|
||||||
access_token=_generate_token(),
|
|
||||||
)
|
|
||||||
return recipient_model
|
|
||||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Mapping, Sequence
|
from collections.abc import Mapping
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
from enum import StrEnum, auto
|
from enum import StrEnum, auto
|
||||||
@@ -943,7 +943,6 @@ class Conversation(Base):
|
|||||||
WorkflowExecutionStatus.FAILED: 0,
|
WorkflowExecutionStatus.FAILED: 0,
|
||||||
WorkflowExecutionStatus.STOPPED: 0,
|
WorkflowExecutionStatus.STOPPED: 0,
|
||||||
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
|
WorkflowExecutionStatus.PARTIAL_SUCCEEDED: 0,
|
||||||
WorkflowExecutionStatus.PAUSED: 0,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for message in messages:
|
for message in messages:
|
||||||
@@ -964,7 +963,6 @@ class Conversation(Base):
|
|||||||
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
"success": status_counts[WorkflowExecutionStatus.SUCCEEDED],
|
||||||
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
"failed": status_counts[WorkflowExecutionStatus.FAILED],
|
||||||
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
"partial_success": status_counts[WorkflowExecutionStatus.PARTIAL_SUCCEEDED],
|
||||||
"paused": status_counts[WorkflowExecutionStatus.PAUSED],
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1347,14 +1345,6 @@ class Message(Base):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# TODO(QuantumGhost): dirty hacks, fix this later.
|
|
||||||
def set_extra_contents(self, contents: Sequence[dict[str, Any]]) -> None:
|
|
||||||
self._extra_contents = list(contents)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def extra_contents(self) -> list[dict[str, Any]]:
|
|
||||||
return getattr(self, "_extra_contents", [])
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def workflow_run(self):
|
def workflow_run(self):
|
||||||
if self.workflow_run_id:
|
if self.workflow_run_id:
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ from sqlalchemy import (
|
|||||||
select,
|
select,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||||
from typing_extensions import deprecated
|
|
||||||
|
|
||||||
from core.file.constants import maybe_file_object
|
from core.file.constants import maybe_file_object
|
||||||
from core.file.models import File
|
from core.file.models import File
|
||||||
@@ -31,7 +30,7 @@ from core.workflow.constants import (
|
|||||||
SYSTEM_VARIABLE_NODE_ID,
|
SYSTEM_VARIABLE_NODE_ID,
|
||||||
)
|
)
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||||
from core.workflow.enums import NodeType, WorkflowExecutionStatus
|
from core.workflow.enums import NodeType
|
||||||
from extensions.ext_storage import Storage
|
from extensions.ext_storage import Storage
|
||||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
@@ -406,11 +405,6 @@ class Workflow(Base): # bug
|
|||||||
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
return helper.generate_text_hash(json.dumps(entity, sort_keys=True))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecated(
|
|
||||||
"This property is not accurate for determining if a workflow is published as a tool."
|
|
||||||
"It only checks if there's a WorkflowToolProvider for the app, "
|
|
||||||
"not if this specific workflow version is the one being used by the tool."
|
|
||||||
)
|
|
||||||
def tool_published(self) -> bool:
|
def tool_published(self) -> bool:
|
||||||
"""
|
"""
|
||||||
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
|
DEPRECATED: This property is not accurate for determining if a workflow is published as a tool.
|
||||||
@@ -613,16 +607,13 @@ class WorkflowRun(Base):
|
|||||||
version: Mapped[str] = mapped_column(String(255))
|
version: Mapped[str] = mapped_column(String(255))
|
||||||
graph: Mapped[str | None] = mapped_column(LongText)
|
graph: Mapped[str | None] = mapped_column(LongText)
|
||||||
inputs: Mapped[str | None] = mapped_column(LongText)
|
inputs: Mapped[str | None] = mapped_column(LongText)
|
||||||
status: Mapped[WorkflowExecutionStatus] = mapped_column(
|
status: Mapped[str] = mapped_column(String(255)) # running, succeeded, failed, stopped, partial-succeeded
|
||||||
EnumText(WorkflowExecutionStatus, length=255),
|
|
||||||
nullable=False,
|
|
||||||
)
|
|
||||||
outputs: Mapped[str | None] = mapped_column(LongText, default="{}")
|
outputs: Mapped[str | None] = mapped_column(LongText, default="{}")
|
||||||
error: Mapped[str | None] = mapped_column(LongText)
|
error: Mapped[str | None] = mapped_column(LongText)
|
||||||
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
elapsed_time: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||||
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
total_tokens: Mapped[int] = mapped_column(sa.BigInteger, server_default=sa.text("0"))
|
||||||
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
total_steps: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255)) # account, end_user
|
created_by_role: Mapped[str] = mapped_column(String(255)) # account, end_user
|
||||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||||
@@ -638,13 +629,11 @@ class WorkflowRun(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
|
||||||
def created_by_account(self):
|
def created_by_account(self):
|
||||||
created_by_role = CreatorUserRole(self.created_by_role)
|
created_by_role = CreatorUserRole(self.created_by_role)
|
||||||
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
return db.session.get(Account, self.created_by) if created_by_role == CreatorUserRole.ACCOUNT else None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
|
||||||
def created_by_end_user(self):
|
def created_by_end_user(self):
|
||||||
from .model import EndUser
|
from .model import EndUser
|
||||||
|
|
||||||
@@ -664,7 +653,6 @@ class WorkflowRun(Base):
|
|||||||
return json.loads(self.outputs) if self.outputs else {}
|
return json.loads(self.outputs) if self.outputs else {}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
|
||||||
def message(self):
|
def message(self):
|
||||||
from .model import Message
|
from .model import Message
|
||||||
|
|
||||||
@@ -673,7 +661,6 @@ class WorkflowRun(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@deprecated("This method is retained for historical reasons; avoid using it if possible.")
|
|
||||||
def workflow(self):
|
def workflow(self):
|
||||||
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
return db.session.query(Workflow).where(Workflow.id == self.workflow_id).first()
|
||||||
|
|
||||||
@@ -1874,12 +1861,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
|||||||
|
|
||||||
def to_entity(self) -> PauseReason:
|
def to_entity(self) -> PauseReason:
|
||||||
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
if self.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||||
return HumanInputRequired(
|
return HumanInputRequired(form_id=self.form_id, node_id=self.node_id)
|
||||||
form_id=self.form_id,
|
|
||||||
form_content="",
|
|
||||||
node_id=self.node_id,
|
|
||||||
node_title="",
|
|
||||||
)
|
|
||||||
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
|
elif self.type_ == PauseReasonType.SCHEDULED_PAUSE:
|
||||||
return SchedulingPause(message=self.message)
|
return SchedulingPause(message=self.message)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ dependencies = [
|
|||||||
"numpy~=1.26.4",
|
"numpy~=1.26.4",
|
||||||
"openpyxl~=3.1.5",
|
"openpyxl~=3.1.5",
|
||||||
"opik~=1.8.72",
|
"opik~=1.8.72",
|
||||||
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
|
||||||
"opentelemetry-api==1.27.0",
|
"opentelemetry-api==1.27.0",
|
||||||
"opentelemetry-distro==0.48b0",
|
"opentelemetry-distro==0.48b0",
|
||||||
"opentelemetry-exporter-otlp==1.27.0",
|
"opentelemetry-exporter-otlp==1.27.0",
|
||||||
@@ -230,23 +230,3 @@ vdb = [
|
|||||||
"mo-vector~=0.1.13",
|
"mo-vector~=0.1.13",
|
||||||
"mysql-connector-python>=9.3.0",
|
"mysql-connector-python>=9.3.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
|
|
||||||
[[tool.mypy.overrides]]
|
|
||||||
# targeted ignores for current type-check errors
|
|
||||||
# TODO(QuantumGhost): suppress type errors in HITL related code.
|
|
||||||
# fix the type error later
|
|
||||||
module = [
|
|
||||||
"configs.middleware.cache.redis_pubsub_config",
|
|
||||||
"extensions.ext_redis",
|
|
||||||
"tasks.workflow_execution_tasks",
|
|
||||||
"core.workflow.nodes.base.node",
|
|
||||||
"services.human_input_delivery_test_service",
|
|
||||||
"core.app.apps.advanced_chat.app_generator",
|
|
||||||
"controllers.console.human_input_form",
|
|
||||||
"controllers.console.app.workflow_run",
|
|
||||||
"repositories.sqlalchemy_api_workflow_node_execution_repository",
|
|
||||||
"extensions.logstore.repositories.logstore_api_workflow_run_repository",
|
|
||||||
]
|
|
||||||
ignore_errors = true
|
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Protocol
|
from typing import Protocol
|
||||||
|
|
||||||
@@ -20,27 +19,6 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
|||||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class WorkflowNodeExecutionSnapshot:
|
|
||||||
"""
|
|
||||||
Minimal snapshot of workflow node execution for stream recovery.
|
|
||||||
|
|
||||||
Only includes fields required by snapshot events.
|
|
||||||
"""
|
|
||||||
|
|
||||||
execution_id: str # Unique execution identifier (node_execution_id or row id).
|
|
||||||
node_id: str # Workflow graph node id.
|
|
||||||
node_type: str # Workflow graph node type (e.g. "human-input").
|
|
||||||
title: str # Human-friendly node title.
|
|
||||||
index: int # Execution order index within the workflow run.
|
|
||||||
status: str # Execution status (running/succeeded/failed/paused).
|
|
||||||
elapsed_time: float # Execution elapsed time in seconds.
|
|
||||||
created_at: datetime # Execution created timestamp.
|
|
||||||
finished_at: datetime | None # Execution finished timestamp.
|
|
||||||
iteration_id: str | None = None # Iteration id from execution metadata, if any.
|
|
||||||
loop_id: str | None = None # Loop id from execution metadata, if any.
|
|
||||||
|
|
||||||
|
|
||||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||||
"""
|
"""
|
||||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||||
@@ -101,8 +79,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
|||||||
Args:
|
Args:
|
||||||
tenant_id: The tenant identifier
|
tenant_id: The tenant identifier
|
||||||
app_id: The application identifier
|
app_id: The application identifier
|
||||||
workflow_id: The workflow identifier
|
|
||||||
triggered_from: The workflow trigger source
|
|
||||||
workflow_run_id: The workflow run identifier
|
workflow_run_id: The workflow run identifier
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -110,27 +86,6 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
|||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_execution_snapshots_by_workflow_run(
|
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
app_id: str,
|
|
||||||
workflow_id: str,
|
|
||||||
triggered_from: str,
|
|
||||||
workflow_run_id: str,
|
|
||||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
|
||||||
"""
|
|
||||||
Get minimal snapshots for node executions in a workflow run.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: The tenant identifier
|
|
||||||
app_id: The application identifier
|
|
||||||
workflow_run_id: The workflow run identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def get_execution_by_id(
|
def get_execution_by_id(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
|
|||||||
@@ -432,13 +432,6 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
# while creating pause.
|
# while creating pause.
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None:
|
|
||||||
"""Retrieve the current pause for a workflow execution.
|
|
||||||
|
|
||||||
If there is no current pause, this method would return `None`.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
def resume_workflow_pause(
|
def resume_workflow_pause(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -634,19 +627,3 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
|||||||
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
|
||||||
"""
|
|
||||||
Get a specific workflow run by its id and the associated tenant id.
|
|
||||||
|
|
||||||
This function does not apply application isolation. It should only be used when
|
|
||||||
the application identifier is not available.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tenant_id: Tenant identifier for multi-tenant isolation
|
|
||||||
run_id: Workflow run identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
WorkflowRun object if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|||||||
@@ -63,12 +63,6 @@ class WorkflowPauseEntity(ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def paused_at(self) -> datetime:
|
|
||||||
"""`paused_at` returns the creation time of the pause."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||||
"""
|
"""
|
||||||
@@ -76,5 +70,7 @@ class WorkflowPauseEntity(ABC):
|
|||||||
|
|
||||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||||
reasons for which the workflow execution was paused.
|
reasons for which the workflow execution was paused.
|
||||||
|
This information is related to, but distinct from, the `PauseReason` type
|
||||||
|
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -1,13 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Protocol
|
|
||||||
|
|
||||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionExtraContentRepository(Protocol):
|
|
||||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["ExecutionExtraContentRepository"]
|
|
||||||
@@ -5,7 +5,6 @@ This module provides a concrete implementation of the service repository protoco
|
|||||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import cast
|
from typing import cast
|
||||||
@@ -14,12 +13,11 @@ from sqlalchemy import asc, delete, desc, func, select
|
|||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, sessionmaker
|
from sqlalchemy.orm import Session, sessionmaker
|
||||||
|
|
||||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
from models.workflow import (
|
||||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
WorkflowNodeExecutionModel,
|
||||||
from repositories.api_workflow_node_execution_repository import (
|
WorkflowNodeExecutionOffload,
|
||||||
DifyAPIWorkflowNodeExecutionRepository,
|
|
||||||
WorkflowNodeExecutionSnapshot,
|
|
||||||
)
|
)
|
||||||
|
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||||
|
|
||||||
|
|
||||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||||
@@ -81,7 +79,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
WorkflowNodeExecutionModel.app_id == app_id,
|
WorkflowNodeExecutionModel.app_id == app_id,
|
||||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||||
WorkflowNodeExecutionModel.node_id == node_id,
|
WorkflowNodeExecutionModel.node_id == node_id,
|
||||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
|
||||||
)
|
)
|
||||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||||
.limit(1)
|
.limit(1)
|
||||||
@@ -120,80 +117,6 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
|||||||
with self._session_maker() as session:
|
with self._session_maker() as session:
|
||||||
return session.execute(stmt).scalars().all()
|
return session.execute(stmt).scalars().all()
|
||||||
|
|
||||||
def get_execution_snapshots_by_workflow_run(
|
|
||||||
self,
|
|
||||||
tenant_id: str,
|
|
||||||
app_id: str,
|
|
||||||
workflow_id: str,
|
|
||||||
triggered_from: str,
|
|
||||||
workflow_run_id: str,
|
|
||||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
|
||||||
stmt = (
|
|
||||||
select(
|
|
||||||
WorkflowNodeExecutionModel.id,
|
|
||||||
WorkflowNodeExecutionModel.node_execution_id,
|
|
||||||
WorkflowNodeExecutionModel.node_id,
|
|
||||||
WorkflowNodeExecutionModel.node_type,
|
|
||||||
WorkflowNodeExecutionModel.title,
|
|
||||||
WorkflowNodeExecutionModel.index,
|
|
||||||
WorkflowNodeExecutionModel.status,
|
|
||||||
WorkflowNodeExecutionModel.elapsed_time,
|
|
||||||
WorkflowNodeExecutionModel.created_at,
|
|
||||||
WorkflowNodeExecutionModel.finished_at,
|
|
||||||
WorkflowNodeExecutionModel.execution_metadata,
|
|
||||||
)
|
|
||||||
.where(
|
|
||||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
|
||||||
WorkflowNodeExecutionModel.app_id == app_id,
|
|
||||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
|
||||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
|
||||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
|
||||||
)
|
|
||||||
.order_by(
|
|
||||||
asc(WorkflowNodeExecutionModel.created_at),
|
|
||||||
asc(WorkflowNodeExecutionModel.index),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._session_maker() as session:
|
|
||||||
rows = session.execute(stmt).all()
|
|
||||||
|
|
||||||
return [self._row_to_snapshot(row) for row in rows]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
|
||||||
metadata: dict[str, object] = {}
|
|
||||||
execution_metadata = getattr(row, "execution_metadata", None)
|
|
||||||
if execution_metadata:
|
|
||||||
try:
|
|
||||||
metadata = json.loads(execution_metadata)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
metadata = {}
|
|
||||||
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
|
|
||||||
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
|
|
||||||
execution_id = getattr(row, "node_execution_id", None) or row.id
|
|
||||||
elapsed_time = getattr(row, "elapsed_time", None)
|
|
||||||
created_at = row.created_at
|
|
||||||
finished_at = getattr(row, "finished_at", None)
|
|
||||||
if elapsed_time is None:
|
|
||||||
if finished_at is not None and created_at is not None:
|
|
||||||
elapsed_time = (finished_at - created_at).total_seconds()
|
|
||||||
else:
|
|
||||||
elapsed_time = 0.0
|
|
||||||
return WorkflowNodeExecutionSnapshot(
|
|
||||||
execution_id=str(execution_id),
|
|
||||||
node_id=row.node_id,
|
|
||||||
node_type=row.node_type,
|
|
||||||
title=row.title,
|
|
||||||
index=row.index,
|
|
||||||
status=row.status,
|
|
||||||
elapsed_time=float(elapsed_time),
|
|
||||||
created_at=created_at,
|
|
||||||
finished_at=finished_at,
|
|
||||||
iteration_id=str(iteration_id) if iteration_id else None,
|
|
||||||
loop_id=str(loop_id) if loop_id else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_execution_by_id(
|
def get_execution_by_id(
|
||||||
self,
|
self,
|
||||||
execution_id: str,
|
execution_id: str,
|
||||||
|
|||||||
@@ -19,7 +19,6 @@ Implementation Notes:
|
|||||||
- Maintains data consistency with proper transaction handling
|
- Maintains data consistency with proper transaction handling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import Callable, Sequence
|
from collections.abc import Callable, Sequence
|
||||||
@@ -28,14 +27,12 @@ from decimal import Decimal
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from pydantic import ValidationError
|
|
||||||
from sqlalchemy import and_, delete, func, null, or_, select
|
from sqlalchemy import and_, delete, func, null, or_, select
|
||||||
from sqlalchemy.engine import CursorResult
|
from sqlalchemy.engine import CursorResult
|
||||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||||
|
|
||||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
|
||||||
from extensions.ext_storage import storage
|
from extensions.ext_storage import storage
|
||||||
from libs.datetime_utils import naive_utc_now
|
from libs.datetime_utils import naive_utc_now
|
||||||
from libs.helper import convert_datetime_to_date
|
from libs.helper import convert_datetime_to_date
|
||||||
@@ -43,7 +40,6 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
|||||||
from libs.time_parser import get_time_threshold
|
from libs.time_parser import get_time_threshold
|
||||||
from libs.uuid_utils import uuidv7
|
from libs.uuid_utils import uuidv7
|
||||||
from models.enums import WorkflowRunTriggeredFrom
|
from models.enums import WorkflowRunTriggeredFrom
|
||||||
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
|
|
||||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||||
@@ -61,67 +57,6 @@ class _WorkflowRunError(Exception):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _select_recipient_token(
|
|
||||||
recipients: Sequence[HumanInputFormRecipient],
|
|
||||||
recipient_type: RecipientType,
|
|
||||||
) -> str | None:
|
|
||||||
for recipient in recipients:
|
|
||||||
if recipient.recipient_type == recipient_type and recipient.access_token:
|
|
||||||
return recipient.access_token
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _build_human_input_required_reason(
|
|
||||||
reason_model: WorkflowPauseReason,
|
|
||||||
form_model: HumanInputForm | None,
|
|
||||||
recipients: Sequence[HumanInputFormRecipient],
|
|
||||||
) -> HumanInputRequired:
|
|
||||||
form_content = ""
|
|
||||||
inputs = []
|
|
||||||
actions = []
|
|
||||||
display_in_ui = False
|
|
||||||
resolved_default_values: dict[str, Any] = {}
|
|
||||||
node_title = "Human Input"
|
|
||||||
form_id = reason_model.form_id
|
|
||||||
node_id = reason_model.node_id
|
|
||||||
if form_model is not None:
|
|
||||||
form_id = form_model.id
|
|
||||||
node_id = form_model.node_id or node_id
|
|
||||||
try:
|
|
||||||
definition_payload = json.loads(form_model.form_definition)
|
|
||||||
if "expiration_time" not in definition_payload:
|
|
||||||
definition_payload["expiration_time"] = form_model.expiration_time
|
|
||||||
definition = FormDefinition.model_validate(definition_payload)
|
|
||||||
except ValidationError:
|
|
||||||
definition = None
|
|
||||||
|
|
||||||
if definition is not None:
|
|
||||||
form_content = definition.form_content
|
|
||||||
inputs = list(definition.inputs)
|
|
||||||
actions = list(definition.user_actions)
|
|
||||||
display_in_ui = bool(definition.display_in_ui)
|
|
||||||
resolved_default_values = dict(definition.default_values)
|
|
||||||
node_title = definition.node_title or node_title
|
|
||||||
|
|
||||||
form_token = (
|
|
||||||
_select_recipient_token(recipients, RecipientType.BACKSTAGE)
|
|
||||||
or _select_recipient_token(recipients, RecipientType.CONSOLE)
|
|
||||||
or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP)
|
|
||||||
)
|
|
||||||
|
|
||||||
return HumanInputRequired(
|
|
||||||
form_id=form_id,
|
|
||||||
form_content=form_content,
|
|
||||||
inputs=inputs,
|
|
||||||
actions=actions,
|
|
||||||
display_in_ui=display_in_ui,
|
|
||||||
node_id=node_id,
|
|
||||||
node_title=node_title,
|
|
||||||
form_token=form_token,
|
|
||||||
resolved_default_values=resolved_default_values,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||||
"""
|
"""
|
||||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||||
@@ -741,11 +676,9 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||||
|
|
||||||
# Check if workflow is in RUNNING status
|
# Check if workflow is in RUNNING status
|
||||||
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
|
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||||
# happens before the execution of GraphLayer
|
|
||||||
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
|
|
||||||
raise _WorkflowRunError(
|
raise _WorkflowRunError(
|
||||||
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
|
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||||
)
|
)
|
||||||
#
|
#
|
||||||
@@ -796,48 +729,13 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
|
|
||||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||||
|
|
||||||
return _PrivateWorkflowPauseEntity(
|
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
|
||||||
pause_model=pause_model,
|
|
||||||
reason_models=pause_reason_models,
|
|
||||||
pause_reasons=pause_reasons,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
||||||
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
||||||
pause_reason_models = session.scalars(reason_stmt).all()
|
pause_reason_models = session.scalars(reason_stmt).all()
|
||||||
return pause_reason_models
|
return pause_reason_models
|
||||||
|
|
||||||
def _hydrate_pause_reasons(
|
|
||||||
self,
|
|
||||||
session: Session,
|
|
||||||
pause_reason_models: Sequence[WorkflowPauseReason],
|
|
||||||
) -> list[PauseReason]:
|
|
||||||
form_ids = [
|
|
||||||
reason.form_id
|
|
||||||
for reason in pause_reason_models
|
|
||||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id
|
|
||||||
]
|
|
||||||
form_models: dict[str, HumanInputForm] = {}
|
|
||||||
recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {}
|
|
||||||
if form_ids:
|
|
||||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
|
||||||
for form in session.scalars(form_stmt).all():
|
|
||||||
form_models[form.id] = form
|
|
||||||
|
|
||||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
|
||||||
for recipient in session.scalars(recipient_stmt).all():
|
|
||||||
recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient)
|
|
||||||
|
|
||||||
pause_reasons: list[PauseReason] = []
|
|
||||||
for reason in pause_reason_models:
|
|
||||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
|
||||||
form_model = form_models.get(reason.form_id)
|
|
||||||
recipients = recipient_models_by_form.get(reason.form_id, [])
|
|
||||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients))
|
|
||||||
else:
|
|
||||||
pause_reasons.append(reason.to_entity())
|
|
||||||
return pause_reasons
|
|
||||||
|
|
||||||
def get_workflow_pause(
|
def get_workflow_pause(
|
||||||
self,
|
self,
|
||||||
workflow_run_id: str,
|
workflow_run_id: str,
|
||||||
@@ -869,12 +767,14 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
if pause_model is None:
|
if pause_model is None:
|
||||||
return None
|
return None
|
||||||
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||||
pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models)
|
|
||||||
|
human_input_form: list[Any] = []
|
||||||
|
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
|
||||||
|
|
||||||
return _PrivateWorkflowPauseEntity(
|
return _PrivateWorkflowPauseEntity(
|
||||||
pause_model=pause_model,
|
pause_model=pause_model,
|
||||||
reason_models=pause_reason_models,
|
reason_models=pause_reason_models,
|
||||||
pause_reasons=pause_reasons,
|
human_input_form=human_input_form,
|
||||||
)
|
)
|
||||||
|
|
||||||
def resume_workflow_pause(
|
def resume_workflow_pause(
|
||||||
@@ -928,10 +828,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||||
|
|
||||||
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||||
hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons)
|
|
||||||
|
|
||||||
# Mark as resumed
|
# Mark as resumed
|
||||||
pause_model.resumed_at = naive_utc_now()
|
pause_model.resumed_at = naive_utc_now()
|
||||||
|
workflow_run.pause_id = None # type: ignore
|
||||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||||
|
|
||||||
session.add(pause_model)
|
session.add(pause_model)
|
||||||
@@ -939,11 +839,7 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
|||||||
|
|
||||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||||
|
|
||||||
return _PrivateWorkflowPauseEntity(
|
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
|
||||||
pause_model=pause_model,
|
|
||||||
reason_models=pause_reasons,
|
|
||||||
pause_reasons=hydrated_pause_reasons,
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_workflow_pause(
|
def delete_workflow_pause(
|
||||||
self,
|
self,
|
||||||
@@ -1269,15 +1165,6 @@ GROUP BY
|
|||||||
|
|
||||||
return cast(list[AverageInteractionStats], response_data)
|
return cast(list[AverageInteractionStats], response_data)
|
||||||
|
|
||||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
|
||||||
"""Get a specific workflow run by its id and the associated tenant id."""
|
|
||||||
with self._session_maker() as session:
|
|
||||||
stmt = select(WorkflowRun).where(
|
|
||||||
WorkflowRun.tenant_id == tenant_id,
|
|
||||||
WorkflowRun.id == run_id,
|
|
||||||
)
|
|
||||||
return session.scalar(stmt)
|
|
||||||
|
|
||||||
|
|
||||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||||
"""
|
"""
|
||||||
@@ -1292,12 +1179,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||||||
*,
|
*,
|
||||||
pause_model: WorkflowPause,
|
pause_model: WorkflowPause,
|
||||||
reason_models: Sequence[WorkflowPauseReason],
|
reason_models: Sequence[WorkflowPauseReason],
|
||||||
pause_reasons: Sequence[PauseReason] | None = None,
|
|
||||||
human_input_form: Sequence = (),
|
human_input_form: Sequence = (),
|
||||||
) -> None:
|
) -> None:
|
||||||
self._pause_model = pause_model
|
self._pause_model = pause_model
|
||||||
self._reason_models = reason_models
|
self._reason_models = reason_models
|
||||||
self._pause_reasons = pause_reasons
|
|
||||||
self._cached_state: bytes | None = None
|
self._cached_state: bytes | None = None
|
||||||
self._human_input_form = human_input_form
|
self._human_input_form = human_input_form
|
||||||
|
|
||||||
@@ -1334,10 +1219,4 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
|||||||
return self._pause_model.resumed_at
|
return self._pause_model.resumed_at
|
||||||
|
|
||||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||||
if self._pause_reasons is not None:
|
|
||||||
return list(self._pause_reasons)
|
|
||||||
return [reason.to_entity() for reason in self._reason_models]
|
return [reason.to_entity() for reason in self._reason_models]
|
||||||
|
|
||||||
@property
|
|
||||||
def paused_at(self) -> datetime:
|
|
||||||
return self._pause_model.created_at
|
|
||||||
|
|||||||
@@ -1,200 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from collections import defaultdict
|
|
||||||
from collections.abc import Sequence
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
|
||||||
|
|
||||||
from core.entities.execution_extra_content import (
|
|
||||||
ExecutionExtraContentDomainModel,
|
|
||||||
HumanInputFormDefinition,
|
|
||||||
HumanInputFormSubmissionData,
|
|
||||||
)
|
|
||||||
from core.entities.execution_extra_content import (
|
|
||||||
HumanInputContent as HumanInputContentDomainModel,
|
|
||||||
)
|
|
||||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
|
||||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
|
||||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
|
||||||
from models.execution_extra_content import (
|
|
||||||
ExecutionExtraContent as ExecutionExtraContentModel,
|
|
||||||
)
|
|
||||||
from models.execution_extra_content import (
|
|
||||||
HumanInputContent as HumanInputContentModel,
|
|
||||||
)
|
|
||||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
|
||||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_output_field_names(form_content: str) -> list[str]:
|
|
||||||
if not form_content:
|
|
||||||
return []
|
|
||||||
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
|
|
||||||
|
|
||||||
|
|
||||||
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
|
|
||||||
def __init__(self, session_maker: sessionmaker[Session]):
|
|
||||||
self._session_maker = session_maker
|
|
||||||
|
|
||||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
|
||||||
if not message_ids:
|
|
||||||
return []
|
|
||||||
|
|
||||||
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
|
|
||||||
message_id: [] for message_id in message_ids
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt = (
|
|
||||||
select(ExecutionExtraContentModel)
|
|
||||||
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
|
|
||||||
.options(selectinload(HumanInputContentModel.form))
|
|
||||||
.order_by(ExecutionExtraContentModel.created_at.asc())
|
|
||||||
)
|
|
||||||
|
|
||||||
with self._session_maker() as session:
|
|
||||||
results = session.scalars(stmt).all()
|
|
||||||
|
|
||||||
form_ids = {
|
|
||||||
content.form_id
|
|
||||||
for content in results
|
|
||||||
if isinstance(content, HumanInputContentModel) and content.form_id is not None
|
|
||||||
}
|
|
||||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
|
|
||||||
if form_ids:
|
|
||||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
|
||||||
recipients = session.scalars(recipient_stmt).all()
|
|
||||||
for recipient in recipients:
|
|
||||||
recipients_by_form_id[recipient.form_id].append(recipient)
|
|
||||||
else:
|
|
||||||
recipients_by_form_id = {}
|
|
||||||
|
|
||||||
for content in results:
|
|
||||||
message_id = content.message_id
|
|
||||||
if not message_id or message_id not in grouped_contents:
|
|
||||||
continue
|
|
||||||
|
|
||||||
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
|
|
||||||
if domain_model is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
grouped_contents[message_id].append(domain_model)
|
|
||||||
|
|
||||||
return [grouped_contents[message_id] for message_id in message_ids]
|
|
||||||
|
|
||||||
def _map_model_to_domain(
|
|
||||||
self,
|
|
||||||
model: ExecutionExtraContentModel,
|
|
||||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
|
||||||
) -> ExecutionExtraContentDomainModel | None:
|
|
||||||
if isinstance(model, HumanInputContentModel):
|
|
||||||
return self._map_human_input_content(model, recipients_by_form_id)
|
|
||||||
|
|
||||||
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _map_human_input_content(
|
|
||||||
self,
|
|
||||||
model: HumanInputContentModel,
|
|
||||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
|
||||||
) -> HumanInputContentDomainModel | None:
|
|
||||||
form = model.form
|
|
||||||
if form is None:
|
|
||||||
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
definition_payload = json.loads(form.form_definition)
|
|
||||||
if "expiration_time" not in definition_payload:
|
|
||||||
definition_payload["expiration_time"] = form.expiration_time
|
|
||||||
form_definition = FormDefinition.model_validate(definition_payload)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
|
|
||||||
return None
|
|
||||||
node_title = form_definition.node_title or form.node_id
|
|
||||||
display_in_ui = bool(form_definition.display_in_ui)
|
|
||||||
|
|
||||||
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
|
|
||||||
if not submitted:
|
|
||||||
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
|
|
||||||
return HumanInputContentDomainModel(
|
|
||||||
workflow_run_id=model.workflow_run_id,
|
|
||||||
submitted=False,
|
|
||||||
form_definition=HumanInputFormDefinition(
|
|
||||||
form_id=form.id,
|
|
||||||
node_id=form.node_id,
|
|
||||||
node_title=node_title,
|
|
||||||
form_content=form.rendered_content,
|
|
||||||
inputs=form_definition.inputs,
|
|
||||||
actions=form_definition.user_actions,
|
|
||||||
display_in_ui=display_in_ui,
|
|
||||||
form_token=form_token,
|
|
||||||
resolved_default_values=form_definition.default_values,
|
|
||||||
expiration_time=int(form.expiration_time.timestamp()),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
selected_action_id = form.selected_action_id
|
|
||||||
if not selected_action_id:
|
|
||||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
action_text = next(
|
|
||||||
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
|
|
||||||
selected_action_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
submitted_data: dict[str, Any] = {}
|
|
||||||
if form.submitted_data:
|
|
||||||
try:
|
|
||||||
submitted_data = json.loads(form.submitted_data)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
rendered_content = HumanInputNode.render_form_content_with_outputs(
|
|
||||||
form.rendered_content,
|
|
||||||
submitted_data,
|
|
||||||
_extract_output_field_names(form_definition.form_content),
|
|
||||||
)
|
|
||||||
|
|
||||||
return HumanInputContentDomainModel(
|
|
||||||
workflow_run_id=model.workflow_run_id,
|
|
||||||
submitted=True,
|
|
||||||
form_submission_data=HumanInputFormSubmissionData(
|
|
||||||
node_id=form.node_id,
|
|
||||||
node_title=node_title,
|
|
||||||
rendered_content=rendered_content,
|
|
||||||
action_id=selected_action_id,
|
|
||||||
action_text=action_text,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
|
|
||||||
console_recipient = next(
|
|
||||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if console_recipient and console_recipient.access_token:
|
|
||||||
return console_recipient.access_token
|
|
||||||
|
|
||||||
web_app_recipient = next(
|
|
||||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
if web_app_recipient and web_app_recipient.access_token:
|
|
||||||
return web_app_recipient.access_token
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]
|
|
||||||
@@ -92,16 +92,6 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
|||||||
|
|
||||||
return list(self.session.scalars(query).all())
|
return list(self.session.scalars(query).all())
|
||||||
|
|
||||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
|
||||||
"""Get the trigger log associated with a workflow run."""
|
|
||||||
query = (
|
|
||||||
select(WorkflowTriggerLog)
|
|
||||||
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
|
|
||||||
.order_by(WorkflowTriggerLog.created_at.desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
return self.session.scalar(query)
|
|
||||||
|
|
||||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||||
"""
|
"""
|
||||||
Delete trigger logs associated with the given workflow run ids.
|
Delete trigger logs associated with the given workflow run ids.
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user