refactor(api): tighten OTel decorator typing (#32163)

This commit is contained in:
Shuvam Pandey
2026-02-09 21:31:02 +05:45
committed by GitHub
parent e0fcf33979
commit 7fb6e0cdfe
3 changed files with 25 additions and 19 deletions

View File

@@ -1,6 +1,6 @@
import functools import functools
from collections.abc import Callable from collections.abc import Callable
from typing import Any, TypeVar, cast from typing import ParamSpec, TypeVar, cast
from opentelemetry.trace import get_tracer from opentelemetry.trace import get_tracer
@@ -8,7 +8,8 @@ from configs import dify_config
from extensions.otel.decorators.handler import SpanHandler from extensions.otel.decorators.handler import SpanHandler
from extensions.otel.runtime import is_instrument_flag_enabled from extensions.otel.runtime import is_instrument_flag_enabled
T = TypeVar("T", bound=Callable[..., Any]) P = ParamSpec("P")
R = TypeVar("R")
_HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()} _HANDLER_INSTANCES: dict[type[SpanHandler], SpanHandler] = {SpanHandler: SpanHandler()}
@@ -20,7 +21,7 @@ def _get_handler_instance(handler_class: type[SpanHandler]) -> SpanHandler:
return _HANDLER_INSTANCES[handler_class] return _HANDLER_INSTANCES[handler_class]
def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T], T]: def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[Callable[P, R]], Callable[P, R]]:
""" """
Decorator that traces a function with an OpenTelemetry span. Decorator that traces a function with an OpenTelemetry span.
@@ -30,9 +31,9 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
:param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler. :param handler_class: Optional handler class to use for this span. If None, uses the default SpanHandler.
""" """
def decorator(func: T) -> T: def decorator(func: Callable[P, R]) -> Callable[P, R]:
@functools.wraps(func) @functools.wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> Any: def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()): if not (dify_config.ENABLE_OTEL or is_instrument_flag_enabled()):
return func(*args, **kwargs) return func(*args, **kwargs)
@@ -46,6 +47,6 @@ def trace_span(handler_class: type[SpanHandler] | None = None) -> Callable[[T],
kwargs=kwargs, kwargs=kwargs,
) )
return cast(T, wrapper) return cast(Callable[P, R], wrapper)
return decorator return decorator

View File

@@ -1,9 +1,11 @@
import inspect import inspect
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import Any from typing import Any, TypeVar
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode
R = TypeVar("R")
class SpanHandler: class SpanHandler:
""" """
@@ -31,9 +33,9 @@ class SpanHandler:
def _extract_arguments( def _extract_arguments(
self, self,
wrapped: Callable[..., Any], wrapped: Callable[..., R],
args: tuple[Any, ...], args: tuple[object, ...],
kwargs: Mapping[str, Any], kwargs: Mapping[str, object],
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
""" """
Extract function arguments using inspect.signature. Extract function arguments using inspect.signature.
@@ -62,10 +64,10 @@ class SpanHandler:
def wrapper( def wrapper(
self, self,
tracer: Any, tracer: Any,
wrapped: Callable[..., Any], wrapped: Callable[..., R],
args: tuple[Any, ...], args: tuple[object, ...],
kwargs: Mapping[str, Any], kwargs: Mapping[str, object],
) -> Any: ) -> R:
""" """
Fully control the wrapper behavior. Fully control the wrapper behavior.

View File

@@ -1,6 +1,6 @@
import logging import logging
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from typing import Any from typing import Any, TypeVar
from opentelemetry.trace import SpanKind, Status, StatusCode from opentelemetry.trace import SpanKind, Status, StatusCode
from opentelemetry.util.types import AttributeValue from opentelemetry.util.types import AttributeValue
@@ -12,16 +12,19 @@ from models.model import Account
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
R = TypeVar("R")
class AppGenerateHandler(SpanHandler): class AppGenerateHandler(SpanHandler):
"""Span handler for ``AppGenerateService.generate``.""" """Span handler for ``AppGenerateService.generate``."""
def wrapper( def wrapper(
self, self,
tracer: Any, tracer: Any,
wrapped: Callable[..., Any], wrapped: Callable[..., R],
args: tuple[Any, ...], args: tuple[object, ...],
kwargs: Mapping[str, Any], kwargs: Mapping[str, object],
) -> Any: ) -> R:
try: try:
arguments = self._extract_arguments(wrapped, args, kwargs) arguments = self._extract_arguments(wrapped, args, kwargs)
if not arguments: if not arguments: