@@ -447,16 +447,20 @@ struct SDContextParams {
447447 std::string tensor_type_rules;
448448 std::string lora_model_dir;
449449
450+ std::string main_backend_device;
451+ std::string diffusion_backend_device;
452+ std::string clip_backend_device;
453+ std::string vae_backend_device;
454+ std::string tae_backend_device;
455+ std::string control_net_backend_device;
456+
450457 std::map<std::string, std::string> embedding_map;
451458 std::vector<sd_embedding_t > embedding_vec;
452459
453460 rng_type_t rng_type = CUDA_RNG;
454461 rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
455462 bool offload_params_to_cpu = false ;
456463 bool enable_mmap = false ;
457- bool control_net_cpu = false ;
458- bool clip_on_cpu = false ;
459- bool vae_on_cpu = false ;
460464 bool diffusion_flash_attn = false ;
461465 bool diffusion_conv_direct = false ;
462466 bool vae_conv_direct = false ;
@@ -561,6 +565,31 @@ struct SDContextParams {
561565 " --upscale-model" ,
562566 " path to esrgan model." ,
563567 &esrgan_path},
568+ {" " ,
569+ " --main-backend-device" ,
570+ " default device to use for all backends (defaults to main gpu device if hardware acceleration is available, otherwise cpu)" ,
571+ &main_backend_device},
572+ {" " ,
573+ " --diffusion-backend-device" ,
574+ " device to use for diffusion (defaults to main-backend-device)" ,
575+ &diffusion_backend_device},
576+ {" " ,
577+ " --clip-backend-device" ,
578+ " device to use for clip (defaults to main-backend-device)" ,
579+ &clip_backend_device},
580+ {" " ,
581+ " --vae-backend-device" ,
582+ " device to use for vae (defaults to main-backend-device). Also applies to tae, unless tae-backend-device is specified" ,
583+ &vae_backend_device},
584+ {" " ,
585+ " --tae-backend-device" ,
586+ " device to use for tae (defaults to vae-backend-device)" ,
587+ &tae_backend_device},
588+ {" " ,
589+ " --control-net-backend-device" ,
590+ " device to use for control net (defaults to main-backend-device)" ,
591+ &control_net_backend_device},
592+
564593 };
565594
566595 options.int_options = {
@@ -603,18 +632,6 @@ struct SDContextParams {
603632 " --mmap" ,
604633 " whether to memory-map model" ,
605634 true , &enable_mmap},
606- {" " ,
607- " --control-net-cpu" ,
608- " keep controlnet in cpu (for low vram)" ,
609- true , &control_net_cpu},
610- {" " ,
611- " --clip-on-cpu" ,
612- " keep clip in cpu (for low vram)" ,
613- true , &clip_on_cpu},
614- {" " ,
615- " --vae-on-cpu" ,
616- " keep vae in cpu (for low vram)" ,
617- true , &vae_on_cpu},
618635 {" " ,
619636 " --diffusion-fa" ,
620637 " use flash attention in the diffusion model" ,
@@ -875,6 +892,7 @@ struct SDContextParams {
875892
876893 std::string embeddings_str = emb_ss.str ();
877894 std::ostringstream oss;
895+ // TODO backend devices
878896 oss << " SDContextParams {\n "
879897 << " n_threads: " << n_threads << " ,\n "
880898 << " model_path: \" " << model_path << " \" ,\n "
@@ -901,9 +919,9 @@ struct SDContextParams {
901919 << " flow_shift: " << (std::isinf (flow_shift) ? " INF" : std::to_string (flow_shift)) << " \n "
902920 << " offload_params_to_cpu: " << (offload_params_to_cpu ? " true" : " false" ) << " ,\n "
903921 << " enable_mmap: " << (enable_mmap ? " true" : " false" ) << " ,\n "
904- << " control_net_cpu: " << (control_net_cpu ? " true" : " false" ) << " ,\n "
905- << " clip_on_cpu: " << (clip_on_cpu ? " true" : " false" ) << " ,\n "
906- << " vae_on_cpu: " << (vae_on_cpu ? " true" : " false" ) << " ,\n "
922+ // << " control_net_cpu: " << (control_net_cpu ? "true" : "false") << ",\n"
923+ // << " clip_on_cpu: " << (clip_on_cpu ? "true" : "false") << ",\n"
924+ // << " vae_on_cpu: " << (vae_on_cpu ? "true" : "false") << ",\n"
907925 << " diffusion_flash_attn: " << (diffusion_flash_attn ? " true" : " false" ) << " ,\n "
908926 << " diffusion_conv_direct: " << (diffusion_conv_direct ? " true" : " false" ) << " ,\n "
909927 << " vae_conv_direct: " << (vae_conv_direct ? " true" : " false" ) << " ,\n "
@@ -965,9 +983,6 @@ struct SDContextParams {
965983 lora_apply_mode,
966984 offload_params_to_cpu,
967985 enable_mmap,
968- clip_on_cpu,
969- control_net_cpu,
970- vae_on_cpu,
971986 diffusion_flash_attn,
972987 taesd_preview,
973988 diffusion_conv_direct,
@@ -980,6 +995,12 @@ struct SDContextParams {
980995 chroma_t5_mask_pad,
981996 qwen_image_zero_cond_t ,
982997 flow_shift,
998+ main_backend_device.c_str (),
999+ diffusion_backend_device.c_str (),
1000+ clip_backend_device.c_str (),
1001+ vae_backend_device.c_str (),
1002+ tae_backend_device.c_str (),
1003+ control_net_backend_device.c_str (),
9831004 };
9841005 return sd_ctx_params;
9851006 }
0 commit comments