Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/cli/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ int main(int argc, const char* argv[]) {

if (gen_params.end_image_path.size() > 0) {
vae_decode_only = false;
if (!load_image_and_update_size(gen_params.init_image_path, end_image)) {
if (!load_image_and_update_size(gen_params.end_image_path, end_image)) {
return 1;
}
}
Expand Down
37 changes: 17 additions & 20 deletions src/anima.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,20 +602,19 @@ namespace Anima {
return Rope::embed_nd(ids, bs, axis_thetas, axes_dim);
}

ggml_cgraph* build_graph(ggml_tensor* x,
ggml_tensor* timesteps,
ggml_tensor* context,
ggml_tensor* t5_ids = nullptr,
ggml_tensor* t5_weights = nullptr) {
ggml_cgraph* build_graph(const sd::Tensor<float>& x_tensor,
const sd::Tensor<float>& timesteps_tensor,
const sd::Tensor<float>& context_tensor = {},
const sd::Tensor<int32_t>& t5_ids_tensor = {},
const sd::Tensor<float>& t5_weights_tensor = {}) {
ggml_tensor* x = make_input(x_tensor);
ggml_tensor* timesteps = make_input(timesteps_tensor);
ggml_tensor* context = make_optional_input(context_tensor);
ggml_tensor* t5_ids = make_optional_input(t5_ids_tensor);
ggml_tensor* t5_weights = make_optional_input(t5_weights_tensor);
GGML_ASSERT(x->ne[3] == 1);
ggml_cgraph* gf = new_graph_custom(ANIMA_GRAPH_SIZE);

x = to_backend(x);
timesteps = to_backend(timesteps);
context = to_backend(context);
t5_ids = to_backend(t5_ids);
t5_weights = to_backend(t5_weights);

int64_t pad_h = (net.patch_size - x->ne[1] % net.patch_size) % net.patch_size;
int64_t pad_w = (net.patch_size - x->ne[0] % net.patch_size) % net.patch_size;
int64_t h_pad = x->ne[1] + pad_h;
Expand Down Expand Up @@ -667,18 +666,16 @@ namespace Anima {
return gf;
}

bool compute(int n_threads,
ggml_tensor* x,
ggml_tensor* timesteps,
ggml_tensor* context,
ggml_tensor* t5_ids = nullptr,
ggml_tensor* t5_weights = nullptr,
ggml_tensor** output = nullptr,
ggml_context* output_ctx = nullptr) {
sd::Tensor<float> compute(int n_threads,
const sd::Tensor<float>& x,
const sd::Tensor<float>& timesteps,
const sd::Tensor<float>& context = {},
const sd::Tensor<int32_t>& t5_ids = {},
const sd::Tensor<float>& t5_weights = {}) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(x, timesteps, context, t5_ids, t5_weights);
};
return GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx);
return restore_trailing_singleton_dims(GGMLRunner::compute<float>(get_graph, n_threads, false), x.dim());
}
};
} // namespace Anima
Expand Down
269 changes: 94 additions & 175 deletions src/auto_encoder_kl.hpp

Large diffs are not rendered by default.

52 changes: 27 additions & 25 deletions src/cache_dit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
#include <unordered_map>
#include <vector>

#include "condition_cache_utils.hpp"
#include "ggml_extend.hpp"
#include "tensor.hpp"

struct DBCacheConfig {
bool enabled = false;
Expand Down Expand Up @@ -771,35 +773,37 @@ struct CacheDitConditionState {
return it != cache_diffs.end() && !it->second.diff.empty();
}

void update_cache(const void* cond, const float* input, const float* output, size_t size) {
void update_cache(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
CacheEntry& entry = cache_diffs[cond];
entry.diff.resize(size);
for (size_t i = 0; i < size; i++) {
entry.diff[i] = output[i] - input[i];
if (!sd::store_condition_cache_diff(&entry.diff, input, output)) {
entry.prev_input.clear();
entry.prev_output.clear();
entry.has_prev = false;
return;
}

size_t size = static_cast<size_t>(output.numel());
const float* input_data = input.data();
const float* output_data = output.data();
entry.prev_input.resize(size);
entry.prev_output.resize(size);
for (size_t i = 0; i < size; i++) {
entry.prev_input[i] = input[i];
entry.prev_output[i] = output[i];
entry.prev_input[i] = input_data[i];
entry.prev_output[i] = output_data[i];
}
entry.has_prev = true;
}

void apply_cache(const void* cond, const float* input, float* output, size_t size) {
void apply_cache(const void* cond,
const sd::Tensor<float>& input,
sd::Tensor<float>* output) {
auto it = cache_diffs.find(cond);
if (it == cache_diffs.end() || it->second.diff.empty())
return;
if (it->second.diff.size() != size)
return;

for (size_t i = 0; i < size; i++) {
output[i] = input[i] + it->second.diff[i];
}
sd::apply_condition_cache_diff(it->second.diff, input, output);
}

bool before_condition(const void* cond, ggml_tensor* input, ggml_tensor* output, float sigma, int step_index) {
bool before_condition(const void* cond, const sd::Tensor<float>& input, sd::Tensor<float>* output, float sigma, int step_index) {
if (!enabled() || step_index < 0)
return false;

Expand All @@ -819,8 +823,7 @@ struct CacheDitConditionState {

if (skip_current_step) {
if (has_cache(cond)) {
apply_cache(cond, (float*)input->data, (float*)output->data,
static_cast<size_t>(ggml_nelements(output)));
apply_cache(cond, input, output);
return true;
}
return false;
Expand All @@ -833,13 +836,13 @@ struct CacheDitConditionState {
if (it == cache_diffs.end() || !it->second.has_prev)
return false;

size_t ne = static_cast<size_t>(ggml_nelements(input));
size_t ne = static_cast<size_t>(input.numel());
if (it->second.prev_input.size() != ne)
return false;

float* input_data = (float*)input->data;
float diff = CacheDitState::calculate_residual_diff(
it->second.prev_input.data(), input_data, ne);
const float* input_data = input.data();
float diff = CacheDitState::calculate_residual_diff(
it->second.prev_input.data(), input_data, ne);

float effective_threshold = config.residual_diff_threshold;
if (config.Fn_compute_blocks > 0) {
Expand All @@ -859,23 +862,22 @@ struct CacheDitConditionState {
cached_steps.push_back(current_step_index);
continuous_cached_steps++;
accumulated_residual_diff += diff;
apply_cache(cond, input_data, (float*)output->data, ne);
apply_cache(cond, input, output);
return true;
}

continuous_cached_steps = 0;
return false;
}

void after_condition(const void* cond, ggml_tensor* input, ggml_tensor* output) {
void after_condition(const void* cond, const sd::Tensor<float>& input, const sd::Tensor<float>& output) {
if (!step_is_active())
return;

size_t ne = static_cast<size_t>(ggml_nelements(output));
update_cache(cond, (float*)input->data, (float*)output->data, ne);
update_cache(cond, input, output);

if (cond == anchor_condition && taylor_config.enabled) {
taylor_state.update_derivatives((float*)output->data, ne, current_step_index);
taylor_state.update_derivatives(output.data(), static_cast<size_t>(output.numel()), current_step_index);
}
}

Expand Down
29 changes: 15 additions & 14 deletions src/clip.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -957,15 +957,14 @@ struct CLIPTextModelRunner : public GGMLRunner {
return model.forward(ctx, input_ids, embeddings, mask, max_token_idx, return_pooled, clip_skip);
}

ggml_cgraph* build_graph(ggml_tensor* input_ids,
ggml_cgraph* build_graph(const sd::Tensor<int32_t>& input_ids_tensor,
int num_custom_embeddings = 0,
void* custom_embeddings_data = nullptr,
size_t max_token_idx = 0,
bool return_pooled = false,
int clip_skip = -1) {
ggml_cgraph* gf = new_graph_custom(2048);

input_ids = to_backend(input_ids);
ggml_cgraph* gf = new_graph_custom(2048);
ggml_tensor* input_ids = make_input(input_ids_tensor);

ggml_tensor* embeddings = nullptr;

Expand Down Expand Up @@ -1004,19 +1003,21 @@ struct CLIPTextModelRunner : public GGMLRunner {
return gf;
}

bool compute(const int n_threads,
ggml_tensor* input_ids,
int num_custom_embeddings,
void* custom_embeddings_data,
size_t max_token_idx,
bool return_pooled,
int clip_skip,
ggml_tensor** output,
ggml_context* output_ctx = nullptr) {
sd::Tensor<float> compute(const int n_threads,
const sd::Tensor<int32_t>& input_ids,
int num_custom_embeddings,
void* custom_embeddings_data,
size_t max_token_idx,
bool return_pooled,
int clip_skip) {
auto get_graph = [&]() -> ggml_cgraph* {
return build_graph(input_ids, num_custom_embeddings, custom_embeddings_data, max_token_idx, return_pooled, clip_skip);
};
return GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx);
auto result = GGMLRunner::compute<float>(get_graph, n_threads, true);
if (return_pooled) {
return take_or_empty(std::move(result));
}
return restore_trailing_singleton_dims(std::move(result), 3);
}
};

Expand Down
58 changes: 29 additions & 29 deletions src/common_dit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
#include "ggml_extend.hpp"

namespace DiT {
ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int pw,
int ph,
bool patch_last = true) {
inline ggml_tensor* patchify(ggml_context* ctx,
ggml_tensor* x,
int pw,
int ph,
bool patch_last = true) {
// x: [N, C, H, W]
// return: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
int64_t N = x->ne[3];
Expand All @@ -33,13 +33,13 @@ namespace DiT {
return x;
}

ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t h,
int64_t w,
int ph,
int pw,
bool patch_last = true) {
inline ggml_tensor* unpatchify(ggml_context* ctx,
ggml_tensor* x,
int64_t h,
int64_t w,
int ph,
int pw,
bool patch_last = true) {
// x: [N, h*w, C*ph*pw] if patch_last else [N, h*w, ph*pw*C]
// return: [N, C, H, W]
int64_t N = x->ne[2];
Expand All @@ -64,10 +64,10 @@ namespace DiT {
return x;
}

ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw) {
inline ggml_tensor* pad_to_patch_size(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw) {
int64_t W = x->ne[0];
int64_t H = x->ne[1];

Expand All @@ -77,23 +77,23 @@ namespace DiT {
return x;
}

ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw,
bool patch_last = true) {
inline ggml_tensor* pad_and_patchify(GGMLRunnerContext* ctx,
ggml_tensor* x,
int ph,
int pw,
bool patch_last = true) {
x = pad_to_patch_size(ctx, x, ph, pw);
x = patchify(ctx->ggml_ctx, x, ph, pw, patch_last);
return x;
}

ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
ggml_tensor* x,
int64_t H,
int64_t W,
int ph,
int pw,
bool patch_last = true) {
inline ggml_tensor* unpatchify_and_crop(ggml_context* ctx,
ggml_tensor* x,
int64_t H,
int64_t W,
int ph,
int pw,
bool patch_last = true) {
int pad_h = (ph - H % ph) % ph;
int pad_w = (pw - W % pw) % pw;
int64_t h = ((H + pad_h) / ph);
Expand All @@ -105,4 +105,4 @@ namespace DiT {
}
} // namespace DiT

#endif // __COMMON_DIT_HPP__
#endif // __COMMON_DIT_HPP__
64 changes: 64 additions & 0 deletions src/condition_cache_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#ifndef __CONDITION_CACHE_UTILS_HPP__
#define __CONDITION_CACHE_UTILS_HPP__

#include <vector>

#include "tensor.hpp"

namespace sd {

inline bool store_condition_cache_diff(std::vector<float>* diff,
const sd::Tensor<float>& input,
const sd::Tensor<float>& output) {
if (diff == nullptr || input.empty() || output.empty()) {
return false;
}

size_t input_size = static_cast<size_t>(input.numel());
size_t output_size = static_cast<size_t>(output.numel());
if (input_size == 0 || input_size != output_size) {
diff->clear();
return false;
}

const float* input_data = input.data();
const float* output_data = output.data();
if (input_data == nullptr || output_data == nullptr) {
diff->clear();
return false;
}

diff->resize(output_size);
for (size_t i = 0; i < output_size; ++i) {
(*diff)[i] = output_data[i] - input_data[i];
}
return true;
}

inline bool apply_condition_cache_diff(const std::vector<float>& diff,
const sd::Tensor<float>& input,
sd::Tensor<float>* output) {
if (output == nullptr || input.empty() || diff.empty()) {
return false;
}

size_t input_size = static_cast<size_t>(input.numel());
if (input_size == 0 || diff.size() != input_size) {
return false;
}

*output = input;
float* output_data = output->data();
if (output_data == nullptr) {
return false;
}

for (size_t i = 0; i < input_size; ++i) {
output_data[i] += diff[i];
}
return true;
}

} // namespace sd

#endif // __CONDITION_CACHE_UTILS_HPP__
Loading
Loading