diff --git a/posthog/ai/utils.py b/posthog/ai/utils.py index d9177686..fff047b9 100644 --- a/posthog/ai/utils.py +++ b/posthog/ai/utils.py @@ -2,7 +2,7 @@ import uuid from typing import Any, Callable, Dict, List, Optional, cast -from posthog import get_tags, identify_context, new_context, tag +from posthog import get_tags, identify_context, new_context, tag, contexts from posthog.ai.sanitization import ( sanitize_anthropic, sanitize_gemini, @@ -366,6 +366,10 @@ def call_llm_and_track_usage( if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) + if not contexts.get_context_distinct_id(): + # Use trace_id as district_id if it's not defined. + identify_context(posthog_trace_id) + if response and ( hasattr(response, "usage") or (provider == "gemini" and hasattr(response, "usage_metadata")) @@ -445,7 +449,7 @@ def call_llm_and_track_usage( sdk_tags, posthog_properties ) ph_client.capture( - distinct_id=posthog_distinct_id or posthog_trace_id, + distinct_id=contexts.get_context_distinct_id(), event="$ai_generation", properties=merged_properties, groups=posthog_groups, @@ -501,6 +505,10 @@ async def call_llm_and_track_usage_async( if posthog_trace_id is None: posthog_trace_id = str(uuid.uuid4()) + if not contexts.get_context_distinct_id(): + # Use trace_id as district_id if it's not defined. + identify_context(posthog_trace_id) + if response and ( hasattr(response, "usage") or (provider == "gemini" and hasattr(response, "usage_metadata")) @@ -580,7 +588,7 @@ async def call_llm_and_track_usage_async( sdk_tags, posthog_properties ) ph_client.capture( - distinct_id=posthog_distinct_id or posthog_trace_id, + distinct_id=contexts.get_context_distinct_id(), event="$ai_generation", properties=merged_properties, groups=posthog_groups,