Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions core/backend/ctx_propagation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,33 @@ var _ = Describe("X-LocalAI-Node ctx propagation contract", func() {
stampViaRouterCtx()
})

// Regression for #10636: a canceled request context must NOT cancel the
// model LOAD. The heavy image/audio backends bind the load to the request
// context so the routing holder reaches the SmartRouter; but a large
// diffusers/LLM model on a slow (e.g. shared-memory iGPU) host can take
// far longer to load than the client stays connected. If the request's
// cancellation propagates to the load, the LoadModel RPC is aborted, the
// backend process is torn down, and every retry restarts from scratch and
// never converges. The load must instead run to completion and cache while
// still carrying the request's routing holder value.
It("ImageGeneration does not propagate request cancellation to the model load", func() {
canceledCtx, cancel := context.WithCancel(reqCtx)
cancel() // client disconnected while the (slow) load was still running

_, err := backend.ImageGeneration(canceledCtx, 64, 64, 1, 0, "p", "", "", "/tmp/out.png", loader, modelCfg, appCfg, nil)
// The load reached the router (short-circuit sentinel), i.e. it was
// NOT aborted early by the already-canceled request context.
Expect(err).To(HaveOccurred())
Expect(err.Error()).To(ContainSubstring("router short-circuit (test)"))

routerCtx := routerCtxOf()
Expect(routerCtx).ToNot(BeNil(), "router callback must have been invoked")
Expect(routerCtx.Err()).To(BeNil(),
"a canceled request must not cancel the model load")
// The routing holder value still propagates despite the decoupling.
stampViaRouterCtx()
})

It("does NOT leak the holder when the app context is used instead", func() {
// Sanity: the bug being fixed manifests as the router getting
// appCfg.Context (no holder) instead of reqCtx (holder). A direct
Expand Down
12 changes: 8 additions & 4 deletions core/backend/embeddings.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,14 @@ func (e *modelEmbedder) Embed(ctx context.Context, text string) ([]float32, erro

func ModelEmbedding(ctx context.Context, s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) {

// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))

inferenceModel, err := loader.Load(opts...)
if err != nil {
Expand Down
12 changes: 8 additions & 4 deletions core/backend/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@ import (

func ImageGeneration(ctx context.Context, height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) {

// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
inferenceModel, err := loader.Load(
opts...,
)
Expand Down
7 changes: 6 additions & 1 deletion core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,12 @@ func ModelInference(ctx context.Context, s string, messages schema.Messages, ima
}
ctx = distributedhdr.MaybeWithPrefixChain(ctx, c.ModelID(), chainSource)

opts := ModelOptions(*c, o, model.WithContext(ctx))
// context.WithoutCancel decouples the model load from the request's
// cancellation while preserving its routing values, so a slow load still
// completes and caches if the client disconnects instead of aborting the
// LoadModel RPC mid-load (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(*c, o, model.WithContext(context.WithoutCancel(ctx)))
inferenceModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile})
Expand Down
12 changes: 8 additions & 4 deletions core/backend/rerank.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,14 @@ func (r *modelReranker) Rerank(ctx context.Context, query string, documents []st
}

func Rerank(ctx context.Context, request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) {
// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
rerankModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
Expand Down
12 changes: 8 additions & 4 deletions core/backend/transcript.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,14 @@ func loadTranscriptionModel(ctx context.Context, ml *model.ModelLoader, modelCon
if modelConfig.Backend == "" {
modelConfig.Backend = model.WhisperBackend
}
// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
transcriptionModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
Expand Down
16 changes: 11 additions & 5 deletions core/backend/tts.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,14 @@ func ModelTTS(
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig,
) (string, *proto.Result, error) {
// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
ttsModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
Expand Down Expand Up @@ -153,7 +157,9 @@ func ModelTTSStream(
modelConfig config.ModelConfig,
audioCallback func([]byte) error,
) error {
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// See ModelTTS above: WithoutCancel decouples the load from request
// cancellation while preserving routing values (issue #10636).
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
ttsModel, err := loader.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
Expand Down
12 changes: 8 additions & 4 deletions core/backend/vad.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,14 @@ func VAD(request *schema.VADRequest,
ml *model.ModelLoader,
appConfig *config.ApplicationConfig,
modelConfig config.ModelConfig) (*schema.VADResponse, error) {
// model.WithContext(ctx) overrides the app-context default set in
// ModelOptions so distributed routing decisions reach the request's
// X-LocalAI-Node holder via distributedhdr.Stamp.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(ctx))
// model.WithContext carries the request context into the load so distributed
// routing decisions reach the request's X-LocalAI-Node holder via
// distributedhdr.Stamp. context.WithoutCancel keeps those values but drops
// the request's cancellation, so a slow first load still completes and
// caches if the client disconnects instead of aborting the LoadModel RPC and
// tearing down the backend process (issue #10636). Inference below keeps the
// cancellable ctx, so a disconnect still stops generation.
opts := ModelOptions(modelConfig, appConfig, model.WithContext(context.WithoutCancel(ctx)))
vadModel, err := ml.Load(opts...)
if err != nil {
recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil)
Expand Down
Loading