mirror of
https://github.com/langgenius/dify.git
synced 2026-02-09 15:10:13 -05:00
feat: extract mcp tool usage (#31802)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from core.mcp.auth_client import MCPClientWithAuthRetry
|
||||
from core.mcp.error import MCPConnectionError
|
||||
@@ -17,6 +17,7 @@ from core.mcp.types import (
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType
|
||||
@@ -46,6 +47,7 @@ class MCPTool(Tool):
|
||||
self.headers = headers or {}
|
||||
self.timeout = timeout
|
||||
self.sse_read_timeout = sse_read_timeout
|
||||
self._latest_usage = LLMUsage.empty_usage()
|
||||
|
||||
def tool_provider_type(self) -> ToolProviderType:
|
||||
return ToolProviderType.MCP
|
||||
@@ -59,6 +61,10 @@ class MCPTool(Tool):
|
||||
message_id: str | None = None,
|
||||
) -> Generator[ToolInvokeMessage, None, None]:
|
||||
result = self.invoke_remote_mcp_tool(tool_parameters)
|
||||
|
||||
# Extract usage metadata from MCP protocol's _meta field
|
||||
self._latest_usage = self._derive_usage_from_result(result)
|
||||
|
||||
# handle dify tool output
|
||||
for content in result.content:
|
||||
if isinstance(content, TextContent):
|
||||
@@ -120,6 +126,99 @@ class MCPTool(Tool):
|
||||
for item in json_list:
|
||||
yield self.create_json_message(item)
|
||||
|
||||
@property
|
||||
def latest_usage(self) -> LLMUsage:
|
||||
return self._latest_usage
|
||||
|
||||
@classmethod
|
||||
def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage:
|
||||
"""
|
||||
Extract usage metadata from MCP tool result's _meta field.
|
||||
|
||||
The MCP protocol's _meta field (aliased as 'meta' in Python) can contain
|
||||
usage information such as token counts, costs, and other metadata.
|
||||
|
||||
Args:
|
||||
result: The CallToolResult from MCP tool invocation
|
||||
|
||||
Returns:
|
||||
LLMUsage instance with values from meta or empty_usage if not found
|
||||
"""
|
||||
# Extract usage from the meta field if present
|
||||
if result.meta:
|
||||
usage_dict = cls._extract_usage_dict(result.meta)
|
||||
if usage_dict is not None:
|
||||
return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict))))
|
||||
|
||||
return LLMUsage.empty_usage()
|
||||
|
||||
@classmethod
|
||||
def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None:
|
||||
"""
|
||||
Recursively search for usage dictionary in the payload.
|
||||
|
||||
The MCP protocol's _meta field can contain usage data in various formats:
|
||||
- Direct usage field: {"usage": {...}}
|
||||
- Nested in metadata: {"metadata": {"usage": {...}}}
|
||||
- Or nested within other fields
|
||||
|
||||
Args:
|
||||
payload: The payload to search for usage data
|
||||
|
||||
Returns:
|
||||
The usage dictionary if found, None otherwise
|
||||
"""
|
||||
# Check for direct usage field
|
||||
usage_candidate = payload.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for metadata nested usage
|
||||
metadata_candidate = payload.get("metadata")
|
||||
if isinstance(metadata_candidate, Mapping):
|
||||
usage_candidate = metadata_candidate.get("usage")
|
||||
if isinstance(usage_candidate, Mapping):
|
||||
return usage_candidate
|
||||
|
||||
# Check for common token counting fields directly in payload
|
||||
# Some MCP servers may include token counts directly
|
||||
if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload:
|
||||
usage_dict: dict[str, Any] = {}
|
||||
for key in (
|
||||
"prompt_tokens",
|
||||
"completion_tokens",
|
||||
"total_tokens",
|
||||
"prompt_unit_price",
|
||||
"completion_unit_price",
|
||||
"total_price",
|
||||
"currency",
|
||||
"prompt_price_unit",
|
||||
"completion_price_unit",
|
||||
"prompt_price",
|
||||
"completion_price",
|
||||
"latency",
|
||||
"time_to_first_token",
|
||||
"time_to_generate",
|
||||
):
|
||||
if key in payload:
|
||||
usage_dict[key] = payload[key]
|
||||
if usage_dict:
|
||||
return usage_dict
|
||||
|
||||
# Recursively search through nested structures
|
||||
for value in payload.values():
|
||||
if isinstance(value, Mapping):
|
||||
found = cls._extract_usage_dict(value)
|
||||
if found is not None:
|
||||
return found
|
||||
elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)):
|
||||
for item in value:
|
||||
if isinstance(item, Mapping):
|
||||
found = cls._extract_usage_dict(item)
|
||||
if found is not None:
|
||||
return found
|
||||
return None
|
||||
|
||||
def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool:
|
||||
return MCPTool(
|
||||
entity=self.entity,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
from decimal import Decimal
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -9,8 +10,10 @@ from core.mcp.types import (
|
||||
CallToolResult,
|
||||
EmbeddedResource,
|
||||
ImageContent,
|
||||
TextContent,
|
||||
TextResourceContents,
|
||||
)
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage
|
||||
@@ -120,3 +123,231 @@ class TestMCPToolInvoke:
|
||||
# Validate values
|
||||
values = {m.message.variable_name: m.message.variable_value for m in var_msgs}
|
||||
assert values == {"a": 1, "b": "x"}
|
||||
|
||||
|
||||
class TestMCPToolUsageExtraction:
|
||||
"""Test usage metadata extraction from MCP tool results."""
|
||||
|
||||
def test_extract_usage_dict_from_direct_usage_field(self) -> None:
|
||||
"""Test extraction when usage is directly in meta.usage field."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.001",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 100
|
||||
assert usage_dict["completion_tokens"] == 50
|
||||
assert usage_dict["total_tokens"] == 150
|
||||
assert usage_dict["total_price"] == "0.001"
|
||||
assert usage_dict["currency"] == "USD"
|
||||
|
||||
def test_extract_usage_dict_from_nested_metadata(self) -> None:
|
||||
"""Test extraction when usage is nested in meta.metadata.usage."""
|
||||
meta = {
|
||||
"metadata": {
|
||||
"usage": {
|
||||
"prompt_tokens": 200,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
}
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 200
|
||||
assert usage_dict["total_tokens"] == 300
|
||||
|
||||
def test_extract_usage_dict_from_flat_token_fields(self) -> None:
|
||||
"""Test extraction when token counts are directly in meta."""
|
||||
meta = {
|
||||
"prompt_tokens": 150,
|
||||
"completion_tokens": 75,
|
||||
"total_tokens": 225,
|
||||
"currency": "EUR",
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["prompt_tokens"] == 150
|
||||
assert usage_dict["completion_tokens"] == 75
|
||||
assert usage_dict["total_tokens"] == 225
|
||||
assert usage_dict["currency"] == "EUR"
|
||||
|
||||
def test_extract_usage_dict_recursive(self) -> None:
|
||||
"""Test recursive search through nested structures."""
|
||||
meta = {
|
||||
"custom": {
|
||||
"nested": {
|
||||
"usage": {
|
||||
"total_tokens": 500,
|
||||
"prompt_tokens": 300,
|
||||
"completion_tokens": 200,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["total_tokens"] == 500
|
||||
|
||||
def test_extract_usage_dict_from_list(self) -> None:
|
||||
"""Test extraction from nested list structures."""
|
||||
meta = {
|
||||
"items": [
|
||||
{"usage": {"total_tokens": 100}},
|
||||
{"other": "data"},
|
||||
]
|
||||
}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is not None
|
||||
assert usage_dict["total_tokens"] == 100
|
||||
|
||||
def test_extract_usage_dict_returns_none_when_missing(self) -> None:
|
||||
"""Test that None is returned when no usage data is present."""
|
||||
meta = {"other": "data", "custom": {"nested": {"value": 123}}}
|
||||
usage_dict = MCPTool._extract_usage_dict(meta)
|
||||
assert usage_dict is None
|
||||
|
||||
def test_extract_usage_dict_empty_meta(self) -> None:
|
||||
"""Test with empty meta dict."""
|
||||
usage_dict = MCPTool._extract_usage_dict({})
|
||||
assert usage_dict is None
|
||||
|
||||
def test_derive_usage_from_result_with_meta(self) -> None:
|
||||
"""Test _derive_usage_from_result with populated meta."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.0015",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[], _meta=meta)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
assert usage.total_tokens == 150
|
||||
assert usage.total_price == Decimal("0.0015")
|
||||
assert usage.currency == "USD"
|
||||
|
||||
def test_derive_usage_from_result_without_meta(self) -> None:
|
||||
"""Test _derive_usage_from_result with no meta returns empty usage."""
|
||||
result = CallToolResult(content=[], meta=None)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
assert usage.total_tokens == 0
|
||||
assert usage.prompt_tokens == 0
|
||||
assert usage.completion_tokens == 0
|
||||
|
||||
def test_derive_usage_from_result_calculates_total_tokens(self) -> None:
|
||||
"""Test that total_tokens is calculated when missing."""
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 100,
|
||||
"completion_tokens": 50,
|
||||
# total_tokens is missing
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[], _meta=meta)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert usage.total_tokens == 150 # 100 + 50
|
||||
assert usage.prompt_tokens == 100
|
||||
assert usage.completion_tokens == 50
|
||||
|
||||
def test_invoke_sets_latest_usage_from_meta(self) -> None:
|
||||
"""Test that _invoke sets _latest_usage from result meta."""
|
||||
tool = _make_mcp_tool()
|
||||
meta = {
|
||||
"usage": {
|
||||
"prompt_tokens": 200,
|
||||
"completion_tokens": 100,
|
||||
"total_tokens": 300,
|
||||
"total_price": "0.003",
|
||||
"currency": "USD",
|
||||
}
|
||||
}
|
||||
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=meta)
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Verify latest_usage was set correctly
|
||||
assert tool.latest_usage.prompt_tokens == 200
|
||||
assert tool.latest_usage.completion_tokens == 100
|
||||
assert tool.latest_usage.total_tokens == 300
|
||||
assert tool.latest_usage.total_price == Decimal("0.003")
|
||||
|
||||
def test_invoke_with_no_meta_returns_empty_usage(self) -> None:
|
||||
"""Test that _invoke returns empty usage when no meta is present."""
|
||||
tool = _make_mcp_tool()
|
||||
result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=None)
|
||||
|
||||
with patch.object(tool, "invoke_remote_mcp_tool", return_value=result):
|
||||
list(tool._invoke(user_id="test_user", tool_parameters={}))
|
||||
|
||||
# Verify latest_usage is empty
|
||||
assert tool.latest_usage.total_tokens == 0
|
||||
assert tool.latest_usage.prompt_tokens == 0
|
||||
assert tool.latest_usage.completion_tokens == 0
|
||||
|
||||
def test_latest_usage_property_returns_llm_usage(self) -> None:
|
||||
"""Test that latest_usage property returns LLMUsage instance."""
|
||||
tool = _make_mcp_tool()
|
||||
assert isinstance(tool.latest_usage, LLMUsage)
|
||||
|
||||
def test_initial_usage_is_empty(self) -> None:
|
||||
"""Test that MCPTool is initialized with empty usage."""
|
||||
tool = _make_mcp_tool()
|
||||
assert tool.latest_usage.total_tokens == 0
|
||||
assert tool.latest_usage.prompt_tokens == 0
|
||||
assert tool.latest_usage.completion_tokens == 0
|
||||
assert tool.latest_usage.total_price == Decimal(0)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"meta_data",
|
||||
[
|
||||
# Direct usage field
|
||||
{"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}},
|
||||
# Nested metadata
|
||||
{"metadata": {"usage": {"total_tokens": 100}}},
|
||||
# Flat token fields
|
||||
{"total_tokens": 50, "prompt_tokens": 30, "completion_tokens": 20},
|
||||
# With price info
|
||||
{
|
||||
"usage": {
|
||||
"total_tokens": 150,
|
||||
"total_price": "0.002",
|
||||
"currency": "EUR",
|
||||
}
|
||||
},
|
||||
# Deep nested
|
||||
{"level1": {"level2": {"usage": {"total_tokens": 200}}}},
|
||||
],
|
||||
)
|
||||
def test_various_meta_formats(self, meta_data) -> None:
|
||||
"""Test that various meta formats are correctly parsed."""
|
||||
result = CallToolResult(content=[], _meta=meta_data)
|
||||
usage = MCPTool._derive_usage_from_result(result)
|
||||
|
||||
assert isinstance(usage, LLMUsage)
|
||||
# Should have at least some usage data
|
||||
if meta_data.get("usage", {}).get("total_tokens") or meta_data.get("total_tokens"):
|
||||
expected_total = (
|
||||
meta_data.get("usage", {}).get("total_tokens")
|
||||
or meta_data.get("total_tokens")
|
||||
or meta_data.get("metadata", {}).get("usage", {}).get("total_tokens")
|
||||
or meta_data.get("level1", {}).get("level2", {}).get("usage", {}).get("total_tokens")
|
||||
)
|
||||
if expected_total:
|
||||
assert usage.total_tokens == expected_total
|
||||
|
||||
Reference in New Issue
Block a user