diff --git a/api/extensions/otel/decorators/base.py b/api/extensions/otel/decorators/base.py index 14221d24dd..a7bb8d051b 100644 --- a/api/extensions/otel/decorators/base.py +++ b/api/extensions/otel/decorators/base.py @@ -1,6 +1,6 @@ import functools from collections.abc import Callable -from typing import Any, TypeVar, cast +from typing import ParamSpec, TypeVar, cast from opentelemetry.trace import get_tracer @@ -8,7 +8,8 @@ from configs import dify_config from extensions.otel.decorators.handler import SpanHandler from extensions.otel.runtime import is_instrument_flag_enabled -T = TypeVar("T", bound=Callable[..., Any]) +P = ParamSpec("P") +R = TypeVar("R") _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} @@ -20,7 +21,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler: return _HANDLER_INSTANCES[handler_class] -def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], T]: +def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]: """ Decorator that traces a function with an OpenTelemetry span. @@ -30,9 +31,9 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], :param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler. """ - def decorator(func: T) -> T: + def decorator(func: Callable[P, R]) -> Callable[P, R]: @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): return func(*args, **kwargs) @@ -46,6 +47,6 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], kwargs=kwargs, ) - return cast(T, wrapper) + return cast(Callable[P, R], wrapper) return decorator diff --git a/api/extensions/otel/decorators/handler.py b/api/extensions/otel/decorators/handler.py index 1a7def5b0b..6915b63dce 100644 --- a/api/extensions/otel/decorators/handler.py +++ b/api/extensions/otel/decorators/handler.py @@ -1,9 +1,11 @@ import inspect from collections.abc import Callable, Mapping -from typing import Any +from typing import Any, TypeVar from opentelemetry.trace import SpanKind, Status, StatusCode +R = TypeVar("R") + class SpanHandler: """ @@ -31,9 +33,9 @@ class SpanHandler: def _extract_arguments( self, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], + wrapped: Callable[..., R], + args: tuple[object, ...], + kwargs: Mapping[str, object], ) -> dict[str, Any] | None: """ Extract function arguments using inspect.signature. @@ -62,10 +64,10 @@ class SpanHandler: def wrapper( self, tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + wrapped: Callable[..., R], + args: tuple[object, ...], + kwargs: Mapping[str, object], + ) -> R: """ Fully control the wrapper behavior. diff --git a/api/extensions/otel/decorators/handlers/generate_handler.py b/api/extensions/otel/decorators/handlers/generate_handler.py index 63748a9824..b37aca664a 100644 --- a/api/extensions/otel/decorators/handlers/generate_handler.py +++ b/api/extensions/otel/decorators/handlers/generate_handler.py @@ -1,6 +1,6 @@ import logging from collections.abc import Callable, Mapping -from typing import Any +from typing import Any, TypeVar from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.util.types import AttributeValue @@ -12,16 +12,19 @@ from models.model import Account logger = logging.getLogger(__name__) +R = TypeVar("R") + + class AppGenerateHandler(SpanHandler): """Span handler for ``AppGenerateService.generate``.""" def wrapper( self, tracer: Any, - wrapped: Callable[..., Any], - args: tuple[Any, ...], - kwargs: Mapping[str, Any], - ) -> Any: + wrapped: Callable[..., R], + args: tuple[object, ...], + kwargs: Mapping[str, object], + ) -> R: try: arguments = self._extract_arguments(wrapped, args, kwargs) if not arguments: