diff --git a/src/a2a/server/agent_execution/active_task.py b/src/a2a/server/agent_execution/active_task.py index cb1e0a061..b0154c8d6 100644 --- a/src/a2a/server/agent_execution/active_task.py +++ b/src/a2a/server/agent_execution/active_task.py @@ -59,7 +59,7 @@ Event, EventQueueSource, QueueShutDown, - _create_async_queue, + create_async_queue, ) from a2a.server.tasks import PushNotificationEvent from a2a.types.a2a_pb2 import ( @@ -402,7 +402,7 @@ def __init__( # Queue for incoming requests self._request_queue: AsyncQueue[tuple[RequestContext, uuid.UUID]] = ( - _create_async_queue() + create_async_queue() ) @property diff --git a/src/a2a/server/events/event_queue.py b/src/a2a/server/events/event_queue.py index 046dad291..ecab0ea3a 100644 --- a/src/a2a/server/events/event_queue.py +++ b/src/a2a/server/events/event_queue.py @@ -1,6 +1,5 @@ import asyncio import logging -import sys import warnings from abc import ABC, abstractmethod @@ -9,33 +8,17 @@ from typing_extensions import Self - -if sys.version_info >= (3, 13): - from asyncio import Queue as AsyncQueue - from asyncio import QueueShutDown - - def _create_async_queue(maxsize: int = 0) -> AsyncQueue[Any]: - """Create a backwards-compatible queue object.""" - return AsyncQueue(maxsize=maxsize) -else: - import culsans - - from culsans import AsyncQueue # type: ignore[no-redef] - from culsans import ( - AsyncQueueShutDown as QueueShutDown, # type: ignore[no-redef] - ) - - def _create_async_queue(maxsize: int = 0) -> AsyncQueue[Any]: - """Create a backwards-compatible queue object.""" - return culsans.Queue(maxsize=maxsize).async_q # type: ignore[no-any-return] - - from a2a.types.a2a_pb2 import ( Message, Task, TaskArtifactUpdateEvent, TaskStatusUpdateEvent, ) +from a2a.utils._async_queue_compat import ( + AsyncQueue, + QueueShutDown, + create_async_queue, +) from a2a.utils.telemetry import SpanKind, trace_class @@ -101,7 +84,7 @@ def __init__(self, max_queue_size: int = DEFAULT_MAX_QUEUE_SIZE) -> None: if max_queue_size <= 0: raise ValueError('max_queue_size must be greater than 0') - self._queue: AsyncQueue[Event] = _create_async_queue( + self._queue: AsyncQueue[Event] = create_async_queue( maxsize=max_queue_size ) self._children: list[EventQueueLegacy] = [] diff --git a/src/a2a/server/events/event_queue_v2.py b/src/a2a/server/events/event_queue_v2.py index 224cb8e56..3386a732e 100644 --- a/src/a2a/server/events/event_queue_v2.py +++ b/src/a2a/server/events/event_queue_v2.py @@ -12,7 +12,7 @@ Event, EventQueue, QueueShutDown, - _create_async_queue, + create_async_queue, ) from a2a.utils.telemetry import SpanKind, trace_class @@ -37,7 +37,7 @@ def __init__( if max_queue_size <= 0: raise ValueError('max_queue_size must be greater than 0') - self._incoming_queue: AsyncQueue[Event] = _create_async_queue( + self._incoming_queue: AsyncQueue[Event] = create_async_queue( maxsize=max_queue_size ) self._lock = asyncio.Lock() @@ -293,7 +293,7 @@ def __init__( raise ValueError('max_queue_size must be greater than 0') self._parent = parent - self._queue: AsyncQueue[Event] = _create_async_queue( + self._queue: AsyncQueue[Event] = create_async_queue( maxsize=max_queue_size ) self._is_closed = False diff --git a/src/a2a/utils/_async_queue_compat.py b/src/a2a/utils/_async_queue_compat.py new file mode 100644 index 000000000..078a221f9 --- /dev/null +++ b/src/a2a/utils/_async_queue_compat.py @@ -0,0 +1,28 @@ +"""Cross-version aliases for async queue primitives.""" + +import sys + +from typing import Any + + +if sys.version_info >= (3, 13): + from asyncio import Queue as AsyncQueue + from asyncio import QueueShutDown + + def create_async_queue(maxsize: int = 0) -> AsyncQueue[Any]: + """Create a backwards-compatible async queue object.""" + return AsyncQueue(maxsize=maxsize) +else: + import culsans + + from culsans import AsyncQueue # type: ignore[no-redef] + from culsans import ( + AsyncQueueShutDown as QueueShutDown, # type: ignore[no-redef] + ) + + def create_async_queue(maxsize: int = 0) -> AsyncQueue[Any]: + """Create a backwards-compatible async queue object.""" + return culsans.Queue(maxsize=maxsize).async_q # type: ignore[no-any-return] + + +__all__ = ['AsyncQueue', 'QueueShutDown', 'create_async_queue'] diff --git a/src/a2a/utils/telemetry.py b/src/a2a/utils/telemetry.py index 7dbfd12c4..ae6a50954 100644 --- a/src/a2a/utils/telemetry.py +++ b/src/a2a/utils/telemetry.py @@ -74,6 +74,8 @@ def internal_method(self): from typing_extensions import Self +from a2a.utils._async_queue_compat import QueueShutDown + if TYPE_CHECKING: from opentelemetry.trace import ( @@ -144,6 +146,12 @@ def __getattr__(self, name: str) -> Any: __all__ = ['SpanKind'] +_NON_ERROR_EXCEPTIONS: tuple[type[BaseException], ...] = ( + asyncio.CancelledError, + QueueShutDown, +) + + def trace_function( # noqa: PLR0915 func: Callable | None = None, *, @@ -233,11 +241,14 @@ async def async_wrapper(*args, **kwargs) -> Any: # Async wrapper, await for the function call to complete. result = await func(*args, **kwargs) span.set_status(StatusCode.OK) - # asyncio.CancelledError extends from BaseException - except asyncio.CancelledError as ce: + except _NON_ERROR_EXCEPTIONS as ge: exception = None - logger.debug('CancelledError in span %s', actual_span_name) - span.record_exception(ce) + logger.debug( + '%s in span %s', + type(ge).__name__, + actual_span_name, + ) + span.record_exception(ge) raise except Exception as e: exception = e diff --git a/tests/utils/test_telemetry.py b/tests/utils/test_telemetry.py index a43bf1fa3..c751ffa3b 100644 --- a/tests/utils/test_telemetry.py +++ b/tests/utils/test_telemetry.py @@ -8,6 +8,7 @@ import pytest +from a2a.server.events.event_queue import QueueShutDown from a2a.utils.telemetry import trace_class, trace_function @@ -266,3 +267,30 @@ def test_env_var_disabled_logs_message( in caplog.text ) assert 'OTEL_INSTRUMENTATION_A2A_SDK_ENABLED' in caplog.text + + +@pytest.mark.asyncio +@pytest.mark.parametrize('exc_cls', [asyncio.CancelledError, QueueShutDown]) +async def test_trace_function_async_non_error_exception_does_not_mark_span_error( + mock_span: mock.MagicMock, + exc_cls: type[BaseException], +) -> None: + """`trace_function` records non-error exceptions but never marks span ERROR. + + Covers `asyncio.CancelledError` and `QueueShutDown`. + """ + + @trace_function + async def non_error_exception() -> NoReturn: + await asyncio.sleep(0) + raise exc_cls('operation ended with non-error exception') + + with pytest.raises(exc_cls): + await non_error_exception() + + mock_span.record_exception.assert_called() + # The wrapper only passes `description=` when calling + # `set_status(StatusCode.ERROR, ...)`. Its absence on every call proves + # the span was never marked as failed. + for call in mock_span.set_status.call_args_list: + assert 'description' not in call.kwargs