Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 16 additions & 16 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,9 @@ class StableDiffusionGGML {
ggml_backend_t control_net_backend = nullptr;
ggml_backend_t vae_backend = nullptr;

GGMLBackendPtr main_backend;
GGMLBackendPtr cpu_backend;

SDVersion version;
bool vae_decode_only = false;
bool external_vae_is_invalid = false;
Expand Down Expand Up @@ -163,21 +166,18 @@ class StableDiffusionGGML {

StableDiffusionGGML() = default;

~StableDiffusionGGML() {
if (clip_backend != backend) {
ggml_backend_free(clip_backend);
}
if (control_net_backend != backend) {
ggml_backend_free(control_net_backend);
}
if (vae_backend != backend) {
ggml_backend_free(vae_backend);
}
ggml_backend_free(backend);
}
~StableDiffusionGGML() = default;

void init_backend() {
backend = sd_get_default_backend();
main_backend = sd_get_default_backend();
backend = main_backend.get();
}

ggml_backend_t get_cpu_backend() {
if (!cpu_backend) {
cpu_backend = GGMLBackendPtr(ggml_backend_cpu_init());
}
return cpu_backend.get();
}

std::shared_ptr<RNG> get_rng(rng_type_t rng_type) {
Expand Down Expand Up @@ -434,7 +434,7 @@ class StableDiffusionGGML {
clip_backend = backend;
if (clip_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("CLIP: Using CPU backend");
clip_backend = ggml_backend_cpu_init();
clip_backend = get_cpu_backend();
}
if (sd_version_is_sd3(version)) {
cond_stage_model = std::make_shared<SD3CLIPEmbedder>(clip_backend,
Expand Down Expand Up @@ -620,7 +620,7 @@ class StableDiffusionGGML {

if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_INFO("VAE Autoencoder: Using CPU backend");
vae_backend = ggml_backend_cpu_init();
vae_backend = get_cpu_backend();
} else {
vae_backend = backend;
}
Expand Down Expand Up @@ -715,7 +715,7 @@ class StableDiffusionGGML {
ggml_backend_t controlnet_backend = nullptr;
if (sd_ctx_params->keep_control_net_on_cpu && !ggml_backend_is_cpu(backend)) {
LOG_DEBUG("ControlNet: Using CPU backend");
controlnet_backend = ggml_backend_cpu_init();
controlnet_backend = get_cpu_backend();
} else {
controlnet_backend = backend;
}
Expand Down
6 changes: 1 addition & 5 deletions src/upscaler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,8 @@ bool UpscalerGGML::load_from_file(const std::string& esrgan_path,
LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str());
}
model_loader.set_wtype_override(model_data_type);
if (!backend) {
LOG_DEBUG("Using CPU backend");
backend = ggml_backend_cpu_init();
}
LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type));
esrgan_upscaler = std::make_shared<ESRGAN>(backend, offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map());
esrgan_upscaler = std::make_shared<ESRGAN>(backend.get(), offload_params_to_cpu, tile_size, model_loader.get_tensor_storage_map());
esrgan_upscaler->set_max_graph_vram_bytes(max_graph_vram_bytes);
if (direct) {
esrgan_upscaler->set_conv2d_direct_enabled(true);
Expand Down
2 changes: 1 addition & 1 deletion src/upscaler.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#include <string>

struct UpscalerGGML {
ggml_backend_t backend = nullptr; // general backend
GGMLBackendPtr backend; // general backend
ggml_type model_data_type = GGML_TYPE_F16;
std::shared_ptr<ESRGAN> esrgan_upscaler;
std::string esrgan_path;
Expand Down
4 changes: 2 additions & 2 deletions src/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -770,7 +770,7 @@ bool sd_backend_is(ggml_backend_t backend, const std::string& name) {
return dev_name.find(name) != std::string::npos;
}

ggml_backend_t sd_get_default_backend() {
GGMLBackendPtr sd_get_default_backend() {
ggml_backend_load_all_once();
static std::once_flag once;
std::call_once(once, []() {
Expand Down Expand Up @@ -825,7 +825,7 @@ ggml_backend_t sd_get_default_backend() {
LOG_DEBUG("Using CPU backend");
}

return backend;
return GGMLBackendPtr(backend);
}

// namespace is needed to avoid conflicts with ggml_backend_extend.hpp
Expand Down
10 changes: 9 additions & 1 deletion src/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,17 @@ int sd_get_preview_interval();
bool sd_should_preview_denoised();
bool sd_should_preview_noisy();

struct GGMLBackendDeleter {
void operator()(ggml_backend_t backend) const noexcept {
ggml_backend_free(backend);
}
};

using GGMLBackendPtr = std::unique_ptr<struct ggml_backend, GGMLBackendDeleter>;

// test if the backend is a specific one, e.g. "CUDA", "ROCm", "Vulkan" etc.
bool sd_backend_is(ggml_backend_t backend, const std::string& name);
ggml_backend_t sd_get_default_backend();
GGMLBackendPtr sd_get_default_backend();

#define LOG_DEBUG(format, ...) log_printf(SD_LOG_DEBUG, __FILE__, __LINE__, format, ##__VA_ARGS__)
#define LOG_INFO(format, ...) log_printf(SD_LOG_INFO, __FILE__, __LINE__, format, ##__VA_ARGS__)
Expand Down
Loading