@@ -516,6 +516,14 @@ def _warn_deprecated(symbol: str, hint: str) -> None:
516516LLAMA_SPLIT_MODE_TENSOR = 3
517517
518518
519+ # enum llama_context_type {
520+ # LLAMA_CONTEXT_TYPE_DEFAULT = 0,
521+ # LLAMA_CONTEXT_TYPE_MTP = 1,
522+ # };
523+ LLAMA_CONTEXT_TYPE_DEFAULT = 0
524+ LLAMA_CONTEXT_TYPE_MTP = 1
525+
526+
519527# typedef struct llama_token_data {
520528# llama_token id; // token id
521529# float logit; // log-odds of the token
@@ -894,9 +902,11 @@ class llama_sampler_seq_config(ctypes.Structure):
894902# uint32_t n_batch; // logical maximum batch size that can be submitted to llama_decode
895903# uint32_t n_ubatch; // physical maximum batch size
896904# uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
905+ # uint32_t n_rs_seq; // number of recurrent-state snapshots per seq for rollback (0 = no rollback) [EXPERIMENTAL]
897906# int32_t n_threads; // number of threads to use for generation
898907# int32_t n_threads_batch; // number of threads to use for batch processing
899908
909+ # enum llama_context_type ctx_type; // set the context type (e.g. MTP)
900910# enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type`
901911# enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id
902912# enum llama_attention_type attention_type; // attention type to use for embeddings
@@ -947,8 +957,10 @@ class llama_context_params(ctypes.Structure):
947957 n_batch (int): logical maximum batch size that can be submitted to llama_decode
948958 n_ubatch (int): physical maximum batch size
949959 n_seq_max (int): max number of sequences (i.e. distinct states for recurrent models)
960+ n_rs_seq (int): number of recurrent-state snapshots per sequence for rollback
950961 n_threads (int): number of threads to use for generation
951962 n_threads_batch (int): number of threads to use for batch processing
963+ ctx_type (int): context type, from `enum llama_context_type`
952964 rope_scaling_type (int): RoPE scaling type, from `enum llama_rope_scaling_type`
953965 pooling_type (int): whether to pool (sum) embedding results by sequence id (ignored if no pooling layer)
954966 attention_type (int): attention type to use for embeddings
@@ -982,8 +994,10 @@ class llama_context_params(ctypes.Structure):
982994 n_batch : int
983995 n_ubatch : int
984996 n_seq_max : int
997+ n_rs_seq : int
985998 n_threads : int
986999 n_threads_batch : int
1000+ ctx_type : int
9871001 rope_scaling_type : int
9881002 pooling_type : int
9891003 attention_type : int
@@ -1016,8 +1030,10 @@ class llama_context_params(ctypes.Structure):
10161030 ("n_batch" , ctypes .c_uint32 ),
10171031 ("n_ubatch" , ctypes .c_uint32 ),
10181032 ("n_seq_max" , ctypes .c_uint32 ),
1033+ ("n_rs_seq" , ctypes .c_uint32 ),
10191034 ("n_threads" , ctypes .c_int32 ),
10201035 ("n_threads_batch" , ctypes .c_int32 ),
1036+ ("ctx_type" , ctypes .c_int ),
10211037 ("rope_scaling_type" , ctypes .c_int ),
10221038 ("pooling_type" , ctypes .c_int ),
10231039 ("attention_type" , ctypes .c_int ),
@@ -1591,6 +1607,11 @@ def llama_n_ubatch(ctx: llama_context_p, /) -> int: ...
15911607def llama_n_seq_max (ctx : llama_context_p , / ) -> int : ...
15921608
15931609
1610+ # LLAMA_API uint32_t llama_n_rs_seq (const struct llama_context * ctx);
1611+ @ctypes_function ("llama_n_rs_seq" , [llama_context_p_ctypes ], ctypes .c_uint32 )
1612+ def llama_n_rs_seq (ctx : llama_context_p , / ) -> int : ...
1613+
1614+
15941615# DEPRECATED(LLAMA_API int32_t llama_n_ctx_train(const struct llama_model * model), "use llama_model_n_ctx_train instead");
15951616@ctypes_function ("llama_n_ctx_train" , [llama_model_p_ctypes ], ctypes .c_int32 )
15961617def llama_n_ctx_train (model : llama_model_p , / ) -> int : ...
0 commit comments