Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

- feat: Expose `attention_type` in `Llama.__init__` for non-causal embedding models by @jamesbiederbeck in #2143
- fix(ci): Build Docker images from the checked-out source and sanitize branch tags by @abetlen in #2156
- fix(ci): Fix the CUDA wheel workflow and keep release tags aligned with the built toolkit by @abetlen in #2155
- fix(ci): Speed up release wheel builds by moving arm64 off QEMU and parallelizing riscv64 by @abetlen in #2154
Expand Down
4 changes: 4 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def __init__(
int
] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED,
attention_type: int = llama_cpp.LLAMA_ATTENTION_TYPE_UNSPECIFIED,
rope_freq_base: float = 0.0,
rope_freq_scale: float = 0.0,
yarn_ext_factor: float = -1.0,
Expand Down Expand Up @@ -163,6 +164,7 @@ def __init__(
n_threads_batch: Number of threads to use for batch processing
rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054
pooling_type: Pooling type, from `enum llama_pooling_type`.
attention_type: Attention type, from `enum llama_attention_type`.
rope_freq_base: RoPE base frequency, 0 = from model
rope_freq_scale: RoPE frequency scaling factor, 0 = from model
yarn_ext_factor: YaRN extrapolation mix factor, negative = from model
Expand Down Expand Up @@ -319,6 +321,7 @@ def __init__(
else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED
)
self.context_params.pooling_type = pooling_type
self.context_params.attention_type = attention_type
self.context_params.rope_freq_base = (
rope_freq_base if rope_freq_base != 0.0 else 0
)
Expand Down Expand Up @@ -2100,6 +2103,7 @@ def __getstate__(self):
n_threads_batch=self.context_params.n_threads_batch,
rope_scaling_type=self.context_params.rope_scaling_type,
pooling_type=self.context_params.pooling_type,
attention_type=self.context_params.attention_type,
rope_freq_base=self.context_params.rope_freq_base,
rope_freq_scale=self.context_params.rope_freq_scale,
yarn_ext_factor=self.context_params.yarn_ext_factor,
Expand Down
Loading