diff --git a/sentry_sdk/_span_batcher.py b/sentry_sdk/_span_batcher.py index 275462b21c..096194d353 100644 --- a/sentry_sdk/_span_batcher.py +++ b/sentry_sdk/_span_batcher.py @@ -13,10 +13,10 @@ if TYPE_CHECKING: from typing import Any, Callable, Optional - from sentry_sdk.traces import StreamedSpan + from sentry_sdk._types import SpanJSON -class SpanBatcher(Batcher["StreamedSpan"]): +class SpanBatcher(Batcher["SpanJSON"]): # MAX_BEFORE_FLUSH should be lower than MAX_BEFORE_DROP, so that there is # a bit of a buffer for spans that appear between the trigger to flush # and actually flushing the buffer. @@ -42,7 +42,7 @@ def __init__( # by trace_id, so that we can then send the buckets each in its own # envelope. # trace_id -> span buffer - self._span_buffer: dict[str, list["StreamedSpan"]] = defaultdict(list) + self._span_buffer: dict[str, list["SpanJSON"]] = defaultdict(list) self._running_size: dict[str, int] = defaultdict(lambda: 0) self._capture_func = capture_func self._record_lost_func = record_lost_func @@ -99,7 +99,7 @@ def _flush_loop(self) -> None: self._flush() self._last_full_flush = time.monotonic() - def add(self, span: "StreamedSpan") -> None: + def add(self, span: "SpanJSON") -> None: # Bail out if the current thread is already executing batcher code. # This prevents deadlocks when code running inside the batcher (e.g. # _add_to_envelope during flush, or _flush_event.wait/set) triggers @@ -115,7 +115,7 @@ def add(self, span: "StreamedSpan") -> None: return None with self._lock: - size = len(self._span_buffer[span.trace_id]) + size = len(self._span_buffer[span["trace_id"]]) if size >= self.MAX_BEFORE_DROP: self._record_lost_func( reason="queue_overflow", @@ -124,14 +124,15 @@ def add(self, span: "StreamedSpan") -> None: ) return None - self._span_buffer[span.trace_id].append(span) - self._running_size[span.trace_id] += self._estimate_size(span) + self._span_buffer[span["trace_id"]].append(span) + self._running_size[span["trace_id"]] += self._estimate_size(span) if ( size + 1 >= self.MAX_BEFORE_FLUSH - or self._running_size[span.trace_id] >= self.MAX_BYTES_BEFORE_FLUSH + or self._running_size[span["trace_id"]] + >= self.MAX_BYTES_BEFORE_FLUSH ): - self._pending_flush.add(span.trace_id) + self._pending_flush.add(span["trace_id"]) notify = True else: notify = False @@ -142,12 +143,12 @@ def add(self, span: "StreamedSpan") -> None: self._active.flag = False @staticmethod - def _estimate_size(item: "StreamedSpan") -> int: + def _estimate_size(item: "SpanJSON") -> int: # Rough estimate of serialized span size that's quick to compute. # 210 is the rough size of the payload without attributes, and then we # estimate the attributes separately. estimate = 210 - for value in item._attributes.values(): + for value in (item.get("attributes") or {}).values(): estimate += 50 if isinstance(value, str): @@ -158,26 +159,15 @@ def _estimate_size(item: "StreamedSpan") -> int: return estimate @staticmethod - def _to_transport_format(item: "StreamedSpan") -> "Any": - res: "dict[str, Any]" = { - "trace_id": item.trace_id, - "span_id": item.span_id, - "name": item._name if item._name is not None else "", - "status": item._status, - "is_segment": item._is_segment(), - "start_timestamp": item._start_timestamp.timestamp(), - } - - if item._end_timestamp: - res["end_timestamp"] = item._end_timestamp.timestamp() - - if item._parent_span_id: - res["parent_span_id"] = item._parent_span_id - - if item._attributes: + def _to_transport_format(item: "SpanJSON") -> "Any": + res = {k: v for k, v in item.items() if k not in ("_segment_span",)} + + if item.get("attributes"): res["attributes"] = { - k: serialize_attribute(v) for (k, v) in item._attributes.items() + k: serialize_attribute(v) for (k, v) in item["attributes"].items() } + else: + del res["attributes"] return res @@ -201,7 +191,7 @@ def _flush(self, only_pending: bool = False) -> None: if not spans: continue - dsc = spans[0]._dynamic_sampling_context() + dsc = spans[0]["_segment_span"]._dynamic_sampling_context() # Max per envelope is 1000, so if we happen to have more than # 1000 spans in one bucket, we'll need to separate them. diff --git a/sentry_sdk/_types.py b/sentry_sdk/_types.py index ad3fa35849..f952baf44c 100644 --- a/sentry_sdk/_types.py +++ b/sentry_sdk/_types.py @@ -140,6 +140,7 @@ def substituted_because_contains_sensitive_data(cls) -> "AnnotatedValue": if TYPE_CHECKING: from collections.abc import Container, MutableMapping, Sequence from datetime import datetime + from sentry_sdk.traces import StreamedSpan from types import TracebackType from typing import Any, Callable, Dict, Mapping, NotRequired, Optional, Type @@ -317,6 +318,22 @@ class SDKInfo(TypedDict): MetricProcessor = Callable[[Metric, Hint], Optional[Metric]] + SpanJSON = TypedDict( + "SpanJSON", + { + "trace_id": str, + "span_id": str, + "parent_span_id": NotRequired[str], + "name": str, + "status": str, + "is_segment": bool, + "start_timestamp": float, + "end_timestamp": NotRequired[float], + "attributes": NotRequired[Attributes], + "_segment_span": NotRequired[StreamedSpan], + }, + ) + # TODO: Make a proper type definition for this (PRs welcome!) Breadcrumb = Dict[str, Any] diff --git a/sentry_sdk/client.py b/sentry_sdk/client.py index d0b93e3bb1..a79df1f2a4 100644 --- a/sentry_sdk/client.py +++ b/sentry_sdk/client.py @@ -25,10 +25,12 @@ logger, get_before_send_log, get_before_send_metric, + get_before_send_span, has_logs_enabled, has_metrics_enabled, ) from sentry_sdk.serializer import serialize +from sentry_sdk.traces import StreamedSpan from sentry_sdk.tracing import trace from sentry_sdk.tracing_utils import has_span_streaming_enabled from sentry_sdk.transport import ( @@ -71,7 +73,6 @@ from sentry_sdk.scope import Scope from sentry_sdk.session import Session from sentry_sdk.spotlight import SpotlightClient - from sentry_sdk.traces import StreamedSpan from sentry_sdk.transport import Transport, Item from sentry_sdk._log_batcher import LogBatcher from sentry_sdk._metrics_batcher import MetricsBatcher @@ -938,34 +939,72 @@ def _capture_telemetry( ty: str, scope: "Scope", ) -> None: - # Capture attributes-based telemetry (logs, metrics, spansV2) + """ + Capture attributes-based telemetry (logs, metrics, streamed spans). + + Apply any attributes set on the scope to it, and run the user's + before_send_{telemetry} on it, if applicable. + """ if telemetry is None: return scope.apply_to_telemetry(telemetry) before_send = None + if ty == "log": before_send = get_before_send_log(self.options) + serialized = telemetry + elif ty == "metric": before_send = get_before_send_metric(self.options) + serialized = telemetry + + elif ty == "span": + before_send = get_before_send_span(self.options) + serialized = telemetry._to_json() # type: ignore[union-attr] if before_send is not None: - telemetry = before_send(telemetry, {}) # type: ignore + serialized = before_send(serialized, {}) # type: ignore[arg-type] + + if ty in ("log", "metric"): + # Logs and metrics can be dropped in their respective + # before_send, so if we get None, don't queue them for sending. + if serialized is None: + return + + elif ty == "span" and isinstance(telemetry, StreamedSpan): + # Spans can't be dropped in before_send_span by design. They can + # be altered though (e.g. to sanitize). Only allow changes to + # name and attributes. + if isinstance(serialized, dict) and serialized and "name" in serialized: + telemetry.name = serialized["name"] # type: ignore[typeddict-item] + telemetry._attributes = {} + for k, v in (serialized.get("attributes") or {}).items(): + telemetry.set_attribute(k, v) + + else: + logger.debug( + "[Tracing] Invalid return value from before_send_span. Keeping original span." + ) - if telemetry is None: - return + serialized = telemetry._to_json() batcher = None if ty == "log": batcher = self.log_batcher + elif ty == "metric": batcher = self.metrics_batcher + elif ty == "span": + # We need a reference to the segment span in the batcher to populate + # the DSC + serialized["_segment_span"] = telemetry._segment # type: ignore batcher = self.span_batcher if batcher is not None: - batcher.add(telemetry) # type: ignore + batcher.add(serialized) # type: ignore def _capture_log(self, log: "Optional[Log]", scope: "Scope") -> None: self._capture_telemetry(log, "log", scope) diff --git a/sentry_sdk/consts.py b/sentry_sdk/consts.py index d2b4cd89af..c81581132f 100644 --- a/sentry_sdk/consts.py +++ b/sentry_sdk/consts.py @@ -56,6 +56,7 @@ class CompressionAlgo(Enum): Log, Metric, ProfilerMode, + SpanJSON, TracesSampler, TransactionProcessor, ) @@ -85,6 +86,9 @@ class CompressionAlgo(Enum): "before_send_metric": Optional[Callable[[Metric, Hint], Optional[Metric]]], "trace_lifecycle": Optional[Literal["static", "stream"]], "ignore_spans": Optional[IgnoreSpansConfig], + "before_send_span": Optional[ + Callable[[SpanJSON, Hint], Optional[SpanJSON]] + ], "suppress_asgi_chained_exceptions": Optional[bool], }, total=False, diff --git a/sentry_sdk/traces.py b/sentry_sdk/traces.py index f49760f03b..f0ea5b6780 100644 --- a/sentry_sdk/traces.py +++ b/sentry_sdk/traces.py @@ -43,7 +43,7 @@ Union, ) - from sentry_sdk._types import Attributes, AttributeValue + from sentry_sdk._types import Attributes, AttributeValue, SpanJSON from sentry_sdk.profiler.continuous_profiler import ContinuousProfile P = ParamSpec("P") @@ -574,6 +574,26 @@ def _set_segment_attributes(self) -> None: self.set_attribute("process.command_args", sys.argv) + def _to_json(self) -> "SpanJSON": + res: "SpanJSON" = { + "trace_id": self.trace_id, + "span_id": self.span_id, + "name": self._name if self._name is not None else "", + "status": self._status, + "is_segment": self._is_segment(), + "start_timestamp": self._start_timestamp.timestamp(), + } + + if self._end_timestamp: + res["end_timestamp"] = self._end_timestamp.timestamp() + + if self._parent_span_id: + res["parent_span_id"] = self._parent_span_id + + res["attributes"] = {k: v for k, v in self._attributes.items()} + + return res + class NoOpStreamedSpan(StreamedSpan): __slots__ = ( diff --git a/sentry_sdk/utils.py b/sentry_sdk/utils.py index 5051a3d9d2..76f1919e98 100644 --- a/sentry_sdk/utils.py +++ b/sentry_sdk/utils.py @@ -77,6 +77,7 @@ Metric, SerializedAttributeValue, ) + from sentry_sdk.traces import StreamedSpan P = ParamSpec("P") R = TypeVar("R") @@ -2111,6 +2112,15 @@ def get_before_send_metric( ) +def get_before_send_span( + options: "Optional[dict[str, Any]]", +) -> "Optional[Callable[[StreamedSpan, Hint], Optional[StreamedSpan]]]": + if options is None: + return None + + return options["_experiments"].get("before_send_span") + + def format_attribute(val: "Any") -> "AttributeValue": """ Turn unsupported attribute value types into an AttributeValue. diff --git a/tests/integrations/sqlalchemy/test_sqlalchemy.py b/tests/integrations/sqlalchemy/test_sqlalchemy.py index d942d5fea3..a938ad9d7b 100644 --- a/tests/integrations/sqlalchemy/test_sqlalchemy.py +++ b/tests/integrations/sqlalchemy/test_sqlalchemy.py @@ -1080,19 +1080,19 @@ class Person(Base): class fake_record_sql_queries: # noqa: N801 def __init__(self, *args, **kwargs): - with record_sql_queries_supporting_streaming( + self._ctx_mgr = record_sql_queries_supporting_streaming( *args, **kwargs - ) as span: - self.span = span + ) + def __enter__(self): + self.span = self._ctx_mgr.__enter__() self.span._start_timestamp = datetime(2024, 1, 1, microsecond=0) self.span._end_timestamp = datetime(2024, 1, 1, microsecond=101000) - - def __enter__(self): return self.span def __exit__(self, type, value, traceback): - pass + self.span._end_timestamp = None + self._ctx_mgr.__exit__(type, value, traceback) with mock.patch( "sentry_sdk.integrations.sqlalchemy.record_sql_queries_supporting_streaming", diff --git a/tests/tracing/test_span_batcher.py b/tests/tracing/test_span_batcher.py index 4286691785..fd575b8b83 100644 --- a/tests/tracing/test_span_batcher.py +++ b/tests/tracing/test_span_batcher.py @@ -236,7 +236,7 @@ def test_weight_based_flushing_by_attribute_size( with sentry_sdk.traces.start_span(name="small span") as bare_span: pass - bare_span_size = SpanBatcher._estimate_size(bare_span) + bare_span_size = SpanBatcher._estimate_size(bare_span._to_json()) big_attr = "x" * bare_span_size monkeypatch.setattr(SpanBatcher, "MAX_BYTES_BEFORE_FLUSH", bare_span_size * 3) diff --git a/tests/tracing/test_span_streaming.py b/tests/tracing/test_span_streaming.py index 0e095b5147..b5cb001745 100644 --- a/tests/tracing/test_span_streaming.py +++ b/tests/tracing/test_span_streaming.py @@ -271,6 +271,143 @@ def traces_sampler(sampling_context): ... +def test_before_send_span_basic(sentry_init, capture_items): + def before_send_span(span, hint): + assert isinstance(span, dict) + + span["name"] = "Better span name" + del span["attributes"]["drop"] + span["attributes"]["sanitize"] = "[Removed]" + span["attributes"]["add"] = "new" + + return span + + sentry_init( + traces_sample_rate=1.0, + _experiments={ + "before_send_span": before_send_span, + "trace_lifecycle": "stream", + }, + ) + + items = capture_items("span") + + with sentry_sdk.traces.start_span( + name="span", + attributes={ + "drop": True, + "sanitize": "myamazingpassword", + }, + ): + ... + + sentry_sdk.get_client().flush() + spans = [item.payload for item in items] + + assert len(spans) == 1 + (span,) = spans + + assert span["name"] == "Better span name" + assert "drop" not in span["attributes"] + assert span["attributes"]["sanitize"] == "[Removed]" + assert span["attributes"]["add"] == "new" + + +@pytest.mark.parametrize( + "return_value", + [None, {}, {"not_a_span": True}], +) +def test_before_send_span_invalid_return_value( + sentry_init, capture_items, return_value +): + def before_send_span(span, hint): + # Spans can't be dropped in before_send_span, so unsupported return + # values will be ignored + return return_value + + sentry_init( + traces_sample_rate=1.0, + _experiments={ + "before_send_span": before_send_span, + "trace_lifecycle": "stream", + }, + ) + + items = capture_items("span") + + with sentry_sdk.traces.start_span(name="span"): + ... + + sentry_sdk.get_client().flush() + spans = [item.payload for item in items] + + assert len(spans) == 1 + (span,) = spans + + assert span["name"] == "span" + + +def test_before_send_span_unsupported_edit(sentry_init, capture_items): + def before_send_span(span, hint): + # Anything beyond attribute and name changes will be ignored + span["trace_id"] = "my-trace-id" + + return span + + sentry_init( + traces_sample_rate=1.0, + _experiments={ + "before_send_span": before_send_span, + "trace_lifecycle": "stream", + }, + ) + + items = capture_items("span") + + with sentry_sdk.traces.start_span(name="span"): + ... + + sentry_sdk.get_client().flush() + spans = [item.payload for item in items] + + assert len(spans) == 1 + (span,) = spans + + assert span["name"] == "span" + assert span["trace_id"] != "my-trace-id" + + +def test_before_send_span_doesnt_receive_ignored_spans(sentry_init, capture_items): + before_send_span_called = False + + def before_send_span(span, hint): + nonlocal before_send_span_called + before_send_span_called = True + return span + + sentry_init( + traces_sample_rate=1.0, + _experiments={ + "before_send_span": before_send_span, + "trace_lifecycle": "stream", + "ignore_spans": [ + "ignored", + ], + }, + ) + + items = capture_items("span") + + with sentry_sdk.traces.start_span(name="ignored"): + ... + + sentry_sdk.get_client().flush() + spans = [item.payload for item in items] + + assert not spans + assert not before_send_span_called + + def test_span_attributes(sentry_init, capture_items): sentry_init( traces_sample_rate=1.0,