feat: agent add context

This commit is contained in:
Novice
2026-01-16 11:28:49 +08:00
parent 2591615a3c
commit a7826d9ea4
10 changed files with 458 additions and 681 deletions

View File

@@ -7,7 +7,7 @@ This module provides memory management for LLM conversations, enabling context r
The memory module contains two types of memory implementations:
1. **TokenBufferMemory** - Conversation-level memory (existing)
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
@@ -28,8 +28,8 @@ The memory module contains two types of memory implementations:
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
│ │ NodeTokenBufferMemory │ │
│ │ Scope: Node within Conversation │ │
│ │ Storage: Object Storage (JSON file) │ │
│ │ Key: (app_id, conversation_id, node_id) │ │
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
│ └─────────────────────────────────────────────────────────────────────-┘ │
│ │
└─────────────────────────────────────────────────────────────────────────────┘
@@ -98,7 +98,7 @@ history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit
---
## NodeTokenBufferMemory (To Be Implemented)
## NodeTokenBufferMemory
### Purpose
@@ -110,114 +110,69 @@ history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
3. **Specialized Agents**: Each agent node maintains its own dialogue history
### Design Decisions
### Design: Zero Extra Storage
#### Storage: Object Storage for Messages (No New Database Table)
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
| Aspect | Database | Object Storage |
| ------------------------- | -------------------- | ------------------ |
| Cost | High | Low |
| Query Flexibility | High | Low |
| Schema Changes | Migration required | None |
| Consistency with existing | ConversationVariable | File uploads, logs |
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
**What is stored in Object Storage:**
- Message content (text)
- Message metadata (role, token_count, created_at)
- File references (upload_file_id, tool_file_id, etc.)
- Thread relationships (message_id, parent_message_id)
**What still requires Database queries:**
- File reconstruction: When reading node memory, file references are used to query
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
complete `File` objects with storage_key, mime_type, etc.
**Why this hybrid approach:**
- No database migration required (no new tables)
- Message data may be large, object storage is cost-effective
- File metadata is already in database, no need to duplicate
- Aligns with existing storage patterns (file uploads, logs)
#### Storage Key Format
```
node_memory/{app_id}/{conversation_id}/{node_id}.json
```
#### Data Structure
```json
{
"version": 1,
"messages": [
{
"message_id": "msg-001",
"parent_message_id": null,
"role": "user",
"content": "Analyze this image",
"files": [
{
"type": "image",
"transfer_method": "local_file",
"upload_file_id": "file-uuid-123",
"belongs_to": "user"
}
],
"token_count": 15,
"created_at": "2026-01-07T10:00:00Z"
},
{
"message_id": "msg-002",
"parent_message_id": "msg-001",
"role": "assistant",
"content": "This is a landscape image...",
"files": [],
"token_count": 50,
"created_at": "2026-01-07T10:00:01Z"
}
]
Each LLM node execution outputs:
```python
outputs = {
"text": clean_text,
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
...
}
```
### Thread Support
This `outputs["context"]` contains:
- All previous user/assistant messages (excluding system prompt)
- The current assistant response
Node memory also supports thread extraction (for regeneration scenarios):
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
```python
def _extract_thread(
self,
messages: list[NodeMemoryMessage],
current_message_id: str
) -> list[NodeMemoryMessage]:
"""
Extract messages belonging to the thread of current_message_id.
Similar to extract_thread_messages() in TokenBufferMemory.
"""
...
### Benefits
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|--------|----------------------------|--------------------------------|
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
| Migration | Required | None |
| Complexity | High | Low |
### Data Flow
```
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ │ │
│ SELECT outputs FROM │ │
│ workflow_node_executions │ │
│ WHERE workflow_run_id = ? │ │
│ AND node_id = ? │ │
│◀─────────────────────────────────┤ │
│ │ │
│ outputs["context"] │ │
├─────────────────────────────────▶│ │
│ │ │
│ deserialize PromptMessages │
│ │ │
│ truncate by max_token_limit │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
```
### File Handling
### Thread Tracking
Files are stored as references (not full metadata):
Thread extraction still uses `Message` table's `parent_message_id` structure:
```python
class NodeMemoryFile(BaseModel):
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None # for local_file
tool_file_id: str | None # for tool_file
url: str | None # for remote_url
belongs_to: str # user / assistant
```
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
2. Get the last completed `workflow_run_id` in the thread
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
### API Design
### API
```python
class NodeTokenBufferMemory:
@@ -226,160 +181,29 @@ class NodeTokenBufferMemory:
app_id: str,
conversation_id: str,
node_id: str,
tenant_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param model_instance: Model instance for token counting
"""
...
def add_messages(
self,
message_id: str,
parent_message_id: str | None,
user_content: str,
user_files: Sequence[File],
assistant_content: str,
assistant_files: Sequence[File],
) -> None:
"""
Append a dialogue turn (user + assistant) to node memory.
Call this after LLM node execution completes.
:param message_id: Current message ID (from Message table)
:param parent_message_id: Parent message ID (for thread tracking)
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
"""Initialize node-level memory."""
...
def get_history_prompt_messages(
self,
current_message_id: str,
tenant_id: str,
*,
max_token_limit: int = 2000,
file_upload_config: FileUploadConfig | None = None,
message_limit: int | None = None,
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
:param current_message_id: Current message ID (for thread extraction)
:param tenant_id: Tenant ID (for file reconstruction)
:param max_token_limit: Maximum tokens for history
:param file_upload_config: File upload configuration
:return: Sequence of PromptMessage for LLM context
Reads from last completed execution's outputs["context"].
"""
...
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
...
def clear(self) -> None:
"""
Clear all messages in this node's memory.
"""
...
```
### Data Flow
```
Object Storage NodeTokenBufferMemory LLM Node
│ │ │
│ │◀── get_history_prompt_messages()
│ storage.load(key) │ │
│◀─────────────────────────────────┤ │
│ │ │
│ JSON data │ │
├─────────────────────────────────▶│ │
│ │ │
│ _extract_thread() │
│ │ │
│ _rebuild_files() via file_factory │
│ │ │
│ _build_prompt_messages() │
│ │ │
│ _truncate_by_tokens() │
│ │ │
│ │ Sequence[PromptMessage] │
│ ├──────────────────────────▶│
│ │ │
│ │◀── LLM execution complete │
│ │ │
│ │◀── add_messages() │
│ │ │
│ storage.save(key, data) │ │
│◀─────────────────────────────────┤ │
│ │ │
```
### Integration with LLM Node
```python
# In LLM Node execution
# 1. Fetch memory based on mode
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
# Node-level memory (Chatflow only)
memory = fetch_node_memory(
variable_pool=variable_pool,
app_id=app_id,
node_id=self.node_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
# Conversation-level memory (existing behavior)
memory = fetch_memory(
variable_pool=variable_pool,
app_id=app_id,
node_data_memory=node_data.memory,
model_instance=model_instance,
)
else:
memory = None
# 2. Get history for context
if memory:
if isinstance(memory, NodeTokenBufferMemory):
history = memory.get_history_prompt_messages(
current_message_id=current_message_id,
tenant_id=tenant_id,
max_token_limit=max_token_limit,
)
else: # TokenBufferMemory
history = memory.get_history_prompt_messages(
max_token_limit=max_token_limit,
)
prompt_messages = [*history, *current_messages]
else:
prompt_messages = current_messages
# 3. Call LLM
response = model_instance.invoke(prompt_messages)
# 4. Append to node memory (only for NodeTokenBufferMemory)
if isinstance(memory, NodeTokenBufferMemory):
memory.add_messages(
message_id=message_id,
parent_message_id=parent_message_id,
user_content=user_input,
user_files=user_files,
assistant_content=response.content,
assistant_files=response_files,
)
memory.flush()
# Legacy methods (no-op, kept for compatibility)
def add_messages(self, *args, **kwargs) -> None: pass
def flush(self) -> None: pass
def clear(self) -> None: pass
```
### Configuration
@@ -388,16 +212,13 @@ Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
```python
class MemoryMode(StrEnum):
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
class MemoryConfig(BaseModel):
# Existing fields
role_prefix: RolePrefix | None = None
window: MemoryWindowConfig | None = None
query_prompt_template: str | None = None
# Memory mode (new)
mode: MemoryMode = MemoryMode.CONVERSATION
```
@@ -408,27 +229,39 @@ class MemoryConfig(BaseModel):
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
> fall back to no memory or raise a configuration error.
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
---
## Comparison
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | Object Storage (JSON) |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via file references) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
| -------------- | ------------------------ | ---------------------------------- |
| Scope | Conversation | Node within Conversation |
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
| Thread Support | Yes | Yes |
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
| Token Limit | Yes | Yes |
| Use Case | Standard chat apps | Complex workflows |
---
## Extending to Other Nodes
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
2. The `NodeTokenBufferMemory` will automatically pick it up
Nodes that could potentially support this:
- `question_classifier`
- `parameter_extractor`
- `agent`
---
## Future Considerations
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
2. **Concurrency**: Consider Redis lock for concurrent node executions
3. **Compression**: Compress large memory files to reduce storage costs
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
2. **Compression**: For very long conversations, consider summarization strategies
3. **Extension**: Other nodes may benefit from node-level memory

View File

@@ -1,15 +1,11 @@
from core.memory.base import BaseMemory
from core.memory.node_token_buffer_memory import (
NodeMemoryData,
NodeMemoryFile,
NodeTokenBufferMemory,
)
from core.memory.token_buffer_memory import TokenBufferMemory
__all__ = [
"BaseMemory",
"NodeMemoryData",
"NodeMemoryFile",
"NodeTokenBufferMemory",
"TokenBufferMemory",
]

View File

@@ -8,73 +8,44 @@ Note: This is only available in Chatflow (advanced-chat mode) because it require
both conversation_id and node_id.
Design:
- Storage is indexed by workflow_run_id (each execution stores one turn)
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
- No separate storage needed - the context is already saved during node execution
- Thread tracking leverages Message table's parent_message_id structure
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
"""
import logging
from collections.abc import Sequence
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.file import File, FileTransferMethod
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
AssistantPromptMessage,
ImagePromptMessageContent,
PromptMessage,
TextPromptMessageContent,
PromptMessageRole,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.prompt.utils.extract_thread_messages import extract_thread_messages
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message
from models.workflow import WorkflowNodeExecutionModel
logger = logging.getLogger(__name__)
class NodeMemoryFile(BaseModel):
"""File reference stored in node memory."""
type: str # image, audio, video, document, custom
transfer_method: str # local_file, remote_url, tool_file
upload_file_id: str | None = None
tool_file_id: str | None = None
url: str | None = None
class NodeMemoryTurn(BaseModel):
"""A single dialogue turn (user + assistant) in node memory."""
user_content: str = ""
user_files: list[NodeMemoryFile] = []
assistant_content: str = ""
assistant_files: list[NodeMemoryFile] = []
class NodeMemoryData(BaseModel):
"""Root data structure for node memory storage."""
version: int = 1
# Key: workflow_run_id, Value: dialogue turn
turns: dict[str, NodeMemoryTurn] = {}
class NodeTokenBufferMemory(BaseMemory):
"""
Node-level Token Buffer Memory.
Provides node-scoped memory within a conversation. Each LLM node can maintain
its own independent conversation history, stored in object storage.
its own independent conversation history.
Key design: Thread tracking is delegated to Message table's parent_message_id.
Storage is indexed by workflow_run_id for easy filtering.
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
which is already saved during node execution. No separate storage needed.
"""
def __init__(
@@ -85,132 +56,25 @@ class NodeTokenBufferMemory(BaseMemory):
tenant_id: str,
model_instance: ModelInstance,
):
"""
Initialize node-level memory.
:param app_id: Application ID
:param conversation_id: Conversation ID
:param node_id: Node ID in the workflow
:param tenant_id: Tenant ID for file reconstruction
:param model_instance: Model instance for token counting
"""
self.app_id = app_id
self.conversation_id = conversation_id
self.node_id = node_id
self.tenant_id = tenant_id
self.model_instance = model_instance
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
self._data: NodeMemoryData | None = None
self._dirty = False
def _load(self) -> NodeMemoryData:
"""Load data from object storage."""
if self._data is not None:
return self._data
try:
raw = storage.load_once(self._storage_key)
self._data = NodeMemoryData.model_validate_json(raw)
except Exception:
# File not found or parse error, start fresh
self._data = NodeMemoryData()
return self._data
def _save(self) -> None:
"""Save data to object storage."""
if self._data is not None:
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
self._dirty = False
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
"""Convert File object to NodeMemoryFile reference."""
return NodeMemoryFile(
type=file.type.value if hasattr(file.type, "value") else str(file.type),
transfer_method=(
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
),
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
)
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
"""Convert NodeMemoryFile to mapping for file_factory."""
mapping: dict = {
"type": memory_file.type,
"transfer_method": memory_file.transfer_method,
}
if memory_file.upload_file_id:
mapping["upload_file_id"] = memory_file.upload_file_id
if memory_file.tool_file_id:
mapping["tool_file_id"] = memory_file.tool_file_id
if memory_file.url:
mapping["url"] = memory_file.url
return mapping
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
"""Rebuild File objects from NodeMemoryFile references."""
if not memory_files:
return []
from factories import file_factory
files = []
for mf in memory_files:
try:
mapping = self._memory_file_to_mapping(mf)
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
files.append(file)
except Exception as e:
logger.warning("Failed to rebuild file from memory: %s", e)
continue
return files
def _build_prompt_message(
self,
role: str,
content: str,
files: list[File],
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
) -> PromptMessage:
"""Build PromptMessage from content and files."""
from core.file import file_manager
if not files:
if role == "user":
return UserPromptMessage(content=content)
else:
return AssistantPromptMessage(content=content)
# Build multimodal content
prompt_contents: list = []
for file in files:
try:
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
prompt_contents.append(prompt_content)
except Exception as e:
logger.warning("Failed to convert file to prompt content: %s", e)
continue
prompt_contents.append(TextPromptMessageContent(data=content))
if role == "user":
return UserPromptMessage(content=prompt_contents)
else:
return AssistantPromptMessage(content=prompt_contents)
def _get_thread_workflow_run_ids(self) -> list[str]:
"""
Get workflow_run_ids for the current thread by querying Message table.
Returns workflow_run_ids in chronological order (oldest first).
"""
# Query messages for this conversation
stmt = (
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
)
messages = db.session.scalars(stmt.limit(500)).all()
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(Message)
.where(Message.conversation_id == self.conversation_id)
.order_by(Message.created_at.desc())
.limit(500)
)
messages = list(session.scalars(stmt).all())
if not messages:
return []
@@ -223,46 +87,31 @@ class NodeTokenBufferMemory(BaseMemory):
thread_messages.pop(0)
# Reverse to get chronological order, extract workflow_run_ids
workflow_run_ids = []
for msg in reversed(thread_messages):
if msg.workflow_run_id:
workflow_run_ids.append(msg.workflow_run_id)
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
return workflow_run_ids
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
"""Deserialize a dict to PromptMessage based on role."""
role = msg_dict.get("role")
if role in (PromptMessageRole.USER, "user"):
return UserPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
return AssistantPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.SYSTEM, "system"):
return SystemPromptMessage.model_validate(msg_dict)
elif role in (PromptMessageRole.TOOL, "tool"):
return ToolPromptMessage.model_validate(msg_dict)
else:
return PromptMessage.model_validate(msg_dict)
def add_messages(
self,
workflow_run_id: str,
user_content: str,
user_files: Sequence[File] | None = None,
assistant_content: str = "",
assistant_files: Sequence[File] | None = None,
) -> None:
"""
Add a dialogue turn to node memory.
Call this after LLM node execution completes.
:param workflow_run_id: Current workflow execution ID
:param user_content: User's text input
:param user_files: Files attached by user
:param assistant_content: Assistant's text response
:param assistant_files: Files generated by assistant
"""
data = self._load()
# Convert files to memory file references
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
# Store the turn indexed by workflow_run_id
data.turns[workflow_run_id] = NodeMemoryTurn(
user_content=user_content,
user_files=user_memory_files,
assistant_content=assistant_content,
assistant_files=assistant_memory_files,
)
self._dirty = True
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
"""Deserialize context data from outputs to list of PromptMessage."""
messages = []
for msg_dict in context_data:
try:
messages.append(self._deserialize_prompt_message(msg_dict))
except Exception as e:
logger.warning("Failed to deserialize prompt message: %s", e)
return messages
def get_history_prompt_messages(
self,
@@ -272,55 +121,38 @@ class NodeTokenBufferMemory(BaseMemory):
) -> Sequence[PromptMessage]:
"""
Retrieve history as PromptMessage sequence.
Thread tracking is handled by querying Message table's parent_message_id structure.
:param max_token_limit: Maximum tokens for history
:param message_limit: unused, for interface compatibility
:return: Sequence of PromptMessage for LLM context
History is read directly from the last completed node execution's outputs["context"].
"""
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
_ = message_limit
detail = ImagePromptMessageContent.DETAIL.HIGH
data = self._load()
_ = message_limit # unused, kept for interface compatibility
if not data.turns:
return []
# Get workflow_run_ids for current thread from Message table
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
if not thread_workflow_run_ids:
return []
# Build prompt messages in thread order
prompt_messages: list[PromptMessage] = []
for wf_run_id in thread_workflow_run_ids:
turn = data.turns.get(wf_run_id)
if not turn:
# This workflow execution didn't have node memory stored
continue
# Get the last completed workflow_run_id (contains accumulated context)
last_run_id = thread_workflow_run_ids[-1]
# Build user message
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
user_msg = self._build_prompt_message(
role="user",
content=turn.user_content,
files=user_files,
detail=detail,
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
WorkflowNodeExecutionModel.node_id == self.node_id,
WorkflowNodeExecutionModel.status == "succeeded",
)
prompt_messages.append(user_msg)
execution = session.scalars(stmt).first()
# Build assistant message
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
assistant_msg = self._build_prompt_message(
role="assistant",
content=turn.assistant_content,
files=assistant_files,
detail=detail,
)
prompt_messages.append(assistant_msg)
if not execution:
return []
outputs = execution.outputs_dict
if not outputs:
return []
context_data = outputs.get("context")
if not context_data or not isinstance(context_data, list):
return []
prompt_messages = self._deserialize_context(context_data)
if not prompt_messages:
return []
@@ -334,20 +166,3 @@ class NodeTokenBufferMemory(BaseMemory):
logger.warning("Failed to count tokens for truncation: %s", e)
return prompt_messages
def flush(self) -> None:
"""
Persist buffered changes to object storage.
Call this at the end of node execution.
"""
if self._dirty:
self._save()
def clear(self) -> None:
"""Clear all messages in this node's memory."""
self._data = NodeMemoryData()
self._save()
def exists(self) -> bool:
"""Check if node memory exists in storage."""
return storage.exists(self._storage_key)

View File

@@ -276,7 +276,5 @@ class ToolPromptMessage(PromptMessage):
:return: True if prompt message is empty, False otherwise
"""
if not super().is_empty() and not self.tool_call_id:
return False
return True
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
return super().is_empty() and not self.tool_call_id

View File

@@ -17,6 +17,12 @@ from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessage,
ToolPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import MemoryMode
@@ -527,6 +533,95 @@ class AgentNode(Node[AgentNodeData]):
# Conversation-level memory doesn't need saving here
return None
def _build_context(
self,
parameters_for_log: dict[str, Any],
user_query: str,
assistant_response: str,
agent_logs: list[AgentLogEvent],
) -> list[PromptMessage]:
"""
Build context from user query, tool calls, and assistant response.
Format: user -> assistant(with tool_calls) -> tool -> assistant
The context includes:
- Current user query (always present, may be empty)
- Assistant message with tool_calls (if tools were called)
- Tool results
- Assistant's final response
"""
context_messages: list[PromptMessage] = []
# Always add user query (even if empty, to maintain conversation structure)
context_messages.append(UserPromptMessage(content=user_query or ""))
# Extract actual tool calls from agent logs
# Only include logs with label starting with "CALL " - these are real tool invocations
tool_calls: list[AssistantPromptMessage.ToolCall] = []
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
for log in agent_logs:
if log.status == "success" and log.label and log.label.startswith("CALL "):
# Extract tool name from label (format: "CALL tool_name")
tool_name = log.label[5:] # Remove "CALL " prefix
tool_call_id = log.message_id
# Parse tool response from data
data = log.data or {}
tool_response = ""
# Try to extract the actual tool response
if "tool_response" in data:
tool_response = data["tool_response"]
elif "output" in data:
tool_response = data["output"]
elif "result" in data:
tool_response = data["result"]
if isinstance(tool_response, dict):
tool_response = str(tool_response)
# Get tool input for arguments
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
if isinstance(tool_input, dict):
import json
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
else:
tool_input_str = str(tool_input) if tool_input else ""
if tool_response:
tool_calls.append(
AssistantPromptMessage.ToolCall(
id=tool_call_id,
type="function",
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
name=tool_name,
arguments=tool_input_str,
),
)
)
tool_results.append((tool_call_id, tool_name, str(tool_response)))
# Add assistant message with tool_calls if there were tool calls
if tool_calls:
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
# Add tool result messages
for tool_call_id, tool_name, result in tool_results:
context_messages.append(
ToolPromptMessage(
content=result,
tool_call_id=tool_call_id,
name=tool_name,
)
)
# Add final assistant response
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _transform_message(
self,
messages: Generator[ToolInvokeMessage, None, None],
@@ -782,20 +877,11 @@ class AgentNode(Node[AgentNodeData]):
is_final=True,
)
# Save to node memory if in node memory mode
from core.workflow.nodes.llm import llm_utils
# Get user query from parameters for building context
user_query = parameters_for_log.get("query", "")
# Get user query from sys.query
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
user_query = user_query_var.text if user_query_var else ""
llm_utils.save_node_memory(
memory=memory,
variable_pool=self.graph_runtime_state.variable_pool,
user_query=user_query,
assistant_response=text,
assistant_files=files,
)
# Build context from history, user query, tool calls and assistant response
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
@@ -805,6 +891,7 @@ class AgentNode(Node[AgentNodeData]):
"usage": jsonable_encoder(llm_usage),
"files": ArrayFileSegment(value=files),
"json": json_output,
"context": context,
**variables,
},
metadata={

View File

@@ -285,7 +285,7 @@ class Node(Generic[NodeDataT]):
extractor_configs.append(node_config)
return extractor_configs
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
"""
Execute all extractor nodes associated with this node.
@@ -349,7 +349,7 @@ class Node(Generic[NodeDataT]):
self._start_at = naive_utc_now()
# Step 1: Execute associated extractor nodes before main node execution
yield from self._execute_extractor_nodes()
yield from self._execute_mention_nodes()
# Create and push start event with required fields
start_event = NodeRunStartedEvent(

View File

@@ -12,6 +12,13 @@ from core.memory import NodeTokenBufferMemory, TokenBufferMemory
from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
)
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
@@ -139,50 +146,6 @@ def fetch_memory(
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def save_node_memory(
memory: BaseMemory | None,
variable_pool: VariablePool,
user_query: str,
assistant_response: str,
user_files: Sequence["File"] | None = None,
assistant_files: Sequence["File"] | None = None,
) -> None:
"""
Save dialogue turn to node memory if applicable.
This function handles the storage logic for NodeTokenBufferMemory.
For TokenBufferMemory (conversation-level), no action is taken as it uses
the Message table which is managed elsewhere.
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
:param variable_pool: Variable pool containing system variables
:param user_query: User's input text
:param assistant_response: Assistant's response text
:param user_files: Files attached by user (optional)
:param assistant_files: Files generated by assistant (optional)
"""
if not isinstance(memory, NodeTokenBufferMemory):
return
# Get workflow_run_id as the key for this execution
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
if not isinstance(workflow_run_id_var, StringSegment):
return
workflow_run_id = workflow_run_id_var.value
if not workflow_run_id:
return
memory.add_messages(
workflow_run_id=workflow_run_id,
user_content=user_query,
user_files=list(user_files) if user_files else None,
assistant_content=assistant_response,
assistant_files=list(assistant_files) if assistant_files else None,
)
memory.flush()
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
@@ -246,3 +209,45 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
)
session.execute(stmt)
session.commit()
def build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, truncating multi-modal base64 data
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})

View File

@@ -20,7 +20,6 @@ from core.memory.base import BaseMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities import (
ImagePromptMessageContent,
MultiModalPromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
@@ -327,25 +326,13 @@ class LLMNode(Node[LLMNodeData]):
"reasoning_content": reasoning_content,
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"context": self._build_context(prompt_messages, clean_text),
"context": llm_utils.build_context(prompt_messages, clean_text),
}
if structured_output:
outputs["structured_output"] = structured_output.structured_output
if self._file_outputs:
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
# Write to Node Memory if in node memory mode
# Resolve the query template to get actual user content
actual_query = variable_pool.convert_template(query or "").text
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=actual_query,
assistant_response=clean_text,
user_files=files,
assistant_files=self._file_outputs,
)
# Send final chunk event to indicate streaming is complete
yield StreamChunkEvent(
selector=[self._node_id, "text"],
@@ -607,48 +594,6 @@ class LLMNode(Node[LLMNodeData]):
# Separated mode: always return clean text and reasoning_content
return clean_text, reasoning_content or ""
@staticmethod
def _build_context(
prompt_messages: Sequence[PromptMessage],
assistant_response: str,
) -> list[PromptMessage]:
"""
Build context from prompt messages and assistant response.
Excludes system messages and includes the current LLM response.
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
"""
context_messages: list[PromptMessage] = [
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
]
context_messages.append(AssistantPromptMessage(content=assistant_response))
return context_messages
@staticmethod
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
"""
Truncate multi-modal content base64 data in a message to avoid storing large data.
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
"""
content = message.content
if content is None or isinstance(content, str):
return message
# Process list content, truncating multi-modal base64 data
new_content: list[PromptMessageContentUnionTypes] = []
for item in content:
if isinstance(item, MultiModalPromptMessageContent):
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
truncated_base64 = ""
if item.base64_data:
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
else:
new_content.append(item)
return message.model_copy(update={"content": new_content})
def _transform_chat_messages(
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
@@ -716,54 +661,158 @@ class LLMNode(Node[LLMNodeData]):
"""
variable_pool = self.graph_runtime_state.variable_pool
# Build a map from context index to its messages
context_messages_map: dict[int, list[PromptMessage]] = {}
# Process messages in DSL order: iterate once and handle each type directly
combined_messages: list[PromptMessage] = []
context_idx = 0
for idx, type_ in template_order:
static_idx = 0
for _, type_ in template_order:
if type_ == "context":
# Handle context reference
ctx_ref = context_refs[context_idx]
ctx_var = variable_pool.get(ctx_ref.value_selector)
if ctx_var is None:
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
if not isinstance(ctx_var, ArrayPromptMessageSegment):
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
context_messages_map[idx] = list(ctx_var.value)
combined_messages.extend(ctx_var.value)
context_idx += 1
# Process static messages
static_prompt_messages: Sequence[PromptMessage] = []
stop: Sequence[str] | None = None
if static_messages:
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
sys_query=query,
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
# Combine messages according to original DSL order
combined_messages: list[PromptMessage] = []
static_msg_iter = iter(static_prompt_messages)
for idx, type_ in template_order:
if type_ == "context":
combined_messages.extend(context_messages_map[idx])
else:
if msg := next(static_msg_iter, None):
combined_messages.append(msg)
# Append any remaining static messages (e.g., memory messages)
combined_messages.extend(static_msg_iter)
# Handle static message
static_msg = static_messages[static_idx]
processed_msgs = LLMNode.handle_list_messages(
messages=[static_msg],
context=context,
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(processed_msgs)
static_idx += 1
# Append memory messages
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=self.node_data.memory,
model_config=model_config,
)
combined_messages.extend(memory_messages)
# Append current query if provided
if query:
query_message = LLMNodeChatModelMessage(
text=query,
role=PromptMessageRole.USER,
edition_type="basic",
)
query_msgs = LLMNode.handle_list_messages(
messages=[query_message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=self.node_data.vision.configs.detail,
)
combined_messages.extend(query_msgs)
# Handle files (sys_files and context_files)
combined_messages = self._append_files_to_messages(
messages=combined_messages,
sys_files=files,
context_files=context_files,
model_config=model_config,
)
# Filter empty messages and get stop sequences
combined_messages = self._filter_messages(combined_messages, model_config)
stop = self._get_stop_sequences(model_config)
return combined_messages, stop
def _append_files_to_messages(
self,
*,
messages: list[PromptMessage],
sys_files: Sequence[File],
context_files: list[File],
model_config: ModelConfigWithCredentialsEntity,
) -> list[PromptMessage]:
"""Append sys_files and context_files to messages."""
vision_enabled = self.node_data.vision.enabled
vision_detail = self.node_data.vision.configs.detail
# Handle sys_files (will be deprecated later)
if vision_enabled and sys_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
# Handle context_files
if vision_enabled and context_files:
file_prompts = [
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
for file in context_files
]
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
else:
messages.append(UserPromptMessage(content=file_prompts))
return messages
def _filter_messages(
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
) -> list[PromptMessage]:
"""Filter empty messages and unsupported content types."""
filtered_messages: list[PromptMessage] = []
for message in messages:
if isinstance(message.content, list):
filtered_content: list[PromptMessageContentUnionTypes] = []
for content_item in message.content:
# Skip non-text content if features are not defined
if not model_config.model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
filtered_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
feature_map = {
PromptMessageContentType.IMAGE: ModelFeature.VISION,
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
}
required_feature = feature_map.get(content_item.type)
if required_feature and required_feature not in model_config.model_schema.features:
continue
filtered_content.append(content_item)
# Simplify single text content
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
message.content = filtered_content[0].data
else:
message.content = filtered_content
if not message.is_empty():
filtered_messages.append(message)
if not filtered_messages:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_messages
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
"""Get stop sequences from model config."""
return model_config.stop
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
variables: dict[str, Any] = {}

View File

@@ -246,13 +246,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
# transform result into standard format
result = self._transform_result(data=node_data, result=result or {})
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query,
assistant_response=json.dumps(result, ensure_ascii=False),
)
# Build context from prompt messages and response
assistant_response = json.dumps(result, ensure_ascii=False)
context = llm_utils.build_context(prompt_messages, assistant_response)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
@@ -262,6 +258,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"__is_success": 1 if not error else 0,
"__reason": error,
"__usage": jsonable_encoder(usage),
"context": context,
**result,
},
metadata={

View File

@@ -199,20 +199,17 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
"model_provider": model_config.provider,
"model_name": model_config.model,
}
# Build context from prompt messages and response
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
context = llm_utils.build_context(prompt_messages, assistant_response)
outputs = {
"class_name": category_name,
"class_id": category_id,
"usage": jsonable_encoder(usage),
"context": context,
}
# Save to node memory if in node memory mode
llm_utils.save_node_memory(
memory=memory,
variable_pool=variable_pool,
user_query=query or "",
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
)
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs=variables,