diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index ef9e9c103a..1d439323f2 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -3,8 +3,8 @@ from __future__ import annotations import base64 import json import logging -from collections.abc import Generator -from typing import Any +from collections.abc import Generator, Mapping +from typing import Any, cast from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError @@ -17,6 +17,7 @@ from core.mcp.types import ( TextContent, TextResourceContents, ) +from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType @@ -46,6 +47,7 @@ class MCPTool(Tool): self.headers = headers or {} self.timeout = timeout self.sse_read_timeout = sse_read_timeout + self._latest_usage = LLMUsage.empty_usage() def tool_provider_type(self) -> ToolProviderType: return ToolProviderType.MCP @@ -59,6 +61,10 @@ class MCPTool(Tool): message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: result = self.invoke_remote_mcp_tool(tool_parameters) + + # Extract usage metadata from MCP protocol's _meta field + self._latest_usage = self._derive_usage_from_result(result) + # handle dify tool output for content in result.content: if isinstance(content, TextContent): @@ -120,6 +126,99 @@ class MCPTool(Tool): for item in json_list: yield self.create_json_message(item) + @property + def latest_usage(self) -> LLMUsage: + return self._latest_usage + + @classmethod + def _derive_usage_from_result(cls, result: CallToolResult) -> LLMUsage: + """ + Extract usage metadata from MCP tool result's _meta field. + + The MCP protocol's _meta field (aliased as 'meta' in Python) can contain + usage information such as token counts, costs, and other metadata. + + Args: + result: The CallToolResult from MCP tool invocation + + Returns: + LLMUsage instance with values from meta or empty_usage if not found + """ + # Extract usage from the meta field if present + if result.meta: + usage_dict = cls._extract_usage_dict(result.meta) + if usage_dict is not None: + return LLMUsage.from_metadata(cast(LLMUsageMetadata, cast(object, dict(usage_dict)))) + + return LLMUsage.empty_usage() + + @classmethod + def _extract_usage_dict(cls, payload: Mapping[str, Any]) -> Mapping[str, Any] | None: + """ + Recursively search for usage dictionary in the payload. + + The MCP protocol's _meta field can contain usage data in various formats: + - Direct usage field: {"usage": {...}} + - Nested in metadata: {"metadata": {"usage": {...}}} + - Or nested within other fields + + Args: + payload: The payload to search for usage data + + Returns: + The usage dictionary if found, None otherwise + """ + # Check for direct usage field + usage_candidate = payload.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + # Check for metadata nested usage + metadata_candidate = payload.get("metadata") + if isinstance(metadata_candidate, Mapping): + usage_candidate = metadata_candidate.get("usage") + if isinstance(usage_candidate, Mapping): + return usage_candidate + + # Check for common token counting fields directly in payload + # Some MCP servers may include token counts directly + if "total_tokens" in payload or "prompt_tokens" in payload or "completion_tokens" in payload: + usage_dict: dict[str, Any] = {} + for key in ( + "prompt_tokens", + "completion_tokens", + "total_tokens", + "prompt_unit_price", + "completion_unit_price", + "total_price", + "currency", + "prompt_price_unit", + "completion_price_unit", + "prompt_price", + "completion_price", + "latency", + "time_to_first_token", + "time_to_generate", + ): + if key in payload: + usage_dict[key] = payload[key] + if usage_dict: + return usage_dict + + # Recursively search through nested structures + for value in payload.values(): + if isinstance(value, Mapping): + found = cls._extract_usage_dict(value) + if found is not None: + return found + elif isinstance(value, list) and not isinstance(value, (str, bytes, bytearray)): + for item in value: + if isinstance(item, Mapping): + found = cls._extract_usage_dict(item) + if found is not None: + return found + return None + def fork_tool_runtime(self, runtime: ToolRuntime) -> MCPTool: return MCPTool( entity=self.entity, diff --git a/api/pyproject.toml b/api/pyproject.toml index 3d61c652f5..f6f47403df 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -85,7 +85,7 @@ dependencies = [ "starlette==0.49.1", "tiktoken~=0.9.0", "transformers~=4.56.1", - "unstructured[docx,epub,md,ppt,pptx]~=0.16.1", + "unstructured[docx,epub,md,ppt,pptx]~=0.18.18", "yarl~=1.18.3", "webvtt-py~=0.5.1", "sseclient-py~=1.8.0", diff --git a/api/tasks/document_indexing_update_task.py b/api/tasks/document_indexing_update_task.py index 45d58c92ec..c7508c6d05 100644 --- a/api/tasks/document_indexing_update_task.py +++ b/api/tasks/document_indexing_update_task.py @@ -36,25 +36,19 @@ def document_indexing_update_task(dataset_id: str, document_id: str): document.indexing_status = "parsing" document.processing_started_at = naive_utc_now() - # delete all document segment and index - try: - dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() - if not dataset: - raise Exception("Dataset not found") + dataset = session.query(Dataset).where(Dataset.id == dataset_id).first() + if not dataset: + return - index_type = document.doc_form - index_processor = IndexProcessorFactory(index_type).init_index_processor() - - segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() - if segments: - index_node_ids = [segment.index_node_id for segment in segments] - - # delete from vector index - index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) - segment_ids = [segment.id for segment in segments] - segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids)) - session.execute(segment_delete_stmt) + index_type = document.doc_form + segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all() + index_node_ids = [segment.index_node_id for segment in segments] + clean_success = False + try: + index_processor = IndexProcessorFactory(index_type).init_index_processor() + if index_node_ids: + index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True) end_at = time.perf_counter() logger.info( click.style( @@ -64,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str): fg="green", ) ) - except Exception: - logger.exception("Cleaned document when document update data source or process rule failed") + clean_success = True + except Exception: + logger.exception("Failed to clean document index during update, document_id: %s", document_id) - try: - indexing_runner = IndexingRunner() - indexing_runner.run([document]) - end_at = time.perf_counter() - logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) - except DocumentIsPausedError as ex: - logger.info(click.style(str(ex), fg="yellow")) - except Exception: - logger.exception("document_indexing_update_task failed, document_id: %s", document_id) + if clean_success: + with session_factory.create_session() as session, session.begin(): + segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id) + session.execute(segment_delete_stmt) + + try: + indexing_runner = IndexingRunner() + indexing_runner.run([document]) + end_at = time.perf_counter() + logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green")) + except DocumentIsPausedError as ex: + logger.info(click.style(str(ex), fg="yellow")) + except Exception: + logger.exception("document_indexing_update_task failed, document_id: %s", document_id) diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index a527773e4e..5930b63f58 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -1,4 +1,5 @@ import base64 +from decimal import Decimal from unittest.mock import Mock, patch import pytest @@ -9,8 +10,10 @@ from core.mcp.types import ( CallToolResult, EmbeddedResource, ImageContent, + TextContent, TextResourceContents, ) +from core.model_runtime.entities.llm_entities import LLMUsage from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage @@ -120,3 +123,231 @@ class TestMCPToolInvoke: # Validate values values = {m.message.variable_name: m.message.variable_value for m in var_msgs} assert values == {"a": 1, "b": "x"} + + +class TestMCPToolUsageExtraction: + """Test usage metadata extraction from MCP tool results.""" + + def test_extract_usage_dict_from_direct_usage_field(self) -> None: + """Test extraction when usage is directly in meta.usage field.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "total_price": "0.001", + "currency": "USD", + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 100 + assert usage_dict["completion_tokens"] == 50 + assert usage_dict["total_tokens"] == 150 + assert usage_dict["total_price"] == "0.001" + assert usage_dict["currency"] == "USD" + + def test_extract_usage_dict_from_nested_metadata(self) -> None: + """Test extraction when usage is nested in meta.metadata.usage.""" + meta = { + "metadata": { + "usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + } + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 200 + assert usage_dict["total_tokens"] == 300 + + def test_extract_usage_dict_from_flat_token_fields(self) -> None: + """Test extraction when token counts are directly in meta.""" + meta = { + "prompt_tokens": 150, + "completion_tokens": 75, + "total_tokens": 225, + "currency": "EUR", + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["prompt_tokens"] == 150 + assert usage_dict["completion_tokens"] == 75 + assert usage_dict["total_tokens"] == 225 + assert usage_dict["currency"] == "EUR" + + def test_extract_usage_dict_recursive(self) -> None: + """Test recursive search through nested structures.""" + meta = { + "custom": { + "nested": { + "usage": { + "total_tokens": 500, + "prompt_tokens": 300, + "completion_tokens": 200, + } + } + } + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["total_tokens"] == 500 + + def test_extract_usage_dict_from_list(self) -> None: + """Test extraction from nested list structures.""" + meta = { + "items": [ + {"usage": {"total_tokens": 100}}, + {"other": "data"}, + ] + } + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is not None + assert usage_dict["total_tokens"] == 100 + + def test_extract_usage_dict_returns_none_when_missing(self) -> None: + """Test that None is returned when no usage data is present.""" + meta = {"other": "data", "custom": {"nested": {"value": 123}}} + usage_dict = MCPTool._extract_usage_dict(meta) + assert usage_dict is None + + def test_extract_usage_dict_empty_meta(self) -> None: + """Test with empty meta dict.""" + usage_dict = MCPTool._extract_usage_dict({}) + assert usage_dict is None + + def test_derive_usage_from_result_with_meta(self) -> None: + """Test _derive_usage_from_result with populated meta.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + "total_tokens": 150, + "total_price": "0.0015", + "currency": "USD", + } + } + result = CallToolResult(content=[], _meta=meta) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + assert usage.total_tokens == 150 + assert usage.total_price == Decimal("0.0015") + assert usage.currency == "USD" + + def test_derive_usage_from_result_without_meta(self) -> None: + """Test _derive_usage_from_result with no meta returns empty usage.""" + result = CallToolResult(content=[], meta=None) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + assert usage.total_tokens == 0 + assert usage.prompt_tokens == 0 + assert usage.completion_tokens == 0 + + def test_derive_usage_from_result_calculates_total_tokens(self) -> None: + """Test that total_tokens is calculated when missing.""" + meta = { + "usage": { + "prompt_tokens": 100, + "completion_tokens": 50, + # total_tokens is missing + } + } + result = CallToolResult(content=[], _meta=meta) + usage = MCPTool._derive_usage_from_result(result) + + assert usage.total_tokens == 150 # 100 + 50 + assert usage.prompt_tokens == 100 + assert usage.completion_tokens == 50 + + def test_invoke_sets_latest_usage_from_meta(self) -> None: + """Test that _invoke sets _latest_usage from result meta.""" + tool = _make_mcp_tool() + meta = { + "usage": { + "prompt_tokens": 200, + "completion_tokens": 100, + "total_tokens": 300, + "total_price": "0.003", + "currency": "USD", + } + } + result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=meta) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + list(tool._invoke(user_id="test_user", tool_parameters={})) + + # Verify latest_usage was set correctly + assert tool.latest_usage.prompt_tokens == 200 + assert tool.latest_usage.completion_tokens == 100 + assert tool.latest_usage.total_tokens == 300 + assert tool.latest_usage.total_price == Decimal("0.003") + + def test_invoke_with_no_meta_returns_empty_usage(self) -> None: + """Test that _invoke returns empty usage when no meta is present.""" + tool = _make_mcp_tool() + result = CallToolResult(content=[TextContent(type="text", text="test")], _meta=None) + + with patch.object(tool, "invoke_remote_mcp_tool", return_value=result): + list(tool._invoke(user_id="test_user", tool_parameters={})) + + # Verify latest_usage is empty + assert tool.latest_usage.total_tokens == 0 + assert tool.latest_usage.prompt_tokens == 0 + assert tool.latest_usage.completion_tokens == 0 + + def test_latest_usage_property_returns_llm_usage(self) -> None: + """Test that latest_usage property returns LLMUsage instance.""" + tool = _make_mcp_tool() + assert isinstance(tool.latest_usage, LLMUsage) + + def test_initial_usage_is_empty(self) -> None: + """Test that MCPTool is initialized with empty usage.""" + tool = _make_mcp_tool() + assert tool.latest_usage.total_tokens == 0 + assert tool.latest_usage.prompt_tokens == 0 + assert tool.latest_usage.completion_tokens == 0 + assert tool.latest_usage.total_price == Decimal(0) + + @pytest.mark.parametrize( + "meta_data", + [ + # Direct usage field + {"usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}}, + # Nested metadata + {"metadata": {"usage": {"total_tokens": 100}}}, + # Flat token fields + {"total_tokens": 50, "prompt_tokens": 30, "completion_tokens": 20}, + # With price info + { + "usage": { + "total_tokens": 150, + "total_price": "0.002", + "currency": "EUR", + } + }, + # Deep nested + {"level1": {"level2": {"usage": {"total_tokens": 200}}}}, + ], + ) + def test_various_meta_formats(self, meta_data) -> None: + """Test that various meta formats are correctly parsed.""" + result = CallToolResult(content=[], _meta=meta_data) + usage = MCPTool._derive_usage_from_result(result) + + assert isinstance(usage, LLMUsage) + # Should have at least some usage data + if meta_data.get("usage", {}).get("total_tokens") or meta_data.get("total_tokens"): + expected_total = ( + meta_data.get("usage", {}).get("total_tokens") + or meta_data.get("total_tokens") + or meta_data.get("metadata", {}).get("usage", {}).get("total_tokens") + or meta_data.get("level1", {}).get("level2", {}).get("usage", {}).get("total_tokens") + ) + if expected_total: + assert usage.total_tokens == expected_total diff --git a/api/uv.lock b/api/uv.lock index 2fd670fdff..0a746e7df1 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -1783,7 +1783,7 @@ requires-dist = [ { name = "starlette", specifier = "==0.49.1" }, { name = "tiktoken", specifier = "~=0.9.0" }, { name = "transformers", specifier = "~=4.56.1" }, - { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.16.1" }, + { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.18.18" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.17.0" }, { name = "webvtt-py", specifier = "~=0.5.1" }, @@ -7085,12 +7085,12 @@ wheels = [ [[package]] name = "unstructured" -version = "0.16.25" +version = "0.18.31" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "backoff" }, { name = "beautifulsoup4" }, - { name = "chardet" }, + { name = "charset-normalizer" }, { name = "dataclasses-json" }, { name = "emoji" }, { name = "filetype" }, @@ -7098,6 +7098,7 @@ dependencies = [ { name = "langdetect" }, { name = "lxml" }, { name = "nltk" }, + { name = "numba" }, { name = "numpy" }, { name = "psutil" }, { name = "python-iso639" }, @@ -7110,9 +7111,9 @@ dependencies = [ { name = "unstructured-client" }, { name = "wrapt" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/64/31/98c4c78e305d1294888adf87fd5ee30577a4c393951341ca32b43f167f1e/unstructured-0.16.25.tar.gz", hash = "sha256:73b9b0f51dbb687af572ecdb849a6811710b9cac797ddeab8ee80fa07d8aa5e6", size = 1683097, upload-time = "2025-03-07T11:19:39.507Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a9/5f/64285bd69a538bc28753f1423fcaa9d64cd79a9e7c097171b1f0d27e9cdb/unstructured-0.18.31.tar.gz", hash = "sha256:af4bbe32d1894ae6e755f0da6fc0dd307a1d0adeebe0e7cc6278f6cf744339ca", size = 1707700, upload-time = "2026-01-27T15:33:05.378Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/12/4f/ad08585b5c8a33c82ea119494c4d3023f4796958c56e668b15cc282ec0a0/unstructured-0.16.25-py3-none-any.whl", hash = "sha256:14719ccef2830216cf1c5bf654f75e2bf07b17ca5dcee9da5ac74618130fd337", size = 1769286, upload-time = "2025-03-07T11:19:37.299Z" }, + { url = "https://files.pythonhosted.org/packages/c8/4a/9c43f39d9e443c9bc3f2e379b305bca27110adc653b071221b3132c18de5/unstructured-0.18.31-py3-none-any.whl", hash = "sha256:fab4641176cb9b192ed38048758aa0d9843121d03626d18f42275afb31e5b2d3", size = 1794889, upload-time = "2026-01-27T15:33:03.136Z" }, ] [package.optional-dependencies] diff --git a/web/app/components/header/account-setting/model-provider-page/hooks.ts b/web/app/components/header/account-setting/model-provider-page/hooks.ts index 6aba41d4e4..8db964cc27 100644 --- a/web/app/components/header/account-setting/model-provider-page/hooks.ts +++ b/web/app/components/header/account-setting/model-provider-page/hooks.ts @@ -308,7 +308,7 @@ export const useMarketplaceAllPlugins = (providers: ModelProvider[], searchText: }, [plugins, collectionPlugins, exclude]) return { - plugins: allPlugins, + plugins: searchText ? plugins : allPlugins, isLoading: isCollectionLoading || isPluginsLoading, } } diff --git a/web/app/components/workflow/skill/editor/skill-editor/plugins/file-picker-panel.tsx b/web/app/components/workflow/skill/editor/skill-editor/plugins/file-picker-panel.tsx index 39692fff98..c57dbb02bf 100644 --- a/web/app/components/workflow/skill/editor/skill-editor/plugins/file-picker-panel.tsx +++ b/web/app/components/workflow/skill/editor/skill-editor/plugins/file-picker-panel.tsx @@ -1,7 +1,7 @@ import type { NodeRendererProps } from 'react-arborist' import type { FileAppearanceType } from '@/app/components/base/file-uploader/types' import type { TreeNodeData } from '@/app/components/workflow/skill/type' -import { RiArrowDownSLine, RiArrowRightSLine, RiFolderLine, RiFolderOpenLine, RiQuestionLine } from '@remixicon/react' +import { RiArrowDownSLine, RiArrowRightSLine, RiFolderLine, RiFolderOpenLine } from '@remixicon/react' import { useSize } from 'ahooks' import * as React from 'react' import { useCallback, useMemo, useRef } from 'react' @@ -52,7 +52,7 @@ const FilePickerTreeNode = ({ node, style, dragHandle, onSelectNode }: FilePicke aria-selected={isSelected} aria-expanded={isFolder ? node.isOpen : undefined} className={cn( - 'group relative flex h-6 cursor-pointer items-center gap-px overflow-hidden rounded-md', + 'group relative flex h-6 cursor-pointer items-center gap-0 overflow-hidden rounded-md', 'hover:bg-state-base-hover', 'focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-inset focus-visible:ring-components-input-border-active', isSelected && 'bg-state-base-active', @@ -82,6 +82,12 @@ const FilePickerTreeNode = ({ node, style, dragHandle, onSelectNode }: FilePicke {node.data.name} + {isFolder && ( +