diff --git a/cmd/root/eval.go b/cmd/root/eval.go index e44e222ca..c42f07f86 100644 --- a/cmd/root/eval.go +++ b/cmd/root/eval.go @@ -14,6 +14,7 @@ import ( "github.com/docker/docker-agent/pkg/config" "github.com/docker/docker-agent/pkg/evaluation" + "github.com/docker/docker-agent/pkg/model/provider/providers" "github.com/docker/docker-agent/pkg/telemetry" ) @@ -117,6 +118,10 @@ func (f *evalFlags) runEvalCommand(cmd *cobra.Command, args []string) (commandEr f.AgentFilename = agentFilename f.EvalsDir = evalsDir + // Wire the full provider set so the judge model can be built (the package + // default registry is empty; see pkg/model/provider/providers). + f.runConfig.ProviderRegistry = providers.NewDefaultRegistry() + // Run evaluation // Pass consoleOut for TTY progress bar, teeOut for results that should go to both console and log run, evalErr := evaluation.Evaluate(ctx, consoleOut, teeOut, isTTY, runName, &f.runConfig, f.Config) diff --git a/pkg/config/runtime.go b/pkg/config/runtime.go index 0bca5887f..fd0780e61 100644 --- a/pkg/config/runtime.go +++ b/pkg/config/runtime.go @@ -24,6 +24,12 @@ type RuntimeConfig struct { modelsDevStore *modelsdev.Store modelsDevStoreErr error modelsDevStoreOnce sync.Once + + // ProviderRegistry instantiates model providers for toolsets that build + // providers at load time (e.g. RAG embeddings/reranking). It is populated + // by the team loader with the same registry used for agent models. When + // nil, ProviderRegistryOrDefault falls back to provider.DefaultRegistry. + ProviderRegistry *provider.Registry } type Config struct { @@ -78,6 +84,7 @@ func (runConfig *RuntimeConfig) Clone() *RuntimeConfig { ModelsDevStoreOverride: runConfig.ModelsDevStoreOverride, modelsDevStore: store, modelsDevStoreErr: storeErr, + ProviderRegistry: runConfig.ProviderRegistry, } clone.envProviderOnce.Do(func() {}) // mark as resolved clone.modelsDevStoreOnce.Do(func() {}) // mark as resolved @@ -109,6 +116,18 @@ func (runConfig *RuntimeConfig) ModelsDevStore() (*modelsdev.Store, error) { return runConfig.modelsDevStore, runConfig.modelsDevStoreErr } +// ProviderRegistryOrDefault returns the configured provider registry, or the +// package default registry when none was set (including when the receiver is +// nil). The default registry only contains providers the core package can +// expose without optional SDK dependencies, so callers that need the full +// provider set must ensure the team loader populated ProviderRegistry. +func (runConfig *RuntimeConfig) ProviderRegistryOrDefault() *provider.Registry { + if runConfig != nil && runConfig.ProviderRegistry != nil { + return runConfig.ProviderRegistry + } + return provider.DefaultRegistry() +} + func (runConfig *RuntimeConfig) EnvProvider() environment.Provider { if runConfig.EnvProviderForTests != nil { return runConfig.EnvProviderForTests diff --git a/pkg/evaluation/eval.go b/pkg/evaluation/eval.go index f067fae4d..e0cdff0c4 100644 --- a/pkg/evaluation/eval.go +++ b/pkg/evaluation/eval.go @@ -667,7 +667,7 @@ func createJudgeModel(ctx context.Context, judgeModel string, runConfig *config. opts = append(opts, options.WithGateway(runConfig.ModelsGateway)) } - judge, err := provider.New(ctx, &cfg, runConfig.EnvProvider(), opts...) + judge, err := runConfig.ProviderRegistryOrDefault().New(ctx, &cfg, runConfig.EnvProvider(), opts...) if err != nil { return nil, fmt.Errorf("creating judge model: %w", err) } diff --git a/pkg/model/provider/factory_js.go b/pkg/model/provider/factory_js.go index 70714746f..58579f735 100644 --- a/pkg/model/provider/factory_js.go +++ b/pkg/model/provider/factory_js.go @@ -10,6 +10,9 @@ import ( "github.com/docker/docker-agent/pkg/config/latest" "github.com/docker/docker-agent/pkg/environment" + "github.com/docker/docker-agent/pkg/model/provider/anthropic" + "github.com/docker/docker-agent/pkg/model/provider/gemini" + "github.com/docker/docker-agent/pkg/model/provider/openai" "github.com/docker/docker-agent/pkg/model/provider/options" "github.com/docker/docker-agent/pkg/model/provider/rulebased" ) @@ -26,10 +29,36 @@ func NewRegistry(factories map[string]Factory) *Registry { return &Registry{factories: copied} } -var defaultFactories map[string]Factory +// defaultFactories is the js/wasm provider set. dmr (os/exec), amazon-bedrock +// and vertex AI (cloud SDKs that don't compile to wasm) are deliberately +// absent; the remaining providers reach their APIs over plain net/http, which +// the Go runtime maps to fetch in the browser. Unlike the non-js build (whose +// DefaultRegistry is empty so applications must wire providers explicitly via +// pkg/model/provider/providers), the wasm build has no such wiring point — +// pkg/model/provider/providers pulls in the cloud SDKs — so the slim set is +// registered here. +var defaultFactories = map[string]Factory{ + "openai": openaiFactory, + "openai_chatcompletions": openaiFactory, + "openai_responses": openaiFactory, + "anthropic": anthropicFactory, + "google": googleFactory, +} func DefaultRegistry() *Registry { return NewRegistry(defaultFactories) } +func openaiFactory(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { + return openai.NewClient(ctx, cfg, env, opts...) +} + +func anthropicFactory(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { + return anthropic.NewClient(ctx, cfg, env, opts...) +} + +func googleFactory(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { + return gemini.NewClient(ctx, cfg, env, opts...) +} + func (r *Registry) New(ctx context.Context, cfg *latest.ModelConfig, env environment.Provider, opts ...options.Opt) (Provider, error) { return r.NewWithModels(ctx, cfg, nil, env, opts...) } diff --git a/pkg/rag/builder.go b/pkg/rag/builder.go index 2553bf9a3..764389de4 100644 --- a/pkg/rag/builder.go +++ b/pkg/rag/builder.go @@ -27,9 +27,11 @@ type ManagersBuildConfig struct { } // NewProvider creates a model provider using the build config's environment, -// gateway, and custom provider settings. +// gateway, and custom provider settings. It uses the provider registry carried +// by RuntimeConfig (populated by the team loader with the full provider set); +// without it, model creation fails with "unknown provider type". func (c ManagersBuildConfig) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) { - return provider.New(ctx, cfg, c.Env, + return c.RuntimeConfig.ProviderRegistryOrDefault().New(ctx, cfg, c.Env, options.WithGateway(c.ModelsGateway), options.WithProviders(c.Providers)) } diff --git a/pkg/rag/strategy/strategy.go b/pkg/rag/strategy/strategy.go index c21194128..9f0b196cb 100644 --- a/pkg/rag/strategy/strategy.go +++ b/pkg/rag/strategy/strategy.go @@ -26,9 +26,11 @@ type BuildContext struct { } // NewProvider creates a model provider using the build context's environment, -// gateway, and custom provider settings. +// gateway, and custom provider settings. It uses the provider registry carried +// by RuntimeConfig (populated by the team loader with the full provider set); +// without it, model creation fails with "unknown provider type". func (c BuildContext) NewProvider(ctx context.Context, cfg *latest.ModelConfig) (provider.Provider, error) { - return provider.New(ctx, cfg, c.Env, + return c.RuntimeConfig.ProviderRegistryOrDefault().New(ctx, cfg, c.Env, options.WithGateway(c.ModelsGateway), options.WithProviders(c.Providers)) } diff --git a/pkg/teamloader/teamloader.go b/pkg/teamloader/teamloader.go index 96c93657d..3323c5f08 100644 --- a/pkg/teamloader/teamloader.go +++ b/pkg/teamloader/teamloader.go @@ -153,6 +153,9 @@ func LoadWithConfig(ctx context.Context, agentSource config.Source, runConfig *c // Make model definitions available to toolset creators (e.g., RAG reranking) runConfig.Models = cfg.Models runConfig.Providers = cfg.Providers + // Share the resolved provider registry so toolsets that build providers at + // load time (e.g. RAG embeddings/reranking) use the same one as agent models. + runConfig.ProviderRegistry = loadOpts.providerRegistry // Load agents parentDir := cmp.Or(agentSource.ParentDir(), runConfig.WorkingDir)