diff --git a/CHANGELOG.md b/CHANGELOG.md index b47613109..de4f070ff 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] +- feat: Expose `attention_type` in `Llama.__init__` for non-causal embedding models by @jamesbiederbeck in #2143 - fix(ci): Build Docker images from the checked-out source and sanitize branch tags by @abetlen in #2156 - fix(ci): Fix the CUDA wheel workflow and keep release tags aligned with the built toolkit by @abetlen in #2155 - fix(ci): Speed up release wheel builds by moving arm64 off QEMU and parallelizing riscv64 by @abetlen in #2154 diff --git a/llama_cpp/llama.py b/llama_cpp/llama.py index 88bc2e5bb..ad484c4d5 100644 --- a/llama_cpp/llama.py +++ b/llama_cpp/llama.py @@ -81,6 +81,7 @@ def __init__( int ] = llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, pooling_type: int = llama_cpp.LLAMA_POOLING_TYPE_UNSPECIFIED, + attention_type: int = llama_cpp.LLAMA_ATTENTION_TYPE_UNSPECIFIED, rope_freq_base: float = 0.0, rope_freq_scale: float = 0.0, yarn_ext_factor: float = -1.0, @@ -163,6 +164,7 @@ def __init__( n_threads_batch: Number of threads to use for batch processing rope_scaling_type: RoPE scaling type, from `enum llama_rope_scaling_type`. ref: https://github.com/ggerganov/llama.cpp/pull/2054 pooling_type: Pooling type, from `enum llama_pooling_type`. + attention_type: Attention type, from `enum llama_attention_type`. rope_freq_base: RoPE base frequency, 0 = from model rope_freq_scale: RoPE frequency scaling factor, 0 = from model yarn_ext_factor: YaRN extrapolation mix factor, negative = from model @@ -319,6 +321,7 @@ def __init__( else llama_cpp.LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED ) self.context_params.pooling_type = pooling_type + self.context_params.attention_type = attention_type self.context_params.rope_freq_base = ( rope_freq_base if rope_freq_base != 0.0 else 0 ) @@ -2100,6 +2103,7 @@ def __getstate__(self): n_threads_batch=self.context_params.n_threads_batch, rope_scaling_type=self.context_params.rope_scaling_type, pooling_type=self.context_params.pooling_type, + attention_type=self.context_params.attention_type, rope_freq_base=self.context_params.rope_freq_base, rope_freq_scale=self.context_params.rope_freq_scale, yarn_ext_factor=self.context_params.yarn_ext_factor,