From 62d6d6b4b379db3425a36f45b2165e69b5371048 Mon Sep 17 00:00:00 2001 From: Alex Wang Date: Wed, 10 Jun 2026 16:50:03 -0700 Subject: [PATCH] fix(otel): Add id generator fallback Import TracerProvider from opentelemetry.sdk.trace instead of the API base class so id_generator and sampler attributes resolve correctly under mypy. Make DeterministicIdGenerator subclass RandomIdGenerator for type compatibility with the provider's id_generator attribute. Change DeterministicIdGenerator to fall back to the tracer provider's original id generator rather than always creating a new RandomIdGenerator. The plugin now captures the provider's existing generator and passes it as the fallback before overwriting it. Add unit tests covering the fallback behavior and update existing tests to reference the renamed _fallback_id_generator attribute. --- .../deterministic_id_generator.py | 26 ++++--- .../plugin.py | 38 +++++---- .../tests/test_deterministic_id_generator.py | 77 ++++++++++++++++++- .../tests/test_plugin.py | 14 ++-- 4 files changed, 121 insertions(+), 34 deletions(-) diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py index 56089303..7b5c2f50 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py @@ -5,11 +5,11 @@ import hashlib import os import re -from datetime import datetime, UTC +from datetime import UTC, datetime + +from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator -from opentelemetry.sdk.trace import RandomIdGenerator -HASH_LENGTH = 16 HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$") @@ -67,19 +67,25 @@ def operation_id_to_span_id(operation_id: str) -> int: class DeterministicIdGenerator(RandomIdGenerator): """An ID generator that produces deterministic span IDs when a pending - operation ID is set, and random IDs otherwise. + operation ID is set, and falls back to the provided generator otherwise. Trace IDs are deterministic when an execution ARN is set, ensuring all - invocations of the same durable execution share a single trace. + invocations of the same durable execution share a single trace. When no + deterministic ID is available, generation is delegated to the fallback + generator (the tracer provider's original ID generator by default). Trace IDs embed a real timestamp so they satisfy the X-Ray format requirement (first 8 hex chars = Unix epoch seconds). + + Args: + fallback_id_generator: Generator used when no deterministic ID is + available. Defaults to a new ``RandomIdGenerator``. """ - def __init__(self) -> None: + def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None: self._next_span_id: int | None = None self._execution_trace_id: int | None = None - self._random_id_generator = RandomIdGenerator() + self._fallback_id_generator = fallback_id_generator or RandomIdGenerator() def set_next_span_id(self, span_id: int | None) -> None: """Set the operation ID to use for the next span's ID. @@ -101,9 +107,11 @@ def set_trace_id( def generate_trace_id(self) -> int: """Generate a 128-bit trace ID.""" - return self._execution_trace_id or self._random_id_generator.generate_trace_id() + return ( + self._execution_trace_id or self._fallback_id_generator.generate_trace_id() + ) def generate_span_id(self) -> int: """Generate a 64-bit span ID.""" span_id, self._next_span_id = self._next_span_id, None - return span_id or self._random_id_generator.generate_span_id() + return span_id or self._fallback_id_generator.generate_span_id() diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py index ecd3895f..3f8b0b0e 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py @@ -7,19 +7,6 @@ import threading from typing import TYPE_CHECKING, Any -from opentelemetry import trace, context -from opentelemetry.context import Context -from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider -from opentelemetry.sdk.trace.sampling import TraceIdRatioBased -from opentelemetry.trace import ( - Tracer, - StatusCode, - SpanContext, - Span, - Link, - TraceFlags, -) - from aws_durable_execution_sdk_python.lambda_service import OperationType from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -27,10 +14,24 @@ InvocationStartInfo, OperationEndInfo, OperationStartInfo, - UserFunctionStartInfo, UserFunctionEndInfo, UserFunctionOutcome, + UserFunctionStartInfo, ) +from opentelemetry import context, trace +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider +from opentelemetry.sdk.trace.sampling import TraceIdRatioBased +from opentelemetry.trace import ( + Link, + Span, + SpanContext, + StatusCode, + TraceFlags, + Tracer, +) + from aws_durable_execution_sdk_python_otel.context_extractors import ( ContextExtractor, xray_context_extractor, @@ -40,6 +41,7 @@ operation_id_to_span_id, ) + if TYPE_CHECKING: pass @@ -95,9 +97,13 @@ def __init__( context_extractor or xray_context_extractor ) - self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator() - self._provider = trace_provider + # A ProxyTracerProvider (the API default from trace.get_tracer_provider() + # before an SDK provider is configured) has no id_generator; fall back to + # None so DeterministicIdGenerator uses its own default generator. + self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator( + fallback_id_generator=getattr(self._provider, "id_generator", None) + ) self._provider.id_generator = self._id_generator self._provider.sampler = TraceIdRatioBased(sampling_rate) self._tracer: Tracer = self._provider.get_tracer(instrument_name) diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py index 3f4e53f7..8cb9dc7c 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py @@ -4,6 +4,8 @@ from datetime import UTC, datetime +from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator + from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( HASHED_ID_PATTERN, DeterministicIdGenerator, @@ -14,6 +16,20 @@ ) +class _StubIdGenerator(IdGenerator): + """An IdGenerator that returns fixed, identifiable IDs.""" + + def __init__(self, trace_id: int, span_id: int) -> None: + self._trace_id = trace_id + self._span_id = span_id + + def generate_trace_id(self) -> int: + return self._trace_id + + def generate_span_id(self) -> int: + return self._span_id + + def test_parse_xray_root_trace_id_returns_root_from_header(): """Verify X-Ray Root trace ID parsing ignores other header fields.""" trace_header = ( @@ -94,7 +110,7 @@ def test_deterministic_id_generator_falls_back_to_random_trace_id(monkeypatch): expected_trace_id = int("1" * 32, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_trace_id", lambda: expected_trace_id, ) @@ -108,7 +124,7 @@ def test_deterministic_id_generator_uses_next_span_id_once(monkeypatch): random_span_id = int("3" * 16, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_span_id", lambda: random_span_id, ) @@ -124,7 +140,7 @@ def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch): expected_span_id = int("4" * 16, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_span_id", lambda: expected_span_id, ) @@ -132,3 +148,58 @@ def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch): generator.set_next_span_id(None) assert generator.generate_span_id() == expected_span_id + + +def test_deterministic_id_generator_defaults_to_random_fallback(): + """Verify the fallback defaults to a RandomIdGenerator when none is given.""" + generator = DeterministicIdGenerator() + + assert isinstance(generator._fallback_id_generator, RandomIdGenerator) + + +def test_deterministic_id_generator_uses_provided_fallback_for_trace_id(monkeypatch): + """Verify the supplied fallback generator produces trace IDs when no + execution trace ID is set.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + assert generator.generate_trace_id() == int("a" * 32, 16) + + +def test_deterministic_id_generator_uses_provided_fallback_for_span_id(): + """Verify the supplied fallback generator produces span IDs when no + deterministic span ID is pending.""" + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + assert generator.generate_span_id() == int("b" * 16, 16) + + +def test_deterministic_id_generator_prefers_execution_trace_id_over_fallback( + monkeypatch, +): + """Verify a configured execution trace ID takes precedence over the fallback.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + generator.set_trace_id( + "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST", + datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC), + ) + + assert generator.generate_trace_id() == int("65937d253aa8c3f7ffe36c50d65b1a6d", 16) + + +def test_deterministic_id_generator_prefers_next_span_id_over_fallback(): + """Verify a pending deterministic span ID takes precedence over the fallback.""" + deterministic_span_id = int("c" * 16, 16) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + generator.set_next_span_id(deterministic_span_id) + + assert generator.generate_span_id() == deterministic_span_id + # Subsequent calls fall back to the provided generator. + assert generator.generate_span_id() == int("b" * 16, 16) diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py index 5fb8a430..f44cbe01 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -5,11 +5,6 @@ from concurrent.futures import ThreadPoolExecutor from datetime import UTC, datetime -from opentelemetry.context import Context -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - from aws_durable_execution_sdk_python.lambda_service import ( InvocationStatus, OperationStatus, @@ -24,6 +19,11 @@ UserFunctionOutcome, UserFunctionStartInfo, ) +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( operation_id_to_span_id, ) @@ -131,7 +131,9 @@ def test_operation_end_without_start_emits_continuation_span_with_link(): plugin.on_invocation_start(_invocation_start_info()) operation_id = "wait-existing" random_span_id = int("1234567890abcdef", 16) - plugin._id_generator._random_id_generator.generate_span_id = lambda: random_span_id + plugin._id_generator._fallback_id_generator.generate_span_id = lambda: ( + random_span_id + ) plugin.on_operation_end( OperationEndInfo(