diff --git a/src/stable-diffusion.cpp b/src/stable-diffusion.cpp index ccd52bd98..959ede388 100644 --- a/src/stable-diffusion.cpp +++ b/src/stable-diffusion.cpp @@ -117,6 +117,9 @@ class StableDiffusionGGML { ggml_backend_t control_net_backend = nullptr; ggml_backend_t vae_backend = nullptr; + GGMLBackendPtr main_backend; + GGMLBackendPtr cpu_backend; + SDVersion version; bool vae_decode_only = false; bool external_vae_is_invalid = false; @@ -163,21 +166,18 @@ class StableDiffusionGGML { StableDiffusionGGML() = default; - ~StableDiffusionGGML() { - if (clip_backend != backend) { - ggml_backend_free(clip_backend); - } - if (control_net_backend != backend) { - ggml_backend_free(control_net_backend); - } - if (vae_backend != backend) { - ggml_backend_free(vae_backend); - } - ggml_backend_free(backend); - } + ~StableDiffusionGGML() = default; void init_backend() { - backend = sd_get_default_backend(); + main_backend = sd_get_default_backend(); + backend = main_backend.get(); + } + + ggml_backend_t get_cpu_backend() { + if (!cpu_backend) { + cpu_backend = GGMLBackendPtr(ggml_backend_cpu_init()); + } + return cpu_backend.get(); } std::shared_ptr get_rng(rng_type_t rng_type) { @@ -434,7 +434,7 @@ class StableDiffusionGGML { clip_backend = backend; if (clip_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("CLIP: Using CPU backend"); - clip_backend = ggml_backend_cpu_init(); + clip_backend = get_cpu_backend(); } if (sd_version_is_sd3(version)) { cond_stage_model = std::make_shared(clip_backend, @@ -620,7 +620,7 @@ class StableDiffusionGGML { if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); - vae_backend = ggml_backend_cpu_init(); + vae_backend = get_cpu_backend(); } else { vae_backend = backend; } @@ -715,7 +715,7 @@ class StableDiffusionGGML { ggml_backend_t controlnet_backend = nullptr; if (sd_ctx_params->keep_control_net_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_DEBUG("ControlNet: Using CPU backend"); - controlnet_backend = ggml_backend_cpu_init(); + controlnet_backend = get_cpu_backend(); } else { controlnet_backend = backend; } diff --git a/src/upscaler.cpp b/src/upscaler.cpp index 25fc0c5df..53073803b 100644 --- a/src/upscaler.cpp +++ b/src/upscaler.cpp @@ -31,12 +31,8 @@ bool UpscalerGGML::load_from_file(const std::string& esrgan_path, LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); } model_loader.set_wtype_override(model_data_type); - if (!backend) { - LOG_DEBUG("Using CPU backend"); - backend = ggml_backend_cpu_init(); - } LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); + esrgan_upscaler = std::make_shared(backend.get(), offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map()); esrgan_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes); if (direct) { esrgan_upscaler->set_conv2d_direct_enabled(true); diff --git a/src/upscaler.h b/src/upscaler.h index d667a6f15..056e290a5 100644 --- a/src/upscaler.h +++ b/src/upscaler.h @@ -9,7 +9,7 @@ #include struct UpscalerGGML { - ggml_backend_t backend = nullptr; // general backend + GGMLBackendPtr backend; // general backend ggml_type model_data_type = GGML_TYPE_F16; std::shared_ptr esrgan_upscaler; std::string esrgan_path; diff --git a/src/util.cpp b/src/util.cpp index 586284c84..a30a6e084 100644 --- a/src/util.cpp +++ b/src/util.cpp @@ -770,7 +770,7 @@ bool sd_backend_is(ggml_backend_t backend, const std::string& name) { return dev_name.find(name) != std::string::npos; } -ggml_backend_t sd_get_default_backend() { +GGMLBackendPtr sd_get_default_backend() { ggml_backend_load_all_once(); static std::once_flag once; std::call_once(once, []() { @@ -825,7 +825,7 @@ ggml_backend_t sd_get_default_backend() { LOG_DEBUG("Using CPU backend"); } - return backend; + return GGMLBackendPtr(backend); } // namespace is needed to avoid conflicts with ggml_backend_extend.hpp diff --git a/src/util.h b/src/util.h index 628a1f9d7..33fd32453 100644 --- a/src/util.h +++ b/src/util.h @@ -84,9 +84,17 @@ int sd_get_preview_interval(); bool sd_should_preview_denoised(); bool sd_should_preview_noisy(); +struct GGMLBackendDeleter { + void operator()(ggml_backend_t backend) const noexcept { + ggml_backend_free(backend); + } +}; + +using GGMLBackendPtr = std::unique_ptr; + // test if the backend is a specific one, e.g. "CUDA", "ROCm", "Vulkan" etc. bool sd_backend_is(ggml_backend_t backend, const std::string& name); -ggml_backend_t sd_get_default_backend(); +GGMLBackendPtr sd_get_default_backend(); #define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__) #define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)