Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions common/chat-auto-parser-generator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -369,9 +369,7 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte
arguments.name_suffix) +
arguments.value_prefix +
(schema_info.resolves_to_string(param_schema) ?
p.tool_arg_string_value(p.schema(until_suffix,
"tool-" + name + "-arg-" + param_name + "-schema",
param_schema, true)) :
p.tool_arg_string_value(until_suffix) :
p.tool_arg_json_value(p.schema(
p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) +
p.space()) +
Expand Down
2 changes: 0 additions & 2 deletions common/reasoning-budget.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,6 @@ static void common_reasoning_budget_apply(struct llama_sampler * smpl, llama_tok
for (size_t i = 0; i < cur_p->size; i++) {
if (cur_p->data[i].id != forced) {
cur_p->data[i].logit = -INFINITY;
} else {
cur_p->data[i].logit = +INFINITY; // force the token
}
}
}
Expand Down
30 changes: 27 additions & 3 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,15 +710,15 @@ def _generate_nvfp4_tensors(self):
self._repack_nvfp4(name, weight, scale, scale2, input_scale)

# Flush any remaining experts (fallback if n_experts was unknown)
for bid, proj_type in expert_blocks.keys():
for bid, proj_type in list(expert_blocks.keys()):
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_input_scales, expert_shapes, bid, proj_type)

# Remove consumed tensors so get_tensors/modify_tensors won't see them
for name in consumed:
self.model_tensors.pop(name, None)

# Remove any remaining unused auxiliary tensors
for name in self.model_tensors.keys():
for name in list(self.model_tensors.keys()):
if name.endswith((".k_scale", ".v_scale")):
del self.model_tensors[name]

Expand Down Expand Up @@ -7988,13 +7988,37 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_freqs_full = torch.tensor(values, dtype=torch.float32)
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), rope_freqs_full)

def _generate_nvfp4_tensors(self):
# Gemma-4 stores a per-layer router.per_expert_scale ([n_expert]) that scales
# each expert's contribution. It's mathematically equivalent to a per-expert
# scalar on the down_proj output, which is exactly where ffn_down_exps_s is
# applied at inference. Fold it into each expert's NVFP4 weight_scale_2 so the
# existing NVFP4 path produces the right scales.
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
for name in [n for n in self.model_tensors if n.endswith(".router.per_expert_scale")]:
bid_match = re.search(r"\.layers\.(\d+)\.", name)
if bid_match is None:
continue
bid = bid_match.group(1)
prefix = name[: name.index(f".layers.{bid}.") + len(f".layers.{bid}.")]
w2_targets = [f"{prefix}experts.{e}.down_proj.weight_scale_2" for e in range(n_experts)]
present = [w2 in self.model_tensors for w2 in w2_targets]
if not any(present):
continue
assert all(present), f"layer {bid}: partial NVFP4 quantization across experts"
r = self.model_tensors.pop(name)
for e, w2 in enumerate(w2_targets):
s = self.model_tensors[w2]
self.model_tensors[w2] = lambda s=s, r=r, i=e: s() * r()[i]
super()._generate_nvfp4_tensors()

@classmethod
def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None:
name, gen = item

if name.endswith("per_dim_scale") or name.endswith("layer_scalar"):
name = name + ".weight"
if ".experts." in name and not name.endswith(".weight"):
if ".experts." in name and not name.endswith((".weight", ".weight_scale", ".weight_scale_2", ".input_scale")):
name += ".weight"

return super().filter_tensors((name, gen))
Expand Down
4 changes: 2 additions & 2 deletions examples/sycl/start-svr.sh
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,14 @@ if [ $GGML_SYCL_DEVICE -ne -1 ]; then
echo "Use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
export ONEAPI_DEVICE_SELECTOR="level_zero:${GGML_SYCL_DEVICE}"
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
else
echo "Use all Intel GPUs, including iGPU & dGPU"
GPUS_SETTING="-sm ${SPLIT_MODE}"
fi

echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 200 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap --host 0.0.0.0 --port 8000"
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap --host 0.0.0.0 --port 8000


2 changes: 1 addition & 1 deletion ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ extern "C" {
// device type
enum ggml_backend_dev_type type;
// device id
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
// for PCI devices, this should be the lower-case PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:c1:00.0")
// if the id is unknown, this should be NULL
const char * device_id;
// device capabilities
Expand Down
2 changes: 1 addition & 1 deletion ggml/src/ggml-backend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -965,7 +965,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
}
if (sched->debug > 1) {
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name,
GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_desc(node), node->name,
fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0);
for (int j = 0; j < GGML_MAX_SRC; j++) {
Expand Down
33 changes: 33 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/roll.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/snake.cuh"
#include "ggml-cuda/softcap.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
Expand Down Expand Up @@ -3757,6 +3758,35 @@ static int ggml_cuda_try_fuse(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph
return 2;
}

// Snake activation: y = x + sin(a*x)^2 * inv_b
// Naive 5-op decomposition emitted by frontends: mul -> sin -> sqr -> mul -> add
if (ggml_can_fuse_subgraph(cgraph, i,
{ GGML_OP_MUL, GGML_OP_SIN, GGML_OP_SQR, GGML_OP_MUL, GGML_OP_ADD },
{ i + 4 })) {
const ggml_tensor * mul0 = cgraph->nodes[i];
const ggml_tensor * sqr = cgraph->nodes[i + 2];
const ggml_tensor * mul1 = cgraph->nodes[i + 3];
ggml_tensor * add = cgraph->nodes[i + 4];

// x carries the full activation shape, a is the broadcast operand
const ggml_tensor * x = ggml_are_same_shape(mul0, mul0->src[0]) ? mul0->src[0] : mul0->src[1];
const ggml_tensor * a = (x == mul0->src[0]) ? mul0->src[1] : mul0->src[0];

// mul1 reads sqr and inv_b in either operand order
const ggml_tensor * inv_b = (mul1->src[0] == sqr) ? mul1->src[1] : mul1->src[0];

// closure check: the trailing add must read the same x as the leading mul
const ggml_tensor * x_in_add = (add->src[0] == mul1) ? add->src[1] : add->src[0];

const bool type_ok = (x->type == GGML_TYPE_F32 || x->type == GGML_TYPE_F16 || x->type == GGML_TYPE_BF16);
const bool shape_ok = ggml_are_same_shape(a, inv_b) && a->ne[0] == 1 && a->ne[1] == x->ne[1];

if (type_ok && shape_ok && x_in_add == x && add->type == x->type) {
ggml_cuda_op_snake_fused(*cuda_ctx, x, a, inv_b, add);
return 4;
}
}

// multi-(add or mul)
if (node->op == GGML_OP_ADD || node->op == GGML_OP_MUL) {
int n_fuse = 0;
Expand Down Expand Up @@ -5434,6 +5464,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
char pci_bus_id[32] = {};
CUDA_CHECK(cudaDeviceGetPCIBusId(pci_bus_id, sizeof(pci_bus_id), i));
dev_ctx->pci_bus_id = pci_bus_id;
for (char & c : dev_ctx->pci_bus_id) {
c = std::tolower(c);
}
dev_ctx->op_offload_min_batch_size = min_batch_size;

ggml_backend_dev_t dev = new ggml_backend_device {
Expand Down
72 changes: 72 additions & 0 deletions ggml/src/ggml-cuda/snake.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#include "snake.cuh"
#include "convert.cuh"

// Fused Snake activation: y = x + sin^2(a * x) * inv_b
// x: [T, C] (T contiguous), a: [1, C], inv_b: [1, C]
// Supports F32, F16, BF16 data with F32 compute.

template <typename T>
static __global__ void snake_kernel(
const T * __restrict__ x,
const float * __restrict__ a,
const float * __restrict__ inv_b,
T * __restrict__ dst,
const int total,
const uint3 T_len_fastdiv) {
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= total) return;

const int c = (int) fastdiv((uint32_t) idx, T_len_fastdiv);

const float xi = ggml_cuda_cast<float>(x[idx]);
const float s = sinf(a[c] * xi);
dst[idx] = ggml_cuda_cast<T>(xi + s * s * inv_b[c]);
}

// Internal launcher with explicit x/a/inv_b/dst tensors.
// Shared by the public op (reads dst->src) and the fusion path (explicit args).
static void launch_snake(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
const float * a_d = (const float *)a->data;
const float * inv_b_d = (const float *)inv_b->data;

const int T = (int)x->ne[0];
const int C = (int)x->ne[1];
const int total = T * C;
const uint3 T_len_fastdiv = init_fastdiv_values((uint64_t) T);

const int block_size = 256;
const int grid_size = (total + block_size - 1) / block_size;

cudaStream_t stream = ctx.stream();

switch (x->type) {
case GGML_TYPE_F32: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const float *)x->data, a_d, inv_b_d, (float *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_F16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const half *)x->data, a_d, inv_b_d, (half *)dst->data, total, T_len_fastdiv);
} break;
case GGML_TYPE_BF16: {
snake_kernel<<<grid_size, block_size, 0, stream>>>(
(const nv_bfloat16 *)x->data, a_d, inv_b_d, (nv_bfloat16 *)dst->data, total, T_len_fastdiv);
} break;
default:
GGML_ABORT("snake: unsupported type");
}
}

// Fusion entry: caller supplies x/a/inv_b explicitly from the matched
// mul -> sin -> sqr -> mul -> add pattern. The dst is the trailing add output.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst) {
launch_snake(ctx, x, a, inv_b, dst);
}
8 changes: 8 additions & 0 deletions ggml/src/ggml-cuda/snake.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
#include "common.cuh"

// Fusion entry point. Caller supplies x/a/inv_b explicitly.
void ggml_cuda_op_snake_fused(ggml_backend_cuda_context & ctx,
const ggml_tensor * x,
const ggml_tensor * a,
const ggml_tensor * inv_b,
ggml_tensor * dst);
Loading
Loading