Skip to content

Commit f0b6b0a

Browse files
committed
addressing AI reviewed comments
1 parent 0449f7e commit f0b6b0a

2 files changed

Lines changed: 82 additions & 42 deletions

File tree

server/internal/orchestrator/swarm/rag_config.go

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package swarm
22

33
import (
4+
"fmt"
45
"path"
56

67
"github.com/goccy/go-yaml"
@@ -22,15 +23,15 @@ type ragServerYAML struct {
2223
}
2324

2425
type ragPipelineYAML struct {
25-
Name string `yaml:"name"`
26-
Description string `yaml:"description,omitempty"`
26+
Name string `yaml:"name"`
27+
Description string `yaml:"description,omitempty"`
2728
Database ragDatabaseYAML `yaml:"database"`
2829
Tables []ragTableYAML `yaml:"tables"`
2930
EmbeddingLLM ragLLMYAML `yaml:"embedding_llm"`
3031
RAGLLM ragLLMYAML `yaml:"rag_llm"`
3132
APIKeys *ragAPIKeysYAML `yaml:"api_keys,omitempty"`
32-
TokenBudget int `yaml:"token_budget,omitempty"`
33-
TopN int `yaml:"top_n,omitempty"`
33+
TokenBudget *int `yaml:"token_budget,omitempty"`
34+
TopN *int `yaml:"top_n,omitempty"`
3435
SystemPrompt string `yaml:"system_prompt,omitempty"`
3536
Search *ragSearchYAML `yaml:"search,omitempty"`
3637
}
@@ -70,8 +71,8 @@ type ragSearchYAML struct {
7071
}
7172

7273
type ragDefaultsYAML struct {
73-
TokenBudget int `yaml:"token_budget,omitempty"`
74-
TopN int `yaml:"top_n,omitempty"`
74+
TokenBudget *int `yaml:"token_budget,omitempty"`
75+
TopN *int `yaml:"top_n,omitempty"`
7576
}
7677

7778
// RAGConfigParams holds all inputs needed to generate pgedge-rag-server.yaml.
@@ -94,20 +95,21 @@ type RAGConfigParams struct {
9495
func GenerateRAGConfig(params *RAGConfigParams) ([]byte, error) {
9596
pipelines := make([]ragPipelineYAML, 0, len(params.Config.Pipelines))
9697
for _, p := range params.Config.Pipelines {
97-
pipelines = append(pipelines, buildRAGPipelineYAML(p, params))
98+
pl, err := buildRAGPipelineYAML(p, params)
99+
if err != nil {
100+
return nil, err
101+
}
102+
pipelines = append(pipelines, pl)
98103
}
99104

100105
var defaults *ragDefaultsYAML
101106
if params.Config.Defaults != nil {
102-
d := &ragDefaultsYAML{}
103-
if params.Config.Defaults.TokenBudget != nil {
104-
d.TokenBudget = *params.Config.Defaults.TokenBudget
105-
}
106-
if params.Config.Defaults.TopN != nil {
107-
d.TopN = *params.Config.Defaults.TopN
108-
}
109-
if d.TokenBudget != 0 || d.TopN != 0 {
110-
defaults = d
107+
src := params.Config.Defaults
108+
if src.TokenBudget != nil || src.TopN != nil {
109+
defaults = &ragDefaultsYAML{
110+
TokenBudget: src.TokenBudget,
111+
TopN: src.TopN,
112+
}
111113
}
112114
}
113115

@@ -127,7 +129,7 @@ func GenerateRAGConfig(params *RAGConfigParams) ([]byte, error) {
127129
return data, nil
128130
}
129131

130-
func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) ragPipelineYAML {
132+
func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) (ragPipelineYAML, error) {
131133
tables := make([]ragTableYAML, 0, len(p.Tables))
132134
for _, t := range p.Tables {
133135
tbl := ragTableYAML{
@@ -157,10 +159,13 @@ func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) ragPi
157159
ragLLM.BaseURL = *p.RAGLLM.BaseURL
158160
}
159161

160-
apiKeys := buildRAGAPIKeysYAML(p, params.KeysDir)
162+
apiKeys, err := buildRAGAPIKeysYAML(p, params.KeysDir)
163+
if err != nil {
164+
return ragPipelineYAML{}, err
165+
}
161166

162167
pipeline := ragPipelineYAML{
163-
Name: p.Name,
168+
Name: p.Name,
164169
Database: ragDatabaseYAML{
165170
Host: params.DatabaseHost,
166171
Port: params.DatabasePort,
@@ -178,12 +183,8 @@ func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) ragPi
178183
if p.Description != nil {
179184
pipeline.Description = *p.Description
180185
}
181-
if p.TokenBudget != nil {
182-
pipeline.TokenBudget = *p.TokenBudget
183-
}
184-
if p.TopN != nil {
185-
pipeline.TopN = *p.TopN
186-
}
186+
pipeline.TokenBudget = p.TokenBudget
187+
pipeline.TopN = p.TopN
187188
if p.SystemPrompt != nil {
188189
pipeline.SystemPrompt = *p.SystemPrompt
189190
}
@@ -194,16 +195,27 @@ func buildRAGPipelineYAML(p database.RAGPipeline, params *RAGConfigParams) ragPi
194195
}
195196
}
196197

197-
return pipeline
198+
return pipeline, nil
198199
}
199200

200201
// buildRAGAPIKeysYAML maps each LLM provider that requires a key to the
201202
// corresponding bind-mounted key file path inside the container.
202203
// Embedding key: {keysDir}/{pipeline}_embedding.key
203204
// RAG key: {keysDir}/{pipeline}_rag.key
204205
// If embedding and RAG use the same provider, the RAG key path takes precedence
205-
// (both files contain the same value).
206-
func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) *ragAPIKeysYAML {
206+
// (both files contain the same value). Returns an error if both LLMs share a
207+
// provider but were configured with different API keys.
208+
func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) (*ragAPIKeysYAML, error) {
209+
// Reject mismatched keys for the same provider — the RAG server has a
210+
// single key slot per provider and cannot reconcile two different values.
211+
if p.EmbeddingLLM.Provider == p.RAGLLM.Provider &&
212+
p.EmbeddingLLM.APIKey != nil && *p.EmbeddingLLM.APIKey != "" &&
213+
p.RAGLLM.APIKey != nil && *p.RAGLLM.APIKey != "" &&
214+
*p.EmbeddingLLM.APIKey != *p.RAGLLM.APIKey {
215+
return nil, fmt.Errorf("pipeline %q: embedding_llm and rag_llm share provider %q but have different API keys",
216+
p.Name, p.EmbeddingLLM.Provider)
217+
}
218+
207219
keys := &ragAPIKeysYAML{}
208220

209221
// Embedding provider key
@@ -231,7 +243,7 @@ func buildRAGAPIKeysYAML(p database.RAGPipeline, keysDir string) *ragAPIKeysYAML
231243
}
232244

233245
if keys.Anthropic == "" && keys.OpenAI == "" && keys.Voyage == "" {
234-
return nil
246+
return nil, nil
235247
}
236-
return keys
248+
return keys, nil
237249
}

server/internal/orchestrator/swarm/rag_config_test.go

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,11 @@ func TestGenerateRAGConfig_OptionalPipelineFields(t *testing.T) {
383383
if p.Description != desc {
384384
t.Errorf("description = %q, want %q", p.Description, desc)
385385
}
386-
if p.TokenBudget != 500 {
387-
t.Errorf("token_budget = %d, want 500", p.TokenBudget)
386+
if p.TokenBudget == nil || *p.TokenBudget != 500 {
387+
t.Errorf("token_budget = %v, want 500", p.TokenBudget)
388388
}
389-
if p.TopN != 5 {
390-
t.Errorf("top_n = %d, want 5", p.TopN)
389+
if p.TopN == nil || *p.TopN != 5 {
390+
t.Errorf("top_n = %v, want 5", p.TopN)
391391
}
392392
if p.SystemPrompt != prompt {
393393
t.Errorf("system_prompt = %q, want %q", p.SystemPrompt, prompt)
@@ -416,11 +416,11 @@ func TestGenerateRAGConfig_OptionalPipelineFieldsOmitted(t *testing.T) {
416416
if p.Description != "" {
417417
t.Errorf("description should be empty (omitted), got %q", p.Description)
418418
}
419-
if p.TokenBudget != 0 {
420-
t.Errorf("token_budget should be 0 (omitted), got %d", p.TokenBudget)
419+
if p.TokenBudget != nil {
420+
t.Errorf("token_budget should be nil (omitted), got %v", *p.TokenBudget)
421421
}
422-
if p.TopN != 0 {
423-
t.Errorf("top_n should be 0 (omitted), got %d", p.TopN)
422+
if p.TopN != nil {
423+
t.Errorf("top_n should be nil (omitted), got %v", *p.TopN)
424424
}
425425
if p.SystemPrompt != "" {
426426
t.Errorf("system_prompt should be empty (omitted), got %q", p.SystemPrompt)
@@ -500,11 +500,11 @@ func TestGenerateRAGConfig_DefaultsSection(t *testing.T) {
500500
if cfg.Defaults == nil {
501501
t.Fatal("defaults section should be present when configured")
502502
}
503-
if cfg.Defaults.TokenBudget != 2000 {
504-
t.Errorf("defaults.token_budget = %d, want 2000", cfg.Defaults.TokenBudget)
503+
if cfg.Defaults.TokenBudget == nil || *cfg.Defaults.TokenBudget != 2000 {
504+
t.Errorf("defaults.token_budget = %v, want 2000", cfg.Defaults.TokenBudget)
505505
}
506-
if cfg.Defaults.TopN != 20 {
507-
t.Errorf("defaults.top_n = %d, want 20", cfg.Defaults.TopN)
506+
if cfg.Defaults.TopN == nil || *cfg.Defaults.TopN != 20 {
507+
t.Errorf("defaults.top_n = %v, want 20", cfg.Defaults.TopN)
508508
}
509509
}
510510

@@ -521,3 +521,31 @@ func TestGenerateRAGConfig_DefaultsAbsent(t *testing.T) {
521521
t.Errorf("defaults section should be absent when not configured, got %+v", cfg.Defaults)
522522
}
523523
}
524+
525+
func TestGenerateRAGConfig_SameProviderDifferentKeys_ReturnsError(t *testing.T) {
526+
key1 := "sk-openai-embed"
527+
key2 := "sk-openai-rag-different"
528+
params := &RAGConfigParams{
529+
Config: &database.RAGServiceConfig{
530+
Pipelines: []database.RAGPipeline{
531+
{
532+
Name: "default",
533+
Tables: []database.RAGPipelineTable{{Table: "t", TextColumn: "c", VectorColumn: "v"}},
534+
EmbeddingLLM: database.RAGPipelineLLMConfig{
535+
Provider: "openai", Model: "text-embedding-3-small", APIKey: &key1,
536+
},
537+
RAGLLM: database.RAGPipelineLLMConfig{
538+
Provider: "openai", Model: "gpt-4o", APIKey: &key2,
539+
},
540+
},
541+
},
542+
},
543+
DatabaseName: "mydb", DatabaseHost: "host", DatabasePort: 5432,
544+
Username: "u", Password: "p", KeysDir: "/app/keys",
545+
}
546+
547+
_, err := GenerateRAGConfig(params)
548+
if err == nil {
549+
t.Fatal("expected error for same-provider mismatched API keys, got nil")
550+
}
551+
}

0 commit comments

Comments
 (0)