mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 23:20:12 -05:00
feat: agent add context
This commit is contained in:
@@ -7,7 +7,7 @@ This module provides memory management for LLM conversations, enabling context r
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
@@ -28,8 +28,8 @@ The memory module contains two types of memory implementations:
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: Object Storage (JSON file) │ │
|
||||
│ │ Key: (app_id, conversation_id, node_id) │ │
|
||||
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
|
||||
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
@@ -98,7 +98,7 @@ history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory (To Be Implemented)
|
||||
## NodeTokenBufferMemory
|
||||
|
||||
### Purpose
|
||||
|
||||
@@ -110,114 +110,69 @@ history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design Decisions
|
||||
### Design: Zero Extra Storage
|
||||
|
||||
#### Storage: Object Storage for Messages (No New Database Table)
|
||||
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
|
||||
|
||||
| Aspect | Database | Object Storage |
|
||||
| ------------------------- | -------------------- | ------------------ |
|
||||
| Cost | High | Low |
|
||||
| Query Flexibility | High | Low |
|
||||
| Schema Changes | Migration required | None |
|
||||
| Consistency with existing | ConversationVariable | File uploads, logs |
|
||||
|
||||
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
|
||||
|
||||
**What is stored in Object Storage:**
|
||||
|
||||
- Message content (text)
|
||||
- Message metadata (role, token_count, created_at)
|
||||
- File references (upload_file_id, tool_file_id, etc.)
|
||||
- Thread relationships (message_id, parent_message_id)
|
||||
|
||||
**What still requires Database queries:**
|
||||
|
||||
- File reconstruction: When reading node memory, file references are used to query
|
||||
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
|
||||
complete `File` objects with storage_key, mime_type, etc.
|
||||
|
||||
**Why this hybrid approach:**
|
||||
|
||||
- No database migration required (no new tables)
|
||||
- Message data may be large, object storage is cost-effective
|
||||
- File metadata is already in database, no need to duplicate
|
||||
- Aligns with existing storage patterns (file uploads, logs)
|
||||
|
||||
#### Storage Key Format
|
||||
|
||||
```
|
||||
node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
```
|
||||
|
||||
#### Data Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"messages": [
|
||||
{
|
||||
"message_id": "msg-001",
|
||||
"parent_message_id": null,
|
||||
"role": "user",
|
||||
"content": "Analyze this image",
|
||||
"files": [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "file-uuid-123",
|
||||
"belongs_to": "user"
|
||||
}
|
||||
],
|
||||
"token_count": 15,
|
||||
"created_at": "2026-01-07T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"message_id": "msg-002",
|
||||
"parent_message_id": "msg-001",
|
||||
"role": "assistant",
|
||||
"content": "This is a landscape image...",
|
||||
"files": [],
|
||||
"token_count": 50,
|
||||
"created_at": "2026-01-07T10:00:01Z"
|
||||
}
|
||||
]
|
||||
Each LLM node execution outputs:
|
||||
```python
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
### Thread Support
|
||||
This `outputs["context"]` contains:
|
||||
- All previous user/assistant messages (excluding system prompt)
|
||||
- The current assistant response
|
||||
|
||||
Node memory also supports thread extraction (for regeneration scenarios):
|
||||
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
|
||||
|
||||
```python
|
||||
def _extract_thread(
|
||||
self,
|
||||
messages: list[NodeMemoryMessage],
|
||||
current_message_id: str
|
||||
) -> list[NodeMemoryMessage]:
|
||||
"""
|
||||
Extract messages belonging to the thread of current_message_id.
|
||||
Similar to extract_thread_messages() in TokenBufferMemory.
|
||||
"""
|
||||
...
|
||||
### Benefits
|
||||
|
||||
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|
||||
|--------|----------------------------|--------------------------------|
|
||||
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
|
||||
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
|
||||
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
|
||||
| Migration | Required | None |
|
||||
| Complexity | High | Low |
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ │ │
|
||||
│ SELECT outputs FROM │ │
|
||||
│ workflow_node_executions │ │
|
||||
│ WHERE workflow_run_id = ? │ │
|
||||
│ AND node_id = ? │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ outputs["context"] │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ deserialize PromptMessages │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### File Handling
|
||||
### Thread Tracking
|
||||
|
||||
Files are stored as references (not full metadata):
|
||||
Thread extraction still uses `Message` table's `parent_message_id` structure:
|
||||
|
||||
```python
|
||||
class NodeMemoryFile(BaseModel):
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None # for local_file
|
||||
tool_file_id: str | None # for tool_file
|
||||
url: str | None # for remote_url
|
||||
belongs_to: str # user / assistant
|
||||
```
|
||||
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
|
||||
2. Get the last completed `workflow_run_id` in the thread
|
||||
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
|
||||
|
||||
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
|
||||
|
||||
### API Design
|
||||
### API
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
@@ -226,160 +181,29 @@ class NodeTokenBufferMemory:
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
...
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
message_id: str,
|
||||
parent_message_id: str | None,
|
||||
user_content: str,
|
||||
user_files: Sequence[File],
|
||||
assistant_content: str,
|
||||
assistant_files: Sequence[File],
|
||||
) -> None:
|
||||
"""
|
||||
Append a dialogue turn (user + assistant) to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param message_id: Current message ID (from Message table)
|
||||
:param parent_message_id: Parent message ID (for thread tracking)
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
"""Initialize node-level memory."""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
current_message_id: str,
|
||||
tenant_id: str,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
file_upload_config: FileUploadConfig | None = None,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
:param current_message_id: Current message ID (for thread extraction)
|
||||
:param tenant_id: Tenant ID (for file reconstruction)
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param file_upload_config: File upload configuration
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
|
||||
Reads from last completed execution's outputs["context"].
|
||||
"""
|
||||
...
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all messages in this node's memory.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Object Storage NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ storage.load(key) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ JSON data │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ _extract_thread() │
|
||||
│ │ │
|
||||
│ _rebuild_files() via file_factory │
|
||||
│ │ │
|
||||
│ _build_prompt_messages() │
|
||||
│ │ │
|
||||
│ _truncate_by_tokens() │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
│ │◀── LLM execution complete │
|
||||
│ │ │
|
||||
│ │◀── add_messages() │
|
||||
│ │ │
|
||||
│ storage.save(key, data) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Integration with LLM Node
|
||||
|
||||
```python
|
||||
# In LLM Node execution
|
||||
|
||||
# 1. Fetch memory based on mode
|
||||
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
memory = fetch_node_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_id=self.node_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
|
||||
# Conversation-level memory (existing behavior)
|
||||
memory = fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
|
||||
# 2. Get history for context
|
||||
if memory:
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
history = memory.get_history_prompt_messages(
|
||||
current_message_id=current_message_id,
|
||||
tenant_id=tenant_id,
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
else: # TokenBufferMemory
|
||||
history = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
prompt_messages = [*history, *current_messages]
|
||||
else:
|
||||
prompt_messages = current_messages
|
||||
|
||||
# 3. Call LLM
|
||||
response = model_instance.invoke(prompt_messages)
|
||||
|
||||
# 4. Append to node memory (only for NodeTokenBufferMemory)
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
memory.add_messages(
|
||||
message_id=message_id,
|
||||
parent_message_id=parent_message_id,
|
||||
user_content=user_input,
|
||||
user_files=user_files,
|
||||
assistant_content=response.content,
|
||||
assistant_files=response_files,
|
||||
)
|
||||
memory.flush()
|
||||
# Legacy methods (no-op, kept for compatibility)
|
||||
def add_messages(self, *args, **kwargs) -> None: pass
|
||||
def flush(self) -> None: pass
|
||||
def clear(self) -> None: pass
|
||||
```
|
||||
|
||||
### Configuration
|
||||
@@ -388,16 +212,13 @@ Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
# Existing fields
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
# Memory mode (new)
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
@@ -408,27 +229,39 @@ class MemoryConfig(BaseModel):
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
|
||||
> fall back to no memory or raise a configuration error.
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | Object Storage (JSON) |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via file references) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ---------------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Extending to Other Nodes
|
||||
|
||||
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
|
||||
|
||||
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
|
||||
2. The `NodeTokenBufferMemory` will automatically pick it up
|
||||
|
||||
Nodes that could potentially support this:
|
||||
- `question_classifier`
|
||||
- `parameter_extractor`
|
||||
- `agent`
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
|
||||
2. **Concurrency**: Consider Redis lock for concurrent node executions
|
||||
3. **Compression**: Compress large memory files to reduce storage costs
|
||||
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory
|
||||
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
|
||||
2. **Compression**: For very long conversations, consider summarization strategies
|
||||
3. **Extension**: Other nodes may benefit from node-level memory
|
||||
|
||||
@@ -1,15 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeMemoryData,
|
||||
NodeMemoryFile,
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeMemoryData",
|
||||
"NodeMemoryFile",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
|
||||
@@ -8,73 +8,44 @@ Note: This is only available in Chatflow (advanced-chat mode) because it require
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- Storage is indexed by workflow_run_id (each execution stores one turn)
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeMemoryFile(BaseModel):
|
||||
"""File reference stored in node memory."""
|
||||
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None = None
|
||||
tool_file_id: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class NodeMemoryTurn(BaseModel):
|
||||
"""A single dialogue turn (user + assistant) in node memory."""
|
||||
|
||||
user_content: str = ""
|
||||
user_files: list[NodeMemoryFile] = []
|
||||
assistant_content: str = ""
|
||||
assistant_files: list[NodeMemoryFile] = []
|
||||
|
||||
|
||||
class NodeMemoryData(BaseModel):
|
||||
"""Root data structure for node memory storage."""
|
||||
|
||||
version: int = 1
|
||||
# Key: workflow_run_id, Value: dialogue turn
|
||||
turns: dict[str, NodeMemoryTurn] = {}
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history, stored in object storage.
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: Thread tracking is delegated to Message table's parent_message_id.
|
||||
Storage is indexed by workflow_run_id for easy filtering.
|
||||
|
||||
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -85,132 +56,25 @@ class NodeTokenBufferMemory(BaseMemory):
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param tenant_id: Tenant ID for file reconstruction
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
|
||||
self._data: NodeMemoryData | None = None
|
||||
self._dirty = False
|
||||
|
||||
def _load(self) -> NodeMemoryData:
|
||||
"""Load data from object storage."""
|
||||
if self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
raw = storage.load_once(self._storage_key)
|
||||
self._data = NodeMemoryData.model_validate_json(raw)
|
||||
except Exception:
|
||||
# File not found or parse error, start fresh
|
||||
self._data = NodeMemoryData()
|
||||
|
||||
return self._data
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save data to object storage."""
|
||||
if self._data is not None:
|
||||
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
|
||||
self._dirty = False
|
||||
|
||||
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
|
||||
"""Convert File object to NodeMemoryFile reference."""
|
||||
return NodeMemoryFile(
|
||||
type=file.type.value if hasattr(file.type, "value") else str(file.type),
|
||||
transfer_method=(
|
||||
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
|
||||
),
|
||||
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
|
||||
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
)
|
||||
|
||||
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
|
||||
"""Convert NodeMemoryFile to mapping for file_factory."""
|
||||
mapping: dict = {
|
||||
"type": memory_file.type,
|
||||
"transfer_method": memory_file.transfer_method,
|
||||
}
|
||||
if memory_file.upload_file_id:
|
||||
mapping["upload_file_id"] = memory_file.upload_file_id
|
||||
if memory_file.tool_file_id:
|
||||
mapping["tool_file_id"] = memory_file.tool_file_id
|
||||
if memory_file.url:
|
||||
mapping["url"] = memory_file.url
|
||||
return mapping
|
||||
|
||||
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
|
||||
"""Rebuild File objects from NodeMemoryFile references."""
|
||||
if not memory_files:
|
||||
return []
|
||||
|
||||
from factories import file_factory
|
||||
|
||||
files = []
|
||||
for mf in memory_files:
|
||||
try:
|
||||
mapping = self._memory_file_to_mapping(mf)
|
||||
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
|
||||
files.append(file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rebuild file from memory: %s", e)
|
||||
continue
|
||||
return files
|
||||
|
||||
def _build_prompt_message(
|
||||
self,
|
||||
role: str,
|
||||
content: str,
|
||||
files: list[File],
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
|
||||
) -> PromptMessage:
|
||||
"""Build PromptMessage from content and files."""
|
||||
from core.file import file_manager
|
||||
|
||||
if not files:
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=content)
|
||||
else:
|
||||
return AssistantPromptMessage(content=content)
|
||||
|
||||
# Build multimodal content
|
||||
prompt_contents: list = []
|
||||
for file in files:
|
||||
try:
|
||||
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
|
||||
prompt_contents.append(prompt_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert file to prompt content: %s", e)
|
||||
continue
|
||||
|
||||
prompt_contents.append(TextPromptMessageContent(data=content))
|
||||
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=prompt_contents)
|
||||
else:
|
||||
return AssistantPromptMessage(content=prompt_contents)
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
# Query messages for this conversation
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
messages = db.session.scalars(stmt.limit(500)).all()
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
@@ -223,46 +87,31 @@ class NodeTokenBufferMemory(BaseMemory):
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
workflow_run_ids = []
|
||||
for msg in reversed(thread_messages):
|
||||
if msg.workflow_run_id:
|
||||
workflow_run_ids.append(msg.workflow_run_id)
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
return workflow_run_ids
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
user_content: str,
|
||||
user_files: Sequence[File] | None = None,
|
||||
assistant_content: str = "",
|
||||
assistant_files: Sequence[File] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a dialogue turn to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param workflow_run_id: Current workflow execution ID
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
data = self._load()
|
||||
|
||||
# Convert files to memory file references
|
||||
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
|
||||
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
|
||||
|
||||
# Store the turn indexed by workflow_run_id
|
||||
data.turns[workflow_run_id] = NodeMemoryTurn(
|
||||
user_content=user_content,
|
||||
user_files=user_memory_files,
|
||||
assistant_content=assistant_content,
|
||||
assistant_files=assistant_memory_files,
|
||||
)
|
||||
|
||||
self._dirty = True
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
messages.append(self._deserialize_prompt_message(msg_dict))
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
@@ -272,55 +121,38 @@ class NodeTokenBufferMemory(BaseMemory):
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Thread tracking is handled by querying Message table's parent_message_id structure.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: unused, for interface compatibility
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
|
||||
_ = message_limit
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
data = self._load()
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
if not data.turns:
|
||||
return []
|
||||
|
||||
# Get workflow_run_ids for current thread from Message table
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Build prompt messages in thread order
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for wf_run_id in thread_workflow_run_ids:
|
||||
turn = data.turns.get(wf_run_id)
|
||||
if not turn:
|
||||
# This workflow execution didn't have node memory stored
|
||||
continue
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
# Build user message
|
||||
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
|
||||
user_msg = self._build_prompt_message(
|
||||
role="user",
|
||||
content=turn.user_content,
|
||||
files=user_files,
|
||||
detail=detail,
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
prompt_messages.append(user_msg)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
# Build assistant message
|
||||
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
|
||||
assistant_msg = self._build_prompt_message(
|
||||
role="assistant",
|
||||
content=turn.assistant_content,
|
||||
files=assistant_files,
|
||||
detail=detail,
|
||||
)
|
||||
prompt_messages.append(assistant_msg)
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
@@ -334,20 +166,3 @@ class NodeTokenBufferMemory(BaseMemory):
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
if self._dirty:
|
||||
self._save()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages in this node's memory."""
|
||||
self._data = NodeMemoryData()
|
||||
self._save()
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if node memory exists in storage."""
|
||||
return storage.exists(self._storage_key)
|
||||
|
||||
@@ -276,7 +276,5 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -17,6 +17,12 @@ from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
@@ -527,6 +533,95 @@ class AgentNode(Node[AgentNodeData]):
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@@ -782,20 +877,11 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Get user query from sys.query
|
||||
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
|
||||
user_query = user_query_var.text if user_query_var else ""
|
||||
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
user_query=user_query,
|
||||
assistant_response=text,
|
||||
assistant_files=files,
|
||||
)
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
@@ -805,6 +891,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@@ -285,7 +285,7 @@ class Node(Generic[NodeDataT]):
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
@@ -349,7 +349,7 @@ class Node(Generic[NodeDataT]):
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_extractor_nodes()
|
||||
yield from self._execute_mention_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
|
||||
@@ -12,6 +12,13 @@ from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
@@ -139,50 +146,6 @@ def fetch_memory(
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
def save_node_memory(
|
||||
memory: BaseMemory | None,
|
||||
variable_pool: VariablePool,
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
user_files: Sequence["File"] | None = None,
|
||||
assistant_files: Sequence["File"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save dialogue turn to node memory if applicable.
|
||||
|
||||
This function handles the storage logic for NodeTokenBufferMemory.
|
||||
For TokenBufferMemory (conversation-level), no action is taken as it uses
|
||||
the Message table which is managed elsewhere.
|
||||
|
||||
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param user_query: User's input text
|
||||
:param assistant_response: Assistant's response text
|
||||
:param user_files: Files attached by user (optional)
|
||||
:param assistant_files: Files generated by assistant (optional)
|
||||
"""
|
||||
if not isinstance(memory, NodeTokenBufferMemory):
|
||||
return
|
||||
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
if not isinstance(workflow_run_id_var, StringSegment):
|
||||
return
|
||||
|
||||
workflow_run_id = workflow_run_id_var.value
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=user_query,
|
||||
user_files=list(user_files) if user_files else None,
|
||||
assistant_content=assistant_response,
|
||||
assistant_files=list(assistant_files) if assistant_files else None,
|
||||
)
|
||||
memory.flush()
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
provider_configuration = provider_model_bundle.configuration
|
||||
@@ -246,3 +209,45 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
@@ -20,7 +20,6 @@ from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
@@ -327,25 +326,13 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text),
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Write to Node Memory if in node memory mode
|
||||
# Resolve the query template to get actual user content
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=actual_query,
|
||||
assistant_response=clean_text,
|
||||
user_files=files,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@@ -607,48 +594,6 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
@staticmethod
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
@staticmethod
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@@ -716,54 +661,158 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Build a map from context index to its messages
|
||||
context_messages_map: dict[int, list[PromptMessage]] = {}
|
||||
# Process messages in DSL order: iterate once and handle each type directly
|
||||
combined_messages: list[PromptMessage] = []
|
||||
context_idx = 0
|
||||
for idx, type_ in template_order:
|
||||
static_idx = 0
|
||||
|
||||
for _, type_ in template_order:
|
||||
if type_ == "context":
|
||||
# Handle context reference
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
context_messages_map[idx] = list(ctx_var.value)
|
||||
combined_messages.extend(ctx_var.value)
|
||||
context_idx += 1
|
||||
|
||||
# Process static messages
|
||||
static_prompt_messages: Sequence[PromptMessage] = []
|
||||
stop: Sequence[str] | None = None
|
||||
if static_messages:
|
||||
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# Combine messages according to original DSL order
|
||||
combined_messages: list[PromptMessage] = []
|
||||
static_msg_iter = iter(static_prompt_messages)
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
combined_messages.extend(context_messages_map[idx])
|
||||
else:
|
||||
if msg := next(static_msg_iter, None):
|
||||
combined_messages.append(msg)
|
||||
# Append any remaining static messages (e.g., memory messages)
|
||||
combined_messages.extend(static_msg_iter)
|
||||
# Handle static message
|
||||
static_msg = static_messages[static_idx]
|
||||
processed_msgs = LLMNode.handle_list_messages(
|
||||
messages=[static_msg],
|
||||
context=context,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(processed_msgs)
|
||||
static_idx += 1
|
||||
|
||||
# Append memory messages
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=self.node_data.memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
combined_messages.extend(memory_messages)
|
||||
|
||||
# Append current query if provided
|
||||
if query:
|
||||
query_message = LLMNodeChatModelMessage(
|
||||
text=query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
query_msgs = LLMNode.handle_list_messages(
|
||||
messages=[query_message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(query_msgs)
|
||||
|
||||
# Handle files (sys_files and context_files)
|
||||
combined_messages = self._append_files_to_messages(
|
||||
messages=combined_messages,
|
||||
sys_files=files,
|
||||
context_files=context_files,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Filter empty messages and get stop sequences
|
||||
combined_messages = self._filter_messages(combined_messages, model_config)
|
||||
stop = self._get_stop_sequences(model_config)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _append_files_to_messages(
|
||||
self,
|
||||
*,
|
||||
messages: list[PromptMessage],
|
||||
sys_files: Sequence[File],
|
||||
context_files: list[File],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
"""Append sys_files and context_files to messages."""
|
||||
vision_enabled = self.node_data.vision.enabled
|
||||
vision_detail = self.node_data.vision.configs.detail
|
||||
|
||||
# Handle sys_files (will be deprecated later)
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Handle context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
for file in context_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
return messages
|
||||
|
||||
def _filter_messages(
|
||||
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> list[PromptMessage]:
|
||||
"""Filter empty messages and unsupported content types."""
|
||||
filtered_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
filtered_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in message.content:
|
||||
# Skip non-text content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
feature_map = {
|
||||
PromptMessageContentType.IMAGE: ModelFeature.VISION,
|
||||
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
|
||||
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
|
||||
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
|
||||
}
|
||||
required_feature = feature_map.get(content_item.type)
|
||||
if required_feature and required_feature not in model_config.model_schema.features:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
|
||||
# Simplify single text content
|
||||
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
|
||||
message.content = filtered_content[0].data
|
||||
else:
|
||||
message.content = filtered_content
|
||||
|
||||
if not message.is_empty():
|
||||
filtered_messages.append(message)
|
||||
|
||||
if not filtered_messages:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_messages
|
||||
|
||||
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
|
||||
"""Get stop sequences from model config."""
|
||||
return model_config.stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
|
||||
@@ -246,13 +246,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query,
|
||||
assistant_response=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = json.dumps(result, ensure_ascii=False)
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -262,6 +258,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@@ -199,20 +199,17 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
}
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query or "",
|
||||
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
|
||||
Reference in New Issue
Block a user