diff --git a/internal/analysisengine/engine.go b/internal/analysisengine/engine.go index d88ee625b9..97ed86ae45 100644 --- a/internal/analysisengine/engine.go +++ b/internal/analysisengine/engine.go @@ -7,6 +7,7 @@ import ( "path/filepath" "time" + "github.com/go-logr/logr" "github.com/openshift/osde2e/internal/aggregator" "github.com/openshift/osde2e/internal/llm" "github.com/openshift/osde2e/internal/llm/tools" @@ -34,6 +35,7 @@ type Engine struct { aggregatorService *aggregator.Aggregator promptStore *prompts.PromptStore llmClient llm.LLMClient + fallbackLLMClient llm.LLMClient } // New creates a new analysis engine @@ -65,11 +67,17 @@ func New(ctx context.Context, config *Config) (*Engine, error) { return nil, fmt.Errorf("failed to initialize LLM client: %w", err) } + fallbackLLMClient, err := llm.NewGeminiClientWithModel(ctx, config.APIKey, llm.FallbackModel) + if err != nil { + return nil, fmt.Errorf("failed to initialize fallback LLM client: %w", err) + } + return &Engine{ config: config, aggregatorService: aggregatorService, promptStore: promptStore, llmClient: client, + fallbackLLMClient: fallbackLLMClient, }, nil } @@ -114,7 +122,15 @@ func (e *Engine) Run(ctx context.Context) (*Result, error) { } } - result, err := e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + logger := logr.FromContextOrDiscard(ctx) + result, err := llm.AnalyzeWithRetry(ctx, logger, + func() (*llm.AnalysisResult, error) { + return e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + }, + func() (*llm.AnalysisResult, error) { + return e.fallbackLLMClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + }, + ) if err != nil { return nil, fmt.Errorf("log analysis failed: %w", err) } diff --git a/internal/llm/errors.go b/internal/llm/errors.go new file mode 100644 index 0000000000..995c69e785 --- /dev/null +++ b/internal/llm/errors.go @@ -0,0 +1,10 @@ +package llm + +import "errors" + +var ( + ErrNoResponseCandidates = errors.New("no response candidates from gemini") + ErrNoContentInResponse = errors.New("no content in gemini response") + ErrToolCallFailed = errors.New("failed to handle tool call") + ErrMaxIterations = errors.New("max iterations reached without final response") +) diff --git a/internal/llm/gemini.go b/internal/llm/gemini.go index c846d72c1d..0e66cf48b6 100644 --- a/internal/llm/gemini.go +++ b/internal/llm/gemini.go @@ -9,7 +9,10 @@ import ( "github.com/openshift/osde2e/internal/llm/tools" ) -const DefaultModel = "gemini-3.1-pro-preview" +const ( + DefaultModel = "gemini-3.1-pro-preview" + FallbackModel = "gemini-2.5-pro" +) type GeminiClient struct { client *genai.Client @@ -111,17 +114,17 @@ func (g *GeminiClient) handleConversationWithTools(ctx context.Context, contents } } - return &AnalysisResult{ToolCalls: toolCalls}, fmt.Errorf("max iterations reached without final response") + return &AnalysisResult{ToolCalls: toolCalls}, ErrMaxIterations } func (g *GeminiClient) extractCandidate(resp *genai.GenerateContentResponse) (*genai.Candidate, error) { if len(resp.Candidates) == 0 { - return nil, fmt.Errorf("no response candidates from gemini") + return nil, ErrNoResponseCandidates } candidate := resp.Candidates[0] if candidate.Content == nil || len(candidate.Content.Parts) == 0 { - return nil, fmt.Errorf("no content in gemini response") + return nil, ErrNoContentInResponse } return candidate, nil @@ -151,7 +154,7 @@ func (g *GeminiClient) processFunctionCalls(ctx context.Context, contents []*gen // Execute the tool and get the result toolResult, err := toolRegistry.HandleToolCall(ctx, functionCall) if err != nil { - return nil, fmt.Errorf("failed to handle tool call: %w", err) + return nil, fmt.Errorf("%w: %w", ErrToolCallFailed, err) } // Add the tool result to conversation history diff --git a/internal/llm/retry.go b/internal/llm/retry.go new file mode 100644 index 0000000000..68cf204b02 --- /dev/null +++ b/internal/llm/retry.go @@ -0,0 +1,123 @@ +package llm + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/go-logr/logr" + "google.golang.org/genai" +) + +const ( + primaryRetries = 2 + fallbackRetries = 0 + retryDelay = 30 * time.Second +) + +var ( + retryDelayOverride time.Duration + + retryableStatusCodes = map[int]bool{ + 429: true, // Rate limit + 500: true, // Internal server error + 502: true, // Bad gateway + 503: true, // Service unavailable + } +) + +func AnalyzeWithRetry( + ctx context.Context, + logger logr.Logger, + primaryFn func() (*AnalysisResult, error), + fallbackFn func() (*AnalysisResult, error), +) (*AnalysisResult, error) { + result, exhausted, primaryErr := retryLoop(ctx, logger, "primary", primaryFn, primaryRetries) + if primaryErr == nil { + return result, nil + } + + if !exhausted { + return nil, primaryErr + } + + logger.Info("switching to fallback model", "reason", "primary model retries exhausted") + + result, _, fallbackErr := retryLoop(ctx, logger, "fallback", fallbackFn, fallbackRetries) + if fallbackErr != nil { + logger.Error(errors.Join(primaryErr, fallbackErr), "LLM analysis failed after all retries on both models") + return nil, fmt.Errorf("LLM analysis unavailable: both primary and fallback models failed after retries: %w", errors.Join(primaryErr, fallbackErr)) + } + + return result, nil +} + +func retryLoop(ctx context.Context, logger logr.Logger, modelName string, fn func() (*AnalysisResult, error), maxRetries int) (*AnalysisResult, bool, error) { + var lastErr error + + for attempt := 0; attempt <= maxRetries; attempt++ { + result, err := fn() + if err == nil { + if attempt > 0 { + logger.Info("LLM analysis succeeded after retry", "model", modelName, "attempt", attempt) + } + return result, false, nil + } + + lastErr = err + + if !isRetryable(err) { + return nil, false, err + } + + if attempt < maxRetries { + backoff := retryDelay + if retryDelayOverride > 0 { + backoff = retryDelayOverride + } + logger.Info("retrying LLM analysis", "model", modelName, "attempt", attempt+1, "maxRetries", maxRetries, "backoff", backoff, "error", err.Error()) + + timer := time.NewTimer(backoff) + select { + case <-ctx.Done(): + timer.Stop() + return nil, false, fmt.Errorf("retry canceled: %w (last LLM error: %v)", ctx.Err(), lastErr) + case <-timer.C: + } + } + } + + return nil, true, lastErr +} + +func isRetryable(err error) bool { + var apiErr genai.APIError + if errors.As(err, &apiErr) { + return retryableStatusCodes[apiErr.Code] + } + + if errors.Is(err, context.DeadlineExceeded) { + return true + } + + if errors.Is(err, context.Canceled) { + return false + } + + var netErr net.Error + if errors.As(err, &netErr) && netErr.Timeout() { + return true + } + + if errors.Is(err, ErrNoResponseCandidates) || errors.Is(err, ErrNoContentInResponse) { + return true + } + + if errors.Is(err, ErrToolCallFailed) || errors.Is(err, ErrMaxIterations) { + return false + } + + return false +} diff --git a/internal/llm/retry_test.go b/internal/llm/retry_test.go new file mode 100644 index 0000000000..18c0545f6c --- /dev/null +++ b/internal/llm/retry_test.go @@ -0,0 +1,470 @@ +package llm + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/go-logr/logr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/genai" +) + +type mockNetError struct { + timeout bool +} + +func (e *mockNetError) Error() string { return "network error" } +func (e *mockNetError) Timeout() bool { return e.timeout } +func (e *mockNetError) Temporary() bool { return true } + +func TestAnalyzeWithRetry_SuccessFirstAttempt(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return &AnalysisResult{Content: "success from primary"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "success from fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success from primary", result.Content) + assert.Equal(t, 1, primaryCallCount) + assert.Equal(t, 0, fallbackCallCount) +} + +func TestAnalyzeWithRetry_SuccessAfterTransientFailure(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 1 { + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 429, Message: "rate limited"}) + } + return &AnalysisResult{Content: "success after retry"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success after retry", result.Content) + assert.Equal(t, 2, primaryCallCount) + assert.Equal(t, 0, fallbackCallCount) +} + +func TestAnalyzeWithRetry_PrimaryExhaustedFallbackSucceeds(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 503, Message: "service unavailable"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "success from fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success from fallback", result.Content) + assert.Equal(t, 3, primaryCallCount) + assert.Equal(t, 1, fallbackCallCount) +} + +func TestAnalyzeWithRetry_BothModelsExhausted(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 500, Message: "internal server error"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 500, Message: "internal server error"}) + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.Contains(t, err.Error(), "LLM analysis unavailable") + assert.Equal(t, 3, primaryCallCount) + assert.Equal(t, 1, fallbackCallCount) +} + +func TestAnalyzeWithRetry_NonRetryableErrorNoPrimaryRetry(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 400, Message: "bad request"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, 1, primaryCallCount) + assert.Equal(t, 0, fallbackCallCount) +} + +func TestAnalyzeWithRetry_NonRetryableErrorOnFallback(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 503, Message: "service unavailable"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 401, Message: "unauthorized"}) + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, 3, primaryCallCount) + assert.Equal(t, 1, fallbackCallCount) + assert.ErrorContains(t, err, "503") + assert.ErrorContains(t, err, "401") +} + +func TestAnalyzeWithRetry_RetryableStatusCodes(t *testing.T) { + testCases := []struct { + name string + code int + }{ + {"rate limit", 429}, + {"internal server error", 500}, + {"bad gateway", 502}, + {"service unavailable", 503}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 1 { + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: tc.code, Message: "error"}) + } + return &AnalysisResult{Content: "success"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success", result.Content) + assert.Equal(t, 2, primaryCallCount, "Expected retry for status code %d", tc.code) + assert.Equal(t, 0, fallbackCallCount) + }) + } +} + +func TestAnalyzeWithRetry_NonRetryableStatusCodes(t *testing.T) { + testCases := []struct { + name string + code int + }{ + {"bad request", 400}, + {"unauthorized", 401}, + {"forbidden", 403}, + {"not found", 404}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: tc.code, Message: "error"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.Equal(t, 1, primaryCallCount, "Expected no retry for status code %d", tc.code) + assert.Equal(t, 0, fallbackCallCount) + }) + } +} + +func TestAnalyzeWithRetry_ContextCancellation(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + primaryCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 2 { + cancel() + } + return nil, fmt.Errorf("gemini API error: %w", genai.APIError{Code: 503, Message: "service unavailable"}) + } + + fallbackFn := func() (*AnalysisResult, error) { + t.Fatal("fallback should not be called when context is canceled") + return nil, fmt.Errorf("fallback should not be called when context is canceled") + } + + start := time.Now() + result, err := AnalyzeWithRetry(ctx, logr.Discard(), primaryFn, fallbackFn) + elapsed := time.Since(start) + + require.Error(t, err) + assert.Nil(t, result) + assert.ErrorIs(t, err, context.Canceled) + assert.Less(t, elapsed, 1*time.Second, "Should return promptly when context is canceled") +} + +func TestAnalyzeWithRetry_NetworkTimeout(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 1 { + return nil, &mockNetError{timeout: true} + } + return &AnalysisResult{Content: "success after timeout"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success after timeout", result.Content) + assert.Equal(t, 2, primaryCallCount) + assert.Equal(t, 0, fallbackCallCount) +} + +func TestAnalyzeWithRetry_ContextCanceledAsAPIError(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, context.Canceled + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.ErrorIs(t, err, context.Canceled) + assert.Equal(t, 1, primaryCallCount) + assert.Equal(t, 0, fallbackCallCount) +} + +func TestAnalyzeWithRetry_NonRetryableErrorMessages(t *testing.T) { + testCases := []struct { + name string + err error + sentinel error + }{ + {"tool call failed", fmt.Errorf("%w: unknown tool", ErrToolCallFailed), ErrToolCallFailed}, + {"max iterations", ErrMaxIterations, ErrMaxIterations}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + return nil, tc.err + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.Error(t, err) + assert.Nil(t, result) + assert.ErrorIs(t, err, tc.sentinel) + assert.Equal(t, 1, primaryCallCount, "Should not retry for error: %v", tc.err) + assert.Equal(t, 0, fallbackCallCount) + }) + } +} + +func TestAnalyzeWithRetry_RetryableEmptyResponseErrors(t *testing.T) { + testCases := []struct { + name string + err error + }{ + {"no response candidates", ErrNoResponseCandidates}, + {"no content in response", ErrNoContentInResponse}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + fallbackCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 1 { + return nil, tc.err + } + return &AnalysisResult{Content: "success after retry"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + fallbackCallCount++ + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success after retry", result.Content) + assert.Equal(t, 2, primaryCallCount, "Should retry for error: %v", tc.err) + assert.Equal(t, 0, fallbackCallCount) + }) + } +} + +func TestAnalyzeWithRetry_RetryableWrappedErrors(t *testing.T) { + testCases := []struct { + name string + err error + }{ + { + "double wrapped retryable api error", + fmt.Errorf("outer: %w", fmt.Errorf("gemini API error: %w", genai.APIError{Code: 429})), + }, + { + "triple wrapped retryable api error", + fmt.Errorf("outer: %w", fmt.Errorf("middle: %w", fmt.Errorf("gemini API error: %w", genai.APIError{Code: 503}))), + }, + { + "double wrapped sentinel error", + fmt.Errorf("outer: %w", ErrNoResponseCandidates), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + retryDelayOverride = time.Millisecond + t.Cleanup(func() { retryDelayOverride = 0 }) + + primaryCallCount := 0 + + primaryFn := func() (*AnalysisResult, error) { + primaryCallCount++ + if primaryCallCount == 1 { + return nil, tc.err + } + return &AnalysisResult{Content: "success"}, nil + } + + fallbackFn := func() (*AnalysisResult, error) { + return &AnalysisResult{Content: "fallback"}, nil + } + + result, err := AnalyzeWithRetry(context.Background(), logr.Discard(), primaryFn, fallbackFn) + + require.NoError(t, err) + assert.Equal(t, "success", result.Content) + assert.Equal(t, 2, primaryCallCount, "wrapped retryable error should trigger retry: %v", tc.err) + }) + } +} diff --git a/pkg/e2e/e2e.go b/pkg/e2e/e2e.go index 11ff978ba7..f39dcd1f75 100644 --- a/pkg/e2e/e2e.go +++ b/pkg/e2e/e2e.go @@ -264,6 +264,10 @@ func (o *E2EOrchestrator) AnalyzeLogs(ctx context.Context, testErr error) error result, err := engine.Run(ctx) if err != nil { + o.analysisResult = &analysisengine.Result{ + Status: "failed", + Error: err.Error(), + } return fmt.Errorf("log analysis failed: %w", err) } diff --git a/pkg/krknai/analysisengine/engine.go b/pkg/krknai/analysisengine/engine.go index 3f9f15ae56..06774d7ca6 100644 --- a/pkg/krknai/analysisengine/engine.go +++ b/pkg/krknai/analysisengine/engine.go @@ -9,6 +9,7 @@ import ( "path/filepath" "time" + "github.com/go-logr/logr" "github.com/openshift/osde2e/internal/analysisengine" "github.com/openshift/osde2e/internal/llm" "github.com/openshift/osde2e/internal/llm/tools" @@ -35,10 +36,11 @@ type Config struct { // Engine analyzes krkn-ai chaos test results using LLM. type Engine struct { - config *Config - aggregator *krknAggregator.KrknAIAggregator - promptStore *prompts.PromptStore - llmClient llm.LLMClient + config *Config + aggregator *krknAggregator.KrknAIAggregator + promptStore *prompts.PromptStore + llmClient llm.LLMClient + fallbackLLMClient llm.LLMClient } // New creates a new krkn-ai analysis engine. @@ -75,11 +77,17 @@ func New(ctx context.Context, config *Config) (*Engine, error) { return nil, fmt.Errorf("failed to initialize LLM client: %w", err) } + fallbackLLMClient, err := llm.NewGeminiClientWithModel(ctx, config.APIKey, llm.FallbackModel) + if err != nil { + return nil, fmt.Errorf("failed to initialize fallback LLM client: %w", err) + } + return &Engine{ - config: config, - aggregator: agg, - promptStore: promptStore, - llmClient: client, + config: config, + aggregator: agg, + promptStore: promptStore, + llmClient: client, + fallbackLLMClient: fallbackLLMClient, }, nil } @@ -127,7 +135,15 @@ func (e *Engine) Run(ctx context.Context) (*analysisengine.Result, error) { } // Run LLM analysis - result, err := e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + logger := logr.FromContextOrDiscard(ctx) + result, err := llm.AnalyzeWithRetry(ctx, logger, + func() (*llm.AnalysisResult, error) { + return e.llmClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + }, + func() (*llm.AnalysisResult, error) { + return e.fallbackLLMClient.Analyze(ctx, userPrompt, llmConfig, toolRegistry) + }, + ) if err != nil { return nil, fmt.Errorf("LLM analysis failed: %w", err) }