Compare commits

..

113 Commits

Author SHA1 Message Date
yyh
a0aa8cdb45 Merge remote-tracking branch 'origin/main' into feature/task-quadrant-view 2026-01-16 18:20:29 +08:00
yyh
ae8618877b fix(web): quadrant matrix i18n 2026-01-16 18:17:28 +08:00
가은 정
fad6fa141d chore: improve accessibility for learn more link (#31120)
Co-authored-by: khmandarrin <jeong-ga-eun@jeong-ga-eun-ui-MacBookAir.local>
2026-01-16 18:12:07 +08:00
Pádraic Slattery
30821fd26c chore: Update outdated GitHub Actions versions (#31114) 2026-01-16 17:56:55 +08:00
Xiangxuan Qu
1a9fdd9a65 refactor: migrate tag list API query parameters to Pydantic (#31097)
Co-authored-by: fghpdf <fghpdf@users.noreply.github.com>
2026-01-16 17:49:52 +08:00
Stream
de610cbf39 fix: call get_text_content() instead of casting to str (#31121)
Signed-off-by: Stream <Stream_2@qq.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-16 18:41:00 +09:00
yyh
1c55602445 fix(web): add calendar icon and DDL label to deadline badge in task-item 2026-01-16 17:24:11 +08:00
yyh
a3f1220d23 feat(web): add fullscreen expand mode to quadrant-matrix component
- Add expand button in header to open FullScreenModal
- Add numbered circles (1-4) to quadrant headers
- Add expanded prop to show full content without line-clamp
- Reorder grid layout: Q1 top-left, Q2 top-right, Q3 bottom-left, Q4 bottom-right
- Remove axis labels for cleaner design
2026-01-16 17:16:13 +08:00
yyh
d62e16b9bb fix(web): improve quadrant-matrix layout and text overflow handling
- Simplify axis label layout with horizontal/vertical arrangement
- Add proper text truncation with line-clamp and tooltips
- Fix overflow issues by adding min-w-0 on flex children
- Move scores inline with task name for compact display
- Add task count badge to quadrant headers
- Reduce maxDisplay to 3 for better density
2026-01-16 16:58:57 +08:00
yyh
13f2a43ccc feat(web): add Eisenhower Matrix visualization component for task quadrants
Add a new quadrant-matrix component that renders tasks in a 2x2 grid based
on importance and urgency scores. Integrate with code-block as a new
'quadrant' language type for markdown rendering.
2026-01-16 16:58:56 +08:00
yyh
6903c31b84 fix(search-input): retain focus after clearing input (#31107) 2026-01-16 16:22:14 +08:00
盐粒 Yanli
b2cc9b255d chore: Update coding agent workflow for backend (#31093) 2026-01-16 14:28:47 +08:00
XiaoBa
e9f0e1e839 fix(web): replace Response.json with legacy Response constructor for pre-Chrome 105 compatibility(#31091) (#31095)
Co-authored-by: Xiaoba Yu <xb1823725853@gmail.com>
2026-01-16 14:26:23 +08:00
pavior
cd497a8c52 fix(web): use portal for variable picker in code editor (Fixes #31063) (#31066) 2026-01-16 13:31:57 +08:00
Stephen Zhou
7aab4529e6 chore: lint for state hooks (#31088) 2026-01-16 11:58:28 +08:00
E.G
4bff0cd0ab fix: resolve 'Expand all chunks' button not working (#31074)
Co-authored-by: GlobalStar117 <GlobalStar117@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <427733928@qq.com>
2026-01-16 11:34:42 +08:00
byteforge
c98870c3f4 refactor: always preserve marketplace search state in URL (#31069)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-16 08:52:53 +09:00
Stephen Zhou
b06c7c8f33 ci: disable limit annotation (#31072) 2026-01-15 23:04:26 +08:00
Stephen Zhou
1a2fce7055 ci: eslint annotation (#31056) 2026-01-15 21:49:46 +08:00
lif
2b021e8752 fix: remove hardcoded 48-character limit from text inputs (#30156)
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-15 17:43:00 +08:00
wangxiaolei
4a197b9458 fix: fix log updated_at is refreshed (#31045) 2026-01-15 15:42:46 +08:00
Xiyuan Chen
772ff636ec feat: credential sync fix for enterprise edition (#30626) 2026-01-14 23:33:24 -08:00
Stephen Zhou
ab1c5a2027 refactor: remove manual set query logic (#31039) 2026-01-15 15:25:43 +08:00
hj24
33e99f069b fix: message clean service ut (#31038) 2026-01-15 15:13:25 +08:00
hj24
52af829f1f refactor: enhance clean messages task (#29638)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: 非法操作 <hjlarry@163.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-15 14:03:17 +08:00
-LAN-
0ef8b5a0ca chore: bump version to 1.11.4 (#30961) 2026-01-15 11:36:15 +08:00
wangxiaolei
2bfc54314e feat: single run add opentelemetry (#31020) 2026-01-15 11:10:55 +08:00
Coding On Star
bdd8d5b470 test: add unit tests for PluginPage and related components (#30908)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:56:02 +08:00
Joseph Adams
4955de5905 fix: validation error when uploading images with None URL values (#31012)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-15 10:54:10 +08:00
yyh
3bee2ee067 refactor(contract): restructure console contracts with nested billing module (#30999) 2026-01-15 10:41:18 +08:00
Stephen Zhou
328897f81c build: require node 24.13.0 (#30945) 2026-01-15 10:38:55 +08:00
Coding On Star
ab078380a3 feat(web): refactor documents component structure and enhance functionality (#30854)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:33:58 +08:00
Coding On Star
a33ac77a22 feat: implement document creation pipeline with multi-step wizard and datasource management (#30843)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-15 10:33:48 +08:00
Asuka Minato
d3923e7b56 refactor: port AppAnnotationHitHistory (#30922)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-15 10:14:55 +08:00
Asuka Minato
2f633de45e refactor: port TenantCreditPool (#30926)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-15 10:14:15 +08:00
wangxiaolei
98c88cec34 refactor: delete_endpoint should be idempotent (#30954) 2026-01-15 10:10:10 +08:00
wangxiaolei
c6999fb5be fix: fix plugin edit endpoint app disappear (#30951) 2026-01-15 10:09:57 +08:00
Asuka Minato
f7f9a08fa5 refactor: port TidbAuthBinding( (#31006)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-01-15 10:07:02 +08:00
wangxiaolei
5008f5e89b fix: Use raw SQL UPDATE to set read status without triggering updated… (#31015) 2026-01-15 09:51:44 +08:00
wangxiaolei
1dd89a02ea fix: fix missing id and message_id (#31008) 2026-01-14 23:26:17 +09:00
盐粒 Yanli
5bf4114d6f fix: increase name length limit in ExternalDatasetCreatePayload (#31000)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
2026-01-14 22:13:53 +09:00
yyh
a56e94ba8e feat: add .agent/skills symlink and orpc-contract-first skill (#30968) 2026-01-14 21:13:14 +08:00
Milad Rashidikhah
11f1782df0 fix: correct API Extension documentation link (#30962) 2026-01-14 21:21:15 +09:00
wangxiaolei
8cf5d9a6a1 fix: fix Cannot destructure property 'name' of 'value' as it is undef… (#30991) 2026-01-14 19:30:47 +08:00
wangxiaolei
0ec2b12e65 feat: allow pass hostname in docker env (#30975) 2026-01-14 19:30:37 +08:00
Stephen Zhou
f33b1a3332 fix: redirect after login (#30985) 2026-01-14 17:20:49 +08:00
kenwoodjw
08026f7399 fix(deps): security updates for pdfminer.six, authlib, werkzeug, aiohttp and others (#30976)
Signed-off-by: kenwoodjw <blackxin55+@gmail.com>
2026-01-14 17:03:46 +08:00
yyh
18e051bd66 chore(web): remove unused demo service component (#30979) 2026-01-14 17:03:35 +08:00
yyh
42f991dbef chore(web): disable Serwist dev logs (#30980) 2026-01-14 16:23:58 +08:00
yyh
b1b2c9636f fix(web): preserve HTTP method in ORPC fetchCompat mode (#30971)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-14 16:18:12 +08:00
-LAN-
01f17b7ddc refactor(http_request_node): apply DI for http request node (#30509)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-14 14:19:48 +08:00
yyh
14b2e5bd0d refactor(web): MCP tool availability to context-based version gating (#30955) 2026-01-14 13:40:16 +08:00
wangxiaolei
d095bd413b fix: fix LOOP_CHILDREN_Z_INDEX (#30719) 2026-01-14 10:22:31 +08:00
heyszt
3473ff7ad1 fix: use Factory to create repository in Aliyun Trace (#30899) 2026-01-14 10:21:46 +08:00
fanadong
138c56bd6e fix(logstore): prevent SQL injection, fix serialization issues, and optimize initialization (#30697) 2026-01-14 10:21:26 +08:00
jialin li
c327d0bb44 fix: Correction to the full name of Volc TOS (#30741) 2026-01-14 10:11:30 +08:00
dependabot[bot]
e4b97fba29 chore(deps): bump azure-core from 1.36.0 to 1.38.0 in /api (#30941)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-14 10:10:49 +08:00
UMDKyle
7f9884e7a1 feat: Add option to delete or keep API keys when uninstalling plugin (#28201)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-01-14 10:09:30 +08:00
dependabot[bot]
e389cd1665 chore(deps): bump filelock from 3.20.0 to 3.20.3 in /api (#30939)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-14 09:56:02 +08:00
wangxiaolei
87f348a0de feat: change param to pydantic model (#30870) 2026-01-14 09:46:41 +08:00
-LAN-
206706987d refactor(variables): clarify base vs union type naming (#30634)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-13 23:39:34 +09:00
Stephen Zhou
91da784f84 refactor: init orpc contract (#30885)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2026-01-13 23:38:28 +09:00
Yunlu Wen
a129e684cc feat: inject traceparent in enterprise api (#30895)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 23:37:39 +09:00
wangxiaolei
fe07c810ba fix: fix instance is not bind to session (#30913) 2026-01-13 21:15:21 +08:00
-LAN-
a22cc5bc5e chore: Bump Dify version to 1.11.3 (#30903) 2026-01-13 17:49:13 +08:00
yyh
1fbdf6b465 refactor(web): setup status caching (#30798) 2026-01-13 16:59:49 +08:00
非法操作
491e1fd6a4 chore: case insensitive email (#29978)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: -LAN- <laipz8200@outlook.com>
2026-01-13 15:42:44 +08:00
青枕
0e33dfb5c2 fix: In the LLM model in dify, when a message is added, the first cli… (#29540)
Co-authored-by: 青枕 <qingzhen.ww@alibaba-inc.com>
2026-01-13 15:42:32 +08:00
lif
ea708e7a32 fix(web): add null check for SSE stream bufferObj to prevent TypeError (#30131)
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-13 15:40:43 +08:00
非法操作
c09e29c3f8 chore: rename the migration file (#30893) 2026-01-13 15:26:41 +08:00
wangxiaolei
2d53ba8671 fix: fix object value is optional should skip validate (#30894) 2026-01-13 15:21:06 +08:00
呆萌闷油瓶
9be863fefa fix: missing content if assistant message with tool_calls (#30083)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 12:46:33 +08:00
Coding On Star
8f43629cd8 fix(amplitude): update sessionReplaySampleRate default value to 0.5 (#30880)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-13 12:26:50 +08:00
wangxiaolei
9ee71902c1 fix: fix formatNumber accuracy (#30877) 2026-01-13 11:51:15 +08:00
hsiong
a012c87445 fix: entrypoint.sh overrides NEXT_PUBLIC_TEXT_GENERATION_TIMEOUT_MS when TEXT_GENERATION_TIMEOUT_MS is unset (#30864) (#30865) 2026-01-13 10:12:51 +08:00
heyszt
450578d4c0 feat(ops): set root span kind for AliyunTrace to enable service-level metrics aggregation (#30728)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-13 10:12:00 +08:00
非法操作
837237aa6d fix: use node factory for single-step workflow nodes (#30859) 2026-01-13 10:11:18 +08:00
QuantumGhost
b63dfbf654 fix(api): defer streaming response until referenced variables are updated (#30832)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-12 16:23:18 +08:00
非法操作
51ea87ab85 feat: clear free plan workflow run logs (#29494)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-12 15:57:40 +08:00
Stephen Zhou
00698e41b7 build: limit esbuild, glob, docker base version to avoid cve (#30848) 2026-01-12 15:33:20 +08:00
QuantumGhost
df938a4543 ci: add HITL test env deployment action (#30846) 2026-01-12 15:07:53 +08:00
yyh
9161936f41 refactor(web): extract isServer/isClient utility & upgrade Node.js to 22.12.0 (#30803)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-12 12:57:43 +08:00
Lemonadeccc
f9a21b56ab feat: add block-no-verify hook for Claude Code (#30839) 2026-01-12 12:56:05 +08:00
Stephen Zhou
220e1df847 docs(web): add corepack recommendation (#30837)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-12 12:44:30 +08:00
dependabot[bot]
8cfdde594c chore(deps-dev): bump tos from 2.7.2 to 2.9.0 in /api (#30834)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 12:44:21 +08:00
dependabot[bot]
31a8fd810c chore(deps-dev): bump @storybook/react from 9.1.13 to 9.1.17 in /web (#30833)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-01-12 12:44:11 +08:00
yihong
9fad97ec9b fix: drop useless pyrefly in ci (#30826)
Signed-off-by: yihong0618 <zouzou0208@gmail.com>
2026-01-12 09:45:49 +08:00
wangxiaolei
0c2729d9b3 fix: fix refresh token deadlock (#30828) 2026-01-12 09:35:31 +08:00
wangxiaolei
a2e03b811e fix: Broken import in .storybook/preview.tsx (#30812) 2026-01-10 19:49:23 +08:00
-LAN-
1e10bf525c refactor(models): Refine MessageAgentThought SQLAlchemy typing (#27749)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-01-10 17:17:45 +09:00
Stephen Zhou
8b1af36d94 feat(web): migrate PWA to Serwist (#30808) 2026-01-10 17:16:18 +09:00
wangxiaolei
0711dd4159 feat: enhance start node object value check (#30732) 2026-01-09 16:13:17 +08:00
QuantumGhost
ae0a26f5b6 revert: "fix: fix assign value stand as default (#30651)" (#30717)
The original fix seems correct on its own. However, for chatflows with multiple answer nodes, the `message_replace` command only preserves the output of the last executed answer node.
2026-01-09 16:08:24 +08:00
Stephen Zhou
d4432ed80f refactor: marketplace state management (#30702)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-09 14:31:24 +08:00
lif
9d9f027246 fix(web): invalidate app list cache after deleting app from detail page (#30751)
Signed-off-by: majiayu000 <1835304752@qq.com>
2026-01-09 14:08:37 +08:00
wangxiaolei
77f097ce76 fix: fix enhance app mode check (#30758) 2026-01-09 14:07:40 +08:00
Maries
7843afc91c feat(workflows): add agent-dev deploy workflow (#30774) 2026-01-09 13:55:49 +08:00
Coding On Star
98df99b0ca feat(embedding-process): implement embedding process components and polling logic (#30622)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-09 10:21:27 +08:00
Coding On Star
9848823dcd feat: implement step two of dataset creation with comprehensive UI components and hooks (#30681)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
2026-01-09 10:21:18 +08:00
github-actions[bot]
5ad2385799 chore(i18n): sync translations with en-US (#30750)
Co-authored-by: claude[bot] <41898282+claude[bot]@users.noreply.github.com>
2026-01-08 22:53:04 +08:00
yyh
7774a1312e fix(ci): use repository_dispatch for i18n sync workflow (#30744) 2026-01-08 21:28:49 +08:00
MkDev11
91d44719f4 fix(web): resolve chat message loading race conditions and infinite loops (#30695)
Co-authored-by: crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-01-08 18:05:32 +08:00
xuwei95
b2cbeeae92 fix(web): restrict postMessage targetOrigin from wildcard to specific origins (#30690)
Co-authored-by: XW <wei.xu1@wiz.ai>
2026-01-08 17:23:27 +08:00
Coding On Star
cd1af04dee feat: model total credits (#30727)
Co-authored-by: CodingOnStar <hanxujiang@dify.ai>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-01-08 14:11:44 +08:00
zyssyz123
fe0802262c feat: credit pool (#30720)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-08 13:17:30 +08:00
NFish
c5b99ebd17 fix: web app login code encrypt (#30705) 2026-01-07 18:04:42 -08:00
Xiyuan Chen
adaf0e32c0 feat: add decryption decorators for password and code fields in webapp (#30704) 2026-01-08 10:03:39 +08:00
Rhon Joe
27a803a6f0 fix(web): resolve key-value input box height inconsistency on focus/blur (#30715) (#30716) 2026-01-08 09:54:27 +08:00
yyh
25ff4ae5da fix(i18n): resolve Claude Code sandbox path issues in workflow (#30710) 2026-01-08 09:53:32 +08:00
-LAN-
7ccf858ce6 fix(workflow): pass correct user_from/invoke_from into graph init (#30637) 2026-01-07 21:47:23 +08:00
Asuka Minato
885f226f77 refactor: split changes for api/controllers/console/workspace/trigger… (#30627)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 21:18:02 +08:00
yyh
a422908efd feat(i18n): Migrate translation workflow to Claude Code GitHub Actions (#30692)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-07 21:17:50 +08:00
Xiangxuan Qu
d8a0291382 refactor(web): remove unused type alias VoiceLanguageKey (#30694)
Co-authored-by: fghpdf <fghpdf@users.noreply.github.com>
2026-01-07 21:15:43 +08:00
510 changed files with 35473 additions and 19474 deletions

1
.agent/skills Symbolic link
View File

@@ -0,0 +1 @@
../.claude/skills

View File

@@ -5,5 +5,18 @@
"typescript-lsp@claude-plugins-official": true,
"pyright-lsp@claude-plugins-official": true,
"ralph-loop@claude-plugins-official": true
},
"hooks": {
"PreToolUse": [
{
"matcher": "Bash",
"hooks": [
{
"type": "command",
"command": "npx -y block-no-verify@1.1.1"
}
]
}
]
}
}

View File

@@ -0,0 +1,46 @@
---
name: orpc-contract-first
description: Guide for implementing oRPC contract-first API patterns in Dify frontend. Triggers when creating new API contracts, adding service endpoints, integrating TanStack Query with typed contracts, or migrating legacy service calls to oRPC. Use for all API layer work in web/contract and web/service directories.
---
# oRPC Contract-First Development
## Project Structure
```
web/contract/
├── base.ts # Base contract (inputStructure: 'detailed')
├── router.ts # Router composition & type exports
├── marketplace.ts # Marketplace contracts
└── console/ # Console contracts by domain
├── system.ts
└── billing.ts
```
## Workflow
1. **Create contract** in `web/contract/console/{domain}.ts`
- Import `base` from `../base` and `type` from `@orpc/contract`
- Define route with `path`, `method`, `input`, `output`
2. **Register in router** at `web/contract/router.ts`
- Import directly from domain file (no barrel files)
- Nest by API prefix: `billing: { invoices, bindPartnerStack }`
3. **Create hooks** in `web/service/use-{domain}.ts`
- Use `consoleQuery.{group}.{contract}.queryKey()` for query keys
- Use `consoleClient.{group}.{contract}()` for API calls
## Key Rules
- **Input structure**: Always use `{ params, query?, body? }` format
- **Path params**: Use `{paramName}` in path, match in `params` object
- **Router nesting**: Group by API prefix (e.g., `/billing/*``billing: {}`)
- **No barrel files**: Import directly from specific files
- **Types**: Import from `@/types/`, use `type<T>()` helper
## Type Export
```typescript
export type ConsoleInputs = InferContractRouterInputs<typeof consoleRouterContract>
```

View File

@@ -39,12 +39,6 @@ jobs:
- name: Install dependencies
run: uv sync --project api --dev
- name: Run pyrefly check
run: |
cd api
uv add --dev pyrefly
uv run pyrefly check || true
- name: Run dify config tests
run: uv run --project api dev/pytest/pytest_config_tests.py

View File

@@ -16,14 +16,14 @@ jobs:
- name: Check Docker Compose inputs
id: docker-compose-changes
uses: tj-actions/changed-files@v46
uses: tj-actions/changed-files@v47
with:
files: |
docker/generate_docker_compose
docker/.env.example
docker/docker-compose-template.yaml
docker/docker-compose.yaml
- uses: actions/setup-python@v5
- uses: actions/setup-python@v6
with:
python-version: "3.11"

View File

@@ -112,7 +112,7 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@v4
uses: actions/download-artifact@v7
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*

View File

@@ -1,4 +1,4 @@
name: Deploy Trigger Dev
name: Deploy Agent Dev
permissions:
contents: read
@@ -7,7 +7,7 @@ on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "deploy/trigger-dev"
- "deploy/agent-dev"
types:
- completed
@@ -16,12 +16,12 @@ jobs:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
github.event.workflow_run.head_branch == 'deploy/trigger-dev'
github.event.workflow_run.head_branch == 'deploy/agent-dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.TRIGGER_SSH_HOST }}
host: ${{ secrets.AGENT_DEV_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |

View File

@@ -16,7 +16,7 @@ jobs:
github.event.workflow_run.head_branch == 'deploy/dev'
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v0.1.8
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.SSH_HOST }}
username: ${{ secrets.SSH_USER }}

29
.github/workflows/deploy-hitl.yml vendored Normal file
View File

@@ -0,0 +1,29 @@
name: Deploy HITL
on:
workflow_run:
workflows: ["Build and Push API & Web"]
branches:
- "feat/hitl-frontend"
- "feat/hitl-backend"
types:
- completed
jobs:
deploy:
runs-on: ubuntu-latest
if: |
github.event.workflow_run.conclusion == 'success' &&
(
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
)
steps:
- name: Deploy to server
uses: appleboy/ssh-action@v1
with:
host: ${{ secrets.HITL_SSH_HOST }}
username: ${{ secrets.SSH_USER }}
key: ${{ secrets.SSH_PRIVATE_KEY }}
script: |
${{ vars.SSH_SCRIPT || secrets.SSH_SCRIPT }}

View File

@@ -18,7 +18,7 @@ jobs:
pull-requests: write
steps:
- uses: actions/stale@v5
- uses: actions/stale@v10
with:
days-before-issue-stale: 15
days-before-issue-close: 3

View File

@@ -65,6 +65,9 @@ jobs:
defaults:
run:
working-directory: ./web
permissions:
checks: write
pull-requests: read
steps:
- name: Checkout code
@@ -90,7 +93,7 @@ jobs:
uses: actions/setup-node@v6
if: steps.changed-files.outputs.any_changed == 'true'
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
@@ -103,7 +106,15 @@ jobs:
if: steps.changed-files.outputs.any_changed == 'true'
working-directory: ./web
run: |
pnpm run lint
pnpm run lint:report
continue-on-error: true
# - name: Annotate Code
# if: steps.changed-files.outputs.any_changed == 'true' && github.event_name == 'pull_request'
# uses: DerLev/eslint-annotations@51347b3a0abfb503fc8734d5ae31c4b151297fae
# with:
# eslint-report: web/eslint_report.json
# github-token: ${{ secrets.GITHUB_TOKEN }}
- name: Web type check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -16,10 +16,6 @@ jobs:
name: unit test for Node.js SDK
runs-on: ubuntu-latest
strategy:
matrix:
node-version: [16, 18, 20, 22]
defaults:
run:
working-directory: sdks/nodejs-client
@@ -29,10 +25,10 @@ jobs:
with:
persist-credentials: false
- name: Use Node.js ${{ matrix.node-version }}
- name: Use Node.js
uses: actions/setup-node@v6
with:
node-version: ${{ matrix.node-version }}
node-version: 24
cache: ''
cache-dependency-path: 'pnpm-lock.yaml'

View File

@@ -1,94 +0,0 @@
name: Translate i18n Files Based on English
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
workflow_dispatch:
permissions:
contents: write
pull-requests: write
jobs:
check-and-update:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
defaults:
run:
working-directory: web
steps:
# Keep use old checkout action version for https://github.com/peter-evans/create-pull-request/issues/4272
- uses: actions/checkout@v4
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Check for file changes in i18n/en-US
id: check_files
run: |
# Skip check for manual trigger, translate all files
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
echo "FILE_ARGS=" >> $GITHUB_ENV
echo "Manual trigger: translating all files"
else
git fetch origin "${{ github.event.before }}" || true
git fetch origin "${{ github.sha }}" || true
changed_files=$(git diff --name-only "${{ github.event.before }}" "${{ github.sha }}" -- 'i18n/en-US/*.json')
echo "Changed files: $changed_files"
if [ -n "$changed_files" ]; then
echo "FILES_CHANGED=true" >> $GITHUB_ENV
file_args=""
for file in $changed_files; do
filename=$(basename "$file" .json)
file_args="$file_args --file $filename"
done
echo "FILE_ARGS=$file_args" >> $GITHUB_ENV
echo "File arguments: $file_args"
else
echo "FILES_CHANGED=false" >> $GITHUB_ENV
fi
fi
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
if: env.FILES_CHANGED == 'true'
uses: actions/setup-node@v6
with:
node-version: 'lts/*'
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install dependencies
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm install --frozen-lockfile
- name: Generate i18n translations
if: env.FILES_CHANGED == 'true'
working-directory: ./web
run: pnpm run i18n:gen ${{ env.FILE_ARGS }}
- name: Create Pull Request
if: env.FILES_CHANGED == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore(i18n): update translations based on en-US changes'
title: 'chore(i18n): translate i18n files based on en-US changes'
body: |
This PR was automatically created to update i18n translation files based on changes in en-US locale.
**Triggered by:** ${{ github.sha }}
**Changes included:**
- Updated translation files for all locales
branch: chore/automated-i18n-updates-${{ github.sha }}
delete-branch: true

View File

@@ -0,0 +1,421 @@
name: Translate i18n Files with Claude Code
# Note: claude-code-action doesn't support push events directly.
# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch.
# See: https://github.com/langgenius/dify/issues/30743
on:
repository_dispatch:
types: [i18n-sync]
workflow_dispatch:
inputs:
files:
description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.'
required: false
type: string
languages:
description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.'
required: false
type: string
mode:
description: 'Sync mode: incremental (only changes) or full (re-check all keys)'
required: false
default: 'incremental'
type: choice
options:
- incremental
- full
permissions:
contents: write
pull-requests: write
jobs:
translate:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 60
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
token: ${{ secrets.GITHUB_TOKEN }}
- name: Configure Git
run: |
git config --global user.name "github-actions[bot]"
git config --global user.email "github-actions[bot]@users.noreply.github.com"
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Set up Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Detect changed files and generate diff
id: detect_changes
run: |
if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then
# Manual trigger
if [ -n "${{ github.event.inputs.files }}" ]; then
echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT
else
# Get all JSON files in en-US directory
files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ')
echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT
fi
echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT
# For manual trigger with incremental mode, get diff from last commit
# For full mode, we'll do a complete check anyway
if [ "${{ github.event.inputs.mode }}" == "full" ]; then
echo "Full mode: will check all keys" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
else
git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
if [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
fi
elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then
# Triggered by push via trigger-i18n-sync.yml workflow
# Validate required payload fields
if [ -z "${{ github.event.client_payload.changed_files }}" ]; then
echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2
exit 1
fi
echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT
echo "TARGET_LANGS=" >> $GITHUB_OUTPUT
echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT
# Decode the base64-encoded diff from the trigger workflow
if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then
if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then
echo "Warning: Failed to decode base64 diff payload" >&2
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
elif [ -s /tmp/i18n-diff.txt ]; then
echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT
else
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "" > /tmp/i18n-diff.txt
echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT
fi
else
echo "Unsupported event type: ${{ github.event_name }}"
exit 1
fi
# Truncate diff if too large (keep first 50KB)
if [ -f /tmp/i18n-diff.txt ]; then
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
fi
echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')"
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@v1
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}
prompt: |
You are a professional i18n synchronization engineer for the Dify project.
Your task is to keep all language translations in sync with the English source (en-US).
## CRITICAL TOOL RESTRICTIONS
- Use **Read** tool to read files (NOT cat or bash)
- Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts)
- Use **Bash** ONLY for: git commands, gh commands, pnpm commands
- Run bash commands ONE BY ONE, never combine with && or ||
- NEVER use `$()` command substitution - it's not supported. Split into separate commands instead.
## WORKING DIRECTORY & ABSOLUTE PATHS
Claude Code sandbox working directory may vary. Always use absolute paths:
- For pnpm: `pnpm --dir ${{ github.workspace }}/web <command>`
- For git: `git -C ${{ github.workspace }} <command>`
- For gh: `gh --repo ${{ github.repository }} <command>`
- For file paths: `${{ github.workspace }}/web/i18n/`
## EFFICIENCY RULES
- **ONE Edit per language file** - batch all key additions into a single Edit
- Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them
- Translate ALL keys for a language mentally first, then do ONE Edit
## Context
- Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
- Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }}
- Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json
- Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts
- Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }}
## CRITICAL DESIGN: Verify First, Then Sync
You MUST follow this three-phase approach:
═══════════════════════════════════════════════════════════════
║ PHASE 1: VERIFY - Analyze and Generate Change Report ║
═══════════════════════════════════════════════════════════════
### Step 1.1: Analyze Git Diff (for incremental mode)
Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff.
Parse the diff to categorize changes:
- Lines with `+` (not `+++`): Added or modified values
- Lines with `-` (not `---`): Removed or old values
- Identify specific keys for each category:
* ADD: Keys that appear only in `+` lines (new keys)
* UPDATE: Keys that appear in both `-` and `+` lines (value changed)
* DELETE: Keys that appear only in `-` lines (removed keys)
### Step 1.2: Read Language Configuration
Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`.
Extract all languages with `supported: true`.
### Step 1.3: Run i18n:check for Each Language
```bash
pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile
```
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
This will report:
- Missing keys (need to ADD)
- Extra keys (need to DELETE)
### Step 1.4: Generate Change Report
Create a structured report identifying:
```
╔══════════════════════════════════════════════════════════════╗
║ I18N SYNC CHANGE REPORT ║
╠══════════════════════════════════════════════════════════════╣
║ Files to process: [list] ║
║ Languages to sync: [list] ║
╠══════════════════════════════════════════════════════════════╣
║ ADD (New Keys): ║
║ - [filename].[key]: "English value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ UPDATE (Modified Keys - MUST re-translate): ║
║ - [filename].[key]: "Old value" → "New value" ║
║ ... ║
╠══════════════════════════════════════════════════════════════╣
║ DELETE (Extra Keys): ║
║ - [language]/[filename].[key] ║
║ ... ║
╚══════════════════════════════════════════════════════════════╝
```
**IMPORTANT**: For UPDATE detection, compare git diff to find keys where
the English value changed. These MUST be re-translated even if target
language already has a translation (it's now stale!).
═══════════════════════════════════════════════════════════════
║ PHASE 2: SYNC - Execute Changes Based on Report ║
═══════════════════════════════════════════════════════════════
### Step 2.1: Process ADD Operations (BATCH per language file)
**CRITICAL WORKFLOW for efficiency:**
1. First, translate ALL new keys for ALL languages mentally
2. Then, for EACH language file, do ONE Edit operation:
- Read the file once
- Insert ALL new keys at the beginning (right after the opening `{`)
- Don't worry about alphabetical order - lint:fix will sort them later
Example Edit (adding 3 keys to zh-Hans/app.json):
```
old_string: '{\n "accessControl"'
new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"'
```
**IMPORTANT**:
- ONE Edit per language file (not one Edit per key!)
- Always use the Edit tool. NEVER use bash scripts, node, or jq.
### Step 2.2: Process UPDATE Operations
**IMPORTANT: Special handling for zh-Hans and ja-JP**
If zh-Hans or ja-JP files were ALSO modified in the same push:
- Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files
- If found, it means someone manually translated them. Apply these rules:
1. **Missing keys**: Still ADD them (completeness required)
2. **Existing translations**: Compare with the NEW English value:
- If translation is **completely wrong** or **unrelated** → Update it
- If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work
- When in doubt, **keep the manual translation**
Example:
- English changed: "Save" → "Save Changes"
- Manual translation: "保存更改" → Keep it (correct meaning)
- Manual translation: "删除" → Update it (completely wrong)
For other languages:
Use Edit tool to replace the old value with the new translation.
You can batch multiple updates in one Edit if they are adjacent.
### Step 2.3: Process DELETE Operations
For extra keys reported by i18n:check:
- Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove`
- Or manually remove from target language JSON files
## Translation Guidelines
- PRESERVE all placeholders exactly as-is:
- `{{variable}}` - Mustache interpolation
- `${variable}` - Template literal
- `<tag>content</tag>` - HTML tags
- `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values)
- Use appropriate language register (formal/informal) based on existing translations
- Match existing translation style in each language
- Technical terms: check existing conventions per language
- For CJK languages: no spaces between characters unless necessary
- For RTL languages (ar-TN, fa-IR): ensure proper text handling
## Output Format Requirements
- Alphabetical key ordering (if original file uses it)
- 2-space indentation
- Trailing newline at end of file
- Valid JSON (use proper escaping for special characters)
═══════════════════════════════════════════════════════════════
║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║
═══════════════════════════════════════════════════════════════
### Step 3.1: Run Lint Fix (IMPORTANT!)
```bash
pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json'
```
This ensures:
- JSON keys are sorted alphabetically (jsonc/sort-keys rule)
- Valid i18n keys (dify-i18n/valid-i18n-keys rule)
- No extra keys (dify-i18n/no-extra-keys rule)
### Step 3.2: Run Final i18n Check
```bash
pnpm --dir ${{ github.workspace }}/web run i18n:check
```
### Step 3.3: Fix Any Remaining Issues
If check reports issues:
- Go back to PHASE 2 for unresolved items
- Repeat until check passes
### Step 3.4: Generate Final Summary
```
╔══════════════════════════════════════════════════════════════╗
║ SYNC COMPLETED SUMMARY ║
╠══════════════════════════════════════════════════════════════╣
║ Language │ Added │ Updated │ Deleted │ Status ║
╠══════════════════════════════════════════════════════════════╣
║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║
║ ... │ ... │ ... │ ... │ ... ║
╠══════════════════════════════════════════════════════════════╣
║ i18n:check │ PASSED - All keys in sync ║
╚══════════════════════════════════════════════════════════════╝
```
## Mode-Specific Behavior
**SYNC_MODE = "incremental"** (default):
- Focus on keys identified from git diff
- Also check i18n:check output for any missing/extra keys
- Efficient for small changes
**SYNC_MODE = "full"**:
- Compare ALL keys between en-US and each language
- Run i18n:check to identify all discrepancies
- Use for first-time sync or fixing historical issues
## Important Notes
1. Always run i18n:check BEFORE and AFTER making changes
2. The check script is the source of truth for missing/extra keys
3. For UPDATE scenario: git diff is the source of truth for changed values
4. Create a single commit with all translation changes
5. If any translation fails, continue with others and report failures
═══════════════════════════════════════════════════════════════
║ PHASE 4: COMMIT AND CREATE PR ║
═══════════════════════════════════════════════════════════════
After all translations are complete and verified:
### Step 4.1: Check for changes
```bash
git -C ${{ github.workspace }} status --porcelain
```
If there are changes:
### Step 4.2: Create a new branch and commit
Run these git commands ONE BY ONE (not combined with &&).
**IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands:
1. First, get the timestamp:
```bash
date +%Y%m%d-%H%M%S
```
(Note the output, e.g., "20260115-143052")
2. Then create branch using the timestamp value:
```bash
git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052
```
(Replace "20260115-143052" with the actual timestamp from step 1)
3. Stage changes:
```bash
git -C ${{ github.workspace }} add web/i18n/
```
4. Commit:
```bash
git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}"
```
5. Push:
```bash
git -C ${{ github.workspace }} push origin HEAD
```
### Step 4.3: Create Pull Request
```bash
gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary
This PR was automatically generated to sync i18n translation files.
### Changes
- Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}
- Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }}
### Verification
- [x] \`i18n:check\` passed
- [x] \`lint:fix\` applied
🤖 Generated with Claude Code GitHub Action" --base main
```
claude_args: |
--max-turns 150
--allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep"

66
.github/workflows/trigger-i18n-sync.yml vendored Normal file
View File

@@ -0,0 +1,66 @@
name: Trigger i18n Sync on Push
# This workflow bridges the push event to repository_dispatch
# because claude-code-action doesn't support push events directly.
# See: https://github.com/langgenius/dify/issues/30743
on:
push:
branches: [main]
paths:
- 'web/i18n/en-US/*.json'
permissions:
contents: write
jobs:
trigger:
if: github.repository == 'langgenius/dify'
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- name: Checkout repository
uses: actions/checkout@v6
with:
fetch-depth: 0
- name: Detect changed files and generate diff
id: detect
run: |
BEFORE_SHA="${{ github.event.before }}"
# Handle edge case: force push may have null/zero SHA
if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then
BEFORE_SHA="HEAD~1"
fi
# Detect changed i18n files
changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "")
echo "changed_files=$changed" >> $GITHUB_OUTPUT
# Generate diff for context
git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt
# Truncate if too large (keep first 50KB to match receiving workflow)
head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt
mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt
# Base64 encode the diff for safe JSON transport (portable, single-line)
diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n')
echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT
if [ -n "$changed" ]; then
echo "has_changes=true" >> $GITHUB_OUTPUT
echo "Detected changed files: $changed"
else
echo "has_changes=false" >> $GITHUB_OUTPUT
echo "No i18n changes detected"
fi
- name: Trigger i18n sync workflow
if: steps.detect.outputs.has_changes == 'true'
uses: peter-evans/repository-dispatch@v3
with:
token: ${{ secrets.GITHUB_TOKEN }}
event-type: i18n-sync
client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}'

View File

@@ -31,7 +31,7 @@ jobs:
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 22
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml

1
.nvmrc
View File

@@ -1 +0,0 @@
22.11.0

View File

@@ -12,12 +12,8 @@ The codebase is split into:
## Backend Workflow
- Read `api/AGENTS.md` for details
- Run backend CLI commands through `uv run --project api <command>`.
- Before submission, all backend modifications must pass local checks: `make lint`, `make type-check`, and `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Use Makefile targets for linting and formatting; `make lint` and `make type-check` cover the required checks.
- Integration tests are CI-only and are not expected to run in the local environment.
## Frontend Workflow

View File

@@ -61,7 +61,8 @@ check:
lint:
@echo "🔧 Running ruff format, check with fixes, import linter, and dotenv-linter..."
@uv run --project api --dev sh -c 'ruff format ./api && ruff check --fix ./api'
@uv run --project api --dev ruff format ./api
@uv run --project api --dev ruff check --fix ./api
@uv run --directory api --dev lint-imports
@uv run --project api --dev dotenv-linter ./api/.env.example ./web/.env.example
@echo "✅ Linting complete"
@@ -73,7 +74,12 @@ type-check:
test:
@echo "🧪 Running backend unit tests..."
@uv run --project api --dev dev/pytest/pytest_unit_tests.sh
@if [ -n "$(TARGET_TESTS)" ]; then \
echo "Target: $(TARGET_TESTS)"; \
uv run --project api --dev pytest $(TARGET_TESTS); \
else \
uv run --project api --dev dev/pytest/pytest_unit_tests.sh; \
fi
@echo "✅ Tests complete"
# Build Docker images
@@ -125,7 +131,7 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checking with basedpyright"
@echo " make test - Run backend unit tests"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"
@echo " make build-web - Build web Docker image"

View File

@@ -417,6 +417,8 @@ SMTP_USERNAME=123
SMTP_PASSWORD=abc
SMTP_USE_TLS=true
SMTP_OPPORTUNISTIC_TLS=false
# Optional: override the local hostname used for SMTP HELO/EHLO
SMTP_LOCAL_HOSTNAME=
# Sendgid configuration
SENDGRID_API_KEY=
# Sentry configuration
@@ -589,6 +591,7 @@ ENABLE_CLEAN_UNUSED_DATASETS_TASK=false
ENABLE_CREATE_TIDB_SERVERLESS_TASK=false
ENABLE_UPDATE_TIDB_SERVERLESS_STATUS_TASK=false
ENABLE_CLEAN_MESSAGES=false
ENABLE_WORKFLOW_RUN_CLEANUP_TASK=false
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK=false
ENABLE_DATASETS_QUEUE_MONITOR=false
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK=true
@@ -712,3 +715,4 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30

View File

@@ -1,62 +1,236 @@
# Agent Skill Index
# API Agent Guide
## Agent Notes (must-check)
Before you start work on any backend file under `api/`, you MUST check whether a related note exists under:
- `agent-notes/<same-relative-path-as-target-file>.md`
Rules:
- **Path mapping**: for a target file `<path>/<name>.py`, the note must be `agent-notes/<path>/<name>.py.md` (same folder structure, same filename, plus `.md`).
- **Before working**:
- If the note exists, read it first and follow any constraints/decisions recorded there.
- If the note conflicts with the current code, or references an "origin" file/path that has been deleted, renamed, or migrated, treat the **code as the single source of truth** and update the note to match reality.
- If the note does not exist, create it with a short architecture/intent summary and any relevant invariants/edge cases.
- **During working**:
- Keep the note in sync as you discover constraints, make decisions, or change approach.
- If you move/rename a file, migrate its note to the new mapped path (and fix any outdated references inside the note).
- Record non-obvious edge cases, trade-offs, and the test/verification plan as you go (not just at the end).
- Keep notes **coherent**: integrate new findings into the relevant sections and rewrite for clarity; avoid append-only “recent fix” / changelog-style additions unless the note is explicitly intended to be a changelog.
- **When finishing work**:
- Update the related note(s) to reflect what changed, why, and any new edge cases/tests.
- If a file is deleted, remove or clearly deprecate the corresponding note so it cannot be mistaken as current guidance.
- Keep notes concise and accurate; they are meant to prevent repeated rediscovery.
## Skill Index
Start with the section that best matches your need. Each entry lists the problems it solves plus key files/concepts so you know what to expect before opening it.
______________________________________________________________________
### Platform Foundations
## Platform Foundations
- **[Infrastructure Overview](agent_skills/infra.md)**\
When to read this:
#### [Infrastructure Overview](agent_skills/infra.md)
- **When to read this**
- You need to understand where a feature belongs in the architecture.
- Youre wiring storage, Redis, vector stores, or OTEL.
- Youre about to add CLI commands or async jobs.\
What it covers: configuration stack (`configs/app_config.py`, remote settings), storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`), Redis conventions (`extensions/ext_redis.py`), plugin runtime topology, vector-store factory (`core/rag/datasource/vdb/*`), observability hooks, SSRF proxy usage, and core CLI commands.
- Youre about to add CLI commands or async jobs.
- **What it covers**
- Configuration stack (`configs/app_config.py`, remote settings)
- Storage entry points (`extensions/ext_storage.py`, `core/file/file_manager.py`)
- Redis conventions (`extensions/ext_redis.py`)
- Plugin runtime topology
- Vector-store factory (`core/rag/datasource/vdb/*`)
- Observability hooks
- SSRF proxy usage
- Core CLI commands
- **[Coding Style](agent_skills/coding_style.md)**\
When to read this:
### Plugin & Extension Development
- Youre writing or reviewing backend code and need the authoritative checklist.
- Youre unsure about Pydantic validators, SQLAlchemy session usage, or logging patterns.
- You want the exact lint/type/test commands used in PRs.\
Includes: Ruff & BasedPyright commands, no-annotation policy, session examples (`with Session(db.engine, ...)`), `@field_validator` usage, logging expectations, and the rule set for file size, helpers, and package management.
______________________________________________________________________
## Plugin & Extension Development
- **[Plugin Systems](agent_skills/plugin.md)**\
When to read this:
#### [Plugin Systems](agent_skills/plugin.md)
- **When to read this**
- Youre building or debugging a marketplace plugin.
- You need to know how manifests, providers, daemons, and migrations fit together.\
What it covers: plugin manifests (`core/plugin/entities/plugin.py`), installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands), runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent), daemon coordination (`core/plugin/entities/plugin_daemon.py`), and how provider registries surface capabilities to the rest of the platform.
- You need to know how manifests, providers, daemons, and migrations fit together.
- **What it covers**
- Plugin manifests (`core/plugin/entities/plugin.py`)
- Installation/upgrade flows (`services/plugin/plugin_service.py`, CLI commands)
- Runtime adapters (`core/plugin/impl/*` for tool/model/datasource/trigger/endpoint/agent)
- Daemon coordination (`core/plugin/entities/plugin_daemon.py`)
- How provider registries surface capabilities to the rest of the platform
- **[Plugin OAuth](agent_skills/plugin_oauth.md)**\
When to read this:
#### [Plugin OAuth](agent_skills/plugin_oauth.md)
- **When to read this**
- You must integrate OAuth for a plugin or datasource.
- Youre handling credential encryption or refresh flows.\
Topics: credential storage, encryption helpers (`core/helper/provider_encryption.py`), OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`), and how console/API layers expose the flows.
- Youre handling credential encryption or refresh flows.
- **Topics**
- Credential storage
- Encryption helpers (`core/helper/provider_encryption.py`)
- OAuth client bootstrap (`services/plugin/oauth_service.py`, `services/plugin/plugin_parameter_service.py`)
- How console/API layers expose the flows
______________________________________________________________________
### Workflow Entry & Execution
## Workflow Entry & Execution
#### [Trigger Concepts](agent_skills/trigger.md)
- **[Trigger Concepts](agent_skills/trigger.md)**\
When to read this:
- **When to read this**
- Youre debugging why a workflow didnt start.
- Youre adding a new trigger type or hook.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.\
Details: Start-node taxonomy, webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`), async orchestration (`services/async_workflow_service.py`, Celery queues), debug event bus, and storage/logging interactions.
- You need to trace async execution, draft debugging, or webhook/schedule pipelines.
- **Details**
- Start-node taxonomy
- Webhook & schedule internals (`core/workflow/nodes/trigger_*`, `services/trigger/*`)
- Async orchestration (`services/async_workflow_service.py`, Celery queues)
- Debug event bus
- Storage/logging interactions
______________________________________________________________________
## General Reminders
## Additional Notes for Agents
- All skill docs assume you follow the coding style guide—run Ruff/BasedPyright/tests listed there before submitting changes.
- All skill docs assume you follow the coding style rules below—run the lint/type/test commands before submitting changes.
- When you cannot find an answer in these briefs, search the codebase using the paths referenced (e.g., `core/plugin/impl/tool.py`, `services/dataset_service.py`).
- If you run into cross-cutting concerns (tenancy, configuration, storage), check the infrastructure guide first; it links to most supporting modules.
- Keep multi-tenancy and configuration central: everything flows through `configs.dify_config` and `tenant_id`.
- When touching plugins or triggers, consult both the system overview and the specialised doc to ensure you adjust lifecycle, storage, and observability consistently.
## Coding Style
This is the default standard for backend code in this repo. Follow it for new code and use it as the checklist when reviewing changes.
### Linting & Formatting
- Use Ruff for formatting and linting (follow `.ruff.toml`).
- Keep each line under 120 characters (including spaces).
### Naming Conventions
- Use `snake_case` for variables and functions.
- Use `PascalCase` for classes.
- Use `UPPER_CASE` for constants.
### Typing & Class Layout
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python
from datetime import datetime
class Example:
user_id: str
created_at: datetime
def __init__(self, user_id: str, created_at: datetime) -> None:
self.user_id = user_id
self.created_at = created_at
```
### General Rules
- Use Pydantic v2 conventions.
- Use `uv` for Python package management in this repo (usually with `--project api`).
- Prefer simple functions over small “utility classes” for lightweight helpers.
- Avoid implementing dunder methods unless its clearly needed and matches existing patterns.
- Never start long-running services as part of agent work (`uv run app.py`, `flask run`, etc.); running tests is allowed.
- Keep files below ~800 lines; split when necessary.
- Keep code readable and explicit—avoid clever hacks.
### Architecture & Boundaries
- Mirror the layered architecture: controller → service → core/domain.
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
### Logging & Errors
- Never use `print`; use a module-level logger:
- `logger = logging.getLogger(__name__)`
- Include tenant/app/workflow identifiers in log context when relevant.
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate them into HTTP responses in controllers.
- Log retryable events at `warning`, terminal failures at `error`.
### SQLAlchemy Patterns
- Models inherit from `models.base.TypeBase`; do not create ad-hoc metadata or engines.
- Open sessions with context managers:
```python
from sqlalchemy.orm import Session
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Prefer SQLAlchemy expressions; avoid raw SQL unless necessary.
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
- Introduce repository abstractions only for very large tables (e.g., workflow executions) or when alternative storage strategies are required.
### Storage & External I/O
- Access storage via `extensions.ext_storage.storage`.
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
- Background tasks that touch storage must be idempotent, and should log relevant object identifiers.
### Pydantic Usage
- Define DTOs with Pydantic v2 models and forbid extras by default.
- Use `@field_validator` / `@model_validator` for domain rules.
Example:
```python
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
class TriggerConfig(BaseModel):
endpoint: HttpUrl
secret: str
model_config = ConfigDict(extra="forbid")
@field_validator("secret")
def ensure_secret_prefix(cls, value: str) -> str:
if not value.startswith("dify_"):
raise ValueError("secret must start with dify_")
return value
```
### Generics & Protocols
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
### Tooling & Checks
Quick checks while iterating:
- Format: `make format`
- Lint (includes auto-fix): `make lint`
- Type check: `make type-check`
- Targeted tests: `make test TARGET_TESTS=./api/tests/<target_tests>`
Before opening a PR / submitting:
- `make lint`
- `make type-check`
- `make test`
### Controllers & Services
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Document non-obvious behaviour with concise comments.
### Miscellaneous
- Use `configs.dify_config` for configuration—never read environment variables directly.
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
- Keep experimental scripts under `dev/`; do not ship them in production builds.

View File

@@ -1,115 +0,0 @@
## Linter
- Always follow `.ruff.toml`.
- Run `uv run ruff check --fix --unsafe-fixes`.
- Keep each line under 100 characters (including spaces).
## Code Style
- `snake_case` for variables and functions.
- `PascalCase` for classes.
- `UPPER_CASE` for constants.
## Rules
- Use Pydantic v2 standard.
- Use `uv` for package management.
- Do not override dunder methods like `__init__`, `__iadd__`, etc.
- Never launch services (`uv run app.py`, `flask run`, etc.); running tests under `tests/` is allowed.
- Prefer simple functions over classes for lightweight helpers.
- Keep files below 800 lines; split when necessary.
- Keep code readable—no clever hacks.
- Never use `print`; log with `logger = logging.getLogger(__name__)`.
## Guiding Principles
- Mirror the projects layered architecture: controller → service → core/domain.
- Reuse existing helpers in `core/`, `services/`, and `libs/` before creating new abstractions.
- Optimise for observability: deterministic control flow, clear logging, actionable errors.
## SQLAlchemy Patterns
- Models inherit from `models.base.Base`; never create ad-hoc metadata or engines.
- Open sessions with context managers:
```python
from sqlalchemy.orm import Session
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(Workflow).where(
Workflow.id == workflow_id,
Workflow.tenant_id == tenant_id,
)
workflow = session.execute(stmt).scalar_one_or_none()
```
- Use SQLAlchemy expressions; avoid raw SQL unless necessary.
- Introduce repository abstractions only for very large tables (e.g., workflow executions) to support alternative storage strategies.
- Always scope queries by `tenant_id` and protect write paths with safeguards (`FOR UPDATE`, row counts, etc.).
## Storage & External IO
- Access storage via `extensions.ext_storage.storage`.
- Use `core.helper.ssrf_proxy` for outbound HTTP fetches.
- Background tasks that touch storage must be idempotent and log the relevant object identifiers.
## Pydantic Usage
- Define DTOs with Pydantic v2 models and forbid extras by default.
- Use `@field_validator` / `@model_validator` for domain rules.
- Example:
```python
from pydantic import BaseModel, ConfigDict, HttpUrl, field_validator
class TriggerConfig(BaseModel):
endpoint: HttpUrl
secret: str
model_config = ConfigDict(extra="forbid")
@field_validator("secret")
def ensure_secret_prefix(cls, value: str) -> str:
if not value.startswith("dify_"):
raise ValueError("secret must start with dify_")
return value
```
## Generics & Protocols
- Use `typing.Protocol` to define behavioural contracts (e.g., cache interfaces).
- Apply generics (`TypeVar`, `Generic`) for reusable utilities like caches or providers.
- Validate dynamic inputs at runtime when generics cannot enforce safety alone.
## Error Handling & Logging
- Raise domain-specific exceptions (`services/errors`, `core/errors`) and translate to HTTP responses in controllers.
- Declare `logger = logging.getLogger(__name__)` at module top.
- Include tenant/app/workflow identifiers in log context.
- Log retryable events at `warning`, terminal failures at `error`.
## Tooling & Checks
- Format/lint: `uv run --project api --dev ruff format ./api` and `uv run --project api --dev ruff check --fix --unsafe-fixes ./api`.
- Type checks: `uv run --directory api --dev basedpyright`.
- Tests: `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`.
- Run all of the above before submitting your work.
## Controllers & Services
- Controllers: parse input via Pydantic, invoke services, return serialised responses; no business logic.
- Services: coordinate repositories, providers, background tasks; keep side effects explicit.
- Avoid repositories unless necessary; direct SQLAlchemy usage is preferred for typical tables.
- Document non-obvious behaviour with concise comments.
## Miscellaneous
- Use `configs.dify_config` for configuration—never read environment variables directly.
- Maintain tenant awareness end-to-end; `tenant_id` must flow through every layer touching shared resources.
- Queue async work through `services/async_workflow_service`; implement tasks under `tasks/` with explicit queue selection.
- Keep experimental scripts under `dev/`; do not ship them in production builds.

View File

@@ -1,7 +1,9 @@
import base64
import datetime
import json
import logging
import secrets
import time
from typing import Any
import click
@@ -34,7 +36,7 @@ from libs.rsa import generate_key_pair
from models import Tenant
from models.dataset import Dataset, DatasetCollectionBinding, DatasetMetadata, DatasetMetadataBinding, DocumentSegment
from models.dataset import Document as DatasetDocument
from models.model import Account, App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.model import App, AppAnnotationSetting, AppMode, Conversation, MessageAnnotation, UploadFile
from models.oauth import DatasourceOauthParamConfig, DatasourceProvider
from models.provider import Provider, ProviderModel
from models.provider_ids import DatasourceProviderID, ToolProviderID
@@ -45,6 +47,9 @@ from services.clear_free_plan_tenant_expired_logs import ClearFreePlanTenantExpi
from services.plugin.data_migration import PluginDataMigration
from services.plugin.plugin_migration import PluginMigration
from services.plugin.plugin_service import PluginService
from services.retention.conversation.messages_clean_policy import create_message_clean_policy
from services.retention.conversation.messages_clean_service import MessagesCleanService
from services.retention.workflow_run.clear_free_plan_expired_workflow_run_logs import WorkflowRunCleanup
from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
@@ -62,8 +67,10 @@ def reset_password(email, new_password, password_confirm):
if str(new_password).strip() != str(password_confirm).strip():
click.echo(click.style("Passwords do not match.", fg="red"))
return
normalized_email = email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
@@ -84,7 +91,7 @@ def reset_password(email, new_password, password_confirm):
base64_password_hashed = base64.b64encode(password_hashed).decode()
account.password = base64_password_hashed
account.password_salt = base64_salt
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
click.echo(click.style("Password reset successfully.", fg="green"))
@@ -100,20 +107,22 @@ def reset_email(email, new_email, email_confirm):
if str(new_email).strip() != str(email_confirm).strip():
click.echo(click.style("New emails do not match.", fg="red"))
return
normalized_new_email = new_email.strip().lower()
with sessionmaker(db.engine, expire_on_commit=False).begin() as session:
account = session.query(Account).where(Account.email == email).one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email.strip(), session=session)
if not account:
click.echo(click.style(f"Account not found for email: {email}", fg="red"))
return
try:
email_validate(new_email)
email_validate(normalized_new_email)
except:
click.echo(click.style(f"Invalid email: {new_email}", fg="red"))
return
account.email = new_email
account.email = normalized_new_email
click.echo(click.style("Email updated successfully.", fg="green"))
@@ -658,7 +667,7 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
return
# Create account
email = email.strip()
email = email.strip().lower()
if "@" not in email:
click.echo(click.style("Invalid email address.", fg="red"))
@@ -852,6 +861,61 @@ def clear_free_plan_tenant_expired_logs(days: int, batch: int, tenant_ids: list[
click.echo(click.style("Clear free plan tenant expired logs completed.", fg="green"))
@click.command("clean-workflow-runs", help="Clean expired workflow runs and related data for free tenants.")
@click.option("--days", default=30, show_default=True, help="Delete workflow runs created before N days ago.")
@click.option("--batch-size", default=200, show_default=True, help="Batch size for selecting workflow runs.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at; must be paired with --end-before.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional upper bound (exclusive) for created_at; must be paired with --start-from.",
)
@click.option(
"--dry-run",
is_flag=True,
help="Preview cleanup results without deleting any workflow run data.",
)
def clean_workflow_runs(
days: int,
batch_size: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
dry_run: bool,
):
"""
Clean workflow runs and related workflow data for free tenants.
"""
if (start_from is None) ^ (end_before is None):
raise click.UsageError("--start-from and --end-before must be provided together.")
start_time = datetime.datetime.now(datetime.UTC)
click.echo(click.style(f"Starting workflow run cleanup at {start_time.isoformat()}.", fg="white"))
WorkflowRunCleanup(
days=days,
batch_size=batch_size,
start_from=start_from,
end_before=end_before,
dry_run=dry_run,
).run()
end_time = datetime.datetime.now(datetime.UTC)
elapsed = end_time - start_time
click.echo(
click.style(
f"Workflow run cleanup completed. start={start_time.isoformat()} "
f"end={end_time.isoformat()} duration={elapsed}",
fg="green",
)
)
@click.option("-f", "--force", is_flag=True, help="Skip user confirmation and force the command to execute.")
@click.command("clear-orphaned-file-records", help="Clear orphaned file records.")
def clear_orphaned_file_records(force: bool):
@@ -2111,3 +2175,79 @@ def migrate_oss(
except Exception as e:
db.session.rollback()
click.echo(click.style(f"Failed to update DB storage_type: {str(e)}", fg="red"))
@click.command("clean-expired-messages", help="Clean expired messages.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
default=21,
show_default=True,
help="Graceful period in days after subscription expiration, will be ignored when billing is disabled.",
)
@click.option("--dry-run", is_flag=True, default=False, help="Show messages logs would be cleaned without deleting")
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
Clean expired messages and related data for tenants based on clean policy.
"""
click.echo(click.style("clean_messages: start clean messages.", fg="green"))
start_at = time.perf_counter()
try:
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()
click.echo(
click.style(
f"clean_messages: completed successfully\n"
f" - Latency: {end_at - start_at:.2f}s\n"
f" - Batches processed: {stats['batches']}\n"
f" - Total messages scanned: {stats['total_messages']}\n"
f" - Messages filtered: {stats['filtered_messages']}\n"
f" - Messages deleted: {stats['total_deleted']}",
fg="green",
)
)
except Exception as e:
end_at = time.perf_counter()
logger.exception("clean_messages failed")
click.echo(
click.style(
f"clean_messages: failed after {end_at - start_at:.2f}s - {str(e)}",
fg="red",
)
)
raise
click.echo(click.style("messages cleanup completed.", fg="green"))

View File

@@ -949,6 +949,12 @@ class MailConfig(BaseSettings):
default=False,
)
SMTP_LOCAL_HOSTNAME: str | None = Field(
description="Override the local hostname used in SMTP HELO/EHLO. "
"Useful behind NAT or when the default hostname causes rejections.",
default=None,
)
EMAIL_SEND_IP_LIMIT_PER_MINUTE: PositiveInt = Field(
description="Maximum number of emails allowed to be sent from the same IP address in a minute",
default=50,
@@ -1101,6 +1107,10 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable clean messages task",
default=False,
)
ENABLE_WORKFLOW_RUN_CLEANUP_TASK: bool = Field(
description="Enable scheduled workflow run cleanup task",
default=False,
)
ENABLE_MAIL_CLEAN_DOCUMENT_NOTIFY_TASK: bool = Field(
description="Enable mail clean document notify task",
default=False,

View File

@@ -8,6 +8,11 @@ class HostedCreditConfig(BaseSettings):
default="",
)
HOSTED_POOL_CREDITS: int = Field(
description="Pool credits for hosted service",
default=200,
)
def get_model_credits(self, model_name: str) -> int:
"""
Get credit value for a specific model name.
@@ -60,19 +65,46 @@ class HostedOpenAiConfig(BaseSettings):
HOSTED_OPENAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gpt-3.5-turbo,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-instruct,"
default="gpt-4,"
"gpt-4-turbo-preview,"
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
"gpt-3.5-turbo-1106,"
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"text-davinci-003",
)
HOSTED_OPENAI_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted OpenAI service usage",
default=200,
"gpt-3.5-turbo-instruct,"
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
)
HOSTED_OPENAI_PAID_ENABLED: bool = Field(
@@ -87,6 +119,13 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-4-turbo-2024-04-09,"
"gpt-4-1106-preview,"
"gpt-4-0125-preview,"
"gpt-4-turbo,"
"gpt-4.1,"
"gpt-4.1-2025-04-14,"
"gpt-4.1-mini,"
"gpt-4.1-mini-2025-04-14,"
"gpt-4.1-nano,"
"gpt-4.1-nano-2025-04-14,"
"gpt-3.5-turbo,"
"gpt-3.5-turbo-16k,"
"gpt-3.5-turbo-16k-0613,"
@@ -94,7 +133,150 @@ class HostedOpenAiConfig(BaseSettings):
"gpt-3.5-turbo-0613,"
"gpt-3.5-turbo-0125,"
"gpt-3.5-turbo-instruct,"
"text-davinci-003",
"text-davinci-003,"
"chatgpt-4o-latest,"
"gpt-4o,"
"gpt-4o-2024-05-13,"
"gpt-4o-2024-08-06,"
"gpt-4o-2024-11-20,"
"gpt-4o-audio-preview,"
"gpt-4o-audio-preview-2025-06-03,"
"gpt-4o-mini,"
"gpt-4o-mini-2024-07-18,"
"o3-mini,"
"o3-mini-2025-01-31,"
"gpt-5-mini-2025-08-07,"
"gpt-5-mini,"
"o4-mini,"
"o4-mini-2025-04-16,"
"gpt-5-chat-latest,"
"gpt-5,"
"gpt-5-2025-08-07,"
"gpt-5-nano,"
"gpt-5-nano-2025-08-07",
)
class HostedGeminiConfig(BaseSettings):
"""
Configuration for fetching Gemini service
"""
HOSTED_GEMINI_API_KEY: str | None = Field(
description="API key for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_API_BASE: str | None = Field(
description="Base URL for hosted Gemini API",
default=None,
)
HOSTED_GEMINI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Gemini service",
default=None,
)
HOSTED_GEMINI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Gemini service",
default=False,
)
HOSTED_GEMINI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
HOSTED_GEMINI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted gemini service",
default=False,
)
HOSTED_GEMINI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="gemini-2.5-flash,gemini-2.0-flash,gemini-2.0-flash-lite,",
)
class HostedXAIConfig(BaseSettings):
"""
Configuration for fetching XAI service
"""
HOSTED_XAI_API_KEY: str | None = Field(
description="API key for hosted XAI service",
default=None,
)
HOSTED_XAI_API_BASE: str | None = Field(
description="Base URL for hosted XAI API",
default=None,
)
HOSTED_XAI_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted XAI service",
default=None,
)
HOSTED_XAI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted XAI service",
default=False,
)
HOSTED_XAI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
HOSTED_XAI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted XAI service",
default=False,
)
HOSTED_XAI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="grok-3,grok-3-mini,grok-3-mini-fast",
)
class HostedDeepseekConfig(BaseSettings):
"""
Configuration for fetching Deepseek service
"""
HOSTED_DEEPSEEK_API_KEY: str | None = Field(
description="API key for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_API_BASE: str | None = Field(
description="Base URL for hosted Deepseek API",
default=None,
)
HOSTED_DEEPSEEK_API_ORGANIZATION: str | None = Field(
description="Organization ID for hosted Deepseek service",
default=None,
)
HOSTED_DEEPSEEK_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="deepseek-chat,deepseek-reasoner",
)
HOSTED_DEEPSEEK_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Deepseek service",
default=False,
)
HOSTED_DEEPSEEK_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="deepseek-chat,deepseek-reasoner",
)
@@ -144,16 +326,66 @@ class HostedAnthropicConfig(BaseSettings):
default=False,
)
HOSTED_ANTHROPIC_QUOTA_LIMIT: NonNegativeInt = Field(
description="Quota limit for hosted Anthropic service usage",
default=600000,
)
HOSTED_ANTHROPIC_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_ANTHROPIC_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
HOSTED_ANTHROPIC_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="claude-opus-4-20250514,"
"claude-sonnet-4-20250514,"
"claude-3-5-haiku-20241022,"
"claude-3-opus-20240229,"
"claude-3-7-sonnet-20250219,"
"claude-3-haiku-20240307",
)
class HostedTongyiConfig(BaseSettings):
"""
Configuration for hosted Tongyi service
"""
HOSTED_TONGYI_API_KEY: str | None = Field(
description="API key for hosted Tongyi service",
default=None,
)
HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT: bool = Field(
description="Use international endpoint for hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_TRIAL_ENABLED: bool = Field(
description="Enable trial access to hosted Tongyi service",
default=False,
)
HOSTED_TONGYI_PAID_ENABLED: bool = Field(
description="Enable paid access to hosted Anthropic service",
default=False,
)
HOSTED_TONGYI_TRIAL_MODELS: str = Field(
description="Comma-separated list of available models for trial access",
default="",
)
HOSTED_TONGYI_PAID_MODELS: str = Field(
description="Comma-separated list of available models for paid access",
default="",
)
class HostedMinmaxConfig(BaseSettings):
"""
@@ -246,9 +478,13 @@ class HostedServiceConfig(
HostedOpenAiConfig,
HostedSparkConfig,
HostedZhipuAIConfig,
HostedTongyiConfig,
# moderation
HostedModerationConfig,
# credit config
HostedCreditConfig,
HostedGeminiConfig,
HostedXAIConfig,
HostedDeepseekConfig,
):
pass

View File

@@ -4,7 +4,7 @@ from pydantic_settings import BaseSettings
class VolcengineTOSStorageConfig(BaseSettings):
"""
Configuration settings for Volcengine Tinder Object Storage (TOS)
Configuration settings for Volcengine Torch Object Storage (TOS)
"""
VOLCENGINE_TOS_BUCKET_NAME: str | None = Field(

View File

@@ -592,9 +592,12 @@ def _get_conversation(app_model, conversation_id):
if not conversation:
raise NotFound("Conversation Not Exists.")
if not conversation.read_at:
conversation.read_at = naive_utc_now()
conversation.read_account_id = current_user.id
db.session.commit()
db.session.execute(
sa.update(Conversation)
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
)
db.session.commit()
db.session.refresh(conversation)
return conversation

View File

@@ -202,7 +202,6 @@ message_detail_model = console_ns.model(
"status": fields.String,
"error": fields.String,
"parent_message_id": fields.String,
"generation_detail": fields.Raw,
},
)

View File

@@ -63,10 +63,9 @@ class ActivateCheckApi(Resource):
args = ActivateCheckQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
workspaceId = args.workspace_id
reg_email = args.email
token = args.token
invitation = RegisterService.get_invitation_if_token_valid(workspaceId, reg_email, token)
invitation = RegisterService.get_invitation_with_case_fallback(workspaceId, args.email, token)
if invitation:
data = invitation.get("data", {})
tenant = invitation.get("tenant", None)
@@ -100,11 +99,12 @@ class ActivateApi(Resource):
def post(self):
args = ActivatePayload.model_validate(console_ns.payload)
invitation = RegisterService.get_invitation_if_token_valid(args.workspace_id, args.email, args.token)
normalized_request_email = args.email.lower() if args.email else None
invitation = RegisterService.get_invitation_with_case_fallback(args.workspace_id, args.email, args.token)
if invitation is None:
raise AlreadyActivateError()
RegisterService.revoke_token(args.workspace_id, args.email, args.token)
RegisterService.revoke_token(args.workspace_id, normalized_request_email, args.token)
account = invitation["account"]
account.name = args.name

View File

@@ -1,7 +1,6 @@
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
@@ -62,6 +61,7 @@ class EmailRegisterSendEmailApi(Resource):
@email_register_enabled
def post(self):
args = EmailRegisterSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -70,13 +70,12 @@ class EmailRegisterSendEmailApi(Resource):
if args.language in languages:
language = args.language
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
token = None
token = AccountService.send_email_register_email(email=args.email, account=account, language=language)
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_email_register_email(email=normalized_email, account=account, language=language)
return {"result": "success", "data": token}
@@ -88,9 +87,9 @@ class EmailRegisterCheckApi(Resource):
def post(self):
args = EmailRegisterValidityPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(args.email)
is_email_register_error_rate_limit = AccountService.is_email_register_error_rate_limit(user_email)
if is_email_register_error_rate_limit:
raise EmailRegisterLimitError()
@@ -98,11 +97,14 @@ class EmailRegisterCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_email_register_error_rate_limit(args.email)
AccountService.add_email_register_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -113,8 +115,8 @@ class EmailRegisterCheckApi(Resource):
user_email, code=args.code, additional_data={"phase": "register"}
)
AccountService.reset_email_register_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_email_register_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/email-register")
@@ -141,22 +143,23 @@ class EmailRegisterResetApi(Resource):
AccountService.revoke_email_register_token(args.token)
email = register_data.get("email", "")
normalized_email = email.lower()
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
raise EmailAlreadyInUseError()
else:
account = self._create_new_account(email, args.password_confirm)
account = self._create_new_account(normalized_email, args.password_confirm)
if not account:
raise AccountNotFoundError()
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(email)
AccountService.reset_login_error_rate_limit(normalized_email)
return {"result": "success", "data": token_pair.model_dump()}
def _create_new_account(self, email, password) -> Account | None:
def _create_new_account(self, email: str, password: str) -> Account | None:
# Create new account if allowed
account = None
try:

View File

@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource, fields
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.console import console_ns
@@ -21,7 +20,6 @@ from events.tenant_event import tenant_was_created
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from services.account_service import AccountService, TenantService
from services.feature_service import FeatureService
@@ -76,6 +74,7 @@ class ForgotPasswordSendEmailApi(Resource):
@email_password_login_enabled
def post(self):
args = ForgotPasswordSendPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -87,11 +86,11 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
token = AccountService.send_reset_password_email(
account=account,
email=args.email,
email=normalized_email,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
)
@@ -122,9 +121,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
args = ForgotPasswordCheckPayload.model_validate(console_ns.payload)
user_email = args.email
user_email = args.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(args.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -132,11 +131,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(args.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -144,11 +148,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=args.code, additional_data={"phase": "reset"}
token_email, code=args.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/forgot-password/resets")
@@ -187,9 +191,8 @@ class ForgotPasswordResetApi(Resource):
password_hashed = hash_password(args.new_password, salt)
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@@ -90,32 +90,38 @@ class LoginApi(Resource):
def post(self):
"""Authenticate user and login."""
args = LoginPayload.model_validate(console_ns.payload)
request_email = args.email
normalized_email = request_email.lower()
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(args.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountInFreezeError()
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(args.email)
is_login_error_rate_limit = AccountService.is_login_error_rate_limit(normalized_email)
if is_login_error_rate_limit:
raise EmailPasswordLoginLimitError()
invite_token = args.invite_token
invitation_data: dict[str, Any] | None = None
if args.invite_token:
invitation_data = RegisterService.get_invitation_if_token_valid(None, args.email, args.invite_token)
if invite_token:
invitation_data = RegisterService.get_invitation_with_case_fallback(None, request_email, invite_token)
if invitation_data is None:
invite_token = None
try:
if invitation_data:
data = invitation_data.get("data", {})
invitee_email = data.get("email") if data else None
if invitee_email != args.email:
invitee_email_normalized = invitee_email.lower() if isinstance(invitee_email, str) else invitee_email
if invitee_email_normalized != normalized_email:
raise InvalidEmailError()
account = AccountService.authenticate(args.email, args.password, args.invite_token)
else:
account = AccountService.authenticate(args.email, args.password)
account = _authenticate_account_with_case_fallback(
request_email, normalized_email, args.password, invite_token
)
except services.errors.account.AccountLoginError:
raise AccountBannedError()
except services.errors.account.AccountPasswordError:
AccountService.add_login_error_rate_limit(args.email)
raise AuthenticationFailedError()
except services.errors.account.AccountPasswordError as exc:
AccountService.add_login_error_rate_limit(normalized_email)
raise AuthenticationFailedError() from exc
# SELF_HOSTED only have one workspace
tenants = TenantService.get_join_tenants(account)
if len(tenants) == 0:
@@ -130,7 +136,7 @@ class LoginApi(Resource):
}
token_pair = AccountService.login(account=account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(normalized_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -170,18 +176,19 @@ class ResetPasswordSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
if args.language is not None and args.language == "zh-Hans":
language = "zh-Hans"
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
token = AccountService.send_reset_password_email(
email=args.email,
email=normalized_email,
account=account,
language=language,
is_allow_register=FeatureService.get_system_features().is_allow_register,
@@ -196,6 +203,7 @@ class EmailCodeLoginSendEmailApi(Resource):
@console_ns.expect(console_ns.models[EmailPayload.__name__])
def post(self):
args = EmailPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
@@ -206,13 +214,13 @@ class EmailCodeLoginSendEmailApi(Resource):
else:
language = "en-US"
try:
account = AccountService.get_user_through_email(args.email)
account = _get_account_with_case_fallback(args.email)
except AccountRegisterError:
raise AccountInFreezeError()
if account is None:
if FeatureService.get_system_features().is_allow_register:
token = AccountService.send_email_code_login_email(email=args.email, language=language)
token = AccountService.send_email_code_login_email(email=normalized_email, language=language)
else:
raise AccountNotFound()
else:
@@ -229,14 +237,17 @@ class EmailCodeLoginApi(Resource):
def post(self):
args = EmailCodeLoginPayload.model_validate(console_ns.payload)
user_email = args.email
original_email = args.email
user_email = original_email.lower()
language = args.language
token_data = AccountService.get_email_code_login_data(args.token)
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args.email:
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args.code:
@@ -244,7 +255,7 @@ class EmailCodeLoginApi(Resource):
AccountService.revoke_email_code_login_token(args.token)
try:
account = AccountService.get_user_through_email(user_email)
account = _get_account_with_case_fallback(original_email)
except AccountRegisterError:
raise AccountInFreezeError()
if account:
@@ -275,7 +286,7 @@ class EmailCodeLoginApi(Resource):
except WorkspacesLimitExceededError:
raise WorkspacesLimitExceeded()
token_pair = AccountService.login(account, ip_address=extract_remote_ip(request))
AccountService.reset_login_error_rate_limit(args.email)
AccountService.reset_login_error_rate_limit(user_email)
# Create response with cookies instead of returning tokens in body
response = make_response({"result": "success"})
@@ -309,3 +320,22 @@ class RefreshTokenApi(Resource):
return response
except Exception as e:
return {"result": "fail", "message": str(e)}, 401
def _get_account_with_case_fallback(email: str):
account = AccountService.get_user_through_email(email)
if account or email == email.lower():
return account
return AccountService.get_user_through_email(email.lower())
def _authenticate_account_with_case_fallback(
original_email: str, normalized_email: str, password: str, invite_token: str | None
):
try:
return AccountService.authenticate(original_email, password, invite_token)
except services.errors.account.AccountPasswordError:
if original_email == normalized_email:
raise
return AccountService.authenticate(normalized_email, password, invite_token)

View File

@@ -3,7 +3,6 @@ import logging
import httpx
from flask import current_app, redirect, request
from flask_restx import Resource
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import Unauthorized
@@ -118,7 +117,10 @@ class OAuthCallback(Resource):
invitation = RegisterService.get_invitation_by_token(token=invite_token)
if invitation:
invitation_email = invitation.get("email", None)
if invitation_email != user_info.email:
invitation_email_normalized = (
invitation_email.lower() if isinstance(invitation_email, str) else invitation_email
)
if invitation_email_normalized != user_info.email.lower():
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin?message=Invalid invitation token.")
return redirect(f"{dify_config.CONSOLE_WEB_URL}/signin/invite-settings?invite_token={invite_token}")
@@ -175,7 +177,7 @@ def _get_account_by_openid_or_email(provider: str, user_info: OAuthUserInfo) ->
if not account:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=user_info.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(user_info.email, session=session)
return account
@@ -197,9 +199,10 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
tenant_was_created.send(new_tenant)
if not account:
normalized_email = user_info.email.lower()
oauth_new_user = True
if not FeatureService.get_system_features().is_allow_register:
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(user_info.email):
if dify_config.BILLING_ENABLED and BillingService.is_email_in_freeze(normalized_email):
raise AccountRegisterError(
description=(
"This email account has been deleted within the past "
@@ -210,7 +213,11 @@ def _generate_account(provider: str, user_info: OAuthUserInfo) -> tuple[Account,
raise AccountRegisterError(description=("Invalid email or password"))
account_name = user_info.name or "Dify"
account = RegisterService.register(
email=user_info.email, name=account_name, password=None, open_id=user_info.id, provider=provider
email=normalized_email,
name=account_name,
password=None,
open_id=user_info.id,
provider=provider,
)
# Set interface language

View File

@@ -7,7 +7,7 @@ from typing import Literal, cast
import sqlalchemy as sa
from flask import request
from flask_restx import Resource, fields, marshal, marshal_with
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import asc, desc, select
from werkzeug.exceptions import Forbidden, NotFound
@@ -104,6 +104,15 @@ class DocumentRenamePayload(BaseModel):
name: str
class DocumentDatasetListParam(BaseModel):
page: int = Field(1, title="Page", description="Page number.")
limit: int = Field(20, title="Limit", description="Page size.")
search: str | None = Field(None, alias="keyword", title="Search", description="Search keyword.")
sort_by: str = Field("-created_at", alias="sort", title="SortBy", description="Sort by field.")
status: str | None = Field(None, title="Status", description="Document status.")
fetch_val: str = Field("false", alias="fetch")
register_schema_models(
console_ns,
KnowledgeConfig,
@@ -225,14 +234,16 @@ class DatasetDocumentListApi(Resource):
def get(self, dataset_id):
current_user, current_tenant_id = current_account_with_tenant()
dataset_id = str(dataset_id)
page = request.args.get("page", default=1, type=int)
limit = request.args.get("limit", default=20, type=int)
search = request.args.get("keyword", default=None, type=str)
sort = request.args.get("sort", default="-created_at", type=str)
status = request.args.get("status", default=None, type=str)
raw_args = request.args.to_dict()
param = DocumentDatasetListParam.model_validate(raw_args)
page = param.page
limit = param.limit
search = param.search
sort = param.sort_by
status = param.status
# "yes", "true", "t", "y", "1" convert to True, while others convert to False.
try:
fetch_val = request.args.get("fetch", default="false")
fetch_val = param.fetch_val
if isinstance(fetch_val, bool):
fetch = fetch_val
else:

View File

@@ -81,7 +81,7 @@ class ExternalKnowledgeApiPayload(BaseModel):
class ExternalDatasetCreatePayload(BaseModel):
external_knowledge_api_id: str
external_knowledge_id: str
name: str = Field(..., min_length=1, max_length=40)
name: str = Field(..., min_length=1, max_length=100)
description: str | None = Field(None, max_length=400)
external_retrieval_model: dict[str, object] | None = None

View File

@@ -84,10 +84,11 @@ class SetupApi(Resource):
raise NotInitValidateError()
args = SetupRequestPayload.model_validate(console_ns.payload)
normalized_email = args.email.lower()
# setup
RegisterService.setup(
email=args.email,
email=normalized_email,
name=args.name,
password=args.password,
ip_address=extract_remote_ip(request),

View File

@@ -30,6 +30,11 @@ class TagBindingRemovePayload(BaseModel):
type: Literal["knowledge", "app"] | None = Field(default=None, description="Tag type")
class TagListQueryParam(BaseModel):
type: Literal["knowledge", "app", ""] = Field("", description="Tag type filter")
keyword: str | None = Field(None, description="Search keyword")
register_schema_models(
console_ns,
TagBasePayload,
@@ -43,12 +48,15 @@ class TagListApi(Resource):
@setup_required
@login_required
@account_initialization_required
@console_ns.doc(
params={"type": 'Tag type filter. Can be "knowledge" or "app".', "keyword": "Search keyword for tag name."}
)
@marshal_with(dataset_tag_fields)
def get(self):
_, current_tenant_id = current_account_with_tenant()
tag_type = request.args.get("type", type=str, default="")
keyword = request.args.get("keyword", default=None, type=str)
tags = TagService.get_tags(tag_type, current_tenant_id, keyword)
raw_args = request.args.to_dict()
param = TagListQueryParam.model_validate(raw_args)
tags = TagService.get_tags(param.type, current_tenant_id, param.keyword)
return tags, 200

View File

@@ -41,7 +41,7 @@ from fields.member_fields import account_fields
from libs.datetime_utils import naive_utc_now
from libs.helper import EmailStr, TimestampField, extract_remote_ip, timezone
from libs.login import current_account_with_tenant, login_required
from models import Account, AccountIntegrate, InvitationCode
from models import AccountIntegrate, InvitationCode
from services.account_service import AccountService
from services.billing_service import BillingService
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
@@ -536,7 +536,8 @@ class ChangeEmailSendEmailApi(Resource):
else:
language = "en-US"
account = None
user_email = args.email
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
@@ -546,16 +547,24 @@ class ChangeEmailSendEmailApi(Resource):
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email != current_user.email:
if user_email.lower() != current_user.email.lower():
raise InvalidEmailError()
user_email = current_user.email
else:
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=args.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(args.email, session=session)
if account is None:
raise AccountNotFound()
email_for_sending = account.email
user_email = account.email
token = AccountService.send_change_email_email(
account=account, email=args.email, old_email=user_email, language=language, phase=args.phase
account=account,
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
)
return {"result": "success", "data": token}
@@ -571,9 +580,9 @@ class ChangeEmailCheckApi(Resource):
payload = console_ns.payload or {}
args = ChangeEmailValidityPayload.model_validate(payload)
user_email = args.email
user_email = args.email.lower()
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(args.email)
is_change_email_error_rate_limit = AccountService.is_change_email_error_rate_limit(user_email)
if is_change_email_error_rate_limit:
raise EmailChangeLimitError()
@@ -581,11 +590,13 @@ class ChangeEmailCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if user_email != normalized_token_email:
raise InvalidEmailError()
if args.code != token_data.get("code"):
AccountService.add_change_email_error_rate_limit(args.email)
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -596,8 +607,8 @@ class ChangeEmailCheckApi(Resource):
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
)
AccountService.reset_change_email_error_rate_limit(args.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_change_email_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@console_ns.route("/account/change-email/reset")
@@ -611,11 +622,12 @@ class ChangeEmailResetApi(Resource):
def post(self):
payload = console_ns.payload or {}
args = ChangeEmailResetPayload.model_validate(payload)
normalized_new_email = args.new_email.lower()
if AccountService.is_account_in_freeze(args.new_email):
if AccountService.is_account_in_freeze(normalized_new_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.new_email):
if not AccountService.check_email_unique(normalized_new_email):
raise EmailAlreadyInUseError()
reset_data = AccountService.get_change_email_data(args.token)
@@ -626,13 +638,13 @@ class ChangeEmailResetApi(Resource):
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email != old_email:
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
updated_account = AccountService.update_account_email(current_user, email=args.new_email)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(
email=args.new_email,
email=normalized_new_email,
)
return updated_account
@@ -645,8 +657,9 @@ class CheckEmailUnique(Resource):
def post(self):
payload = console_ns.payload or {}
args = CheckEmailUniquePayload.model_validate(payload)
if AccountService.is_account_in_freeze(args.email):
normalized_email = args.email.lower()
if AccountService.is_account_in_freeze(normalized_email):
raise AccountInFreezeError()
if not AccountService.check_email_unique(args.email):
if not AccountService.check_email_unique(normalized_email):
raise EmailAlreadyInUseError()
return {"result": "success"}

View File

@@ -116,26 +116,31 @@ class MemberInviteEmailApi(Resource):
raise WorkspaceMembersLimitExceeded()
for invitee_email in invitee_emails:
normalized_invitee_email = invitee_email.lower()
try:
if not inviter.current_tenant:
raise ValueError("No current tenant")
token = RegisterService.invite_new_member(
inviter.current_tenant, invitee_email, interface_language, role=invitee_role, inviter=inviter
tenant=inviter.current_tenant,
email=invitee_email,
language=interface_language,
role=invitee_role,
inviter=inviter,
)
encoded_invitee_email = parse.quote(invitee_email)
encoded_invitee_email = parse.quote(normalized_invitee_email)
invitation_results.append(
{
"status": "success",
"email": invitee_email,
"email": normalized_invitee_email,
"url": f"{console_web_url}/activate?email={encoded_invitee_email}&token={token}",
}
)
except AccountAlreadyInTenantError:
invitation_results.append(
{"status": "success", "email": invitee_email, "url": f"{console_web_url}/signin"}
{"status": "success", "email": normalized_invitee_email, "url": f"{console_web_url}/signin"}
)
except Exception as e:
invitation_results.append({"status": "failed", "email": invitee_email, "message": str(e)})
invitation_results.append({"status": "failed", "email": normalized_invitee_email, "message": str(e)})
return {
"result": "success",

View File

@@ -1,14 +1,14 @@
import logging
from collections.abc import Mapping
from typing import Any
from flask import make_response, redirect, request
from flask_restx import Resource, reqparse
from pydantic import BaseModel, Field, model_validator
from flask_restx import Resource
from pydantic import BaseModel, model_validator
from sqlalchemy.orm import Session
from werkzeug.exceptions import BadRequest, Forbidden
from configs import dify_config
from controllers.common.schema import register_schema_models
from controllers.web.error import NotFoundError
from core.model_runtime.utils.encoders import jsonable_encoder
from core.plugin.entities.plugin_daemon import CredentialType
@@ -35,35 +35,38 @@ from ..wraps import (
logger = logging.getLogger(__name__)
class TriggerSubscriptionUpdateRequest(BaseModel):
"""Request payload for updating a trigger subscription"""
class TriggerSubscriptionBuilderCreatePayload(BaseModel):
credential_type: str = CredentialType.UNAUTHORIZED
name: str | None = Field(default=None, description="The name for the subscription")
credentials: Mapping[str, Any] | None = Field(default=None, description="The credentials for the subscription")
parameters: Mapping[str, Any] | None = Field(default=None, description="The parameters for the subscription")
properties: Mapping[str, Any] | None = Field(default=None, description="The properties for the subscription")
class TriggerSubscriptionBuilderVerifyPayload(BaseModel):
credentials: dict[str, Any]
class TriggerSubscriptionBuilderUpdatePayload(BaseModel):
name: str | None = None
parameters: dict[str, Any] | None = None
properties: dict[str, Any] | None = None
credentials: dict[str, Any] | None = None
@model_validator(mode="after")
def check_at_least_one_field(self):
if all(v is None for v in (self.name, self.credentials, self.parameters, self.properties)):
if all(v is None for v in self.model_dump().values()):
raise ValueError("At least one of name, credentials, parameters, or properties must be provided")
return self
class TriggerSubscriptionVerifyRequest(BaseModel):
"""Request payload for verifying subscription credentials."""
credentials: Mapping[str, Any] = Field(description="The credentials to verify")
class TriggerOAuthClientPayload(BaseModel):
client_params: dict[str, Any] | None = None
enabled: bool | None = None
console_ns.schema_model(
TriggerSubscriptionUpdateRequest.__name__,
TriggerSubscriptionUpdateRequest.model_json_schema(ref_template="#/definitions/{model}"),
)
console_ns.schema_model(
TriggerSubscriptionVerifyRequest.__name__,
TriggerSubscriptionVerifyRequest.model_json_schema(ref_template="#/definitions/{model}"),
register_schema_models(
console_ns,
TriggerSubscriptionBuilderCreatePayload,
TriggerSubscriptionBuilderVerifyPayload,
TriggerSubscriptionBuilderUpdatePayload,
TriggerOAuthClientPayload,
)
@@ -132,16 +135,11 @@ class TriggerSubscriptionListApi(Resource):
raise
parser = reqparse.RequestParser().add_argument(
"credential_type", type=str, required=False, nullable=True, location="json"
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/create",
)
class TriggerSubscriptionBuilderCreateApi(Resource):
@console_ns.expect(parser)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderCreatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -151,10 +149,10 @@ class TriggerSubscriptionBuilderCreateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser.parse_args()
payload = TriggerSubscriptionBuilderCreatePayload.model_validate(console_ns.payload or {})
try:
credential_type = CredentialType.of(args.get("credential_type") or CredentialType.UNAUTHORIZED.value)
credential_type = CredentialType.of(payload.credential_type)
subscription_builder = TriggerSubscriptionBuilderService.create_trigger_subscription_builder(
tenant_id=user.current_tenant_id,
user_id=user.id,
@@ -182,18 +180,11 @@ class TriggerSubscriptionBuilderGetApi(Resource):
)
parser_api = (
reqparse.RequestParser()
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/verify-and-update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
@console_ns.expect(parser_api)
class TriggerSubscriptionBuilderVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -203,7 +194,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser_api.parse_args()
payload = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_verify to prevent race conditions
@@ -213,7 +204,7 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
credentials=args.get("credentials", None),
credentials=payload.credentials,
),
)
except Exception as e:
@@ -221,24 +212,11 @@ class TriggerSubscriptionBuilderVerifyAndUpdateApi(Resource):
raise ValueError(str(e)) from e
parser_update_api = (
reqparse.RequestParser()
# The name of the subscription builder
.add_argument("name", type=str, required=False, nullable=True, location="json")
# The parameters of the subscription builder
.add_argument("parameters", type=dict, required=False, nullable=True, location="json")
# The properties of the subscription builder
.add_argument("properties", type=dict, required=False, nullable=True, location="json")
# The credentials of the subscription builder
.add_argument("credentials", type=dict, required=False, nullable=True, location="json")
)
@console_ns.route(
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/update/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderUpdateApi(Resource):
@console_ns.expect(parser_update_api)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -249,7 +227,7 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
assert isinstance(user, Account)
assert user.current_tenant_id is not None
args = parser_update_api.parse_args()
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
return jsonable_encoder(
TriggerSubscriptionBuilderService.update_trigger_subscription_builder(
@@ -257,10 +235,10 @@ class TriggerSubscriptionBuilderUpdateApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
credentials=args.get("credentials", None),
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
credentials=payload.credentials,
),
)
)
@@ -295,7 +273,7 @@ class TriggerSubscriptionBuilderLogsApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/builder/build/<path:subscription_builder_id>",
)
class TriggerSubscriptionBuilderBuildApi(Resource):
@console_ns.expect(parser_update_api)
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -304,7 +282,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"""Build a subscription instance for a trigger provider"""
user = current_user
assert user.current_tenant_id is not None
args = parser_update_api.parse_args()
payload = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
try:
# Use atomic update_and_build to prevent race conditions
TriggerSubscriptionBuilderService.update_and_build_builder(
@@ -313,9 +291,9 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
provider_id=TriggerProviderID(provider),
subscription_builder_id=subscription_builder_id,
subscription_builder_updater=SubscriptionBuilderUpdater(
name=args.get("name", None),
parameters=args.get("parameters", None),
properties=args.get("properties", None),
name=payload.name,
parameters=payload.parameters,
properties=payload.properties,
),
)
return 200
@@ -328,7 +306,7 @@ class TriggerSubscriptionBuilderBuildApi(Resource):
"/workspaces/current/trigger-provider/<path:subscription_id>/subscriptions/update",
)
class TriggerSubscriptionUpdateApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionUpdateRequest.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderUpdatePayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -338,7 +316,7 @@ class TriggerSubscriptionUpdateApi(Resource):
user = current_user
assert user.current_tenant_id is not None
request = TriggerSubscriptionUpdateRequest.model_validate(console_ns.payload)
request = TriggerSubscriptionBuilderUpdatePayload.model_validate(console_ns.payload or {})
subscription = TriggerProviderService.get_subscription_by_id(
tenant_id=user.current_tenant_id,
@@ -568,13 +546,6 @@ class TriggerOAuthCallbackApi(Resource):
return redirect(f"{dify_config.CONSOLE_WEB_URL}/oauth-callback")
parser_oauth_client = (
reqparse.RequestParser()
.add_argument("client_params", type=dict, required=False, nullable=True, location="json")
.add_argument("enabled", type=bool, required=False, nullable=True, location="json")
)
@console_ns.route("/workspaces/current/trigger-provider/<path:provider>/oauth/client")
class TriggerOAuthClientManageApi(Resource):
@setup_required
@@ -622,7 +593,7 @@ class TriggerOAuthClientManageApi(Resource):
logger.exception("Error getting OAuth client", exc_info=e)
raise
@console_ns.expect(parser_oauth_client)
@console_ns.expect(console_ns.models[TriggerOAuthClientPayload.__name__])
@setup_required
@login_required
@is_admin_or_owner_required
@@ -632,15 +603,15 @@ class TriggerOAuthClientManageApi(Resource):
user = current_user
assert user.current_tenant_id is not None
args = parser_oauth_client.parse_args()
payload = TriggerOAuthClientPayload.model_validate(console_ns.payload or {})
try:
provider_id = TriggerProviderID(provider)
return TriggerProviderService.save_custom_oauth_client_params(
tenant_id=user.current_tenant_id,
provider_id=provider_id,
client_params=args.get("client_params"),
enabled=args.get("enabled"),
client_params=payload.client_params,
enabled=payload.enabled,
)
except ValueError as e:
@@ -676,7 +647,7 @@ class TriggerOAuthClientManageApi(Resource):
"/workspaces/current/trigger-provider/<path:provider>/subscriptions/verify/<path:subscription_id>",
)
class TriggerSubscriptionVerifyApi(Resource):
@console_ns.expect(console_ns.models[TriggerSubscriptionVerifyRequest.__name__])
@console_ns.expect(console_ns.models[TriggerSubscriptionBuilderVerifyPayload.__name__])
@setup_required
@login_required
@edit_permission_required
@@ -686,9 +657,7 @@ class TriggerSubscriptionVerifyApi(Resource):
user = current_user
assert user.current_tenant_id is not None
verify_request: TriggerSubscriptionVerifyRequest = TriggerSubscriptionVerifyRequest.model_validate(
console_ns.payload
)
verify_request = TriggerSubscriptionBuilderVerifyPayload.model_validate(console_ns.payload or {})
try:
result = TriggerProviderService.verify_subscription_credentials(

View File

@@ -80,6 +80,9 @@ tenant_fields = {
"in_trial": fields.Boolean,
"trial_end_reason": fields.String,
"custom_config": fields.Raw(attribute="custom_config"),
"trial_credits": fields.Integer,
"trial_credits_used": fields.Integer,
"next_credit_reset_date": fields.Integer,
}
tenants_fields = {

View File

@@ -4,7 +4,6 @@ import secrets
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_models
@@ -22,7 +21,7 @@ from controllers.web import web_ns
from extensions.ext_database import db
from libs.helper import EmailStr, extract_remote_ip
from libs.password import hash_password, valid_password
from models import Account
from models.account import Account
from services.account_service import AccountService
@@ -70,6 +69,9 @@ class ForgotPasswordSendEmailApi(Resource):
def post(self):
payload = ForgotPasswordSendPayload.model_validate(web_ns.payload or {})
request_email = payload.email
normalized_email = request_email.lower()
ip_address = extract_remote_ip(request)
if AccountService.is_email_send_ip_limit(ip_address):
raise EmailSendIpLimitError()
@@ -80,12 +82,12 @@ class ForgotPasswordSendEmailApi(Resource):
language = "en-US"
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=payload.email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(request_email, session=session)
token = None
if account is None:
raise AuthenticationFailedError()
else:
token = AccountService.send_reset_password_email(account=account, email=payload.email, language=language)
token = AccountService.send_reset_password_email(account=account, email=normalized_email, language=language)
return {"result": "success", "data": token}
@@ -104,9 +106,9 @@ class ForgotPasswordCheckApi(Resource):
def post(self):
payload = ForgotPasswordCheckPayload.model_validate(web_ns.payload or {})
user_email = payload.email
user_email = payload.email.lower()
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(payload.email)
is_forgot_password_error_rate_limit = AccountService.is_forgot_password_error_rate_limit(user_email)
if is_forgot_password_error_rate_limit:
raise EmailPasswordResetLimitError()
@@ -114,11 +116,16 @@ class ForgotPasswordCheckApi(Resource):
if token_data is None:
raise InvalidTokenError()
if user_email != token_data.get("email"):
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if user_email != normalized_token_email:
raise InvalidEmailError()
if payload.code != token_data.get("code"):
AccountService.add_forgot_password_error_rate_limit(payload.email)
AccountService.add_forgot_password_error_rate_limit(user_email)
raise EmailCodeError()
# Verified, revoke the first token
@@ -126,11 +133,11 @@ class ForgotPasswordCheckApi(Resource):
# Refresh token data by generating a new token
_, new_token = AccountService.generate_reset_password_token(
user_email, code=payload.code, additional_data={"phase": "reset"}
token_email, code=payload.code, additional_data={"phase": "reset"}
)
AccountService.reset_forgot_password_error_rate_limit(payload.email)
return {"is_valid": True, "email": token_data.get("email"), "token": new_token}
AccountService.reset_forgot_password_error_rate_limit(user_email)
return {"is_valid": True, "email": normalized_token_email, "token": new_token}
@web_ns.route("/forgot-password/resets")
@@ -174,7 +181,7 @@ class ForgotPasswordResetApi(Resource):
email = reset_data.get("email", "")
with Session(db.engine) as session:
account = session.execute(select(Account).filter_by(email=email)).scalar_one_or_none()
account = AccountService.get_account_by_email_with_case_fallback(email, session=session)
if account:
self._update_existing_account(account, password_hashed, salt, session)

View File

@@ -10,7 +10,12 @@ from controllers.console.auth.error import (
InvalidEmailError,
)
from controllers.console.error import AccountBannedError
from controllers.console.wraps import only_edition_enterprise, setup_required
from controllers.console.wraps import (
decrypt_code_field,
decrypt_password_field,
only_edition_enterprise,
setup_required,
)
from controllers.web import web_ns
from controllers.web.wraps import decode_jwt_token
from libs.helper import email
@@ -42,6 +47,7 @@ class LoginApi(Resource):
404: "Account not found",
}
)
@decrypt_password_field
def post(self):
"""Authenticate user and login."""
parser = (
@@ -181,6 +187,7 @@ class EmailCodeLoginApi(Resource):
404: "Account not found",
}
)
@decrypt_code_field
def post(self):
parser = (
reqparse.RequestParser()
@@ -190,25 +197,29 @@ class EmailCodeLoginApi(Resource):
)
args = parser.parse_args()
user_email = args["email"]
user_email = args["email"].lower()
token_data = WebAppAuthService.get_email_code_login_data(args["token"])
if token_data is None:
raise InvalidTokenError()
if token_data["email"] != args["email"]:
token_email = token_data.get("email")
if not isinstance(token_email, str):
raise InvalidEmailError()
normalized_token_email = token_email.lower()
if normalized_token_email != user_email:
raise InvalidEmailError()
if token_data["code"] != args["code"]:
raise EmailCodeError()
WebAppAuthService.revoke_email_code_login_token(args["token"])
account = WebAppAuthService.get_user_through_email(user_email)
account = WebAppAuthService.get_user_through_email(token_email)
if not account:
raise AuthenticationFailedError()
token = WebAppAuthService.login(account=account)
AccountService.reset_login_error_rate_limit(args["email"])
AccountService.reset_login_error_rate_limit(user_email)
response = make_response({"result": "success", "data": {"access_token": token}})
# set_access_token_to_cookie(request, response, token, samesite="None", httponly=False)
return response

View File

@@ -1,380 +0,0 @@
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentEntity, AgentLog, AgentResult
from core.agent.patterns.strategy_factory import StrategyFactory
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from models.model import Message
logger = logging.getLogger(__name__)
class AgentAppRunner(BaseAgentRunner):
def _create_tool_invoke_hook(self, message: Message):
"""
Create a tool invoke hook that uses ToolEngine.agent_invoke.
This hook handles file creation and returns proper meta information.
"""
# Get trace manager from app generate entity
trace_manager = self.application_generate_entity.trace_manager
def tool_invoke_hook(
tool: Tool, tool_args: dict[str, Any], tool_name: str
) -> tuple[str, list[str], ToolInvokeMeta]:
"""Hook that uses agent_invoke for proper file and meta handling."""
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool,
tool_parameters=tool_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=message.id,
conversation_id=self.conversation.id,
)
# Publish files and track IDs
for message_file_id in message_files:
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id),
PublishFrom.APPLICATION_MANAGER,
)
self._current_message_file_ids.append(message_file_id)
return tool_invoke_response, message_files, tool_invoke_meta
return tool_invoke_hook
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run Agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, _ = self._init_prompt_tools()
assert app_config.agent
# Create tool invoke hook for agent_invoke
tool_invoke_hook = self._create_tool_invoke_hook(message)
# Get instruction for ReAct strategy
instruction = self.app_config.prompt_template.simple_prompt_template or ""
# Use factory to create appropriate strategy
strategy = StrategyFactory.create_strategy(
model_features=self.model_features,
model_instance=self.model_instance,
tools=list(tool_instances.values()),
files=list(self.files),
max_iterations=app_config.agent.max_iteration,
context=self.build_execution_context(),
agent_strategy=self.config.strategy,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Initialize state variables
current_agent_thought_id = None
has_published_thought = False
current_tool_name: str | None = None
self._current_message_file_ids: list[str] = []
# organize prompt messages
prompt_messages = self._organize_prompt_messages()
# Run strategy
generator = strategy.run(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
stop=app_generate_entity.model_conf.stop,
stream=True,
)
# Consume generator and collect result
result: AgentResult | None = None
try:
while True:
try:
output = next(generator)
except StopIteration as e:
# Generator finished, get the return value
result = e.value
break
if isinstance(output, LLMResultChunk):
# Handle LLM chunk
if current_agent_thought_id and not has_published_thought:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
has_published_thought = True
yield output
elif isinstance(output, AgentLog):
# Handle Agent Log using log_type for type-safe dispatch
if output.status == AgentLog.LogStatus.START:
if output.log_type == AgentLog.LogType.ROUND:
# Start of a new round
message_file_ids: list[str] = []
current_agent_thought_id = self.create_agent_thought(
message_id=message.id,
message="",
tool_name="",
tool_input="",
messages_ids=message_file_ids,
)
has_published_thought = False
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call start - extract data from structured fields
current_tool_name = output.data.get("tool_name", "")
tool_input = output.data.get("tool_args", {})
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=current_tool_name,
tool_input=tool_input,
thought=None,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.status == AgentLog.LogStatus.SUCCESS:
if output.log_type == AgentLog.LogType.THOUGHT:
if current_agent_thought_id is None:
continue
thought_text = output.data.get("thought")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=thought_text,
observation=None,
tool_invoke_meta=None,
answer=None,
messages_ids=[],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.TOOL_CALL:
if current_agent_thought_id is None:
continue
# Tool call finished
tool_output = output.data.get("output")
# Get meta from strategy output (now properly populated)
tool_meta = output.data.get("meta")
# Wrap tool_meta with tool_name as key (required by agent_service)
if tool_meta and current_tool_name:
tool_meta = {current_tool_name: tool_meta}
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=None,
observation=tool_output,
tool_invoke_meta=tool_meta,
answer=None,
messages_ids=self._current_message_file_ids,
)
# Clear message file ids after saving
self._current_message_file_ids = []
current_tool_name = None
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
elif output.log_type == AgentLog.LogType.ROUND:
if current_agent_thought_id is None:
continue
# Round finished - save LLM usage and answer
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
llm_result = output.data.get("llm_result")
final_answer = output.data.get("final_answer")
self.save_agent_thought(
agent_thought_id=current_agent_thought_id,
tool_name=None,
tool_input=None,
thought=llm_result,
observation=None,
tool_invoke_meta=None,
answer=final_answer,
messages_ids=[],
llm_usage=llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
PublishFrom.APPLICATION_MANAGER,
)
except Exception:
# Re-raise any other exceptions
raise
# Process final result
if isinstance(result, AgentResult):
final_answer = result.text
usage = result.usage or LLMUsage.empty_usage()
# Publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=self.model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=usage,
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_template:
return prompt_messages or []
prompt_messages = prompt_messages or []
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
return prompt_messages
if not prompt_messages:
return [SystemPromptMessage(content=prompt_template)]
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
# For ReAct strategy, use the agent prompt template
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
prompt_template = self.config.prompt.first_prompt
else:
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@@ -1,11 +1,12 @@
import json
import logging
import uuid
from decimal import Decimal
from typing import Union, cast
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity, ExecutionContext
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager
@@ -41,6 +42,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool
from extensions.ext_database import db
from factories import file_factory
from models.enums import CreatorUserRole
from models.model import Conversation, Message, MessageAgentThought, MessageFile
logger = logging.getLogger(__name__)
@@ -114,20 +116,9 @@ class BaseAgentRunner(AppRunner):
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
self.model_features = features
self.query: str | None = ""
self._current_thoughts: list[PromptMessage] = []
def build_execution_context(self) -> ExecutionContext:
"""Build execution context."""
return ExecutionContext(
user_id=self.user_id,
app_id=self.app_config.app_id,
conversation_id=self.conversation.id,
message_id=self.message.id,
tenant_id=self.tenant_id,
)
def _repack_app_generate_entity(
self, app_generate_entity: AgentChatAppGenerateEntity
) -> AgentChatAppGenerateEntity:
@@ -300,6 +291,7 @@ class BaseAgentRunner(AppRunner):
thought = MessageAgentThought(
message_id=message_id,
message_chain_id=None,
tool_process_data=None,
thought="",
tool=tool_name,
tool_labels_str="{}",
@@ -307,20 +299,20 @@ class BaseAgentRunner(AppRunner):
tool_input=tool_input,
message=message,
message_token=0,
message_unit_price=0,
message_price_unit=0,
message_unit_price=Decimal(0),
message_price_unit=Decimal("0.001"),
message_files=json.dumps(messages_ids) if messages_ids else "",
answer="",
observation="",
answer_token=0,
answer_unit_price=0,
answer_price_unit=0,
answer_unit_price=Decimal(0),
answer_price_unit=Decimal("0.001"),
tokens=0,
total_price=0,
total_price=Decimal(0),
position=self.agent_thought_count + 1,
currency="USD",
latency=0,
created_by_role="account",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=self.user_id,
)
@@ -353,7 +345,8 @@ class BaseAgentRunner(AppRunner):
raise ValueError("agent thought not found")
if thought:
agent_thought.thought += thought
existing_thought = agent_thought.thought or ""
agent_thought.thought = f"{existing_thought}{thought}"
if tool_name:
agent_thought.tool = tool_name
@@ -451,21 +444,30 @@ class BaseAgentRunner(AppRunner):
agent_thoughts: list[MessageAgentThought] = message.agent_thoughts
if agent_thoughts:
for agent_thought in agent_thoughts:
tools = agent_thought.tool
if tools:
tools = tools.split(";")
tool_names_raw = agent_thought.tool
if tool_names_raw:
tool_names = tool_names_raw.split(";")
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_call_response: list[ToolPromptMessage] = []
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception:
tool_inputs = {tool: {} for tool in tools}
try:
tool_responses = json.loads(agent_thought.observation)
except Exception:
tool_responses = dict.fromkeys(tools, agent_thought.observation)
tool_input_payload = agent_thought.tool_input
if tool_input_payload:
try:
tool_inputs = json.loads(tool_input_payload)
except Exception:
tool_inputs = {tool: {} for tool in tool_names}
else:
tool_inputs = {tool: {} for tool in tool_names}
for tool in tools:
observation_payload = agent_thought.observation
if observation_payload:
try:
tool_responses = json.loads(observation_payload)
except Exception:
tool_responses = dict.fromkeys(tool_names, observation_payload)
else:
tool_responses = dict.fromkeys(tool_names, observation_payload)
for tool in tool_names:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
tool_calls.append(
@@ -495,7 +497,7 @@ class BaseAgentRunner(AppRunner):
*tool_call_response,
]
)
if not tools:
if not tool_names_raw:
result.append(AssistantPromptMessage(content=agent_thought.thought))
else:
if message.answer:

View File

@@ -0,0 +1,437 @@
import json
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from core.agent.base_agent_runner import BaseAgentRunner
from core.agent.entities import AgentScratchpadUnit
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
UserPromptMessage,
)
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.__base.tool import Tool
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
class CotAgentRunner(BaseAgentRunner, ABC):
_is_first_iteration = True
_ignore_observation_providers = ["wenxin"]
_historic_prompt_messages: list[PromptMessage]
_agent_scratchpad: list[AgentScratchpadUnit]
_instruction: str
_query: str
_prompt_messages_tools: Sequence[PromptMessageTool]
def run(
self,
message: Message,
query: str,
inputs: Mapping[str, str],
) -> Generator:
"""
Run Cot agent application
"""
app_generate_entity = self.application_generate_entity
self._repack_app_generate_entity(app_generate_entity)
self._init_react_state(query)
trace_manager = app_generate_entity.trace_manager
# check model mode
if "Observation" not in app_generate_entity.model_conf.stop:
if app_generate_entity.model_conf.provider not in self._ignore_observation_providers:
app_generate_entity.model_conf.stop.append("Observation")
app_config = self.app_config
assert app_config.agent
# init instruction
inputs = inputs or {}
instruction = app_config.prompt_template.simple_prompt_template or ""
self._instruction = self._fill_in_inputs_from_external_data_tools(instruction, inputs)
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
self._prompt_messages_tools = prompt_messages_tools
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
agent_thought_id = "" # Initialize agent_thought_id
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
# continue to run until there is not any tool call
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
self._prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
if iteration_step > 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=[],
stop=app_generate_entity.model_conf.stop,
stream=True,
user=self.user_id,
callbacks=[],
)
usage_dict: dict[str, LLMUsage | None] = {}
react_chunks = CotAgentOutputParser.handle_react_stream_output(chunks, usage_dict)
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
# publish agent thought if it's first iteration
if iteration_step == 1:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
action = chunk
# detect action
assert scratchpad.agent_response is not None
scratchpad.agent_response += json.dumps(chunk.model_dump())
scratchpad.action_str = json.dumps(chunk.model_dump())
scratchpad.action = action
else:
assert scratchpad.agent_response is not None
scratchpad.agent_response += chunk
assert scratchpad.thought is not None
scratchpad.thought += chunk
yield LLMResultChunk(
model=self.model_config.model,
prompt_messages=prompt_messages,
system_fingerprint="",
delta=LLMResultChunkDelta(index=0, message=AssistantPromptMessage(content=chunk), usage=None),
)
assert scratchpad.thought is not None
scratchpad.thought = scratchpad.thought.strip() or "I am thinking about how to help you"
self._agent_scratchpad.append(scratchpad)
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and scratchpad.action:
if scratchpad.action.action_name.lower() != "final answer":
raise AgentMaxIterationError(app_config.agent.max_iteration)
# get llm usage
if "usage" in usage_dict:
if usage_dict["usage"] is not None:
increase_usage(llm_usage, usage_dict["usage"])
else:
usage_dict["usage"] = LLMUsage.empty_usage()
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=(scratchpad.action.action_name if scratchpad.action and not scratchpad.is_final() else ""),
tool_input={scratchpad.action.action_name: scratchpad.action.action_input} if scratchpad.action else {},
tool_invoke_meta={},
thought=scratchpad.thought or "",
observation="",
answer=scratchpad.agent_response or "",
messages_ids=[],
llm_usage=usage_dict["usage"],
)
if not scratchpad.is_final():
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
if not scratchpad.action:
# failed to extract action, return final answer directly
final_answer = ""
else:
if scratchpad.action.action_name.lower() == "final answer":
# action is final answer, return final answer directly
try:
if isinstance(scratchpad.action.action_input, dict):
final_answer = json.dumps(scratchpad.action.action_input, ensure_ascii=False)
elif isinstance(scratchpad.action.action_input, str):
final_answer = scratchpad.action.action_input
else:
final_answer = f"{scratchpad.action.action_input}"
except TypeError:
final_answer = f"{scratchpad.action.action_input}"
else:
function_call_state = True
# action is tool call, invoke tool
tool_invoke_response, tool_invoke_meta = self._handle_invoke_action(
action=scratchpad.action,
tool_instances=tool_instances,
message_file_ids=message_file_ids,
trace_manager=trace_manager,
)
scratchpad.observation = tool_invoke_response
scratchpad.agent_response = tool_invoke_response
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=scratchpad.action.action_name,
tool_input={scratchpad.action.action_name: scratchpad.action.action_input},
thought=scratchpad.thought or "",
observation={scratchpad.action.action_name: tool_invoke_response},
tool_invoke_meta={scratchpad.action.action_name: tool_invoke_meta.to_dict()},
answer=scratchpad.agent_response,
messages_ids=message_file_ids,
llm_usage=usage_dict["usage"],
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool message
for prompt_tool in self._prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
),
system_fingerprint="",
)
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input={},
tool_invoke_meta={},
thought=final_answer,
observation={},
answer=final_answer,
messages_ids=[],
)
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def _handle_invoke_action(
self,
action: AgentScratchpadUnit.Action,
tool_instances: Mapping[str, Tool],
message_file_ids: list[str],
trace_manager: TraceQueueManager | None = None,
) -> tuple[str, ToolInvokeMeta]:
"""
handle invoke action
:param action: action
:param tool_instances: tool instances
:param message_file_ids: message file ids
:param trace_manager: trace manager
:return: observation, meta
"""
# action is tool call, invoke tool
tool_call_name = action.action_name
tool_call_args = action.action_input
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
answer = f"there is not a tool named {tool_call_name}"
return answer, ToolInvokeMeta.error_instance(answer)
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
return tool_invoke_response, tool_invoke_meta
def _convert_dict_to_action(self, action: dict) -> AgentScratchpadUnit.Action:
"""
convert dict to action
"""
return AgentScratchpadUnit.Action(action_name=action["action"], action_input=action["action_input"])
def _fill_in_inputs_from_external_data_tools(self, instruction: str, inputs: Mapping[str, Any]) -> str:
"""
fill in inputs from external data tools
"""
for key, value in inputs.items():
try:
instruction = instruction.replace(f"{{{{{key}}}}}", str(value))
except Exception:
continue
return instruction
def _init_react_state(self, query):
"""
init agent scratchpad
"""
self._query = query
self._agent_scratchpad = []
self._historic_prompt_messages = self._organize_historic_prompt_messages()
@abstractmethod
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
organize prompt messages
"""
def _format_assistant_message(self, agent_scratchpad: list[AgentScratchpadUnit]) -> str:
"""
format assistant message
"""
message = ""
for scratchpad in agent_scratchpad:
if scratchpad.is_final():
message += f"Final Answer: {scratchpad.agent_response}"
else:
message += f"Thought: {scratchpad.thought}\n\n"
if scratchpad.action_str:
message += f"Action: {scratchpad.action_str}\n\n"
if scratchpad.observation:
message += f"Observation: {scratchpad.observation}\n\n"
return message
def _organize_historic_prompt_messages(
self, current_session_messages: list[PromptMessage] | None = None
) -> list[PromptMessage]:
"""
organize historic prompt messages
"""
result: list[PromptMessage] = []
scratchpads: list[AgentScratchpadUnit] = []
current_scratchpad: AgentScratchpadUnit | None = None
for message in self.history_prompt_messages:
if isinstance(message, AssistantPromptMessage):
if not current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad = AgentScratchpadUnit(
agent_response=message.content,
thought=message.content or "I am thinking about how to help you",
action_str="",
action=None,
observation=None,
)
scratchpads.append(current_scratchpad)
if message.tool_calls:
try:
current_scratchpad.action = AgentScratchpadUnit.Action(
action_name=message.tool_calls[0].function.name,
action_input=json.loads(message.tool_calls[0].function.arguments),
)
current_scratchpad.action_str = json.dumps(current_scratchpad.action.to_dict())
except Exception:
logger.exception("Failed to parse tool call from assistant message")
elif isinstance(message, ToolPromptMessage):
if current_scratchpad:
assert isinstance(message.content, str)
current_scratchpad.observation = message.content
else:
raise NotImplementedError("expected str type")
elif isinstance(message, UserPromptMessage):
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
scratchpads = []
current_scratchpad = None
result.append(message)
if scratchpads:
result.append(AssistantPromptMessage(content=self._format_assistant_message(scratchpads)))
historic_prompts = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=current_session_messages or [],
history_messages=result,
memory=self.memory,
).get_prompt()
return historic_prompts

View File

@@ -0,0 +1,118 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
PromptMessage,
SystemPromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.model_runtime.utils.encoders import jsonable_encoder
class CotChatAgentRunner(CotAgentRunner):
def _organize_system_prompt(self) -> SystemPromptMessage:
"""
Organize system prompt
"""
assert self.app_config.agent
assert self.app_config.agent.prompt
prompt_entity = self.app_config.agent.prompt
if not prompt_entity:
raise ValueError("Agent prompt configuration is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return SystemPromptMessage(content=system_prompt)
def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize
"""
# organize system prompt
system_message = self._organize_system_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
if not agent_scratchpad:
assistant_messages = []
else:
assistant_message = AssistantPromptMessage(content="")
assistant_message.content = "" # FIXME: type check tell mypy that assistant_message.content is str
for unit in agent_scratchpad:
if unit.is_final():
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Final Answer: {unit.agent_response}"
else:
assert isinstance(assistant_message.content, str)
assistant_message.content += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_message.content += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_message.content += f"Observation: {unit.observation}\n\n"
assistant_messages = [assistant_message]
# query messages
query_messages = self._organize_user_query(self._query, [])
if assistant_messages:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages(
[system_message, *query_messages, *assistant_messages, UserPromptMessage(content="continue")]
)
messages = [
system_message,
*historic_messages,
*query_messages,
*assistant_messages,
UserPromptMessage(content="continue"),
]
else:
# organize historic prompt messages
historic_messages = self._organize_historic_prompt_messages([system_message, *query_messages])
messages = [system_message, *historic_messages, *query_messages]
# join all messages
return messages

View File

@@ -0,0 +1,87 @@
import json
from core.agent.cot_agent_runner import CotAgentRunner
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
TextPromptMessageContent,
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
class CotCompletionAgentRunner(CotAgentRunner):
def _organize_instruction_prompt(self) -> str:
"""
Organize instruction prompt
"""
if self.app_config.agent is None:
raise ValueError("Agent configuration is not set")
prompt_entity = self.app_config.agent.prompt
if prompt_entity is None:
raise ValueError("prompt entity is not set")
first_prompt = prompt_entity.first_prompt
system_prompt = (
first_prompt.replace("{{instruction}}", self._instruction)
.replace("{{tools}}", json.dumps(jsonable_encoder(self._prompt_messages_tools)))
.replace("{{tool_names}}", ", ".join([tool.name for tool in self._prompt_messages_tools]))
)
return system_prompt
def _organize_historic_prompt(self, current_session_messages: list[PromptMessage] | None = None) -> str:
"""
Organize historic prompt
"""
historic_prompt_messages = self._organize_historic_prompt_messages(current_session_messages)
historic_prompt = ""
for message in historic_prompt_messages:
if isinstance(message, UserPromptMessage):
historic_prompt += f"Question: {message.content}\n\n"
elif isinstance(message, AssistantPromptMessage):
if isinstance(message.content, str):
historic_prompt += message.content + "\n\n"
elif isinstance(message.content, list):
for content in message.content:
if not isinstance(content, TextPromptMessageContent):
continue
historic_prompt += content.data
return historic_prompt
def _organize_prompt_messages(self) -> list[PromptMessage]:
"""
Organize prompt messages
"""
# organize system prompt
system_prompt = self._organize_instruction_prompt()
# organize historic prompt messages
historic_prompt = self._organize_historic_prompt()
# organize current assistant messages
agent_scratchpad = self._agent_scratchpad
assistant_prompt = ""
for unit in agent_scratchpad or []:
if unit.is_final():
assistant_prompt += f"Final Answer: {unit.agent_response}"
else:
assistant_prompt += f"Thought: {unit.thought}\n\n"
if unit.action_str:
assistant_prompt += f"Action: {unit.action_str}\n\n"
if unit.observation:
assistant_prompt += f"Observation: {unit.observation}\n\n"
# query messages
query_prompt = f"Question: {self._query}"
# join all messages
prompt = (
system_prompt.replace("{{historic_messages}}", historic_prompt)
.replace("{{agent_scratchpad}}", assistant_prompt)
.replace("{{query}}", query_prompt)
)
return [UserPromptMessage(content=prompt)]

View File

@@ -1,5 +1,3 @@
import uuid
from collections.abc import Mapping
from enum import StrEnum
from typing import Any, Union
@@ -94,96 +92,3 @@ class AgentInvokeMessage(ToolInvokeMessage):
"""
pass
class ExecutionContext(BaseModel):
"""Execution context containing trace and audit information.
This context carries all the IDs and metadata that are not part of
the core business logic but needed for tracing, auditing, and
correlation purposes.
"""
user_id: str | None = None
app_id: str | None = None
conversation_id: str | None = None
message_id: str | None = None
tenant_id: str | None = None
@classmethod
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
"""Create a minimal context with only essential fields."""
return cls(user_id=user_id)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for passing to legacy code."""
return {
"user_id": self.user_id,
"app_id": self.app_id,
"conversation_id": self.conversation_id,
"message_id": self.message_id,
"tenant_id": self.tenant_id,
}
def with_updates(self, **kwargs) -> "ExecutionContext":
"""Create a new context with updated fields."""
data = self.to_dict()
data.update(kwargs)
return ExecutionContext(
user_id=data.get("user_id"),
app_id=data.get("app_id"),
conversation_id=data.get("conversation_id"),
message_id=data.get("message_id"),
tenant_id=data.get("tenant_id"),
)
class AgentLog(BaseModel):
"""
Agent Log.
"""
class LogType(StrEnum):
"""Type of agent log entry."""
ROUND = "round" # A complete iteration round
THOUGHT = "thought" # LLM thinking/reasoning
TOOL_CALL = "tool_call" # Tool invocation
class LogMetadata(StrEnum):
STARTED_AT = "started_at"
FINISHED_AT = "finished_at"
ELAPSED_TIME = "elapsed_time"
TOTAL_PRICE = "total_price"
TOTAL_TOKENS = "total_tokens"
PROVIDER = "provider"
CURRENCY = "currency"
LLM_USAGE = "llm_usage"
ICON = "icon"
ICON_DARK = "icon_dark"
class LogStatus(StrEnum):
START = "start"
ERROR = "error"
SUCCESS = "success"
id: str = Field(default_factory=lambda: str(uuid.uuid4()), description="The id of the log")
label: str = Field(..., description="The label of the log")
log_type: LogType = Field(..., description="The type of the log")
parent_id: str | None = Field(default=None, description="Leave empty for root log")
error: str | None = Field(default=None, description="The error message")
status: LogStatus = Field(..., description="The status of the log")
data: Mapping[str, Any] = Field(..., description="Detailed log data")
metadata: Mapping[LogMetadata, Any] = Field(default={}, description="The metadata of the log")
class AgentResult(BaseModel):
"""
Agent execution result.
"""
text: str = Field(default="", description="The generated text")
files: list[Any] = Field(default_factory=list, description="Files produced during execution")
usage: Any | None = Field(default=None, description="LLM usage statistics")
finish_reason: str | None = Field(default=None, description="Reason for completion")

View File

@@ -0,0 +1,468 @@
import json
import logging
from collections.abc import Generator
from copy import deepcopy
from typing import Any, Union
from core.agent.base_agent_runner import BaseAgentRunner
from core.app.apps.base_app_queue_manager import PublishFrom
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
from core.file import file_manager
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageContentType,
SystemPromptMessage,
TextPromptMessageContent,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
from core.tools.entities.tool_entities import ToolInvokeMeta
from core.tools.tool_engine import ToolEngine
from core.workflow.nodes.agent.exc import AgentMaxIterationError
from models.model import Message
logger = logging.getLogger(__name__)
class FunctionCallAgentRunner(BaseAgentRunner):
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
"""
Run FunctionCall agent application
"""
self.query = query
app_generate_entity = self.application_generate_entity
app_config = self.app_config
assert app_config is not None, "app_config is required"
assert app_config.agent is not None, "app_config.agent is required"
# convert tools into ModelRuntime Tool format
tool_instances, prompt_messages_tools = self._init_prompt_tools()
assert app_config.agent
iteration_step = 1
max_iteration_steps = min(app_config.agent.max_iteration, 99) + 1
# continue to run until there is not any tool call
function_call_state = True
llm_usage: dict[str, LLMUsage | None] = {"usage": None}
final_answer = ""
prompt_messages: list = [] # Initialize prompt_messages
# get tracing instance
trace_manager = app_generate_entity.trace_manager
def increase_usage(final_llm_usage_dict: dict[str, LLMUsage | None], usage: LLMUsage):
if not final_llm_usage_dict["usage"]:
final_llm_usage_dict["usage"] = usage
else:
llm_usage = final_llm_usage_dict["usage"]
llm_usage.prompt_tokens += usage.prompt_tokens
llm_usage.completion_tokens += usage.completion_tokens
llm_usage.total_tokens += usage.total_tokens
llm_usage.prompt_price += usage.prompt_price
llm_usage.completion_price += usage.completion_price
llm_usage.total_price += usage.total_price
model_instance = self.model_instance
while function_call_state and iteration_step <= max_iteration_steps:
function_call_state = False
if iteration_step == max_iteration_steps:
# the last iteration, remove all tools
prompt_messages_tools = []
message_file_ids: list[str] = []
agent_thought_id = self.create_agent_thought(
message_id=message.id, message="", tool_name="", tool_input="", messages_ids=message_file_ids
)
# recalc llm max tokens
prompt_messages = self._organize_prompt_messages()
self.recalc_llm_max_tokens(self.model_config, prompt_messages)
# invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = model_instance.invoke_llm(
prompt_messages=prompt_messages,
model_parameters=app_generate_entity.model_conf.parameters,
tools=prompt_messages_tools,
stop=app_generate_entity.model_conf.stop,
stream=self.stream_tool_call,
user=self.user_id,
callbacks=[],
)
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
# save full response
response = ""
# save tool call names and inputs
tool_call_names = ""
tool_call_inputs = ""
current_llm_usage = None
if isinstance(chunks, Generator):
is_first_chunk = True
for chunk in chunks:
if is_first_chunk:
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
is_first_chunk = False
# check if there is any tool call
if self.check_tool_calls(chunk):
function_call_state = True
tool_calls.extend(self.extract_tool_calls(chunk) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if chunk.delta.message and chunk.delta.message.content:
if isinstance(chunk.delta.message.content, list):
for content in chunk.delta.message.content:
response += content.data
else:
response += str(chunk.delta.message.content)
if chunk.delta.usage:
increase_usage(llm_usage, chunk.delta.usage)
current_llm_usage = chunk.delta.usage
yield chunk
else:
result = chunks
# check if there is any tool call
if self.check_blocking_tool_calls(result):
function_call_state = True
tool_calls.extend(self.extract_blocking_tool_calls(result) or [])
tool_call_names = ";".join([tool_call[1] for tool_call in tool_calls])
try:
tool_call_inputs = json.dumps(
{tool_call[1]: tool_call[2] for tool_call in tool_calls}, ensure_ascii=False
)
except TypeError:
# fallback: force ASCII to handle non-serializable objects
tool_call_inputs = json.dumps({tool_call[1]: tool_call[2] for tool_call in tool_calls})
if result.usage:
increase_usage(llm_usage, result.usage)
current_llm_usage = result.usage
if result.message and result.message.content:
if isinstance(result.message.content, list):
for content in result.message.content:
response += content.data
else:
response += str(result.message.content)
if not result.message.content:
result.message.content = ""
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
yield LLMResultChunk(
model=model_instance.model,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
index=0,
message=result.message,
usage=result.usage,
),
)
assistant_message = AssistantPromptMessage(content=response, tool_calls=[])
if tool_calls:
assistant_message.tool_calls = [
AssistantPromptMessage.ToolCall(
id=tool_call[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_call[1], arguments=json.dumps(tool_call[2], ensure_ascii=False)
),
)
for tool_call in tool_calls
]
self._current_thoughts.append(assistant_message)
# save thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name=tool_call_names,
tool_input=tool_call_inputs,
thought=response,
tool_invoke_meta=None,
observation=None,
answer=response,
messages_ids=[],
llm_usage=current_llm_usage,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
final_answer += response + "\n"
# Check if max iteration is reached and model still wants to call tools
if iteration_step == max_iteration_steps and tool_calls:
raise AgentMaxIterationError(app_config.agent.max_iteration)
# call tools
tool_responses = []
for tool_call_id, tool_call_name, tool_call_args in tool_calls:
tool_instance = tool_instances.get(tool_call_name)
if not tool_instance:
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}",
"meta": ToolInvokeMeta.error_instance(f"there is not a tool named {tool_call_name}").to_dict(),
}
else:
# invoke tool
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
trace_manager=trace_manager,
app_id=self.application_generate_entity.app_config.app_id,
message_id=self.message.id,
conversation_id=self.conversation.id,
)
# publish files
for message_file_id in message_files:
# publish message file
self.queue_manager.publish(
QueueMessageFileEvent(message_file_id=message_file_id), PublishFrom.APPLICATION_MANAGER
)
# add message file ids
message_file_ids.append(message_file_id)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": tool_invoke_response,
"meta": tool_invoke_meta.to_dict(),
}
tool_responses.append(tool_response)
if tool_response["tool_response"] is not None:
self._current_thoughts.append(
ToolPromptMessage(
content=str(tool_response["tool_response"]),
tool_call_id=tool_call_id,
name=tool_call_name,
)
)
if len(tool_responses) > 0:
# save agent thought
self.save_agent_thought(
agent_thought_id=agent_thought_id,
tool_name="",
tool_input="",
thought="",
tool_invoke_meta={
tool_response["tool_call_name"]: tool_response["meta"] for tool_response in tool_responses
},
observation={
tool_response["tool_call_name"]: tool_response["tool_response"]
for tool_response in tool_responses
},
answer="",
messages_ids=message_file_ids,
)
self.queue_manager.publish(
QueueAgentThoughtEvent(agent_thought_id=agent_thought_id), PublishFrom.APPLICATION_MANAGER
)
# update prompt tool
for prompt_tool in prompt_messages_tools:
self.update_prompt_message_tool(tool_instances[prompt_tool.name], prompt_tool)
iteration_step += 1
# publish end event
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
system_fingerprint="",
)
),
PublishFrom.APPLICATION_MANAGER,
)
def check_tool_calls(self, llm_result_chunk: LLMResultChunk) -> bool:
"""
Check if there is any tool call in llm result chunk
"""
if llm_result_chunk.delta.message.tool_calls:
return True
return False
def check_blocking_tool_calls(self, llm_result: LLMResult) -> bool:
"""
Check if there is any blocking tool call in llm result
"""
if llm_result.message.tool_calls:
return True
return False
def extract_tool_calls(self, llm_result_chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract tool calls from llm result chunk
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result_chunk.delta.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def extract_blocking_tool_calls(self, llm_result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""
Extract blocking tool calls from llm result
Returns:
List[Tuple[str, str, Dict[str, Any]]]: [(tool_call_id, tool_call_name, tool_call_args)]
"""
tool_calls = []
for prompt_message in llm_result.message.tool_calls:
args = {}
if prompt_message.function.arguments != "":
args = json.loads(prompt_message.function.arguments)
tool_calls.append(
(
prompt_message.id,
prompt_message.function.name,
args,
)
)
return tool_calls
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Initialize system message
"""
if not prompt_messages and prompt_template:
return [
SystemPromptMessage(content=prompt_template),
]
if prompt_messages and not isinstance(prompt_messages[0], SystemPromptMessage) and prompt_template:
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
return prompt_messages or []
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
Organize user query
"""
if self.files:
# get image detail config
image_detail_config = (
self.application_generate_entity.file_upload_config.image_config.detail
if (
self.application_generate_entity.file_upload_config
and self.application_generate_entity.file_upload_config.image_config
)
else None
)
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
for file in self.files:
prompt_message_contents.append(
file_manager.to_prompt_message_content(
file,
image_detail_config=image_detail_config,
)
)
prompt_message_contents.append(TextPromptMessageContent(data=query))
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
else:
prompt_messages.append(UserPromptMessage(content=query))
return prompt_messages
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
"""
As for now, gpt supports both fc and vision at the first iteration.
We need to remove the image messages from the prompt messages at the first iteration.
"""
prompt_messages = deepcopy(prompt_messages)
for prompt_message in prompt_messages:
if isinstance(prompt_message, UserPromptMessage):
if isinstance(prompt_message.content, list):
prompt_message.content = "\n".join(
[
content.data
if content.type == PromptMessageContentType.TEXT
else "[image]"
if content.type == PromptMessageContentType.IMAGE
else "[file]"
for content in prompt_message.content
]
)
return prompt_messages
def _organize_prompt_messages(self):
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
query_prompt_messages = self._organize_user_query(self.query or "", [])
self.history_prompt_messages = AgentHistoryPromptTransform(
model_config=self.model_config,
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
history_messages=self.history_prompt_messages,
memory=self.memory,
).get_prompt()
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
if len(self._current_thoughts) != 0:
# clear messages after the first iteration
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
return prompt_messages

View File

@@ -1,55 +0,0 @@
# Agent Patterns
A unified agent pattern module that powers both Agent V2 workflow nodes and agent applications. Strategies share a common execution contract while adapting to model capabilities and tool availability.
## Overview
The module applies a strategy pattern around LLM/tool orchestration. `StrategyFactory` auto-selects the best implementation based on model features or an explicit agent strategy, and each strategy streams logs and usage consistently.
## Key Features
- **Dual strategies**
- `FunctionCallStrategy`: uses native LLM function/tool calling when the model exposes `TOOL_CALL`, `MULTI_TOOL_CALL`, or `STREAM_TOOL_CALL`.
- `ReActStrategy`: ReAct (reasoning + acting) flow driven by `CotAgentOutputParser`, used when function calling is unavailable or explicitly requested.
- **Explicit or auto selection**
- `StrategyFactory.create_strategy` prefers an explicit `AgentEntity.Strategy` (FUNCTION_CALLING or CHAIN_OF_THOUGHT).
- Otherwise it falls back to function calling when tool-call features exist, or ReAct when they do not.
- **Unified execution contract**
- `AgentPattern.run` yields streaming `AgentLog` entries and `LLMResultChunk` data, returning an `AgentResult` with text, files, usage, and `finish_reason`.
- Iterations are configurable and hard-capped at 99 rounds; the last round forces a final answer by withholding tools.
- **Tool handling and hooks**
- Tools convert to `PromptMessageTool` objects before invocation.
- Optional `tool_invoke_hook` lets callers override tool execution (e.g., agent apps) while workflow runs use `ToolEngine.generic_invoke`.
- Tool outputs support text, links, JSON, variables, blobs, retriever resources, and file attachments; `target=="self"` files are reloaded into model context, others are returned as outputs.
- **File-aware arguments**
- Tool args accept `[File: <id>]` or `[Files: <id1, id2>]` placeholders that resolve to `File` objects before invocation, enabling models to reference uploaded files safely.
- **ReAct prompt shaping**
- System prompts replace `{{instruction}}`, `{{tools}}`, and `{{tool_names}}` placeholders.
- Adds `Observation` to stop sequences and appends scratchpad text so the model sees prior Thought/Action/Observation history.
- **Observability and accounting**
- Standardized `AgentLog` entries for rounds, model thoughts, and tool calls, including usage aggregation (`LLMUsage`) across streaming and non-streaming paths.
## Architecture
```
agent/patterns/
├── base.py # Shared utilities: logging, usage, tool invocation, file handling
├── function_call.py # Native function-calling loop with tool execution
├── react.py # ReAct loop with CoT parsing and scratchpad wiring
└── strategy_factory.py # Strategy selection by model features or explicit override
```
## Usage
- For auto-selection:
- Call `StrategyFactory.create_strategy(model_features, model_instance, context, tools, files, ...)` and run the returned strategy with prompt messages and model params.
- For explicit behavior:
- Pass `agent_strategy=AgentEntity.Strategy.FUNCTION_CALLING` to force native calls (falls back to ReAct if unsupported), or `CHAIN_OF_THOUGHT` to force ReAct.
- Both strategies stream chunks and logs; collect the generator output until it returns an `AgentResult`.
## Integration Points
- **Model runtime**: delegates to `ModelInstance.invoke_llm` for both streaming and non-streaming calls.
- **Tool system**: defaults to `ToolEngine.generic_invoke`, with `tool_invoke_hook` for custom callers.
- **Files**: flows through `File` objects for tool inputs/outputs and model-context attachments.
- **Execution context**: `ExecutionContext` fields (user/app/conversation/message) propagate to tool invocations and logging.

View File

@@ -1,19 +0,0 @@
"""Agent patterns module.
This module provides different strategies for agent execution:
- FunctionCallStrategy: Uses native function/tool calling
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
- StrategyFactory: Factory for creating strategies based on model features
"""
from .base import AgentPattern
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
from .strategy_factory import StrategyFactory
__all__ = [
"AgentPattern",
"FunctionCallStrategy",
"ReActStrategy",
"StrategyFactory",
]

View File

@@ -1,474 +0,0 @@
"""Base class for agent strategies."""
from __future__ import annotations
import json
import re
import time
from abc import ABC, abstractmethod
from collections.abc import Callable, Generator
from typing import TYPE_CHECKING, Any
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
PromptMessageTool,
)
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import TextPromptMessageContent
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
# Type alias for tool invoke hook
# Returns: (response_content, message_file_ids, tool_invoke_meta)
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
class AgentPattern(ABC):
"""Base class for agent execution strategies."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
):
"""Initialize the agent strategy."""
self.model_instance = model_instance
self.tools = tools
self.context = context
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
self.workflow_call_depth = workflow_call_depth
self.files: list[File] = files
self.tool_invoke_hook = tool_invoke_hook
@abstractmethod
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the agent strategy."""
pass
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
"""Accumulate LLM usage statistics."""
if not total_usage.get("usage"):
# Create a copy to avoid modifying the original
total_usage["usage"] = LLMUsage(
prompt_tokens=delta_usage.prompt_tokens,
prompt_unit_price=delta_usage.prompt_unit_price,
prompt_price_unit=delta_usage.prompt_price_unit,
prompt_price=delta_usage.prompt_price,
completion_tokens=delta_usage.completion_tokens,
completion_unit_price=delta_usage.completion_unit_price,
completion_price_unit=delta_usage.completion_price_unit,
completion_price=delta_usage.completion_price,
total_tokens=delta_usage.total_tokens,
total_price=delta_usage.total_price,
currency=delta_usage.currency,
latency=delta_usage.latency,
)
else:
current: LLMUsage = total_usage["usage"]
current.prompt_tokens += delta_usage.prompt_tokens
current.completion_tokens += delta_usage.completion_tokens
current.total_tokens += delta_usage.total_tokens
current.prompt_price += delta_usage.prompt_price
current.completion_price += delta_usage.completion_price
current.total_price += delta_usage.total_price
def _extract_content(self, content: Any) -> str:
"""Extract text content from message content."""
if isinstance(content, list):
# Content items are PromptMessageContentUnionTypes
text_parts = []
for c in content:
# Check if it's a TextPromptMessageContent (which has data attribute)
if isinstance(c, TextPromptMessageContent):
text_parts.append(c.data)
return "".join(text_parts)
return str(content)
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
"""Check if chunk contains tool calls."""
# LLMResultChunk always has delta attribute
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
def _has_tool_calls_result(self, result: LLMResult) -> bool:
"""Check if result contains tool calls (non-streaming)."""
# LLMResult always has message attribute
return bool(result.message and result.message.tool_calls)
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from streaming chunk."""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
if chunk.delta.message and chunk.delta.message.tool_calls:
for tool_call in chunk.delta.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
"""Extract tool calls from non-streaming result."""
tool_calls = []
if result.message and result.message.tool_calls:
for tool_call in result.message.tool_calls:
if tool_call.function:
try:
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
except json.JSONDecodeError:
args = {}
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
return tool_calls
def _extract_text_from_message(self, message: PromptMessage) -> str:
"""Extract text content from a prompt message."""
# PromptMessage always has content attribute
content = message.content
if isinstance(content, str):
return content
elif isinstance(content, list):
# Extract text from content list
text_parts = []
for item in content:
if isinstance(item, TextPromptMessageContent):
text_parts.append(item.data)
return " ".join(text_parts)
return ""
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
"""Get metadata for a tool including provider and icon info."""
from core.tools.tool_manager import ToolManager
metadata: dict[AgentLog.LogMetadata, Any] = {}
if tool_instance.entity and tool_instance.entity.identity:
identity = tool_instance.entity.identity
if identity.provider:
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
# Get icon using ToolManager for proper URL generation
tenant_id = self.context.tenant_id
if tenant_id and identity.provider:
try:
provider_type = tool_instance.tool_provider_type()
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
if isinstance(icon, str):
metadata[AgentLog.LogMetadata.ICON] = icon
elif isinstance(icon, dict):
# Handle icon dict with background/content or light/dark variants
metadata[AgentLog.LogMetadata.ICON] = icon
except Exception:
# Fallback to identity.icon if ToolManager fails
if identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
elif identity.icon:
metadata[AgentLog.LogMetadata.ICON] = identity.icon
return metadata
def _create_log(
self,
label: str,
log_type: AgentLog.LogType,
status: AgentLog.LogStatus,
data: dict[str, Any] | None = None,
parent_id: str | None = None,
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
) -> AgentLog:
"""Create a new AgentLog with standard metadata."""
metadata: dict[AgentLog.LogMetadata, Any] = {
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
}
if extra_metadata:
metadata.update(extra_metadata)
return AgentLog(
label=label,
log_type=log_type,
status=status,
data=data or {},
parent_id=parent_id,
metadata=metadata,
)
def _finish_log(
self,
log: AgentLog,
data: dict[str, Any] | None = None,
usage: LLMUsage | None = None,
) -> AgentLog:
"""Finish an AgentLog by updating its status and metadata."""
log.status = AgentLog.LogStatus.SUCCESS
if data is not None:
log.data = data
# Calculate elapsed time
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
finished_at = time.perf_counter()
# Update metadata
log.metadata = {
**log.metadata,
AgentLog.LogMetadata.FINISHED_AT: finished_at,
# Calculate elapsed time in seconds
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
}
# Add usage information if provided
if usage:
log.metadata.update(
{
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
AgentLog.LogMetadata.CURRENCY: usage.currency,
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
AgentLog.LogMetadata.LLM_USAGE: usage,
}
)
return log
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
"""
Replace file references in tool arguments with actual File objects.
Args:
tool_args: Dictionary of tool arguments
Returns:
Updated tool arguments with file references replaced
"""
# Process each argument in the dictionary
processed_args: dict[str, Any] = {}
for key, value in tool_args.items():
processed_args[key] = self._process_file_reference(value)
return processed_args
def _process_file_reference(self, data: Any) -> Any:
"""
Recursively process data to replace file references.
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
Args:
data: The data to process (can be dict, list, str, or other types)
Returns:
Processed data with file references replaced
"""
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
if isinstance(data, dict):
# Process dictionary recursively
return {key: self._process_file_reference(value) for key, value in data.items()}
elif isinstance(data, list):
# Process list recursively
return [self._process_file_reference(item) for item in data]
elif isinstance(data, str):
# Check for single file pattern [File: file_id]
single_match = single_file_pattern.match(data.strip())
if single_match:
file_id = single_match.group(1).strip()
# Find the file in self.files
for file in self.files:
if file.id and str(file.id) == file_id:
return file
# If file not found, return original value
return data
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
multiple_match = multiple_files_pattern.match(data.strip())
if multiple_match:
file_ids_str = multiple_match.group(1).strip()
# Split by comma and strip whitespace
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
# Find all matching files
matched_files: list[File] = []
for file_id in file_ids:
for file in self.files:
if file.id and str(file.id) == file_id:
matched_files.append(file)
break
# Return list of files if any were found, otherwise return original
return matched_files or data
return data
else:
# Return other types as-is
return data
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
"""Create a text chunk for streaming."""
return LLMResultChunk(
model=self.model_instance.model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=AssistantPromptMessage(content=text),
usage=None,
),
system_fingerprint="",
)
def _invoke_tool(
self,
tool_instance: Tool,
tool_args: dict[str, Any],
tool_name: str,
) -> tuple[str, list[File], ToolInvokeMeta | None]:
"""
Invoke a tool and collect its response.
Args:
tool_instance: The tool instance to invoke
tool_args: Tool arguments
tool_name: Name of the tool
Returns:
Tuple of (response_content, tool_files, tool_invoke_meta)
"""
# Process tool_args to replace file references with actual File objects
tool_args = self._replace_file_references(tool_args)
# If a tool invoke hook is set, use it instead of generic_invoke
if self.tool_invoke_hook:
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
# The caller (AgentAppRunner) handles file publishing
return response_content, [], tool_invoke_meta
# Default: use generic_invoke for workflow scenarios
# Import here to avoid circular import
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
tool_response = ToolEngine().generic_invoke(
tool=tool_instance,
tool_parameters=tool_args,
user_id=self.context.user_id or "",
workflow_tool_callback=DifyWorkflowCallbackHandler(),
workflow_call_depth=self.workflow_call_depth,
app_id=self.context.app_id,
conversation_id=self.context.conversation_id,
message_id=self.context.message_id,
)
# Collect response and files
response_content = ""
tool_files: list[File] = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
response_content += response.message.text
elif response.type == ToolInvokeMessage.MessageType.LINK:
# Handle link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Link: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
# Handle image URL messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
# Handle image link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
response_content += f"[Image: {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
# Handle binary file link messages
if isinstance(response.message, ToolInvokeMessage.TextMessage):
filename = response.meta.get("filename", "file") if response.meta else "file"
response_content += f"[File: {filename} - {response.message.text}]"
elif response.type == ToolInvokeMessage.MessageType.JSON:
# Handle JSON messages
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
elif response.type == ToolInvokeMessage.MessageType.BLOB:
# Handle blob messages - convert to text representation
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
mime_type = (
response.meta.get("mime_type", "application/octet-stream")
if response.meta
else "application/octet-stream"
)
size = len(response.message.blob)
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
# Handle variable messages
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
var_name = response.message.variable_name
var_value = response.message.variable_value
if isinstance(var_value, str):
response_content += var_value
else:
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
# Handle blob chunk messages - these are parts of a larger blob
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
# Handle retriever resources messages
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
response_content += response.message.context
elif response.type == ToolInvokeMessage.MessageType.FILE:
# Extract file from meta
if response.meta and "file" in response.meta:
file = response.meta["file"]
if isinstance(file, File):
# Check if file is for model or tool output
if response.meta.get("target") == "self":
# File is for model - add to files for next prompt
self.files.append(file)
response_content += f"File '{file.filename}' has been loaded into your context."
else:
# File is tool output
tool_files.append(file)
return response_content, tool_files, None
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
"""Find a tool instance by its name."""
for tool in self.tools:
if tool.entity.identity.name == tool_name:
return tool
return None
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
"""Convert tools to prompt message format."""
prompt_tools: list[PromptMessageTool] = []
for tool in self.tools:
prompt_tools.append(tool.to_prompt_message_tool())
return prompt_tools
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
"""Initialize usage tracking with empty usage if not set."""
if "usage" not in llm_usage or llm_usage["usage"] is None:
llm_usage["usage"] = LLMUsage.empty_usage()

View File

@@ -1,299 +0,0 @@
"""Function Call strategy implementation."""
import json
from collections.abc import Generator
from typing import Any, Union
from core.agent.entities import AgentLog, AgentResult
from core.file import File
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
LLMUsage,
PromptMessage,
PromptMessageTool,
ToolPromptMessage,
)
from core.tools.entities.tool_entities import ToolInvokeMeta
from .base import AgentPattern
class FunctionCallStrategy(AgentPattern):
"""Function Call strategy using model's native tool calling capability."""
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the function call agent strategy."""
# Convert tools to prompt format
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
# Initialize tracking
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
function_call_state: bool = True
total_usage: dict[str, LLMUsage | None] = {"usage": None}
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
final_text: str = ""
finish_reason: str | None = None
output_files: list[File] = [] # Track files produced by tools
while function_call_state and iteration_step <= max_iterations:
function_call_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# On last iteration, remove tools to force final answer
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, LLMUsage | None] = {"usage": None}
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages,
model_parameters=model_parameters,
tools=current_tools,
stop=stop,
stream=stream,
user=self.context.user_id,
callbacks=[],
)
# Process response
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log
)
messages.append(self._create_assistant_message(response_content, tool_calls))
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update final text if no tool calls (this is likely the final answer)
if not tool_calls:
final_text = response_content
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Process tool calls
tool_outputs: dict[str, str] = {}
if tool_calls:
function_call_state = True
# Execute tools
for tool_call_id, tool_name, tool_args in tool_calls:
tool_response, tool_files, _ = yield from self._handle_tool_call(
tool_name, tool_args, tool_call_id, messages, round_log
)
tool_outputs[tool_name] = tool_response
# Track files produced by tools
output_files.extend(tool_files)
yield self._finish_log(
round_log,
data={
"llm_result": response_content,
"tool_calls": [
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
]
if tool_calls
else [],
"final_answer": final_text if not function_call_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text,
files=output_files,
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
finish_reason=finish_reason,
)
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, LLMUsage | None],
start_log: AgentLog,
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
]:
"""Handle LLM response chunks and extract tool calls and content.
Returns a tuple of (tool_calls, response_content, finish_reason).
"""
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
response_content: str = ""
finish_reason: str | None = None
if isinstance(chunks, Generator):
# Streaming response
for chunk in chunks:
# Extract tool calls
if self._has_tool_calls(chunk):
tool_calls.extend(self._extract_tool_calls(chunk))
# Extract content
if chunk.delta.message and chunk.delta.message.content:
response_content += self._extract_content(chunk.delta.message.content)
# Track usage
if chunk.delta.usage:
self._accumulate_usage(llm_usage, chunk.delta.usage)
# Capture finish reason
if chunk.delta.finish_reason:
finish_reason = chunk.delta.finish_reason
yield chunk
else:
# Non-streaming response
result: LLMResult = chunks
if self._has_tool_calls_result(result):
tool_calls.extend(self._extract_tool_calls_result(result))
if result.message and result.message.content:
response_content += self._extract_content(result.message.content)
if result.usage:
self._accumulate_usage(llm_usage, result.usage)
# Convert to streaming format
yield LLMResultChunk(
model=result.model,
prompt_messages=result.prompt_messages,
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
)
yield self._finish_log(
start_log,
data={
"result": response_content,
},
usage=llm_usage.get("usage"),
)
return tool_calls, response_content, finish_reason
def _create_assistant_message(
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
) -> AssistantPromptMessage:
"""Create assistant message with tool calls."""
if tool_calls is None:
return AssistantPromptMessage(content=content)
return AssistantPromptMessage(
content=content or "",
tool_calls=[
AssistantPromptMessage.ToolCall(
id=tc[0],
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
)
for tc in tool_calls
],
)
def _handle_tool_call(
self,
tool_name: str,
tool_args: dict[str, Any],
tool_call_id: str,
messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None]]:
"""Handle a single tool call and return response with files and meta."""
# Find tool
tool_instance = self._find_tool_by_name(tool_name)
if not tool_instance:
raise ValueError(f"Tool {tool_name} not found")
# Get tool metadata (provider, icon, etc.)
tool_metadata = self._get_tool_metadata(tool_instance)
# Create tool call log
tool_call_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_call_log
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
yield self._finish_log(
tool_call_log,
data={
**tool_call_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
final_content = response_content or "Tool executed successfully"
# Add tool response to messages
messages.append(
ToolPromptMessage(
content=final_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return response_content, tool_files, tool_invoke_meta
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_call_log.status = AgentLog.LogStatus.ERROR
tool_call_log.error = error_message
tool_call_log.data = {
**tool_call_log.data,
"error": error_message,
}
yield tool_call_log
# Add error message to conversation
error_content = f"Tool execution failed: {error_message}"
messages.append(
ToolPromptMessage(
content=error_content,
tool_call_id=tool_call_id,
name=tool_name,
)
)
return error_content, [], None

View File

@@ -1,418 +0,0 @@
"""ReAct strategy implementation."""
from __future__ import annotations
import json
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Union
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
from core.file import File
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
LLMResult,
LLMResultChunk,
LLMResultChunkDelta,
PromptMessage,
SystemPromptMessage,
)
from .base import AgentPattern, ToolInvokeHook
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class ReActStrategy(AgentPattern):
"""ReAct strategy using reasoning and acting approach."""
def __init__(
self,
model_instance: ModelInstance,
tools: list[Tool],
context: ExecutionContext,
max_iterations: int = 10,
workflow_call_depth: int = 0,
files: list[File] = [],
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
):
"""Initialize the ReAct strategy with instruction support."""
super().__init__(
model_instance=model_instance,
tools=tools,
context=context,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
files=files,
tool_invoke_hook=tool_invoke_hook,
)
self.instruction = instruction
def run(
self,
prompt_messages: list[PromptMessage],
model_parameters: dict[str, Any],
stop: list[str] = [],
stream: bool = True,
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
"""Execute the ReAct agent strategy."""
# Initialize tracking
agent_scratchpad: list[AgentScratchpadUnit] = []
iteration_step: int = 1
max_iterations: int = self.max_iterations + 1
react_state: bool = True
total_usage: dict[str, Any] = {"usage": None}
output_files: list[File] = [] # Track files produced by tools
final_text: str = ""
finish_reason: str | None = None
# Add "Observation" to stop sequences
if "Observation" not in stop:
stop = stop.copy()
stop.append("Observation")
while react_state and iteration_step <= max_iterations:
react_state = False
round_log = self._create_log(
label=f"ROUND {iteration_step}",
log_type=AgentLog.LogType.ROUND,
status=AgentLog.LogStatus.START,
data={},
)
yield round_log
# Build prompt with/without tools based on iteration
include_tools = iteration_step < max_iterations
current_messages = self._build_prompt_with_react_format(
prompt_messages, agent_scratchpad, include_tools, self.instruction
)
model_log = self._create_log(
label=f"{self.model_instance.model} Thought",
log_type=AgentLog.LogType.THOUGHT,
status=AgentLog.LogStatus.START,
data={},
parent_id=round_log.id,
extra_metadata={
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
},
)
yield model_log
# Track usage for this round only
round_usage: dict[str, Any] = {"usage": None}
# Use current messages directly (files are handled by base class if needed)
messages_to_use = current_messages
# Invoke model
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
prompt_messages=messages_to_use,
model_parameters=model_parameters,
stop=stop,
stream=stream,
user=self.context.user_id or "",
callbacks=[],
)
# Process response
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
chunks, round_usage, model_log, current_messages
)
agent_scratchpad.append(scratchpad)
# Accumulate to total usage
round_usage_value = round_usage.get("usage")
if round_usage_value:
self._accumulate_usage(total_usage, round_usage_value)
# Update finish reason
if chunk_finish_reason:
finish_reason = chunk_finish_reason
# Check if we have an action to execute
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
react_state = True
# Execute tool
observation, tool_files = yield from self._handle_tool_call(
scratchpad.action, current_messages, round_log
)
scratchpad.observation = observation
# Track files produced by tools
output_files.extend(tool_files)
# Add observation to scratchpad for display
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
else:
# Extract final answer
if scratchpad.action and scratchpad.action.action_input:
final_answer = scratchpad.action.action_input
if isinstance(final_answer, dict):
final_answer = json.dumps(final_answer, ensure_ascii=False)
final_text = str(final_answer)
elif scratchpad.thought:
# If no action but we have thought, use thought as final answer
final_text = scratchpad.thought
yield self._finish_log(
round_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
"observation": scratchpad.observation or None,
"final_answer": final_text if not react_state else None,
},
usage=round_usage.get("usage"),
)
iteration_step += 1
# Return final result
from core.agent.entities import AgentResult
return AgentResult(
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
)
def _build_prompt_with_react_format(
self,
original_messages: list[PromptMessage],
agent_scratchpad: list[AgentScratchpadUnit],
include_tools: bool = True,
instruction: str = "",
) -> list[PromptMessage]:
"""Build prompt messages with ReAct format."""
# Copy messages to avoid modifying original
messages = list(original_messages)
# Find and update the system prompt that should already exist
system_prompt_found = False
for i, msg in enumerate(messages):
if isinstance(msg, SystemPromptMessage):
system_prompt_found = True
# The system prompt from frontend already has the template, just replace placeholders
# Format tools
tools_str = ""
tool_names = []
if include_tools and self.tools:
# Convert tools to prompt message tools format
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
tool_names = [tool.name for tool in prompt_tools]
# Format tools as JSON for comprehensive information
from core.model_runtime.utils.encoders import jsonable_encoder
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
else:
tools_str = "No tools available"
tool_names_str = ""
# Replace placeholders in the existing system prompt
updated_content = msg.content
assert isinstance(updated_content, str)
updated_content = updated_content.replace("{{instruction}}", instruction)
updated_content = updated_content.replace("{{tools}}", tools_str)
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
# Create new SystemPromptMessage with updated content
messages[i] = SystemPromptMessage(content=updated_content)
break
# If no system prompt found, that's unexpected but add scratchpad anyway
if not system_prompt_found:
# This shouldn't happen if frontend is working correctly
pass
# Format agent scratchpad
scratchpad_str = ""
if agent_scratchpad:
scratchpad_parts: list[str] = []
for unit in agent_scratchpad:
if unit.thought:
scratchpad_parts.append(f"Thought: {unit.thought}")
if unit.action_str:
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
if unit.observation:
scratchpad_parts.append(f"Observation: {unit.observation}")
scratchpad_str = "\n".join(scratchpad_parts)
# If there's a scratchpad, append it to the last message
if scratchpad_str:
messages.append(AssistantPromptMessage(content=scratchpad_str))
return messages
def _handle_chunks(
self,
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
llm_usage: dict[str, Any],
model_log: AgentLog,
current_messages: list[PromptMessage],
) -> Generator[
LLMResultChunk | AgentLog,
None,
tuple[AgentScratchpadUnit, str | None],
]:
"""Handle LLM response chunks and extract action/thought.
Returns a tuple of (scratchpad_unit, finish_reason).
"""
usage_dict: dict[str, Any] = {}
# Convert non-streaming to streaming format if needed
if isinstance(chunks, LLMResult):
# Create a generator from the LLMResult
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
model=chunks.model,
prompt_messages=chunks.prompt_messages,
delta=LLMResultChunkDelta(
index=0,
message=chunks.message,
usage=chunks.usage,
finish_reason=None, # LLMResult doesn't have finish_reason, only streaming chunks do
),
system_fingerprint=chunks.system_fingerprint or "",
)
streaming_chunks = result_to_chunks()
else:
streaming_chunks = chunks
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
# Initialize scratchpad unit
scratchpad = AgentScratchpadUnit(
agent_response="",
thought="",
action_str="",
observation="",
action=None,
)
finish_reason: str | None = None
# Process chunks
for chunk in react_chunks:
if isinstance(chunk, AgentScratchpadUnit.Action):
# Action detected
action_str = json.dumps(chunk.model_dump())
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
scratchpad.action_str = action_str
scratchpad.action = chunk
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
else:
# Text chunk
chunk_text = str(chunk)
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
scratchpad.thought = (scratchpad.thought or "") + chunk_text
yield self._create_text_chunk(chunk_text, current_messages)
# Update usage
if usage_dict.get("usage"):
if llm_usage.get("usage"):
self._accumulate_usage(llm_usage, usage_dict["usage"])
else:
llm_usage["usage"] = usage_dict["usage"]
# Clean up thought
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
# Finish model log
yield self._finish_log(
model_log,
data={
"thought": scratchpad.thought,
"action": scratchpad.action_str if scratchpad.action else None,
},
usage=llm_usage.get("usage"),
)
return scratchpad, finish_reason
def _handle_tool_call(
self,
action: AgentScratchpadUnit.Action,
prompt_messages: list[PromptMessage],
round_log: AgentLog,
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
"""Handle tool call and return observation with files."""
tool_name = action.action_name
tool_args: dict[str, Any] | str = action.action_input
# Find tool instance first to get metadata
tool_instance = self._find_tool_by_name(tool_name)
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
# Start tool log with tool metadata
tool_log = self._create_log(
label=f"CALL {tool_name}",
log_type=AgentLog.LogType.TOOL_CALL,
status=AgentLog.LogStatus.START,
data={
"tool_name": tool_name,
"tool_args": tool_args,
},
parent_id=round_log.id,
extra_metadata=tool_metadata,
)
yield tool_log
if not tool_instance:
# Finish tool log with error
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"error": f"Tool {tool_name} not found",
},
)
return f"Tool {tool_name} not found", []
# Ensure tool_args is a dict
tool_args_dict: dict[str, Any]
if isinstance(tool_args, str):
try:
tool_args_dict = json.loads(tool_args)
except json.JSONDecodeError:
tool_args_dict = {"input": tool_args}
elif not isinstance(tool_args, dict):
tool_args_dict = {"input": str(tool_args)}
else:
tool_args_dict = tool_args
# Invoke tool using base class method with error handling
try:
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
# Finish tool log
yield self._finish_log(
tool_log,
data={
**tool_log.data,
"output": response_content,
"files": len(tool_files),
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
},
)
return response_content or "Tool executed successfully", tool_files
except Exception as e:
# Tool invocation failed, yield error log
error_message = str(e)
tool_log.status = AgentLog.LogStatus.ERROR
tool_log.error = error_message
tool_log.data = {
**tool_log.data,
"error": error_message,
}
yield tool_log
return f"Tool execution failed: {error_message}", []

View File

@@ -1,107 +0,0 @@
"""Strategy factory for creating agent strategies."""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.agent.entities import AgentEntity, ExecutionContext
from core.file.models import File
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from .base import AgentPattern, ToolInvokeHook
from .function_call import FunctionCallStrategy
from .react import ReActStrategy
if TYPE_CHECKING:
from core.tools.__base.tool import Tool
class StrategyFactory:
"""Factory for creating agent strategies based on model features."""
# Tool calling related features
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
@staticmethod
def create_strategy(
model_features: list[ModelFeature],
model_instance: ModelInstance,
context: ExecutionContext,
tools: list[Tool],
files: list[File],
max_iterations: int = 10,
workflow_call_depth: int = 0,
agent_strategy: AgentEntity.Strategy | None = None,
tool_invoke_hook: ToolInvokeHook | None = None,
instruction: str = "",
) -> AgentPattern:
"""
Create an appropriate strategy based on model features.
Args:
model_features: List of model features/capabilities
model_instance: Model instance to use
context: Execution context containing trace/audit information
tools: Available tools
files: Available files
max_iterations: Maximum iterations for the strategy
workflow_call_depth: Depth of workflow calls
agent_strategy: Optional explicit strategy override
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
instruction: Optional instruction for ReAct strategy
Returns:
AgentStrategy instance
"""
# If explicit strategy is provided and it's Function Calling, try to use it if supported
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
# Fallback to ReAct if FC is requested but not supported
# If explicit strategy is Chain of Thought (ReAct)
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)
# Default auto-selection logic
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
# Model supports native function calling
return FunctionCallStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
)
else:
# Use ReAct strategy for models without function calling
return ReActStrategy(
model_instance=model_instance,
context=context,
tools=tools,
files=files,
max_iterations=max_iterations,
workflow_call_depth=workflow_call_depth,
tool_invoke_hook=tool_invoke_hook,
instruction=instruction,
)

View File

@@ -1,4 +1,3 @@
import json
from collections.abc import Sequence
from enum import StrEnum, auto
from typing import Any, Literal
@@ -121,7 +120,7 @@ class VariableEntity(BaseModel):
allowed_file_types: Sequence[FileType] | None = Field(default_factory=list)
allowed_file_extensions: Sequence[str] | None = Field(default_factory=list)
allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list)
json_schema: str | None = Field(default=None)
json_schema: dict | None = Field(default=None)
@field_validator("description", mode="before")
@classmethod
@@ -135,17 +134,11 @@ class VariableEntity(BaseModel):
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, schema: str | None) -> str | None:
def validate_json_schema(cls, schema: dict | None) -> dict | None:
if schema is None:
return None
try:
json_schema = json.loads(schema)
except json.JSONDecodeError:
raise ValueError(f"invalid json_schema value {schema}")
try:
Draft7Validator.check_schema(json_schema)
Draft7Validator.check_schema(schema)
except SchemaError as e:
raise ValueError(f"Invalid JSON schema: {e.message}")
return schema

View File

@@ -26,7 +26,6 @@ class AdvancedChatAppConfigManager(BaseAppConfigManager):
@classmethod
def get_app_config(cls, app_model: App, workflow: Workflow) -> AdvancedChatAppConfig:
features_dict = workflow.features_dict
app_mode = AppMode.value_of(app_model.mode)
app_config = AdvancedChatAppConfig(
tenant_id=app_model.tenant_id,

View File

@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -39,7 +39,6 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from models import Workflow
from models.enums import UserFrom
from models.model import App, Conversation, Message, MessageAnnotation
from models.workflow import ConversationVariable
from services.conversation_variable_updater import ConversationVariableUpdater
@@ -106,6 +105,11 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
if not app_record:
raise ValueError("App not found")
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
@@ -145,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -158,6 +162,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
)
db.session.close()
@@ -175,12 +181,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
user_from=user_from,
invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,
@@ -316,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -341,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""

View File

@@ -4,7 +4,6 @@ import re
import time
from collections.abc import Callable, Generator, Mapping
from contextlib import contextmanager
from dataclasses import dataclass, field
from threading import Thread
from typing import Any, Union
@@ -20,7 +19,6 @@ from core.app.entities.app_invoke_entities import (
InvokeFrom,
)
from core.app.entities.queue_entities import (
ChunkType,
MessageQueueMessage,
QueueAdvancedChatMessageEndEvent,
QueueAgentLogEvent,
@@ -72,122 +70,13 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, LLMGenerationDetail, Message, MessageFile
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@dataclass
class StreamEventBuffer:
"""
Buffer for recording stream events in order to reconstruct the generation sequence.
Records the exact order of text chunks, thoughts, and tool calls as they stream.
"""
# Accumulated reasoning content (each thought block is a separate element)
reasoning_content: list[str] = field(default_factory=list)
# Current reasoning buffer (accumulates until we see a different event type)
_current_reasoning: str = ""
# Tool calls with their details
tool_calls: list[dict] = field(default_factory=list)
# Tool call ID to index mapping for updating results
_tool_call_id_map: dict[str, int] = field(default_factory=dict)
# Sequence of events in stream order
sequence: list[dict] = field(default_factory=list)
# Current position in answer text
_content_position: int = 0
# Track last event type to detect transitions
_last_event_type: str | None = None
def _flush_current_reasoning(self) -> None:
"""Flush accumulated reasoning to the list and add to sequence."""
if self._current_reasoning.strip():
self.reasoning_content.append(self._current_reasoning.strip())
self.sequence.append({"type": "reasoning", "index": len(self.reasoning_content) - 1})
self._current_reasoning = ""
def record_text_chunk(self, text: str) -> None:
"""Record a text chunk event."""
if not text:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
text_len = len(text)
start_pos = self._content_position
# If last event was also content, extend it; otherwise create new
if self.sequence and self.sequence[-1].get("type") == "content":
self.sequence[-1]["end"] = start_pos + text_len
else:
self.sequence.append({"type": "content", "start": start_pos, "end": start_pos + text_len})
self._content_position += text_len
self._last_event_type = "content"
def record_thought_chunk(self, text: str) -> None:
"""Record a thought/reasoning chunk event."""
if not text:
return
# Accumulate thought content
self._current_reasoning += text
self._last_event_type = "thought"
def record_tool_call(self, tool_call_id: str, tool_name: str, tool_arguments: str) -> None:
"""Record a tool call event."""
if not tool_call_id:
return
# Flush any pending reasoning first
if self._last_event_type == "thought":
self._flush_current_reasoning()
# Check if this tool call already exists (we might get multiple chunks)
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
# Update arguments if provided
if tool_arguments:
self.tool_calls[idx]["arguments"] = tool_arguments
else:
# New tool call
tool_call = {
"id": tool_call_id or "",
"name": tool_name or "",
"arguments": tool_arguments or "",
"result": "",
"elapsed_time": None,
}
self.tool_calls.append(tool_call)
idx = len(self.tool_calls) - 1
self._tool_call_id_map[tool_call_id] = idx
self.sequence.append({"type": "tool_call", "index": idx})
self._last_event_type = "tool_call"
def record_tool_result(self, tool_call_id: str, result: str, tool_elapsed_time: float | None = None) -> None:
"""Record a tool result event (update existing tool call)."""
if not tool_call_id:
return
if tool_call_id in self._tool_call_id_map:
idx = self._tool_call_id_map[tool_call_id]
self.tool_calls[idx]["result"] = result
self.tool_calls[idx]["elapsed_time"] = tool_elapsed_time
def finalize(self) -> None:
"""Finalize the buffer, flushing any pending data."""
if self._last_event_type == "thought":
self._flush_current_reasoning()
def has_data(self) -> bool:
"""Check if there's any meaningful data recorded."""
return bool(self.reasoning_content or self.tool_calls or self.sequence)
class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""
AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application.
@@ -255,8 +144,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
# Stream event buffer for recording generation sequence
self._stream_buffer = StreamEventBuffer()
self._seed_graph_runtime_state_from_queue_manager()
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
@@ -471,25 +358,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if node_finish_resp:
yield node_finish_resp
# For ANSWER nodes, check if we need to send a message_replace event
# Only send if the final output differs from the accumulated task_state.answer
# This happens when variables were updated by variable_assigner during workflow execution
if event.node_type == NodeType.ANSWER and event.outputs:
final_answer = event.outputs.get("answer")
if final_answer is not None and final_answer != self._task_state.answer:
logger.info(
"ANSWER node final output '%s' differs from accumulated answer '%s', sending message_replace event",
final_answer,
self._task_state.answer,
)
# Update the task state answer
self._task_state.answer = str(final_answer)
# Send message_replace event to update the UI
yield self._message_cycle_manager.message_replace_to_stream_response(
answer=str(final_answer),
reason="variable_update",
)
def _handle_node_failed_events(
self,
event: Union[QueueNodeFailedEvent, QueueNodeExceptionEvent],
@@ -515,7 +383,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
queue_message: Union[WorkflowQueueMessage, MessageQueueMessage] | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle text chunk events and record to stream buffer for sequence reconstruction."""
"""Handle text chunk events."""
delta_text = event.text
if delta_text is None:
return
@@ -537,52 +405,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else ""
tool_name = tool_payload.name if tool_payload and tool_payload.name else ""
tool_arguments = tool_call.arguments if tool_call and tool_call.arguments else ""
tool_files = tool_result.files if tool_result else []
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# Record stream event based on chunk type
chunk_type = event.chunk_type or ChunkType.TEXT
match chunk_type:
case ChunkType.TEXT:
self._stream_buffer.record_text_chunk(delta_text)
self._task_state.answer += delta_text
case ChunkType.THOUGHT:
# Reasoning should not be part of final answer text
self._stream_buffer.record_thought_chunk(delta_text)
case ChunkType.TOOL_CALL:
self._stream_buffer.record_tool_call(
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
)
case ChunkType.TOOL_RESULT:
self._stream_buffer.record_tool_result(
tool_call_id=tool_call_id,
result=delta_text,
tool_elapsed_time=tool_elapsed_time,
)
self._task_state.answer += delta_text
case _:
pass
self._task_state.answer += delta_text
yield self._message_cycle_manager.message_to_stream_response(
answer=delta_text,
message_id=self._message_id,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type.value if event.chunk_type else None,
tool_call_id=tool_call_id or None,
tool_name=tool_name or None,
tool_arguments=tool_arguments or None,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
answer=delta_text, message_id=self._message_id, from_variable_selector=event.from_variable_selector
)
def _handle_iteration_start_event(
@@ -950,7 +775,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer
answer_text = self._strip_think_blocks(answer_text)
if self._recorded_files:
# Remove markdown image links since we're storing files separately
answer_text = re.sub(r"!\[.*?\]\(.*?\)", "", answer_text).strip()
@@ -1002,54 +826,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
# Save generation detail (reasoning/tool calls/sequence) from stream buffer
self._save_generation_detail(session=session, message=message)
@staticmethod
def _strip_think_blocks(text: str) -> str:
"""Remove <think>...</think> blocks (including their content) from text."""
if not text or "<think" not in text.lower():
return text
clean_text = re.sub(r"<think[^>]*>.*?</think>", "", text, flags=re.IGNORECASE | re.DOTALL)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text
def _save_generation_detail(self, *, session: Session, message: Message) -> None:
"""
Save LLM generation detail for Chatflow using stream event buffer.
The buffer records the exact order of events as they streamed,
allowing accurate reconstruction of the generation sequence.
"""
# Finalize the stream buffer to flush any pending data
self._stream_buffer.finalize()
# Only save if there's meaningful data
if not self._stream_buffer.has_data():
return
reasoning_content = self._stream_buffer.reasoning_content
tool_calls = self._stream_buffer.tool_calls
sequence = self._stream_buffer.sequence
# Check if generation detail already exists for this message
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_content) if reasoning_content else None
existing.tool_calls = json.dumps(tool_calls) if tool_calls else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_content) if reasoning_content else None,
tool_calls=json.dumps(tool_calls) if tool_calls else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@@ -3,8 +3,10 @@ from typing import cast
from sqlalchemy import select
from core.agent.agent_app_runner import AgentAppRunner
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
from core.agent.entities import AgentEntity
from core.agent.fc_agent_runner import FunctionCallAgentRunner
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.apps.base_app_runner import AppRunner
@@ -12,7 +14,8 @@ from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity
from core.app.entities.queue_entities import QueueAnnotationReplyEvent
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.model_entities import ModelFeature
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.moderation.base import ModerationError
from extensions.ext_database import db
@@ -191,7 +194,22 @@ class AgentChatAppRunner(AppRunner):
raise ValueError("Message not found")
db.session.close()
runner = AgentAppRunner(
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
# start agent runner
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
# check LLM mode
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
runner_cls = CotChatAgentRunner
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
runner_cls = CotCompletionAgentRunner
else:
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
runner_cls = FunctionCallAgentRunner
else:
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
runner = runner_cls(
tenant_id=app_config.tenant_id,
application_generate_entity=application_generate_entity,
conversation=conversation_result,

View File

@@ -1,4 +1,3 @@
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Union, final
@@ -76,12 +75,24 @@ class BaseAppGenerator:
user_inputs = {**user_inputs, **files_inputs, **file_list_inputs}
# Check if all files are converted to File
if any(filter(lambda v: isinstance(v, dict), user_inputs.values())):
raise ValueError("Invalid input type")
if any(
filter(lambda v: isinstance(v, dict), filter(lambda item: isinstance(item, list), user_inputs.values()))
):
raise ValueError("Invalid input type")
invalid_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, dict)
and entity_dictionary[k].type not in {VariableEntityType.FILE, VariableEntityType.JSON_OBJECT}
]
if invalid_dict_keys:
raise ValueError(f"Invalid input type for {invalid_dict_keys}")
invalid_list_dict_keys = [
k
for k, v in user_inputs.items()
if isinstance(v, list)
and any(isinstance(item, dict) for item in v)
and entity_dictionary[k].type != VariableEntityType.FILE_LIST
]
if invalid_list_dict_keys:
raise ValueError(f"Invalid input type for {invalid_list_dict_keys}")
return user_inputs
@@ -178,12 +189,8 @@ class BaseAppGenerator:
elif value == 0:
value = False
case VariableEntityType.JSON_OBJECT:
if not isinstance(value, str):
raise ValueError(f"{variable_entity.variable} in input form must be a string")
try:
json.loads(value)
except json.JSONDecodeError:
raise ValueError(f"{variable_entity.variable} in input form must be a valid JSON object")
if value and not isinstance(value, dict):
raise ValueError(f"{variable_entity.variable} in input form must be a dict")
case _:
raise AssertionError("this statement should be unreachable.")

View File

@@ -671,7 +671,7 @@ class WorkflowResponseConverter:
task_id=task_id,
data=AgentLogStreamResponse.Data(
node_execution_id=event.node_execution_id,
message_id=event.id,
id=event.id,
parent_id=event.parent_id,
label=event.label,
error=event.error,

View File

@@ -73,9 +73,15 @@ class PipelineRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(PipelineConfig, app_config)
invoke_from = self.application_generate_entity.invoke_from
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
user_id = None
if self.application_generate_entity.invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
if invoke_from in {InvokeFrom.WEB_APP, InvokeFrom.SERVICE_API}:
end_user = db.session.query(EndUser).where(EndUser.id == self.application_generate_entity.user_id).first()
if end_user:
user_id = end_user.session_id
@@ -117,7 +123,7 @@ class PipelineRunner(WorkflowBasedAppRunner):
dataset_id=self.application_generate_entity.dataset_id,
datasource_type=self.application_generate_entity.datasource_type,
datasource_info=self.application_generate_entity.datasource_info,
invoke_from=self.application_generate_entity.invoke_from.value,
invoke_from=invoke_from.value,
)
rag_pipeline_variables = []
@@ -149,6 +155,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph_runtime_state=graph_runtime_state,
start_node_id=self.application_generate_entity.start_node_id,
workflow=workflow,
user_from=user_from,
invoke_from=invoke_from,
)
# RUN WORKFLOW
@@ -159,12 +167,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
user_from=user_from,
invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
graph_runtime_state=graph_runtime_state,
variable_pool=variable_pool,
@@ -210,7 +214,12 @@ class PipelineRunner(WorkflowBasedAppRunner):
return workflow
def _init_rag_pipeline_graph(
self, workflow: Workflow, graph_runtime_state: GraphRuntimeState, start_node_id: str | None = None
self,
workflow: Workflow,
graph_runtime_state: GraphRuntimeState,
start_node_id: str | None = None,
user_from: UserFrom = UserFrom.ACCOUNT,
invoke_from: InvokeFrom = InvokeFrom.SERVICE_API,
) -> Graph:
"""
Init pipeline graph
@@ -253,8 +262,8 @@ class PipelineRunner(WorkflowBasedAppRunner):
workflow_id=workflow.id,
graph_config=graph_config,
user_id=self.application_generate_entity.user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
user_from=user_from,
invoke_from=invoke_from,
call_depth=0,
)

View File

@@ -20,7 +20,6 @@ from core.workflow.workflow_entry import WorkflowEntry
from extensions.ext_redis import redis_client
from extensions.otel import WorkflowAppRunnerHandler, trace_span
from libs.datetime_utils import naive_utc_now
from models.enums import UserFrom
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -74,7 +73,12 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
@@ -102,6 +106,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
@@ -120,12 +126,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
graph=graph,
graph_config=self._workflow.graph_dict,
user_id=self.application_generate_entity.user_id,
user_from=(
UserFrom.ACCOUNT
if self.application_generate_entity.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
else UserFrom.END_USER
),
invoke_from=self.application_generate_entity.invoke_from,
user_from=user_from,
invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
variable_pool=variable_pool,
graph_runtime_state=graph_runtime_state,

View File

@@ -13,7 +13,6 @@ from core.app.apps.common.workflow_response_converter import WorkflowResponseCon
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
AppQueueEvent,
ChunkType,
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
@@ -484,33 +483,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if delta_text is None:
return
tool_call = event.tool_call
tool_result = event.tool_result
tool_payload = tool_call or tool_result
tool_call_id = tool_payload.id if tool_payload and tool_payload.id else None
tool_name = tool_payload.name if tool_payload and tool_payload.name else None
tool_arguments = tool_call.arguments if tool_call else None
tool_elapsed_time = tool_result.elapsed_time if tool_result else None
tool_files = tool_result.files if tool_result else []
tool_icon = tool_payload.icon if tool_payload else None
tool_icon_dark = tool_payload.icon_dark if tool_payload else None
# only publish tts message at text chunk streaming
if tts_publisher and queue_message:
tts_publisher.publish(queue_message)
yield self._text_chunk_to_stream_response(
text=delta_text,
from_variable_selector=event.from_variable_selector,
chunk_type=event.chunk_type,
tool_call_id=tool_call_id,
tool_name=tool_name,
tool_arguments=tool_arguments,
tool_files=tool_files,
tool_elapsed_time=tool_elapsed_time,
tool_icon=tool_icon,
tool_icon_dark=tool_icon_dark,
)
yield self._text_chunk_to_stream_response(delta_text, from_variable_selector=event.from_variable_selector)
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
@@ -673,61 +650,16 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
session.add(workflow_app_log)
def _text_chunk_to_stream_response(
self,
text: str,
from_variable_selector: list[str] | None = None,
chunk_type: ChunkType | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
self, text: str, from_variable_selector: list[str] | None = None
) -> TextChunkStreamResponse:
"""
Handle completed event.
:param text: text
:return:
"""
from core.app.entities.task_entities import ChunkType as ResponseChunkType
response_chunk_type = ResponseChunkType(chunk_type.value) if chunk_type else ResponseChunkType.TEXT
data = TextChunkStreamResponse.Data(
text=text,
from_variable_selector=from_variable_selector,
chunk_type=response_chunk_type,
)
if response_chunk_type == ResponseChunkType.TOOL_CALL:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif response_chunk_type == ResponseChunkType.TOOL_RESULT:
data = data.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
response = TextChunkStreamResponse(
task_id=self._application_generate_entity.task_id,
data=data,
data=TextChunkStreamResponse.Data(text=text, from_variable_selector=from_variable_selector),
)
return response

View File

@@ -77,10 +77,18 @@ class WorkflowBasedAppRunner:
self._app_id = app_id
self._graph_engine_layers = graph_engine_layers
@staticmethod
def _resolve_user_from(invoke_from: InvokeFrom) -> UserFrom:
if invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}:
return UserFrom.ACCOUNT
return UserFrom.END_USER
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@@ -105,8 +113,8 @@ class WorkflowBasedAppRunner:
workflow_id=workflow_id,
graph_config=graph_config,
user_id=user_id,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
user_from=user_from,
invoke_from=invoke_from,
call_depth=0,
)
@@ -250,7 +258,7 @@ class WorkflowBasedAppRunner:
graph_config=graph_config,
user_id="",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
)
@@ -455,20 +463,12 @@ class WorkflowBasedAppRunner:
)
)
elif isinstance(event, NodeRunStreamChunkEvent):
from core.app.entities.queue_entities import ChunkType as QueueChunkType
if event.is_final and not event.chunk:
return
self._publish_event(
QueueTextChunkEvent(
text=event.chunk,
from_variable_selector=list(event.selector),
in_iteration_id=event.in_iteration_id,
in_loop_id=event.in_loop_id,
chunk_type=QueueChunkType(event.chunk_type.value),
tool_call=event.tool_call,
tool_result=event.tool_result,
)
)
elif isinstance(event, NodeRunRetrieverResourceEvent):

View File

@@ -1,70 +0,0 @@
"""
LLM Generation Detail entities.
Defines the structure for storing and transmitting LLM generation details
including reasoning content, tool calls, and their sequence.
"""
from typing import Literal
from pydantic import BaseModel, Field
class ContentSegment(BaseModel):
"""Represents a content segment in the generation sequence."""
type: Literal["content"] = "content"
start: int = Field(..., description="Start position in the text")
end: int = Field(..., description="End position in the text")
class ReasoningSegment(BaseModel):
"""Represents a reasoning segment in the generation sequence."""
type: Literal["reasoning"] = "reasoning"
index: int = Field(..., description="Index into reasoning_content array")
class ToolCallSegment(BaseModel):
"""Represents a tool call segment in the generation sequence."""
type: Literal["tool_call"] = "tool_call"
index: int = Field(..., description="Index into tool_calls array")
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
class ToolCallDetail(BaseModel):
"""Represents a tool call with its arguments and result."""
id: str = Field(default="", description="Unique identifier for the tool call")
name: str = Field(..., description="Name of the tool")
arguments: str = Field(default="", description="JSON string of tool arguments")
result: str = Field(default="", description="Result from the tool execution")
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
class LLMGenerationDetailData(BaseModel):
"""
Domain model for LLM generation detail.
Contains the structured data for reasoning content, tool calls,
and their display sequence.
"""
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
def is_empty(self) -> bool:
"""Check if there's any meaningful generation detail."""
return not self.reasoning_content and not self.tool_calls
def to_response_dict(self) -> dict:
"""Convert to dictionary for API response."""
return {
"reasoning_content": self.reasoning_content,
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
"sequence": [seg.model_dump() for seg in self.sequence],
}

View File

@@ -7,7 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@@ -177,17 +177,6 @@ class QueueLoopCompletedEvent(AppQueueEvent):
error: str | None = None
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class QueueTextChunkEvent(AppQueueEvent):
"""
QueueTextChunkEvent entity
@@ -202,16 +191,6 @@ class QueueTextChunkEvent(AppQueueEvent):
in_loop_id: str | None = None
"""loop id if node is in loop"""
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool streaming payloads
tool_call: ToolCall | None = None
"""structured tool call info"""
tool_result: ToolResult | None = None
"""structured tool result info"""
class QueueAgentMessageEvent(AppQueueEvent):
"""

View File

@@ -113,38 +113,6 @@ class MessageStreamResponse(StreamResponse):
answer: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming (imported at runtime to avoid circular import)
chunk_type: str | None = None
"""type of the chunk: text, tool_call, tool_result, thought"""
# Tool call fields (when chunk_type == "tool_call")
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == "tool_result")
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
tool_icon: str | dict | None = None
"""icon of the tool"""
tool_icon_dark: str | dict | None = None
"""dark theme icon of the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
class MessageAudioStreamResponse(StreamResponse):
"""
@@ -614,17 +582,6 @@ class LoopNodeCompletedStreamResponse(StreamResponse):
data: Data
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class TextChunkStreamResponse(StreamResponse):
"""
TextChunkStreamResponse entity
@@ -638,36 +595,6 @@ class TextChunkStreamResponse(StreamResponse):
text: str
from_variable_selector: list[str] | None = None
# Extended fields for Agent/Tool streaming
chunk_type: ChunkType = ChunkType.TEXT
"""type of the chunk"""
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call_id: str | None = None
"""unique identifier for this tool call"""
tool_name: str | None = None
"""name of the tool being called"""
tool_arguments: str | None = None
"""accumulated tool arguments JSON"""
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_files: list[str] | None = None
"""file IDs produced by tool"""
tool_error: str | None = None
"""error message if tool failed"""
# Tool elapsed time fields (when chunk_type == TOOL_RESULT)
tool_elapsed_time: float | None = None
"""elapsed time spent executing the tool"""
def model_dump(self, *args, **kwargs) -> dict[str, object]:
kwargs.setdefault("exclude_none", True)
return super().model_dump(*args, **kwargs)
def model_dump_json(self, *args, **kwargs) -> str:
kwargs.setdefault("exclude_none", True)
return super().model_dump_json(*args, **kwargs)
event: StreamEvent = StreamEvent.TEXT_CHUNK
data: Data
@@ -816,7 +743,7 @@ class AgentLogStreamResponse(StreamResponse):
"""
node_execution_id: str
message_id: str
id: str
label: str
parent_id: str | None = None
error: str | None = None

View File

@@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,

View File

@@ -1,5 +1,4 @@
import logging
import re
import time
from collections.abc import Generator
from threading import Thread
@@ -59,7 +58,7 @@ from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, LLMGenerationDetail, Message, MessageAgentThought
from models.model import AppMode, Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@@ -69,8 +68,6 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
EasyUIBasedGenerateTaskPipeline is a class that generate stream output and state management for Application.
"""
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
_task_state: EasyUITaskState
_application_generate_entity: Union[ChatAppGenerateEntity, CompletionAppGenerateEntity, AgentChatAppGenerateEntity]
@@ -412,136 +409,11 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
)
)
# Save LLM generation detail if there's reasoning_content
self._save_generation_detail(session=session, message=message, llm_result=llm_result)
message_was_created.send(
message,
application_generate_entity=self._application_generate_entity,
)
def _save_generation_detail(self, *, session: Session, message: Message, llm_result: LLMResult) -> None:
"""
Save LLM generation detail for Completion/Chat/Agent-Chat applications.
For Agent-Chat, also merges MessageAgentThought records.
"""
import json
reasoning_list: list[str] = []
tool_calls_list: list[dict] = []
sequence: list[dict] = []
answer = message.answer or ""
# Check if this is Agent-Chat mode by looking for agent thoughts
agent_thoughts = (
session.query(MessageAgentThought)
.filter_by(message_id=message.id)
.order_by(MessageAgentThought.position.asc())
.all()
)
if agent_thoughts:
# Agent-Chat mode: merge MessageAgentThought records
content_pos = 0
cleaned_answer_parts: list[str] = []
for thought in agent_thoughts:
# Add thought/reasoning
if thought.thought:
reasoning_text = thought.thought
if "<think" in reasoning_text.lower():
clean_text, extracted_reasoning = self._split_reasoning_from_answer(reasoning_text)
if extracted_reasoning:
reasoning_text = extracted_reasoning
thought.thought = clean_text or extracted_reasoning
reasoning_list.append(reasoning_text)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
# Add tool calls
if thought.tool:
tool_calls_list.append(
{
"name": thought.tool,
"arguments": thought.tool_input or "",
"result": thought.observation or "",
}
)
sequence.append({"type": "tool_call", "index": len(tool_calls_list) - 1})
# Add answer content if present
if thought.answer:
content_text = thought.answer
if "<think" in content_text.lower():
clean_answer, extracted_reasoning = self._split_reasoning_from_answer(content_text)
if extracted_reasoning:
reasoning_list.append(extracted_reasoning)
sequence.append({"type": "reasoning", "index": len(reasoning_list) - 1})
content_text = clean_answer
thought.answer = clean_answer or content_text
if content_text:
start = content_pos
end = content_pos + len(content_text)
sequence.append({"type": "content", "start": start, "end": end})
content_pos = end
cleaned_answer_parts.append(content_text)
if cleaned_answer_parts:
merged_answer = "".join(cleaned_answer_parts)
message.answer = merged_answer
llm_result.message.content = merged_answer
else:
# Completion/Chat mode: use reasoning_content from llm_result
reasoning_content = llm_result.reasoning_content
if not reasoning_content and answer:
# Extract reasoning from <think> blocks and clean the final answer
clean_answer, reasoning_content = self._split_reasoning_from_answer(answer)
if reasoning_content:
answer = clean_answer
llm_result.message.content = clean_answer
llm_result.reasoning_content = reasoning_content
message.answer = clean_answer
if reasoning_content:
reasoning_list = [reasoning_content]
# Content comes first, then reasoning
if answer:
sequence.append({"type": "content", "start": 0, "end": len(answer)})
sequence.append({"type": "reasoning", "index": 0})
# Only save if there's meaningful generation detail
if not reasoning_list and not tool_calls_list:
return
# Check if generation detail already exists
existing = session.query(LLMGenerationDetail).filter_by(message_id=message.id).first()
if existing:
existing.reasoning_content = json.dumps(reasoning_list) if reasoning_list else None
existing.tool_calls = json.dumps(tool_calls_list) if tool_calls_list else None
existing.sequence = json.dumps(sequence) if sequence else None
else:
generation_detail = LLMGenerationDetail(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
message_id=message.id,
reasoning_content=json.dumps(reasoning_list) if reasoning_list else None,
tool_calls=json.dumps(tool_calls_list) if tool_calls_list else None,
sequence=json.dumps(sequence) if sequence else None,
)
session.add(generation_detail)
@classmethod
def _split_reasoning_from_answer(cls, text: str) -> tuple[str, str]:
"""
Extract reasoning segments from <think> blocks and return (clean_text, reasoning).
"""
matches = cls._THINK_PATTERN.findall(text)
reasoning_content = "\n".join(match.strip() for match in matches) if matches else ""
clean_text = cls._THINK_PATTERN.sub("", text)
clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip()
return clean_text, reasoning_content or ""
def _handle_stop(self, event: QueueStopEvent):
"""
Handle stop.

View File

@@ -232,31 +232,15 @@ class MessageCycleManager:
answer: str,
message_id: str,
from_variable_selector: list[str] | None = None,
chunk_type: str | None = None,
tool_call_id: str | None = None,
tool_name: str | None = None,
tool_arguments: str | None = None,
tool_files: list[str] | None = None,
tool_error: str | None = None,
tool_elapsed_time: float | None = None,
tool_icon: str | dict | None = None,
tool_icon_dark: str | dict | None = None,
event_type: StreamEvent | None = None,
) -> MessageStreamResponse:
"""
Message to stream response.
:param answer: answer
:param message_id: message id
:param from_variable_selector: from variable selector
:param chunk_type: type of the chunk (text, function_call, tool_result, thought)
:param tool_call_id: unique identifier for this tool call
:param tool_name: name of the tool being called
:param tool_arguments: accumulated tool arguments JSON
:param tool_files: file IDs produced by tool
:param tool_error: error message if tool failed
:return:
"""
response = MessageStreamResponse(
return MessageStreamResponse(
task_id=self._application_generate_entity.task_id,
id=message_id,
answer=answer,
@@ -264,35 +248,6 @@ class MessageCycleManager:
event=event_type or StreamEvent.MESSAGE,
)
if chunk_type:
response = response.model_copy(update={"chunk_type": chunk_type})
if chunk_type == "tool_call":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
elif chunk_type == "tool_result":
response = response.model_copy(
update={
"tool_call_id": tool_call_id,
"tool_name": tool_name,
"tool_arguments": tool_arguments,
"tool_files": tool_files,
"tool_error": tool_error,
"tool_elapsed_time": tool_elapsed_time,
"tool_icon": tool_icon,
"tool_icon_dark": tool_icon_dark,
}
)
return response
def message_replace_to_stream_response(self, answer: str, reason: str = "") -> MessageReplaceStreamResponse:
"""
Message replace to stream response.

View File

@@ -5,6 +5,7 @@ from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.rag.index_processor.constant.index_type import IndexStructureType
from core.rag.models.document import Document
@@ -89,8 +90,6 @@ class DatasetIndexToolCallbackHandler:
# TODO(-LAN-): Improve type check
def return_retriever_resource_info(self, resource: Sequence[RetrievalSourceMetadata]):
"""Handle return_retriever_resource_info."""
from core.app.entities.queue_entities import QueueRetrieverResourcesEvent
self._queue_manager.publish(
QueueRetrieverResourcesEvent(retriever_resources=resource), PublishFrom.APPLICATION_MANAGER
)

View File

@@ -33,6 +33,10 @@ class MaxRetriesExceededError(ValueError):
pass
request_error = httpx.RequestError
max_retries_exceeded_error = MaxRetriesExceededError
def _create_proxy_mounts() -> dict[str, httpx.HTTPTransport]:
return {
"http://": httpx.HTTPTransport(

View File

@@ -56,6 +56,10 @@ class HostingConfiguration:
self.provider_map[f"{DEFAULT_PLUGIN_ID}/minimax/minimax"] = self.init_minimax()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/spark/spark"] = self.init_spark()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/zhipuai/zhipuai"] = self.init_zhipuai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/gemini/google"] = self.init_gemini()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/x/x"] = self.init_xai()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/deepseek/deepseek"] = self.init_deepseek()
self.provider_map[f"{DEFAULT_PLUGIN_ID}/tongyi/tongyi"] = self.init_tongyi()
self.moderation_config = self.init_moderation_config()
@@ -128,7 +132,7 @@ class HostingConfiguration:
quotas: list[HostingQuota] = []
if dify_config.HOSTED_OPENAI_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_OPENAI_QUOTA_LIMIT
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_OPENAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
@@ -156,18 +160,49 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
@staticmethod
def init_anthropic() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS
def init_gemini(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_GEMINI_TRIAL_ENABLED:
hosted_quota_limit = 0
trial_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trial_models)
quotas.append(trial_quota)
if dify_config.HOSTED_GEMINI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_GEMINI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"google_api_key": dify_config.HOSTED_GEMINI_API_KEY,
}
if dify_config.HOSTED_GEMINI_API_BASE:
credentials["google_base_url"] = dify_config.HOSTED_GEMINI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_anthropic(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_ANTHROPIC_TRIAL_ENABLED:
hosted_quota_limit = dify_config.HOSTED_ANTHROPIC_QUOTA_LIMIT
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit)
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_ANTHROPIC_PAID_ENABLED:
paid_quota = PaidHostingQuota()
paid_models = self.parse_restrict_models_from_env("HOSTED_ANTHROPIC_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
@@ -185,6 +220,94 @@ class HostingConfiguration:
quota_unit=quota_unit,
)
def init_tongyi(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_TONGYI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_TONGYI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_TONGYI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"dashscope_api_key": dify_config.HOSTED_TONGYI_API_KEY,
"use_international_endpoint": dify_config.HOSTED_TONGYI_USE_INTERNATIONAL_ENDPOINT,
}
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_xai(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_XAI_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_XAI_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_XAI_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_XAI_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_XAI_API_KEY,
}
if dify_config.HOSTED_XAI_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_XAI_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
def init_deepseek(self) -> HostingProvider:
quota_unit = QuotaUnit.CREDITS
quotas: list[HostingQuota] = []
if dify_config.HOSTED_DEEPSEEK_TRIAL_ENABLED:
hosted_quota_limit = 0
trail_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_TRIAL_MODELS")
trial_quota = TrialHostingQuota(quota_limit=hosted_quota_limit, restrict_models=trail_models)
quotas.append(trial_quota)
if dify_config.HOSTED_DEEPSEEK_PAID_ENABLED:
paid_models = self.parse_restrict_models_from_env("HOSTED_DEEPSEEK_PAID_MODELS")
paid_quota = PaidHostingQuota(restrict_models=paid_models)
quotas.append(paid_quota)
if len(quotas) > 0:
credentials = {
"api_key": dify_config.HOSTED_DEEPSEEK_API_KEY,
}
if dify_config.HOSTED_DEEPSEEK_API_BASE:
credentials["endpoint_url"] = dify_config.HOSTED_DEEPSEEK_API_BASE
return HostingProvider(enabled=True, credentials=credentials, quota_unit=quota_unit, quotas=quotas)
return HostingProvider(
enabled=False,
quota_unit=quota_unit,
)
@staticmethod
def init_minimax() -> HostingProvider:
quota_unit = QuotaUnit.TOKENS

View File

@@ -71,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = cast(str, response.message.content)
if answer is None:
answer = response.message.get_text_content()
if answer == "":
return ""
try:
result_dict = json.loads(answer)
@@ -184,7 +184,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = cast(str, response.message.content)
rule_config["prompt"] = response.message.get_text_content()
except InvokeError as e:
error = str(e)
@@ -237,13 +237,11 @@ class LLMGenerator:
return rule_config
rule_config["prompt"] = cast(str, prompt_content.message.content)
rule_config["prompt"] = prompt_content.message.get_text_content()
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -253,7 +251,7 @@ class LLMGenerator:
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"INPUT_TEXT": prompt_content.message.content,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
@@ -263,7 +261,7 @@ class LLMGenerator:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content))
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
@@ -272,7 +270,7 @@ class LLMGenerator:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
@@ -315,7 +313,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, response.message.content)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
@@ -351,7 +349,7 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
answer = cast(str, response.message.content)
answer = response.message.get_text_content()
return answer.strip()
@classmethod
@@ -375,10 +373,7 @@ class LLMGenerator:
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = response.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)

View File

@@ -251,10 +251,7 @@ class AssistantPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_calls:
return False
return True
return super().is_empty() and not self.tool_calls
class SystemPromptMessage(PromptMessage):

View File

@@ -1,6 +1,7 @@
import logging
from collections.abc import Sequence
from opentelemetry.trace import SpanKind
from sqlalchemy.orm import sessionmaker
from core.ops.aliyun_trace.data_exporter.traceclient import (
@@ -54,7 +55,7 @@ from core.ops.entities.trace_entity import (
ToolTraceInfo,
WorkflowTraceInfo,
)
from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
@@ -151,6 +152,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -273,7 +275,7 @@ class AliyunDataTrace(BaseTraceInstance):
service_account = self.get_service_account_with_tenant(app_id)
session_factory = sessionmaker(bind=db.engine)
workflow_node_execution_repository = SQLAlchemyWorkflowNodeExecutionRepository(
workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=service_account,
app_id=app_id,
@@ -456,6 +458,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER,
)
self.trace_client.add_span(message_span)
@@ -475,6 +478,7 @@ class AliyunDataTrace(BaseTraceInstance):
),
status=status,
links=trace_metadata.links,
span_kind=SpanKind.SERVER if message_span_id is None else SpanKind.INTERNAL,
)
self.trace_client.add_span(workflow_span)

View File

@@ -166,7 +166,7 @@ class SpanBuilder:
attributes=span_data.attributes,
events=span_data.events,
links=span_data.links,
kind=trace_api.SpanKind.INTERNAL,
kind=span_data.span_kind,
status=span_data.status,
start_time=span_data.start_time,
end_time=span_data.end_time,

View File

@@ -4,7 +4,7 @@ from typing import Any
from opentelemetry import trace as trace_api
from opentelemetry.sdk.trace import Event
from opentelemetry.trace import Status, StatusCode
from opentelemetry.trace import SpanKind, Status, StatusCode
from pydantic import BaseModel, Field
@@ -34,3 +34,4 @@ class SpanData(BaseModel):
status: Status = Field(default=Status(StatusCode.UNSET), description="The status of the span.")
start_time: int | None = Field(..., description="The start time of the span in nanoseconds.")
end_time: int | None = Field(..., description="The end time of the span in nanoseconds.")
span_kind: SpanKind = Field(default=SpanKind.INTERNAL, description="The OpenTelemetry SpanKind for this span.")

View File

@@ -1,5 +1,6 @@
from core.plugin.entities.endpoint import EndpointEntityWithInstance
from core.plugin.impl.base import BasePluginClient
from core.plugin.impl.exc import PluginDaemonInternalServerError
class PluginEndpointClient(BasePluginClient):
@@ -70,18 +71,27 @@ class PluginEndpointClient(BasePluginClient):
def delete_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""
Delete the given endpoint.
This operation is idempotent: if the endpoint is already deleted (record not found),
it will return True instead of raising an error.
"""
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
try:
return self._request_with_plugin_daemon_response(
"POST",
f"plugin/{tenant_id}/endpoint/remove",
bool,
data={
"endpoint_id": endpoint_id,
},
headers={
"Content-Type": "application/json",
},
)
except PluginDaemonInternalServerError as e:
# Make delete idempotent: if record is not found, consider it a success
if "record not found" in str(e.description).lower():
return True
raise
def enable_endpoint(self, tenant_id: str, user_id: str, endpoint_id: str):
"""

View File

@@ -618,18 +618,18 @@ class ProviderManager:
)
for quota in configuration.quotas:
if quota.quota_type == ProviderQuotaType.TRIAL:
if quota.quota_type in (ProviderQuotaType.TRIAL, ProviderQuotaType.PAID):
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
if quota.quota_type not in provider_quota_to_provider_record_dict:
try:
# FIXME ignore the type error, only TrialHostingQuota has limit need to change the logic
new_provider_record = Provider(
tenant_id=tenant_id,
# TODO: Use provider name with prefix after the data migration.
provider_name=ModelProviderID(provider_name).provider_name,
provider_type=ProviderType.SYSTEM,
quota_type=ProviderQuotaType.TRIAL,
quota_limit=quota.quota_limit, # type: ignore
provider_type=ProviderType.SYSTEM.value,
quota_type=quota.quota_type,
quota_limit=0, # type: ignore
quota_used=0,
is_valid=True,
)
@@ -641,8 +641,8 @@ class ProviderManager:
stmt = select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == ModelProviderID(provider_name).provider_name,
Provider.provider_type == ProviderType.SYSTEM,
Provider.quota_type == ProviderQuotaType.TRIAL,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == quota.quota_type,
)
existed_provider_record = db.session.scalar(stmt)
if not existed_provider_record:
@@ -912,6 +912,22 @@ class ProviderManager:
provider_record
)
quota_configurations = []
if dify_config.EDITION == "CLOUD":
from services.credit_pool_service import CreditPoolService
trail_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.TRIAL.value,
)
paid_pool = CreditPoolService.get_pool(
tenant_id=tenant_id,
pool_type=ProviderQuotaType.PAID.value,
)
else:
trail_pool = None
paid_pool = None
for provider_quota in provider_hosting_configuration.quotas:
if provider_quota.quota_type not in quota_type_to_provider_records_dict:
if provider_quota.quota_type == ProviderQuotaType.FREE:
@@ -932,16 +948,36 @@ class ProviderManager:
raise ValueError("quota_used is None")
if provider_record.quota_limit is None:
raise ValueError("quota_limit is None")
if provider_quota.quota_type == ProviderQuotaType.TRIAL and trail_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=trail_pool.quota_used,
quota_limit=trail_pool.quota_limit,
is_valid=trail_pool.quota_limit > trail_pool.quota_used or trail_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
elif provider_quota.quota_type == ProviderQuotaType.PAID and paid_pool is not None:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=paid_pool.quota_used,
quota_limit=paid_pool.quota_limit,
is_valid=paid_pool.quota_limit > paid_pool.quota_used or paid_pool.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
else:
quota_configuration = QuotaConfiguration(
quota_type=provider_quota.quota_type,
quota_unit=provider_hosting_configuration.quota_unit or QuotaUnit.TOKENS,
quota_used=provider_record.quota_used,
quota_limit=provider_record.quota_limit,
is_valid=provider_record.quota_limit > provider_record.quota_used
or provider_record.quota_limit == -1,
restrict_models=provider_quota.restrict_models,
)
quota_configurations.append(quota_configuration)

View File

@@ -29,7 +29,6 @@ from models import (
Account,
CreatorUserRole,
EndUser,
LLMGenerationDetail,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionTriggeredFrom,
)
@@ -458,113 +457,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
session.merge(db_model)
session.flush()
# Save LLMGenerationDetail for LLM nodes with successful execution
if (
domain_model.node_type == NodeType.LLM
and domain_model.status == WorkflowNodeExecutionStatus.SUCCEEDED
and domain_model.outputs is not None
):
self._save_llm_generation_detail(session, domain_model)
def _save_llm_generation_detail(self, session, execution: WorkflowNodeExecution) -> None:
"""
Save LLM generation detail for LLM nodes.
Extracts reasoning_content, tool_calls, and sequence from outputs and metadata.
"""
outputs = execution.outputs or {}
metadata = execution.metadata or {}
reasoning_list = self._extract_reasoning(outputs)
tool_calls_list = self._extract_tool_calls(metadata.get(WorkflowNodeExecutionMetadataKey.AGENT_LOG))
if not reasoning_list and not tool_calls_list:
return
sequence = self._build_generation_sequence(outputs.get("text", ""), reasoning_list, tool_calls_list)
self._upsert_generation_detail(session, execution, reasoning_list, tool_calls_list, sequence)
def _extract_reasoning(self, outputs: Mapping[str, Any]) -> list[str]:
"""Extract reasoning_content as a clean list of non-empty strings."""
reasoning_content = outputs.get("reasoning_content")
if isinstance(reasoning_content, str):
trimmed = reasoning_content.strip()
return [trimmed] if trimmed else []
if isinstance(reasoning_content, list):
return [item.strip() for item in reasoning_content if isinstance(item, str) and item.strip()]
return []
def _extract_tool_calls(self, agent_log: Any) -> list[dict[str, str]]:
"""Extract tool call records from agent logs."""
if not agent_log or not isinstance(agent_log, list):
return []
tool_calls: list[dict[str, str]] = []
for log in agent_log:
log_data = log.data if hasattr(log, "data") else (log.get("data", {}) if isinstance(log, dict) else {})
tool_name = log_data.get("tool_name")
if tool_name and str(tool_name).strip():
tool_calls.append(
{
"id": log_data.get("tool_call_id", ""),
"name": tool_name,
"arguments": json.dumps(log_data.get("tool_args", {})),
"result": str(log_data.get("output", "")),
}
)
return tool_calls
def _build_generation_sequence(
self, text: str, reasoning_list: list[str], tool_calls_list: list[dict[str, str]]
) -> list[dict[str, Any]]:
"""Build a simple content/reasoning/tool_call sequence."""
sequence: list[dict[str, Any]] = []
if text:
sequence.append({"type": "content", "start": 0, "end": len(text)})
for index in range(len(reasoning_list)):
sequence.append({"type": "reasoning", "index": index})
for index in range(len(tool_calls_list)):
sequence.append({"type": "tool_call", "index": index})
return sequence
def _upsert_generation_detail(
self,
session,
execution: WorkflowNodeExecution,
reasoning_list: list[str],
tool_calls_list: list[dict[str, str]],
sequence: list[dict[str, Any]],
) -> None:
"""Insert or update LLMGenerationDetail with serialized fields."""
existing = (
session.query(LLMGenerationDetail)
.filter_by(
workflow_run_id=execution.workflow_execution_id,
node_id=execution.node_id,
)
.first()
)
reasoning_json = json.dumps(reasoning_list) if reasoning_list else None
tool_calls_json = json.dumps(tool_calls_list) if tool_calls_list else None
sequence_json = json.dumps(sequence) if sequence else None
if existing:
existing.reasoning_content = reasoning_json
existing.tool_calls = tool_calls_json
existing.sequence = sequence_json
return
generation_detail = LLMGenerationDetail(
tenant_id=self._tenant_id,
app_id=self._app_id,
workflow_run_id=execution.workflow_execution_id,
node_id=execution.node_id,
reasoning_content=reasoning_json,
tool_calls=tool_calls_json,
sequence=sequence_json,
)
session.add(generation_detail)
def get_db_models_by_workflow_run(
self,
workflow_run_id: str,

View File

@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.model import File
from core.model_runtime.entities.message_entities import PromptMessageTool
from core.tools.__base.tool_runtime import ToolRuntime
from core.tools.entities.tool_entities import (
ToolEntity,
@@ -155,60 +154,6 @@ class Tool(ABC):
return parameters
def to_prompt_message_tool(self) -> PromptMessageTool:
message_tool = PromptMessageTool(
name=self.entity.identity.name,
description=self.entity.description.llm if self.entity.description else "",
parameters={
"type": "object",
"properties": {},
"required": [],
},
)
parameters = self.get_merged_runtime_parameters()
for parameter in parameters:
if parameter.form != ToolParameter.ToolParameterForm.LLM:
continue
parameter_type = parameter.type.as_normal_type()
if parameter.type in {
ToolParameter.ToolParameterType.SYSTEM_FILES,
ToolParameter.ToolParameterType.FILE,
ToolParameter.ToolParameterType.FILES,
}:
# Determine the description based on parameter type
if parameter.type == ToolParameter.ToolParameterType.FILE:
file_format_desc = " Input the file id with format: [File: file_id]."
else:
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
message_tool.parameters["properties"][parameter.name] = {
"type": "string",
"description": (parameter.llm_description or "") + file_format_desc,
}
continue
enum = []
if parameter.type == ToolParameter.ToolParameterType.SELECT:
enum = [option.value for option in parameter.options] if parameter.options else []
message_tool.parameters["properties"][parameter.name] = (
{
"type": parameter_type,
"description": parameter.llm_description or "",
}
if parameter.input_schema is None
else parameter.input_schema
)
if len(enum) > 0:
message_tool.parameters["properties"][parameter.name]["enum"] = enum
if parameter.required:
message_tool.parameters["required"].append(parameter.name)
return message_tool
def create_image_message(
self,
image: str,

View File

@@ -7,8 +7,8 @@ from typing import Any, cast
from flask import has_request_context
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.db.session_factory import session_factory
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.tools.__base.tool import Tool
@@ -20,7 +20,6 @@ from core.tools.entities.tool_entities import (
ToolProviderType,
)
from core.tools.errors import ToolInvokeError
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs.login import current_user
from models import Account, Tenant
@@ -230,30 +229,32 @@ class WorkflowTool(Tool):
"""
Resolve user from database (worker/Celery context).
"""
with session_factory.create_session() as session:
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = session.scalar(user_stmt)
if user:
user.current_tenant = tenant
session.expunge(user)
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = session.scalar(end_user_stmt)
if end_user:
session.expunge(end_user)
return end_user
tenant_stmt = select(Tenant).where(Tenant.id == self.runtime.tenant_id)
tenant = db.session.scalar(tenant_stmt)
if not tenant:
return None
user_stmt = select(Account).where(Account.id == user_id)
user = db.session.scalar(user_stmt)
if user:
user.current_tenant = tenant
return user
end_user_stmt = select(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == tenant.id)
end_user = db.session.scalar(end_user_stmt)
if end_user:
return end_user
return None
def _get_workflow(self, app_id: str, version: str) -> Workflow:
"""
get the workflow by app id and version
"""
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
if not version:
stmt = (
select(Workflow)
@@ -265,22 +266,24 @@ class WorkflowTool(Tool):
stmt = select(Workflow).where(Workflow.app_id == app_id, Workflow.version == version)
workflow = session.scalar(stmt)
if not workflow:
raise ValueError("workflow not found or not published")
if not workflow:
raise ValueError("workflow not found or not published")
return workflow
session.expunge(workflow)
return workflow
def _get_app(self, app_id: str) -> App:
"""
get the app by app id
"""
stmt = select(App).where(App.id == app_id)
with Session(db.engine, expire_on_commit=False) as session, session.begin():
with session_factory.create_session() as session, session.begin():
app = session.scalar(stmt)
if not app:
raise ValueError("app not found")
if not app:
raise ValueError("app not found")
return app
session.expunge(app)
return app
def _transform_args(self, tool_parameters: dict) -> tuple[dict, list[dict]]:
"""

View File

@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]

View File

@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass

View File

@@ -1,16 +1,11 @@
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .tool_entities import ToolCall, ToolCallResult, ToolResult, ToolResultStatus
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"ToolCall",
"ToolCallResult",
"ToolResult",
"ToolResultStatus",
"WorkflowExecution",
"WorkflowNodeExecution",
]

View File

@@ -1,39 +0,0 @@
from enum import StrEnum
from pydantic import BaseModel, Field
from core.file import File
class ToolResultStatus(StrEnum):
SUCCESS = "success"
ERROR = "error"
class ToolCall(BaseModel):
id: str | None = Field(default=None, description="Unique identifier for this tool call")
name: str | None = Field(default=None, description="Name of the tool being called")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolResult(BaseModel):
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
name: str | None = Field(default=None, description="Name of the tool")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[str] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
icon: str | dict | None = Field(default=None, description="Icon of the tool")
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
class ToolCallResult(BaseModel):
id: str | None = Field(default=None, description="Identifier for the tool call")
name: str | None = Field(default=None, description="Name of the tool")
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
output: str | None = Field(default=None, description="Tool output text, error or success message")
files: list[File] = Field(default_factory=list, description="File produced by tool")
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")

View File

@@ -211,6 +211,10 @@ class WorkflowExecutionStatus(StrEnum):
def is_ended(self) -> bool:
return self in _END_STATE
@classmethod
def ended_values(cls) -> list[str]:
return [status.value for status in _END_STATE]
_END_STATE = frozenset(
[
@@ -247,8 +251,6 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
LLM_CONTENT_SEQUENCE = "llm_content_sequence"
LLM_TRACE = "llm_trace"
COMPLETED_REASON = "completed_reason" # completed reason for loop node

View File

@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):

View File

@@ -16,13 +16,7 @@ from pydantic import BaseModel, Field
from core.workflow.enums import NodeExecutionType, NodeState
from core.workflow.graph import Graph
from core.workflow.graph_events import (
ChunkType,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
from core.workflow.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent
from core.workflow.nodes.base.template import TextSegment, VariableSegment
from core.workflow.runtime import VariablePool
@@ -327,24 +321,11 @@ class ResponseStreamCoordinator:
selector: Sequence[str],
chunk: str,
is_final: bool = False,
chunk_type: ChunkType = ChunkType.TEXT,
tool_call: ToolCall | None = None,
tool_result: ToolResult | None = None,
) -> NodeRunStreamChunkEvent:
"""Create a stream chunk event with consistent structure.
For selectors with special prefixes (sys, env, conversation), we use the
active response node's information since these are not actual node IDs.
Args:
node_id: The node ID to attribute the event to
execution_id: The execution ID for this node
selector: The variable selector
chunk: The chunk content
is_final: Whether this is the final chunk
chunk_type: The semantic type of the chunk being streamed
tool_call: Structured data for tool_call chunks
tool_result: Structured data for tool_result chunks
"""
# Check if this is a special selector that doesn't correspond to a node
if selector and selector[0] not in self._graph.nodes and self._active_session:
@@ -357,9 +338,6 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
# Standard case: selector refers to an actual node
@@ -371,9 +349,6 @@ class ResponseStreamCoordinator:
selector=selector,
chunk=chunk,
is_final=is_final,
chunk_type=chunk_type,
tool_call=tool_call,
tool_result=tool_result,
)
def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]:
@@ -381,8 +356,6 @@ class ResponseStreamCoordinator:
Handles both regular node selectors and special system selectors (sys, env, conversation).
For special selectors, we attribute the output to the active response node.
For object-type variables, automatically streams all child fields that have stream events.
"""
events: list[NodeRunStreamChunkEvent] = []
source_selector_prefix = segment.selector[0] if segment.selector else ""
@@ -391,81 +364,60 @@ class ResponseStreamCoordinator:
# Determine which node to attribute the output to
# For special selectors (sys, env, conversation), use the active response node
# For regular selectors, use the source node
active_session = self._active_session
special_selector = bool(active_session and source_selector_prefix not in self._graph.nodes)
output_node_id = active_session.node_id if special_selector and active_session else source_selector_prefix
if self._active_session and source_selector_prefix not in self._graph.nodes:
# Special selector - use active response node
output_node_id = self._active_session.node_id
else:
# Regular node selector
output_node_id = source_selector_prefix
execution_id = self._get_or_create_execution_id(output_node_id)
# Check if there's a direct stream for this selector
has_direct_stream = (
tuple(segment.selector) in self._stream_buffers or tuple(segment.selector) in self._closed_streams
)
stream_targets = [segment.selector] if has_direct_stream else sorted(self._find_child_streams(segment.selector))
if stream_targets:
all_complete = True
for target_selector in stream_targets:
while self._has_unread_stream(target_selector):
if event := self._pop_stream_chunk(target_selector):
events.append(
self._rewrite_stream_event(
event=event,
output_node_id=output_node_id,
execution_id=execution_id,
special_selector=bool(special_selector),
)
)
if not self._is_stream_closed(target_selector):
all_complete = False
is_complete = all_complete
# Fallback: check if scalar value exists in variable pool
if not is_complete and not has_direct_stream:
if value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session
and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
# Stream all available chunks
while self._has_unread_stream(segment.selector):
if event := self._pop_stream_chunk(segment.selector):
# For special selectors, we need to update the event to use
# the active response node's information
if self._active_session and source_selector_prefix not in self._graph.nodes:
response_node = self._graph.nodes[self._active_session.node_id]
# Create a new event with the response node's information
# but keep the original selector
updated_event = NodeRunStreamChunkEvent(
id=execution_id,
node_id=response_node.id,
node_type=response_node.node_type,
selector=event.selector, # Keep original selector
chunk=event.chunk,
is_final=event.is_final,
)
events.append(updated_event)
else:
# Regular node selector - use event as is
events.append(event)
# Check if this is the last chunk by looking ahead
stream_closed = self._is_stream_closed(segment.selector)
# Check if stream is closed to determine if segment is complete
if stream_closed:
is_complete = True
elif value := self._variable_pool.get(segment.selector):
# Process scalar value
is_last_segment = bool(
self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1
)
events.append(
self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=segment.selector,
chunk=value.markdown,
is_final=is_last_segment,
)
is_complete = True
)
is_complete = True
return events, is_complete
def _rewrite_stream_event(
self,
event: NodeRunStreamChunkEvent,
output_node_id: str,
execution_id: str,
special_selector: bool,
) -> NodeRunStreamChunkEvent:
"""Rewrite event to attribute to active response node when selector is special."""
if not special_selector:
return event
return self._create_stream_chunk_event(
node_id=output_node_id,
execution_id=execution_id,
selector=event.selector,
chunk=event.chunk,
is_final=event.is_final,
chunk_type=event.chunk_type,
tool_call=event.tool_call,
tool_result=event.tool_result,
)
def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]:
"""Process a text segment. Returns (events, is_complete)."""
assert self._active_session is not None
@@ -561,36 +513,6 @@ class ResponseStreamCoordinator:
# ============= Internal Stream Management Methods =============
def _find_child_streams(self, parent_selector: Sequence[str]) -> list[tuple[str, ...]]:
"""Find all child stream selectors that are descendants of the parent selector.
For example, if parent_selector is ['llm', 'generation'], this will find:
- ['llm', 'generation', 'content']
- ['llm', 'generation', 'tool_calls']
- ['llm', 'generation', 'tool_results']
- ['llm', 'generation', 'thought']
Args:
parent_selector: The parent selector to search for children
Returns:
List of child selector tuples found in stream buffers or closed streams
"""
parent_key = tuple(parent_selector)
parent_len = len(parent_key)
child_streams: set[tuple[str, ...]] = set()
# Search in both active buffers and closed streams
all_selectors = set(self._stream_buffers.keys()) | self._closed_streams
for selector_key in all_selectors:
# Check if this selector is a direct child of the parent
# Direct child means: len(child) == len(parent) + 1 and child starts with parent
if len(selector_key) == parent_len + 1 and selector_key[:parent_len] == parent_key:
child_streams.add(selector_key)
return sorted(child_streams)
def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None:
"""
Append a stream chunk to the internal buffer.

View File

@@ -36,7 +36,6 @@ from .loop import (
# Node events
from .node import (
ChunkType,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunPauseRequestedEvent,
@@ -45,13 +44,10 @@ from .node import (
NodeRunStartedEvent,
NodeRunStreamChunkEvent,
NodeRunSucceededEvent,
ToolCall,
ToolResult,
)
__all__ = [
"BaseGraphEvent",
"ChunkType",
"GraphEngineEvent",
"GraphNodeEventBase",
"GraphRunAbortedEvent",
@@ -77,6 +73,4 @@ __all__ = [
"NodeRunStartedEvent",
"NodeRunStreamChunkEvent",
"NodeRunSucceededEvent",
"ToolCall",
"ToolResult",
]

View File

@@ -1,11 +1,10 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit, ToolCall, ToolResult
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from .base import GraphNodeEventBase
@@ -22,39 +21,13 @@ class NodeRunStartedEvent(GraphNodeEventBase):
provider_id: str = ""
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class NodeRunStreamChunkEvent(GraphNodeEventBase):
"""Stream chunk event for workflow node execution."""
# Base fields
# Spec-compliant fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
# Tool call fields (when chunk_type == TOOL_CALL)
tool_call: ToolCall | None = Field(
default=None,
description="structured payload for tool_call chunks",
)
# Tool result fields (when chunk_type == TOOL_RESULT)
tool_result: ToolResult | None = Field(
default=None,
description="structured payload for tool_result chunks",
)
class NodeRunRetrieverResourceEvent(GraphNodeEventBase):

View File

@@ -13,21 +13,16 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
ChunkType,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
RunRetryEvent,
StreamChunkEvent,
StreamCompletedEvent,
ThoughtChunkEvent,
ToolCallChunkEvent,
ToolResultChunkEvent,
)
__all__ = [
"AgentLogEvent",
"ChunkType",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",
@@ -44,7 +39,4 @@ __all__ = [
"RunRetryEvent",
"StreamChunkEvent",
"StreamCompletedEvent",
"ThoughtChunkEvent",
"ToolCallChunkEvent",
"ToolResultChunkEvent",
]

View File

@@ -1,13 +1,11 @@
from collections.abc import Sequence
from datetime import datetime
from enum import StrEnum
from pydantic import Field
from core.file import File
from core.model_runtime.entities.llm_entities import LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import ToolCall, ToolResult
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.node_events import NodeRunResult
@@ -34,60 +32,13 @@ class RunRetryEvent(NodeEventBase):
start_at: datetime = Field(..., description="Retry start time")
class ChunkType(StrEnum):
"""Stream chunk type for LLM-related events."""
TEXT = "text" # Normal text streaming
TOOL_CALL = "tool_call" # Tool call arguments streaming
TOOL_RESULT = "tool_result" # Tool execution result
THOUGHT = "thought" # Agent thinking process (ReAct)
THOUGHT_START = "thought_start" # Agent thought start
THOUGHT_END = "thought_end" # Agent thought end
class StreamChunkEvent(NodeEventBase):
"""Base stream chunk event - normal text streaming output."""
# Spec-compliant fields
selector: Sequence[str] = Field(
..., description="selector identifying the output location (e.g., ['nodeA', 'text'])"
)
chunk: str = Field(..., description="the actual chunk content")
is_final: bool = Field(default=False, description="indicates if this is the last chunk")
chunk_type: ChunkType = Field(default=ChunkType.TEXT, description="type of the chunk")
tool_call: ToolCall | None = Field(default=None, description="structured payload for tool_call chunks")
tool_result: ToolResult | None = Field(default=None, description="structured payload for tool_result chunks")
class ToolCallChunkEvent(StreamChunkEvent):
"""Tool call streaming event - tool call arguments streaming output."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_CALL, frozen=True)
tool_call: ToolCall | None = Field(default=None, description="structured tool call payload")
class ToolResultChunkEvent(StreamChunkEvent):
"""Tool result event - tool execution result."""
chunk_type: ChunkType = Field(default=ChunkType.TOOL_RESULT, frozen=True)
tool_result: ToolResult | None = Field(default=None, description="structured tool result payload")
class ThoughtStartChunkEvent(StreamChunkEvent):
"""Agent thought start streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_START, frozen=True)
class ThoughtEndChunkEvent(StreamChunkEvent):
"""Agent thought end streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT_END, frozen=True)
class ThoughtChunkEvent(StreamChunkEvent):
"""Agent thought streaming event - Agent thinking process (ReAct)."""
chunk_type: ChunkType = Field(default=ChunkType.THOUGHT, frozen=True)
class StreamCompletedEvent(NodeEventBase):

Some files were not shown because too many files have changed in this diff Show More