diff --git a/Project.toml b/Project.toml index 7d8fee8..d290647 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "SparseArraysBase" uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208" -version = "0.9.23" +version = "0.10.0" authors = ["ITensor developers and contributors"] [workspace] @@ -12,20 +12,16 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MapBroadcast = "ebd9b9da-f48d-417c-9660-449667d60261" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138" [weakdeps] -NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a" [extensions] -SparseArraysBaseNamedDimsArraysExt = "NamedDimsArrays" SparseArraysBaseTensorAlgebraExt = ["TensorAlgebra", "SparseArrays"] [compat] @@ -34,13 +30,10 @@ Adapt = "4.3" ArrayLayouts = "1.11" Dictionaries = "0.4.3" FillArrays = "1.13" -FunctionImplementations = "0.4" GPUArraysCore = "0.2" LinearAlgebra = "1.10" MapBroadcast = "0.1.5" -NamedDimsArrays = "0.13, 0.14, 0.15" Random = "1.10" SparseArrays = "1.10" -TensorAlgebra = "0.6.2, 0.7, 0.8, 0.9" -TypeParameterAccessors = "0.4.3" +TensorAlgebra = "0.10" julia = "1.10" diff --git a/docs/Project.toml b/docs/Project.toml index 19c9ce6..95fe37e 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -13,4 +13,4 @@ Dictionaries = "0.4.4" Documenter = "1.8.1" ITensorFormatter = "0.2.27" Literate = "2.20.1" -SparseArraysBase = "0.9" +SparseArraysBase = "0.10" diff --git a/examples/Project.toml b/examples/Project.toml index e2ef73d..0c07597 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -8,5 +8,5 @@ path = ".." [compat] Dictionaries = "0.4.4" -SparseArraysBase = "0.9" +SparseArraysBase = "0.10" Test = "<0.0.1, 1" diff --git a/ext/SparseArraysBaseNamedDimsArraysExt/SparseArraysBaseNamedDimsArraysExt.jl b/ext/SparseArraysBaseNamedDimsArraysExt/SparseArraysBaseNamedDimsArraysExt.jl deleted file mode 100644 index 8630347..0000000 --- a/ext/SparseArraysBaseNamedDimsArraysExt/SparseArraysBaseNamedDimsArraysExt.jl +++ /dev/null @@ -1,19 +0,0 @@ -module SparseArraysBaseNamedDimsArraysExt - -using NamedDimsArrays: AbstractNamedDimsArray, AbstractNamedUnitRange, denamed, inds, name, - nameddims, nameddimsof -using SparseArraysBase: SparseArraysBase, dense, oneelement - -function SparseArraysBase.dense(a::AbstractNamedDimsArray) - # TODO: Use `NamedDimsArrays.nameddimsof(a, dense(unname(a)))` once that is defined, - # see: https://github.com/ITensor/NamedDimsArrays.jl/issues/138 - return nameddimsof(a, dense(denamed(a))) -end - -function SparseArraysBase.oneelement( - value, index::NTuple{N, Int}, ax::NTuple{N, AbstractNamedUnitRange} - ) where {N} - return nameddims(oneelement(value, index, denamed.(ax)), name.(ax)) -end - -end diff --git a/src/SparseArraysBase.jl b/src/SparseArraysBase.jl index 0f5ce2f..582a863 100644 --- a/src/SparseArraysBase.jl +++ b/src/SparseArraysBase.jl @@ -17,6 +17,13 @@ export SparseArrayDOK, storedpairs, storedvalues +# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` +# and is useful for sparse array logic, since it can be used to empty +# the sparse array storage. SparseArraysBase owns its own `zero!` rather +# than relying on an external definition. +function zero! end + +include("concatenate.jl") include("abstractsparsearraystyle.jl") include("sparsearraystyle.jl") include("indexing.jl") diff --git a/src/abstractsparsearray.jl b/src/abstractsparsearray.jl index 3452ba2..f1209ee 100644 --- a/src/abstractsparsearray.jl +++ b/src/abstractsparsearray.jl @@ -23,11 +23,6 @@ const AnyAbstractSparseVecOrMat{T} = Union{ Base.convert(T::Type{<:AbstractSparseArray}, a::AbstractArray) = a isa T ? a : T(a) -using FunctionImplementations: FunctionImplementations -function FunctionImplementations.ImplementationStyle(::Type{<:AnyAbstractSparseArray}) - return SparseArrayImplementationStyle() -end - function Base.copy(a::AnyAbstractSparseArray) return copyto!(similar(a), a) end @@ -66,40 +61,66 @@ end using ArrayLayouts: ArrayLayouts using LinearAlgebra: LinearAlgebra -Base.getindex(a::AnyAbstractSparseArray, I::Any...) = style(a)(getindex)(a, I...) -Base.getindex(a::AnyAbstractSparseArray, I::Int...) = style(a)(getindex)(a, I...) -Base.setindex!(a::AnyAbstractSparseArray, x, I::Any...) = style(a)(setindex!)(a, x, I...) -Base.setindex!(a::AnyAbstractSparseArray, x, I::Int...) = style(a)(setindex!)(a, x, I...) -Base.copy!(dst::AbstractArray, src::AnyAbstractSparseArray) = style(src)(copy!)(dst, src) +# `copy!` and `real` are routed to sparse implementations that are not (yet) +# defined, matching the prior behavior where these threw a `MethodError`. +function copy!_sparse end +function real_sparse end + +Base.getindex(a::AnyAbstractSparseArray, I::Any...) = getindex_sparse(a, I...) +Base.getindex(a::AnyAbstractSparseArray, I::Int...) = getindex_sparse(a, I...) +Base.setindex!(a::AnyAbstractSparseArray, x, I::Any...) = setindex!_sparse(a, x, I...) +Base.setindex!(a::AnyAbstractSparseArray, x, I::Int...) = setindex!_sparse(a, x, I...) +Base.copy!(dst::AbstractArray, src::AnyAbstractSparseArray) = copy!_sparse(dst, src) function Base.copyto!(dst::AbstractArray, src::AnyAbstractSparseArray) - return style(src)(copyto!)(dst, src) + return copyto!_sparse(dst, src) end -Base.map(f, as::AnyAbstractSparseArray...) = style(as...)(map)(f, as...) +Base.map(f, as::AnyAbstractSparseArray...) = map_sparse(f, as...) function Base.map!(f, dst::AbstractArray, as::AnyAbstractSparseArray...) - return style(as...)(map!)(f, dst, as...) + return map!_sparse(f, dst, as...) end function Base.mapreduce(f, op, as::AnyAbstractSparseArray...; kwargs...) - return style(as...)(mapreduce)(f, op, as...; kwargs...) + return mapreduce_sparse(f, op, as...; kwargs...) end function Base.reduce(f, as::AnyAbstractSparseArray...; kwargs...) - return style(as...)(reduce)(f, as...; kwargs...) -end -Base.all(f::Function, a::AnyAbstractSparseArray) = style(a)(all)(f, a) -Base.all(a::AnyAbstractSparseArray) = style(a)(all)(a) -Base.iszero(a::AnyAbstractSparseArray) = style(a)(iszero)(a) -Base.isreal(a::AnyAbstractSparseArray) = style(a)(isreal)(a) -Base.real(a::AnyAbstractSparseArray) = style(a)(real)(a) -Base.fill!(a::AnyAbstractSparseArray, x) = style(a)(fill!)(a, x) -FunctionImplementations.zero!(a::AnyAbstractSparseArray) = style(a)(zero!)(a) -Base.zero(a::AnyAbstractSparseArray) = style(a)(zero)(a) + return reduce_sparse(f, as...; kwargs...) +end +Base.all(f::Function, a::AnyAbstractSparseArray) = all_sparse(f, a) +Base.all(a::AnyAbstractSparseArray) = all_sparse(a) +Base.iszero(a::AnyAbstractSparseArray) = iszero_sparse(a) +Base.isreal(a::AnyAbstractSparseArray) = isreal_sparse(a) +Base.real(a::AnyAbstractSparseArray) = real_sparse(a) +Base.fill!(a::AnyAbstractSparseArray, x) = fill!_sparse(a, x) +zero!(a::AnyAbstractSparseArray) = zero!_sparse(a) +Base.zero(a::AnyAbstractSparseArray) = zero_sparse(a) function Base.permutedims!(dst, a::AnyAbstractSparseArray, perm) - return style(a)(permutedims!)(dst, a, perm) + return permutedims!_sparse(dst, a, perm) end function LinearAlgebra.mul!( dst::AbstractMatrix, a1::AnyAbstractSparseArray, a2::AnyAbstractSparseArray, α::Number, β::Number ) - return style(a1, a2)(mul!)(dst, a1, a2, α, β) + return mul!_sparse(dst, a1, a2, α, β) +end + +# Wire the sparse stored-entry implementations (defined in `indexing.jl`) to the +# generic interface functions for sparse array types. Concrete sparse types and +# wrappers may override the canonical methods directly. +@inline getstoredindex(a::AnyAbstractSparseArray, I::Int...) = + getstoredindex_sparse(a, I...) +@inline function getunstoredindex(a::AnyAbstractSparseArray, I::Int...) + return getunstoredindex_sparse(a, I...) +end +@inline isstored(a::AbstractSparseArray, i::Int, I::Int...) = isstored_sparse(a, i, I...) +@inline function setstoredindex!(a::AnyAbstractSparseArray, v, I::Int...) + return setstoredindex!_sparse(a, v, I...) +end +@inline function setunstoredindex!(a::AnyAbstractSparseArray, v, I::Int...) + return setunstoredindex!_sparse(a, v, I...) +end +storedvalues(a::AnyAbstractSparseArray) = storedvalues_sparse(a) +storedpairs(a::AnyAbstractSparseArray) = storedpairs_sparse(a) +function eachstoredindex(style::IndexStyle, a::AnyAbstractSparseArray, bs::AbstractArray...) + return eachstoredindex_sparse(style, a, bs...) end function Base.Broadcast.BroadcastStyle(type::Type{<:AnyAbstractSparseArray}) @@ -109,7 +130,7 @@ end using ArrayLayouts: ArrayLayouts ArrayLayouts.MemoryLayout(type::Type{<:AnyAbstractSparseArray}) = SparseLayout() -using FunctionImplementations.Concatenate: concatenate +using .Concatenate: concatenate # We overload `Base._cat` instead of `Base.cat` since it # is friendlier for invalidations/compile times, see: # https://github.com/ITensor/SparseArraysBase.jl/issues/25 diff --git a/src/abstractsparsearraystyle.jl b/src/abstractsparsearraystyle.jl index 3892687..dd727b9 100644 --- a/src/abstractsparsearraystyle.jl +++ b/src/abstractsparsearraystyle.jl @@ -1,6 +1,5 @@ using Base: @_propagate_inbounds_meta using FillArrays: Zeros -using FunctionImplementations: FunctionImplementations function unstored end function eachstoredindex end @@ -38,7 +37,28 @@ unstoredsimilar(a::AbstractArray) = a # Generic functionality for converting to a # dense array, trying to preserve information # about the array (such as which device it is on). -using TypeParameterAccessors: unspecify_type_parameters, unwrap_array, unwrap_array_type +# Local equivalents of the `TypeParameterAccessors` helpers; usage is +# limited to `densetype`/`dense`. +unspecify_type_parameters(::Type{T}) where {T} = Base.typename(T).wrapper +function unwrap_array(a::AbstractArray) + p = parent(a) + p isa typeof(a) && return a + return unwrap_array(p) +end +using LinearAlgebra: Adjoint, Transpose +unwrap_array_type(arraytype::Type{<:AbstractArray}) = arraytype +unwrap_array_type(a::AbstractArray) = unwrap_array_type(typeof(a)) +unwrap_array_type(::Type{<:Adjoint{<:Any, P}}) where {P} = unwrap_array_type(P) +unwrap_array_type(::Type{<:Transpose{<:Any, P}}) where {P} = unwrap_array_type(P) +function unwrap_array_type( + ::Type{<:PermutedDimsArray{<:Any, <:Any, <:Any, <:Any, P}} + ) where {P} + return unwrap_array_type(P) +end +function unwrap_array_type(::Type{<:Base.ReshapedArray{<:Any, <:Any, P}}) where {P} + return unwrap_array_type(P) +end +unwrap_array_type(::Type{<:SubArray{<:Any, <:Any, P}}) where {P} = unwrap_array_type(P) function densetype(arraytype::Type{<:AbstractArray}) return unspecify_type_parameters(unwrap_array_type(arraytype)) end @@ -51,34 +71,6 @@ function dense(a::AbstractArray) return @allowscalar convert(densetype(a), a) end -# Minimal interface for `SparseArrayImplementationStyle`. -# Fallbacks for dense/non-sparse arrays. - -using FunctionImplementations: AbstractArrayImplementationStyle -abstract type AbstractSparseArrayImplementationStyle <: AbstractArrayImplementationStyle end - -function FunctionImplementations.ImplementationStyle( - style1::AbstractSparseArrayImplementationStyle, - style2::AbstractSparseArrayImplementationStyle - ) - return SparseArrayImplementationStyle() -end -function FunctionImplementations.ImplementationStyle( - style1::AbstractSparseArrayImplementationStyle, - style2::AbstractArrayImplementationStyle - ) - return style1 -end -# Fix ambiguity error with -# `ImplementationStyle(::AbstractSparseArrayImplementationStyle, ::AbstractArrayImplementationStyle)`. -using FunctionImplementations: DefaultArrayImplementationStyle -function FunctionImplementations.ImplementationStyle( - style1::AbstractSparseArrayImplementationStyle, - style2::DefaultArrayImplementationStyle - ) - return style1 -end - to_vec(x) = vec(collect(x)) to_vec(x::AbstractArray) = vec(x) diff --git a/src/concatenate.jl b/src/concatenate.jl new file mode 100644 index 0000000..6fb92ca --- /dev/null +++ b/src/concatenate.jl @@ -0,0 +1,194 @@ +# Alternative implementation for `Base.cat` through `Concatenate.cat(!)`. +# This is mostly a copy of the Base implementation, with the main difference being +# that the destination is chosen based on all inputs instead of just the first. +# There is an intermediate representation in terms of a `Concatenated` object, +# reminiscent of how Broadcast works. Destination selection can be customized through +# `Base.similar(::Concatenated{Style}, ::Type{T}, axes)`, and the operation itself +# through `Base.copy`/`Base.copyto!` on a `Concatenated`. +module Concatenate + +import Base.Broadcast as BC +using ..SparseArraysBase: zero! +using Base: promote_eltypeof + +unval(::Val{x}) where {x} = x + +function _Concatenated end + +# Lazy representation of the concatenation of various `Args` along `Dims`, in order to +# provide hooks to customize the implementation. +struct Concatenated{Style, Dims, Args <: Tuple} + style::Style + dims::Val{Dims} + args::Args + global @inline function _Concatenated( + style::Style, dims::Val{Dims}, args::Args + ) where {Style, Dims, Args <: Tuple} + return new{Style, Dims, Args}(style, dims, args) + end +end + +function Concatenated( + style::Union{BC.AbstractArrayStyle, Nothing}, dims::Val, args::Tuple + ) + return _Concatenated(style, dims, args) +end +function Concatenated(dims::Val, args::Tuple) + return Concatenated(cat_style(dims, args...), dims, args) +end +function Concatenated{Style}( + dims::Val, args::Tuple + ) where {Style <: Union{BC.AbstractArrayStyle, Nothing}} + return Concatenated(Style(), dims, args) +end + +dims(::Concatenated{<:Any, D}) where {D} = D +style(concat::Concatenated) = getfield(concat, :style) + +concatenated(dims, args...) = concatenated(Val(dims), args...) +concatenated(dims::Val, args...) = Concatenated(dims, args) + +function Base.convert( + ::Type{Concatenated{NewStyle}}, concat::Concatenated{<:Any, Dims, Args} + ) where {NewStyle, Dims, Args} + return Concatenated{NewStyle}( + concat.dims, concat.args + )::Concatenated{NewStyle, Dims, Args} +end + +# allocating the destination container +# ------------------------------------ +Base.similar(concat::Concatenated) = similar(concat, eltype(concat)) +Base.similar(concat::Concatenated, ::Type{T}) where {T} = similar(concat, T, axes(concat)) +function Base.similar(concat::Concatenated, ax) + return similar(concat, eltype(concat), ax) +end + +function Base.similar(concat::Concatenated, ::Type{T}, ax) where {T} + # Convert to a broadcasted to leverage its similar implementation. + bc = BC.Broadcasted(style(concat), identity, concat.args, ax) + return similar(bc, T) +end + +function cat_axis( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... + ) + return cat_axis(cat_axis(a1, a2), a_rest...) +end +function cat_axis(a1::AbstractUnitRange, a2::AbstractUnitRange) + first(a1) == first(a2) == 1 || throw(ArgumentError("Concatenated axes must start at 1")) + return Base.OneTo(length(a1) + length(a2)) +end + +function cat_ndims(dims, as::AbstractArray...) + return max(maximum(dims), maximum(ndims, as)) +end +function cat_ndims(dims::Val, as::AbstractArray...) + return cat_ndims(unval(dims), as...) +end + +function cat_axes(dims, a::AbstractArray, as::AbstractArray...) + return ntuple(cat_ndims(dims, a, as...)) do dim + return if dim in dims + cat_axis(map(Base.Fix2(axes, dim), (a, as...))...) + else + axes(a, dim) + end + end +end +function cat_axes(dims::Val, as::AbstractArray...) + return cat_axes(unval(dims), as...) +end + +function cat_style(dims, as::AbstractArray...) + N = cat_ndims(dims, as...) + return typeof(BC.combine_styles(as...))(Val(N)) +end + +Base.eltype(concat::Concatenated) = promote_eltypeof(concat.args...) +Base.axes(concat::Concatenated) = cat_axes(dims(concat), concat.args...) +Base.size(concat::Concatenated) = length.(axes(concat)) +Base.ndims(concat::Concatenated) = cat_ndims(dims(concat), concat.args...) + +# Main logic +# ---------- +# Concatenate the supplied `args` along dimensions `dims`. +concatenate(dims, args...) = Base.materialize(concatenated(dims, args...)) + +# Concatenate the supplied `args` along dimensions `dims`. +cat(args...; dims) = concatenate(dims, args...) +Base.materialize(concat::Concatenated) = copy(concat) + +# Concatenate the supplied `args` along dimensions `dims`, placing the result into `dest`. +function cat!(dest, args...; dims) + Base.materialize!(dest, concatenated(dims, args...)) + return dest +end +Base.materialize!(dest, concat::Concatenated) = copyto!(dest, concat) + +Base.copy(concat::Concatenated) = copyto!(similar(concat), concat) + +# The following is largely copied from the Base implementation of `Base.cat`, see: +# https://github.com/JuliaLang/julia/blob/885b1cd875f101f227b345f681cc36879124d80d/base/abstractarray.jl#L1778-L1887 +_copy_or_fill!(A, inds, x) = fill!(view(A, inds...), x) +_copy_or_fill!(A, inds, x::AbstractArray) = (A[inds...] = x) + +cat_size(A) = (1,) +cat_size(A::AbstractArray) = size(A) +cat_size(A, d) = 1 +cat_size(A::AbstractArray, d) = size(A, d) + +cat_indices(A, d) = Base.OneTo(1) +cat_indices(A::AbstractArray, d) = axes(A, d) + +function __cat!(A, shape, catdims, X...) + return __cat_offset!(A, shape, catdims, ntuple(zero, length(shape)), X...) +end +function __cat_offset!(A, shape, catdims, offsets, x, X...) + # splitting the "work" on x from X... may reduce latency (fewer costly specializations) + newoffsets = __cat_offset1!(A, shape, catdims, offsets, x) + return __cat_offset!(A, shape, catdims, newoffsets, X...) +end +__cat_offset!(A, shape, catdims, offsets) = A +function __cat_offset1!(A, shape, catdims, offsets, x) + inds = ntuple(length(offsets)) do i + return if (i <= length(catdims) && catdims[i]) + offsets[i] .+ cat_indices(x, i) + else + 1:shape[i] + end + end + _copy_or_fill!(A, inds, x) + newoffsets = ntuple(length(offsets)) do i + return if (i <= length(catdims) && catdims[i]) + offsets[i] + cat_size(x, i) + else + offsets[i] + end + end + return newoffsets +end + +dims2cat(dims::Val) = dims2cat(unval(dims)) +function dims2cat(dims) + if any(≤(0), dims) + throw(ArgumentError("All cat dimensions must be positive integers, but got $dims")) + end + return ntuple(in(dims), maximum(dims)) +end + +# default falls back to replacing style with Nothing +# this permits specializing on typeof(dest) without ambiguities +# Note: this needs to be defined for AbstractArray specifically to avoid ambiguities with Base. +@inline function Base.copyto!(dest::AbstractArray, concat::Concatenated) + return copyto!(dest, convert(Concatenated{Nothing}, concat)) +end + +function Base.copyto!(dest::AbstractArray, concat::Concatenated{Nothing}) + catdims = dims2cat(dims(concat)) + shape = size(concat) + count(!iszero, catdims)::Int > 1 && zero!(dest) + return __cat!(dest, shape, catdims, concat.args...) +end + +end diff --git a/src/indexing.jl b/src/indexing.jl index 9ead7f0..781f82f 100644 --- a/src/indexing.jl +++ b/src/indexing.jl @@ -1,5 +1,4 @@ using Base: @_propagate_inbounds_meta -using FunctionImplementations: Implementation, style # Indexing interface # ------------------ @@ -11,9 +10,7 @@ Obtain `getindex(A, I...)` with the guarantee that there is a stored entry at th Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline function getstoredindex(A::AbstractArray, I...) - return style(A)(getstoredindex)(A, I...) -end +function getstoredindex end """ getunstoredindex(A::AbstractArray, I...) -> eltype(A) @@ -25,9 +22,7 @@ instantiated object. Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline function getunstoredindex(A::AbstractArray, I...) - return style(A)(getunstoredindex)(A, I...) -end +function getunstoredindex end """ isstored(A::AbstractArray, I...) -> Bool @@ -38,9 +33,7 @@ sparse array types might overload this function when appropriate. Similar to `Base.getindex`, new definitions should be in line with `IndexStyle(A)`. """ -@inline function isstored(A::AbstractArray, I...) - return style(A)(isstored)(A, I...) -end +function isstored end """ setstoredindex!(A::AbstractArray, v, I...) -> A @@ -49,9 +42,7 @@ end Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. """ -@inline function setstoredindex!(A::AbstractArray, v, I...) - return style(A)(setstoredindex!)(A, v, I...) -end +function setstoredindex! end """ setunstoredindex!(A::AbstractArray, v, I...) -> A @@ -60,9 +51,7 @@ end Similar to `Base.setindex!`, new definitions should be in line with `IndexStyle(A)`. """ -@inline function setunstoredindex!(A::AbstractArray, v, I...) - return style(A)(setunstoredindex!)(A, v, I...) -end +function setunstoredindex! end # Indices interface # ----------------- @@ -110,14 +99,6 @@ to be the same as [`eachstoredindex`](@ref). """ function storedvalues end -eachstoredindex(as::AbstractArray...) = style(as...)(eachstoredindex)(as...) -function eachstoredindex(indexstyle::IndexStyle, as::AbstractArray...) - return style(as...)(eachstoredindex)(indexstyle, as...) -end -storedlength(a::AbstractArray) = style(a)(storedlength)(a) -storedpairs(a::AbstractArray) = style(a)(storedpairs)(a) -storedvalues(a::AbstractArray) = style(a)(storedvalues)(a) - # canonical indexing # ------------------ # ensure functions only have to be defined in terms of a single canonical f: @@ -128,7 +109,7 @@ for f in (:isstored, :getunstoredindex, :getstoredindex) _f = Symbol(:_, f) error_if_canonical = Symbol(:error_if_canonical_, f) @eval begin - function (::Implementation{typeof($f)})(A::AbstractArray, I...) + function $f(A::AbstractArray, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -180,7 +161,7 @@ for f! in (:setstoredindex!, :setunstoredindex!) _f! = Symbol(:_, f!) error_if_canonical = Symbol(:error_if_canonical_, f!) @eval begin - function (::Implementation{typeof($f!)})(A::AbstractArray, v, I...) + function $f!(A::AbstractArray, v, I...) @_propagate_inbounds_meta style = IndexStyle(A) $error_if_canonical(style, A, I...) @@ -231,50 +212,52 @@ for f! in (:setstoredindex!, :setunstoredindex!) end end -# AbstractArrayStyle fallback definitions -# ------------------------------------------- -function (::Implementation{typeof(isstored)})(A::AbstractArray, i::Int, I::Int...) +# AbstractArray fallback definitions +# ---------------------------------- +function isstored(A::AbstractArray, i::Int, I::Int...) @inline @boundscheck checkbounds(A, i, I...) return true end -function (::Implementation{typeof(getunstoredindex)})(A::AbstractArray, I::Int...) +function getunstoredindex(A::AbstractArray, I::Int...) @inline @boundscheck checkbounds(A, I...) return zero(eltype(A)) end -function (::Implementation{typeof(getstoredindex)})(A::AbstractArray, I::Int...) +function getstoredindex(A::AbstractArray, I::Int...) @inline return getindex(A, I...) end -function (::Implementation{typeof(setstoredindex!)})(A::AbstractArray, v, I::Int...) +function setstoredindex!(A::AbstractArray, v, I::Int...) @inline return setindex!(A, v, I...) end -function (::Implementation{typeof(setunstoredindex!)})(A::AbstractArray, v, I::Int...) +function setunstoredindex!(A::AbstractArray, v, I::Int...) return error("setunstoredindex! for $(typeof(A)) is not supported") end -function (::Implementation{typeof(eachstoredindex)})(A::AbstractArray, B::AbstractArray...) +function eachstoredindex(A::AbstractArray, B::AbstractArray...) return eachstoredindex(IndexStyle(A, B...), A, B...) end -function (::Implementation{typeof(eachstoredindex)})( +function eachstoredindex( style::IndexStyle, A::AbstractArray, B::AbstractArray... ) return eachindex(style, A, B...) end -(::Implementation{typeof(storedvalues)})(a::AbstractArray) = values(a) -(::Implementation{typeof(storedpairs)})(a::AbstractArray) = pairs(a) -(::Implementation{typeof(storedlength)})(a::AbstractArray) = length(storedvalues(a)) +storedvalues(a::AbstractArray) = values(a) +storedpairs(a::AbstractArray) = pairs(a) +storedlength(a::AbstractArray) = length(storedvalues(a)) -# SparseArrayInterface implementations -# ------------------------------------ +# Sparse array implementations +# ---------------------------- +# These are the implementations for arrays whose stored entries are a strict subset +# of all entries (`AnyAbstractSparseArray`); they are wired to the interface functions +# above in `abstractsparsearray.jl`. # canonical errors are moved to `isstored`, `getstoredindex` and `getunstoredindex` # so no errors at this level by defining both IndexLinear and IndexCartesian -const getindex_sparse = sparse_style(getindex) function getindex_sparse(A::AbstractArray{<:Any, N}, I::Vararg{Int, N}) where {N} @_propagate_inbounds_meta @boundscheck checkbounds(A, I...) # generally isstored requires bounds checking @@ -302,7 +285,6 @@ function getindex_sparse(a::AbstractArray, I...) return ArrayLayouts.layout_getindex(a, I...) end -const setindex!_sparse = sparse_style(setindex!) function setindex!_sparse( A::AbstractArray{<:Any, N}, v, I::Vararg{Int, N} ) where {N} @@ -354,7 +336,6 @@ end end # required: one implementation for canonical index style -const eachstoredindex_sparse = sparse_style(eachstoredindex) function eachstoredindex_sparse(style::IndexStyle, A::AbstractArray) error_if_canonical_eachstoredindex(style, A) inds = eachstoredindex(A) @@ -376,13 +357,11 @@ function eachstoredindex_sparse( return union(map(Base.Fix1(eachstoredindex, style), (A, B...))...) end -const storedvalues_sparse = sparse_style(storedvalues) storedvalues_sparse(A::AbstractArray) = StoredValues(A) # default implementation is a bit tricky here: we don't know if this is the "canonical" # implementation, so we check this and otherwise map back to `_isstored` to canonicalize the # indices -const isstored_sparse = sparse_style(isstored) function isstored_sparse(A::AbstractArray, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) @@ -402,7 +381,6 @@ function isstored_sparse(A::AbstractArray, I::Int...) return _isstored(style, A, Base.to_indices(A, I)...) end -const getunstoredindex_sparse = sparse_style(getunstoredindex) function getunstoredindex_sparse(A::AbstractArray, I::Int...) @_propagate_inbounds_meta style = IndexStyle(A) @@ -420,7 +398,6 @@ function getunstoredindex_sparse(A::AbstractArray, I::Int...) return _getunstoredindex(style, A, Base.to_indices(A, I)...) end -const getstoredindex_sparse = sparse_style(getstoredindex) function getstoredindex_sparse( A::AbstractArray, I::Int... ) @@ -430,8 +407,6 @@ function getstoredindex_sparse( return _getstoredindex(style, A, Base.to_indices(A, I)...) end -const setstoredindex!_sparse = sparse_style(setstoredindex!) -const setunstoredindex!_sparse = sparse_style(setunstoredindex!) for f! in (:setstoredindex!, :setunstoredindex!) f!_sparse = Symbol(f!, :_sparse) _f! = Symbol(:_, f!) @@ -446,7 +421,6 @@ for f! in (:setstoredindex!, :setunstoredindex!) end end -const storedpairs_sparse = sparse_style(storedpairs) function storedpairs_sparse(A::AbstractArray) return Iterators.map(I -> (I => A[I]), eachstoredindex(A)) end diff --git a/src/map.jl b/src/map.jl index b18bcd0..475f82a 100644 --- a/src/map.jl +++ b/src/map.jl @@ -68,7 +68,6 @@ end # map(!) # ------ -const map_sparse = sparse_style(map) function map_sparse( f, A::AbstractArray, Bs::AbstractArray... ) @@ -89,7 +88,6 @@ function map_sparse(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...) return map!_sparse(f, C, A, Bs...) end -const map!_sparse = sparse_style(map!) function map!_sparse( f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray... ) @@ -124,7 +122,6 @@ end # Derived functions # ----------------- -const copyto!_sparse = sparse_style(copyto!) function copyto!_sparse( dest::AbstractArray, src::AbstractArray ) @@ -132,7 +129,6 @@ function copyto!_sparse( return dest end -const permutedims!_sparse = sparse_style(permutedims!) function permutedims!_sparse( a_dest::AbstractArray, a_src::AbstractArray, perm ) @@ -142,7 +138,6 @@ end # Only map the stored values of the inputs. function map_stored! end -const map_stored!_sparse = sparse_style(map_stored!) function map_stored!_sparse( f, a_dest::AbstractArray, as::AbstractArray... ) @@ -153,7 +148,6 @@ end # Only map all values, not just the stored ones. function map_all! end -const map_all!_sparse = sparse_style(map_all!) function map_all!_sparse( f, a_dest::AbstractArray, as::AbstractArray... ) @@ -162,12 +156,10 @@ function map_all!_sparse( end # TODO: Generalize to multiple inputs. -const reduce_sparse = sparse_style(reduce) function reduce_sparse(f, a::AbstractArray; kwargs...) return mapreduce(identity, f, a; kwargs...) end -const all_sparse = sparse_style(all) function all_sparse(a::AbstractArray) return reduce(&, a; init = true) end @@ -175,10 +167,8 @@ function all_sparse(f::Function, a::AbstractArray) return mapreduce(f, &, a; init = true) end -const isreal_sparse = sparse_style(isreal) isreal_sparse(a::AbstractArray) = all(isreal, a) -const iszero_sparse = sparse_style(iszero) iszero_sparse(a::AbstractArray) = all(iszero, a) # Utility functions diff --git a/src/sparsearraydok.jl b/src/sparsearraydok.jl index 0050691..c08ca18 100644 --- a/src/sparsearraydok.jl +++ b/src/sparsearraydok.jl @@ -1,6 +1,5 @@ using Accessors: @set using Dictionaries: Dictionary, IndexError, set! -using FunctionImplementations: FunctionImplementations, zero! const DOKStorage{T, N} = Dictionary{CartesianIndex{N}, T} @@ -108,7 +107,7 @@ end storedpairs(a::SparseArrayDOK) = pairs(storage(a)) # TODO: Also handle wrappers. -function FunctionImplementations.zero!(a::SparseArrayDOK) +function zero!(a::SparseArrayDOK) empty!(storage(a)) return a end diff --git a/src/sparsearraystyle.jl b/src/sparsearraystyle.jl index f998f2a..132fef0 100644 --- a/src/sparsearraystyle.jl +++ b/src/sparsearraystyle.jl @@ -1,25 +1,15 @@ -using FunctionImplementations: FunctionImplementations - -struct SparseArrayImplementationStyle <: AbstractSparseArrayImplementationStyle end - -# Convenient shorthand to refer to the sparse style. -# Can turn a function into a sparse function with the syntax `sparse_style(f)`, -# i.e. `sparse_style(map)(x -> 2x, randn(2, 2))` while use the sparse -# version of `map`. -const sparse_style = SparseArrayImplementationStyle() - -const fill!_sparse = sparse_style(fill!) function fill!_sparse(a::AbstractArray, value) return map!(Returns(value), a, a) end -using FunctionImplementations: FunctionImplementations, zero! +# Generic fallback for SparseArraysBase's owned `zero!` (see the declaration +# in `SparseArraysBase.jl`). +function zero!(a::AbstractArray) + fill!(a, zero(eltype(a))) + return a +end -# `zero!` isn't defined in `Base`, but it is defined in `ArrayLayouts` -# and is useful for sparse array logic, since it can be used to empty -# the sparse array storage. # We use a single function definition to minimize method ambiguities. -const zero!_sparse = sparse_style(zero!) function zero!_sparse(a::AbstractArray) # More generally, this codepath could be taking if `zero(eltype(a))` # is defined and the elements are immutable. @@ -30,7 +20,6 @@ function zero!_sparse(a::AbstractArray) return a end -const zero_sparse = sparse_style(zero) # Specialized version of `Base.zero` written in terms of `zero!`. # This is friendlier for sparse arrays since `zero!` makes it easier # to handle the logic of dropping all elements of the sparse array when possible. @@ -57,14 +46,12 @@ end # This is defined in this way so we can rely on the Broadcast logic # for determining the destination of the operation (element type, shape, etc.). -const map_sparse = sparse_style(map) function map_sparse(f, as::AbstractArray...) # Broadcasting is used here to determine the destination array but that # could be done manually here. return f.(as...) end -const mapreduce_sparse = sparse_style(mapreduce) function mapreduce_sparse( f, op, as::AbstractArray...; init = reduce_init(f, op, as...), kwargs... ) @@ -100,7 +87,6 @@ function Base.similar( end using ArrayLayouts: ArrayLayouts -const mul!_sparse = sparse_style(mul!) function mul!_sparse( a_dest::AbstractVecOrMat, a1::AbstractVecOrMat, a2::AbstractVecOrMat, α::Number, β::Number diff --git a/src/wrappers.jl b/src/wrappers.jl index 7deee7c..6a70a51 100644 --- a/src/wrappers.jl +++ b/src/wrappers.jl @@ -194,11 +194,8 @@ for type in (:Adjoint, :PermutedDimsArray, :ReshapedArray, :SubArray, :Transpose end end -using FunctionImplementations: ImplementationStyle using LinearAlgebra: LinearAlgebra, Diagonal -const diag_style = ImplementationStyle(Diagonal) -const storedvalues_diag = diag_style(storedvalues) -storedvalues_diag(D::AbstractMatrix) = LinearAlgebra.diag(D) +storedvalues(D::Diagonal) = LinearAlgebra.diag(D) # compat with LTS: @static if VERSION ≥ v"1.11" @@ -208,21 +205,16 @@ else return view(CartesianIndices(x), LinearAlgebra.diagind(x)) end end -const eachstoredindex_diag = diag_style(eachstoredindex) -eachstoredindex_diag(D::AbstractMatrix) = _diagind(D, IndexCartesian()) +eachstoredindex(D::Diagonal) = _diagind(D, IndexCartesian()) -const isstored_diag = diag_style(isstored) -function isstored_diag(D::AbstractMatrix, i::Int, j::Int) +function isstored(D::Diagonal, i::Int, j::Int) return i == j && checkbounds(Bool, D, i, j) end -const getstoredindex_diag = diag_style(getstoredindex) -getstoredindex_diag(D::AbstractMatrix, i::Int, j::Int) = D.diag[i] -const getunstoredindex_diag = diag_style(getunstoredindex) -function getunstoredindex_diag(D::AbstractMatrix, i::Int, j::Int) +getstoredindex(D::Diagonal, i::Int, j::Int) = D.diag[i] +function getunstoredindex(D::Diagonal, i::Int, j::Int) return zero(eltype(D)) end -const setstoredindex!_diag = diag_style(setstoredindex!) -function setstoredindex!_diag(D::AbstractMatrix, v, i::Int, j::Int) +function setstoredindex!(D::Diagonal, v, i::Int, j::Int) D.diag[i] = v return D end diff --git a/test/Project.toml b/test/Project.toml index cfd958c..d19da54 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -4,11 +4,9 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a" Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -FunctionImplementations = "7c7cc465-9c6a-495f-bdd1-f42428e86d0c" ITensorPkgSkeleton = "3d388ab1-018a-49f4-ae50-18094d5f71ea" JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -NamedDimsArrays = "60cbd0c0-df58-4cb7-918c-6f5607b73fde" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -27,16 +25,14 @@ Aqua = "0.8.11" ArrayLayouts = "1.11.1" Dictionaries = "0.4.4" FillArrays = "1.13" -FunctionImplementations = "0.4" ITensorPkgSkeleton = "0.3.42" JLArrays = "0.2, 0.3" LinearAlgebra = "<0.0.1, 1" -NamedDimsArrays = "0.13, 0.14, 0.15" Random = "<0.0.1, 1" SafeTestsets = "0.1" SparseArrays = "1.10" -SparseArraysBase = "0.9" +SparseArraysBase = "0.10" StableRNGs = "1.0.2" Suppressor = "0.2.8" -TensorAlgebra = "0.6, 0.7, 0.8, 0.9" +TensorAlgebra = "0.10" Test = "<0.0.1, 1" diff --git a/test/test_nameddimsarraysext.jl b/test/test_nameddimsarraysext.jl deleted file mode 100644 index 85de589..0000000 --- a/test/test_nameddimsarraysext.jl +++ /dev/null @@ -1,35 +0,0 @@ -using NamedDimsArrays: NamedDimsArray, denamed, inds, nameddims, namedoneto -using SparseArraysBase: dense, oneelement, sparsezeros -using Test: @test, @testset - -@testset "SparseArraysBaseExt (eltype=$elt)" for elt in (Float64, ComplexF64) - @testset "oneelement" begin - i = namedoneto(3, "i") - a = oneelement(i => 2) - @test a isa NamedDimsArray{Bool} - @test ndims(a) == 1 - @test issetequal(inds(a), (i,)) - @test eltype(a) ≡ Bool - @test a[1] == 0 - @test a[2] == 1 - @test a[3] == 0 - - i = namedoneto(3, "i") - a = oneelement(elt, i => 2) - @test a isa NamedDimsArray{elt} - @test ndims(a) == 1 - @test issetequal(inds(a), (i,)) - @test eltype(a) ≡ elt - @test a[1] == 0 - @test a[2] == 1 - @test a[3] == 0 - end - @testset "dense" begin - s = sparsezeros(elt, 3, 4) - a = nameddims(s, (:a, :b)) - b = dense(a) - @test denamed(b) == dense(denamed(a)) - @test denamed(b) isa Array{elt, 2} - @test inds(b) == inds(a) - end -end diff --git a/test/test_sparse_style.jl b/test/test_sparse_style.jl deleted file mode 100644 index 0b71162..0000000 --- a/test/test_sparse_style.jl +++ /dev/null @@ -1,53 +0,0 @@ -using FunctionImplementations: DefaultArrayImplementationStyle, ImplementationStyle, style -using SparseArraysBase: AbstractSparseArrayImplementationStyle, - SparseArrayImplementationStyle, sparse_style, sparsezeros -using Test: @test, @testset - -module TestSparseImplementationStyleUtils - using SparseArraysBase: AbstractSparseArray, AbstractSparseArrayImplementationStyle - using FunctionImplementations: FunctionImplementations - struct MySparseArrayImplementationStyle <: AbstractSparseArrayImplementationStyle end - struct MySparseArray{T, N} <: AbstractSparseArray{T, N} - size::NTuple{N, Int} - end - function FunctionImplementations.ImplementationStyle(::Type{<:MySparseArray}) - return MySparseArrayImplementationStyle() - end -end - -@testset "Combine Sparse Styles" begin - @test sparse_style ≡ SparseArrayImplementationStyle() - @test ImplementationStyle( - SparseArrayImplementationStyle(), - SparseArrayImplementationStyle() - ) ≡ SparseArrayImplementationStyle() - @test ImplementationStyle( - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle(), - SparseArrayImplementationStyle() - ) ≡ - SparseArrayImplementationStyle() - @test ImplementationStyle( - SparseArrayImplementationStyle(), - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle() - ) ≡ - SparseArrayImplementationStyle() - @test style(TestSparseImplementationStyleUtils.MySparseArray{Float64, 2}((2, 2))) ≡ - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle() - @test style( - sparsezeros(2, 2), - TestSparseImplementationStyleUtils.MySparseArray{Float64, 2}((2, 2)) - ) ≡ - SparseArrayImplementationStyle() - # Regression tests for ambiguity caused by combining AbstractSparseArrayStyle with - # DefaultArrayStyle. - @test ImplementationStyle( - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle(), - DefaultArrayImplementationStyle() - ) ≡ - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle() - @test style( - TestSparseImplementationStyleUtils.MySparseArray{Float64, 2}((2, 2)), - randn(2, 2) - ) ≡ - TestSparseImplementationStyleUtils.MySparseArrayImplementationStyle() -end