Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion internal/analysisengine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"path/filepath"
"time"

"github.com/go-logr/logr"
"github.com/openshift/osde2e/internal/aggregator"
"github.com/openshift/osde2e/internal/llm"
"github.com/openshift/osde2e/internal/llm/tools"
Expand Down Expand Up @@ -34,6 +35,7 @@ type Engine struct {
aggregatorService *aggregator.Aggregator
promptStore *prompts.PromptStore
llmClient llm.LLMClient
fallbackLLMClient llm.LLMClient
}

// New creates a new analysis engine
Expand Down Expand Up @@ -65,11 +67,17 @@ func New(ctx context.Context, config *Config) (*Engine, error) {
return nil, fmt.Errorf("failed to initialize LLM client: %w", err)
}

fallbackLLMClient, err := llm.NewGeminiClientWithModel(ctx, config.APIKey, llm.FallbackModel)
if err != nil {
return nil, fmt.Errorf("failed to initialize fallback LLM client: %w", err)
}

return &Engine{
config: config,
aggregatorService: aggregatorService,
promptStore: promptStore,
llmClient: client,
fallbackLLMClient: fallbackLLMClient,
}, nil
}

Expand Down Expand Up @@ -114,7 +122,15 @@ func (e *Engine) Run(ctx context.Context) (*Result, error) {
}
}

result, err := e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry)
logger := logr.FromContextOrDiscard(ctx)
result, err := llm.AnalyzeWithRetry(ctx, logger,
func() (*llm.AnalysisResult, error) {
return e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry)
},
func() (*llm.AnalysisResult, error) {
return e.fallbackLLMClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry)
},
)
if err != nil {
return nil, fmt.Errorf("log analysis failed: %w", err)
}
Expand Down
10 changes: 10 additions & 0 deletions internal/llm/errors.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package llm

import "errors"

var (
ErrNoResponseCandidates = errors.New("no response candidates from gemini")
ErrNoContentInResponse = errors.New("no content in gemini response")
ErrToolCallFailed = errors.New("failed to handle tool call")
ErrMaxIterations = errors.New("max iterations reached without final response")
)
13 changes: 8 additions & 5 deletions internal/llm/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ import (
"github.com/openshift/osde2e/internal/llm/tools"
)

const DefaultModel = "gemini-3.1-pro-preview"
const (
DefaultModel = "gemini-3.1-pro-preview"
FallbackModel = "gemini-2.5-pro"
)

type GeminiClient struct {
client *genai.Client
Expand Down Expand Up @@ -111,17 +114,17 @@ func (g *GeminiClient) handleConversationWithTools(ctx context.Context, contents
}
}

return &AnalysisResult{ToolCalls: toolCalls}, fmt.Errorf("max iterations reached without final response")
return &AnalysisResult{ToolCalls: toolCalls}, ErrMaxIterations
}

func (g *GeminiClient) extractCandidate(resp *genai.GenerateContentResponse) (*genai.Candidate, error) {
if len(resp.Candidates) == 0 {
return nil, fmt.Errorf("no response candidates from gemini")
return nil, ErrNoResponseCandidates
}

candidate := resp.Candidates[0]
if candidate.Content == nil || len(candidate.Content.Parts) == 0 {
return nil, fmt.Errorf("no content in gemini response")
return nil, ErrNoContentInResponse
}

return candidate, nil
Expand Down Expand Up @@ -151,7 +154,7 @@ func (g *GeminiClient) processFunctionCalls(ctx context.Context, contents []*gen
// Execute the tool and get the result
toolResult, err := toolRegistry.HandleToolCall(ctx, functionCall)
if err != nil {
return nil, fmt.Errorf("failed to handle tool call: %w", err)
return nil, fmt.Errorf("%w: %w", ErrToolCallFailed, err)
}

// Add the tool result to conversation history
Expand Down
123 changes: 123 additions & 0 deletions internal/llm/retry.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
package llm

import (
"context"
"errors"
"fmt"
"net"
"time"

"github.com/go-logr/logr"
"google.golang.org/genai"
)

const (
primaryRetries = 2
fallbackRetries = 0
retryDelay = 30 * time.Second
)

var (
retryDelayOverride time.Duration

retryableStatusCodes = map[int]bool{
429: true, // Rate limit
500: true, // Internal server error
502: true, // Bad gateway
503: true, // Service unavailable
}
)

func AnalyzeWithRetry(
ctx context.Context,
logger logr.Logger,
primaryFn func() (*AnalysisResult, error),
fallbackFn func() (*AnalysisResult, error),
) (*AnalysisResult, error) {
result, exhausted, primaryErr := retryLoop(ctx, logger, "primary", primaryFn, primaryRetries)
if primaryErr == nil {
return result, nil
}

if !exhausted {
return nil, primaryErr
}

logger.Info("switching to fallback model", "reason", "primary model retries exhausted")

result, _, fallbackErr := retryLoop(ctx, logger, "fallback", fallbackFn, fallbackRetries)
if fallbackErr != nil {
logger.Error(errors.Join(primaryErr, fallbackErr), "LLM analysis failed after all retries on both models")
return nil, fmt.Errorf("LLM analysis unavailable: both primary and fallback models failed after retries: %w", errors.Join(primaryErr, fallbackErr))
}

return result, nil
}

func retryLoop(ctx context.Context, logger logr.Logger, modelName string, fn func() (*AnalysisResult, error), maxRetries int) (*AnalysisResult, bool, error) {
var lastErr error

for attempt := 0; attempt <= maxRetries; attempt++ {
result, err := fn()
if err == nil {
if attempt > 0 {
logger.Info("LLM analysis succeeded after retry", "model", modelName, "attempt", attempt)
}
return result, false, nil
}

lastErr = err

if !isRetryable(err) {
return nil, false, err
}

if attempt < maxRetries {
backoff := retryDelay
if retryDelayOverride > 0 {
backoff = retryDelayOverride
}
logger.Info("retrying LLM analysis", "model", modelName, "attempt", attempt+1, "maxRetries", maxRetries, "backoff", backoff, "error", err.Error())

timer := time.NewTimer(backoff)
select {
case <-ctx.Done():
timer.Stop()
return nil, false, fmt.Errorf("retry canceled: %w (last LLM error: %v)", ctx.Err(), lastErr)
case <-timer.C:
}
}
}

return nil, true, lastErr
}

func isRetryable(err error) bool {
var apiErr genai.APIError
if errors.As(err, &apiErr) {
return retryableStatusCodes[apiErr.Code]
}
Comment thread
varunraokadaparthi marked this conversation as resolved.

if errors.Is(err, context.DeadlineExceeded) {
return true
}

if errors.Is(err, context.Canceled) {
return false
}

var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() {
return true
}

if errors.Is(err, ErrNoResponseCandidates) || errors.Is(err, ErrNoContentInResponse) {
return true
}

if errors.Is(err, ErrToolCallFailed) || errors.Is(err, ErrMaxIterations) {
return false
}

return false
}
Loading