Skip to content

Commit efa8b0d

Browse files
fix(openai): support for openai non-consumed streams (#3155)
1 parent 87f8932 commit efa8b0d

File tree

6 files changed

+1205
-53
lines changed

6 files changed

+1205
-53
lines changed

packages/opentelemetry-instrumentation-openai/opentelemetry/instrumentation/openai/shared/chat_wrappers.py

Lines changed: 144 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import json
33
import logging
4+
import threading
45
import time
56
from functools import singledispatch
67
from typing import List, Optional, Union
@@ -269,7 +270,8 @@ async def _handle_request(span, kwargs, instance):
269270
MessageEvent(
270271
content=message.get("content"),
271272
role=message.get("role"),
272-
tool_calls=_parse_tool_calls(message.get("tool_calls", None)),
273+
tool_calls=_parse_tool_calls(
274+
message.get("tool_calls", None)),
273275
)
274276
)
275277
else:
@@ -292,6 +294,7 @@ def _handle_response(
292294
choice_counter=None,
293295
duration_histogram=None,
294296
duration=None,
297+
is_streaming: bool = False,
295298
):
296299
if is_openai_v1():
297300
response_dict = model_as_dict(response)
@@ -306,6 +309,7 @@ def _handle_response(
306309
duration_histogram,
307310
response_dict,
308311
duration,
312+
is_streaming,
309313
)
310314

311315
# span attributes
@@ -323,13 +327,19 @@ def _handle_response(
323327

324328

325329
def _set_chat_metrics(
326-
instance, token_counter, choice_counter, duration_histogram, response_dict, duration
330+
instance,
331+
token_counter,
332+
choice_counter,
333+
duration_histogram,
334+
response_dict,
335+
duration,
336+
is_streaming: bool = False,
327337
):
328338
shared_attributes = metric_shared_attributes(
329339
response_model=response_dict.get("model") or None,
330340
operation="chat",
331341
server_address=_get_openai_base_url(instance),
332-
is_streaming=False,
342+
is_streaming=is_streaming,
333343
)
334344

335345
# token metrics
@@ -420,7 +430,8 @@ async def _set_prompts(span, messages):
420430
content = json.dumps(content)
421431
_set_span_attribute(span, f"{prefix}.content", content)
422432
if msg.get("tool_call_id"):
423-
_set_span_attribute(span, f"{prefix}.tool_call_id", msg.get("tool_call_id"))
433+
_set_span_attribute(
434+
span, f"{prefix}.tool_call_id", msg.get("tool_call_id"))
424435
tool_calls = msg.get("tool_calls")
425436
if tool_calls:
426437
for i, tool_call in enumerate(tool_calls):
@@ -476,9 +487,11 @@ def _set_completions(span, choices):
476487
_set_span_attribute(span, f"{prefix}.role", message.get("role"))
477488

478489
if message.get("refusal"):
479-
_set_span_attribute(span, f"{prefix}.refusal", message.get("refusal"))
490+
_set_span_attribute(
491+
span, f"{prefix}.refusal", message.get("refusal"))
480492
else:
481-
_set_span_attribute(span, f"{prefix}.content", message.get("content"))
493+
_set_span_attribute(
494+
span, f"{prefix}.content", message.get("content"))
482495

483496
function_call = message.get("function_call")
484497
if function_call:
@@ -533,7 +546,8 @@ def _set_streaming_token_metrics(
533546
# If API response doesn't have usage, fallback to tiktoken calculation
534547
if prompt_usage == -1 or completion_usage == -1:
535548
model_name = (
536-
complete_response.get("model") or request_kwargs.get("model") or "gpt-4"
549+
complete_response.get("model") or request_kwargs.get(
550+
"model") or "gpt-4"
537551
)
538552

539553
# Calculate prompt tokens if not available from API
@@ -543,7 +557,8 @@ def _set_streaming_token_metrics(
543557
if msg.get("content"):
544558
prompt_content += msg.get("content")
545559
if model_name and should_record_stream_token_usage():
546-
prompt_usage = get_token_count_from_string(prompt_content, model_name)
560+
prompt_usage = get_token_count_from_string(
561+
prompt_content, model_name)
547562

548563
# Calculate completion tokens if not available from API
549564
if completion_usage == -1 and complete_response.get("choices"):
@@ -566,7 +581,8 @@ def _set_streaming_token_metrics(
566581
**shared_attributes,
567582
SpanAttributes.LLM_TOKEN_TYPE: "input",
568583
}
569-
token_counter.record(prompt_usage, attributes=attributes_with_token_type)
584+
token_counter.record(
585+
prompt_usage, attributes=attributes_with_token_type)
570586

571587
if isinstance(completion_usage, int) and completion_usage >= 0:
572588
attributes_with_token_type = {
@@ -619,11 +635,34 @@ def __init__(
619635
self._time_of_first_token = self._start_time
620636
self._complete_response = {"choices": [], "model": ""}
621637

638+
# Cleanup state tracking to prevent duplicate operations
639+
self._cleanup_completed = False
640+
self._cleanup_lock = threading.Lock()
641+
642+
def __del__(self):
643+
"""Cleanup when object is garbage collected"""
644+
if hasattr(self, '_cleanup_completed') and not self._cleanup_completed:
645+
self._ensure_cleanup()
646+
622647
def __enter__(self):
623648
return self
624649

625650
def __exit__(self, exc_type, exc_val, exc_tb):
626-
self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)
651+
cleanup_exception = None
652+
try:
653+
self._ensure_cleanup()
654+
except Exception as e:
655+
cleanup_exception = e
656+
# Don't re-raise to avoid masking original exception
657+
658+
result = self.__wrapped__.__exit__(exc_type, exc_val, exc_tb)
659+
660+
if cleanup_exception:
661+
# Log cleanup exception but don't affect context manager behavior
662+
logger.debug(
663+
"Error during ChatStream cleanup in __exit__: %s", cleanup_exception)
664+
665+
return result
627666

628667
async def __aenter__(self):
629668
return self
@@ -643,6 +682,11 @@ def __next__(self):
643682
except Exception as e:
644683
if isinstance(e, StopIteration):
645684
self._process_complete_response()
685+
else:
686+
# Handle cleanup for other exceptions during stream iteration
687+
self._ensure_cleanup()
688+
if self._span and self._span.is_recording():
689+
self._span.set_status(Status(StatusCode.ERROR, str(e)))
646690
raise
647691
else:
648692
self._process_item(chunk)
@@ -654,13 +698,19 @@ async def __anext__(self):
654698
except Exception as e:
655699
if isinstance(e, StopAsyncIteration):
656700
self._process_complete_response()
701+
else:
702+
# Handle cleanup for other exceptions during stream iteration
703+
self._ensure_cleanup()
704+
if self._span and self._span.is_recording():
705+
self._span.set_status(Status(StatusCode.ERROR, str(e)))
657706
raise
658707
else:
659708
self._process_item(chunk)
660709
return chunk
661710

662711
def _process_item(self, item):
663-
self._span.add_event(name=f"{SpanAttributes.LLM_CONTENT_COMPLETION_CHUNK}")
712+
self._span.add_event(
713+
name=f"{SpanAttributes.LLM_CONTENT_COMPLETION_CHUNK}")
664714

665715
if self._first_token and self._streaming_time_to_first_token:
666716
self._time_of_first_token = time.time()
@@ -721,10 +771,82 @@ def _process_complete_response(self):
721771
emit_event(_parse_choice_event(choice))
722772
else:
723773
if should_send_prompts():
724-
_set_completions(self._span, self._complete_response.get("choices"))
774+
_set_completions(
775+
self._span, self._complete_response.get("choices"))
725776

726777
self._span.set_status(Status(StatusCode.OK))
727778
self._span.end()
779+
self._cleanup_completed = True
780+
781+
@dont_throw
782+
def _ensure_cleanup(self):
783+
"""Thread-safe cleanup method that handles different cleanup scenarios"""
784+
with self._cleanup_lock:
785+
if self._cleanup_completed:
786+
logger.debug("ChatStream cleanup already completed, skipping")
787+
return
788+
789+
try:
790+
logger.debug("Starting ChatStream cleanup")
791+
792+
# Calculate partial metrics based on available data
793+
self._record_partial_metrics()
794+
795+
# Set span status and close it
796+
if self._span and self._span.is_recording():
797+
self._span.set_status(Status(StatusCode.OK))
798+
self._span.end()
799+
logger.debug("ChatStream span closed successfully")
800+
801+
self._cleanup_completed = True
802+
logger.debug("ChatStream cleanup completed successfully")
803+
804+
except Exception as e:
805+
# Log cleanup errors but don't propagate to avoid masking original issues
806+
logger.debug("Error during ChatStream cleanup: %s", str(e))
807+
808+
# Still try to close the span even if metrics recording failed
809+
try:
810+
if self._span and self._span.is_recording():
811+
self._span.set_status(
812+
Status(StatusCode.ERROR, "Cleanup failed"))
813+
self._span.end()
814+
self._cleanup_completed = True
815+
except Exception:
816+
# Final fallback - just mark as completed to prevent infinite loops
817+
self._cleanup_completed = True
818+
819+
@dont_throw
820+
def _record_partial_metrics(self):
821+
"""Record metrics based on available partial data"""
822+
# Always record duration if we have start time
823+
if self._start_time and isinstance(self._start_time, (float, int)) and self._duration_histogram:
824+
duration = time.time() - self._start_time
825+
self._duration_histogram.record(
826+
duration, attributes=self._shared_attributes()
827+
)
828+
829+
# Record basic span attributes even without complete response
830+
if self._span and self._span.is_recording():
831+
_set_response_attributes(self._span, self._complete_response)
832+
833+
# Record partial token metrics if we have any data
834+
if self._complete_response.get("choices") or self._request_kwargs:
835+
_set_streaming_token_metrics(
836+
self._request_kwargs,
837+
self._complete_response,
838+
self._span,
839+
self._token_counter,
840+
self._shared_attributes(),
841+
)
842+
843+
# Record choice metrics if we have any choices processed
844+
if self._choice_counter and self._complete_response.get("choices"):
845+
_set_choice_counter_metrics(
846+
self._choice_counter,
847+
self._complete_response.get("choices"),
848+
self._shared_attributes(),
849+
)
728850

729851

730852
# Backward compatibility with OpenAI v0
@@ -755,7 +877,8 @@ def _build_from_streaming_response(
755877

756878
if first_token and streaming_time_to_first_token:
757879
time_of_first_token = time.time()
758-
streaming_time_to_first_token.record(time_of_first_token - start_time)
880+
streaming_time_to_first_token.record(
881+
time_of_first_token - start_time)
759882
first_token = False
760883

761884
_accumulate_stream_items(item, complete_response)
@@ -825,7 +948,8 @@ async def _abuild_from_streaming_response(
825948

826949
if first_token and streaming_time_to_first_token:
827950
time_of_first_token = time.time()
828-
streaming_time_to_first_token.record(time_of_first_token - start_time)
951+
streaming_time_to_first_token.record(
952+
time_of_first_token - start_time)
829953
first_token = False
830954

831955
_accumulate_stream_items(item, complete_response)
@@ -943,7 +1067,8 @@ def _(choice: dict) -> ChoiceEvent:
9431067

9441068
content = choice.get("message").get("content", "") if has_message else None
9451069
role = choice.get("message").get("role") if has_message else "unknown"
946-
finish_reason = choice.get("finish_reason") if has_finish_reason else "unknown"
1070+
finish_reason = choice.get(
1071+
"finish_reason") if has_finish_reason else "unknown"
9471072

9481073
if has_tool_calls and has_function_call:
9491074
tool_calls = message.get("tool_calls") + [message.get("function_call")]
@@ -982,7 +1107,8 @@ def _accumulate_stream_items(item, complete_response):
9821107

9831108
# prompt filter results
9841109
if item.get("prompt_filter_results"):
985-
complete_response["prompt_filter_results"] = item.get("prompt_filter_results")
1110+
complete_response["prompt_filter_results"] = item.get(
1111+
"prompt_filter_results")
9861112

9871113
for choice in item.get("choices"):
9881114
index = choice.get("index")
@@ -1029,4 +1155,5 @@ def _accumulate_stream_items(item, complete_response):
10291155
if tool_call_function and tool_call_function.get("name"):
10301156
span_function["name"] = tool_call_function.get("name")
10311157
if tool_call_function and tool_call_function.get("arguments"):
1032-
span_function["arguments"] += tool_call_function.get("arguments")
1158+
span_function["arguments"] += tool_call_function.get(
1159+
"arguments")

0 commit comments

Comments
 (0)