From 14bbbb91cd07a3e8bc690ab1017f9c17ae01d9b5 Mon Sep 17 00:00:00 2001 From: Daniel Cazzulino Date: Tue, 5 May 2026 16:20:26 -0300 Subject: [PATCH] Fix streaming usage accounting Closes #142 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- AGENTS.md | 1 + src/xAI.Tests/ChatClientTests.cs | 153 ++++++++++++++++++++++++ src/xAI.Tests/Extensions/CallHelpers.cs | 26 ++++ src/xAI.Tests/SanityChecks.cs | 59 +++++++++ src/xAI/GrokChatClient.cs | 33 ++++- 5 files changed, 271 insertions(+), 1 deletion(-) diff --git a/AGENTS.md b/AGENTS.md index ab361ef..10ee8d1 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -6,3 +6,4 @@ - `AsISpeechToTextClient` returns an `ISpeechToTextClient` implementation that uses `POST /v1/stt` for file transcription and `wss://.../v1/stt` for raw-audio streaming transcription. - TTS defaults follow xAI docs: voice `eve`, language `en` when omitted by `TextToSpeechOptions`, and MP3 output when no codec is specified. - STT streaming defaults follow xAI docs: encoding `pcm` and sample rate `16000` when omitted; WebSocket input must be raw encoded audio, not MP3/WAV container bytes. +- Chat streaming `GetChatCompletionChunk.Usage` values are cumulative within a sampling segment and may reset across tool-driven segments; emit deltas (or restart deltas after a reset) so `ToChatResponse()` totals match non-streaming usage. diff --git a/src/xAI.Tests/ChatClientTests.cs b/src/xAI.Tests/ChatClientTests.cs index ea541d2..6b7c325 100644 --- a/src/xAI.Tests/ChatClientTests.cs +++ b/src/xAI.Tests/ChatClientTests.cs @@ -670,6 +670,159 @@ public async Task GrokCustomFactoryInvokedFromOptions() Assert.Equal("Hey Cazzulino!", response.Text); } + [Fact] + public async Task GrokStreamingResponseUsesUsageDeltas() + { + var client = new Mock(MockBehavior.Strict); + client.Setup(x => x.GetCompletionChunk(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncServerStreamingCall( + new GetChatCompletionChunk + { + Id = "response-1", + Model = "grok-4-1-fast-non-reasoning", + Outputs = + { + new CompletionOutputChunk + { + Delta = new Delta + { + Role = MessageRole.RoleAssistant, + Content = "Hello" + }, + FinishReason = FinishReason.ReasonInvalid, + Index = 0 + } + }, + Usage = new SamplingUsage + { + PromptTokens = 10, + CompletionTokens = 2, + TotalTokens = 12 + } + }, + new GetChatCompletionChunk + { + Id = "response-1", + Model = "grok-4-1-fast-non-reasoning", + Outputs = + { + new CompletionOutputChunk + { + Delta = new Delta + { + Content = " world" + }, + FinishReason = FinishReason.ReasonStop, + Index = 0 + } + }, + Usage = new SamplingUsage + { + PromptTokens = 10, + CompletionTokens = 4, + TotalTokens = 14 + } + })); + + var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning"); + + var updates = await grok.GetStreamingResponseAsync("Hi").ToListAsync(); + var response = updates.ToChatResponse(); + + Assert.NotNull(response.Usage); + Assert.Equal(10, response.Usage.InputTokenCount); + Assert.Equal(4, response.Usage.OutputTokenCount); + Assert.Equal(14, response.Usage.TotalTokenCount); + } + + [Fact] + public async Task GrokStreamingResponseUsageHandlesCounterResets() + { + var client = new Mock(MockBehavior.Strict); + client.Setup(x => x.GetCompletionChunk(It.IsAny(), null, null, CancellationToken.None)) + .Returns(CallHelpers.CreateAsyncServerStreamingCall( + new GetChatCompletionChunk + { + Id = "response-2", + Model = "grok-4-1-fast-non-reasoning", + Outputs = + { + new CompletionOutputChunk + { + Delta = new Delta + { + Role = MessageRole.RoleAssistant, + Content = "Phase one" + }, + FinishReason = FinishReason.ReasonInvalid, + Index = 0 + } + }, + Usage = new SamplingUsage + { + PromptTokens = 10, + CompletionTokens = 2, + TotalTokens = 12 + } + }, + new GetChatCompletionChunk + { + Id = "response-2", + Model = "grok-4-1-fast-non-reasoning", + Outputs = + { + new CompletionOutputChunk + { + Delta = new Delta + { + Content = " complete" + }, + FinishReason = FinishReason.ReasonInvalid, + Index = 0 + } + }, + Usage = new SamplingUsage + { + PromptTokens = 10, + CompletionTokens = 4, + TotalTokens = 14 + } + }, + new GetChatCompletionChunk + { + Id = "response-2", + Model = "grok-4-1-fast-non-reasoning", + Outputs = + { + new CompletionOutputChunk + { + Delta = new Delta + { + Content = " phase two" + }, + FinishReason = FinishReason.ReasonStop, + Index = 0 + } + }, + Usage = new SamplingUsage + { + PromptTokens = 3, + CompletionTokens = 3, + TotalTokens = 6 + } + })); + + var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning"); + + var updates = await grok.GetStreamingResponseAsync("Hi").ToListAsync(); + var response = updates.ToChatResponse(); + + Assert.NotNull(response.Usage); + Assert.Equal(13, response.Usage.InputTokenCount); + Assert.Equal(7, response.Usage.OutputTokenCount); + Assert.Equal(20, response.Usage.TotalTokenCount); + } + [Fact] public async Task AskFiles() { diff --git a/src/xAI.Tests/Extensions/CallHelpers.cs b/src/xAI.Tests/Extensions/CallHelpers.cs index 77b8050..b52ed73 100644 --- a/src/xAI.Tests/Extensions/CallHelpers.cs +++ b/src/xAI.Tests/Extensions/CallHelpers.cs @@ -42,5 +42,31 @@ public static AsyncUnaryCall CreateAsyncUnaryCall(StatusCo () => new Metadata(), () => { }); } + + public static AsyncServerStreamingCall CreateAsyncServerStreamingCall(params TResponse[] responses) + { + return new AsyncServerStreamingCall( + new TestAsyncStreamReader(responses), + Task.FromResult(new Metadata()), + () => Status.DefaultSuccess, + () => new Metadata(), + () => { }); + } + + class TestAsyncStreamReader(IEnumerable responses) : IAsyncStreamReader + { + readonly IEnumerator enumerator = responses.GetEnumerator(); + + public TResponse Current { get; private set; } = default!; + + public Task MoveNext(CancellationToken cancellationToken) + { + if (!enumerator.MoveNext()) + return Task.FromResult(false); + + Current = enumerator.Current; + return Task.FromResult(true); + } + } } } diff --git a/src/xAI.Tests/SanityChecks.cs b/src/xAI.Tests/SanityChecks.cs index a972bb3..d85917a 100644 --- a/src/xAI.Tests/SanityChecks.cs +++ b/src/xAI.Tests/SanityChecks.cs @@ -446,6 +446,33 @@ parseable by a decimal parser. output.WriteLine($"Code interpreter calls: {codeInterpreterCalls.Count}"); } + [SecretsFact("CI_XAI_API_KEY")] + public async Task StreamingUsageRoughlyMatchesNonStreamingUsage() + { + var client = new GrokClient(Configuration["CI_XAI_API_KEY"]!) + .AsIChatClient("grok-4-1-fast-non-reasoning"); + + var prompts = new[] + { + """Reply with JSON only: {"number":7}""", + """Using the previous assistant response, add 5 and reply with JSON only: {"number":12}""", + """Using the latest number from this conversation, multiply it by 3 and reply with JSON only: {"number":36}""", + }; + + var nonStreamingUsage = await GetConversationUsageAsync(client, prompts, streaming: false); + var streamingUsage = await GetConversationUsageAsync(client, prompts, streaming: true); + var usageDelta = Math.Abs(streamingUsage - nonStreamingUsage) / (double)nonStreamingUsage; + + output.WriteLine($"Non-streaming total tokens: {nonStreamingUsage}"); + output.WriteLine($"Streaming total tokens: {streamingUsage}"); + output.WriteLine($"Relative delta: {usageDelta:P2}"); + + Assert.True(nonStreamingUsage > 0, "Expected non-streaming total token usage to be reported."); + Assert.True(streamingUsage > 0, "Expected streaming total token usage to be reported."); + Assert.True(usageDelta <= 0.20, + $"Expected streaming total token usage to remain within 20% of non-streaming usage, but got {streamingUsage} vs {nonStreamingUsage} ({usageDelta:P2})."); + } + [SecretsTheory("CI_XAI_API_KEY")] [InlineData("rex")] public async Task TextToSpeech_SpeechToText(string voiceId) @@ -521,6 +548,38 @@ static async Task GetResponseAsync(IChatClient client, ChatConvers return updates.ToChatResponse(); } + static async Task GetConversationUsageAsync(IChatClient client, IReadOnlyList prompts, bool streaming) + { + var chat = new ChatConversation + { + { "system", "You are a precise assistant. Reply with compact JSON only." }, + }; + + long totalTokenCount = 0; + + foreach (var prompt in prompts) + { + chat.Add("user", prompt); + + var response = await GetResponseAsync(client, chat, new GrokChatOptions + { + ResponseFormat = ChatResponseFormat.Json, + Temperature = 0, + MaxOutputTokens = 64, + }, streaming); + + Assert.NotNull(response.Usage); + Assert.NotNull(response.Usage.TotalTokenCount); + var tokenCount = response.Usage.TotalTokenCount.Value; + Assert.True(tokenCount > 0, $"Expected token usage for prompt '{prompt}'."); + + totalTokenCount += tokenCount; + chat.AddRange(response.Messages); + } + + return totalTokenCount; + } + static T ParseJson(ChatResponse response, ITestOutputHelper output) { var responseText = response.Messages.Last().Text; diff --git a/src/xAI/GrokChatClient.cs b/src/xAI/GrokChatClient.cs index 8763ae2..ee55e99 100644 --- a/src/xAI/GrokChatClient.cs +++ b/src/xAI/GrokChatClient.cs @@ -70,6 +70,9 @@ async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable { var request = this.AsCompletionsRequest(messages, options); var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken); + var promptTokens = 0; + var completionTokens = 0; + var totalTokens = 0; await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken)) { @@ -107,7 +110,7 @@ async IAsyncEnumerable CompleteChatStreamingCore(IEnumerable text is not null) update.Contents.Add(new TextContent(text)); - if (chunk.Usage.Convert() is { } usage) + if (ConvertStreamingUsageDelta(chunk.Usage, ref promptTokens, ref completionTokens, ref totalTokens) is { } usage) update.Contents.Add(new UsageContent(usage) { RawRepresentation = chunk.Usage }); yield return update; @@ -136,6 +139,34 @@ static CitationAnnotation MapCitation(string citation) }; } + static UsageDetails? ConvertStreamingUsageDelta(SamplingUsage usage, ref int promptTokens, ref int completionTokens, ref int totalTokens) + { + if (usage == null) + return null; + + var reset = usage.PromptTokens < promptTokens + || usage.CompletionTokens < completionTokens + || usage.TotalTokens < totalTokens; + + var inputDelta = reset ? usage.PromptTokens : usage.PromptTokens - promptTokens; + var outputDelta = reset ? usage.CompletionTokens : usage.CompletionTokens - completionTokens; + var totalDelta = reset ? usage.TotalTokens : usage.TotalTokens - totalTokens; + + promptTokens = usage.PromptTokens; + completionTokens = usage.CompletionTokens; + totalTokens = usage.TotalTokens; + + if (inputDelta == 0 && outputDelta == 0 && totalDelta == 0) + return null; + + return new UsageDetails + { + InputTokenCount = inputDelta, + OutputTokenCount = outputDelta, + TotalTokenCount = totalDelta + }; + } + /// public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch {