diff --git a/lightllm/common/basemodel/attention/triton/fp.py b/lightllm/common/basemodel/attention/triton/fp.py index d29f15ec3b..1902960769 100644 --- a/lightllm/common/basemodel/attention/triton/fp.py +++ b/lightllm/common/basemodel/attention/triton/fp.py @@ -25,12 +25,12 @@ def prefill_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ) -> torch.Tensor: - assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: - return self._nomarl_prefill_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._nomarl_prefill_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) def _alibi_prefill_att( self, @@ -59,9 +59,21 @@ def _alibi_prefill_att( ) return out - def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, alloc_func=torch.empty): + def _nomarl_prefill_att( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + att_control: AttControl = AttControl(), + alloc_func=torch.empty, + ): from ...triton_kernel.att.prefill_att.context_flashattention_nopad import context_attention_fwd + if att_control.use_sliding_window: + sliding_window = int(att_control.sliding_window[0]) + else: + sliding_window = -1 + out = alloc_func(q.shape, q.dtype) context_attention_fwd( q, @@ -74,6 +86,7 @@ def _nomarl_prefill_att(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, self.infer_state.b_ready_cache_len, self.infer_state.max_q_seq_len, self.infer_state.req_manager.req_to_token_indexs, + sliding_window=sliding_window, ) return out @@ -94,8 +107,8 @@ def decode_att( att_control: AttControl = AttControl(), alloc_func=torch.empty, ): - assert att_control.use_sliding_window is False and att_control.use_att_sink is False if att_control.use_alibi: + assert att_control.use_sliding_window is False, "alibi + sliding_window not supported" assert att_control.tp_alibi is not None return self._alibi_decode_att(q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func) else: @@ -104,7 +117,9 @@ def decode_att( if q_head_num == k_head_num: return self._normal_decode_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) elif q_head_num > k_head_num: - return self._normal_decode_gqa_flash_decoding_att(q=q, k=k, v=v, alloc_func=alloc_func) + return self._normal_decode_gqa_flash_decoding_att( + q=q, k=k, v=v, att_control=att_control, alloc_func=alloc_func + ) else: raise NotImplementedError("error") @@ -163,12 +178,18 @@ def _normal_decode_gqa_flash_decoding_att( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, + att_control: AttControl = AttControl(), alloc_func=torch.empty, ): from ...triton_kernel.att.decode_att.gqa.flash_decoding.gqa_flash_decoding import ( gqa_token_decode_attention_flash_decoding, ) + if att_control.use_sliding_window: + sliding_window = int(att_control.sliding_window[0]) + else: + sliding_window = -1 + out = alloc_func(q.shape, q.dtype) gqa_token_decode_attention_flash_decoding( @@ -178,6 +199,7 @@ def _normal_decode_gqa_flash_decoding_att( cache_v=v, out=out, alloc_tensor_func=alloc_func, + sliding_window=sliding_window, ) return out diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py index 276b5856f9..f0cc129c09 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py @@ -103,11 +103,13 @@ def _context_attention_wrapper_run( ) -> torch.Tensor: if torch.cuda.is_current_stream_capturing(): q = q.contiguous() - cache_kv = cache_kv.contiguous() - _q, _cache_kv = ( - tensor_to_no_ref_tensor(q), - tensor_to_no_ref_tensor(cache_kv), - ) + # cache_kv is None for layers that own no K/V slot (e.g. gemma4 + # KV-shared layers, which read K/V from a prior layer's cache and + # ignore this arg in _context_attention_kernel). Skip the + # graph-input plumbing for it instead of crashing on None. + cache_kv = cache_kv.contiguous() if cache_kv is not None else None + _q = tensor_to_no_ref_tensor(q) + _cache_kv = tensor_to_no_ref_tensor(cache_kv) if cache_kv is not None else None pre_capture_graph = infer_state.prefill_cuda_graph_get_current_capture_graph() pre_capture_graph.__exit__(None, None, None) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py index 8f54e14a72..dd99616b6b 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/fused_moe_weight.py @@ -33,12 +33,14 @@ def __init__( num_fused_shared_experts: int = 0, layer_num: int = 0, network_config: Dict[str, Any] = None, + per_expert_scale_name: str = "", ) -> None: super().__init__(data_type=data_type) self.w1_weight_name = gate_proj_name self.w2_weight_name = down_proj_name self.w3_weight_name = up_proj_name self.e_score_correction_bias_name = e_score_correction_bias_name + self.per_expert_scale_name = per_expert_scale_name self.weight_prefix = weight_prefix self.layer_num_ = layer_num self.global_rank_ = get_global_rank() @@ -130,6 +132,8 @@ def experts( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ) -> torch.Tensor: """Backward compatible method that routes to platform-specific implementation.""" return self.fuse_moe_impl( @@ -145,6 +149,8 @@ def experts( topk_group=topk_group, num_expert_group=num_expert_group, is_prefill=is_prefill, + per_expert_scale=per_expert_scale, + use_gelu=use_gelu, ) def low_latency_dispatch( @@ -263,16 +269,22 @@ def load_hf_weights(self, weights): # Load bias if self.e_score_correction_bias_name in weights: self.e_score_correction_bias.copy_(weights[self.e_score_correction_bias_name]) + self._load_per_expert_scale(weights) self._load_weight(self.expert_idx_to_local_idx, weights) if self.redundancy_expert_num > 0: self._load_weight(self.redundancy_expert_idx_to_local_idx, weights) def verify_load(self): - return all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + weight_load_ok = all(all(_weight_pack.load_ok) for _weight_pack in self.w1_list + self.w2_list + self.w3_list) + per_expert_scale_load_ok = ( + True if self.per_expert_scale is None else getattr(self.per_expert_scale, "load_ok", False) + ) + return weight_load_ok and per_expert_scale_load_ok def _create_weight(self): intermediate_size = self.split_inter_size self.e_score_correction_bias = None + self.per_expert_scale = None # Create e_score_correction_bias if self.e_score_correction_bias_name: self.e_score_correction_bias = torch.empty( @@ -280,6 +292,13 @@ def _create_weight(self): dtype=self.data_type_, device=f"cuda:{self.device_id_}", ) + if self.per_expert_scale_name: + self.per_expert_scale = torch.empty( + (self.n_routed_experts,), + dtype=torch.float32, + device=f"cuda:{self.device_id_}", + ) + self.per_expert_scale.load_ok = False self.w13, w13_param_list = self.quant_method.create_moe_weight( out_dims=[intermediate_size, intermediate_size], @@ -299,6 +318,11 @@ def _create_weight(self): self.w3_list: List[WeightPack] = self._get_expert_weight_list(w13_param_list[1]) self.w2_list: List[WeightPack] = self._get_expert_weight_list(self.w2) + def _load_per_expert_scale(self, weights: Dict[str, torch.Tensor]): + if self.per_expert_scale_name and self.per_expert_scale_name in weights: + self.per_expert_scale.copy_(weights[self.per_expert_scale_name].to(self.per_expert_scale.dtype)) + self.per_expert_scale.load_ok = True + def _get_expert_weight_list(self, weight_pack: WeightPack): weight_list = [] for idx in range(self.local_n_routed_experts): @@ -307,7 +331,6 @@ def _get_expert_weight_list(self, weight_pack: WeightPack): return weight_list def _load_weight(self, expert_idx_to_local_idx: Dict[int, int], weights: Dict[str, torch.Tensor]): - # Load each expert with TP slicing for expert_idx, local_expert_idx in expert_idx_to_local_idx.items(): with self.lock: diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py new file mode 100644 index 0000000000..1df39993ec --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/gemma4_packed_fused_moe_weight.py @@ -0,0 +1,34 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.fused_moe_weight import FusedMoeWeight + + +class Gemma4PackedFusedMoeWeight(FusedMoeWeight): + def load_hf_weights(self, weights): + gate_up_name = f"{self.weight_prefix}.gate_up_proj" + down_name = f"{self.weight_prefix}.down_proj" + if gate_up_name not in weights and down_name not in weights and self.per_expert_scale_name not in weights: + return super().load_hf_weights(weights) + + assert self.quant_method.method_name == "none", "Gemma-4 packed MoE currently supports bf16/no-quant weights." + assert not self.enable_ep_moe, "Gemma-4 packed MoE currently supports TP mode only." + + start = self.split_inter_size * self.tp_rank_ + end = self.split_inter_size * (self.tp_rank_ + 1) + moe_intermediate_size = self.moe_intermediate_size + + if gate_up_name in weights: + gate_up_weight = weights[gate_up_name] + for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items(): + gate_weight = gate_up_weight[expert_idx, start:end, :].contiguous() + up_weight = gate_up_weight[ + expert_idx, moe_intermediate_size + start : moe_intermediate_size + end, : + ].contiguous() + self.quant_method.load_weight(gate_weight, self.w1_list[local_expert_idx]) + self.quant_method.load_weight(up_weight, self.w3_list[local_expert_idx]) + + if down_name in weights: + down_weight = weights[down_name] + for expert_idx, local_expert_idx in self.expert_idx_to_local_idx.items(): + down_weight_slice = down_weight[expert_idx, :, start:end].contiguous() + self.quant_method.load_weight(down_weight_slice, self.w2_list[local_expert_idx]) + + self._load_per_expert_scale(weights) diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py index 00587ac185..3e6ab8accf 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/base_impl.py @@ -62,5 +62,7 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ) -> torch.Tensor: pass diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py index bdd86eb51e..bf0c350138 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/deepgemm_impl.py @@ -31,6 +31,7 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts @@ -48,6 +49,8 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.redundancy_expert_num > 0: redundancy_topk_ids_repair( topk_ids=topk_ids, @@ -68,8 +71,8 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + use_gelu: bool = False, ): - w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale use_fp8_w8a8 = self.quant_method.method_name != "none" @@ -88,6 +91,7 @@ def _fused_experts( w1_scale=w13_scale, w2_scale=w2_scale, previous_event=None, # for overlap + use_gelu=use_gelu, ) return output @@ -210,11 +214,20 @@ def masked_group_gemm( masked_m: torch.Tensor, dtype: torch.dtype, expected_m: int, + use_gelu: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale return masked_group_gemm( - recv_x, masked_m, dtype, w13_weight, w13_scale, w2_weight, w2_scale, expected_m=expected_m + recv_x, + masked_m, + dtype, + w13_weight, + w13_scale, + w2_weight, + w2_scale, + expected_m=expected_m, + use_gelu=use_gelu, ) def prefilled_group_gemm( @@ -226,6 +239,7 @@ def prefilled_group_gemm( w13: WeightPack, w2: WeightPack, hidden_dtype=torch.bfloat16, + use_gelu: bool = False, ): device = recv_x[0].device w13_weight, w13_scale = w13.weight, w13.weight_scale @@ -278,7 +292,7 @@ def prefilled_group_gemm( # TODO fused kernel silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out, use_gelu=use_gelu) qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( silu_out, block_size, dtype=w13_weight.dtype, column_major_scales=True, scale_tma_aligned=True ) @@ -298,7 +312,7 @@ def prefilled_group_gemm( if Autotuner.is_autotune_warmup(): _gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype) _silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out, use_gelu=use_gelu) _gemm_out_a, _silu_out = None, None return gather_out diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py index 6391a10800..67087d2151 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/marlin_impl.py @@ -29,7 +29,9 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: Optional[bool] = None, + use_gelu: bool = False, ): + assert not use_gelu, "FuseMoeMarlin does not support GELU expert activation." w1_weight, w1_scale, w1_zero_point = w13.weight, w13.weight_scale, w13.weight_zero_point w2_weight, w2_scale, w2_zero_point = w2.weight, w2.weight_scale, w2.weight_zero_point diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py index d6e923a115..c634ed59ad 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/fused_moe/impl/triton_impl.py @@ -42,6 +42,7 @@ def _select_experts( topk_group: int, num_expert_group: int, scoring_func: str, + per_expert_scale: Optional[torch.Tensor] = None, ): """Select experts and return topk weights and ids.""" from lightllm.common.basemodel.triton_kernel.fused_moe.topk_select import select_experts @@ -59,6 +60,8 @@ def _select_experts( ) if self.routed_scaling_factor != 1.0: topk_weights.mul_(self.routed_scaling_factor) + if per_expert_scale is not None: + topk_weights = topk_weights * per_expert_scale[topk_ids.to(torch.long)].to(topk_weights.dtype) if self.num_fused_shared_experts > 0: pad_topk_ids = ( torch.arange( @@ -91,6 +94,7 @@ def _fused_experts( topk_ids: torch.Tensor, router_logits: Optional[torch.Tensor] = None, is_prefill: bool = False, + use_gelu: bool = False, ): w13_weight, w13_scale = w13.weight, w13.weight_scale w2_weight, w2_scale = w2.weight, w2.weight_scale @@ -108,6 +112,7 @@ def _fused_experts( use_fp8_w8a8=use_fp8_w8a8, w1_scale=w13_scale, w2_scale=w2_scale, + use_gelu=use_gelu, ) return input_tensor @@ -125,6 +130,8 @@ def __call__( topk_group: int, num_expert_group: int, is_prefill: Optional[bool] = None, + per_expert_scale: Optional[torch.Tensor] = None, + use_gelu: bool = False, ): topk_weights, topk_ids = self._select_experts( input_tensor=input_tensor, @@ -136,6 +143,7 @@ def __call__( topk_group=topk_group, num_expert_group=num_expert_group, scoring_func=scoring_func, + per_expert_scale=per_expert_scale, ) output = self._fused_experts( input_tensor=input_tensor, @@ -145,5 +153,6 @@ def __call__( topk_ids=topk_ids, router_logits=router_logits, is_prefill=is_prefill, + use_gelu=use_gelu, ) return output diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py index 5021699143..895482b491 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/mm_weight/mm_weight.py @@ -54,10 +54,23 @@ def __init__( self.gen_weight_quant_param_names() def mm( - self, input_tensor: torch.Tensor, out: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True + self, + input_tensor: torch.Tensor, + out: Optional[torch.Tensor] = None, + use_custom_tensor_mananger: bool = True, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + # out_dtype: optional override that asks the quant backend to produce + # an output of a specified dtype (e.g. fp32) directly from the GEMM + # accumulator. Only NoQuantization currently honors values that differ + # from input dtype; other quant impls will assert. return self.quant_method.apply( - input_tensor, self.mm_param, out, use_custom_tensor_mananger=use_custom_tensor_mananger, bias=self.bias + input_tensor, + self.mm_param, + out, + use_custom_tensor_mananger=use_custom_tensor_mananger, + bias=self.bias, + out_dtype=out_dtype, ) def gen_weight_quant_param_names(self): diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py index e549298e3b..55180d7adb 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding.py @@ -2,7 +2,13 @@ def gqa_token_decode_attention_flash_decoding( - q: torch.Tensor, infer_state, cache_k: torch.Tensor, cache_v: torch.Tensor, out=None, alloc_tensor_func=torch.empty + q: torch.Tensor, + infer_state, + cache_k: torch.Tensor, + cache_v: torch.Tensor, + out=None, + alloc_tensor_func=torch.empty, + sliding_window: int = -1, ): batch_size = infer_state.batch_size q_head_num, head_dim = q.shape[1], q.shape[2] @@ -39,6 +45,7 @@ def gqa_token_decode_attention_flash_decoding( mid_out=mid_o, mid_out_logsumexp=mid_o_logexpsum, block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) flash_decode_stage2( mid_out=mid_o, @@ -46,5 +53,6 @@ def gqa_token_decode_attention_flash_decoding( B_Seqlen=infer_state.b_seq_len, out=o_tensor.view(calcu_shape1), block_seq=BLOCK_SEQ, + sliding_window=sliding_window, ) return o_tensor diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py index f484e7850f..d60e434627 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage1.py @@ -39,6 +39,8 @@ def _fwd_kernel_flash_decode_stage1( BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -46,6 +48,12 @@ def _fwd_kernel_flash_decode_stage1( grid_block_num = tl.num_programs(2) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index + else: + kv_start_index = 0 + req_total_block_num = tl.cdiv(cur_batch_seq_len, BLOCK_SEQ) if block_index >= req_total_block_num: return @@ -77,7 +85,7 @@ def _fwd_kernel_flash_decode_stage1( offs_n_new = start_n * BLOCK_N + offs_n n_mask = offs_n_new < cur_batch_end_index k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + kv_start_index + offs_n_new, mask=n_mask, other=0, ).to(tl.int64) @@ -110,14 +118,8 @@ def _fwd_kernel_flash_decode_stage1( + offs_d[None, :] ) off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + block_index - tl.store( - Mid_O + off_mid_o, - acc / sum_exp[:, None], - ) - tl.store( - Mid_O_LogExpSum + off_mid_o_logexpsum, - max_logic + tl.log(sum_exp), - ) + tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None]) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) return @@ -170,6 +172,7 @@ def flash_decode_stage1( mid_out, mid_out_logsumexp, block_seq, + sliding_window: int = -1, run_config: Optional[dict] = None, ): """ """ @@ -185,8 +188,8 @@ def flash_decode_stage1( # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} - if Lk == 256: + assert Lk in {16, 32, 64, 128, 256, 512} + if Lk >= 256: BLOCK_N = min(BLOCK_N, 16) assert BLOCK_SEQ % BLOCK_N == 0 sm_scale = 1.0 / (Lk ** 0.5) @@ -194,6 +197,8 @@ def flash_decode_stage1( block_num = mid_out.shape[2] grid = (batch, kv_head_num, block_num) gqa_group_size = q.shape[1] // k.shape[1] + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel_flash_decode_stage1[grid]( q, @@ -228,6 +233,8 @@ def flash_decode_stage1( BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=Lk, BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=num_warps, num_stages=num_stages, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py index 50739e8305..810abe1efa 100644 --- a/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py +++ b/lightllm/common/basemodel/triton_kernel/att/decode_att/gqa/flash_decoding/gqa_flash_decoding_stage2.py @@ -22,12 +22,17 @@ def _fwd_kernel_flash_decode_stage2( block_num, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + if USE_SLIDING_WINDOW: + kv_start_index = tl.maximum(cur_batch_seq_len - 1 - SLIDING_WINDOW_SIZE, 0) + cur_batch_seq_len = cur_batch_seq_len - kv_start_index block_num = tl.minimum(tl.cdiv(cur_batch_seq_len, BLOCK_SEQ), block_num) @@ -54,12 +59,14 @@ def _fwd_kernel_flash_decode_stage2( @torch.no_grad() -def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq, sliding_window: int = -1): Lk = mid_out.shape[-1] - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} batch, head_num = mid_out.shape[0], mid_out.shape[1] grid = (batch, head_num) block_num = mid_out.shape[2] + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel_flash_decode_stage2[grid]( B_Seqlen, @@ -79,6 +86,8 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, out, block_seq): block_num, BLOCK_SEQ=block_seq, BLOCK_DMODEL=Lk, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=4, num_stages=2, ) diff --git a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py index 5ba6d0beb6..7daf3a12e8 100644 --- a/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py +++ b/lightllm/common/basemodel/triton_kernel/att/prefill_att/context_flashattention_nopad.py @@ -41,6 +41,8 @@ def _fwd_kernel( BLOCK_DMODEL: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, ): start_m = tl.program_id(0) cur_bh = tl.program_id(1) @@ -60,6 +62,7 @@ def _fwd_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_pos = offs_m + prompt_cache_len off_q = ( (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh @@ -76,20 +79,31 @@ def _fwd_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) - # causal mask - for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index + else: + kv_start_index = 0 + block_kv_len = block_end_loc + + # causal (+ sliding-window) mask + for start_n in range(0, block_mask * block_kv_len, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = kv_start_index + start_n + offs_n # -- compute qk ---- kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), - mask=(start_n + offs_n) < block_end_loc, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_pos < block_end_loc, other=0, ).to(tl.int64) off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd - k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0) + k = tl.load(K + off_k, mask=k_pos[None, :] < block_end_loc, other=0.0) qk = tl.dot(q, k) - mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :]) + mask = q_pos[:, None] >= k_pos[None, :] + if USE_SLIDING_WINDOW: + mask = mask & ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_SIZE) qk = tl.where(mask, qk * sm_scale, -1.0e8) m_ij = tl.maximum(m_i, tl.max(qk, 1)) qk -= m_ij[:, None] @@ -103,7 +117,7 @@ def _fwd_kernel( acc = acc * alpha[:, None] # update acc off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd - v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0) + v = tl.load(V + off_v, mask=k_pos[:, None] < block_end_loc, other=0.0) p = p.to(v.dtype) acc = tl.dot(p, v, acc) # update m_i and l_i @@ -121,13 +135,30 @@ def _fwd_kernel( @torch.no_grad() def context_attention_fwd( - q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + sliding_window: int = -1, ): BLOCK_M = 128 if not is_tesla() else 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -140,6 +171,8 @@ def context_attention_fwd( BLOCK_N = BLOCK_M num_warps = 4 if Lk <= 64 else 8 num_stages = 1 + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 _fwd_kernel[grid]( q, @@ -171,6 +204,8 @@ def context_attention_fwd( BLOCK_DMODEL=Lk, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, num_warps=num_warps, num_stages=num_stages, ) @@ -291,7 +326,14 @@ def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, ma # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 @@ -463,7 +505,14 @@ def context_attention_fwd_contiguous_kv( # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 128, 256, 512} + # Larger head_dim needs smaller tiles to fit in SM shared memory. + # H100/H200 has ~228KB shared memory per SM; a 128x512 bf16 tile already + # consumes 128KB, leaving no room for K/V/scores buffers. + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634, # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。 diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py index 638abbd6ca..bed3754960 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe.py @@ -114,7 +114,6 @@ def moe_align1_kernel( TOKEN_BLOCK_SIZE: tl.constexpr, NUM_STAGE: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_n = tl.arange(0, TOKEN_BLOCK_SIZE) @@ -308,7 +307,6 @@ def moe_align2_kernel( BLOCK_M: tl.constexpr, BLOCK_EXPERT: tl.constexpr, ): - expert_id = tl.program_id(axis=0) off_expert = tl.arange(0, BLOCK_EXPERT) expert_to_token_num = tl.load(experts_token_num_ptr + off_expert, mask=off_expert < expert_num, other=0) @@ -911,6 +909,7 @@ def fused_experts_impl( layout="blocked", limit=None, alpha=None, + use_gelu: bool = False, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -990,6 +989,7 @@ def fused_experts_impl( limit=limit, alpha=alpha, layout=layout, + use_gelu=use_gelu, ) grouped_matmul( @@ -1035,6 +1035,7 @@ def inplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: fused_experts_impl( hidden_states, @@ -1054,6 +1055,7 @@ def inplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) @@ -1075,6 +1077,7 @@ def inplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: pass @@ -1105,6 +1108,7 @@ def outplace_fused_experts_impl( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: return fused_experts_impl( hidden_states, @@ -1124,6 +1128,7 @@ def outplace_fused_experts_impl( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) @@ -1145,6 +1150,7 @@ def outplace_fused_experts_impl_fake( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ) -> None: return torch.empty_like(hidden_states) @@ -1176,6 +1182,7 @@ def fused_experts( layout: str = "blocked", alpha: Optional[float] = None, limit: Optional[float] = None, + use_gelu: bool = False, ): if inplace: torch.ops.lightllm.inplace_fused_experts_impl( @@ -1195,6 +1202,7 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) return hidden_states else: @@ -1215,4 +1223,5 @@ def fused_experts( layout=layout, alpha=alpha, limit=limit, + use_gelu=use_gelu, ) diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py index 2c6d013bd5..8f31bde57a 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/grouped_fused_moe_ep.py @@ -40,6 +40,7 @@ def masked_group_gemm( w2: torch.Tensor, w2_scale: torch.Tensor, expected_m: int, + use_gelu: bool = False, ): padded_m = recv_x[0].shape[1] E, N, _ = w1.shape @@ -54,7 +55,7 @@ def masked_group_gemm( _deepgemm_grouped_fp8_nt_masked(recv_x, (w1, w1_scale), gemm_out_a, masked_m, expected_m) - silu_and_mul_masked_post_quant_fwd(gemm_out_a, qsilu_out, qsilu_out_scale, block_size, masked_m) + silu_and_mul_masked_post_quant_fwd(gemm_out_a, qsilu_out, qsilu_out_scale, block_size, masked_m, use_gelu=use_gelu) _deepgemm_grouped_fp8_nt_masked((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, masked_m, expected_m) return gemm_out_b @@ -74,6 +75,7 @@ def fused_experts_impl( w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, previous_event: Optional["EventOverlap"] = None, + use_gelu: bool = False, ): # Check constraints. assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" @@ -175,7 +177,7 @@ def fused_experts_impl( # TODO fused kernel silu_out = torch.empty((all_tokens, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out) + silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out, use_gelu=use_gelu) qsilu_out, qsilu_out_scale = per_token_group_quant_fp8( silu_out, block_size_k, dtype=w1.dtype, column_major_scales=True, scale_tma_aligned=True ) @@ -194,7 +196,7 @@ def fused_experts_impl( if Autotuner.is_autotune_warmup(): _gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype) _silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype) - silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out) + silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out, use_gelu=use_gelu) _gemm_out_a, _silu_out = None, None # normal combine @@ -220,7 +222,9 @@ def fused_experts_impl( return_recv_hook=False, ) # deepgemm - gemm_out_b = masked_group_gemm(recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m) + gemm_out_b = masked_group_gemm( + recv_x, masked_m, hidden_states.dtype, w1, w1_scale, w2, w2_scale, expected_m, use_gelu=use_gelu + ) # low latency combine combined_x, event_overlap, hook = buffer.low_latency_combine( gemm_out_b, topk_idx, topk_weights, handle, async_finish=False, return_recv_hook=False diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py index d7bcc17743..2b7a9d30b4 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul.py @@ -23,6 +23,7 @@ def _silu_and_mul_kernel_fast( NEED_MASK: tl.constexpr, layout: tl.constexpr = "blocked", # "blocked" or "interleaved" USE_LIMIT_AND_ALPHA: tl.constexpr = False, + USE_GELU: tl.constexpr = False, ): stride_input_m = tl.cast(stride_input_m, dtype=tl.int64) stride_output_m = tl.cast(stride_output_m, dtype=tl.int64) @@ -74,7 +75,14 @@ def _silu_and_mul_kernel_fast( mask=mask, ) else: - gate = gate / (1 + tl.exp(-gate)) + if USE_GELU: + # tanh-approx GELU, matching Gemma's gelu_pytorch_tanh MLP. + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) tl.store( @@ -106,7 +114,13 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor): mutates_args=["output"], ) def silu_and_mul_fwd( - input: torch.Tensor, output: torch.Tensor, layout="blocked", limit=None, alpha=None, run_config=None + input: torch.Tensor, + output: torch.Tensor, + layout="blocked", + limit=None, + alpha=None, + run_config=None, + use_gelu: bool = False, ): assert input.is_contiguous() assert output.is_contiguous() @@ -157,5 +171,6 @@ def silu_and_mul_fwd( num_warps=num_warps, layout=layout, USE_LIMIT_AND_ALPHA=USE_LIMIT_AND_ALPHA, + USE_GELU=use_gelu, ) return diff --git a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py index d2c44b2953..30124cc2b2 100644 --- a/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py +++ b/lightllm/common/basemodel/triton_kernel/fused_moe/moe_silu_and_mul_mix_quant_ep.py @@ -24,6 +24,7 @@ def _silu_and_mul_post_quant_kernel( fp8_min, BLOCK_N: tl.constexpr, NUM_STAGE: tl.constexpr, + USE_GELU: tl.constexpr = False, ): expert_id = tl.program_id(2) token_id = tl.program_id(1) @@ -48,7 +49,13 @@ def _silu_and_mul_post_quant_kernel( for token_index in tl.range(token_id, token_num_cur_expert, block_num_per_expert, num_stages=NUM_STAGE): gate = tl.load(input_ptr_offs + token_index * stride_input_1, mask=offs_in_d < size_n, other=0.0).to(tl.float32) up = tl.load(input_ptr_offs + token_index * stride_input_1 + size_n, mask=offs_in_d < size_n, other=0.0) - gate = gate / (1 + tl.exp(-gate)) + if USE_GELU: + gate_cubed = gate * gate * gate + tanh_arg = 0.7978845608028654 * (gate + 0.044715 * gate_cubed) + tanh_val = 2.0 / (1.0 + tl.exp(-2.0 * tanh_arg)) - 1.0 + gate = 0.5 * gate * (1.0 + tanh_val) + else: + gate = gate / (1 + tl.exp(-gate)) gate = gate.to(input_ptr.dtype.element_ty) gate_up = up * gate _absmax = tl.maximum(tl.max(tl.abs(gate_up)), 1e-10) @@ -66,7 +73,12 @@ def _silu_and_mul_post_quant_kernel( def silu_and_mul_masked_post_quant_fwd( - input: torch.Tensor, output: torch.Tensor, output_scale: torch.Tensor, quant_group_size: int, masked_m: torch.Tensor + input: torch.Tensor, + output: torch.Tensor, + output_scale: torch.Tensor, + quant_group_size: int, + masked_m: torch.Tensor, + use_gelu: bool = False, ): """ input shape [expert_num, token_num_padded, hidden_dim] @@ -122,6 +134,7 @@ def silu_and_mul_masked_post_quant_fwd( fp8_min, BLOCK_N=BLOCK_N, NUM_STAGE=NUM_STAGES, + USE_GELU=use_gelu, num_warps=num_warps, ) return diff --git a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py index e2d4aea587..05d678e41b 100644 --- a/lightllm/common/basemodel/triton_kernel/multimodal_emb.py +++ b/lightllm/common/basemodel/triton_kernel/multimodal_emb.py @@ -23,6 +23,8 @@ def _fwd_kernel( tp_text_end_token_id, hidden_size, tp_world_size, + APPLY_TEXT_EMBED_SCALE: tl.constexpr, + TEXT_EMBED_SCALE: tl.constexpr, BLOCK_HIDDEN_DIM: tl.constexpr, ): @@ -43,6 +45,8 @@ def _fwd_kernel( mask=off_d < hidden_size, other=0, ) + if APPLY_TEXT_EMBED_SCALE: + load_emb *= TEXT_EMBED_SCALE tl.store(Out + stride_out_s * seq_index + stride_out_d * off_d, load_emb, mask=off_d < hidden_size) img_start_token_id = tl.load(Img_start_token_ids + img_handle_id - 1, mask=img_handle_id >= 1, other=0) @@ -84,9 +88,12 @@ def multimodal_emb( tp_text_start_token_id: int, tp_text_end_token_id: int, tp_world_size: int, + text_embed_scale: float = 1.0, ): total_len = prompt_ids.shape[0] BLOCK = triton.next_power_of_2(out.shape[1]) + text_embed_scale = float(text_embed_scale) + apply_text_embed_scale = text_embed_scale != 1.0 # print(len(img_token_lens)) grid = (total_len, len(img_token_lens) + 1) num_warps = 1 @@ -109,6 +116,8 @@ def multimodal_emb( tp_text_end_token_id=tp_text_end_token_id, hidden_size=out.shape[1], tp_world_size=float(tp_world_size), + APPLY_TEXT_EMBED_SCALE=apply_text_embed_scale, + TEXT_EMBED_SCALE=text_embed_scale, BLOCK_HIDDEN_DIM=BLOCK, num_warps=num_warps, num_stages=1, diff --git a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py index ca8f9a1c81..8dc8558922 100644 --- a/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py +++ b/lightllm/common/basemodel/triton_kernel/norm/rmsnorm.py @@ -18,6 +18,7 @@ def _rms_norm_fwd_fused( y_stride1, N, # number of columns in X eps, # epsilon to avoid division by zero + HAS_WEIGHT: tl.constexpr, BLOCK_SIZE: tl.constexpr, ): # Map the program id to the row of X and Y it should compute. @@ -32,14 +33,17 @@ def _rms_norm_fwd_fused( _var += x * x var = tl.sum(_var, axis=0) / N rstd = 1 / tl.sqrt(var + eps) - # Normalize and apply linear transformation + # Normalize and optionally apply linear transformation for off in range(0, N, BLOCK_SIZE): cols = off + tl.arange(0, BLOCK_SIZE) mask = cols < N - w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_WEIGHT: + w = tl.load(W + cols, mask=mask).to(tl.float32) x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) x_hat = x * rstd - y = x_hat * w + y = x_hat + if HAS_WEIGHT: + y = x_hat * w # Write output tl.store(Y + cols * y_stride1, y.to(Y.dtype.element_ty), mask=mask) @@ -50,7 +54,9 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) # reshape input data into 2D tensor x_arg = x.view(-1, x.shape[-1]) y_arg = y.view(-1, x.shape[-1]) - assert x_arg.shape[-1] == weight.shape[0] and x_arg.shape == y_arg.shape + assert x_arg.shape == y_arg.shape + if weight is not None: + assert x_arg.shape[-1] == weight.shape[0] assert y.data_ptr() == y_arg.data_ptr() M, N = x_arg.shape # Less than 64KB per feature: enqueue fused kernel @@ -73,6 +79,7 @@ def rmsnorm_forward(x: torch.Tensor, weight: torch.Tensor, eps: float, out=None) y_arg.stride(1), N, eps, + HAS_WEIGHT=weight is not None, BLOCK_SIZE=BLOCK_SIZE, num_warps=rmsnorm_num_warps, ) diff --git a/lightllm/common/quantization/awq.py b/lightllm/common/quantization/awq.py index f3c7623975..96cd2b2926 100644 --- a/lightllm/common/quantization/awq.py +++ b/lightllm/common/quantization/awq.py @@ -58,6 +58,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("AWQ online quantization is not supported yet.") @@ -92,7 +93,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "awq quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point @@ -167,7 +170,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "awq_marlin quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale qzeros = weight_pack.weight_zero_point diff --git a/lightllm/common/quantization/deepgemm.py b/lightllm/common/quantization/deepgemm.py index 137455a821..901ec142e1 100644 --- a/lightllm/common/quantization/deepgemm.py +++ b/lightllm/common/quantization/deepgemm.py @@ -35,6 +35,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -75,7 +76,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "deepgemm-fp8w8a8-b128 quant does not support out_dtype" qweight = weight_pack.weight weight_scale = weight_pack.weight_scale input_scale = None diff --git a/lightllm/common/quantization/no_quant.py b/lightllm/common/quantization/no_quant.py index fa926ad6f0..b0deaca9a4 100644 --- a/lightllm/common/quantization/no_quant.py +++ b/lightllm/common/quantization/no_quant.py @@ -18,20 +18,27 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager weight = weight_pack.weight.t() + target_dtype = out_dtype if out_dtype is not None else input_tensor.dtype if out is None: shape = (input_tensor.shape[0], weight.shape[1]) - dtype = input_tensor.dtype device = input_tensor.device if use_custom_tensor_mananger: - out = g_cache_manager.alloc_tensor(shape, dtype, device=device) + out = g_cache_manager.alloc_tensor(shape, target_dtype, device=device) else: - out = torch.empty(shape, dtype=dtype, device=device) + out = torch.empty(shape, dtype=target_dtype, device=device) + else: + assert out.dtype == target_dtype, ( + f"NoQuantization.apply: pre-allocated out.dtype={out.dtype} does not match " + f"requested out_dtype={target_dtype}" + ) if bias is None: - return torch.mm(input_tensor, weight, out=out) + return torch.mm(input_tensor, weight, out=out, out_dtype=target_dtype) + assert out_dtype is None, "NoQuantization.apply: out_dtype not supported when bias is set" return torch.addmm(bias, input_tensor, weight, out=out) def _create_weight( diff --git a/lightllm/common/quantization/quantize_method.py b/lightllm/common/quantization/quantize_method.py index 95d8d806f9..d3f251ec84 100644 --- a/lightllm/common/quantization/quantize_method.py +++ b/lightllm/common/quantization/quantize_method.py @@ -55,6 +55,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: pass diff --git a/lightllm/common/quantization/w8a8.py b/lightllm/common/quantization/w8a8.py index 65ec6cd145..3ce3d92345 100644 --- a/lightllm/common/quantization/w8a8.py +++ b/lightllm/common/quantization/w8a8.py @@ -41,6 +41,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -78,7 +79,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "w8a8 quant does not support out_dtype" input_scale = None qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale @@ -140,7 +143,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8 quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale x_q, x_scale = scaled_fp8_quant(input_tensor, scale=None, scale_ub=None, use_per_token_if_dynamic=True) @@ -207,7 +212,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8-b128 quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale.t() input_scale = None # dynamic quantization for input tensor diff --git a/lightllm/common/quantization/w8a8gx.py b/lightllm/common/quantization/w8a8gx.py index c25136697d..a6a4065745 100644 --- a/lightllm/common/quantization/w8a8gx.py +++ b/lightllm/common/quantization/w8a8gx.py @@ -24,6 +24,7 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: raise NotImplementedError("Not implemented") @@ -62,7 +63,9 @@ def apply( workspace: Optional[torch.Tensor] = None, use_custom_tensor_mananger: bool = True, bias: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, ) -> torch.Tensor: + assert out_dtype is None, "fp8w8a8gxx quant does not support out_dtype" qweight = weight_pack.weight.t() weight_scale = weight_pack.weight_scale from lightllm.common.basemodel.triton_kernel.quantization.fp8act_quant_kernel import ( diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 2caee91709..f619b1d88f 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -33,6 +33,7 @@ from lightllm.models.qwen3_vl.model import Qwen3VLTpPartModel from lightllm.models.qwen3_vl_moe.model import Qwen3VLMOETpPartModel from lightllm.models.gemma3.model import Gemma3TpPartModel +from lightllm.models.gemma4.model import Gemma4TpPartModel from lightllm.models.tarsier2.model import ( Tarsier2Qwen2TpPartModel, Tarsier2Qwen2VLTpPartModel, diff --git a/lightllm/models/gemma4/__init__.py b/lightllm/models/gemma4/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/gemma4_visual.py b/lightllm/models/gemma4/gemma4_visual.py new file mode 100644 index 0000000000..7ed64108b3 --- /dev/null +++ b/lightllm/models/gemma4/gemma4_visual.py @@ -0,0 +1,146 @@ +import json +import os +from io import BytesIO +from typing import List + +import torch +from PIL import Image +from safetensors import safe_open +from transformers import AutoConfig, AutoProcessor + +from lightllm.server.embed_cache.utils import get_shm_name_data, read_shm +from lightllm.server.multimodal_params import ImageItem +from lightllm.utils.log_utils import init_logger +from lightllm.utils.torch_dtype_utils import get_torch_dtype + + +logger = init_logger(__name__) + + +class Gemma4VisionModel: + def __init__(self, data_type="bfloat16"): + self.vision_tower = None + self.embed_vision = None + self.image_processor = None + self.data_type = data_type if isinstance(data_type, torch.dtype) else get_torch_dtype(data_type) + self.device = torch.device("cpu") + + def _weight_files(self, weight_dir): + index_path = os.path.join(weight_dir, "model.safetensors.index.json") + if os.path.exists(index_path): + with open(index_path, "r") as f: + weight_map = json.load(f)["weight_map"] + return sorted(set(weight_map.values())) + return sorted(f for f in os.listdir(weight_dir) if f.endswith(".safetensors")) + + def _load_prefix_state_dict(self, weight_dir, prefix): + state_dict = {} + for file_name in self._weight_files(weight_dir): + file_path = os.path.join(weight_dir, file_name) + with safe_open(file_path, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith(prefix): + state_dict[key[len(prefix) :]] = f.get_tensor(key) + return state_dict + + def load_model(self, weight_dir): + try: + from transformers.models.gemma4.modeling_gemma4 import ( + Gemma4MultimodalEmbedder, + Gemma4VisionModel as HFGemma4VisionModel, + ) + except ImportError as e: + raise ImportError("Gemma-4 vision requires a transformers build with Gemma4 support.") from e + + config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True) + if config.vision_config is None: + raise ValueError("Gemma-4 checkpoint does not contain vision_config") + + processor = AutoProcessor.from_pretrained(weight_dir) + self.image_processor = processor.image_processor + self.vision_tower = HFGemma4VisionModel(config.vision_config).eval() + self.embed_vision = Gemma4MultimodalEmbedder(config.vision_config, config.text_config).eval() + + vision_state = self._load_prefix_state_dict(weight_dir, "model.vision_tower.") + embed_state = self._load_prefix_state_dict(weight_dir, "model.embed_vision.") + missing, unexpected = self.vision_tower.load_state_dict(vision_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 vision_tower weight mismatch: missing={missing}, unexpected={unexpected}") + missing, unexpected = self.embed_vision.load_state_dict(embed_state, strict=False) + if missing or unexpected: + raise RuntimeError(f"Gemma-4 embed_vision weight mismatch: missing={missing}, unexpected={unexpected}") + + return self + + def cuda(self): + self.device = torch.device("cuda") + self.vision_tower = self.vision_tower.cuda() + self.embed_vision = self.embed_vision.cuda() + return self + + def forward(self, pixel_values, image_position_ids): + pixel_values = pixel_values.to(self.device, non_blocking=True) + image_position_ids = image_position_ids.to(self.device, non_blocking=True) + pooling_k = self.vision_tower.config.pooling_kernel_size + pooling_k2 = pooling_k * pooling_k + + # Per-image vision-tower call. `output_length` MUST match the per-image + # num_soft_tokens the image processor declared; otherwise HF's pooler + # falls back to config.image_seq_length and silently emits a different + # token count than what `valid_ids` expects. + per_image_hidden = [] + for i in range(pixel_values.shape[0]): + pv = pixel_values[i : i + 1] + pp = image_position_ids[i : i + 1] + output_length = pv.shape[1] // pooling_k2 + per_image_hidden.append( + self.vision_tower( + pixel_values=pv, + pixel_position_ids=pp, + output_length=output_length, + ).last_hidden_state + ) + + # embed_vision is token-independent (RMSNorm + Linear); cat once and + # project once instead of looping like vllm — same numerics, fewer + # Python launches, lines up naturally with our flat embed-cache output. + flat_hidden = torch.cat(per_image_hidden, dim=0) + target_dtype = self.embed_vision.embedding_projection.weight.dtype + image_features = self.embed_vision(inputs_embeds=flat_hidden.unsqueeze(0).to(target_dtype)).squeeze(0) + return image_features.to(self.data_type) + + @torch.inference_mode() + def encode(self, images: List[ImageItem]): + pil_images = [] + uuids = [] + for img in images: + if not isinstance(img, ImageItem): + raise TypeError(f"Unsupported Gemma-4 image input type: {type(img)}") + uuids.append(img.uuid) + image_data = read_shm(get_shm_name_data(img.uuid)) + with Image.open(BytesIO(image_data)) as image: + pil_images.append(image.convert("RGB")) + + if not pil_images: + return None + + image_inputs = self.image_processor(pil_images, return_tensors="pt") + token_nums = image_inputs.pop("num_soft_tokens_per_image") + pixel_values = image_inputs["pixel_values"] + image_position_ids = image_inputs["image_position_ids"] + + valid_ids = [] + valid_start = 0 + for img, token_num in zip(images, token_nums): + token_num = int(token_num) + if img.token_num != token_num: + raise ValueError(f"Gemma-4 image token mismatch: allocated={img.token_num}, encoded={token_num}") + valid_ids.append([valid_start, valid_start + token_num]) + valid_start += token_num + + all_img_embeds = self.forward(pixel_values, image_position_ids) + if all_img_embeds.shape[0] != valid_start: + raise ValueError( + f"Gemma-4 image embed length mismatch: embeds={all_img_embeds.shape[0]}, tokens={valid_start}" + ) + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/gemma4/infer_struct.py b/lightllm/models/gemma4/infer_struct.py new file mode 100644 index 0000000000..a6ad2f9c8b --- /dev/null +++ b/lightllm/models/gemma4/infer_struct.py @@ -0,0 +1,71 @@ +import torch +from lightllm.common.basemodel import InferStateInfo +from lightllm.models.gemma4.triton_kernel.build_b_image_token_end import build_b_image_token_end + + +class Gemma4InferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() + # Gemma-4 uses two RoPE frequency tables (one per layer type): + # * sliding_attention layers: theta=10000, full rotation over head_dim=256 + # * full_attention layers: theta=1_000_000, partial rotation (first 25% of head_dim=512) + self.position_cos_sliding = None + self.position_sin_sliding = None + self.position_cos_full = None + self.position_sin_full = None + self.b_image_token_end = None + + def init_some_extra_state(self, model): + super().init_some_extra_state(model) + position_ids = self.position_ids + self.position_cos_sliding = torch.index_select(model._cos_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_sliding = torch.index_select(model._sin_cached_sliding, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_cos_full = torch.index_select(model._cos_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + self.position_sin_full = torch.index_select(model._sin_cached_full, 0, position_ids).view( + position_ids.shape[0], -1 + ) + if self.is_prefill: + self.max_seq_len = self.max_kv_seq_len + self._build_b_image_token_end() + return + + def _build_b_image_token_end(self): + device = self.position_ids.device + self.b_image_token_end = torch.zeros(self.position_ids.shape[0], dtype=torch.int32, device=device) + + if not self.multimodal_params: + return + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in self.multimodal_params: + b_image_start_num.append(image_start_num) + images = params.get("images", []) + b_image_nums.append(len(images)) + for img in images: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + if image_start_num == 0: + return + + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32).cuda(non_blocking=True), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32).cuda(non_blocking=True), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32).cuda(non_blocking=True), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32).cuda(non_blocking=True), + b_q_start_loc=self.b_q_start_loc, + b_ready_cache_len=self.b_ready_cache_len, + b_q_seq_len=self.b_q_seq_len, + b_image_token_end=self.b_image_token_end, + ) diff --git a/lightllm/models/gemma4/layer_infer/__init__.py b/lightllm/models/gemma4/layer_infer/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/layer_infer/post_layer_infer.py b/lightllm/models/gemma4/layer_infer/post_layer_infer.py new file mode 100644 index 0000000000..22bcf0508d --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/post_layer_infer.py @@ -0,0 +1,20 @@ +import torch +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer + + +class Gemma4PostLayerInfer(LlamaPostLayerInfer): + """ + Same final RMSNorm + tied lm_head path as Llama, with an extra tanh-based + logit softcap at the end: logits = softcap * tanh(logits / softcap). + """ + + def __init__(self, network_config): + super().__init__(network_config) + self.final_logit_softcapping = float(network_config.get("final_logit_softcapping")) + + def token_forward(self, input_embdings, infer_state, layer_weight): + logits = super().token_forward(input_embdings, infer_state, layer_weight) + if self.final_logit_softcapping is not None and self.final_logit_softcapping > 0: + cap = self.final_logit_softcapping + logits = torch.tanh(logits / cap) * cap + return logits diff --git a/lightllm/models/gemma4/layer_infer/pre_layer_infer.py b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py new file mode 100644 index 0000000000..50d459f642 --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/pre_layer_infer.py @@ -0,0 +1,69 @@ +import math +import torch +import torch.distributed as dist +from lightllm.distributed.communication_op import all_reduce +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.qwen_vl.layer_infer.pre_layer_infer import LlamaMultimodalPreLayerInfer +from lightllm.utils.envs_utils import get_env_start_args + + +class Gemma4PreLayerInfer(LlamaMultimodalPreLayerInfer): + def __init__(self, network_config): + super().__init__(network_config) + self.embed_scale = float(network_config["hidden_size"]) ** 0.5 + self.multimodal_text_embed_scale_ = self.embed_scale + self.pad_token_id_ = network_config.get("pad_token_id", 0) + + self.has_ple = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple: + self.num_layers_ = network_config["num_hidden_layers"] + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + self.ple_embed_scale_ = math.sqrt(self.ple_dim_) + self.ple_proj_scale_ = float(network_config["hidden_size"]) ** -0.5 + self.ple_combine_scale_ = 2.0 ** -0.5 + self.rms_norm_eps_ = network_config.get("rms_norm_eps", 1e-6) + self.ple_static_buffer = None + + def _compute_per_layer_embeds(self, input_ids_for_ple, input_embdings, infer_state, layer_weight): + ple_embeds = layer_weight.embed_tokens_per_layer_weight_(input_ids_for_ple) + if self.tp_world_size_ > 1: + all_reduce(ple_embeds, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + ple_embeds = ple_embeds * self.ple_embed_scale_ + + ple_proj = layer_weight.per_layer_model_projection_weight_.mm(input_embdings) + ple_proj = ple_proj * self.ple_proj_scale_ + ple_proj = ple_proj.reshape(*ple_proj.shape[:-1], self.num_layers_, self.ple_dim_) + ple_proj = layer_weight.per_layer_projection_norm_weight_( + input=ple_proj, eps=self.rms_norm_eps_, alloc_func=self.alloc_tensor + ) + + ple_embeds = ple_embeds.reshape(*ple_embeds.shape[:-1], self.num_layers_, self.ple_dim_) + buf = self.ple_static_buffer + N = input_embdings.shape[0] + out = buf[:N] + torch.add(ple_proj, ple_embeds, out=out) + out.mul_(self.ple_combine_scale_) + + def context_forward(self, input_ids, infer_state, layer_weight): + input_embdings = LlamaMultimodalPreLayerInfer.context_forward(self, input_ids, infer_state, layer_weight) + if self.has_ple: + input_ids_for_ple = input_ids.masked_fill(infer_state.b_image_token_end != 0, self.pad_token_id_) + self._compute_per_layer_embeds(input_ids_for_ple, input_embdings, infer_state, layer_weight) + return input_embdings + + def token_forward(self, input_ids, infer_state, layer_weight): + input_embdings = LlamaPreLayerInfer.token_forward(self, input_ids, infer_state, layer_weight) + input_embdings = input_embdings * self.embed_scale + if self.has_ple: + self._compute_per_layer_embeds(input_ids, input_embdings, infer_state, layer_weight) + return input_embdings + + def _tpsp_sp_split(self, input: torch.Tensor, infer_state): + if self.tp_world_size_ > 1 and get_env_start_args().enable_tpsp_mix_mode: + # SP would need a per-rank slice (N/world_size tokens), but the + # PLE static buffer is sized/written for the full N tokens. If you + # ever need SP + PLE, refactor _compute_per_layer_embeds to do an + # sp_pad_copy into a per-rank buffer. + assert not self.has_ple, "gemma4 PLE + enable_tpsp_mix_mode not implemented" + return super()._tpsp_sp_split(input=input, infer_state=infer_state) + return input diff --git a/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py new file mode 100644 index 0000000000..a52c87c32e --- /dev/null +++ b/lightllm/models/gemma4/layer_infer/transformer_layer_infer.py @@ -0,0 +1,383 @@ +import math +import torch +import torch.nn as nn + +from lightllm.common.basemodel.attention.base_att import AttControl +from lightllm.common.basemodel.infer_struct import InferStateInfo +from lightllm.common.basemodel.triton_kernel.norm.rmsnorm import rmsnorm_forward +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.models.gemma4.triton_kernel.context_attention_fwd_gemma4_mm import ( + context_attention_fwd_gemma4_mm, +) +from lightllm.common.basemodel.triton_kernel.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd + + +class Gemma4TransformerLayerInfer(LlamaTransformerLayerInfer): + """ + Gemma-4 decoder block. Per-layer heterogeneity (sliding vs full attention) + is handled by switching shape / RoPE table / sliding-window flag at init + time. The KV cache layout is uniform (sliding shape: num_kv_heads=16, + head_dim=256); full-attention layers pack their (4, 512) tensor into the + first 8 heads of the 16-head slot at cache-write time, then reshape on + read. See Gemma4TpPartModel._init_mem_manager for context. + """ + + def __init__(self, layer_num, network_config): + super().__init__(layer_num, network_config) + self.eps_ = network_config.get("rms_norm_eps", 1e-6) + self.embed_dim_ = network_config["hidden_size"] + self.is_moe = bool(network_config.get("enable_moe_block", False)) + self.num_experts_per_tok = network_config.get("num_experts_per_tok", network_config.get("top_k_experts", 0)) + self.norm_topk_prob = network_config.get("norm_topk_prob", True) + self.router_root_scale = self.embed_dim_ ** -0.5 + + layer_type = network_config["layer_types"][layer_num] + self.is_sliding = layer_type == "sliding_attention" + + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + + # Override parent's head_dim_ (hidden_size/num_heads = 224 on 31B, wrong + # for Gemma-4 — actual is 256 sliding / 512 full). + if self.is_sliding: + self.head_dim_ = network_config["head_dim"] + total_kv_heads = network_config["num_key_value_heads"] + self.k_eq_v = False + else: + self.head_dim_ = network_config["global_head_dim"] + total_kv_heads = num_global_kv + self.k_eq_v = network_config.get("attention_k_eq_v", True) + + # TP shard counts for this layer + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.tp_world_size_ + self.tp_k_head_num_ = max(total_kv_heads // self.tp_world_size_, 1) + self.tp_v_head_num_ = self.tp_k_head_num_ + self.tp_o_head_num_ = self.tp_q_head_num_ + + self.kv_cache_slot_dim_ = network_config["head_dim"] + sliding_total = network_config["num_key_value_heads"] * network_config["head_dim"] + full_total = num_global_kv * network_config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + assert ( + per_token_k_width % self.kv_cache_slot_dim_ == 0 + ), f"per-token K width {per_token_k_width} not aligned to kv_cache_slot_dim {self.kv_cache_slot_dim_}" + self.kv_cache_slot_num_ = (per_token_k_width // self.kv_cache_slot_dim_) // self.tp_world_size_ + + # Sliding window (None on full-attn layers) + if self.is_sliding: + sw = network_config.get("sliding_window", 0) + self.sliding_window_ = int(sw) if sw else 0 + else: + self.sliding_window_ = 0 + + # E-series Per-Layer Embeddings gate (HF: config.hidden_size_per_layer_input, + # absent or 0 on 31B). + self.has_ple_ = bool(network_config.get("hidden_size_per_layer_input")) + if self.has_ple_: + self.ple_dim_ = network_config["hidden_size_per_layer_input"] + + # HF: config.num_kv_shared_layers (may be missing or null on non-E + # checkpoints — treat as 0). + kv_shared_count = network_config.get("num_kv_shared_layers") or 0 + total_layers = network_config["num_hidden_layers"] + self.is_kv_shared_ = kv_shared_count > 0 and layer_num >= total_layers - kv_shared_count + self.kv_share_target_layer_ = None + if self.is_kv_shared_: + cutoff = total_layers - kv_shared_count + for j in range(cutoff - 1, -1, -1): + if network_config["layer_types"][j] == layer_type: + self.kv_share_target_layer_ = j + break + assert self.kv_share_target_layer_ is not None, ( + f"layer {layer_num} ({layer_type}) is KV-shared but no earlier non-shared " + f"layer of the same type found below cutoff={cutoff}" + ) + + # Always 1.0: NoPE dims for full-attn layers are zero-padded into + # cos/sin (cos=1, sin=0 → identity), so the kernel walks the whole + # head_dim. Don't change to 0.25 — that double-counts with the table. + self.partial_rotary_factor_ = 1.0 + + self.ple_static_buffer = None + + def _rope_cos_sin(self, infer_state): + # Tables are built in the model dtype (Gemma4TpPartModel._init_to_get_rotary_gemma4), + # so they already match q/k dtype — no cast needed. + if self.is_sliding: + return infer_state.position_cos_sliding, infer_state.position_sin_sliding + return infer_state.position_cos_full, infer_state.position_sin_full + + def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + input = self._tpsp_allgather(input=input, infer_state=infer_state) + + head_dim = self.head_dim_ + q_heads = self.tp_q_head_num_ + kv_heads = self.tp_k_head_num_ + + q = layer_weight.q_proj.mm(input).view(-1, q_heads, head_dim) + q = layer_weight.q_norm_weight_(input=q, eps=self.eps_, alloc_func=self.alloc_tensor) + + cos, sin = self._rope_cos_sin(infer_state) + + if self.is_kv_shared_: + # K/V come from target layer's already-rotated, already-normed cache. + rotary_emb_fwd(q, None, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + q = q * math.sqrt(head_dim) + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + return q, None + + # ---- non-shared: full K/V path ---- + k = layer_weight.k_proj.mm(input).view(-1, kv_heads, head_dim) + if self.k_eq_v: + # Full-attn k_eq_v variant (e.g. 31B): K weights serve as V. + v = k + else: + v = layer_weight.v_proj.mm(input).view(-1, kv_heads, head_dim) + + k = layer_weight.k_norm_weight_(input=k, eps=self.eps_, alloc_func=self.alloc_tensor) + + # V-norm: unweighted RMSNorm over head_dim (matches vllm's Gemma4 has_weight=False). + v = rmsnorm_forward( + x=v, + weight=None, + eps=self.eps_, + out=self.alloc_tensor(v.shape, dtype=v.dtype, device=v.device), + ) + + rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=self.partial_rotary_factor_) + + # Gemma-4 uses scaling=1.0 in attention. The attention kernel hardcodes + # sm_scale = 1/sqrt(head_dim); pre-scale Q by sqrt(head_dim) so the + # kernel's division cancels out, yielding scores = Q @ K^T. + q = q * math.sqrt(head_dim) + + # Pack into the uniform KV-cache layout (N, 2*slot_num, slot_dim). + # K occupies slots [0, used_slots); V occupies + # [slot_num, slot_num + used_slots). If this layer's K/V width is + # smaller than the allocated cache slot width, pad with zeros. + cache_slot_num = self.kv_cache_slot_num_ + cache_slot_dim = self.kv_cache_slot_dim_ + N = k.shape[0] + k_packed = k.reshape(N, -1, cache_slot_dim) + v_packed = v.reshape(N, -1, cache_slot_dim) + used_cache_slots = k_packed.shape[1] + if used_cache_slots == cache_slot_num: + cache_kv = torch.cat([k_packed, v_packed], dim=1) + else: + cache_kv = self.alloc_tensor((N, 2 * cache_slot_num, cache_slot_dim), dtype=k.dtype) + cache_kv.zero_() + cache_kv[:, :used_cache_slots, :] = k_packed + cache_kv[:, cache_slot_num : cache_slot_num + used_cache_slots, :] = v_packed + + if infer_state.need_dp_prefill_balance: + q = infer_state._all_to_all_unbalance_get(data=q) + cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv) + + return q, cache_kv + + def _post_cache_kv(self, cache_kv, infer_state, layer_weight): + if self.is_kv_shared_ or cache_kv is None: + return + return super()._post_cache_kv(cache_kv, infer_state, layer_weight) + + # ----- Attention kernels (sliding window + per-layer KV reshape) --- + + def _att_control(self): + if self.is_sliding and self.sliding_window_ > 0: + w = self.sliding_window_ - 1 + return AttControl(use_sliding_window=True, sliding_window=(w, w)) + return AttControl(use_sliding_window=False, sliding_window=(-1, -1)) + + def _get_layer_kv(self, infer_state: InferStateInfo): + # KV-shared layers read from the target layer's cache slot. + layer_idx = self.kv_share_target_layer_ if self.is_kv_shared_ else self.layer_num_ + _k_raw, _v_raw = infer_state.mem_manager.get_att_input_params(layer_index=layer_idx) + # _k_raw / _v_raw shape (S, cache_slot_num, cache_slot_dim). Use .view + # (not .reshape) so any non-contiguous layout from a future mem_manager + # backend fails loudly instead of silently copying — slice + view is + # O(1) on the standard MemoryManager layout (inner (kv_heads, head_dim) + # span is contiguous). + kv_heads = self.tp_k_head_num_ + head_dim = self.head_dim_ + cache_slot_dim = self.kv_cache_slot_dim_ + used_cache_slots = kv_heads * head_dim // cache_slot_dim + if used_cache_slots == _k_raw.shape[1]: + # Layout already matches this layer's natural shape. + return _k_raw.view(-1, kv_heads, head_dim), _v_raw.view(-1, kv_heads, head_dim) + # Otherwise the K/V live in the first used_cache_slots; the rest is zero pad. + _k = _k_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) + _v = _v_raw[:, :used_cache_slots, :].view(-1, kv_heads, head_dim) + return _k, _v + + def _context_attention_kernel( + self, + q: torch.Tensor, + kv, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + if self.is_sliding: + # Sliding layers always go through the gemma4_mm Triton kernel: it + # handles SWA + image bidirectional masking in one pass. + o_tensor = self.alloc_tensor(_q.shape, q.dtype) + sw = self.sliding_window_ - 1 if self.sliding_window_ > 0 else -1 + context_attention_fwd_gemma4_mm( + _q, + _k, + _v, + o_tensor, + infer_state.b_req_idx, + infer_state.b_q_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_q_seq_len, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_image_token_end, + sliding_window=sw, + ) + return o_tensor.view(q.shape) + + # Full-attn layers: head_dim=512, no SWA, no image bidi — standard + # triton via the primary backend. + o_tensor = infer_state.prefill_att_state.prefill_att( + q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor + ) + return o_tensor.view(q.shape) + + def _token_attention_kernel( + self, + q: torch.Tensor, + infer_state: InferStateInfo, + layer_weight: Gemma4TransformerLayerWeight, + out=None, + ) -> torch.Tensor: + _k, _v = self._get_layer_kv(infer_state) + _q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + att_state = infer_state.decode_att_state1 if self.is_sliding else infer_state.decode_att_state + o_tensor = att_state.decode_att(q=_q, k=_k, v=_v, att_control=self._att_control(), alloc_func=self.alloc_tensor) + return o_tensor.view(q.shape) + + # ----- FFN (Gemma gelu-tanh, fused gate_up + down) ----------------- + + def _ffn_dense( + self, input, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + gate_up = layer_weight.gate_up_proj.mm(input) + ffn1 = self.alloc_tensor((input.size(0), gate_up.size(1) // 2), input.dtype) + silu_and_mul_fwd(gate_up, ffn1, use_gelu=True) + gate_up = None + ffn2 = layer_weight.down_proj.mm(ffn1) + ffn1 = None + ffn2 = self._tpsp_reduce(input=ffn2, infer_state=infer_state) + return ffn2 + + def _router_logits(self, residual, layer_weight: Gemma4TransformerLayerWeight) -> torch.Tensor: + # Mirrors vllm Gemma4Router: unweighted RMSNorm -> 1/sqrt(hidden) -> + # per-channel scale -> bf16xbf16 -> fp32 gate matmul for stable top-k. + x = residual.view(-1, self.embed_dim_) + x = rmsnorm_forward(x=x, weight=None, eps=self.eps_, out=self.alloc_tensor(x.shape, dtype=x.dtype)) + x = x * self.router_root_scale * layer_weight.router_input_scale_.weight + return layer_weight.moe_gate.mm(x, out_dtype=torch.float32) + + def _ffn_moe(self, input, router_logits, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input = input.view(-1, self.embed_dim_) + input = self._tpsp_allgather(input=input, infer_state=infer_state) + moe_out = layer_weight.experts.experts( + input, + router_logits=router_logits, + top_k=self.num_experts_per_tok, + renormalize=self.norm_topk_prob, + use_grouped_topk=False, + topk_group=None, + num_expert_group=None, + is_prefill=infer_state.is_prefill, + per_expert_scale=layer_weight.experts.per_expert_scale, + use_gelu=True, + ) + moe_out = self._tpsp_reduce(input=moe_out, infer_state=infer_state) + return moe_out + + def _ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + residual = input_embdings + dense_input = layer_weight.pre_feedforward_layernorm_weight_( + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out = self._ffn_dense(dense_input, infer_state, layer_weight) + dense_input = None + + if self.is_moe: + dense_out = layer_weight.post_feedforward_layernorm_1_weight_( + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + + router_logits = self._router_logits(residual, layer_weight) + moe_input = layer_weight.pre_feedforward_layernorm_2_weight_( + input=residual, eps=self.eps_, alloc_func=self.alloc_tensor + ) + moe_out = self._ffn_moe(moe_input, router_logits, infer_state, layer_weight) + moe_input = None + router_logits = None + moe_out = layer_weight.post_feedforward_layernorm_2_weight_( + input=moe_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out.add_(moe_out) + moe_out = None + + ffn_out = layer_weight.post_feedforward_layernorm_weight_( + input=dense_out, eps=self.eps_, alloc_func=self.alloc_tensor + ) + dense_out = None + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + # ----- block-level forwards (PLE fusion + layer_scalar at the end) ---- + + def _block_epilogue(self, hidden_states, infer_state, layer_weight): + if self.has_ple_: + flat = hidden_states.view(-1, self.embed_dim_) + N = flat.shape[0] + ple_slice = self.ple_static_buffer[:N, self.layer_num_, :] + gate = layer_weight.per_layer_input_gate_.mm(flat) + gated = nn.functional.gelu(gate, approximate="tanh") * ple_slice + contrib = layer_weight.per_layer_projection_.mm(gated) + contrib = layer_weight.post_per_layer_input_norm_weight_( + input=contrib, eps=self.eps_, alloc_func=self.alloc_tensor + ) + flat.add_(contrib) + hidden_states.mul_(layer_weight.layer_scalar_.weight) + return hidden_states + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) + o = self.context_attention_forward(input1, infer_state, layer_weight) + input1 = None + # Gemma sandwich norm: post_attention_layernorm on the attn branch + # before the residual add, not on the post-add residual stream. + o = self._ffn_norm(o, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) + + return self._block_epilogue(input_embdings, infer_state, layer_weight) + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight: Gemma4TransformerLayerWeight): + input1 = self._att_norm(input_embdings.view(-1, self.embed_dim_), infer_state, layer_weight) + o = self.token_attention_forward(input1, infer_state, layer_weight) + input1 = None + o = self._ffn_norm(o, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input_embdings = self._ffn(input_embdings, infer_state, layer_weight) + + return self._block_epilogue(input_embdings, infer_state, layer_weight) diff --git a/lightllm/models/gemma4/layer_weights/__init__.py b/lightllm/models/gemma4/layer_weights/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 0000000000..22a2fc4dc7 --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,64 @@ +from lightllm.common.basemodel import PreAndPostLayerWeight +from lightllm.common.basemodel.layer_weights.meta_weights import ( + EmbeddingWeight, + LMHeadWeight, + ROWMMWeight, + RMSNormWeight, +) + + +class Gemma4PreAndPostLayerWeight(PreAndPostLayerWeight): + def __init__(self, data_type, network_config): + super().__init__(data_type, network_config) + hidden_size = network_config["hidden_size"] + vocab_size = network_config["vocab_size"] + + self.wte_weight_ = EmbeddingWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="model.language_model.embed_tokens.weight", + data_type=self.data_type_, + ) + # lm_head is tied to input embedding for Gemma-4 (no separate lm_head.weight). + self.lm_head_weight_ = LMHeadWeight( + dim=hidden_size, + vocab_size=vocab_size, + weight_name="lm_head.weight", + data_type=self.data_type_, + embedding_weight=self.wte_weight_, + ) + + # Gemma-4 uses standard RMSNorm (not the gemma2/3 (1+w) variant). + self.final_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name="model.language_model.norm.weight", + data_type=self.data_type_, + ) + + if network_config.get("hidden_size_per_layer_input"): + num_layers = network_config["num_hidden_layers"] + ple_dim = network_config["hidden_size_per_layer_input"] + ple_vocab = network_config.get("vocab_size_per_layer_input", vocab_size) + self.embed_tokens_per_layer_weight_ = EmbeddingWeight( + dim=num_layers * ple_dim, + vocab_size=ple_vocab, + weight_name="model.language_model.embed_tokens_per_layer.weight", + data_type=self.data_type_, + ) + # nn.Linear(in=hidden_size, out=num_layers*ple_dim); HF storage is + # (out, in). Replicated across TP ranks. + self.per_layer_model_projection_weight_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[num_layers * ple_dim], + weight_names="model.language_model.per_layer_model_projection.weight", + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + # RMSNorm over the ple_dim of the projection output. + self.per_layer_projection_norm_weight_ = RMSNormWeight( + dim=ple_dim, + weight_name="model.language_model.per_layer_projection_norm.weight", + data_type=self.data_type_, + ) + return diff --git a/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py new file mode 100644 index 0000000000..6d9a5c2613 --- /dev/null +++ b/lightllm/models/gemma4/layer_weights/transformer_layer_weight.py @@ -0,0 +1,274 @@ +from lightllm.common.basemodel.layer_weights.meta_weights.mm_weight import ROWMMWeight, COLMMWeight +from lightllm.common.basemodel.layer_weights.meta_weights import RMSNormWeight, ParameterWeight +from lightllm.common.basemodel.layer_weights.meta_weights.fused_moe.gemma4_packed_fused_moe_weight import ( + Gemma4PackedFusedMoeWeight, +) +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight +from lightllm.utils.envs_utils import get_env_start_args + + +class Gemma4TransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__( + self, + layer_num, + data_type, + network_config, + quant_cfg=None, + ): + self._pre_parse_layer_shape(layer_num, network_config) + super().__init__(layer_num, data_type, network_config, quant_cfg) + return + + def _pre_parse_layer_shape(self, layer_num, network_config): + self._is_moe = bool(network_config.get("enable_moe_block", False)) + layer_type = network_config["layer_types"][layer_num] + self._is_sliding = layer_type == "sliding_attention" + # Some E-series checkpoints leave num_global_key_value_heads = null; + # HF treats that as "fall back to num_key_value_heads". + num_global_kv = network_config.get("num_global_key_value_heads") or network_config["num_key_value_heads"] + if self._is_sliding: + self._layer_head_dim = network_config["head_dim"] + self._layer_kv_head_num = network_config["num_key_value_heads"] + self._layer_k_eq_v = False + else: + self._layer_head_dim = network_config["global_head_dim"] + self._layer_kv_head_num = num_global_kv + self._layer_k_eq_v = network_config.get("attention_k_eq_v", True) + + def _parse_config(self): + self.n_head = self.network_config_["num_attention_heads"] + self.q_head_num_ = self.network_config_["num_attention_heads"] + self.k_head_num_ = self._layer_kv_head_num + self.v_head_num_ = self._layer_kv_head_num + self.o_head_num_ = self.q_head_num_ + self.head_dim = self._layer_head_dim + self.n_embed = self.network_config_["hidden_size"] + self.n_inter = self.network_config_["intermediate_size"] + + def _init_weight_names(self): + prefix = f"model.language_model.layers.{self.layer_num_}" + self._q_weight_name = f"{prefix}.self_attn.q_proj.weight" + self._q_bias_name = None + self._k_weight_name = f"{prefix}.self_attn.k_proj.weight" + self._k_bias_name = None + self._v_weight_name = f"{prefix}.self_attn.v_proj.weight" + self._v_bias_name = None + self._o_weight_name = f"{prefix}.self_attn.o_proj.weight" + self._o_bias_name = None + + self._q_norm_weight_name = f"{prefix}.self_attn.q_norm.weight" + self._k_norm_weight_name = f"{prefix}.self_attn.k_norm.weight" + + self._gate_weight_name = f"{prefix}.mlp.gate_proj.weight" + self._up_weight_name = f"{prefix}.mlp.up_proj.weight" + self._down_weight_name = f"{prefix}.mlp.down_proj.weight" + + self._att_norm_weight_name = f"{prefix}.input_layernorm.weight" + self._ffn_norm_weight_name = f"{prefix}.post_attention_layernorm.weight" + self._pre_feedforward_layernorm_name = f"{prefix}.pre_feedforward_layernorm.weight" + self._post_feedforward_layernorm_name = f"{prefix}.post_feedforward_layernorm.weight" + self._post_feedforward_layernorm_1_name = f"{prefix}.post_feedforward_layernorm_1.weight" + self._pre_feedforward_layernorm_2_name = f"{prefix}.pre_feedforward_layernorm_2.weight" + self._post_feedforward_layernorm_2_name = f"{prefix}.post_feedforward_layernorm_2.weight" + + self._router_input_scale_name = f"{prefix}.router.scale" + self._router_weight_name = f"{prefix}.router.proj.weight" + + self._layer_scalar_name = f"{prefix}.layer_scalar" + + # E-series Per-Layer Embeddings names (only loaded when PLE enabled). + self._per_layer_input_gate_name = f"{prefix}.per_layer_input_gate.weight" + self._per_layer_projection_name = f"{prefix}.per_layer_projection.weight" + self._post_per_layer_input_norm_name = f"{prefix}.post_per_layer_input_norm.weight" + + def _init_weight(self): + self._init_qkv() + self._init_o() + self._init_ffn() + if self._is_moe: + self._init_moe() + self._init_norm() + if self.network_config_.get("hidden_size_per_layer_input"): + self._init_ple() + + def _init_ple(self): + ple_dim = self.network_config_["hidden_size_per_layer_input"] + hidden_size = self.network_config_["hidden_size"] + self.per_layer_input_gate_ = ROWMMWeight( + in_dim=hidden_size, + out_dims=[ple_dim], + weight_names=self._per_layer_input_gate_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.per_layer_projection_ = ROWMMWeight( + in_dim=ple_dim, + out_dims=[hidden_size], + weight_names=self._per_layer_projection_name, + data_type=self.data_type_, + tp_rank=0, + tp_world_size=1, + ) + self.post_per_layer_input_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_per_layer_input_norm_name, + data_type=self.data_type_, + ) + + def _init_qkv(self): + in_dim = self.n_embed + q_out_dim = self.q_head_num_ * self.head_dim + kv_out_dim = self.k_head_num_ * self.head_dim + + self.q_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[q_out_dim], + weight_names=self._q_weight_name, + data_type=self.data_type_, + bias_names=self._q_bias_name, + quant_method=self.get_quant_method("q_proj"), + ) + self.k_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._k_weight_name, + data_type=self.data_type_, + bias_names=self._k_bias_name, + quant_method=self.get_quant_method("k_proj"), + ) + if not self._layer_k_eq_v: + self.v_proj = ROWMMWeight( + in_dim=in_dim, + out_dims=[kv_out_dim], + weight_names=self._v_weight_name, + data_type=self.data_type_, + bias_names=self._v_bias_name, + quant_method=self.get_quant_method("v_proj"), + ) + # For k_eq_v layers HF checkpoint has no v_proj weight; the inference + # code aliases v = k at compute time, so no weight object is created. + + def _init_o(self): + in_dim = self.o_head_num_ * self.head_dim + out_dim = self.n_embed + self.o_proj = COLMMWeight( + in_dim=in_dim, + out_dims=[out_dim], + weight_names=self._o_weight_name, + data_type=self.data_type_, + bias_names=self._o_bias_name, + quant_method=self.get_quant_method("o_proj"), + ) + + def _init_ffn(self): + # Packed gate+up: ROWMMWeight stitches `gate_proj` and `up_proj` weights + # along the output dim so the dense FFN runs one matmul + a fused + # gelu*mul kernel (mirrors llama's gate_up_proj path). + self.gate_up_proj = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.n_inter, self.n_inter], + weight_names=[self._gate_weight_name, self._up_weight_name], + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("gate_up_proj"), + ) + self.down_proj = COLMMWeight( + in_dim=self.n_inter, + out_dims=[self.n_embed], + weight_names=self._down_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("down_proj"), + ) + + def _init_moe(self): + enable_ep_moe = get_env_start_args().enable_ep_moe + assert not enable_ep_moe, "Gemma-4 MoE packed expert weights currently support TP mode only." + + self.router_input_scale_ = ParameterWeight( + weight_name=self._router_input_scale_name, + data_type=self.data_type_, + weight_shape=(self.n_embed,), + ) + self.moe_gate = ROWMMWeight( + in_dim=self.n_embed, + out_dims=[self.network_config_["num_experts"]], + weight_names=self._router_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_method=self.get_quant_method("moe_gate"), + tp_rank=0, + tp_world_size=1, + ) + self.experts = Gemma4PackedFusedMoeWeight( + gate_proj_name="gate_proj", + down_proj_name="down_proj", + up_proj_name="up_proj", + e_score_correction_bias_name="", + weight_prefix=f"model.language_model.layers.{self.layer_num_}.experts", + n_routed_experts=self.network_config_["num_experts"], + hidden_size=self.network_config_["hidden_size"], + moe_intermediate_size=self.network_config_["moe_intermediate_size"], + data_type=self.data_type_, + quant_method=self.quant_cfg.get_quant_method(self.layer_num_, "fused_moe"), + layer_num=self.layer_num_, + network_config=self.network_config_, + per_expert_scale_name=f"model.language_model.layers.{self.layer_num_}.router.per_expert_scale", + ) + + def _init_norm(self): + hidden_size = self.network_config_["hidden_size"] + # Gemma-4 uses standard RMSNorm (x * rsqrt(var+eps) * w), NOT the + # gemma2/3 (1+w) variant - do not swap in NoTpGEMMANormWeight. + self.q_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._q_norm_weight_name, + data_type=self.data_type_, + ) + self.k_norm_weight_ = RMSNormWeight( + dim=self._layer_head_dim, + weight_name=self._k_norm_weight_name, + data_type=self.data_type_, + ) + self.att_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._att_norm_weight_name, + data_type=self.data_type_, + ) + self.ffn_norm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._ffn_norm_weight_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_name, + data_type=self.data_type_, + ) + if self._is_moe: + self.post_feedforward_layernorm_1_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_1_name, + data_type=self.data_type_, + ) + self.pre_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._pre_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) + self.post_feedforward_layernorm_2_weight_ = RMSNormWeight( + dim=hidden_size, + weight_name=self._post_feedforward_layernorm_2_name, + data_type=self.data_type_, + ) + self.layer_scalar_ = ParameterWeight( + weight_name=self._layer_scalar_name, + data_type=self.data_type_, + weight_shape=(1,), + ) diff --git a/lightllm/models/gemma4/model.py b/lightllm/models/gemma4/model.py new file mode 100644 index 0000000000..7d70c31bed --- /dev/null +++ b/lightllm/models/gemma4/model.py @@ -0,0 +1,336 @@ +import math +import os +import json +import torch +from lightllm.models.registry import ModelRegistry +from lightllm.common.basemodel.multimodal_tokenizer import BaseMultiModalTokenizer +from lightllm.common.basemodel.attention.triton.fp import TritonAttBackend +from lightllm.common.kv_cache_mem_manager.mem_utils import select_mem_manager_class +from lightllm.common.build_utils import repair_config +from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.models.gemma4.infer_struct import Gemma4InferStateInfo +from lightllm.models.gemma4.layer_infer.pre_layer_infer import Gemma4PreLayerInfer +from lightllm.models.gemma4.layer_infer.post_layer_infer import Gemma4PostLayerInfer +from lightllm.models.gemma4.layer_infer.transformer_layer_infer import Gemma4TransformerLayerInfer +from lightllm.models.gemma4.layer_weights.pre_and_post_layer_weight import Gemma4PreAndPostLayerWeight +from lightllm.models.gemma4.layer_weights.transformer_layer_weight import Gemma4TransformerLayerWeight +from lightllm.utils.envs_utils import get_added_mtp_kv_layer_num, get_env_start_args +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager + +logger = init_logger(__name__) + + +class Gemma4Tokenizer(BaseMultiModalTokenizer): + def __init__(self, tokenizer, model_cfg, image_processor=None): + super().__init__(tokenizer) + self.image_token_index = model_cfg.get("image_token_id", 258880) + self.boi_token_index = model_cfg.get("boi_token_id", 255999) + self.eoi_token_index = model_cfg.get("eoi_token_id", 258882) + self.image_processor = image_processor + self.image_length = model_cfg.get("vision_soft_tokens_per_image", 280) + self.patch_size = getattr(self.image_processor, "patch_size", 16) + self.pooling_kernel_size = getattr(self.image_processor, "pooling_kernel_size", 3) + self.max_soft_tokens = getattr(self.image_processor, "max_soft_tokens", self.image_length) + # HF Gemma-4 tokenizer does not prepend BOS even with add_special_tokens=True. + self.bos_token_id = tokenizer.bos_token_id + + def init_imageitem_extral_params(self, img, multi_params, sampling_params): + return + + def init_audioitem_extral_params(self, audio, multi_params, sampling_params): + raise NotImplementedError + + def get_image_token_length(self, img): + if self.image_processor is None or img.image_w <= 0 or img.image_h <= 0: + return self.image_length + + patch, kernel = self.patch_size, self.pooling_kernel_size + unit = patch * kernel + num_patches_orig = (img.image_h / patch) * (img.image_w / patch) + scale = math.sqrt(self.max_soft_tokens * kernel ** 2 / num_patches_orig) + target_h = max(unit, int(math.floor(img.image_h * scale / unit)) * unit) + target_w = max(unit, int(math.floor(img.image_w * scale / unit)) * unit) + num_patches = (target_h // patch) * (target_w // patch) + return min(num_patches // kernel ** 2, self.max_soft_tokens) + + def get_audio_token_length(self, audio): + raise NotImplementedError + + def encode(self, prompt, multimodal_params=None, add_special_tokens=False): + origin_ids = self.tokenizer(prompt, add_special_tokens=False).input_ids + if ( + add_special_tokens + and self.bos_token_id is not None + and (len(origin_ids) == 0 or origin_ids[0] != self.bos_token_id) + ): + origin_ids = [self.bos_token_id] + origin_ids + + images = [] if multimodal_params is None else getattr(multimodal_params, "images", []) + if not images: + return origin_ids + + input_ids = [] + image_id = 0 + start = 0 + while True: + try: + image_start = origin_ids.index(self.image_token_index, start) + except ValueError: + break + + input_ids.extend(origin_ids[start:image_start]) + image_end = image_start + 1 + while image_end < len(origin_ids) and origin_ids[image_end] == self.image_token_index: + image_end += 1 + if image_id >= len(images): + raise ValueError("image token error") + + img = images[image_id] + if not input_ids or input_ids[-1] != self.boi_token_index: + input_ids.append(self.boi_token_index) + img.start_idx = len(input_ids) + input_ids.extend(range(img.token_id, img.token_id + img.token_num)) + input_ids.append(self.eoi_token_index) + + if image_end < len(origin_ids) and origin_ids[image_end] == self.eoi_token_index: + image_end += 1 + start = image_end + image_id += 1 + + input_ids.extend(origin_ids[start:]) + image_cnt = len(images) + if image_cnt != image_id: + raise ValueError(f"invalid image tag num: {image_cnt} vs {image_id}!") + return input_ids + + +@ModelRegistry("gemma4", is_multimodal=True) +class Gemma4TpPartModel(LlamaTpPartModel): + pre_and_post_weight_class = Gemma4PreAndPostLayerWeight + transformer_weight_class = Gemma4TransformerLayerWeight + + pre_layer_infer_class = Gemma4PreLayerInfer + transformer_layer_infer_class = Gemma4TransformerLayerInfer + post_layer_infer_class = Gemma4PostLayerInfer + + infer_state_class = Gemma4InferStateInfo + + def __init__(self, kvargs): + # head_dim_ is used by the default _init_to_get_rotary which we + # override; still set it to the sliding-layer head_dim for consistency + # with the mem manager and any generic helpers. + self.head_dim_ = 256 + super().__init__(kvargs) + return + + def _init_config(self): + with open(os.path.join(self.weight_dir_, "config.json"), "r") as json_file: + self.config = json.load(json_file) + # The shipped checkpoint is a multimodal config wrapping a Gemma4TextConfig + # under text_config; flatten it so downstream code sees text-model fields + # at the top level (mirrors the gemma3 approach). + if "text_config" in self.config: + self.config = self.config["text_config"].copy() + + repair_config(self.config, same_names=["num_attention_heads", "n_head"]) + repair_config(self.config, same_names=["hidden_size", "n_embd", "n_embed"]) + repair_config(self.config, same_names=["num_hidden_layers", "n_layer"]) + self._reset_num_key_value_heads() + + if self.finetune_config: + self.config["vocab_size"] = self.finetune_config.vocab_size + + if self.config.get("enable_moe_block", False): + # LightLLM's MoE helpers use Qwen/DeepSeek-style field names. + # Gemma-4 checkpoints expose equivalent values as top_k_experts + # and moe_intermediate_size. + self.config.setdefault("num_experts_per_tok", self.config["top_k_experts"]) + self.config.setdefault("norm_topk_prob", True) + self.config.setdefault("scoring_func", "softmax") + return + + def _verify_params(self): + assert self.load_way == "HF", "Gemma-4 only supports HF format." + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + assert self.config["num_key_value_heads"] % self.tp_world_size_ == 0 + # Use `or` rather than the dict.get default: E4B-style configs ship + # `num_global_key_value_heads: null`, which the default form would + # leave as None. + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] + assert ( + num_global_kv % self.tp_world_size_ == 0 + ), f"num_global_key_value_heads={num_global_kv} must be divisible by tp={self.tp_world_size_}" + kv_shared = self.config.get("num_kv_shared_layers") or 0 + assert 0 <= kv_shared < self.config["num_hidden_layers"], ( + f"num_kv_shared_layers={kv_shared} out of range for " + f"num_hidden_layers={self.config['num_hidden_layers']}" + ) + return + + def _init_mem_manager(self): + # Uniform per-layer KV cache layout. The per-layer cache slot must fit + # whichever layer type has the largest per-token K/V width: sliding + # (num_key_value_heads * head_dim) or full + # (num_global_kv * global_head_dim). Keep cache_slot_dim = head_dim + # and pick cache_slot_num = max-width / head_dim. For 31B this + # collapses to num_key_value_heads; for E4B the full-attn shape wins + # (2*512 > 2*256), so it uses 4 storage slots of 256 dims. + # Gemma4TransformerLayerInfer.__init__ computes the same value and + # uses it to pack/unpack K/V at write/read time. + head_dim = self.config["head_dim"] + num_global_kv = self.config.get("num_global_key_value_heads") or self.config["num_key_value_heads"] + sliding_total = self.config["num_key_value_heads"] * self.config["head_dim"] + full_total = num_global_kv * self.config["global_head_dim"] + per_token_k_width = max(sliding_total, full_total) + head_num_per_rank = (per_token_k_width // head_dim) // self.tp_world_size_ + self.mem_manager = select_mem_manager_class()( + self.max_total_token_num, + dtype=self.data_type, + head_num=head_num_per_rank, + head_dim=head_dim, + layer_num=self.config["num_hidden_layers"] + get_added_mtp_kv_layer_num(), + mem_fraction=self.mem_fraction, + ) + return + + def _init_att_backend(self): + # Gemma-4 has per-layer heterogeneous attention: sliding layers use + # (head_dim=256, kv_heads=16); full-attn layers use (head_dim=512, + # kv_heads=4, k_eq_v). No single generic backend setup covers both: + # - FA3 caps head_dim at 256 -> can't run full-attn layers. + # - Flashinfer plans once per infer_state on a single shape -> can't + # accommodate heterogeneous layout at all. + # Strategy: + # - Prefill: sliding layers go through the gemma4_mm Triton kernel + # directly (handles SWA + image bidi); full-attn layers use the + # primary triton backend below. No FA3 in prefill — its + # image_token_end build asserts incompatible with SWA. Revisit + # when fa3 supports both simultaneously. + # - Decode: full-attn layers on triton (primary); sliding layers on + # fa3 (with SWA) when available — secondary backend set in + # _init_att_backend1. + fa3_loadable = self._gemma4_fa3_loadable() + + # Full-attn layers always go through triton. + self.prefill_att_backend = TritonAttBackend(model=self) + self.decode_att_backend = TritonAttBackend(model=self) + + self._gemma4_sliding_decode_backend_kind = self._resolve_gemma4_sliding_backend( + self.args.llm_decode_att_backend[0], fa3_loadable + ) + + def _init_att_backend1(self): + # Only decode needs the sliding-layer backend; prefill sliding goes + # through gemma4_mm Triton directly in the layer. + self.prefill_att_backend1 = None + self.decode_att_backend1 = self._build_gemma4_sliding_backend(self._gemma4_sliding_decode_backend_kind) + + @staticmethod + def _gemma4_fa3_loadable(): + from lightllm.utils.sgl_utils import flash_attn_with_kvcache + + return flash_attn_with_kvcache is not None + + @staticmethod + def _resolve_gemma4_sliding_backend(backend_name, fa3_loadable): + assert backend_name in ("auto", "triton", "fa3"), ( + "Gemma-4 requires triton or fa3 for sliding layers; flashinfer is " + f"not wired for the heterogeneous layout. Got backend={backend_name!r}." + ) + if backend_name == "auto": + return "fa3" if fa3_loadable else "triton" + if backend_name == "fa3": + assert fa3_loadable, ( + "Requested --llm_*_att_backend=fa3 but flash_attn_with_kvcache " + "did not import (sgl_kernel missing or wrong arch)." + ) + return backend_name + + def _build_gemma4_sliding_backend(self, backend_kind): + if backend_kind == "fa3": + from lightllm.common.basemodel.attention.fa3.fp import Fa3AttBackend + + return Fa3AttBackend(model=self) + assert backend_kind == "triton" + return TritonAttBackend(model=self) + + def _init_custom(self): + self._init_to_get_rotary_gemma4() + if self.config.get("enable_moe_block", False): + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + self._init_ple_static_buffer() + + def _init_ple_static_buffer(self): + ple_dim = self.config.get("hidden_size_per_layer_input") or 0 + if ple_dim <= 0: + return + args = get_env_start_args() + max_tokens = max( + int(self.batch_max_tokens or 0), + int(self.graph_max_batch_size or 0), + int(getattr(args, "prefill_cudagraph_max_handle_token", 0) or 0), + ) + assert max_tokens > 0, "PLE static buffer needs a positive max-token bound" + num_layers = self.config["num_hidden_layers"] + buf = torch.zeros((max_tokens, num_layers, ple_dim), dtype=self.data_type, device="cuda") + self.pre_infer.ple_static_buffer = buf + for layer_infer in self.layers_infer: + layer_infer.ple_static_buffer = buf + logger.info( + f"Allocated PLE static buffer: tokens={max_tokens}, layers={num_layers}, " + f"ple_dim={ple_dim}, dtype={self.data_type}" + ) + + def _init_to_get_rotary_gemma4(self): + rope_params = self.config["rope_parameters"] + + # Cap the rotary table at something we can fit in memory — Gemma-4's + # advertised max_position_embeddings is 262144 which would require + # ~200MB per table in fp32. Rely on the server's max_seq_length instead. + max_seq_len = max(self.max_seq_length + 1024, 16384) + + t = torch.arange(max_seq_len, dtype=torch.float32, device="cpu") + + # Sliding layers: default RoPE, theta=10000, full rotation over head_dim=256. + sliding_params = rope_params["sliding_attention"] + sliding_head_dim = self.config["head_dim"] + sliding_theta = sliding_params["rope_theta"] + sliding_partial = sliding_params.get("partial_rotary_factor", 1.0) + sliding_rot_dim = int(sliding_head_dim * sliding_partial) + inv_freq_sliding = 1.0 / ( + sliding_theta ** (torch.arange(0, sliding_rot_dim, 2, dtype=torch.float32) / sliding_rot_dim) + ) + freqs_s = torch.outer(t, inv_freq_sliding) + self._cos_cached_sliding = torch.cos(freqs_s).to(self.data_type).cuda() + self._sin_cached_sliding = torch.sin(freqs_s).to(self.data_type).cuda() + + # Full-attention layers: proportional RoPE, theta=1_000_000, + # partial_rotary_factor=0.25 over global_head_dim=512. + # Proportional semantics (HF transformers): + # rope_angles = int(partial * head_dim // 2) -> 64 + # inv_freq[0:rope_angles] = 1 / base ** (arange(0, 2*rope_angles, 2) / head_dim) + # inv_freq[rope_angles:head_dim//2] = 0 (identity rotation for "no-pe" dims) + full_params = rope_params["full_attention"] + full_head_dim = self.config["global_head_dim"] + full_theta = full_params["rope_theta"] + full_partial = full_params.get("partial_rotary_factor", 1.0) + rope_type = full_params.get("rope_type", "default") + if rope_type == "proportional": + rope_angles = int(full_partial * full_head_dim // 2) + inv_freq_rot = 1.0 / ( + full_theta ** (torch.arange(0, 2 * rope_angles, 2, dtype=torch.float32) / full_head_dim) + ) + nope_angles = full_head_dim // 2 - rope_angles + if nope_angles > 0: + inv_freq_full = torch.cat([inv_freq_rot, torch.zeros(nope_angles, dtype=torch.float32)]) + else: + inv_freq_full = inv_freq_rot + else: + full_rot_dim = int(full_head_dim * full_partial) + inv_freq_full = 1.0 / (full_theta ** (torch.arange(0, full_rot_dim, 2, dtype=torch.float32) / full_rot_dim)) + + freqs_f = torch.outer(t, inv_freq_full) + self._cos_cached_full = torch.cos(freqs_f).to(self.data_type).cuda() + self._sin_cached_full = torch.sin(freqs_f).to(self.data_type).cuda() + return diff --git a/lightllm/models/gemma4/triton_kernel/__init__.py b/lightllm/models/gemma4/triton_kernel/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py new file mode 100644 index 0000000000..bb5f383611 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/build_b_image_token_end.py @@ -0,0 +1,172 @@ +"""GPU-resident builder for ``b_image_token_end``. + +Replaces a 3× D2H sync + Python per-batch-image slice-fill in CPU memory +with a single small H2D copy (image metadata) + one Triton kernel that +scatters the image-end markers into the flat-Q-token tensor on GPU. + +Adapted from neo_chat_moe's `get_neo_position_triton`. Same per-batch +program structure; we only emit the `b_image_token_end` scatter (no 3D +position_ids — gemma-4 uses 1D position ids). +""" + +import torch +import triton +import triton.language as tl + + +@triton.jit +def _build_b_image_token_end_kernel( + B_Image_Start_Idx, # (num_imgs,) int32, image span start in absolute request position + B_Image_Len, # (num_imgs,) int32, image token count + B_Image_Nums, # (batch,) int32, per-batch image count + B_Image_Start_Num, # (batch,) int32, prefix-sum offset into flat per-image arrays + B_Q_Start_Loc, # (batch,) int32, per-batch start in flat layout + B_Ready_Cache_Len, # (batch,) int32, per-batch prompt-cache length + B_Q_Seq_Len, # (batch,) int32, per-batch new-token count + B_Image_Token_End, # (sum_q,) int32, output scatter target + BLOCK_SIZE: tl.constexpr, +): + cur_batch = tl.program_id(0) + cache_len = tl.load(B_Ready_Cache_Len + cur_batch) + q_seq_len = tl.load(B_Q_Seq_Len + cur_batch) + image_num = tl.load(B_Image_Nums + cur_batch) + image_start_num = tl.load(B_Image_Start_Num + cur_batch) + flat_start = tl.load(B_Q_Start_Loc + cur_batch) + + for i in range(image_num): + image_start_idx = tl.load(B_Image_Start_Idx + image_start_num + i) + image_len = tl.load(B_Image_Len + image_start_num + i) + image_end_idx = image_start_idx + image_len + # Flat layout offset of the image's first token within this batch. + flat_image_start = flat_start + image_start_idx - cache_len + + for j in range(0, image_len, BLOCK_SIZE): + off = j + tl.arange(0, BLOCK_SIZE) + in_image = off < image_len + # Only fill positions that fall inside this batch's NEW-tokens range + # (i.e., the part of the image that hasn't already been processed + # in a previous chunked-prefill chunk and isn't past the chunk's end). + in_new_tokens = (image_start_idx - cache_len + off >= 0) & (image_start_idx - cache_len + off < q_seq_len) + tl.store( + B_Image_Token_End + flat_image_start + off, + image_end_idx, + mask=in_image & in_new_tokens, + ) + + +def build_b_image_token_end( + b_image_start_idx: torch.Tensor, + b_image_len: torch.Tensor, + b_image_nums: torch.Tensor, + b_image_start_num: torch.Tensor, + b_q_start_loc: torch.Tensor, + b_ready_cache_len: torch.Tensor, + b_q_seq_len: torch.Tensor, + b_image_token_end: torch.Tensor, +): + batch_size = b_q_start_loc.shape[0] + assert b_image_nums.shape[0] == batch_size + grid = (batch_size,) + BLOCK_SIZE = 64 + _build_b_image_token_end_kernel[grid]( + b_image_start_idx, + b_image_len, + b_image_nums, + b_image_start_num, + b_q_start_loc, + b_ready_cache_len, + b_q_seq_len, + b_image_token_end, + BLOCK_SIZE=BLOCK_SIZE, + ) + + +# --------------------------------------------------------------------------- +# Standalone correctness check +# --------------------------------------------------------------------------- + + +def _reference( + multimodal_params, + b_q_start_loc_cpu, + b_ready_cache_len_cpu, + b_q_seq_len_cpu, + sum_q, +): + out = torch.zeros((sum_q,), dtype=torch.int32) + for batch_idx, params in enumerate(multimodal_params): + cache_len = b_ready_cache_len_cpu[batch_idx] + new_len = b_q_seq_len_cpu[batch_idx] + flat_start = b_q_start_loc_cpu[batch_idx] + for img in params.get("images", []): + image_start_idx = img["start_idx"] + image_end_idx = image_start_idx + img["token_num"] + for j in range(img["token_num"]): + req_off = image_start_idx - cache_len + j + if req_off < 0 or req_off >= new_len: + continue + out[flat_start + req_off] = image_end_idx + return out + + +def _check(): + device = "cuda" + # Two batches. b0 has 1 image overlapping new tokens; b1 has 2 images, one + # fully cached and one in the new-token range. + multimodal = [ + {"images": [{"start_idx": 5, "token_num": 4}]}, # b0: image at req[5..9) + { + "images": [ + {"start_idx": 0, "token_num": 3}, # fully cached + {"start_idx": 8, "token_num": 5}, # in new tokens + ] + }, + ] + b_q_start_loc = torch.tensor([0, 6], dtype=torch.int32) # b0 new=6, b1 new=10 + b_ready_cache_len = torch.tensor([2, 5], dtype=torch.int32) + b_q_seq_len = torch.tensor([6, 10], dtype=torch.int32) + sum_q = int(b_q_seq_len.sum().item()) + + ref = _reference( + multimodal, + b_q_start_loc.tolist(), + b_ready_cache_len.tolist(), + b_q_seq_len.tolist(), + sum_q, + ) + + b_image_start_idx = [] + b_image_len = [] + b_image_nums = [] + b_image_start_num = [] + image_start_num = 0 + for params in multimodal: + b_image_start_num.append(image_start_num) + b_image_nums.append(len(params["images"])) + for img in params["images"]: + b_image_start_idx.append(img["start_idx"]) + b_image_len.append(img["token_num"]) + image_start_num += 1 + + out_gpu = torch.zeros((sum_q,), dtype=torch.int32, device=device) + build_b_image_token_end( + b_image_start_idx=torch.tensor(b_image_start_idx, dtype=torch.int32, device=device), + b_image_len=torch.tensor(b_image_len, dtype=torch.int32, device=device), + b_image_nums=torch.tensor(b_image_nums, dtype=torch.int32, device=device), + b_image_start_num=torch.tensor(b_image_start_num, dtype=torch.int32, device=device), + b_q_start_loc=b_q_start_loc.to(device), + b_ready_cache_len=b_ready_cache_len.to(device), + b_q_seq_len=b_q_seq_len.to(device), + b_image_token_end=out_gpu, + ) + + out_cpu = out_gpu.cpu() + assert torch.equal(out_cpu, ref), f"\n got {out_cpu.tolist()}\n ref {ref.tolist()}" + print("ok", out_cpu.tolist()) + + +if __name__ == "__main__": + if torch.cuda.is_available(): + _check() + else: + print("No CUDA, skip.") diff --git a/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py new file mode 100644 index 0000000000..b0ab70d7c7 --- /dev/null +++ b/lightllm/models/gemma4/triton_kernel/context_attention_fwd_gemma4_mm.py @@ -0,0 +1,470 @@ +"""Gemma-4 prefill attention kernel with image bidirectional masking. + +Gemma-4 was trained with bidirectional attention inside each image span on its +sliding-window layers (matches HF/vllm `use_bidirectional_attention="vision"`). +Other lightllm multimodal models use causal attention on image tokens, so the +shared prefill kernel does not need this — keep the modification scoped to +this gemma4-private file rather than the common path. + +The kernel mirrors `context_flashattention_nopad._fwd_kernel` (paged KV via +req_to_token_indexs, prompt_cache_len for chunked prefill, sliding window +support, head_dim=256/512 with BLOCK_M reduction) and adds two ideas borrowed +from `lightllm-neo/.../context_attention_fwd_neo`: + +1. Per-Q `b_image_token_end` tensor of shape (sum_q,). For Q tokens inside an + image span it carries the span's end index; for text tokens it is 0. + The attention mask becomes `causal_mask | (k_pos < q_image_end)`. +2. K/V iteration upper bound is extended to `max(causal_end, block_image_end)` + so a Q tile in the middle of an image span actually loads K/V tiles past + its causal end. Without this, the bidi mask in the original diff was a + no-op on every tile but the last one of the image span. + +The standalone `reference_attention` and `check_once` are runnable as a script +for unit testing image bidi correctness. +""" + +import math +import torch +import triton +import triton.language as tl + +from lightllm.utils.device_utils import is_tesla + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + Out, + B_Start_Loc, + B_Seqlen, + Req_to_tokens, + B_req_idx, + B_Image_Token_End, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + H: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + USE_SLIDING_WINDOW: tl.constexpr, + SLIDING_WINDOW_SIZE: tl.constexpr, +): + start_m = tl.program_id(0) + cur_bh = tl.program_id(1) + cur_batch = cur_bh // H + cur_head = cur_bh % H + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + total_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = total_len - prompt_cache_len # new tokens this step + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + if block_start_loc >= cur_batch_seq_len: + return + + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = block_start_loc + tl.arange(0, BLOCK_M) + q_valid = offs_m < cur_batch_seq_len + + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + q = tl.load(Q + off_q, mask=q_valid[:, None], other=0.0) + + # Per-Q image_end. 0 for non-image tokens, image-span end for image tokens. + q_image_end = tl.load( + B_Image_Token_End + cur_batch_in_all_start_index + offs_m, + mask=q_valid, + other=0, + ).to(tl.int32) + + # Absolute position in the request (prompt_cache_len + offset within new tokens). + q_pos = prompt_cache_len + offs_m # [M] + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + causal_end = tl.minimum(prompt_cache_len + block_start_loc + BLOCK_M, total_len) + block_image_end = tl.minimum(tl.max(q_image_end, axis=0), total_len) + block_end_loc = tl.maximum(causal_end, block_image_end) + + if USE_SLIDING_WINDOW: + kv_start_index = block_start_loc + prompt_cache_len - SLIDING_WINDOW_SIZE + kv_start_index = tl.maximum(kv_start_index, 0) + block_kv_len = block_end_loc - kv_start_index + else: + kv_start_index = 0 + block_kv_len = block_end_loc + + for start_n in range(0, block_kv_len, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + k_pos = kv_start_index + start_n + offs_n # [N] + k_valid = k_pos < block_end_loc + + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * k_pos, + mask=k_valid, + other=0, + ).to(tl.int64) + + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load(K + off_k, mask=k_valid[None, :], other=0.0) + qk = tl.dot(q, k) + + causal_mask = q_pos[:, None] >= k_pos[None, :] + if USE_SLIDING_WINDOW: + # SLIDING_WINDOW_SIZE is the FA-style offset (window = offset + 1 tokens). + causal_mask = causal_mask & ((q_pos[:, None] - k_pos[None, :]) <= SLIDING_WINDOW_SIZE) + # Image bidi: a Q in image span [_, e) attends to all K with k_pos < e. + # For text Q (q_image_end == 0) this is k_pos < 0 = always False, so + # the union with causal_mask leaves text-attention unchanged. + image_mask = k_pos[None, :] < q_image_end[:, None] + mask = (causal_mask | image_mask) & k_valid[None, :] + + qk = tl.where(mask, qk * sm_scale, -1.0e8) + + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk -= m_ij[:, None] + p = tl.math.exp2(qk) + l_ij = tl.sum(p, 1) + + alpha = tl.math.exp2(m_i - m_ij) + l_i = l_i * alpha + l_ij + acc = acc * alpha[:, None] + + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load(V + off_v, mask=k_valid[:, None], other=0.0) + p = p.to(v.dtype) + acc = tl.dot(p, v, acc) + + m_i = m_ij + + acc = acc / l_i[:, None] + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + tl.store(Out + off_o, acc, mask=q_valid[:, None]) + + +@torch.no_grad() +def context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + b_image_token_end, + sliding_window: int = -1, +): + """Prefill attention with image bidirectional masking on sliding layers. + + Args: + b_image_token_end: int32 tensor of shape (sum_q,). For each Q token + position (in the flattened new-token layout), value is the image + span's end index (in absolute request position) if the token is + inside an image span, else 0. + """ + BLOCK_M = 128 if not is_tesla() else 64 + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128, 256, 512} + if Lk >= 512: + BLOCK_M = min(BLOCK_M, 32) + elif Lk >= 256: + BLOCK_M = min(BLOCK_M, 64) + + sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = lambda meta: (triton.cdiv(max_input_len, meta["BLOCK_M"]), batch * head, 1) + BLOCK_N = BLOCK_M + num_warps = 4 if Lk <= 64 else 8 + num_stages = 1 + use_sliding_window = sliding_window >= 0 + sliding_window_size = int(sliding_window) if use_sliding_window else 0 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + o, + b_start_loc, + b_seq_len, + req_to_token_indexs, + b_req_idx, + b_image_token_end, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + H=head, + BLOCK_DMODEL=Lk, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + USE_SLIDING_WINDOW=use_sliding_window, + SLIDING_WINDOW_SIZE=sliding_window_size, + num_warps=num_warps, + num_stages=num_stages, + ) + + +# --------------------------------------------------------------------------- +# Reference implementation + standalone test harness +# --------------------------------------------------------------------------- + + +def reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=-1, +): + """Slow torch reference for the gemma4 mm prefill kernel. + + `sliding_window` is the FA-style offset (window = sliding_window + 1 tokens). + < 0 disables SWA. + """ + device = q.device + dtype = q.dtype + sum_q, Hq, D = q.shape + Hk = k.shape[1] + kv_group_num = Hq // Hk + + out = torch.empty_like(q) + scale = 1.0 / math.sqrt(D) + + batch = b_seq_len.shape[0] + for b in range(batch): + req = int(b_req_idx[b].item()) + total_len = int(b_seq_len[b].item()) + prompt_len = int(b_prompt_cache_len[b].item()) + new_len = total_len - prompt_len + q_start = int(b_start_loc[b].item()) + + q_blk = q[q_start : q_start + new_len] # [M, Hq, D] + q_image_end = b_image_token_end[q_start : q_start + new_len].to(torch.int64) # [M] + + token_locs = req_to_token_indexs[req, :total_len].to(torch.int64) + k_blk = k[token_locs] + v_blk = v[token_locs] + + k_hq = k_blk.repeat_interleave(kv_group_num, dim=1) + v_hq = v_blk.repeat_interleave(kv_group_num, dim=1) + + q_pos = torch.arange(prompt_len, total_len, device=device, dtype=torch.int64) + k_pos = torch.arange(0, total_len, device=device, dtype=torch.int64) + + causal = k_pos[None, :] <= q_pos[:, None] + if sliding_window >= 0: + causal = causal & ((q_pos[:, None] - k_pos[None, :]) <= sliding_window) + image = k_pos[None, :] < q_image_end[:, None] + allow = causal | image + + q_t = q_blk.permute(1, 0, 2).to(torch.float32) + k_t = k_hq.permute(1, 2, 0).to(torch.float32) + scores = torch.matmul(q_t, k_t) * scale + + neg = torch.tensor(-1.0e9, device=device, dtype=torch.float32) + scores = torch.where(allow[None, :, :], scores, neg) + p = torch.softmax(scores, dim=-1) + v_t = v_hq.permute(1, 0, 2).to(torch.float32) + out_hq = torch.matmul(p, v_t) + out[q_start : q_start + new_len] = out_hq.permute(1, 0, 2).to(dtype) + + return out + + +def make_test_case( + device="cuda", + dtype=torch.bfloat16, + batch=3, + Hq=8, + Hk=4, + D=256, + seed=0, + base_index=50000, + sliding_window=-1, +): + torch.manual_seed(seed) + + prompt_lens = torch.randint(low=0, high=8, size=(batch,), device=device) + new_lens = torch.randint(low=4, high=24, size=(batch,), device=device) + total_lens = (prompt_lens + new_lens).to(torch.int32) + max_total_len = int(total_lens.max().item()) + max_new_len = int(new_lens.max().item()) + + b_start_loc = torch.zeros((batch,), device=device, dtype=torch.int32) + cur = 0 + for b in range(batch): + b_start_loc[b] = cur + cur += int(new_lens[b].item()) + sum_q = cur + + b_seq_len = total_lens + b_prompt_cache_len = prompt_lens.to(torch.int32) + b_req_idx = torch.arange(batch, device=device, dtype=torch.int32) + + sum_kv = int(total_lens.sum().item()) + kv_size = base_index + sum_kv + 1024 + pool = torch.randperm(kv_size - base_index, device=device, dtype=torch.int64)[:sum_kv] + base_index + + req_to_token_indexs = torch.zeros((batch, max_total_len), device=device, dtype=torch.int32) + p = 0 + for r in range(batch): + L = int(total_lens[r].item()) + req_to_token_indexs[r, :L] = pool[p : p + L].to(torch.int32) + p += L + + # Inject one image span per batch into the new-token region with prob 0.7. + b_image_token_end = torch.zeros((sum_q,), device=device, dtype=torch.int32) + for b in range(batch): + M = int(new_lens[b].item()) + P = int(prompt_lens[b].item()) + start = int(b_start_loc[b].item()) + if M >= 4 and torch.rand((), device=device).item() > 0.3: + s = int(torch.randint(0, M - 2, (1,), device=device).item()) + span_len = int(torch.randint(2, max(3, M - s + 1), (1,), device=device).item()) + e = min(M, s + span_len) + # image_end is absolute (request-position) = prompt_len + new-offset + b_image_token_end[start + s : start + e] = P + e + + q = torch.randn((sum_q, Hq, D), device=device, dtype=dtype) + k = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + v = torch.randn((kv_size, Hk, D), device=device, dtype=dtype) + o = torch.empty((sum_q, Hq, D), device=device, dtype=dtype) + + return ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) + + +def check_once(seed=0, dtype=torch.bfloat16, sliding_window=-1, D=256): + case = make_test_case(seed=seed, dtype=dtype, sliding_window=sliding_window, D=D) + ( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window, + ) = case + + context_attention_fwd_gemma4_mm( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + max_new_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + ref = reference_attention( + q, + k, + v, + b_req_idx, + b_start_loc, + b_seq_len, + b_prompt_cache_len, + req_to_token_indexs, + b_image_token_end, + sliding_window=sliding_window, + ) + + diff = (o - ref).abs() + max_abs = diff.max().item() + denom = ref.abs().max().item() + 1e-6 + max_rel = max_abs / denom + has_image = (b_image_token_end > 0).any().item() + print( + f"seed={seed} dtype={dtype} D={D} sw={sliding_window} has_image={has_image} " + f"max_abs={max_abs:.4e} max_rel={max_rel:.4e}" + ) + assert max_abs < 5e-2, f"max_abs too large: {max_abs}" + + +if __name__ == "__main__": + if not torch.cuda.is_available(): + print("No CUDA, skip.") + else: + # Vary D, sliding window, and image presence. + for seed in (0, 1, 2): + check_once(seed=seed, D=128, sliding_window=-1) + check_once(seed=seed, D=128, sliding_window=4096) + check_once(seed=seed, D=256, sliding_window=4096) + print("ok") diff --git a/lightllm/models/llama/triton_kernel/rotary_emb.py b/lightllm/models/llama/triton_kernel/rotary_emb.py index c6d4f3010d..f87b9d9e02 100755 --- a/lightllm/models/llama/triton_kernel/rotary_emb.py +++ b/lightllm/models/llama/triton_kernel/rotary_emb.py @@ -23,6 +23,7 @@ def _rotary_kernel( max_total_len, HEAD_Q, HEAD_K, # N_CTX 代表要计算的上下文长度 + HAS_K: tl.constexpr, BLOCK_HEAD: tl.constexpr, BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, @@ -73,55 +74,59 @@ def _rotary_kernel( Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) ) - off_k0 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range0[None, None, :] * stride_kd - ) - off_k1 = ( - cur_seq_range[:, None, None] * stride_kbs - + cur_head_range[None, :, None] * stride_kh - + dim_range1[None, None, :] * stride_kd - ) - - off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd - - k0 = tl.load( - K + off_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - k1 = tl.load( - K + off_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - other=0.0, - ) - cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) - - out_k0 = k0 * cos - k1 * sin - out_k1 = k0 * sin + k1 * cos - - tl.store( - K + off_k0, - out_k0, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) - tl.store( - K + off_k1, - out_k1, - mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), - ) + if HAS_K: + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) return @torch.no_grad() -def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): total_len = q.shape[0] - head_num_q, head_num_k = q.shape[1], k.shape[1] + has_k = k is not None + head_num_q = q.shape[1] + head_num_k = k.shape[1] if has_k else 0 head_dim = int(q.shape[2] * partial_rotary_factor) assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" - assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + if has_k: + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" BLOCK_SEQ = 16 BLOCK_HEAD = 4 @@ -139,9 +144,9 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): q.stride(0), q.stride(1), q.stride(2), - k.stride(0), - k.stride(1), - k.stride(2), + k.stride(0) if has_k else 0, + k.stride(1) if has_k else 0, + k.stride(2) if has_k else 0, cos.stride(0), cos.stride(1), sin.stride(0), @@ -149,6 +154,7 @@ def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.): total_len, head_num_q, head_num_k, + HAS_K=has_k, BLOCK_HEAD=BLOCK_HEAD, BLOCK_SEQ=BLOCK_SEQ, BLOCK_DMODEL=head_dim, diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 9b9fe2569c..ce09632d2c 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -81,6 +81,7 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei tp_text_start_token_id=layer_weight.wte_weight_.tp_vocab_start_id, tp_text_end_token_id=layer_weight.wte_weight_.tp_vocab_end_id, tp_world_size=self.tp_world_size_, + text_embed_scale=getattr(self, "multimodal_text_embed_scale_", 1.0), ) if self.tp_world_size_ > 1: all_reduce(out, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 4a345000b0..7af5aa6b89 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -193,6 +193,7 @@ def make_argument_parser() -> argparse.ArgumentParser: "step3", "nano_v3", "interns1", + "gemma4", ], default=None, help="reasoning parser type", diff --git a/lightllm/server/api_openai.py b/lightllm/server/api_openai.py index fe4f3b50b0..2f79d730d7 100644 --- a/lightllm/server/api_openai.py +++ b/lightllm/server/api_openai.py @@ -165,8 +165,8 @@ def _is_force_thinking_mode(request: ChatCompletionRequest) -> bool: return False if reasoning_parser in ["deepseek-v3"]: return request.chat_template_kwargs is not None and request.chat_template_kwargs.get("thinking") is True - if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1"]: - # qwen3, glm45, nano_v3, and interns1 are reasoning by default + if reasoning_parser in ["qwen3", "glm45", "nano_v3", "interns1", "gemma4"]: + # qwen3, glm45, nano_v3, interns1, and gemma4 are reasoning by default; return not request.chat_template_kwargs or request.chat_template_kwargs.get("enable_thinking", True) is True return True # default @@ -315,6 +315,16 @@ async def chat_completions_impl(request: ChatCompletionRequest, raw_request: Req "seed": request.seed, } + # Gemma-4's reasoning delimiters (<|channel>=100, =101) are + # special tokens. The default skip_special_tokens=True would drop them + # from the decoded stream and the Gemma4Detector would be unable to + # find the reasoning boundary. Mirrors vllm's + # Gemma4ReasoningParser.adjust_request behaviour. Only applied when no + # explicit value is supplied so callers can still opt back into the + # default if they want. + if get_env_start_args().reasoning_parser == "gemma4" and "skip_special_tokens" not in sampling_params_dict: + sampling_params_dict["skip_special_tokens"] = False + if request.max_completion_tokens is not None: sampling_params_dict["max_new_tokens"] = request.max_completion_tokens elif request.max_tokens is not None: diff --git a/lightllm/server/build_prompt.py b/lightllm/server/build_prompt.py index 913cb67107..7fad0a4bc9 100644 --- a/lightllm/server/build_prompt.py +++ b/lightllm/server/build_prompt.py @@ -106,11 +106,31 @@ def _alias_reasoning_to_reasoning_content(messages: list) -> None: msg["reasoning_content"] = reasoning +def _normalize_multimodal_content_types(messages: list) -> None: + # OpenAI requests use content part types like `image_url` and `audio_url`. + # Model chat templates generally render modality tokens from `image` and + # `audio` parts while the raw media payload is carried separately in + # MultimodalParams. Preserve the original fields and normalize only the + # template-facing type to keep prompt tags aligned with media counts. + for msg in messages: + content = msg.get("content") + if not isinstance(content, list): + continue + for part in content: + if not isinstance(part, dict): + continue + if part.get("type") == "image_url": + part["type"] = "image" + elif part.get("type") == "audio_url": + part["type"] = "audio" + + async def build_prompt(request, tools) -> str: # pydantic格式转成dict, 否则,当根据tokenizer_config.json拼template时,Jinja判断无法识别 messages = [m.model_dump(by_alias=True, exclude_none=True) for m in request.messages] _normalize_tool_call_arguments(messages) _alias_reasoning_to_reasoning_content(messages) + _normalize_multimodal_content_types(messages) kwargs = {"conversation": messages} if request.character_settings: diff --git a/lightllm/server/reasoning_parser.py b/lightllm/server/reasoning_parser.py index 024be4f769..fc80cb2fa6 100644 --- a/lightllm/server/reasoning_parser.py +++ b/lightllm/server/reasoning_parser.py @@ -862,6 +862,33 @@ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False) ) +class Gemma4Detector(BaseReasoningFormatDetector): + """ + Detector for Google Gemma-4 thinking models. + + Format: ``<|channel>thought\\n...reasoning...\\nanswer``. + Role label ``thought\\n`` is baked into the start token (cf. + GptOssDetector) so the base class strips it for free. + + Note: ``<|channel>`` and ```` are special tokens (ids 100/101). + The API layer forces ``skip_special_tokens=False`` when this parser is + active so the delimiters survive decoding (see ``api_openai.py``). + """ + + THINK_START_TOKEN = "<|channel>thought\n" + THINK_END_TOKEN = "" + + def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False): + # force_reasoning ignored: Gemma-4's template never starts generation + # inside an open channel (ReasoningParser pins it to False too). + super().__init__( + self.THINK_START_TOKEN, + self.THINK_END_TOKEN, + force_reasoning=False, + stream_reasoning=stream_reasoning, + ) + + class ReasoningParser: """ Parser that handles both streaming and non-streaming scenarios for extracting @@ -887,6 +914,7 @@ class ReasoningParser: "step3": DeepSeekR1Detector, "nano_v3": NanoV3Detector, "interns1": Qwen3Detector, + "gemma4": Gemma4Detector, } def __init__( @@ -905,6 +933,12 @@ def __init__( # Special cases where we override force_reasoning if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}: force_reasoning = True + elif model_type.lower() == "gemma4": + # Gemma-4's chat template never positions generation inside an open + # channel — see Gemma4Detector docstring. Pin to False so a + # request_enable_reasoning=True from the caller can't accidentally + # mark the parser as already inside reasoning. + force_reasoning = False # Only pass force_reasoning if explicitly set, let detectors use their defaults kwargs = {"stream_reasoning": stream_reasoning} diff --git a/lightllm/server/tokenizer.py b/lightllm/server/tokenizer.py index 25726b2578..c353ee6d35 100644 --- a/lightllm/server/tokenizer.py +++ b/lightllm/server/tokenizer.py @@ -31,6 +31,7 @@ from ..models.qwen3_vl.model import QWen3VLTokenizer from ..models.internvl.model import InternvlTokenizer from ..models.gemma3.model import Gemma3Tokenizer +from ..models.gemma4.model import Gemma4Tokenizer from ..models.qwen3_omni_moe_thinker.model import QWen3OmniTokenizer # A fast LLaMA tokenizer with the pre-processed `tokenizer.json` file. @@ -130,5 +131,13 @@ def get_tokenizer( tokenizer = InternvlTokenizer(tokenizer, model_cfg, weight_dir=tokenizer_name) elif model_type == "gemma3": tokenizer = Gemma3Tokenizer(tokenizer, model_cfg) + elif model_type == "gemma4": + image_processor = None + if "vision_config" in model_cfg and model_cfg["vision_config"] is not None: + from transformers import AutoProcessor + + processor = AutoProcessor.from_pretrained(tokenizer_name) + image_processor = processor.image_processor + tokenizer = Gemma4Tokenizer(tokenizer, model_cfg, image_processor=image_processor) return tokenizer diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 92ca2e3836..50bc12fd23 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -13,6 +13,7 @@ from lightllm.models.llava.llava_visual import LlavaVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.gemma3.gemma3_visual import Gemma3VisionModel +from lightllm.models.gemma4.gemma4_visual import Gemma4VisionModel from lightllm.models.vit.model import VisionTransformer from lightllm.server.multimodal_params import MultimodalParams, ImageItem from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel @@ -97,6 +98,8 @@ def exposed_init_model(self, kvargs): # self.model = InternVLVisionModel() elif self.model_type == "gemma3": self.model = Gemma3VisionModel() + elif self.model_type == "gemma4": + self.model = Gemma4VisionModel(data_type=kvargs["data_type"]) elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py index c64e8a912b..548a36aeb0 100644 --- a/lightllm/utils/config_utils.py +++ b/lightllm/utils/config_utils.py @@ -236,12 +236,28 @@ def get_eos_token_ids(model_path: str) -> Optional[List[int]]: eos_token_id = _get_config_llm_keyvalue(model_path=model_path, key_name=["eos_token_id"]) if isinstance(eos_token_id, int): - return [eos_token_id] - if isinstance(eos_token_id, list): - return eos_token_id + eos_token_ids = [eos_token_id] + elif isinstance(eos_token_id, list): + eos_token_ids = list(eos_token_id) + else: + raise ValueError("error eos_token_id format in config.json") + + generation_config_path = os.path.join(model_path, "generation_config.json") + if os.path.exists(generation_config_path): + try: + with open(generation_config_path, "r") as file: + generation_eos = json.load(file).get("eos_token_id") + except Exception as exc: + logger.warning(f"failed to load eos_token_id from generation_config.json: {exc}") + generation_eos = None + if isinstance(generation_eos, int): + generation_eos = [generation_eos] + if isinstance(generation_eos, list): + for token_id in generation_eos: + if isinstance(token_id, int) and token_id not in eos_token_ids: + eos_token_ids.append(token_id) - assert False, "error eos_token_id format in config.json" - return + return eos_token_ids def get_model_architectures(model_path: str): @@ -327,6 +343,9 @@ def has_vision_module(model_path: str) -> bool: return True elif model_type == "gemma3": return True + elif model_type == "gemma4": + model_cfg["vision_config"] + return model_cfg["vision_config"] is not None elif ( model_cfg.get("thinker_config", {}).get("vision_config", {}).get("model_type") == "qwen3_omni_moe_vision_encoder" @@ -450,4 +469,8 @@ def get_reasoning_parser_for_model(model_path: str) -> Optional[str]: if model_type == "deepseek_r1": return "deepseek-r1" + # Gemma-4 (all variants share the same Harmony-like <|channel>... format) + if model_type == "gemma4": + return "gemma4" + return None