diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py index 56089303..7b5c2f50 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/deterministic_id_generator.py @@ -5,11 +5,11 @@ import hashlib import os import re -from datetime import datetime, UTC +from datetime import UTC, datetime + +from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator -from opentelemetry.sdk.trace import RandomIdGenerator -HASH_LENGTH = 16 HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$") @@ -67,19 +67,25 @@ def operation_id_to_span_id(operation_id: str) -> int: class DeterministicIdGenerator(RandomIdGenerator): """An ID generator that produces deterministic span IDs when a pending - operation ID is set, and random IDs otherwise. + operation ID is set, and falls back to the provided generator otherwise. Trace IDs are deterministic when an execution ARN is set, ensuring all - invocations of the same durable execution share a single trace. + invocations of the same durable execution share a single trace. When no + deterministic ID is available, generation is delegated to the fallback + generator (the tracer provider's original ID generator by default). Trace IDs embed a real timestamp so they satisfy the X-Ray format requirement (first 8 hex chars = Unix epoch seconds). + + Args: + fallback_id_generator: Generator used when no deterministic ID is + available. Defaults to a new ``RandomIdGenerator``. """ - def __init__(self) -> None: + def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None: self._next_span_id: int | None = None self._execution_trace_id: int | None = None - self._random_id_generator = RandomIdGenerator() + self._fallback_id_generator = fallback_id_generator or RandomIdGenerator() def set_next_span_id(self, span_id: int | None) -> None: """Set the operation ID to use for the next span's ID. @@ -101,9 +107,11 @@ def set_trace_id( def generate_trace_id(self) -> int: """Generate a 128-bit trace ID.""" - return self._execution_trace_id or self._random_id_generator.generate_trace_id() + return ( + self._execution_trace_id or self._fallback_id_generator.generate_trace_id() + ) def generate_span_id(self) -> int: """Generate a 64-bit span ID.""" span_id, self._next_span_id = self._next_span_id, None - return span_id or self._random_id_generator.generate_span_id() + return span_id or self._fallback_id_generator.generate_span_id() diff --git a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py index ecd3895f..3f8b0b0e 100644 --- a/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/src/aws_durable_execution_sdk_python_otel/plugin.py @@ -7,19 +7,6 @@ import threading from typing import TYPE_CHECKING, Any -from opentelemetry import trace, context -from opentelemetry.context import Context -from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider -from opentelemetry.sdk.trace.sampling import TraceIdRatioBased -from opentelemetry.trace import ( - Tracer, - StatusCode, - SpanContext, - Span, - Link, - TraceFlags, -) - from aws_durable_execution_sdk_python.lambda_service import OperationType from aws_durable_execution_sdk_python.plugin import ( DurableInstrumentationPlugin, @@ -27,10 +14,24 @@ InvocationStartInfo, OperationEndInfo, OperationStartInfo, - UserFunctionStartInfo, UserFunctionEndInfo, UserFunctionOutcome, + UserFunctionStartInfo, ) +from opentelemetry import context, trace +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider +from opentelemetry.sdk.trace.sampling import TraceIdRatioBased +from opentelemetry.trace import ( + Link, + Span, + SpanContext, + StatusCode, + TraceFlags, + Tracer, +) + from aws_durable_execution_sdk_python_otel.context_extractors import ( ContextExtractor, xray_context_extractor, @@ -40,6 +41,7 @@ operation_id_to_span_id, ) + if TYPE_CHECKING: pass @@ -95,9 +97,13 @@ def __init__( context_extractor or xray_context_extractor ) - self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator() - self._provider = trace_provider + # A ProxyTracerProvider (the API default from trace.get_tracer_provider() + # before an SDK provider is configured) has no id_generator; fall back to + # None so DeterministicIdGenerator uses its own default generator. + self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator( + fallback_id_generator=getattr(self._provider, "id_generator", None) + ) self._provider.id_generator = self._id_generator self._provider.sampler = TraceIdRatioBased(sampling_rate) self._tracer: Tracer = self._provider.get_tracer(instrument_name) diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py index 3f4e53f7..8cb9dc7c 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_deterministic_id_generator.py @@ -4,6 +4,8 @@ from datetime import UTC, datetime +from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator + from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( HASHED_ID_PATTERN, DeterministicIdGenerator, @@ -14,6 +16,20 @@ ) +class _StubIdGenerator(IdGenerator): + """An IdGenerator that returns fixed, identifiable IDs.""" + + def __init__(self, trace_id: int, span_id: int) -> None: + self._trace_id = trace_id + self._span_id = span_id + + def generate_trace_id(self) -> int: + return self._trace_id + + def generate_span_id(self) -> int: + return self._span_id + + def test_parse_xray_root_trace_id_returns_root_from_header(): """Verify X-Ray Root trace ID parsing ignores other header fields.""" trace_header = ( @@ -94,7 +110,7 @@ def test_deterministic_id_generator_falls_back_to_random_trace_id(monkeypatch): expected_trace_id = int("1" * 32, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_trace_id", lambda: expected_trace_id, ) @@ -108,7 +124,7 @@ def test_deterministic_id_generator_uses_next_span_id_once(monkeypatch): random_span_id = int("3" * 16, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_span_id", lambda: random_span_id, ) @@ -124,7 +140,7 @@ def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch): expected_span_id = int("4" * 16, 16) generator = DeterministicIdGenerator() monkeypatch.setattr( - generator._random_id_generator, + generator._fallback_id_generator, "generate_span_id", lambda: expected_span_id, ) @@ -132,3 +148,58 @@ def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch): generator.set_next_span_id(None) assert generator.generate_span_id() == expected_span_id + + +def test_deterministic_id_generator_defaults_to_random_fallback(): + """Verify the fallback defaults to a RandomIdGenerator when none is given.""" + generator = DeterministicIdGenerator() + + assert isinstance(generator._fallback_id_generator, RandomIdGenerator) + + +def test_deterministic_id_generator_uses_provided_fallback_for_trace_id(monkeypatch): + """Verify the supplied fallback generator produces trace IDs when no + execution trace ID is set.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + assert generator.generate_trace_id() == int("a" * 32, 16) + + +def test_deterministic_id_generator_uses_provided_fallback_for_span_id(): + """Verify the supplied fallback generator produces span IDs when no + deterministic span ID is pending.""" + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + assert generator.generate_span_id() == int("b" * 16, 16) + + +def test_deterministic_id_generator_prefers_execution_trace_id_over_fallback( + monkeypatch, +): + """Verify a configured execution trace ID takes precedence over the fallback.""" + monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + generator.set_trace_id( + "arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST", + datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC), + ) + + assert generator.generate_trace_id() == int("65937d253aa8c3f7ffe36c50d65b1a6d", 16) + + +def test_deterministic_id_generator_prefers_next_span_id_over_fallback(): + """Verify a pending deterministic span ID takes precedence over the fallback.""" + deterministic_span_id = int("c" * 16, 16) + fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16)) + generator = DeterministicIdGenerator(fallback_id_generator=fallback) + + generator.set_next_span_id(deterministic_span_id) + + assert generator.generate_span_id() == deterministic_span_id + # Subsequent calls fall back to the provided generator. + assert generator.generate_span_id() == int("b" * 16, 16) diff --git a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py index 5fb8a430..f44cbe01 100644 --- a/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py +++ b/packages/aws-durable-execution-sdk-python-otel/tests/test_plugin.py @@ -5,11 +5,6 @@ from concurrent.futures import ThreadPoolExecutor from datetime import UTC, datetime -from opentelemetry.context import Context -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import SimpleSpanProcessor -from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter - from aws_durable_execution_sdk_python.lambda_service import ( InvocationStatus, OperationStatus, @@ -24,6 +19,11 @@ UserFunctionOutcome, UserFunctionStartInfo, ) +from opentelemetry.context import Context +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + from aws_durable_execution_sdk_python_otel.deterministic_id_generator import ( operation_id_to_span_id, ) @@ -131,7 +131,9 @@ def test_operation_end_without_start_emits_continuation_span_with_link(): plugin.on_invocation_start(_invocation_start_info()) operation_id = "wait-existing" random_span_id = int("1234567890abcdef", 16) - plugin._id_generator._random_id_generator.generate_span_id = lambda: random_span_id + plugin._id_generator._fallback_id_generator.generate_span_id = lambda: ( + random_span_id + ) plugin.on_operation_end( OperationEndInfo(