Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
- `AsISpeechToTextClient` returns an `ISpeechToTextClient` implementation that uses `POST /v1/stt` for file transcription and `wss://.../v1/stt` for raw-audio streaming transcription.
- TTS defaults follow xAI docs: voice `eve`, language `en` when omitted by `TextToSpeechOptions`, and MP3 output when no codec is specified.
- STT streaming defaults follow xAI docs: encoding `pcm` and sample rate `16000` when omitted; WebSocket input must be raw encoded audio, not MP3/WAV container bytes.
- Chat streaming `GetChatCompletionChunk.Usage` values are cumulative within a sampling segment and may reset across tool-driven segments; emit deltas (or restart deltas after a reset) so `ToChatResponse()` totals match non-streaming usage.
153 changes: 153 additions & 0 deletions src/xAI.Tests/ChatClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,159 @@ public async Task GrokCustomFactoryInvokedFromOptions()
Assert.Equal("Hey Cazzulino!", response.Text);
}

[Fact]
public async Task GrokStreamingResponseUsesUsageDeltas()
{
var client = new Mock<xAI.Protocol.Chat.ChatClient>(MockBehavior.Strict);
client.Setup(x => x.GetCompletionChunk(It.IsAny<GetCompletionsRequest>(), null, null, CancellationToken.None))
.Returns(CallHelpers.CreateAsyncServerStreamingCall(
new GetChatCompletionChunk
{
Id = "response-1",
Model = "grok-4-1-fast-non-reasoning",
Outputs =
{
new CompletionOutputChunk
{
Delta = new Delta
{
Role = MessageRole.RoleAssistant,
Content = "Hello"
},
FinishReason = FinishReason.ReasonInvalid,
Index = 0
}
},
Usage = new SamplingUsage
{
PromptTokens = 10,
CompletionTokens = 2,
TotalTokens = 12
}
},
new GetChatCompletionChunk
{
Id = "response-1",
Model = "grok-4-1-fast-non-reasoning",
Outputs =
{
new CompletionOutputChunk
{
Delta = new Delta
{
Content = " world"
},
FinishReason = FinishReason.ReasonStop,
Index = 0
}
},
Usage = new SamplingUsage
{
PromptTokens = 10,
CompletionTokens = 4,
TotalTokens = 14
}
}));

var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");

var updates = await grok.GetStreamingResponseAsync("Hi").ToListAsync();
var response = updates.ToChatResponse();

Assert.NotNull(response.Usage);
Assert.Equal(10, response.Usage.InputTokenCount);
Assert.Equal(4, response.Usage.OutputTokenCount);
Assert.Equal(14, response.Usage.TotalTokenCount);
}

[Fact]
public async Task GrokStreamingResponseUsageHandlesCounterResets()
{
var client = new Mock<xAI.Protocol.Chat.ChatClient>(MockBehavior.Strict);
client.Setup(x => x.GetCompletionChunk(It.IsAny<GetCompletionsRequest>(), null, null, CancellationToken.None))
.Returns(CallHelpers.CreateAsyncServerStreamingCall(
new GetChatCompletionChunk
{
Id = "response-2",
Model = "grok-4-1-fast-non-reasoning",
Outputs =
{
new CompletionOutputChunk
{
Delta = new Delta
{
Role = MessageRole.RoleAssistant,
Content = "Phase one"
},
FinishReason = FinishReason.ReasonInvalid,
Index = 0
}
},
Usage = new SamplingUsage
{
PromptTokens = 10,
CompletionTokens = 2,
TotalTokens = 12
}
},
new GetChatCompletionChunk
{
Id = "response-2",
Model = "grok-4-1-fast-non-reasoning",
Outputs =
{
new CompletionOutputChunk
{
Delta = new Delta
{
Content = " complete"
},
FinishReason = FinishReason.ReasonInvalid,
Index = 0
}
},
Usage = new SamplingUsage
{
PromptTokens = 10,
CompletionTokens = 4,
TotalTokens = 14
}
},
new GetChatCompletionChunk
{
Id = "response-2",
Model = "grok-4-1-fast-non-reasoning",
Outputs =
{
new CompletionOutputChunk
{
Delta = new Delta
{
Content = " phase two"
},
FinishReason = FinishReason.ReasonStop,
Index = 0
}
},
Usage = new SamplingUsage
{
PromptTokens = 3,
CompletionTokens = 3,
TotalTokens = 6
}
}));

var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");

var updates = await grok.GetStreamingResponseAsync("Hi").ToListAsync();
var response = updates.ToChatResponse();

Assert.NotNull(response.Usage);
Assert.Equal(13, response.Usage.InputTokenCount);
Assert.Equal(7, response.Usage.OutputTokenCount);
Assert.Equal(20, response.Usage.TotalTokenCount);
}

[Fact]
public async Task AskFiles()
{
Expand Down
26 changes: 26 additions & 0 deletions src/xAI.Tests/Extensions/CallHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,31 @@ public static AsyncUnaryCall<TResponse> CreateAsyncUnaryCall<TResponse>(StatusCo
() => new Metadata(),
() => { });
}

public static AsyncServerStreamingCall<TResponse> CreateAsyncServerStreamingCall<TResponse>(params TResponse[] responses)
{
return new AsyncServerStreamingCall<TResponse>(
new TestAsyncStreamReader<TResponse>(responses),
Task.FromResult(new Metadata()),
() => Status.DefaultSuccess,
() => new Metadata(),
() => { });
}

class TestAsyncStreamReader<TResponse>(IEnumerable<TResponse> responses) : IAsyncStreamReader<TResponse>
{
readonly IEnumerator<TResponse> enumerator = responses.GetEnumerator();

public TResponse Current { get; private set; } = default!;

public Task<bool> MoveNext(CancellationToken cancellationToken)
{
if (!enumerator.MoveNext())
return Task.FromResult(false);

Current = enumerator.Current;
return Task.FromResult(true);
}
}
}
}
59 changes: 59 additions & 0 deletions src/xAI.Tests/SanityChecks.cs
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,33 @@ parseable by a decimal parser.
output.WriteLine($"Code interpreter calls: {codeInterpreterCalls.Count}");
}

[SecretsFact("CI_XAI_API_KEY")]
public async Task StreamingUsageRoughlyMatchesNonStreamingUsage()
{
var client = new GrokClient(Configuration["CI_XAI_API_KEY"]!)
.AsIChatClient("grok-4-1-fast-non-reasoning");

var prompts = new[]
{
"""Reply with JSON only: {"number":7}""",
"""Using the previous assistant response, add 5 and reply with JSON only: {"number":12}""",
"""Using the latest number from this conversation, multiply it by 3 and reply with JSON only: {"number":36}""",
};

var nonStreamingUsage = await GetConversationUsageAsync(client, prompts, streaming: false);
var streamingUsage = await GetConversationUsageAsync(client, prompts, streaming: true);
var usageDelta = Math.Abs(streamingUsage - nonStreamingUsage) / (double)nonStreamingUsage;

output.WriteLine($"Non-streaming total tokens: {nonStreamingUsage}");
output.WriteLine($"Streaming total tokens: {streamingUsage}");
output.WriteLine($"Relative delta: {usageDelta:P2}");

Assert.True(nonStreamingUsage > 0, "Expected non-streaming total token usage to be reported.");
Assert.True(streamingUsage > 0, "Expected streaming total token usage to be reported.");
Assert.True(usageDelta <= 0.20,
$"Expected streaming total token usage to remain within 20% of non-streaming usage, but got {streamingUsage} vs {nonStreamingUsage} ({usageDelta:P2}).");
}

[SecretsTheory("CI_XAI_API_KEY")]
[InlineData("rex")]
public async Task TextToSpeech_SpeechToText(string voiceId)
Expand Down Expand Up @@ -521,6 +548,38 @@ static async Task<ChatResponse> GetResponseAsync(IChatClient client, ChatConvers
return updates.ToChatResponse();
}

static async Task<long> GetConversationUsageAsync(IChatClient client, IReadOnlyList<string> prompts, bool streaming)
{
var chat = new ChatConversation
{
{ "system", "You are a precise assistant. Reply with compact JSON only." },
};

long totalTokenCount = 0;

foreach (var prompt in prompts)
{
chat.Add("user", prompt);

var response = await GetResponseAsync(client, chat, new GrokChatOptions
{
ResponseFormat = ChatResponseFormat.Json,
Temperature = 0,
MaxOutputTokens = 64,
}, streaming);

Assert.NotNull(response.Usage);
Assert.NotNull(response.Usage.TotalTokenCount);
var tokenCount = response.Usage.TotalTokenCount.Value;
Assert.True(tokenCount > 0, $"Expected token usage for prompt '{prompt}'.");

totalTokenCount += tokenCount;
chat.AddRange(response.Messages);
}

return totalTokenCount;
}

static T ParseJson<T>(ChatResponse response, ITestOutputHelper output)
{
var responseText = response.Messages.Last().Text;
Expand Down
33 changes: 32 additions & 1 deletion src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
{
var request = this.AsCompletionsRequest(messages, options);
var call = client.GetCompletionChunk(request, cancellationToken: cancellationToken);
var promptTokens = 0;
var completionTokens = 0;
var totalTokens = 0;

await foreach (var chunk in call.ResponseStream.ReadAllAsync(cancellationToken))
{
Expand Down Expand Up @@ -107,7 +110,7 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
text is not null)
update.Contents.Add(new TextContent(text));

if (chunk.Usage.Convert() is { } usage)
if (ConvertStreamingUsageDelta(chunk.Usage, ref promptTokens, ref completionTokens, ref totalTokens) is { } usage)
update.Contents.Add(new UsageContent(usage) { RawRepresentation = chunk.Usage });

yield return update;
Expand Down Expand Up @@ -136,6 +139,34 @@ static CitationAnnotation MapCitation(string citation)
};
}

static UsageDetails? ConvertStreamingUsageDelta(SamplingUsage usage, ref int promptTokens, ref int completionTokens, ref int totalTokens)
{
if (usage == null)
return null;

var reset = usage.PromptTokens < promptTokens
|| usage.CompletionTokens < completionTokens
|| usage.TotalTokens < totalTokens;

var inputDelta = reset ? usage.PromptTokens : usage.PromptTokens - promptTokens;
var outputDelta = reset ? usage.CompletionTokens : usage.CompletionTokens - completionTokens;
var totalDelta = reset ? usage.TotalTokens : usage.TotalTokens - totalTokens;

promptTokens = usage.PromptTokens;
completionTokens = usage.CompletionTokens;
totalTokens = usage.TotalTokens;

if (inputDelta == 0 && outputDelta == 0 && totalDelta == 0)
return null;

return new UsageDetails
{
InputTokenCount = inputDelta,
OutputTokenCount = outputDelta,
TotalTokenCount = totalDelta
};
}

/// <inheritdoc />
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
{
Expand Down
Loading