diff --git a/sentry_sdk/ai/monitoring.py b/sentry_sdk/ai/monitoring.py index e149ebe7df..e826f3bf90 100644 --- a/sentry_sdk/ai/monitoring.py +++ b/sentry_sdk/ai/monitoring.py @@ -3,7 +3,7 @@ import sentry_sdk.utils from sentry_sdk import start_span -from sentry_sdk.tracing import Span +from sentry_sdk.tracing import POTelSpan as Span from sentry_sdk.utils import ContextVar from typing import TYPE_CHECKING diff --git a/sentry_sdk/ai/utils.py b/sentry_sdk/ai/utils.py index ed3494f679..4a972071a9 100644 --- a/sentry_sdk/ai/utils.py +++ b/sentry_sdk/ai/utils.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: from typing import Any -from sentry_sdk.tracing import Span +from sentry_sdk.tracing import POTelSpan as Span from sentry_sdk.utils import logger diff --git a/sentry_sdk/integrations/langchain.py b/sentry_sdk/integrations/langchain.py index afce913d8e..deb700bde2 100644 --- a/sentry_sdk/integrations/langchain.py +++ b/sentry_sdk/integrations/langchain.py @@ -3,10 +3,10 @@ import sentry_sdk from sentry_sdk.ai.monitoring import set_ai_pipeline_name, record_token_usage -from sentry_sdk.consts import OP, SPANDATA +from sentry_sdk.consts import OP, SPANDATA, SPANSTATUS from sentry_sdk.ai.utils import set_data_normalized from sentry_sdk.scope import should_send_default_pii -from sentry_sdk.tracing import Span +from sentry_sdk.tracing import POTelSpan as Span from sentry_sdk.integrations import DidNotEnable, Integration from sentry_sdk.utils import logger, capture_internal_exceptions @@ -72,7 +72,6 @@ def setup_once(): class WatchedSpan: - span = None # type: Span num_completion_tokens = 0 # type: int num_prompt_tokens = 0 # type: int no_collect_tokens = False # type: bool @@ -123,8 +122,9 @@ def _handle_error(self, run_id, error): span_data = self.span_map[run_id] if not span_data: return - sentry_sdk.capture_exception(error, span_data.span.scope) - span_data.span.__exit__(None, None, None) + sentry_sdk.capture_exception(error) + span_data.span.set_status(SPANSTATUS.INTERNAL_ERROR) + span_data.span.finish() del self.span_map[run_id] def _normalize_langchain_message(self, message): @@ -136,23 +136,27 @@ def _normalize_langchain_message(self, message): def _create_span(self, run_id, parent_id, **kwargs): # type: (SentryLangchainCallback, UUID, Optional[Any], Any) -> WatchedSpan - watched_span = None # type: Optional[WatchedSpan] - if parent_id: - parent_span = self.span_map.get(parent_id) # type: Optional[WatchedSpan] - if parent_span: - watched_span = WatchedSpan(parent_span.span.start_child(**kwargs)) - parent_span.children.append(watched_span) - if watched_span is None: - watched_span = WatchedSpan( - sentry_sdk.start_span(only_if_parent=True, **kwargs) - ) + parent_watched_span = self.span_map.get(parent_id) if parent_id else None + sentry_span = sentry_sdk.start_span( + parent_span=parent_watched_span.span if parent_watched_span else None, + only_if_parent=True, + **kwargs, + ) + watched_span = WatchedSpan(sentry_span) + if parent_watched_span: + parent_watched_span.children.append(watched_span) if kwargs.get("op", "").startswith("ai.pipeline."): if kwargs.get("name"): set_ai_pipeline_name(kwargs.get("name")) watched_span.is_pipeline = True - watched_span.span.__enter__() + # the same run_id is reused for the pipeline it seems + # so we need to end the older span to avoid orphan spans + existing_span_data = self.span_map.get(run_id) + if existing_span_data is not None: + self._exit_span(existing_span_data, run_id) + self.span_map[run_id] = watched_span self.gc_span_map() return watched_span @@ -163,7 +167,8 @@ def _exit_span(self, span_data, run_id): if span_data.is_pipeline: set_ai_pipeline_name(None) - span_data.span.__exit__(None, None, None) + span_data.span.set_status(SPANSTATUS.OK) + span_data.span.finish() del self.span_map[run_id] def on_llm_start( diff --git a/sentry_sdk/integrations/opentelemetry/span_processor.py b/sentry_sdk/integrations/opentelemetry/span_processor.py index 42ad32a5ea..8d513ec97d 100644 --- a/sentry_sdk/integrations/opentelemetry/span_processor.py +++ b/sentry_sdk/integrations/opentelemetry/span_processor.py @@ -291,3 +291,17 @@ def _common_span_transaction_attributes_as_json(self, span): common_json["tags"] = tags return common_json + + def _log_debug_info(self): + # type: () -> None + import pprint + + pprint.pprint( + { + format_span_id(span_id): [ + (format_span_id(child.context.span_id), child.name) + for child in children + ] + for span_id, children in self._children_spans.items() + } + ) diff --git a/sentry_sdk/tracing.py b/sentry_sdk/tracing.py index a0b9439dc8..3ee155aedb 100644 --- a/sentry_sdk/tracing.py +++ b/sentry_sdk/tracing.py @@ -1213,6 +1213,7 @@ def __init__( source=TRANSACTION_SOURCE_CUSTOM, # type: str attributes=None, # type: OTelSpanAttributes only_if_parent=False, # type: bool + parent_span=None, # type: Optional[POTelSpan] otel_span=None, # type: Optional[OtelSpan] **_, # type: dict[str, object] ): @@ -1231,7 +1232,7 @@ def __init__( self._otel_span = otel_span else: skip_span = False - if only_if_parent: + if only_if_parent and parent_span is None: parent_span_context = get_current_span().get_span_context() skip_span = ( not parent_span_context.is_valid or parent_span_context.is_remote @@ -1262,8 +1263,17 @@ def __init__( if sampled is not None: attributes[SentrySpanAttribute.CUSTOM_SAMPLED] = sampled + parent_context = None + if parent_span is not None: + parent_context = otel_trace.set_span_in_context( + parent_span._otel_span + ) + self._otel_span = tracer.start_span( - span_name, start_time=start_timestamp, attributes=attributes + span_name, + context=parent_context, + start_time=start_timestamp, + attributes=attributes, ) self.origin = origin or DEFAULT_SPAN_ORIGIN @@ -1506,10 +1516,7 @@ def timestamp(self): def start_child(self, **kwargs): # type: (**Any) -> POTelSpan - kwargs.setdefault("sampled", self.sampled) - - span = POTelSpan(only_if_parent=True, **kwargs) - return span + return POTelSpan(sampled=self.sampled, parent_span=self, **kwargs) def iter_headers(self): # type: () -> Iterator[Tuple[str, str]] diff --git a/tests/integrations/langchain/test_langchain.py b/tests/integrations/langchain/test_langchain.py index 2ac6679321..f8ab30054d 100644 --- a/tests/integrations/langchain/test_langchain.py +++ b/tests/integrations/langchain/test_langchain.py @@ -187,17 +187,11 @@ def test_langchain_agent( assert "measurements" not in chat_spans[0] if send_default_pii and include_prompts: - assert ( - "You are very powerful" - in chat_spans[0]["data"]["ai.input_messages"][0]["content"] - ) + assert "You are very powerful" in chat_spans[0]["data"]["ai.input_messages"] assert "5" in chat_spans[0]["data"]["ai.responses"] assert "word" in tool_exec_span["data"]["ai.input_messages"] assert 5 == int(tool_exec_span["data"]["ai.responses"]) - assert ( - "You are very powerful" - in chat_spans[1]["data"]["ai.input_messages"][0]["content"] - ) + assert "You are very powerful" in chat_spans[1]["data"]["ai.input_messages"] assert "5" in chat_spans[1]["data"]["ai.responses"] else: assert "ai.input_messages" not in chat_spans[0].get("data", {})