diff --git a/core/backend/ctx_propagation_test.go b/core/backend/ctx_propagation_test.go index 34f269aa3e14..1d5c05444ec9 100644 --- a/core/backend/ctx_propagation_test.go +++ b/core/backend/ctx_propagation_test.go @@ -157,6 +157,33 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() { stampViaRouterCtx() }) + // Regression for #10636: a canceled request context must NOT cancel the + // model LOAD. The heavy image/audio backends bind the load to the request + // context so the routing holder reaches the SmartRouter; but a large + // diffusers/LLM model on a slow (e.g. shared-memory iGPU) host can take + // far longer to load than the client stays connected. If the request's + // cancellation propagates to the load, the LoadModel RPC is aborted, the + // backend process is torn down, and every retry restarts from scratch and + // never converges. The load must instead run to completion and cache while + // still carrying the request's routing holder value. + It("ImageGeneration does not propagate request cancellation to the model load", func() { + canceledCtx, cancel := context.WithCancel(reqCtx) + cancel() // client disconnected while the (slow) load was still running + + _, err := backend.ImageGeneration(canceledCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil) + // The load reached the router (short-circuit sentinel), i.e. it was + // NOT aborted early by the already-canceled request context. + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("router short-circuit (test)")) + + routerCtx := routerCtxOf() + Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked") + Expect(routerCtx.Err()).To(BeNil(), + "a canceled request must not cancel the model load") + // The routing holder value still propagates despite the decoupling. + stampViaRouterCtx() + }) + It("does NOT leak the holder when the app context is used instead", func() { // Sanity: the bug being fixed manifests as the router getting // appCfg.Context (no holder) instead of reqCtx (holder). A direct diff --git a/core/backend/embeddings.go b/core/backend/embeddings.go index eff88ef04b19..cab2c4b8a5a5 100644 --- a/core/backend/embeddings.go +++ b/core/backend/embeddings.go @@ -40,10 +40,14 @@ func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, erro func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) inferenceModel, err := loader.Load(opts...) if err != nil { diff --git a/core/backend/image.go b/core/backend/image.go index 7414d5d2b623..c167c8ef4b64 100644 --- a/core/backend/image.go +++ b/core/backend/image.go @@ -13,10 +13,14 @@ import ( func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) inferenceModel, err := loader.Load( opts..., ) diff --git a/core/backend/llm.go b/core/backend/llm.go index 4f6b4d216b5a..1f3b372f7487 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -111,7 +111,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima } ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource) - opts := ModelOptions(*c, o, model.WithContext(ctx)) + // context.WithoutCancel decouples the model load from the request's + // cancellation while preserving its routing values, so a slow load still + // completes and caches if the client disconnects instead of aborting the + // LoadModel RPC mid-load (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(*c, o, model.WithContext(context.WithoutCancel(ctx))) inferenceModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile}) diff --git a/core/backend/rerank.go b/core/backend/rerank.go index 0c0d15b226a1..de45f568fa95 100644 --- a/core/backend/rerank.go +++ b/core/backend/rerank.go @@ -57,10 +57,14 @@ func (r *modelReranker) Rerank(ctx context.Context, query string, documents []st } func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) rerankModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) diff --git a/core/backend/transcript.go b/core/backend/transcript.go index 211269160750..f91c66e2178f 100644 --- a/core/backend/transcript.go +++ b/core/backend/transcript.go @@ -45,10 +45,14 @@ func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelCon if modelConfig.Backend == "" { modelConfig.Backend = model.WhisperBackend } - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) transcriptionModel, err := ml.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) diff --git a/core/backend/tts.go b/core/backend/tts.go index 2b49149ae46a..ffea458f7ae2 100644 --- a/core/backend/tts.go +++ b/core/backend/tts.go @@ -50,10 +50,14 @@ func ModelTTS( appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, ) (string, *proto.Result, error) { - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) ttsModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) @@ -153,7 +157,9 @@ func ModelTTSStream( modelConfig config.ModelConfig, audioCallback func([]byte) error, ) error { - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // See ModelTTS above: WithoutCancel decouples the load from request + // cancellation while preserving routing values (issue #10636). + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) ttsModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) diff --git a/core/backend/vad.go b/core/backend/vad.go index 1c874e3382e5..fd5caa7a2f87 100644 --- a/core/backend/vad.go +++ b/core/backend/vad.go @@ -14,10 +14,14 @@ func VAD(request *schema.VADRequest, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*schema.VADResponse, error) { - // model.WithContext(ctx) overrides the app-context default set in - // ModelOptions so distributed routing decisions reach the request's - // X-LocalAI-Node holder via distributedhdr.Stamp. - opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx)) + // model.WithContext carries the request context into the load so distributed + // routing decisions reach the request's X-LocalAI-Node holder via + // distributedhdr.Stamp. context.WithoutCancel keeps those values but drops + // the request's cancellation, so a slow first load still completes and + // caches if the client disconnects instead of aborting the LoadModel RPC and + // tearing down the backend process (issue #10636). Inference below keeps the + // cancellable ctx, so a disconnect still stops generation. + opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx))) vadModel, err := ml.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)