Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import hashlib
import os
import re
from datetime import datetime, UTC
from datetime import UTC, datetime

from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator

from opentelemetry.sdk.trace import RandomIdGenerator

HASH_LENGTH = 16
HASHED_ID_PATTERN = re.compile(r"^[0-9a-f]{16}$")


Expand Down Expand Up @@ -67,19 +67,25 @@ def operation_id_to_span_id(operation_id: str) -> int:

class DeterministicIdGenerator(RandomIdGenerator):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class was to make type checker happy. But I feel it's not correct

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the reason is the checker is complaining this like

self._provider.id_generator = self._id_generator

The provider's id_generator should use the base class, but they are using the RandomIdGenerator class

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you find a way to fix this?

"""An ID generator that produces deterministic span IDs when a pending
operation ID is set, and random IDs otherwise.
operation ID is set, and falls back to the provided generator otherwise.

Trace IDs are deterministic when an execution ARN is set, ensuring all
invocations of the same durable execution share a single trace.
invocations of the same durable execution share a single trace. When no
deterministic ID is available, generation is delegated to the fallback
generator (the tracer provider's original ID generator by default).

Trace IDs embed a real timestamp so they satisfy the X-Ray format
requirement (first 8 hex chars = Unix epoch seconds).

Args:
fallback_id_generator: Generator used when no deterministic ID is
available. Defaults to a new ``RandomIdGenerator``.
"""

def __init__(self) -> None:
def __init__(self, fallback_id_generator: IdGenerator | None = None) -> None:
self._next_span_id: int | None = None
self._execution_trace_id: int | None = None
self._random_id_generator = RandomIdGenerator()
self._fallback_id_generator = fallback_id_generator or RandomIdGenerator()

def set_next_span_id(self, span_id: int | None) -> None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will the threading issue be addressed in a separate PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, threading issue is tracked here: #428

"""Set the operation ID to use for the next span's ID.
Expand All @@ -101,9 +107,11 @@ def set_trace_id(

def generate_trace_id(self) -> int:
"""Generate a 128-bit trace ID."""
return self._execution_trace_id or self._random_id_generator.generate_trace_id()
return (
self._execution_trace_id or self._fallback_id_generator.generate_trace_id()
)

def generate_span_id(self) -> int:
"""Generate a 64-bit span ID."""
span_id, self._next_span_id = self._next_span_id, None
return span_id or self._random_id_generator.generate_span_id()
return span_id or self._fallback_id_generator.generate_span_id()
Original file line number Diff line number Diff line change
Expand Up @@ -7,30 +7,31 @@
import threading
from typing import TYPE_CHECKING, Any

from opentelemetry import trace, context
from opentelemetry.context import Context
from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider
from opentelemetry.sdk.trace.sampling import TraceIdRatioBased
from opentelemetry.trace import (
Tracer,
StatusCode,
SpanContext,
Span,
Link,
TraceFlags,
)

from aws_durable_execution_sdk_python.lambda_service import OperationType
from aws_durable_execution_sdk_python.plugin import (
DurableInstrumentationPlugin,
InvocationEndInfo,
InvocationStartInfo,
OperationEndInfo,
OperationStartInfo,
UserFunctionStartInfo,
UserFunctionEndInfo,
UserFunctionOutcome,
UserFunctionStartInfo,
)
from opentelemetry import context, trace
from opentelemetry.context import Context
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace import TracerProvider as SdkTracerProvider
from opentelemetry.sdk.trace.sampling import TraceIdRatioBased
from opentelemetry.trace import (
Link,
Span,
SpanContext,
StatusCode,
TraceFlags,
Tracer,
)

from aws_durable_execution_sdk_python_otel.context_extractors import (
ContextExtractor,
xray_context_extractor,
Expand All @@ -40,6 +41,7 @@
operation_id_to_span_id,
)


if TYPE_CHECKING:
pass

Expand Down Expand Up @@ -95,9 +97,13 @@ def __init__(
context_extractor or xray_context_extractor
)

self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator()

self._provider = trace_provider
# A ProxyTracerProvider (the API default from trace.get_tracer_provider()
# before an SDK provider is configured) has no id_generator; fall back to
# None so DeterministicIdGenerator uses its own default generator.
self._id_generator: DeterministicIdGenerator = DeterministicIdGenerator(
fallback_id_generator=getattr(self._provider, "id_generator", None)
)
self._provider.id_generator = self._id_generator
self._provider.sampler = TraceIdRatioBased(sampling_rate)
self._tracer: Tracer = self._provider.get_tracer(instrument_name)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from datetime import UTC, datetime

from opentelemetry.sdk.trace import IdGenerator, RandomIdGenerator

from aws_durable_execution_sdk_python_otel.deterministic_id_generator import (
HASHED_ID_PATTERN,
DeterministicIdGenerator,
Expand All @@ -14,6 +16,20 @@
)


class _StubIdGenerator(IdGenerator):
"""An IdGenerator that returns fixed, identifiable IDs."""

def __init__(self, trace_id: int, span_id: int) -> None:
self._trace_id = trace_id
self._span_id = span_id

def generate_trace_id(self) -> int:
return self._trace_id

def generate_span_id(self) -> int:
return self._span_id


def test_parse_xray_root_trace_id_returns_root_from_header():
"""Verify X-Ray Root trace ID parsing ignores other header fields."""
trace_header = (
Expand Down Expand Up @@ -94,7 +110,7 @@ def test_deterministic_id_generator_falls_back_to_random_trace_id(monkeypatch):
expected_trace_id = int("1" * 32, 16)
generator = DeterministicIdGenerator()
monkeypatch.setattr(
generator._random_id_generator,
generator._fallback_id_generator,
"generate_trace_id",
lambda: expected_trace_id,
)
Expand All @@ -108,7 +124,7 @@ def test_deterministic_id_generator_uses_next_span_id_once(monkeypatch):
random_span_id = int("3" * 16, 16)
generator = DeterministicIdGenerator()
monkeypatch.setattr(
generator._random_id_generator,
generator._fallback_id_generator,
"generate_span_id",
lambda: random_span_id,
)
Expand All @@ -124,11 +140,66 @@ def test_deterministic_id_generator_accepts_cleared_next_span_id(monkeypatch):
expected_span_id = int("4" * 16, 16)
generator = DeterministicIdGenerator()
monkeypatch.setattr(
generator._random_id_generator,
generator._fallback_id_generator,
"generate_span_id",
lambda: expected_span_id,
)

generator.set_next_span_id(None)

assert generator.generate_span_id() == expected_span_id


def test_deterministic_id_generator_defaults_to_random_fallback():
"""Verify the fallback defaults to a RandomIdGenerator when none is given."""
generator = DeterministicIdGenerator()

assert isinstance(generator._fallback_id_generator, RandomIdGenerator)


def test_deterministic_id_generator_uses_provided_fallback_for_trace_id(monkeypatch):
"""Verify the supplied fallback generator produces trace IDs when no
execution trace ID is set."""
monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False)
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16))
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

assert generator.generate_trace_id() == int("a" * 32, 16)


def test_deterministic_id_generator_uses_provided_fallback_for_span_id():
"""Verify the supplied fallback generator produces span IDs when no
deterministic span ID is pending."""
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16))
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

assert generator.generate_span_id() == int("b" * 16, 16)


def test_deterministic_id_generator_prefers_execution_trace_id_over_fallback(
monkeypatch,
):
"""Verify a configured execution trace ID takes precedence over the fallback."""
monkeypatch.delenv("_X_AMZN_TRACE_ID", raising=False)
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16))
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

generator.set_trace_id(
"arn:aws:lambda:us-west-2:123456789012:function:workflow:$LATEST",
datetime(2024, 1, 2, 3, 4, 5, tzinfo=UTC),
)

assert generator.generate_trace_id() == int("65937d253aa8c3f7ffe36c50d65b1a6d", 16)


def test_deterministic_id_generator_prefers_next_span_id_over_fallback():
"""Verify a pending deterministic span ID takes precedence over the fallback."""
deterministic_span_id = int("c" * 16, 16)
fallback = _StubIdGenerator(trace_id=int("a" * 32, 16), span_id=int("b" * 16, 16))
generator = DeterministicIdGenerator(fallback_id_generator=fallback)

generator.set_next_span_id(deterministic_span_id)

assert generator.generate_span_id() == deterministic_span_id
# Subsequent calls fall back to the provided generator.
assert generator.generate_span_id() == int("b" * 16, 16)
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,6 @@
from concurrent.futures import ThreadPoolExecutor
from datetime import UTC, datetime

from opentelemetry.context import Context
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from aws_durable_execution_sdk_python.lambda_service import (
InvocationStatus,
OperationStatus,
Expand All @@ -24,6 +19,11 @@
UserFunctionOutcome,
UserFunctionStartInfo,
)
from opentelemetry.context import Context
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter

from aws_durable_execution_sdk_python_otel.deterministic_id_generator import (
operation_id_to_span_id,
)
Expand Down Expand Up @@ -131,7 +131,9 @@ def test_operation_end_without_start_emits_continuation_span_with_link():
plugin.on_invocation_start(_invocation_start_info())
operation_id = "wait-existing"
random_span_id = int("1234567890abcdef", 16)
plugin._id_generator._random_id_generator.generate_span_id = lambda: random_span_id
plugin._id_generator._fallback_id_generator.generate_span_id = lambda: (
random_span_id
)

plugin.on_operation_end(
OperationEndInfo(
Expand Down
Loading