Skip to content

Commit 15585fb

Browse files
committed
Select backend devices via arg
1 parent 0e52afc commit 15585fb

3 files changed

Lines changed: 185 additions & 131 deletions

File tree

examples/common/common.hpp

Lines changed: 42 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -447,16 +447,20 @@ struct SDContextParams {
447447
std::string tensor_type_rules;
448448
std::string lora_model_dir;
449449

450+
std::string main_backend_device;
451+
std::string diffusion_backend_device;
452+
std::string clip_backend_device;
453+
std::string vae_backend_device;
454+
std::string tae_backend_device;
455+
std::string control_net_backend_device;
456+
450457
std::map<std::string, std::string> embedding_map;
451458
std::vector<sd_embedding_t> embedding_vec;
452459

453460
rng_type_t rng_type = CUDA_RNG;
454461
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
455462
bool offload_params_to_cpu = false;
456463
bool enable_mmap = false;
457-
bool control_net_cpu = false;
458-
bool clip_on_cpu = false;
459-
bool vae_on_cpu = false;
460464
bool diffusion_flash_attn = false;
461465
bool diffusion_conv_direct = false;
462466
bool vae_conv_direct = false;
@@ -561,6 +565,31 @@ struct SDContextParams {
561565
"--upscale-model",
562566
"path to esrgan model.",
563567
&esrgan_path},
568+
{"",
569+
"--main-backend-device",
570+
"default device to use for all backends (defaults to main gpu device if hardware acceleration is available, otherwise cpu)",
571+
&main_backend_device},
572+
{"",
573+
"--diffusion-backend-device",
574+
"device to use for diffusion (defaults to main-backend-device)",
575+
&diffusion_backend_device},
576+
{"",
577+
"--clip-backend-device",
578+
"device to use for clip (defaults to main-backend-device)",
579+
&clip_backend_device},
580+
{"",
581+
"--vae-backend-device",
582+
"device to use for vae (defaults to main-backend-device). Also applies to tae, unless tae-backend-device is specified",
583+
&vae_backend_device},
584+
{"",
585+
"--tae-backend-device",
586+
"device to use for tae (defaults to vae-backend-device)",
587+
&tae_backend_device},
588+
{"",
589+
"--control-net-backend-device",
590+
"device to use for control net (defaults to main-backend-device)",
591+
&control_net_backend_device},
592+
564593
};
565594

566595
options.int_options = {
@@ -603,18 +632,6 @@ struct SDContextParams {
603632
"--mmap",
604633
"whether to memory-map model",
605634
true, &enable_mmap},
606-
{"",
607-
"--control-net-cpu",
608-
"keep controlnet in cpu (for low vram)",
609-
true, &control_net_cpu},
610-
{"",
611-
"--clip-on-cpu",
612-
"keep clip in cpu (for low vram)",
613-
true, &clip_on_cpu},
614-
{"",
615-
"--vae-on-cpu",
616-
"keep vae in cpu (for low vram)",
617-
true, &vae_on_cpu},
618635
{"",
619636
"--diffusion-fa",
620637
"use flash attention in the diffusion model",
@@ -875,6 +892,7 @@ struct SDContextParams {
875892

876893
std::string embeddings_str = emb_ss.str();
877894
std::ostringstream oss;
895+
// TODO backend devices
878896
oss << "SDContextParams {\n"
879897
<< " n_threads: " << n_threads << ",\n"
880898
<< " model_path: \"" << model_path << "\",\n"
@@ -901,9 +919,9 @@ struct SDContextParams {
901919
<< " flow_shift: " << (std::isinf(flow_shift) ? "INF" : std::to_string(flow_shift)) << "\n"
902920
<< " offload_params_to_cpu: " << (offload_params_to_cpu ? "true" : "false") << ",\n"
903921
<< " enable_mmap: " << (enable_mmap ? "true" : "false") << ",\n"
904-
<< " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
905-
<< " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
906-
<< " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
922+
// << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
923+
// << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
924+
// << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
907925
<< " diffusion_flash_attn: " << (diffusion_flash_attn ? "true" : "false") << ",\n"
908926
<< " diffusion_conv_direct: " << (diffusion_conv_direct ? "true" : "false") << ",\n"
909927
<< " vae_conv_direct: " << (vae_conv_direct ? "true" : "false") << ",\n"
@@ -965,9 +983,6 @@ struct SDContextParams {
965983
lora_apply_mode,
966984
offload_params_to_cpu,
967985
enable_mmap,
968-
clip_on_cpu,
969-
control_net_cpu,
970-
vae_on_cpu,
971986
diffusion_flash_attn,
972987
taesd_preview,
973988
diffusion_conv_direct,
@@ -980,6 +995,12 @@ struct SDContextParams {
980995
chroma_t5_mask_pad,
981996
qwen_image_zero_cond_t,
982997
flow_shift,
998+
main_backend_device.c_str(),
999+
diffusion_backend_device.c_str(),
1000+
clip_backend_device.c_str(),
1001+
vae_backend_device.c_str(),
1002+
tae_backend_device.c_str(),
1003+
control_net_backend_device.c_str(),
9831004
};
9841005
return sd_ctx_params;
9851006
}

0 commit comments

Comments
 (0)