diff --git a/src/coo_linalg.jl b/src/coo_linalg.jl index 84732a6..0086406 100644 --- a/src/coo_linalg.jl +++ b/src/coo_linalg.jl @@ -12,21 +12,23 @@ function coo_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMatrix, α, end end -function LinearAlgebra.mul!( - C::StridedVecOrMat, - A::AbstractSparseMatrixCOO, - B::SparseArrays.DenseInputVecOrMat, - α::Number, - β::Number, -) - size(A, 2) == size(B, 1) || throw(DimensionMismatch()) - size(A, 1) == size(C, 1) || throw(DimensionMismatch()) - size(B, 2) == size(C, 2) || throw(DimensionMismatch()) - if β != 1 - β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) +for T in (AbstractVector, AbstractMatrix) + @eval function LinearAlgebra.mul!( + C::StridedVecOrMat, + A::AbstractSparseMatrixCOO, + B::$T, + α::Number, + β::Number, + ) + size(A, 2) == size(B, 1) || throw(DimensionMismatch()) + size(A, 1) == size(C, 1) || throw(DimensionMismatch()) + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) + if β != 1 + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) + end + coo_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A)) + C end - coo_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A)) - C end function coo_adjtrans_mul!(C::AbstractVector, Arows, Acols, Avals, B::AbstractVector, α, Annz, t) @@ -44,22 +46,24 @@ function coo_adjtrans_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMa end for (T, t) in ((Adjoint, adjoint), (Transpose, transpose)) - @eval function LinearAlgebra.mul!( - C::StridedVecOrMat, - xA::$T{<:Any, <:AbstractSparseMatrixCOO}, - B::SparseArrays.DenseInputVecOrMat, - α::Number, - β::Number, - ) - A = xA.parent - size(A, 2) == size(C, 1) || throw(DimensionMismatch()) - size(A, 1) == size(B, 1) || throw(DimensionMismatch()) - size(B, 2) == size(C, 2) || throw(DimensionMismatch()) - if β != 1 - β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) + for Tb in (AbstractVector, AbstractMatrix) + @eval function LinearAlgebra.mul!( + C::StridedVecOrMat, + xA::$T{<:Any, <:AbstractSparseMatrixCOO}, + B::$Tb, + α::Number, + β::Number, + ) + A = xA.parent + size(A, 2) == size(C, 1) || throw(DimensionMismatch()) + size(A, 1) == size(B, 1) || throw(DimensionMismatch()) + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) + if β != 1 + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) + end + coo_adjtrans_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t) + C end - coo_adjtrans_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t) - C end end @@ -86,22 +90,24 @@ function coo_sym_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMatrix, end for (T, t) in ((Hermitian, adjoint), (Symmetric, transpose)) - @eval function LinearAlgebra.mul!( - C::StridedVecOrMat, - xA::$T{<:Any, <:AbstractSparseMatrixCOO}, - B::SparseArrays.DenseInputVecOrMat, - α::Number, - β::Number, - ) - A = xA.data - size(A, 2) == size(B, 1) || throw(DimensionMismatch()) - size(A, 1) == size(C, 1) || throw(DimensionMismatch()) - size(B, 2) == size(C, 2) || throw(DimensionMismatch()) - if β != 1 - β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) + for Tb in (AbstractVector, AbstractMatrix) + @eval function LinearAlgebra.mul!( + C::StridedVecOrMat, + xA::$T{<:Any, <:AbstractSparseMatrixCOO}, + B::$Tb, + α::Number, + β::Number, + ) + A = xA.data + size(A, 2) == size(B, 1) || throw(DimensionMismatch()) + size(A, 1) == size(C, 1) || throw(DimensionMismatch()) + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) + if β != 1 + β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C))) + end + coo_sym_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t, xA.uplo) + C end - coo_sym_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t, xA.uplo) - C end end