From ab3b340ae9bf36f5bdd35a2082b4dd122116490c Mon Sep 17 00:00:00 2001 From: ChrisRackauckas-Claude Date: Wed, 15 Apr 2026 16:29:40 -0400 Subject: [PATCH] Add promote_eltype for CuArray / ROCArray / MtlArray MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `ArrayInterface.promote_eltype` only had a method for plain `Array{T, N}` (with an explicit "no generic fallback is given" note in the docstring). Downstream packages that pass GPU array types through `promote_eltype` therefore hit a `MethodError` — for example, SciML/NonlinearSolve.jl#910 tripped this on `test/cuda_tests.jl:33 "GeneralizedFirstOrderAlgorithm"` when deriving a Dual-eltype wrapper-signature array type for `CuArray{Float32}`: MethodError: no method matching promote_eltype( ::Type{CuArray{Float32, 1, CUDA.DeviceMemory}}, ::Type{ForwardDiff.Dual{Tag{NonlinearSolveBase.NonlinearSolveTag, Float32}, Float32, 1}}) Adds the obvious eltype-swapping method in each GPU extension, preserving the non-eltype type parameters (`M` for `CuArray` memory kind, `B` for `ROCArray` buffer type, `S` for `MtlArray` storage mode): ArrayInterface.promote_eltype( ::Type{<:CuArray{T, N, M}}, ::Type{T2} ) where {T, N, M, T2} = CuArray{promote_type(T, T2), N, M} ArrayInterface.promote_eltype( ::Type{<:ROCArray{T, N, B}}, ::Type{T2} ) where {T, N, B, T2} = ROCArray{promote_type(T, T2), N, B} ArrayInterface.promote_eltype( ::Type{<:MtlArray{T, N, S}}, ::Type{T2} ) where {T, N, S, T2} = MtlArray{promote_type(T, T2), N, S} Bumps patch version 7.23.0 → 7.24.0 so downstream packages can compat-bound the new method. Co-Authored-By: Chris Rackauckas Co-Authored-By: Claude Opus 4.6 (1M context) --- Project.toml | 2 +- ext/ArrayInterfaceAMDGPUExt.jl | 6 ++++++ ext/ArrayInterfaceCUDAExt.jl | 6 ++++++ ext/ArrayInterfaceMetalExt.jl | 6 ++++++ 4 files changed, 19 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index c62c3c4c..2c4d81e5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ArrayInterface" uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -version = "7.23.0" +version = "7.24.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/ext/ArrayInterfaceAMDGPUExt.jl b/ext/ArrayInterfaceAMDGPUExt.jl index 47d3b08f..5d8e2f91 100644 --- a/ext/ArrayInterfaceAMDGPUExt.jl +++ b/ext/ArrayInterfaceAMDGPUExt.jl @@ -12,4 +12,10 @@ end ArrayInterface.device(::Type{<:AMDGPU.ROCArray}) = ArrayInterface.GPU() +function ArrayInterface.promote_eltype( + ::Type{<:AMDGPU.ROCArray{T, N, B}}, ::Type{T2} + ) where {T, N, B, T2} + return AMDGPU.ROCArray{promote_type(T, T2), N, B} +end + end # module diff --git a/ext/ArrayInterfaceCUDAExt.jl b/ext/ArrayInterfaceCUDAExt.jl index 4febac2e..715f8107 100644 --- a/ext/ArrayInterfaceCUDAExt.jl +++ b/ext/ArrayInterfaceCUDAExt.jl @@ -13,4 +13,10 @@ end ArrayInterface.device(::Type{<:CUDA.CuArray}) = ArrayInterface.GPU() +function ArrayInterface.promote_eltype( + ::Type{<:CUDA.CuArray{T, N, M}}, ::Type{T2} + ) where {T, N, M, T2} + return CUDA.CuArray{promote_type(T, T2), N, M} +end + end # module diff --git a/ext/ArrayInterfaceMetalExt.jl b/ext/ArrayInterfaceMetalExt.jl index c27f67f2..36655fad 100644 --- a/ext/ArrayInterfaceMetalExt.jl +++ b/ext/ArrayInterfaceMetalExt.jl @@ -12,4 +12,10 @@ end ArrayInterface.device(::Type{<:Metal.MtlArray}) = ArrayInterface.GPU() +function ArrayInterface.promote_eltype( + ::Type{<:Metal.MtlArray{T, N, S}}, ::Type{T2} + ) where {T, N, S, T2} + return Metal.MtlArray{promote_type(T, T2), N, S} +end + end # module \ No newline at end of file