From 03e5207fa68daadfef47c5ed1c2c134d6dc9f930 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Tue, 22 Jul 2025 22:22:42 +0000 Subject: [PATCH] [WIP] feat: ExLlamaV3 quantization format --- CMakeLists.txt | 1 + aphrodite/_custom_ops.py | 33 ++ aphrodite/quantization/__init__.py | 3 + aphrodite/quantization/exl3.py | 427 +++++++++++++++++++++++++ kernels/ops.h | 9 + kernels/quantization/exl3/exl3_gemm.cu | 268 ++++++++++++++++ kernels/torch_bindings.cpp | 12 + 7 files changed, 753 insertions(+) create mode 100644 aphrodite/quantization/exl3.py create mode 100644 kernels/quantization/exl3/exl3_gemm.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 8a9717d9a0..28eb609366 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -242,6 +242,7 @@ set(APHRODITE_EXT_SRC "kernels/quantization/fp8/common.cu" "kernels/quantization/fused_kernels/fused_layernorm_dynamic_per_token_quant.cu" "kernels/quantization/gguf/gguf_kernel.cu" + "kernels/quantization/exl3/exl3_gemm.cu" "kernels/quantization/activation_kernels.cu" "kernels/cuda_utils_kernels.cu" "kernels/prepare_inputs/advance_step.cu" diff --git a/aphrodite/_custom_ops.py b/aphrodite/_custom_ops.py index 6b535d3777..898176cf1d 100644 --- a/aphrodite/_custom_ops.py +++ b/aphrodite/_custom_ops.py @@ -319,6 +319,39 @@ def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) +# exl3 +def exl3_gemm(input: torch.Tensor, trellis: torch.Tensor, suh: torch.Tensor, + svh: torch.Tensor, mcg_mult: int, mul1_mult: int) -> torch.Tensor: + return torch.ops._C.exl3_gemm(input, trellis, suh, svh, mcg_mult, mul1_mult) + + +def exl3_reconstruct(trellis: torch.Tensor, in_features: int, out_features: int, + mcg_mult: int, mul1_mult: int) -> torch.Tensor: + return torch.ops._C.exl3_reconstruct(trellis, in_features, out_features, + mcg_mult, mul1_mult) + + +if hasattr(torch.ops._C, "exl3_gemm"): + + @register_fake("_C::exl3_gemm") + def _exl3_gemm_fake(input: torch.Tensor, trellis: torch.Tensor, + suh: torch.Tensor, svh: torch.Tensor, mcg_mult: int, + mul1_mult: int) -> torch.Tensor: + batch_size = input.size(0) + out_features = svh.size(0) + return torch.empty((batch_size, out_features), + dtype=torch.float32, + device=input.device) + + @register_fake("_C::exl3_reconstruct") + def _exl3_reconstruct_fake(trellis: torch.Tensor, in_features: int, + out_features: int, mcg_mult: int, + mul1_mult: int) -> torch.Tensor: + return torch.empty((in_features, out_features), + dtype=torch.float16, + device=trellis.device) + + # squeezellm def squeezellm_gemm(vec: torch.Tensor, mat: torch.Tensor, mul: torch.Tensor, lookup_table: torch.Tensor) -> None: diff --git a/aphrodite/quantization/__init__.py b/aphrodite/quantization/__init__.py index ad59f56e80..8f1990ae93 100644 --- a/aphrodite/quantization/__init__.py +++ b/aphrodite/quantization/__init__.py @@ -30,6 +30,7 @@ "quark", "moe_wna16", "torchao", + "exl3", "fp2", "fp3", "fp4", @@ -94,6 +95,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: from .compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsConfig) from .deepspeedfp import DeepSpeedFPConfig + from .exl3 import EXL3Config from .experts_int8 import ExpertsInt8Config from .fbgemm_fp8 import FBGEMMFp8Config from .fp6 import QuantLLMFPConfig @@ -142,6 +144,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: "quark": QuarkConfig, "moe_wna16": MoeWNA16Config, "torchao": TorchAOConfig, + "exl3": EXL3Config, "fp2": QuantLLMFPConfig, "fp3": QuantLLMFPConfig, "fp4": QuantLLMFPConfig, diff --git a/aphrodite/quantization/exl3.py b/aphrodite/quantization/exl3.py new file mode 100644 index 0000000000..fc6df20547 --- /dev/null +++ b/aphrodite/quantization/exl3.py @@ -0,0 +1,427 @@ +from typing import Any, Dict, List, Optional + +import torch +from torch.nn.parameter import Parameter + +from aphrodite import _custom_ops as ops +from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase +from aphrodite.modeling.utils import set_weight_attrs +from aphrodite.quantization import QuantizationMethods +from aphrodite.quantization.base_config import QuantizationConfig + +# Try to import EXL3 CUDA operations - fallback if not available +try: + _EXL3_KERNELS_AVAILABLE = hasattr(ops, 'exl3_gemm') +except (ImportError, AttributeError): + _EXL3_KERNELS_AVAILABLE = False + + +class EXL3Config(QuantizationConfig): + """Config class for EXL3 quantization. + + EXL3 is based on the QTIP quantization method from Cornell RelaxML, + using trellis-based encoding with Hadamard transformations. + """ + + def __init__( + self, + bits: float, + head_bits: int = 6, + calibration: Optional[Dict[str, Any]] = None, + **kwargs + ) -> None: + super().__init__() + self.bits = bits + self.head_bits = head_bits + self.calibration = calibration or {} + + # Validate bits per weight + if not (1.0 <= bits <= 8.0): + raise ValueError( + f"EXL3 bits per weight must be between 1.0 and 8.0, " + f"got {bits}") + + def __repr__(self) -> str: + return (f"EXL3Config(bits={self.bits}, " + f"head_bits={self.head_bits})") + + @classmethod + def get_name(cls) -> QuantizationMethods: + return "exl3" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + # EXL3 requires modern GPU for optimal performance + return 80 + + @staticmethod + def get_config_filenames() -> List[str]: + return [ + "quantization_config.json", + "config.json" + ] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "EXL3Config": + # Handle both quantization_config.json and config.json formats + quant_config = config.get("quantization_config", config) + + bits = cls.get_from_keys(quant_config, ["bits"]) + head_bits = cls.get_from_keys_or(quant_config, ["head_bits"], 6) + calibration = cls.get_from_keys_or(quant_config, ["calibration"], None) + + return cls(bits=bits, head_bits=head_bits, calibration=calibration) + + def get_quant_method(self, layer: torch.nn.Module, + prefix: str) -> Optional["EXL3LinearMethod"]: + if isinstance(layer, LinearBase): + return EXL3LinearMethod(self) + return None + + +class EXL3LinearMethod(LinearMethodBase): + """Linear method for EXL3 quantization. + + Implements the EXL3 trellis-based quantization with Hadamard transforms + for efficient GPU inference. + """ + + def __init__(self, quant_config: EXL3Config): + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ): + del output_size # Unused + weight_loader = extra_weight_attrs.get("weight_loader") + output_size_per_partition = sum(output_partition_sizes) + + # Calculate tile dimensions for EXL3 format + # EXL3 uses 16x16 tiles for optimal GPU memory access + tiles_k = (input_size_per_partition + 15) // 16 + tiles_n = (output_size_per_partition + 15) // 16 + + # K parameter determines bits per weight via trellis depth + # This is derived from the target bits in the config + K = max(1, int(self.quant_config.bits)) + + # Main quantized weight tensor (trellis format) + # Initialize with proper shape but will be loaded from checkpoint + trellis = Parameter( + torch.empty(tiles_k, tiles_n, K * 16, dtype=torch.int16), + requires_grad=False + ) + set_weight_attrs(trellis, { + "input_dim": 0, + "output_dim": 1, + "weight_loader": self._make_weight_loader("trellis", weight_loader, layer) + }) + + # Input channel Hadamard sign factors (unpacked format preferred) + suh = Parameter( + torch.empty(input_size_per_partition, dtype=torch.half), + requires_grad=False + ) + set_weight_attrs(suh, { + "input_dim": 0, + "weight_loader": self._make_weight_loader("suh", weight_loader, layer) + }) + + # Output channel Hadamard sign factors (unpacked format preferred) + svh = Parameter( + torch.empty(output_size_per_partition, dtype=torch.half), + requires_grad=False + ) + set_weight_attrs(svh, { + "output_dim": 0, + "weight_loader": self._make_weight_loader("svh", weight_loader, layer) + }) + + # Optional packed versions (for legacy compatibility) + su = Parameter( + torch.empty((input_size_per_partition + 15) // 16, dtype=torch.int16), + requires_grad=False + ) + set_weight_attrs(su, { + "input_dim": 0, + "weight_loader": self._make_weight_loader("su", weight_loader, layer) + }) + + sv = Parameter( + torch.empty((output_size_per_partition + 15) // 16, dtype=torch.int16), + requires_grad=False + ) + set_weight_attrs(sv, { + "output_dim": 0, + "weight_loader": self._make_weight_loader("sv", weight_loader, layer) + }) + + # Experimental multipliers (optional) + mcg = Parameter( + torch.tensor(0, dtype=torch.int32), + requires_grad=False + ) + set_weight_attrs(mcg, {"weight_loader": self._make_weight_loader("mcg", weight_loader, layer)}) + + mul1 = Parameter( + torch.tensor(0, dtype=torch.int32), + requires_grad=False + ) + set_weight_attrs(mul1, {"weight_loader": self._make_weight_loader("mul1", weight_loader, layer)}) + + # Store tensor shapes and metadata + layer.K = K + layer.tiles_k = tiles_k + layer.tiles_n = tiles_n + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + + # Initialize multiplier values (will be updated during weight loading) + layer.mcg_mult = 0 + layer.mul1_mult = 0 + + # Register all parameters + layer.register_parameter("trellis", trellis) + layer.register_parameter("suh", suh) + layer.register_parameter("svh", svh) + layer.register_parameter("su", su) + layer.register_parameter("sv", sv) + layer.register_parameter("mcg", mcg) + layer.register_parameter("mul1", mul1) + + def _make_weight_loader(self, tensor_name: str, default_loader, layer): + """Create a weight loader that handles EXL3-specific tensor loading.""" + def exl3_weight_loader(param: Parameter, loaded_weight: torch.Tensor, shard_id: Optional[str] = None): + # Handle different tensor name patterns from EXL3 checkpoints + if tensor_name in ["trellis", "suh", "svh", "su", "sv", "mcg", "mul1"]: + # Handle multiplier extraction for mcg and mul1 + if tensor_name == "mcg" and loaded_weight.numel() > 0: + # Extract the multiplier value and store it on the layer + layer.mcg_mult = loaded_weight.view(torch.uint32).item() + elif tensor_name == "mul1" and loaded_weight.numel() > 0: + # Extract the multiplier value and store it on the layer + layer.mul1_mult = loaded_weight.view(torch.uint32).item() + + # Ensure the loaded weight matches expected shape and dtype + if loaded_weight.shape != param.shape: + # Handle potential shape mismatches due to padding + if tensor_name == "trellis": + # Trellis tensor might have different K values + if len(loaded_weight.shape) == 3: + param.data[:loaded_weight.shape[0], + :loaded_weight.shape[1], + :loaded_weight.shape[2]].copy_(loaded_weight) + return + elif tensor_name in ["suh", "svh"]: + # Sign factors might be shorter due to actual vs padded dimensions + if loaded_weight.numel() <= param.numel(): + param.data[:loaded_weight.numel()].copy_(loaded_weight.flatten()) + return + + # Standard copy for matching shapes + if loaded_weight.dtype != param.dtype: + loaded_weight = loaded_weight.to(param.dtype) + param.data.copy_(loaded_weight) + else: + # Fallback to default loader + if shard_id is not None: + default_loader(param, loaded_weight, shard_id) + else: + default_loader(param, loaded_weight) + + return exl3_weight_loader + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + """Apply EXL3 quantized linear transformation.""" + + # Get quantization parameters + trellis = layer.trellis + suh = getattr(layer, 'suh', None) + svh = getattr(layer, 'svh', None) + su = getattr(layer, 'su', None) + sv = getattr(layer, 'sv', None) + mcg = getattr(layer, 'mcg', None) + mul1 = getattr(layer, 'mul1', None) + + # Use pre-extracted multiplier values from the layer (set during weight loading) + mcg_mult = getattr(layer, 'mcg_mult', 0) + mul1_mult = getattr(layer, 'mul1_mult', 0) + + # Unpack signs if needed (fallback for legacy format) + if suh is None and su is not None: + suh = self._unpack_signs(su) + if svh is None and sv is not None: + svh = self._unpack_signs(sv) + + if suh is None or svh is None: + raise ValueError("EXL3 layer missing required sign factors (suh/svh)") + + # Reshape input for processing + input_shape = x.shape + x_reshaped = x.view(-1, x.shape[-1]) + batch_size = x_reshaped.shape[0] + + # Try to use optimized CUDA kernel if available + if _EXL3_KERNELS_AVAILABLE and x_reshaped.is_cuda: + # Create empty sign tensors if not provided + if suh.numel() == 0: + suh = torch.empty(0, dtype=torch.half, device=x_reshaped.device) + if svh.numel() == 0: + svh = torch.empty(0, dtype=torch.half, device=x_reshaped.device) + + output = ops.exl3_gemm( + x_reshaped.to(torch.float16), # Convert to half precision + trellis, + suh, + svh, + mcg_mult, + mul1_mult + ) + # Convert back to input dtype if needed + if output.dtype != x.dtype: + output = output.to(x.dtype) + else: + # Use fallback reconstruction method + output = self._apply_fallback(x_reshaped, trellis, suh, svh, + layer.K, mcg_mult, mul1_mult) + + # Add bias if present + if bias is not None: + output = output + bias + + # Reshape output to match input shape + result = output.view(input_shape[:-1] + (layer.output_size_per_partition,)) + + return result + + def _unpack_signs(self, packed_signs: torch.Tensor) -> torch.Tensor: + """Unpack bit-packed sign factors to float16.""" + # Convert packed int16 to individual sign bits + device = packed_signs.device + packed_signs = packed_signs.view(torch.uint16).to(torch.int32) + + # Extract individual bits and convert to signs + masks = (1 << torch.arange(16, device=device)).unsqueeze(0) + expanded = (packed_signs.unsqueeze(-1) & masks) > 0 + expanded = expanded.flatten() + + # Convert boolean to sign values (-1 or +1) + signs = torch.where(expanded, + torch.tensor(-1.0, dtype=torch.half, device=device), + torch.tensor(1.0, dtype=torch.half, device=device)) + + return signs.contiguous() + + def _apply_fallback(self, x: torch.Tensor, trellis: torch.Tensor, + suh: torch.Tensor, svh: torch.Tensor, + K: int, mcg_mult: int, mul1_mult: int) -> torch.Tensor: + """Fallback implementation using weight reconstruction.""" + # This is a simplified fallback - in practice you'd want the optimized kernel + + # Reconstruct the weight matrix from trellis format + weight = self._reconstruct_weight(trellis, suh, svh, K, mcg_mult, mul1_mult) + + # Standard matrix multiplication + return torch.mm(x, weight) + + def _reconstruct_weight(self, trellis: torch.Tensor, suh: torch.Tensor, + svh: torch.Tensor, K: int, mcg_mult: int, mul1_mult: int) -> torch.Tensor: + """Reconstruct full precision weight from EXL3 format.""" + + tiles_k, tiles_n, trellis_depth = trellis.shape + actual_in_features = suh.shape[0] + actual_out_features = svh.shape[0] + + # Create weight matrix with actual dimensions + weight = torch.zeros( + (actual_in_features, actual_out_features), + dtype=torch.float16, + device=trellis.device + ) + + # Decode trellis using 3INST procedural codebook + for tile_k in range(tiles_k): + for tile_n in range(tiles_n): + # Process each 16x16 tile + tile_start_k = tile_k * 16 + tile_start_n = tile_n * 16 + tile_end_k = min(tile_start_k + 16, actual_in_features) + tile_end_n = min(tile_start_n + 16, actual_out_features) + + # Decode the trellis data for this tile + for row in range(tile_end_k - tile_start_k): + for col in range(tile_end_n - tile_start_n): + # Calculate trellis index + elem_idx = (row * 16 + col) % (K * 16) + + # Get quantized value + quant_val = trellis[tile_k, tile_n, elem_idx] + + # Decode using 3INST procedural codebook + if mcg_mult != 0: + # MCG mode + decoded = self._decode_3inst_mcg(quant_val.item(), mcg_mult) + elif mul1_mult != 0: + # MUL1 mode + decoded = self._decode_3inst_mul1(quant_val.item(), mul1_mult) + else: + # Default mode + decoded = self._decode_3inst_default(quant_val.item()) + + # Store decoded value + k_idx = tile_start_k + row + n_idx = tile_start_n + col + if k_idx < actual_in_features and n_idx < actual_out_features: + weight[k_idx, n_idx] = decoded + + # Apply Hadamard sign transforms + # Input signs (suh) transform the rows, output signs (svh) transform the columns + suh_expanded = suh[:actual_in_features].unsqueeze(1) # [in_features, 1] + svh_expanded = svh[:actual_out_features].unsqueeze(0) # [1, out_features] + + # Apply sign factors + weight = weight * suh_expanded * svh_expanded + + return weight + + def _decode_3inst_default(self, x: int) -> float: + """Default 3INST procedural codebook decoding.""" + # Convert to unsigned 16-bit + x = x & 0xFFFF + # Default MCG multiplier + x = (x * 89226354) & 0xFFFFFFFF + x = (x + 64248484) & 0xFFFFFFFF + # Convert to float and normalize + return float(x) / float(0xFFFFFFFF) * 2.0 - 1.0 + + def _decode_3inst_mcg(self, x: int, mult: int) -> float: + """MCG mode 3INST procedural codebook decoding.""" + # Convert to unsigned 16-bit + x = x & 0xFFFF + # MCG mode with custom multiplier + x = (x * mult) & 0xFFFFFFFF + # Convert to float and normalize + return float(x) / float(0xFFFFFFFF) * 2.0 - 1.0 + + def _decode_3inst_mul1(self, x: int, mult: int) -> float: + """MUL1 mode 3INST procedural codebook decoding.""" + # Convert to unsigned 16-bit + x = x & 0xFFFF + # MUL1 mode with custom multiplier + x = (x * mult) & 0xFFFFFFFF + # Scale and offset (simplified version) + return float(x) * 6.77e-6 - 10.39 \ No newline at end of file diff --git a/kernels/ops.h b/kernels/ops.h index e8117f9737..1d9b995ffe 100644 --- a/kernels/ops.h +++ b/kernels/ops.h @@ -332,6 +332,15 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight, void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit); +// EXL3 quantization operations +torch::Tensor exl3_gemm(torch::Tensor input, torch::Tensor trellis, + torch::Tensor suh, torch::Tensor svh, + int64_t mcg_mult, int64_t mul1_mult); + +torch::Tensor exl3_reconstruct(torch::Tensor trellis, int64_t in_features, + int64_t out_features, int64_t mcg_mult, + int64_t mul1_mult); + void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input, torch::Tensor const& scale); diff --git a/kernels/quantization/exl3/exl3_gemm.cu b/kernels/quantization/exl3/exl3_gemm.cu new file mode 100644 index 0000000000..86345aefd7 --- /dev/null +++ b/kernels/quantization/exl3/exl3_gemm.cu @@ -0,0 +1,268 @@ +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +// Utility unions for efficient type conversion +union half2_uint32 { + uint32_t as_uint32; + half2 as_half2; + __device__ half2_uint32(uint32_t val) : as_uint32(val) {} + __device__ half2_uint32(half2 val) : as_half2(val) {} + __device__ half2_uint32() : as_uint32(0) {} +}; + +union half_uint16 { + uint16_t as_uint16; + half as_half; + __device__ half_uint16(uint16_t val) : as_uint16(val) {} + __device__ half_uint16(half val) : as_half(val) {} + __device__ half_uint16() : as_uint16(0) {} +}; + +// "3INST" procedural codebook decoder (based on ExLlamaV3) +template +__device__ inline half decode_3inst(uint32_t x, uint32_t mult) +{ + if constexpr (cb == 0) + { + x *= 89226354u; + x += 64248484u; + // Simplified version without inline assembly for compatibility + half2_uint32 xu(x); + // Extract high and low halves and add them + uint32_t low_bits = x & 0xFFFF; + uint32_t high_bits = (x >> 16) & 0xFFFF; + // Convert to half precision values + half low_half = __ushort_as_half((uint16_t)((low_bits & 0x8FFF) | 0x3B60)); + half high_half = __ushort_as_half((uint16_t)((high_bits & 0x8FFF) | 0x3B60)); + return __hadd(low_half, high_half); + } + if constexpr (cb == 1) + { + x *= mult; + // Same simplified conversion for MCG mode + uint32_t low_bits = x & 0xFFFF; + uint32_t high_bits = (x >> 16) & 0xFFFF; + half low_half = __ushort_as_half((uint16_t)((low_bits & 0x8FFF) | 0x3B60)); + half high_half = __ushort_as_half((uint16_t)((high_bits & 0x8FFF) | 0x3B60)); + return __hadd(low_half, high_half); + } + if constexpr (cb == 2) + { + x *= mult; + // MUL1 mode with scaling + const half k_inv_h = __ushort_as_half(0x1eee); // 0.00677 = 1/147.7 + const half k_bias_h = __ushort_as_half(0xc931); // -10.39 + half_uint16 h((uint16_t)(x + 0x6400u)); + return __hfma(h.as_half, k_inv_h, k_bias_h); + } +} + +// Determine which codebook mode to use based on multipliers +__device__ inline int get_cb_mode(uint32_t mcg_mult, uint32_t mul1_mult) +{ + if (mul1_mult != 0) return 2; // MUL1 mode + if (mcg_mult != 0) return 1; // MCG mode + return 0; // Default mode +} + +// Get the appropriate multiplier for the mode +__device__ inline uint32_t get_mult(int cb, uint32_t mcg_mult, uint32_t mul1_mult) +{ + if (cb == 1) return mcg_mult; + if (cb == 2) return mul1_mult; + return 89226354u; // Default multiplier +} + +// Simplified EXL3 GEMM kernel that reconstructs weights on-the-fly +__global__ void exl3_gemm_kernel( + const half* __restrict__ A, // [batch_size, in_features] + const int16_t* __restrict__ B, // [tiles_k, tiles_n, K*16] trellis + half* __restrict__ C, // [batch_size, out_features] + int batch_size, + int in_features, + int out_features, + int tiles_k, + int tiles_n, + int K, + uint32_t mcg_mult, + uint32_t mul1_mult +) { + int tx = threadIdx.x; + int ty = threadIdx.y; + int bx = blockIdx.x; + int by = blockIdx.y; + + // Determine codebook mode + int cb = get_cb_mode(mcg_mult, mul1_mult); + uint32_t mult = get_mult(cb, mcg_mult, mul1_mult); + + // Each block handles a 16x16 output tile + int tile_row = by; + int tile_col = bx; + + if (tile_row >= batch_size || tile_col >= (out_features + 15) / 16) return; + + // Each thread computes one output element + int out_row = tile_row; + int out_col = tile_col * 16 + tx; + + if (out_row >= batch_size || out_col >= out_features) return; + + float acc = 0.0f; + + // Iterate over input feature tiles + for (int k_tile = 0; k_tile < tiles_k; k_tile++) { + // Process 16 elements from the input + for (int k_local = 0; k_local < 16; k_local++) { + int k_global = k_tile * 16 + k_local; + + if (k_global >= in_features) break; + + // Get input value + half a_val = A[out_row * in_features + k_global]; + + // Reconstruct weight from trellis + int trellis_tile_k = k_tile; + int trellis_tile_n = tile_col; + + if (trellis_tile_n < tiles_n) { + // Calculate trellis index + int elem_idx = (k_local * 16 + tx) % (K * 16); + + // Get quantized value + int16_t quant_val = B[trellis_tile_k * tiles_n * K * 16 + + trellis_tile_n * K * 16 + + elem_idx]; + + // Decode using 3INST procedural codebook + half b_val; + uint32_t quant_u32 = static_cast(static_cast(quant_val)); + + if (cb == 0) { + b_val = decode_3inst<0>(quant_u32, mult); + } else if (cb == 1) { + b_val = decode_3inst<1>(quant_u32, mult); + } else { + b_val = decode_3inst<2>(quant_u32, mult); + } + + // Accumulate + acc += __half2float(__hmul(a_val, b_val)); + } + } + } + + // Store result + C[out_row * out_features + out_col] = __float2half(acc); +} + +// Host functions matching ExLlamaV3 interface +torch::Tensor exl3_gemm( + torch::Tensor input, // [batch_size, in_features] + torch::Tensor trellis, // [tiles_k, tiles_n, K*16] + torch::Tensor suh, // [in_features] - input signs (optional) + torch::Tensor svh, // [out_features] - output signs (optional) + int64_t mcg_mult, + int64_t mul1_mult +) { + const at::cuda::OptionalCUDAGuard device_guard(input.device()); + + TORCH_CHECK(input.dtype() == torch::kFloat16, "Input must be float16"); + TORCH_CHECK(trellis.dtype() == torch::kInt16, "Trellis must be int16"); + TORCH_CHECK(input.dim() == 2, "Input must be 2D"); + TORCH_CHECK(trellis.dim() == 3, "Trellis must be 3D"); + + int batch_size = input.size(0); + int in_features = input.size(1); + int tiles_k = trellis.size(0); + int tiles_n = trellis.size(1); + int K_times_16 = trellis.size(2); + int K = K_times_16 / 16; + int out_features = tiles_n * 16; + + // Create output tensor + torch::Tensor output = torch::zeros({batch_size, out_features}, + torch::TensorOptions() + .dtype(torch::kFloat16) + .device(input.device())); + + // Launch kernel + dim3 threads(16, 1); + dim3 blocks((out_features + 15) / 16, batch_size); + + exl3_gemm_kernel<<>>( + input.data_ptr(), + trellis.data_ptr(), + output.data_ptr(), + batch_size, + in_features, + out_features, + tiles_k, + tiles_n, + K, + static_cast(mcg_mult), + static_cast(mul1_mult) + ); + + // Apply sign factors if provided + if (suh.numel() > 0) { + // Apply input signs by broadcasting multiplication + // This is a simplified version - real implementation would use Hadamard transforms + // For now, just apply the signs directly + TORCH_CHECK(suh.size(0) == in_features, "suh size mismatch"); + // Note: This is incomplete - would need proper Hadamard transform + } + + if (svh.numel() > 0) { + // Apply output signs + TORCH_CHECK(svh.size(0) == out_features, "svh size mismatch"); + output = output * svh.unsqueeze(0); + } + + return output; +} + +torch::Tensor exl3_reconstruct( + torch::Tensor trellis, + int64_t in_features, + int64_t out_features, + int64_t mcg_mult, + int64_t mul1_mult +) { + const at::cuda::OptionalCUDAGuard device_guard(trellis.device()); + + TORCH_CHECK(trellis.dtype() == torch::kInt16, "Trellis must be int16"); + TORCH_CHECK(trellis.dim() == 3, "Trellis must be 3D"); + + int tiles_k = trellis.size(0); + int tiles_n = trellis.size(1); + int K_times_16 = trellis.size(2); + int K = K_times_16 / 16; + + // Create weight tensor + torch::Tensor weight = torch::zeros({in_features, out_features}, + torch::TensorOptions() + .dtype(torch::kFloat16) + .device(trellis.device())); + + // Use the GEMM kernel with identity input to reconstruct weights + torch::Tensor identity = torch::eye(in_features, + torch::TensorOptions() + .dtype(torch::kFloat16) + .device(trellis.device())); + + // Empty sign factors + torch::Tensor empty_signs = torch::empty({0}, + torch::TensorOptions() + .dtype(torch::kFloat16) + .device(trellis.device())); + + weight = exl3_gemm(identity, trellis, empty_signs, empty_signs, mcg_mult, mul1_mult); + + return weight; +} \ No newline at end of file diff --git a/kernels/torch_bindings.cpp b/kernels/torch_bindings.cpp index b08d3df7e0..d07862ac38 100644 --- a/kernels/torch_bindings.cpp +++ b/kernels/torch_bindings.cpp @@ -555,6 +555,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()"); ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle); + // EXL3 quantization operations + ops.def( + "exl3_gemm(Tensor input, Tensor trellis, Tensor suh, Tensor svh, " + "int mcg_mult, int mul1_mult) -> Tensor", + {stride_tag}); + ops.impl("exl3_gemm", torch::kCUDA, &exl3_gemm); + + ops.def( + "exl3_reconstruct(Tensor trellis, int in_features, int out_features, " + "int mcg_mult, int mul1_mult) -> Tensor"); + ops.impl("exl3_reconstruct", torch::kCUDA, &exl3_reconstruct); + // Compute FP8 quantized tensor for given scaling factor. ops.def( "static_scaled_fp8_quant(Tensor! result, Tensor input, Tensor scale) -> "