From d98732ac476aa68786063d119a6a1ea70113a75d Mon Sep 17 00:00:00 2001 From: David Ndungu Date: Mon, 6 Apr 2026 15:31:15 -0700 Subject: [PATCH] refactor(compute): split gpu_engine.go into focused files Extract methods from gpu_engine.go (3521 -> 2245 lines) into three focused files to improve navigability: - gpu_engine_elementwise.go (400 lines): Add/Sub/Mul/Div, scalar ops, Exp/Log/Sin/Cos/Tanh/Pow, Sqrt/Rsqrt, fused RoPE/SwiGLU/RMSNorm, CosineSimilarity, HadamardTransform - gpu_engine_reduction.go (221 lines): Sum, Softmax, ReduceSum/Max/Mean, GPUArgmax, GPUScaledSoftmax, GPUFusedSoftmaxVMul - gpu_engine_memory.go (695 lines): Transpose, Zero/Zeros/Copy, Gather, ScatterAdd, Fill, Split/Concat/Repeat/RepeatInterleave, Reshape, OneHot, ConvertFP16ToF32 Zero behavioral changes. All method signatures identical. Build, vet, and race-detector tests pass. --- compute/gpu_engine.go | 1276 ----------------------------- compute/gpu_engine_elementwise.go | 400 +++++++++ compute/gpu_engine_memory.go | 695 ++++++++++++++++ compute/gpu_engine_reduction.go | 221 +++++ 4 files changed, 1316 insertions(+), 1276 deletions(-) create mode 100644 compute/gpu_engine_elementwise.go create mode 100644 compute/gpu_engine_memory.go create mode 100644 compute/gpu_engine_reduction.go diff --git a/compute/gpu_engine.go b/compute/gpu_engine.go index 34edc20..e57c847 100644 --- a/compute/gpu_engine.go +++ b/compute/gpu_engine.go @@ -2231,1270 +2231,6 @@ func (e *GPUEngine[T]) mmapDevicePtr(ms *tensor.MmapStorage) (unsafe.Pointer, fu return devPtr, cleanup, nil } -// --- GPU-accelerated and fallback methods --- - -func (e *GPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.cpu.UnaryOp(ctx, a, op, dst...) -} - -// Add performs element-wise addition. -func (e *GPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuAdd(ctx, a, b, dst...) -} - -// Sub performs element-wise subtraction. -func (e *GPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuSub(ctx, a, b, dst...) -} - -// Mul performs element-wise multiplication. -func (e *GPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuMul(ctx, a, b, dst...) -} - -// Div performs element-wise division. -func (e *GPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuDiv(ctx, a, b, dst...) -} - -// Transpose transposes a tensor along the given axes. -func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() { - return e.cpu.Transpose(ctx, a, axes, dst...) - } - - // Only use GPU path for GPU-resident tensors (Phase 6 behavior). - // CPU-backed tensors fall back to CPU transpose to avoid unexpected - // H2D copies that may interfere with CUDA graph capture/replay. - _, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) - isFP16 := false - if e.dtype != DTypeF32 { - _, isFP16 = any(a.GetStorage()).(*tensor.Float16Storage) - } - if !isGPU && !isFP16 { - return e.cpu.Transpose(ctx, a, axes, dst...) - } - - e.setDevice() - - shape := a.Shape() - rank := len(shape) - - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE: shape=%v rank=%d axes=%v storage=%T\n", shape, rank, axes, a.GetStorage()) - } - - // Default: reverse axes (same as CPU Transpose with nil axes). - if len(axes) == 0 { - axes = make([]int, rank) - for i := range rank { - axes[i] = rank - 1 - i - } - } - - if len(axes) != rank { - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=axes_rank_mismatch shape=%v\n", shape) - } - return e.cpu.Transpose(ctx, a, axes, dst...) - } - - // GPU transpose kernel supports up to 4D; fall back to CPU for higher ranks. - if rank > 4 { - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=rank_gt_4 shape=%v\n", shape) - } - return e.cpu.Transpose(ctx, a, axes, dst...) - } - - // Compute output shape. - outShape := make([]int, rank) - for i, ax := range axes { - outShape[i] = shape[ax] - } - - // Fast path: if the transpose only swaps unit-sized dimensions, it is - // equivalent to a reshape (no data movement). This is common during - // single-token generation where seqLen=1. Check by comparing the - // non-unit dimensions in input vs output order. - if isTransposeReshape(shape, outShape) { - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE: reshape fast path shape=%v outShape=%v storage=%T\n", shape, outShape, a.GetStorage()) - } - if e.dtype != DTypeF32 { - if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok { - storageT := any(fs).(tensor.Storage[T]) - t, tErr := tensor.NewWithStorage[T](outShape, storageT) - if tErr != nil { - return nil, tErr - } - return t, nil - } - } - gs := a.GetStorage().(*tensor.GPUStorage[T]) - viewGS := gs.View(gs.Len()) - t, tErr := tensor.NewWithStorage[T](outShape, viewGS) - if tErr != nil { - return nil, tErr - } - if len(dst) > 0 && dst[0] != nil { - dst[0].SetStorage(viewGS) - dst[0].SetShape(outShape) - return dst[0], nil - } - return t, nil - } - - // Compute total elements. - total := 1 - for _, d := range shape { - total *= d - } - - // Compute input strides. - inStrides := make([]int, rank) - stride := 1 - for i := rank - 1; i >= 0; i-- { - inStrides[i] = stride - stride *= shape[i] - } - - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr: storage=%T\n", a.GetStorage()) - } - devIn, cleanupIn, err := getDevicePtr(e, a) - if err != nil { - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=getDevicePtr_failed shape=%v\n", shape) - } - return e.cpu.Transpose(ctx, a, axes, dst...) - } - defer cleanupIn() - if debugGPU { - fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr OK: ptr=%p\n", devIn) - } - - byteSize := total * f32Size - devOut, err := e.pool.Alloc(e.deviceID, byteSize) - if err != nil { - return e.cpu.Transpose(ctx, a, axes, dst...) - } - - // Fast path: 2D transpose. - if rank == 2 && axes[0] == 1 && axes[1] == 0 { - if debugGPU { - e.logger.Debug("TRANSPOSE: using 2D fast path", - "rows", fmt.Sprintf("%d", shape[0]), - "cols", fmt.Sprintf("%d", shape[1])) - } - if err := e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, byteSize) - return nil, err - } - return makeGPUResult[T](e, outShape, devOut, total, dst...) - } - - // General N-D transpose via stride-based kernel. - // Precompute output strides on the host so the kernel avoids O(ndim^2) per thread. - if debugGPU { - e.logger.Debug("TRANSPOSE: using general N-D path", - "rank", fmt.Sprintf("%d", rank), - "axes", fmt.Sprintf("%v", axes)) - } - outStrides := make([]int, rank) - outStride := 1 - for i := rank - 1; i >= 0; i-- { - outStrides[i] = outStride - outStride *= outShape[i] - } - - inStrides32 := make([]int32, rank) - outStrides32 := make([]int32, rank) - perm32 := make([]int32, rank) - for i := range rank { - inStrides32[i] = int32(inStrides[i]) - outStrides32[i] = int32(outStrides[i]) - perm32[i] = int32(axes[i]) - } - - if err := e.kernels.TransposeND(devIn, devOut, inStrides32, outStrides32, perm32, rank, total, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, byteSize) - return nil, err - } - - return makeGPUResult[T](e, outShape, devOut, total, dst...) -} - -// Sum computes the sum of elements along an axis. -func (e *GPUEngine[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuSum(ctx, a, axis, keepDims, dst...) -} - -// Exp computes the element-wise exponential. -func (e *GPUEngine[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuExp(ctx, a, dst...) -} - -// Log computes the element-wise natural logarithm. -func (e *GPUEngine[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuLog(ctx, a, dst...) -} - -// Sin computes the element-wise sine. -func (e *GPUEngine[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuSin(ctx, a, dst...) -} - -// Cos computes the element-wise cosine. -func (e *GPUEngine[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuCos(ctx, a, dst...) -} - -// Tanh computes the element-wise hyperbolic tangent. -func (e *GPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuTanh(ctx, a, dst...) -} - -// TanhPrime computes the element-wise gradient of tanh. -func (e *GPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuTanhPrime(ctx, a, upstream, dst...) -} - -// Pow raises each element to the given power. -func (e *GPUEngine[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuPow(ctx, base, exponent, dst...) -} - -// Zero sets all elements to zero. -func (e *GPUEngine[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error { - // GPU path: use cudaMemsetAsync on the engine's stream. - if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok { - return e.runtime.MemsetAsync(gs.Ptr(), 0, gs.ByteSize(), e.stream) - } - // CPU fallback for non-GPU tensors. - return e.cpu.Zero(ctx, a) -} - -// Zeros fills the tensor with zeros. -func (e *GPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error { - return e.cpu.Zeros(ctx, a, shape) -} - -// Copy copies data from source to destination tensor. -func (e *GPUEngine[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error { - dstGS, dstIsGPU := dst.GetStorage().(*tensor.GPUStorage[T]) - srcGS, srcIsGPU := src.GetStorage().(*tensor.GPUStorage[T]) - if dstIsGPU && srcIsGPU { - // D2D copy on engine stream. - return e.runtime.MemcpyAsync(dstGS.Ptr(), srcGS.Ptr(), dstGS.ByteSize(), gpuapi.MemcpyDeviceToDevice, e.stream) - } - // Fall back to CPU for mixed or CPU-only tensors. - return e.cpu.Copy(ctx, dst, src) -} - -// Gather performs an embedding-style gather. -func (e *GPUEngine[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error { - if !isFloat32[T]() { - return e.cpu.Gather(ctx, params, indices, output) - } - - // Q8 GPU gather: dequantize only the requested rows on GPU. - if qs, ok := any(params.GetStorage()).(*tensor.Q8Storage); ok { - if ptr, _, _ := qs.GPUPtr(); ptr != nil { - return e.gatherQ8(params, indices, output, qs, ptr) - } - } - - // Check whether params are GPU-resident (F32 or FP16 storage). - _, isGPU := params.GetStorage().(*tensor.GPUStorage[T]) - var fp16Stor *tensor.Float16Storage - isFP16 := false - if e.dtype != DTypeF32 { - fp16Stor, isFP16 = any(params.GetStorage()).(*tensor.Float16Storage) - } - if !isGPU && !isFP16 { - return e.cpu.Gather(ctx, params, indices, output) - } - - e.setDevice() - - pShape := params.Shape() - if len(pShape) != 2 { - return e.cpu.Gather(ctx, params, indices, output) - } - V := pShape[0] - D := pShape[1] - - // Flatten indices to get N. - idxData := indices.Data() - N := len(idxData) - if N == 0 { - return nil - } - - // Get device pointer for params. For Float16Storage, convert FP16->F32 - // into a temporary buffer so the F32 Gather kernel can operate on it. - var devParams unsafe.Pointer - var cleanupParams func() - if isFP16 { - fp16Ptr, _, _ := fp16Stor.GPUPtr() - if fp16Ptr == nil { - return e.cpu.Gather(ctx, params, indices, output) - } - nElems := V * D - f32Bytes := nElems * f32Size - f32Ptr, err := e.pool.Alloc(e.deviceID, f32Bytes) - if err != nil { - return e.cpu.Gather(ctx, params, indices, output) - } - if err := e.kernels.FP16ToF32(fp16Ptr, f32Ptr, nElems, e.stream); err != nil { - e.pool.Free(e.deviceID, f32Ptr, f32Bytes) - return e.cpu.Gather(ctx, params, indices, output) - } - devParams = f32Ptr - cleanupParams = func() { e.pool.Free(e.deviceID, f32Ptr, f32Bytes) } - } else { - var err error - devParams, cleanupParams, err = getDevicePtr(e, params) - if err != nil { - return e.cpu.Gather(ctx, params, indices, output) - } - } - defer cleanupParams() - - // Upload indices to GPU as int64 (Go int on 64-bit platforms). - // The gather kernel accepts int64 indices directly, avoiding the - // CPU-side int64→int32 conversion that would trigger a D2H copy - // for GPU-resident indices and block CUDA graph capture. - intSize := int(unsafe.Sizeof(int(0))) - idxByteSize := N * intSize - devIdx, err := e.pool.Alloc(e.deviceID, idxByteSize) - if err != nil { - return e.cpu.Gather(ctx, params, indices, output) - } - defer e.pool.Free(e.deviceID, devIdx, idxByteSize) - - if err := e.runtime.Memcpy(devIdx, unsafe.Pointer(&idxData[0]), idxByteSize, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.Gather(ctx, params, indices, output) - } - - // Allocate output on GPU. - outByteSize := N * D * f32Size - devOut, err := e.pool.Alloc(e.deviceID, outByteSize) - if err != nil { - return e.cpu.Gather(ctx, params, indices, output) - } - - if err := e.kernels.Gather(devParams, devIdx, devOut, N, D, V, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outByteSize) - return fmt.Errorf("GPU Gather: %w", err) - } - - // When dtype is FP16, convert the F32 gather output to FP16 on GPU. - // This is the single F32->FP16 conversion point for the entire forward pass; - // all downstream ops receive Float16Storage and operate in FP16 natively. - if e.dtype == DTypeFP16 { - outElems := N * D - fp16Bytes := outElems * fp16Size - fp16Ptr, err := e.pool.Alloc(e.deviceID, fp16Bytes) - if err != nil { - e.pool.Free(e.deviceID, devOut, outByteSize) - return fmt.Errorf("Gather FP16 alloc: %w", err) - } - if err := e.kernels.F32ToFP16(devOut, fp16Ptr, outElems, e.stream); err != nil { - e.pool.Free(e.deviceID, fp16Ptr, fp16Bytes) - e.pool.Free(e.deviceID, devOut, outByteSize) - return fmt.Errorf("Gather F32->FP16: %w", err) - } - e.pool.Free(e.deviceID, devOut, outByteSize) - fs := any(tensor.NewFloat16StorageGPU(fp16Ptr, outElems, e.deviceID)).(tensor.Storage[T]) - output.SetStorage(fs) - return nil - } - - // Set output storage to GPU (pool-backed so Free returns to pool, not cudaFree). - gs, err := tensor.NewGPUStorageFromPool[T](devOut, N*D, e.pool, e.deviceID) - if err != nil { - e.pool.Free(e.deviceID, devOut, outByteSize) - return err - } - output.SetStorage(gs) - - return nil -} - -// gatherQ8 performs Q8_0 embedding gather on GPU using the Q8 gather kernel. -// Dequantizes only the requested rows, keeping the full Q8 table compressed. -func (e *GPUEngine[T]) gatherQ8( - params *tensor.TensorNumeric[T], - indices *tensor.TensorNumeric[int], - output *tensor.TensorNumeric[T], - qs *tensor.Q8Storage, - devQ8 unsafe.Pointer, -) error { - e.setDevice() - - pShape := params.Shape() - V := pShape[0] - D := pShape[1] - - idxData := indices.Data() - N := len(idxData) - if N == 0 { - return nil - } - - // Upload indices as int32 to GPU. - idx32 := make([]int32, N) - for i, id := range idxData { - idx32[i] = int32(id) - } - idxBytes := N * 4 - devIdx, err := e.pool.Alloc(e.deviceID, idxBytes) - if err != nil { - return e.cpu.Gather(context.Background(), params, indices, output) - } - defer e.pool.Free(e.deviceID, devIdx, idxBytes) - - if err := e.runtime.Memcpy(devIdx, unsafe.Pointer(&idx32[0]), idxBytes, gpuapi.MemcpyHostToDevice); err != nil { - return e.cpu.Gather(context.Background(), params, indices, output) - } - - // Allocate output [N, D] on GPU. - outElems := N * D - outBytes := outElems * f32Size - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return e.cpu.Gather(context.Background(), params, indices, output) - } - - // Launch Q8 gather kernel. - if err := e.kernels.GatherQ8F32(devQ8, devIdx, devOut, N, D, V, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.cpu.Gather(context.Background(), params, indices, output) - } - - // Write result into output tensor as GPUStorage (pool-backed). - gs, err := tensor.NewGPUStorageFromPool[float32](devOut, outElems, e.pool, e.deviceID) - if err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return fmt.Errorf("gatherQ8: create GPU storage: %w", err) - } - output.SetStorage(any(gs).(tensor.Storage[T])) - return nil -} - -// ScatterAdd performs a row-wise scatter-add for embeddings. -func (e *GPUEngine[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error { - return e.cpu.ScatterAdd(ctx, dEmbeddingTable, indices, dOut) -} - -// RandomUniform fills the tensor with uniform random values. -func (e *GPUEngine[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error { - return e.cpu.RandomUniform(ctx, t, minVal, maxVal) -} - -// Fill fills the tensor with a scalar value. -func (e *GPUEngine[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error { - return e.gpuFill(ctx, t, value) -} - -// MulScalar multiplies each element by a scalar. -func (e *GPUEngine[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuMulScalar(ctx, a, scalar, dst...) -} - -// DivScalar divides each element by a scalar. -func (e *GPUEngine[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuDivScalar(ctx, a, scalar, dst...) -} - -// Softmax applies the softmax function along an axis. -func (e *GPUEngine[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuSoftmax(ctx, a, axis, dst...) -} - -// ReduceSum computes the sum of elements along an axis. -func (e *GPUEngine[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuReduceSum(ctx, a, axis, keepDims, dst...) -} - -// ReduceMax computes the maximum of elements along an axis (CPU fallback). -func (e *GPUEngine[T]) ReduceMax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.cpu.ReduceMax(ctx, a, axis, keepDims, dst...) -} - -// AddScalar adds a scalar to each element. -func (e *GPUEngine[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuAddScalar(ctx, a, scalar, dst...) -} - -// Sqrt computes the element-wise square root. -func (e *GPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuSqrt(ctx, a, dst...) -} - -// Split splits a tensor into multiple tensors along an axis. -func (e *GPUEngine[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() { - return e.cpu.Split(ctx, a, numSplits, axis) - } - if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok { - return e.gpuSplit(gs.Ptr(), a.Shape(), numSplits, axis) - } - if e.dtype != DTypeF32 { - if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok { - ptr, _, _ := fs.GPUPtr() - if ptr != nil { - return e.gpuSplitFP16(ptr, a.Shape(), numSplits, axis) - } - } - } - return e.cpu.Split(ctx, a, numSplits, axis) -} - -// Concat concatenates tensors along an axis. -func (e *GPUEngine[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() || len(tensors) == 0 { - return e.cpu.Concat(ctx, tensors, axis, dst...) - } - // Check all inputs are GPU-resident (GPUStorage or Float16Storage). - ptrs := make([]unsafe.Pointer, len(tensors)) - allFP16 := true - for i, t := range tensors { - if gs, ok := t.GetStorage().(*tensor.GPUStorage[T]); ok { - ptrs[i] = gs.Ptr() - allFP16 = false - } else if e.dtype != DTypeF32 { - if fs, ok := any(t.GetStorage()).(*tensor.Float16Storage); ok { - p, _, _ := fs.GPUPtr() - if p == nil { - return e.cpu.Concat(ctx, tensors, axis, dst...) - } - ptrs[i] = p - } else { - return e.cpu.Concat(ctx, tensors, axis, dst...) - } - } else { - return e.cpu.Concat(ctx, tensors, axis, dst...) - } - } - if allFP16 && e.dtype != DTypeF32 { - return e.gpuConcatFP16(ptrs, tensors, axis, dst...) - } - return e.gpuConcat(ptrs, tensors, axis, dst...) -} - -// Repeat repeats the tensor along an axis. -func (e *GPUEngine[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() || a == nil { - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - - shape := a.Shape() - if axis < 0 || axis >= len(shape) { - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - if repetitions <= 0 { - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - - // Get device pointer (handles GPUStorage[T] and Float16Storage). - isFP16 := false - if e.dtype != DTypeF32 { - _, isFP16 = any(a.GetStorage()).(*tensor.Float16Storage) - } - gs, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) - if !isGPU && !isFP16 { - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - - e.setDevice() - - newShape := make([]int, len(shape)) - copy(newShape, shape) - newShape[axis] *= repetitions - - outElems := 1 - for _, d := range newShape { - outElems *= d - } - outBytes := outElems * f32Size - - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - - var devA unsafe.Pointer - var cleanupA func() - if isGPU { - devA = gs.Ptr() - cleanupA = func() {} - } else { - // Float16Storage: convert FP16→F32 for the F32 repeat kernel. - f32Engine, ok := any(e).(*GPUEngine[float32]) - if !ok { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - devA, cleanupA, err = getDevicePtr(f32Engine, any(a).(*tensor.TensorNumeric[float32])) - if err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - } - defer cleanupA() - - // Compute dimensions for the repeat kernel. - outerSize := 1 - for i := 0; i < axis; i++ { - outerSize *= shape[i] - } - axisDim := shape[axis] - innerSize := 1 - for i := axis + 1; i < len(shape); i++ { - innerSize *= shape[i] - } - - if err := e.kernels.Repeat(devA, devOut, outerSize, axisDim, innerSize, repetitions, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - - // For FP16 inputs, convert the F32 output back to Float16Storage. - if isFP16 { - fp16Bytes := outElems * fp16Size - fp16Out, allocErr := e.pool.Alloc(e.deviceID, fp16Bytes) - if allocErr != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - if convErr := e.kernels.F32ToFP16(devOut, fp16Out, outElems, e.stream); convErr != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - e.pool.Free(e.deviceID, fp16Out, fp16Bytes) - return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) - } - e.pool.Free(e.deviceID, devOut, outBytes) - fs := tensor.NewFloat16StorageGPU(fp16Out, outElems, e.deviceID) - storageT := any(fs).(tensor.Storage[T]) - return tensor.NewWithStorage[T](newShape, storageT) - } - - return makeGPUResult[T](e, newShape, devOut, outElems, dst...) -} - -// RepeatInterleave expands a 4D tensor from [B, numKV, S, D] to [B, numQ, S, D] -// by repeating each head along axis 1 (the head dimension) `reps` times. -// This is a fused kernel for GQA key/value head expansion, replacing the -// Reshape -> Repeat -> Reshape chain with a single kernel launch. -// axis must be 1 and the input must be 4D [B, numKV, S, D]. -func (e *GPUEngine[T]) RepeatInterleave(ctx context.Context, a *tensor.TensorNumeric[T], axis int, reps int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - shape := a.Shape() - if !isFloat32[T]() || a == nil || axis != 1 || len(shape) != 4 || reps <= 0 { - // Fall back to generic Repeat path for unsupported configurations. - return e.Repeat(ctx, a, axis, reps, dst...) - } - - B, numKV, S, D := shape[0], shape[1], shape[2], shape[3] - numQ := numKV * reps - - gs, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) - if !isGPU { - return e.Repeat(ctx, a, axis, reps, dst...) - } - - e.setDevice() - - outElems := B * numQ * S * D - outBytes := outElems * f32Size - - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return e.Repeat(ctx, a, axis, reps, dst...) - } - - if err := e.kernels.RepeatInterleaveF32(gs.Ptr(), devOut, B, numKV, S, D, reps, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return e.Repeat(ctx, a, axis, reps, dst...) - } - - outShape := []int{B, numQ, S, D} - return makeGPUResult[T](e, outShape, devOut, outElems, dst...) -} - -// OneHot creates a one-hot encoding. -func (e *GPUEngine[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.cpu.OneHot(ctx, input, depth, dst...) -} - -// Reshape changes the shape without changing data. -func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - // Resolve -1 dimension and verify size. - currentSize := a.Size() - inferredShape := make([]int, len(shape)) - copy(inferredShape, shape) - inferIdx := -1 - knownSize := 1 - for i, d := range inferredShape { - if d == -1 { - inferIdx = i - } else { - knownSize *= d - } - } - if inferIdx >= 0 { - inferredShape[inferIdx] = currentSize / knownSize - } - newSize := 1 - for _, d := range inferredShape { - newSize *= d - } - - // Float16Storage: zero-copy reshape (same GPU pointer, new shape). - if e.dtype != DTypeF32 { - if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok && newSize == currentSize { - return tensor.NewWithStorage[T](inferredShape, any(fs).(tensor.Storage[T])) - } - } - - // GPUStorage[T]: zero-copy reshape. - if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize { - return tensor.NewWithStorage[T](inferredShape, gs.View(gs.Len())) - } - - return e.cpu.Reshape(ctx, a, shape, dst...) -} - -// ReduceMean computes the mean of elements along an axis. -func (e *GPUEngine[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuReduceMean(ctx, a, axis, keepDims, dst...) -} - -// Rsqrt computes the element-wise reciprocal square root. -func (e *GPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.gpuRsqrt(ctx, a, dst...) -} - -// GPUArgmax finds the index of the maximum element in a GPU-resident float32 tensor. -// Returns the index as an int without copying the full tensor to the host. -// Only copies back a single int32 (4 bytes) instead of the entire tensor. -func (e *GPUEngine[T]) GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error) { - gs, ok := t.GetStorage().(*tensor.GPUStorage[float32]) - if !ok { - return 0, fmt.Errorf("GPUArgmax: tensor not GPU-resident") - } - - e.setDevice() - - n := gs.Len() - devInput := gs.Ptr() - - // Allocate scratch: 2 * ceil(n/256) * 4 bytes (blockVals + blockIdxs). - numBlocks := (n + 255) / 256 - scratchSize := 2 * numBlocks * 4 - devScratch, err := e.pool.Alloc(e.deviceID, scratchSize) - if err != nil { - return 0, fmt.Errorf("GPUArgmax: scratch alloc: %w", err) - } - defer e.pool.Free(e.deviceID, devScratch, scratchSize) - - // Allocate device result (single int32). - devResult, err := e.pool.Alloc(e.deviceID, 4) - if err != nil { - return 0, fmt.Errorf("GPUArgmax: result alloc: %w", err) - } - defer e.pool.Free(e.deviceID, devResult, 4) - - if err := e.kernels.Argmax(devInput, devResult, devScratch, n, e.stream); err != nil { - return 0, fmt.Errorf("GPUArgmax: %w", err) - } - - // Copy single int32 result back to host. - var result int32 - if err := e.runtime.Memcpy(unsafe.Pointer(&result), devResult, 4, gpuapi.MemcpyDeviceToHost); err != nil { - return 0, fmt.Errorf("GPUArgmax: D2H copy: %w", err) - } - - return int(result), nil -} - -// ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32 -// GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it -// does not have Float16Storage. -func (e *GPUEngine[T]) ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error) { - fs, ok := any(t.GetStorage()).(*tensor.Float16Storage) - if !ok { - return t, nil - } - - fp16Ptr, _, _ := fs.GPUPtr() - if fp16Ptr == nil { - // CPU-side Float16Storage: decode via Slice (no GPU conversion possible). - data := fs.Slice() - out, err := tensor.New(t.Shape(), data) - if err != nil { - return nil, fmt.Errorf("ConvertFP16ToF32: create f32 tensor: %w", err) - } - return out, nil - } - - e.setDevice() - - nElems := fs.Len() - f32Bytes := nElems * f32Size - f32Ptr, err := e.pool.Alloc(e.deviceID, f32Bytes) - if err != nil { - return nil, fmt.Errorf("ConvertFP16ToF32: alloc: %w", err) - } - - if err := e.kernels.FP16ToF32(fp16Ptr, f32Ptr, nElems, e.stream); err != nil { - e.pool.Free(e.deviceID, f32Ptr, f32Bytes) - return nil, fmt.Errorf("ConvertFP16ToF32: kernel: %w", err) - } - - gs, err := tensor.NewGPUStorageFromPool[float32](f32Ptr, nElems, e.pool, e.deviceID) - if err != nil { - e.pool.Free(e.deviceID, f32Ptr, f32Bytes) - return nil, fmt.Errorf("ConvertFP16ToF32: gpu storage: %w", err) - } - out, err := tensor.NewWithStorage[float32](t.Shape(), gs) - if err != nil { - return nil, fmt.Errorf("ConvertFP16ToF32: wrap tensor: %w", err) - } - return out, nil -} - -// GPUFusedRoPE applies rotary position embeddings in a single GPU kernel launch. -// This replaces Split + 4 Mul + Sub + Add + Concat (8 operations, ~10 D2D memcpy) with 1 kernel. -func (e *GPUEngine[T]) GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error) { - shape := input.Shape() - if len(shape) != 3 { - return nil, fmt.Errorf("GPUFusedRoPE: expected 3D input [batch, seq, dim], got %dD", len(shape)) - } - - batch := shape[0] - seqLen := shape[1] - headDim := shape[2] - halfRotary := rotaryDim / 2 - - cosShape := cosAngles.Shape() - if len(cosShape) != 2 || cosShape[0] < seqLen || cosShape[1] < halfRotary { - return nil, fmt.Errorf("GPUFusedRoPE: cos shape %v incompatible with seq_len=%d half_rotary=%d", cosShape, seqLen, halfRotary) - } - cosStride := cosShape[1] - - // Get device pointers for input, cos, sin. - inPtr, inCleanup, err := getDevicePtr(e, input) - if err != nil { - return nil, fmt.Errorf("GPUFusedRoPE input: %w", err) - } - defer inCleanup() - - cosPtr, cosCleanup, err := getDevicePtr(e, cosAngles) - if err != nil { - return nil, fmt.Errorf("GPUFusedRoPE cos: %w", err) - } - defer cosCleanup() - - sinPtr, sinCleanup, err := getDevicePtr(e, sinAngles) - if err != nil { - return nil, fmt.Errorf("GPUFusedRoPE sin: %w", err) - } - defer sinCleanup() - - // Allocate output. - outElems := batch * seqLen * headDim - outBytes := outElems * f32Size - e.setDevice() - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, fmt.Errorf("GPUFusedRoPE alloc: %w", err) - } - - if err := e.kernels.FusedRoPEF32(inPtr, cosPtr, sinPtr, devOut, batch, seqLen, headDim, halfRotary, cosStride, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return nil, err - } - - return makeGPUResult[T](e, shape, devOut, outElems) -} - -// GPUFusedSwiGLU computes SwiGLU(w1, w3) = w1 * sigmoid(w1) * w3 in a single GPU kernel. -// This replaces Concat + Split + sigmoid + Mul + Mul (5 operations, ~4 D2D memcpy per layer) with 1 kernel. -func (e *GPUEngine[T]) GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - w1Shape := w1.Shape() - w3Shape := w3.Shape() - if len(w1Shape) == 0 || len(w3Shape) == 0 { - return nil, fmt.Errorf("GPUFusedSwiGLU: empty shape") - } - - // Validate shapes match. - n1 := 1 - for _, d := range w1Shape { - n1 *= d - } - n3 := 1 - for _, d := range w3Shape { - n3 *= d - } - if n1 != n3 { - return nil, fmt.Errorf("GPUFusedSwiGLU: w1 (%d elems) and w3 (%d elems) size mismatch", n1, n3) - } - - w1Ptr, w1Cleanup, err := getDevicePtr(e, w1) - if err != nil { - return nil, fmt.Errorf("GPUFusedSwiGLU w1: %w", err) - } - defer w1Cleanup() - - w3Ptr, w3Cleanup, err := getDevicePtr(e, w3) - if err != nil { - return nil, fmt.Errorf("GPUFusedSwiGLU w3: %w", err) - } - defer w3Cleanup() - - outBytes := n1 * f32Size - e.setDevice() - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, fmt.Errorf("GPUFusedSwiGLU alloc: %w", err) - } - - if err := e.kernels.FusedSwiGLUF32(w1Ptr, w3Ptr, devOut, n1, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return nil, err - } - - return makeGPUResult[T](e, w1Shape, devOut, n1) -} - -// GPUScaledSoftmax computes softmax(input * scale) in a single GPU kernel launch. -// This replaces MulScalar + Softmax (2 kernel launches) with 1, saving 26 launches -// per token for 26 transformer layers. -func (e *GPUEngine[T]) GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() { - return nil, fmt.Errorf("GPUScaledSoftmax: only float32 supported") - } - - e.setDevice() - - if input == nil { - return nil, fmt.Errorf("GPUScaledSoftmax: input tensor must not be nil") - } - - shape := input.Shape() - rank := len(shape) - - if rank == 0 { - return nil, fmt.Errorf("GPUScaledSoftmax: scalar tensors not supported") - } - - if axis < 0 { - axis = rank + axis - } - - if axis < 0 || axis >= rank { - return nil, fmt.Errorf("GPUScaledSoftmax: axis %d out of bounds for %d dimensions", axis, rank) - } - - n := input.GetStorage().Len() - - inner := 1 - for i := axis + 1; i < rank; i++ { - inner *= shape[i] - } - - outer := 1 - for i := 0; i < axis; i++ { - outer *= shape[i] - } - - axisSize := shape[axis] - - // FP16 paths — skip entirely for F32 compute. - if e.dtype != DTypeF32 { - // Native FP16 path: input already has Float16Storage on GPU — no conversion needed. - if fs, ok := any(input.GetStorage()).(*tensor.Float16Storage); ok { - return fp16ScaledSoftmaxNative(e, fs, input.Shape(), scale, outer, inner, axisSize) - } - // FP16 path: convert to FP16, run FP16 scaled softmax, convert back. - return fp16ScaledSoftmax(e, input, scale, outer, inner, axisSize) - } - - devIn, cleanupIn, err := getDevicePtr(e, input) - if err != nil { - return nil, fmt.Errorf("GPUScaledSoftmax input: %w", err) - } - defer cleanupIn() - - byteSize := n * f32Size - devOut, err := e.pool.Alloc(e.deviceID, byteSize) - if err != nil { - return nil, fmt.Errorf("GPUScaledSoftmax alloc: %w", err) - } - - if err := e.kernels.ScaledSoftmaxF32(devIn, devOut, outer, inner, axisSize, scale, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, byteSize) - return nil, err - } - - return makeGPUResult[T](e, shape, devOut, n) -} - -// GPUFusedSoftmaxVMul computes softmax(scores * scale) @ V in a single GPU -// kernel launch. Decode-optimized (seqQ=1): avoids materializing the attention -// weights tensor, saving one kernel launch and the associated memory traffic. -// scores: [BH, 1, seqKV], V: [BH, seqKV, D]. Returns output: [BH, 1, D]. -func (e *GPUEngine[T]) GPUFusedSoftmaxVMul(scores, V *tensor.TensorNumeric[T], scale float32) (*tensor.TensorNumeric[T], error) { - if !isFloat32[T]() { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul: only float32 supported") - } - - if scores == nil || V == nil { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul: input tensors must not be nil") - } - - e.setDevice() - - sShape := scores.Shape() - vShape := V.Shape() - - // scores must be [BH, 1, seqKV] or [BH, seqKV] - var BH, seqKV int - switch len(sShape) { - case 3: - if sShape[1] != 1 { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul: scores seqQ must be 1 for decode, got %d", sShape[1]) - } - BH, seqKV = sShape[0], sShape[2] - case 2: - BH, seqKV = sShape[0], sShape[1] - default: - return nil, fmt.Errorf("GPUFusedSoftmaxVMul: scores must be 2D or 3D, got %dD", len(sShape)) - } - - // V must be [BH, seqKV, D] - if len(vShape) != 3 || vShape[0] != BH || vShape[1] != seqKV { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul: V shape mismatch: want [%d, %d, D], got %v", BH, seqKV, vShape) - } - D := vShape[2] - - scoresPtr, scoresCleanup, err := getDevicePtr(e, scores) - if err != nil { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul scores: %w", err) - } - defer scoresCleanup() - - vPtr, vCleanup, err := getDevicePtr(e, V) - if err != nil { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul V: %w", err) - } - defer vCleanup() - - outElems := BH * D - outBytes := outElems * f32Size - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, fmt.Errorf("GPUFusedSoftmaxVMul alloc: %w", err) - } - - if err := e.kernels.FusedSoftmaxVMulF32(scoresPtr, vPtr, devOut, scale, BH, seqKV, D, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return nil, err - } - - outShape := []int{BH, 1, D} - return makeGPUResult[T](e, outShape, devOut, outElems) -} - -// GPUFusedAddRMSNorm computes sum = input + residual and -// normed = rmsnorm(sum, weight, eps) in a single GPU kernel launch. -// Both inputs are read-only; outputs go to separate buffers. -// This replaces Add + RMSNorm (2 kernel launches) with 1. -func (e *GPUEngine[T]) GPUFusedAddRMSNorm( - input, residual *tensor.TensorNumeric[T], - weight *tensor.TensorNumeric[T], - eps float32, -) (normed *tensor.TensorNumeric[T], residualOut *tensor.TensorNumeric[T], scales *tensor.TensorNumeric[T], err error) { - // FP16 paths — skip entirely for F32 compute. - if e.dtype != DTypeF32 { - // Native FP16 path: input and residual already have Float16Storage — no conversion needed. - inFS, inOK := any(input.GetStorage()).(*tensor.Float16Storage) - resFS, resOK := any(residual.GetStorage()).(*tensor.Float16Storage) - if inOK && resOK { - return fp16FusedAddRMSNormNative(e, inFS, resFS, input, weight, eps) - } - // FP16 path: decompose into F32 Add + FP16 RMSNorm. - return fp16FusedAddRMSNorm(e, input, residual, weight, eps) - } - - inShape := input.Shape() - if len(inShape) < 2 { - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm: input must be at least 2D, got %v", inShape) - } - D := inShape[len(inShape)-1] - rows := 1 - for i := 0; i < len(inShape)-1; i++ { - rows *= inShape[i] - } - - inPtr, inCleanup, err := getDevicePtr(e, input) - if err != nil { - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm input: %w", err) - } - defer inCleanup() - - // Residual is updated in-place. We need a mutable device pointer. - resPtr, resCleanup, err := getDevicePtr(e, residual) - if err != nil { - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm residual: %w", err) - } - defer resCleanup() - - wPtr, wCleanup, err := getDevicePtr(e, weight) - if err != nil { - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm weight: %w", err) - } - defer wCleanup() - - outBytes := rows * D * f32Size - e.setDevice() - devNormed, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm alloc normed: %w", err) - } - - devSum, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - e.pool.Free(e.deviceID, devNormed, outBytes) - return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm alloc sum: %w", err) - } - - if err := e.kernels.FusedAddRMSNormF32(inPtr, resPtr, wPtr, devNormed, devSum, eps, rows, D, e.stream); err != nil { - e.pool.Free(e.deviceID, devNormed, outBytes) - e.pool.Free(e.deviceID, devSum, outBytes) - return nil, nil, nil, err - } - - normed, err = makeGPUResult[T](e, inShape, devNormed, rows*D) - if err != nil { - e.pool.Free(e.deviceID, devSum, outBytes) - return nil, nil, nil, err - } - - residualOut, err = makeGPUResult[T](e, inShape, devSum, rows*D) - if err != nil { - return nil, nil, nil, err - } - - return normed, residualOut, nil, nil -} - -// GPUFusedNormAdd computes output = rmsnorm(input, weight, eps) + residual -// in a single GPU kernel launch. Replaces separate RMSNorm + Add (2 launches → 1). -func (e *GPUEngine[T]) GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error) { - inShape := input.Shape() - if len(inShape) < 2 { - return nil, fmt.Errorf("GPUFusedNormAdd: input must be at least 2D, got %v", inShape) - } - D := inShape[len(inShape)-1] - rows := 1 - for i := 0; i < len(inShape)-1; i++ { - rows *= inShape[i] - } - - inPtr, inCleanup, err := getDevicePtr(e, input) - if err != nil { - return nil, fmt.Errorf("GPUFusedNormAdd input: %w", err) - } - defer inCleanup() - - wPtr, wCleanup, err := getDevicePtr(e, weight) - if err != nil { - return nil, fmt.Errorf("GPUFusedNormAdd weight: %w", err) - } - defer wCleanup() - - resPtr, resCleanup, err := getDevicePtr(e, residual) - if err != nil { - return nil, fmt.Errorf("GPUFusedNormAdd residual: %w", err) - } - defer resCleanup() - - outElems := rows * D - outBytes := outElems * f32Size - e.setDevice() - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, fmt.Errorf("GPUFusedNormAdd alloc: %w", err) - } - - if err := e.kernels.FusedNormAddF32(inPtr, wPtr, resPtr, devOut, eps, rows, D, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return nil, err - } - - return makeGPUResult[T](e, inShape, devOut, outElems) -} - -// GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K heads -// in a single GPU kernel launch. This replaces 4 kernel launches per GQA layer. -// input: [totalHeads, headDim], weightQ/weightK: [headDim], -// cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim]. -func (e *GPUEngine[T]) GPUFusedQKNormRoPE( - input *tensor.TensorNumeric[T], - weightQ, weightK *tensor.TensorNumeric[T], - cosAngles, sinAngles *tensor.TensorNumeric[T], - eps float32, - totalHeads, headDim, numQHeads, halfRotary int, -) (*tensor.TensorNumeric[T], error) { - inPtr, inCleanup, err := getDevicePtr(e, input) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE input: %w", err) - } - defer inCleanup() - - wqPtr, wqCleanup, err := getDevicePtr(e, weightQ) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE weightQ: %w", err) - } - defer wqCleanup() - - wkPtr, wkCleanup, err := getDevicePtr(e, weightK) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE weightK: %w", err) - } - defer wkCleanup() - - cosPtr, cosCleanup, err := getDevicePtr(e, cosAngles) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE cos: %w", err) - } - defer cosCleanup() - - sinPtr, sinCleanup, err := getDevicePtr(e, sinAngles) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE sin: %w", err) - } - defer sinCleanup() - - outElems := totalHeads * headDim - outBytes := outElems * f32Size - e.setDevice() - devOut, err := e.pool.Alloc(e.deviceID, outBytes) - if err != nil { - return nil, fmt.Errorf("GPUFusedQKNormRoPE alloc: %w", err) - } - - if err := e.kernels.FusedQKNormRoPEF32(inPtr, wqPtr, wkPtr, cosPtr, sinPtr, devOut, eps, totalHeads, headDim, numQHeads, halfRotary, e.stream); err != nil { - e.pool.Free(e.deviceID, devOut, outBytes) - return nil, err - } - - return makeGPUResult[T](e, []int{totalHeads, headDim}, devOut, outElems) -} - // Sync synchronizes the GPU stream, blocking until all enqueued operations complete. // Use for benchmarking or when explicit synchronization is needed. func (e *GPUEngine[T]) Sync() error { @@ -3505,17 +2241,5 @@ func (e *GPUEngine[T]) Sync() error { return nil } -// CosineSimilarity computes pairwise cosine similarity between rows of two 2D tensors. -// a has shape [M, D], b has shape [N, D]. Result has shape [M, N]. -// Currently delegates to CPUEngine; a dedicated GPU kernel will be added later. -func (e *GPUEngine[T]) CosineSimilarity(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.cpu.CosineSimilarity(ctx, a, b, dst...) -} - -// HadamardTransform delegates to the CPU engine. -func (e *GPUEngine[T]) HadamardTransform(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { - return e.cpu.HadamardTransform(ctx, a, dst...) -} - // Static type assertion: GPUEngine satisfies Engine. var _ Engine[float32] = (*GPUEngine[float32])(nil) diff --git a/compute/gpu_engine_elementwise.go b/compute/gpu_engine_elementwise.go new file mode 100644 index 0000000..998a83e --- /dev/null +++ b/compute/gpu_engine_elementwise.go @@ -0,0 +1,400 @@ +package compute + +// gpu_engine_elementwise.go contains all element-wise operations, scalar +// operations, fused element-wise kernels (RoPE, SwiGLU, RMSNorm), and +// pairwise operations (CosineSimilarity, HadamardTransform) for GPUEngine. + +import ( + "context" + "fmt" + + "github.com/zerfoo/ztensor/tensor" +) + +// UnaryOp applies an arbitrary unary function element-wise (CPU fallback). +func (e *GPUEngine[T]) UnaryOp(ctx context.Context, a *tensor.TensorNumeric[T], op func(T) T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.cpu.UnaryOp(ctx, a, op, dst...) +} + +// Add performs element-wise addition. +func (e *GPUEngine[T]) Add(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuAdd(ctx, a, b, dst...) +} + +// Sub performs element-wise subtraction. +func (e *GPUEngine[T]) Sub(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuSub(ctx, a, b, dst...) +} + +// Mul performs element-wise multiplication. +func (e *GPUEngine[T]) Mul(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuMul(ctx, a, b, dst...) +} + +// Div performs element-wise division. +func (e *GPUEngine[T]) Div(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuDiv(ctx, a, b, dst...) +} + +// Exp computes the element-wise exponential. +func (e *GPUEngine[T]) Exp(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuExp(ctx, a, dst...) +} + +// Log computes the element-wise natural logarithm. +func (e *GPUEngine[T]) Log(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuLog(ctx, a, dst...) +} + +// Sin computes the element-wise sine. +func (e *GPUEngine[T]) Sin(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuSin(ctx, a, dst...) +} + +// Cos computes the element-wise cosine. +func (e *GPUEngine[T]) Cos(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuCos(ctx, a, dst...) +} + +// Tanh computes the element-wise hyperbolic tangent. +func (e *GPUEngine[T]) Tanh(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuTanh(ctx, a, dst...) +} + +// TanhPrime computes the element-wise gradient of tanh. +func (e *GPUEngine[T]) TanhPrime(ctx context.Context, a, upstream *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuTanhPrime(ctx, a, upstream, dst...) +} + +// Pow raises each element to the given power. +func (e *GPUEngine[T]) Pow(ctx context.Context, base, exponent *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuPow(ctx, base, exponent, dst...) +} + +// MulScalar multiplies each element by a scalar. +func (e *GPUEngine[T]) MulScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuMulScalar(ctx, a, scalar, dst...) +} + +// DivScalar divides each element by a scalar. +func (e *GPUEngine[T]) DivScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuDivScalar(ctx, a, scalar, dst...) +} + +// AddScalar adds a scalar to each element. +func (e *GPUEngine[T]) AddScalar(ctx context.Context, a *tensor.TensorNumeric[T], scalar T, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuAddScalar(ctx, a, scalar, dst...) +} + +// Sqrt computes the element-wise square root. +func (e *GPUEngine[T]) Sqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuSqrt(ctx, a, dst...) +} + +// Rsqrt computes the element-wise reciprocal square root. +func (e *GPUEngine[T]) Rsqrt(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuRsqrt(ctx, a, dst...) +} + +// CosineSimilarity computes pairwise cosine similarity between rows of two 2D tensors. +// a has shape [M, D], b has shape [N, D]. Result has shape [M, N]. +// Currently delegates to CPUEngine; a dedicated GPU kernel will be added later. +func (e *GPUEngine[T]) CosineSimilarity(ctx context.Context, a, b *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.cpu.CosineSimilarity(ctx, a, b, dst...) +} + +// HadamardTransform delegates to the CPU engine. +func (e *GPUEngine[T]) HadamardTransform(ctx context.Context, a *tensor.TensorNumeric[T], dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.cpu.HadamardTransform(ctx, a, dst...) +} + +// GPUFusedRoPE applies rotary position embeddings in a single GPU kernel launch. +// This replaces Split + 4 Mul + Sub + Add + Concat (8 operations, ~10 D2D memcpy) with 1 kernel. +func (e *GPUEngine[T]) GPUFusedRoPE(input, cosAngles, sinAngles *tensor.TensorNumeric[T], rotaryDim int) (*tensor.TensorNumeric[T], error) { + shape := input.Shape() + if len(shape) != 3 { + return nil, fmt.Errorf("GPUFusedRoPE: expected 3D input [batch, seq, dim], got %dD", len(shape)) + } + + batch := shape[0] + seqLen := shape[1] + headDim := shape[2] + halfRotary := rotaryDim / 2 + + cosShape := cosAngles.Shape() + if len(cosShape) != 2 || cosShape[0] < seqLen || cosShape[1] < halfRotary { + return nil, fmt.Errorf("GPUFusedRoPE: cos shape %v incompatible with seq_len=%d half_rotary=%d", cosShape, seqLen, halfRotary) + } + cosStride := cosShape[1] + + // Get device pointers for input, cos, sin. + inPtr, inCleanup, err := getDevicePtr(e, input) + if err != nil { + return nil, fmt.Errorf("GPUFusedRoPE input: %w", err) + } + defer inCleanup() + + cosPtr, cosCleanup, err := getDevicePtr(e, cosAngles) + if err != nil { + return nil, fmt.Errorf("GPUFusedRoPE cos: %w", err) + } + defer cosCleanup() + + sinPtr, sinCleanup, err := getDevicePtr(e, sinAngles) + if err != nil { + return nil, fmt.Errorf("GPUFusedRoPE sin: %w", err) + } + defer sinCleanup() + + // Allocate output. + outElems := batch * seqLen * headDim + outBytes := outElems * f32Size + e.setDevice() + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, fmt.Errorf("GPUFusedRoPE alloc: %w", err) + } + + if err := e.kernels.FusedRoPEF32(inPtr, cosPtr, sinPtr, devOut, batch, seqLen, headDim, halfRotary, cosStride, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return nil, err + } + + return makeGPUResult[T](e, shape, devOut, outElems) +} + +// GPUFusedSwiGLU computes SwiGLU(w1, w3) = w1 * sigmoid(w1) * w3 in a single GPU kernel. +// This replaces Concat + Split + sigmoid + Mul + Mul (5 operations, ~4 D2D memcpy per layer) with 1 kernel. +func (e *GPUEngine[T]) GPUFusedSwiGLU(w1, w3 *tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + w1Shape := w1.Shape() + w3Shape := w3.Shape() + if len(w1Shape) == 0 || len(w3Shape) == 0 { + return nil, fmt.Errorf("GPUFusedSwiGLU: empty shape") + } + + // Validate shapes match. + n1 := 1 + for _, d := range w1Shape { + n1 *= d + } + n3 := 1 + for _, d := range w3Shape { + n3 *= d + } + if n1 != n3 { + return nil, fmt.Errorf("GPUFusedSwiGLU: w1 (%d elems) and w3 (%d elems) size mismatch", n1, n3) + } + + w1Ptr, w1Cleanup, err := getDevicePtr(e, w1) + if err != nil { + return nil, fmt.Errorf("GPUFusedSwiGLU w1: %w", err) + } + defer w1Cleanup() + + w3Ptr, w3Cleanup, err := getDevicePtr(e, w3) + if err != nil { + return nil, fmt.Errorf("GPUFusedSwiGLU w3: %w", err) + } + defer w3Cleanup() + + outBytes := n1 * f32Size + e.setDevice() + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, fmt.Errorf("GPUFusedSwiGLU alloc: %w", err) + } + + if err := e.kernels.FusedSwiGLUF32(w1Ptr, w3Ptr, devOut, n1, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return nil, err + } + + return makeGPUResult[T](e, w1Shape, devOut, n1) +} + +// GPUFusedAddRMSNorm computes sum = input + residual and +// normed = rmsnorm(sum, weight, eps) in a single GPU kernel launch. +// Both inputs are read-only; outputs go to separate buffers. +// This replaces Add + RMSNorm (2 kernel launches) with 1. +func (e *GPUEngine[T]) GPUFusedAddRMSNorm( + input, residual *tensor.TensorNumeric[T], + weight *tensor.TensorNumeric[T], + eps float32, +) (normed *tensor.TensorNumeric[T], residualOut *tensor.TensorNumeric[T], scales *tensor.TensorNumeric[T], err error) { + // FP16 paths — skip entirely for F32 compute. + if e.dtype != DTypeF32 { + // Native FP16 path: input and residual already have Float16Storage — no conversion needed. + inFS, inOK := any(input.GetStorage()).(*tensor.Float16Storage) + resFS, resOK := any(residual.GetStorage()).(*tensor.Float16Storage) + if inOK && resOK { + return fp16FusedAddRMSNormNative(e, inFS, resFS, input, weight, eps) + } + // FP16 path: decompose into F32 Add + FP16 RMSNorm. + return fp16FusedAddRMSNorm(e, input, residual, weight, eps) + } + + inShape := input.Shape() + if len(inShape) < 2 { + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm: input must be at least 2D, got %v", inShape) + } + D := inShape[len(inShape)-1] + rows := 1 + for i := 0; i < len(inShape)-1; i++ { + rows *= inShape[i] + } + + inPtr, inCleanup, err := getDevicePtr(e, input) + if err != nil { + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm input: %w", err) + } + defer inCleanup() + + // Residual is updated in-place. We need a mutable device pointer. + resPtr, resCleanup, err := getDevicePtr(e, residual) + if err != nil { + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm residual: %w", err) + } + defer resCleanup() + + wPtr, wCleanup, err := getDevicePtr(e, weight) + if err != nil { + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm weight: %w", err) + } + defer wCleanup() + + outBytes := rows * D * f32Size + e.setDevice() + devNormed, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm alloc normed: %w", err) + } + + devSum, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + e.pool.Free(e.deviceID, devNormed, outBytes) + return nil, nil, nil, fmt.Errorf("GPUFusedAddRMSNorm alloc sum: %w", err) + } + + if err := e.kernels.FusedAddRMSNormF32(inPtr, resPtr, wPtr, devNormed, devSum, eps, rows, D, e.stream); err != nil { + e.pool.Free(e.deviceID, devNormed, outBytes) + e.pool.Free(e.deviceID, devSum, outBytes) + return nil, nil, nil, err + } + + normed, err = makeGPUResult[T](e, inShape, devNormed, rows*D) + if err != nil { + e.pool.Free(e.deviceID, devSum, outBytes) + return nil, nil, nil, err + } + + residualOut, err = makeGPUResult[T](e, inShape, devSum, rows*D) + if err != nil { + return nil, nil, nil, err + } + + return normed, residualOut, nil, nil +} + +// GPUFusedNormAdd computes output = rmsnorm(input, weight, eps) + residual +// in a single GPU kernel launch. Replaces separate RMSNorm + Add (2 launches → 1). +func (e *GPUEngine[T]) GPUFusedNormAdd(input, weight, residual *tensor.TensorNumeric[T], eps float32) (*tensor.TensorNumeric[T], error) { + inShape := input.Shape() + if len(inShape) < 2 { + return nil, fmt.Errorf("GPUFusedNormAdd: input must be at least 2D, got %v", inShape) + } + D := inShape[len(inShape)-1] + rows := 1 + for i := 0; i < len(inShape)-1; i++ { + rows *= inShape[i] + } + + inPtr, inCleanup, err := getDevicePtr(e, input) + if err != nil { + return nil, fmt.Errorf("GPUFusedNormAdd input: %w", err) + } + defer inCleanup() + + wPtr, wCleanup, err := getDevicePtr(e, weight) + if err != nil { + return nil, fmt.Errorf("GPUFusedNormAdd weight: %w", err) + } + defer wCleanup() + + resPtr, resCleanup, err := getDevicePtr(e, residual) + if err != nil { + return nil, fmt.Errorf("GPUFusedNormAdd residual: %w", err) + } + defer resCleanup() + + outElems := rows * D + outBytes := outElems * f32Size + e.setDevice() + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, fmt.Errorf("GPUFusedNormAdd alloc: %w", err) + } + + if err := e.kernels.FusedNormAddF32(inPtr, wPtr, resPtr, devOut, eps, rows, D, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return nil, err + } + + return makeGPUResult[T](e, inShape, devOut, outElems) +} + +// GPUFusedQKNormRoPE applies per-head RMSNorm + RoPE to combined Q+K heads +// in a single GPU kernel launch. This replaces 4 kernel launches per GQA layer. +// input: [totalHeads, headDim], weightQ/weightK: [headDim], +// cosAngles/sinAngles: [halfRotary], output: [totalHeads, headDim]. +func (e *GPUEngine[T]) GPUFusedQKNormRoPE( + input *tensor.TensorNumeric[T], + weightQ, weightK *tensor.TensorNumeric[T], + cosAngles, sinAngles *tensor.TensorNumeric[T], + eps float32, + totalHeads, headDim, numQHeads, halfRotary int, +) (*tensor.TensorNumeric[T], error) { + inPtr, inCleanup, err := getDevicePtr(e, input) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE input: %w", err) + } + defer inCleanup() + + wqPtr, wqCleanup, err := getDevicePtr(e, weightQ) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE weightQ: %w", err) + } + defer wqCleanup() + + wkPtr, wkCleanup, err := getDevicePtr(e, weightK) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE weightK: %w", err) + } + defer wkCleanup() + + cosPtr, cosCleanup, err := getDevicePtr(e, cosAngles) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE cos: %w", err) + } + defer cosCleanup() + + sinPtr, sinCleanup, err := getDevicePtr(e, sinAngles) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE sin: %w", err) + } + defer sinCleanup() + + outElems := totalHeads * headDim + outBytes := outElems * f32Size + e.setDevice() + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, fmt.Errorf("GPUFusedQKNormRoPE alloc: %w", err) + } + + if err := e.kernels.FusedQKNormRoPEF32(inPtr, wqPtr, wkPtr, cosPtr, sinPtr, devOut, eps, totalHeads, headDim, numQHeads, halfRotary, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return nil, err + } + + return makeGPUResult[T](e, []int{totalHeads, headDim}, devOut, outElems) +} diff --git a/compute/gpu_engine_memory.go b/compute/gpu_engine_memory.go new file mode 100644 index 0000000..be6d118 --- /dev/null +++ b/compute/gpu_engine_memory.go @@ -0,0 +1,695 @@ +package compute + +// gpu_engine_memory.go contains data movement, layout transformation, +// and memory operations for GPUEngine: Copy, Zero, Gather, Scatter, +// Transpose, Split, Concat, Repeat, Reshape, and format conversion. + +import ( + "context" + "fmt" + "os" + "unsafe" + + "github.com/zerfoo/ztensor/internal/gpuapi" + "github.com/zerfoo/ztensor/tensor" +) + +// Transpose transposes a tensor along the given axes. +func (e *GPUEngine[T]) Transpose(ctx context.Context, a *tensor.TensorNumeric[T], axes []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() { + return e.cpu.Transpose(ctx, a, axes, dst...) + } + + // Only use GPU path for GPU-resident tensors (Phase 6 behavior). + // CPU-backed tensors fall back to CPU transpose to avoid unexpected + // H2D copies that may interfere with CUDA graph capture/replay. + _, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) + isFP16 := false + if e.dtype != DTypeF32 { + _, isFP16 = any(a.GetStorage()).(*tensor.Float16Storage) + } + if !isGPU && !isFP16 { + return e.cpu.Transpose(ctx, a, axes, dst...) + } + + e.setDevice() + + shape := a.Shape() + rank := len(shape) + + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE: shape=%v rank=%d axes=%v storage=%T\n", shape, rank, axes, a.GetStorage()) + } + + // Default: reverse axes (same as CPU Transpose with nil axes). + if len(axes) == 0 { + axes = make([]int, rank) + for i := range rank { + axes[i] = rank - 1 - i + } + } + + if len(axes) != rank { + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=axes_rank_mismatch shape=%v\n", shape) + } + return e.cpu.Transpose(ctx, a, axes, dst...) + } + + // GPU transpose kernel supports up to 4D; fall back to CPU for higher ranks. + if rank > 4 { + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=rank_gt_4 shape=%v\n", shape) + } + return e.cpu.Transpose(ctx, a, axes, dst...) + } + + // Compute output shape. + outShape := make([]int, rank) + for i, ax := range axes { + outShape[i] = shape[ax] + } + + // Fast path: if the transpose only swaps unit-sized dimensions, it is + // equivalent to a reshape (no data movement). This is common during + // single-token generation where seqLen=1. Check by comparing the + // non-unit dimensions in input vs output order. + if isTransposeReshape(shape, outShape) { + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE: reshape fast path shape=%v outShape=%v storage=%T\n", shape, outShape, a.GetStorage()) + } + if e.dtype != DTypeF32 { + if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok { + storageT := any(fs).(tensor.Storage[T]) + t, tErr := tensor.NewWithStorage[T](outShape, storageT) + if tErr != nil { + return nil, tErr + } + return t, nil + } + } + gs := a.GetStorage().(*tensor.GPUStorage[T]) + viewGS := gs.View(gs.Len()) + t, tErr := tensor.NewWithStorage[T](outShape, viewGS) + if tErr != nil { + return nil, tErr + } + if len(dst) > 0 && dst[0] != nil { + dst[0].SetStorage(viewGS) + dst[0].SetShape(outShape) + return dst[0], nil + } + return t, nil + } + + // Compute total elements. + total := 1 + for _, d := range shape { + total *= d + } + + // Compute input strides. + inStrides := make([]int, rank) + stride := 1 + for i := rank - 1; i >= 0; i-- { + inStrides[i] = stride + stride *= shape[i] + } + + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr: storage=%T\n", a.GetStorage()) + } + devIn, cleanupIn, err := getDevicePtr(e, a) + if err != nil { + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE CPU FALLBACK: reason=getDevicePtr_failed shape=%v\n", shape) + } + return e.cpu.Transpose(ctx, a, axes, dst...) + } + defer cleanupIn() + if debugGPU { + fmt.Fprintf(os.Stderr, "TRANSPOSE getDevicePtr OK: ptr=%p\n", devIn) + } + + byteSize := total * f32Size + devOut, err := e.pool.Alloc(e.deviceID, byteSize) + if err != nil { + return e.cpu.Transpose(ctx, a, axes, dst...) + } + + // Fast path: 2D transpose. + if rank == 2 && axes[0] == 1 && axes[1] == 0 { + if debugGPU { + e.logger.Debug("TRANSPOSE: using 2D fast path", + "rows", fmt.Sprintf("%d", shape[0]), + "cols", fmt.Sprintf("%d", shape[1])) + } + if err := e.kernels.Transpose2D(devIn, devOut, shape[0], shape[1], e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, byteSize) + return nil, err + } + return makeGPUResult[T](e, outShape, devOut, total, dst...) + } + + // General N-D transpose via stride-based kernel. + // Precompute output strides on the host so the kernel avoids O(ndim^2) per thread. + if debugGPU { + e.logger.Debug("TRANSPOSE: using general N-D path", + "rank", fmt.Sprintf("%d", rank), + "axes", fmt.Sprintf("%v", axes)) + } + outStrides := make([]int, rank) + outStride := 1 + for i := rank - 1; i >= 0; i-- { + outStrides[i] = outStride + outStride *= outShape[i] + } + + inStrides32 := make([]int32, rank) + outStrides32 := make([]int32, rank) + perm32 := make([]int32, rank) + for i := range rank { + inStrides32[i] = int32(inStrides[i]) + outStrides32[i] = int32(outStrides[i]) + perm32[i] = int32(axes[i]) + } + + if err := e.kernels.TransposeND(devIn, devOut, inStrides32, outStrides32, perm32, rank, total, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, byteSize) + return nil, err + } + + return makeGPUResult[T](e, outShape, devOut, total, dst...) +} + +// Zero sets all elements to zero. +func (e *GPUEngine[T]) Zero(ctx context.Context, a *tensor.TensorNumeric[T]) error { + // GPU path: use cudaMemsetAsync on the engine's stream. + if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok { + return e.runtime.MemsetAsync(gs.Ptr(), 0, gs.ByteSize(), e.stream) + } + // CPU fallback for non-GPU tensors. + return e.cpu.Zero(ctx, a) +} + +// Zeros fills the tensor with zeros. +func (e *GPUEngine[T]) Zeros(ctx context.Context, a *tensor.TensorNumeric[T], shape []int) error { + return e.cpu.Zeros(ctx, a, shape) +} + +// Copy copies data from source to destination tensor. +func (e *GPUEngine[T]) Copy(ctx context.Context, dst, src *tensor.TensorNumeric[T]) error { + dstGS, dstIsGPU := dst.GetStorage().(*tensor.GPUStorage[T]) + srcGS, srcIsGPU := src.GetStorage().(*tensor.GPUStorage[T]) + if dstIsGPU && srcIsGPU { + // D2D copy on engine stream. + return e.runtime.MemcpyAsync(dstGS.Ptr(), srcGS.Ptr(), dstGS.ByteSize(), gpuapi.MemcpyDeviceToDevice, e.stream) + } + // Fall back to CPU for mixed or CPU-only tensors. + return e.cpu.Copy(ctx, dst, src) +} + +// Gather performs an embedding-style gather. +func (e *GPUEngine[T]) Gather(ctx context.Context, params *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], output *tensor.TensorNumeric[T]) error { + if !isFloat32[T]() { + return e.cpu.Gather(ctx, params, indices, output) + } + + // Q8 GPU gather: dequantize only the requested rows on GPU. + if qs, ok := any(params.GetStorage()).(*tensor.Q8Storage); ok { + if ptr, _, _ := qs.GPUPtr(); ptr != nil { + return e.gatherQ8(params, indices, output, qs, ptr) + } + } + + // Check whether params are GPU-resident (F32 or FP16 storage). + _, isGPU := params.GetStorage().(*tensor.GPUStorage[T]) + var fp16Stor *tensor.Float16Storage + isFP16 := false + if e.dtype != DTypeF32 { + fp16Stor, isFP16 = any(params.GetStorage()).(*tensor.Float16Storage) + } + if !isGPU && !isFP16 { + return e.cpu.Gather(ctx, params, indices, output) + } + + e.setDevice() + + pShape := params.Shape() + if len(pShape) != 2 { + return e.cpu.Gather(ctx, params, indices, output) + } + V := pShape[0] + D := pShape[1] + + // Flatten indices to get N. + idxData := indices.Data() + N := len(idxData) + if N == 0 { + return nil + } + + // Get device pointer for params. For Float16Storage, convert FP16->F32 + // into a temporary buffer so the F32 Gather kernel can operate on it. + var devParams unsafe.Pointer + var cleanupParams func() + if isFP16 { + fp16Ptr, _, _ := fp16Stor.GPUPtr() + if fp16Ptr == nil { + return e.cpu.Gather(ctx, params, indices, output) + } + nElems := V * D + f32Bytes := nElems * f32Size + f32Ptr, err := e.pool.Alloc(e.deviceID, f32Bytes) + if err != nil { + return e.cpu.Gather(ctx, params, indices, output) + } + if err := e.kernels.FP16ToF32(fp16Ptr, f32Ptr, nElems, e.stream); err != nil { + e.pool.Free(e.deviceID, f32Ptr, f32Bytes) + return e.cpu.Gather(ctx, params, indices, output) + } + devParams = f32Ptr + cleanupParams = func() { e.pool.Free(e.deviceID, f32Ptr, f32Bytes) } + } else { + var err error + devParams, cleanupParams, err = getDevicePtr(e, params) + if err != nil { + return e.cpu.Gather(ctx, params, indices, output) + } + } + defer cleanupParams() + + // Upload indices to GPU as int64 (Go int on 64-bit platforms). + // The gather kernel accepts int64 indices directly, avoiding the + // CPU-side int64→int32 conversion that would trigger a D2H copy + // for GPU-resident indices and block CUDA graph capture. + intSize := int(unsafe.Sizeof(int(0))) + idxByteSize := N * intSize + devIdx, err := e.pool.Alloc(e.deviceID, idxByteSize) + if err != nil { + return e.cpu.Gather(ctx, params, indices, output) + } + defer e.pool.Free(e.deviceID, devIdx, idxByteSize) + + if err := e.runtime.Memcpy(devIdx, unsafe.Pointer(&idxData[0]), idxByteSize, gpuapi.MemcpyHostToDevice); err != nil { + return e.cpu.Gather(ctx, params, indices, output) + } + + // Allocate output on GPU. + outByteSize := N * D * f32Size + devOut, err := e.pool.Alloc(e.deviceID, outByteSize) + if err != nil { + return e.cpu.Gather(ctx, params, indices, output) + } + + if err := e.kernels.Gather(devParams, devIdx, devOut, N, D, V, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outByteSize) + return fmt.Errorf("GPU Gather: %w", err) + } + + // When dtype is FP16, convert the F32 gather output to FP16 on GPU. + // This is the single F32->FP16 conversion point for the entire forward pass; + // all downstream ops receive Float16Storage and operate in FP16 natively. + if e.dtype == DTypeFP16 { + outElems := N * D + fp16Bytes := outElems * fp16Size + fp16Ptr, err := e.pool.Alloc(e.deviceID, fp16Bytes) + if err != nil { + e.pool.Free(e.deviceID, devOut, outByteSize) + return fmt.Errorf("Gather FP16 alloc: %w", err) + } + if err := e.kernels.F32ToFP16(devOut, fp16Ptr, outElems, e.stream); err != nil { + e.pool.Free(e.deviceID, fp16Ptr, fp16Bytes) + e.pool.Free(e.deviceID, devOut, outByteSize) + return fmt.Errorf("Gather F32->FP16: %w", err) + } + e.pool.Free(e.deviceID, devOut, outByteSize) + fs := any(tensor.NewFloat16StorageGPU(fp16Ptr, outElems, e.deviceID)).(tensor.Storage[T]) + output.SetStorage(fs) + return nil + } + + // Set output storage to GPU (pool-backed so Free returns to pool, not cudaFree). + gs, err := tensor.NewGPUStorageFromPool[T](devOut, N*D, e.pool, e.deviceID) + if err != nil { + e.pool.Free(e.deviceID, devOut, outByteSize) + return err + } + output.SetStorage(gs) + + return nil +} + +// gatherQ8 performs Q8_0 embedding gather on GPU using the Q8 gather kernel. +// Dequantizes only the requested rows, keeping the full Q8 table compressed. +func (e *GPUEngine[T]) gatherQ8( + params *tensor.TensorNumeric[T], + indices *tensor.TensorNumeric[int], + output *tensor.TensorNumeric[T], + qs *tensor.Q8Storage, + devQ8 unsafe.Pointer, +) error { + e.setDevice() + + pShape := params.Shape() + V := pShape[0] + D := pShape[1] + + idxData := indices.Data() + N := len(idxData) + if N == 0 { + return nil + } + + // Upload indices as int32 to GPU. + idx32 := make([]int32, N) + for i, id := range idxData { + idx32[i] = int32(id) + } + idxBytes := N * 4 + devIdx, err := e.pool.Alloc(e.deviceID, idxBytes) + if err != nil { + return e.cpu.Gather(context.Background(), params, indices, output) + } + defer e.pool.Free(e.deviceID, devIdx, idxBytes) + + if err := e.runtime.Memcpy(devIdx, unsafe.Pointer(&idx32[0]), idxBytes, gpuapi.MemcpyHostToDevice); err != nil { + return e.cpu.Gather(context.Background(), params, indices, output) + } + + // Allocate output [N, D] on GPU. + outElems := N * D + outBytes := outElems * f32Size + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return e.cpu.Gather(context.Background(), params, indices, output) + } + + // Launch Q8 gather kernel. + if err := e.kernels.GatherQ8F32(devQ8, devIdx, devOut, N, D, V, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.cpu.Gather(context.Background(), params, indices, output) + } + + // Write result into output tensor as GPUStorage (pool-backed). + gs, err := tensor.NewGPUStorageFromPool[float32](devOut, outElems, e.pool, e.deviceID) + if err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return fmt.Errorf("gatherQ8: create GPU storage: %w", err) + } + output.SetStorage(any(gs).(tensor.Storage[T])) + return nil +} + +// ScatterAdd performs a row-wise scatter-add for embeddings. +func (e *GPUEngine[T]) ScatterAdd(ctx context.Context, dEmbeddingTable *tensor.TensorNumeric[T], indices *tensor.TensorNumeric[int], dOut *tensor.TensorNumeric[T]) error { + return e.cpu.ScatterAdd(ctx, dEmbeddingTable, indices, dOut) +} + +// RandomUniform fills the tensor with uniform random values. +func (e *GPUEngine[T]) RandomUniform(ctx context.Context, t *tensor.TensorNumeric[T], minVal, maxVal T) error { + return e.cpu.RandomUniform(ctx, t, minVal, maxVal) +} + +// Fill fills the tensor with a scalar value. +func (e *GPUEngine[T]) Fill(ctx context.Context, t *tensor.TensorNumeric[T], value T) error { + return e.gpuFill(ctx, t, value) +} + +// Split splits a tensor into multiple tensors along an axis. +func (e *GPUEngine[T]) Split(ctx context.Context, a *tensor.TensorNumeric[T], numSplits int, axis int) ([]*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() { + return e.cpu.Split(ctx, a, numSplits, axis) + } + if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok { + return e.gpuSplit(gs.Ptr(), a.Shape(), numSplits, axis) + } + if e.dtype != DTypeF32 { + if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok { + ptr, _, _ := fs.GPUPtr() + if ptr != nil { + return e.gpuSplitFP16(ptr, a.Shape(), numSplits, axis) + } + } + } + return e.cpu.Split(ctx, a, numSplits, axis) +} + +// Concat concatenates tensors along an axis. +func (e *GPUEngine[T]) Concat(ctx context.Context, tensors []*tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() || len(tensors) == 0 { + return e.cpu.Concat(ctx, tensors, axis, dst...) + } + // Check all inputs are GPU-resident (GPUStorage or Float16Storage). + ptrs := make([]unsafe.Pointer, len(tensors)) + allFP16 := true + for i, t := range tensors { + if gs, ok := t.GetStorage().(*tensor.GPUStorage[T]); ok { + ptrs[i] = gs.Ptr() + allFP16 = false + } else if e.dtype != DTypeF32 { + if fs, ok := any(t.GetStorage()).(*tensor.Float16Storage); ok { + p, _, _ := fs.GPUPtr() + if p == nil { + return e.cpu.Concat(ctx, tensors, axis, dst...) + } + ptrs[i] = p + } else { + return e.cpu.Concat(ctx, tensors, axis, dst...) + } + } else { + return e.cpu.Concat(ctx, tensors, axis, dst...) + } + } + if allFP16 && e.dtype != DTypeF32 { + return e.gpuConcatFP16(ptrs, tensors, axis, dst...) + } + return e.gpuConcat(ptrs, tensors, axis, dst...) +} + +// Repeat repeats the tensor along an axis. +func (e *GPUEngine[T]) Repeat(ctx context.Context, a *tensor.TensorNumeric[T], axis int, repetitions int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() || a == nil { + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + + shape := a.Shape() + if axis < 0 || axis >= len(shape) { + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + if repetitions <= 0 { + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + + // Get device pointer (handles GPUStorage[T] and Float16Storage). + isFP16 := false + if e.dtype != DTypeF32 { + _, isFP16 = any(a.GetStorage()).(*tensor.Float16Storage) + } + gs, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) + if !isGPU && !isFP16 { + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + + e.setDevice() + + newShape := make([]int, len(shape)) + copy(newShape, shape) + newShape[axis] *= repetitions + + outElems := 1 + for _, d := range newShape { + outElems *= d + } + outBytes := outElems * f32Size + + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + + var devA unsafe.Pointer + var cleanupA func() + if isGPU { + devA = gs.Ptr() + cleanupA = func() {} + } else { + // Float16Storage: convert FP16→F32 for the F32 repeat kernel. + f32Engine, ok := any(e).(*GPUEngine[float32]) + if !ok { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + devA, cleanupA, err = getDevicePtr(f32Engine, any(a).(*tensor.TensorNumeric[float32])) + if err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + } + defer cleanupA() + + // Compute dimensions for the repeat kernel. + outerSize := 1 + for i := 0; i < axis; i++ { + outerSize *= shape[i] + } + axisDim := shape[axis] + innerSize := 1 + for i := axis + 1; i < len(shape); i++ { + innerSize *= shape[i] + } + + if err := e.kernels.Repeat(devA, devOut, outerSize, axisDim, innerSize, repetitions, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + + // For FP16 inputs, convert the F32 output back to Float16Storage. + if isFP16 { + fp16Bytes := outElems * fp16Size + fp16Out, allocErr := e.pool.Alloc(e.deviceID, fp16Bytes) + if allocErr != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + if convErr := e.kernels.F32ToFP16(devOut, fp16Out, outElems, e.stream); convErr != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + e.pool.Free(e.deviceID, fp16Out, fp16Bytes) + return e.cpu.Repeat(ctx, a, axis, repetitions, dst...) + } + e.pool.Free(e.deviceID, devOut, outBytes) + fs := tensor.NewFloat16StorageGPU(fp16Out, outElems, e.deviceID) + storageT := any(fs).(tensor.Storage[T]) + return tensor.NewWithStorage[T](newShape, storageT) + } + + return makeGPUResult[T](e, newShape, devOut, outElems, dst...) +} + +// RepeatInterleave expands a 4D tensor from [B, numKV, S, D] to [B, numQ, S, D] +// by repeating each head along axis 1 (the head dimension) `reps` times. +// This is a fused kernel for GQA key/value head expansion, replacing the +// Reshape -> Repeat -> Reshape chain with a single kernel launch. +// axis must be 1 and the input must be 4D [B, numKV, S, D]. +func (e *GPUEngine[T]) RepeatInterleave(ctx context.Context, a *tensor.TensorNumeric[T], axis int, reps int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + shape := a.Shape() + if !isFloat32[T]() || a == nil || axis != 1 || len(shape) != 4 || reps <= 0 { + // Fall back to generic Repeat path for unsupported configurations. + return e.Repeat(ctx, a, axis, reps, dst...) + } + + B, numKV, S, D := shape[0], shape[1], shape[2], shape[3] + numQ := numKV * reps + + gs, isGPU := a.GetStorage().(*tensor.GPUStorage[T]) + if !isGPU { + return e.Repeat(ctx, a, axis, reps, dst...) + } + + e.setDevice() + + outElems := B * numQ * S * D + outBytes := outElems * f32Size + + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return e.Repeat(ctx, a, axis, reps, dst...) + } + + if err := e.kernels.RepeatInterleaveF32(gs.Ptr(), devOut, B, numKV, S, D, reps, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return e.Repeat(ctx, a, axis, reps, dst...) + } + + outShape := []int{B, numQ, S, D} + return makeGPUResult[T](e, outShape, devOut, outElems, dst...) +} + +// OneHot creates a one-hot encoding. +func (e *GPUEngine[T]) OneHot(ctx context.Context, input *tensor.TensorNumeric[int], depth int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.cpu.OneHot(ctx, input, depth, dst...) +} + +// Reshape changes the shape without changing data. +func (e *GPUEngine[T]) Reshape(ctx context.Context, a *tensor.TensorNumeric[T], shape []int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + // Resolve -1 dimension and verify size. + currentSize := a.Size() + inferredShape := make([]int, len(shape)) + copy(inferredShape, shape) + inferIdx := -1 + knownSize := 1 + for i, d := range inferredShape { + if d == -1 { + inferIdx = i + } else { + knownSize *= d + } + } + if inferIdx >= 0 { + inferredShape[inferIdx] = currentSize / knownSize + } + newSize := 1 + for _, d := range inferredShape { + newSize *= d + } + + // Float16Storage: zero-copy reshape (same GPU pointer, new shape). + if e.dtype != DTypeF32 { + if fs, ok := any(a.GetStorage()).(*tensor.Float16Storage); ok && newSize == currentSize { + return tensor.NewWithStorage[T](inferredShape, any(fs).(tensor.Storage[T])) + } + } + + // GPUStorage[T]: zero-copy reshape. + if gs, ok := a.GetStorage().(*tensor.GPUStorage[T]); ok && isFloat32[T]() && newSize == currentSize { + return tensor.NewWithStorage[T](inferredShape, gs.View(gs.Len())) + } + + return e.cpu.Reshape(ctx, a, shape, dst...) +} + +// ConvertFP16ToF32 converts a tensor with Float16Storage to a regular float32 +// GPU tensor using the FP16->F32 kernel. Returns the input unchanged if it +// does not have Float16Storage. +func (e *GPUEngine[T]) ConvertFP16ToF32(t *tensor.TensorNumeric[float32]) (*tensor.TensorNumeric[float32], error) { + fs, ok := any(t.GetStorage()).(*tensor.Float16Storage) + if !ok { + return t, nil + } + + fp16Ptr, _, _ := fs.GPUPtr() + if fp16Ptr == nil { + // CPU-side Float16Storage: decode via Slice (no GPU conversion possible). + data := fs.Slice() + out, err := tensor.New(t.Shape(), data) + if err != nil { + return nil, fmt.Errorf("ConvertFP16ToF32: create f32 tensor: %w", err) + } + return out, nil + } + + e.setDevice() + + nElems := fs.Len() + f32Bytes := nElems * f32Size + f32Ptr, err := e.pool.Alloc(e.deviceID, f32Bytes) + if err != nil { + return nil, fmt.Errorf("ConvertFP16ToF32: alloc: %w", err) + } + + if err := e.kernels.FP16ToF32(fp16Ptr, f32Ptr, nElems, e.stream); err != nil { + e.pool.Free(e.deviceID, f32Ptr, f32Bytes) + return nil, fmt.Errorf("ConvertFP16ToF32: kernel: %w", err) + } + + gsOut, err := tensor.NewGPUStorageFromPool[float32](f32Ptr, nElems, e.pool, e.deviceID) + if err != nil { + e.pool.Free(e.deviceID, f32Ptr, f32Bytes) + return nil, fmt.Errorf("ConvertFP16ToF32: gpu storage: %w", err) + } + out, err := tensor.NewWithStorage[float32](t.Shape(), gsOut) + if err != nil { + return nil, fmt.Errorf("ConvertFP16ToF32: wrap tensor: %w", err) + } + return out, nil +} diff --git a/compute/gpu_engine_reduction.go b/compute/gpu_engine_reduction.go new file mode 100644 index 0000000..78be741 --- /dev/null +++ b/compute/gpu_engine_reduction.go @@ -0,0 +1,221 @@ +package compute + +// gpu_engine_reduction.go contains reduction operations (Softmax, Sum, +// ReduceSum, ReduceMax, ReduceMean), argmax, and fused scaled-softmax +// kernels for GPUEngine. + +import ( + "context" + "fmt" + "unsafe" + + "github.com/zerfoo/ztensor/internal/gpuapi" + "github.com/zerfoo/ztensor/tensor" +) + +// Sum computes the sum of elements along an axis. +func (e *GPUEngine[T]) Sum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuSum(ctx, a, axis, keepDims, dst...) +} + +// Softmax applies the softmax function along an axis. +func (e *GPUEngine[T]) Softmax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuSoftmax(ctx, a, axis, dst...) +} + +// ReduceSum computes the sum of elements along an axis. +func (e *GPUEngine[T]) ReduceSum(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuReduceSum(ctx, a, axis, keepDims, dst...) +} + +// ReduceMax computes the maximum of elements along an axis (CPU fallback). +func (e *GPUEngine[T]) ReduceMax(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.cpu.ReduceMax(ctx, a, axis, keepDims, dst...) +} + +// ReduceMean computes the mean of elements along an axis. +func (e *GPUEngine[T]) ReduceMean(ctx context.Context, a *tensor.TensorNumeric[T], axis int, keepDims bool, dst ...*tensor.TensorNumeric[T]) (*tensor.TensorNumeric[T], error) { + return e.gpuReduceMean(ctx, a, axis, keepDims, dst...) +} + +// GPUArgmax finds the index of the maximum element in a GPU-resident float32 tensor. +// Returns the index as an int without copying the full tensor to the host. +// Only copies back a single int32 (4 bytes) instead of the entire tensor. +func (e *GPUEngine[T]) GPUArgmax(t *tensor.TensorNumeric[float32]) (int, error) { + gs, ok := t.GetStorage().(*tensor.GPUStorage[float32]) + if !ok { + return 0, fmt.Errorf("GPUArgmax: tensor not GPU-resident") + } + + e.setDevice() + + n := gs.Len() + devInput := gs.Ptr() + + // Allocate scratch: 2 * ceil(n/256) * 4 bytes (blockVals + blockIdxs). + numBlocks := (n + 255) / 256 + scratchSize := 2 * numBlocks * 4 + devScratch, err := e.pool.Alloc(e.deviceID, scratchSize) + if err != nil { + return 0, fmt.Errorf("GPUArgmax: scratch alloc: %w", err) + } + defer e.pool.Free(e.deviceID, devScratch, scratchSize) + + // Allocate device result (single int32). + devResult, err := e.pool.Alloc(e.deviceID, 4) + if err != nil { + return 0, fmt.Errorf("GPUArgmax: result alloc: %w", err) + } + defer e.pool.Free(e.deviceID, devResult, 4) + + if err := e.kernels.Argmax(devInput, devResult, devScratch, n, e.stream); err != nil { + return 0, fmt.Errorf("GPUArgmax: %w", err) + } + + // Copy single int32 result back to host. + var result int32 + if err := e.runtime.Memcpy(unsafe.Pointer(&result), devResult, 4, gpuapi.MemcpyDeviceToHost); err != nil { + return 0, fmt.Errorf("GPUArgmax: D2H copy: %w", err) + } + + return int(result), nil +} + +// GPUScaledSoftmax computes softmax(input * scale) in a single GPU kernel launch. +// This replaces MulScalar + Softmax (2 kernel launches) with 1, saving 26 launches +// per token for 26 transformer layers. +func (e *GPUEngine[T]) GPUScaledSoftmax(input *tensor.TensorNumeric[T], scale float32, axis int) (*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() { + return nil, fmt.Errorf("GPUScaledSoftmax: only float32 supported") + } + + e.setDevice() + + if input == nil { + return nil, fmt.Errorf("GPUScaledSoftmax: input tensor must not be nil") + } + + shape := input.Shape() + rank := len(shape) + + if rank == 0 { + return nil, fmt.Errorf("GPUScaledSoftmax: scalar tensors not supported") + } + + if axis < 0 { + axis = rank + axis + } + + if axis < 0 || axis >= rank { + return nil, fmt.Errorf("GPUScaledSoftmax: axis %d out of bounds for %d dimensions", axis, rank) + } + + n := input.GetStorage().Len() + + inner := 1 + for i := axis + 1; i < rank; i++ { + inner *= shape[i] + } + + outer := 1 + for i := 0; i < axis; i++ { + outer *= shape[i] + } + + axisSize := shape[axis] + + // FP16 paths — skip entirely for F32 compute. + if e.dtype != DTypeF32 { + // Native FP16 path: input already has Float16Storage on GPU — no conversion needed. + if fs, ok := any(input.GetStorage()).(*tensor.Float16Storage); ok { + return fp16ScaledSoftmaxNative(e, fs, input.Shape(), scale, outer, inner, axisSize) + } + // FP16 path: convert to FP16, run FP16 scaled softmax, convert back. + return fp16ScaledSoftmax(e, input, scale, outer, inner, axisSize) + } + + devIn, cleanupIn, err := getDevicePtr(e, input) + if err != nil { + return nil, fmt.Errorf("GPUScaledSoftmax input: %w", err) + } + defer cleanupIn() + + byteSize := n * f32Size + devOut, err := e.pool.Alloc(e.deviceID, byteSize) + if err != nil { + return nil, fmt.Errorf("GPUScaledSoftmax alloc: %w", err) + } + + if err := e.kernels.ScaledSoftmaxF32(devIn, devOut, outer, inner, axisSize, scale, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, byteSize) + return nil, err + } + + return makeGPUResult[T](e, shape, devOut, n) +} + +// GPUFusedSoftmaxVMul computes softmax(scores * scale) @ V in a single GPU +// kernel launch. Decode-optimized (seqQ=1): avoids materializing the attention +// weights tensor, saving one kernel launch and the associated memory traffic. +// scores: [BH, 1, seqKV], V: [BH, seqKV, D]. Returns output: [BH, 1, D]. +func (e *GPUEngine[T]) GPUFusedSoftmaxVMul(scores, V *tensor.TensorNumeric[T], scale float32) (*tensor.TensorNumeric[T], error) { + if !isFloat32[T]() { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul: only float32 supported") + } + + if scores == nil || V == nil { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul: input tensors must not be nil") + } + + e.setDevice() + + sShape := scores.Shape() + vShape := V.Shape() + + // scores must be [BH, 1, seqKV] or [BH, seqKV] + var BH, seqKV int + switch len(sShape) { + case 3: + if sShape[1] != 1 { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul: scores seqQ must be 1 for decode, got %d", sShape[1]) + } + BH, seqKV = sShape[0], sShape[2] + case 2: + BH, seqKV = sShape[0], sShape[1] + default: + return nil, fmt.Errorf("GPUFusedSoftmaxVMul: scores must be 2D or 3D, got %dD", len(sShape)) + } + + // V must be [BH, seqKV, D] + if len(vShape) != 3 || vShape[0] != BH || vShape[1] != seqKV { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul: V shape mismatch: want [%d, %d, D], got %v", BH, seqKV, vShape) + } + D := vShape[2] + + scoresPtr, scoresCleanup, err := getDevicePtr(e, scores) + if err != nil { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul scores: %w", err) + } + defer scoresCleanup() + + vPtr, vCleanup, err := getDevicePtr(e, V) + if err != nil { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul V: %w", err) + } + defer vCleanup() + + outElems := BH * D + outBytes := outElems * f32Size + devOut, err := e.pool.Alloc(e.deviceID, outBytes) + if err != nil { + return nil, fmt.Errorf("GPUFusedSoftmaxVMul alloc: %w", err) + } + + if err := e.kernels.FusedSoftmaxVMulF32(scoresPtr, vPtr, devOut, scale, BH, seqKV, D, e.stream); err != nil { + e.pool.Free(e.deviceID, devOut, outBytes) + return nil, err + } + + outShape := []int{BH, 1, D} + return makeGPUResult[T](e, outShape, devOut, outElems) +}