Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .vitepress/config/apiReferenceSidebar.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ const chatWrappersOrder = [
"Llama3ChatWrapper",
"Llama2ChatWrapper",
"MistralChatWrapper",
"Gemma4ChatWrapper",
"GemmaChatWrapper",
"ChatMLChatWrapper",
"FalconChatWrapper",
Expand Down
6 changes: 3 additions & 3 deletions llama/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,16 @@ list(REMOVE_DUPLICATES GPU_INFO_HEADERS)
list(REMOVE_DUPLICATES GPU_INFO_SOURCES)
list(REMOVE_DUPLICATES GPU_INFO_EXTRA_LIBS)

addVariantSuffix(llama ${NLC_VARIANT})
addVariantSuffix(ggml ${NLC_VARIANT})
addVariantSuffix(llama "${NLC_VARIANT}")
addVariantSuffix(ggml "${NLC_VARIANT}")

file(GLOB SOURCE_FILES "addon/*.cpp" "addon/**/*.cpp" ${GPU_INFO_SOURCES})

add_library(${PROJECT_NAME} SHARED ${SOURCE_FILES} ${CMAKE_JS_SRC} ${GPU_INFO_HEADERS})
set_target_properties(${PROJECT_NAME} PROPERTIES PREFIX "" SUFFIX ".node")
target_link_libraries(${PROJECT_NAME} ${CMAKE_JS_LIB})
target_link_libraries(${PROJECT_NAME} "llama")
target_link_libraries(${PROJECT_NAME} "common")
target_link_libraries(${PROJECT_NAME} "llama-common")

if (DEFINED GPU_INFO_EXTRA_LIBS)
target_link_libraries(${PROJECT_NAME} ${GPU_INFO_EXTRA_LIBS})
Expand Down
114 changes: 79 additions & 35 deletions llama/addon/AddonContext.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include <algorithm>
#include <cmath>
#include "common/common.h"
#include "llama-context.h"
#include "llama-vocab.h"
#include "llama.h"

Expand Down Expand Up @@ -107,15 +108,15 @@ class AddonContextLoadContextWorker : public Napi::AsyncWorker {
try {
context->ctx = llama_init_from_model(context->model->model, context->context_params);

context->contextLoaded = context->ctx != nullptr && context->ctx != NULL;
context->contextLoaded = context->ctx != nullptr;
} catch (const std::exception& e) {
SetError(e.what());
} catch(...) {
SetError("Unknown error when calling \"llama_init_from_model\"");
}
}
void OnOK() {
if (context->contextLoaded) {
if (context->contextLoaded && !context->model->model_params.no_alloc) {
uint64_t contextMemorySize = llama_state_get_size(context->ctx);
adjustNapiExternalMemoryAdd(Env(), contextMemorySize);
context->loadedContextMemorySize = contextMemorySize;
Expand Down Expand Up @@ -173,8 +174,10 @@ class AddonContextUnloadContextWorker : public Napi::AsyncWorker {
}
}
void OnOK() {
adjustNapiExternalMemorySubtract(Env(), context->loadedContextMemorySize);
context->loadedContextMemorySize = 0;
if (!context->model->model_params.no_alloc) {
adjustNapiExternalMemorySubtract(Env(), context->loadedContextMemorySize);
context->loadedContextMemorySize = 0;
}

adjustNapiExternalMemorySubtract(Env(), context->batchMemorySize);
context->batchMemorySize = 0;
Expand Down Expand Up @@ -251,22 +254,8 @@ class AddonContextSampleTokenWorker : public Napi::AsyncWorker {

sampler->rebuildChainIfNeeded();

const auto * logits = llama_get_logits_ith(ctx->ctx, batchLogitIndex);
const int n_vocab = llama_vocab_n_tokens(ctx->model->vocab);

auto & candidates = sampler->tokenCandidates;
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
candidates[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}

llama_token_data_array cur_p = {
/* .data = */ candidates.data(),
/* .size = */ candidates.size(),
/* .selected = */ -1,
/* .sorted = */ false,
};

llama_sampler_apply(sampler->chain, &cur_p);
llama_token_data_array cur_p;
sampler->sample(ctx->ctx, batchLogitIndex, cur_p, returnProbabilities || returnConfidence);

if (!(cur_p.selected >= 0 && cur_p.selected < (int32_t)cur_p.size)) {
no_output = true;
Expand Down Expand Up @@ -397,13 +386,13 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad

context_params = llama_context_default_params();
context_params.n_ctx = 4096;
context_params.n_threads = std::max(cpu_get_num_math(), 1);
context_params.n_threads = std::max(common_cpu_get_num_math(), 1);
context_params.n_threads_batch = context_params.n_threads;
context_params.no_perf = true;
context_params.swa_full = false;

if (info.Length() > 1 && info[1].IsObject()) {
Napi::Object options = info[1].As<Napi::Object>();
const auto options = info[1].As<Napi::Object>();

if (options.Has("contextSize")) {
context_params.n_ctx = options.Get("contextSize").As<Napi::Number>().Uint32Value();
Expand All @@ -427,31 +416,41 @@ AddonContext::AddonContext(const Napi::CallbackInfo& info) : Napi::ObjectWrap<Ad
}

if (options.Has("flashAttention")) {
bool flashAttention = options.Get("flashAttention").As<Napi::Boolean>().Value();
context_params.flash_attn_type = flashAttention ? LLAMA_FLASH_ATTN_TYPE_ENABLED : LLAMA_FLASH_ATTN_TYPE_DISABLED;
const auto flashAttention = options.Get("flashAttention");

if (flashAttention.IsString() && flashAttention.As<Napi::String>().Utf8Value() == "auto") {
context_params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
} else {
const bool flashAttentionEnabled = flashAttention.As<Napi::Boolean>().Value();
context_params.flash_attn_type = flashAttentionEnabled
? LLAMA_FLASH_ATTN_TYPE_ENABLED
: LLAMA_FLASH_ATTN_TYPE_DISABLED;
}
}

if (options.Has("threads")) {
const auto n_threads = options.Get("threads").As<Napi::Number>().Int32Value();
const auto resolved_n_threads = n_threads == 0 ? std::max((int32_t)std::thread::hardware_concurrency(), context_params.n_threads) : n_threads;
const auto threads = options.Get("threads").As<Napi::Number>().Int32Value();
const auto resolvedThreads = threads == 0
? std::max((int32_t)std::thread::hardware_concurrency(), context_params.n_threads)
: threads;

context_params.n_threads = resolved_n_threads;
context_params.n_threads_batch = resolved_n_threads;
context_params.n_threads = resolvedThreads;
context_params.n_threads_batch = resolvedThreads;
}

if (options.Has("performanceTracking")) {
context_params.no_perf = !(options.Get("performanceTracking").As<Napi::Boolean>().Value());
}

if (options.Has("kvCacheKeyType") && options.Get("kvCacheKeyType").IsNumber()) {
auto keyType = options.Get("kvCacheKeyType").As<Napi::Number>().Int32Value();
const auto keyType = options.Get("kvCacheKeyType").As<Napi::Number>().Int32Value();
if (keyType >= 0 && keyType < GGML_TYPE_COUNT) {
context_params.type_k = static_cast<ggml_type>(keyType);
}
}

if (options.Has("kvCacheValueType") && options.Get("kvCacheValueType").IsNumber()) {
auto valueType = options.Get("kvCacheValueType").As<Napi::Number>().Int32Value();
const auto valueType = options.Get("kvCacheValueType").As<Napi::Number>().Int32Value();
if (valueType >= 0 && valueType < GGML_TYPE_COUNT) {
context_params.type_v = static_cast<ggml_type>(valueType);
}
Expand All @@ -476,8 +475,10 @@ void AddonContext::dispose() {
contextLoaded = false;
llama_free(ctx);

adjustNapiExternalMemorySubtract(Env(), loadedContextMemorySize);
loadedContextMemorySize = 0;
if (!model->model_params.no_alloc) {
adjustNapiExternalMemorySubtract(Env(), loadedContextMemorySize);
loadedContextMemorySize = 0;
}
}

model->Unref();
Expand Down Expand Up @@ -728,6 +729,49 @@ Napi::Value AddonContext::GetStateSize(const Napi::CallbackInfo& info) {
return Napi::Number::From(info.Env(), llama_state_get_size(ctx));
}

Napi::Value AddonContext::GetMemoryBreakdown(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

if (!contextLoaded || ctx == nullptr) {
Napi::Error::New(info.Env(), "Context is not loaded").ThrowAsJavaScriptException();
return info.Env().Undefined();
}

std::size_t cpuRam = 0;
std::size_t gpuVram = 0;

for (const auto& [bufferType, memoryBreakdown] : ctx->memory_breakdown()) {
const std::size_t size = memoryBreakdown.context + memoryBreakdown.compute;
if (size == 0) {
continue;
}

if (ggml_backend_buft_is_host(bufferType)) {
cpuRam += size;
} else {
ggml_backend_dev_t device = ggml_backend_buft_get_device(bufferType);
if (device != nullptr) {
auto deviceType = ggml_backend_dev_type(device);
if (deviceType == GGML_BACKEND_DEVICE_TYPE_GPU || deviceType == GGML_BACKEND_DEVICE_TYPE_IGPU) {
gpuVram += size;
} else {
cpuRam += size;
}
} else {
cpuRam += size;
}
}
}

Napi::Object result = Napi::Object::New(info.Env());
result.Set("cpuRam", Napi::Number::New(info.Env(), cpuRam));
result.Set("gpuVram", Napi::Number::New(info.Env(), gpuVram));
return result;
}

Napi::Value AddonContext::GetThreads(const Napi::CallbackInfo& info) {
if (disposed) {
Napi::Error::New(info.Env(), "Context is disposed").ThrowAsJavaScriptException();
Expand All @@ -744,9 +788,8 @@ Napi::Value AddonContext::SetThreads(const Napi::CallbackInfo& info) {
}

const auto threads = info[0].As<Napi::Number>().Int32Value();
const auto resolvedThreads = threads == 0
? std::max((int32_t)std::thread::hardware_concurrency(), std::max(cpu_get_num_math(), 1))
: threads;
const auto resolvedThreads =
threads == 0 ? std::max((int32_t)std::thread::hardware_concurrency(), std::max(common_cpu_get_num_math(), 1)) : threads;

if (llama_n_threads(ctx) != resolvedThreads) {
llama_set_n_threads(ctx, resolvedThreads, resolvedThreads);
Expand Down Expand Up @@ -1062,6 +1105,7 @@ void AddonContext::init(Napi::Object exports) {
InstanceMethod("sampleToken", &AddonContext::SampleToken),
InstanceMethod("getEmbedding", &AddonContext::GetEmbedding),
InstanceMethod("getStateSize", &AddonContext::GetStateSize),
InstanceMethod("getMemoryBreakdown", &AddonContext::GetMemoryBreakdown),
InstanceMethod("getThreads", &AddonContext::GetThreads),
InstanceMethod("setThreads", &AddonContext::SetThreads),
InstanceMethod("printTimings", &AddonContext::PrintTimings),
Expand Down
1 change: 1 addition & 0 deletions llama/addon/AddonContext.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class AddonContext : public Napi::ObjectWrap<AddonContext> {

Napi::Value GetEmbedding(const Napi::CallbackInfo& info);
Napi::Value GetStateSize(const Napi::CallbackInfo& info);
Napi::Value GetMemoryBreakdown(const Napi::CallbackInfo& info);
Napi::Value GetThreads(const Napi::CallbackInfo& info);
Napi::Value SetThreads(const Napi::CallbackInfo& info);

Expand Down
Loading
Loading