Skip to content

Commit

Permalink
dispatch vec mat
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffroyleconte committed Sep 22, 2023
1 parent 63dc663 commit 63465cf
Showing 1 changed file with 50 additions and 44 deletions.
94 changes: 50 additions & 44 deletions src/coo_linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,23 @@ function coo_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMatrix, α,
end
end

function LinearAlgebra.mul!(
C::StridedVecOrMat,
A::AbstractSparseMatrixCOO,
B::SparseArrays.DenseInputVecOrMat,
α::Number,
β::Number,
)
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
for T in (AbstractVector, AbstractMatrix)
@eval function LinearAlgebra.mul!(
C::StridedVecOrMat,
A::AbstractSparseMatrixCOO,
B::$T,
α::Number,
β::Number,
)
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
coo_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A))
C
end
coo_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A))
C
end

function coo_adjtrans_mul!(C::AbstractVector, Arows, Acols, Avals, B::AbstractVector, α, Annz, t)
Expand All @@ -44,22 +46,24 @@ function coo_adjtrans_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMa
end

for (T, t) in ((Adjoint, adjoint), (Transpose, transpose))
@eval function LinearAlgebra.mul!(
C::StridedVecOrMat,
xA::$T{<:Any, <:AbstractSparseMatrixCOO},
B::SparseArrays.DenseInputVecOrMat,
α::Number,
β::Number,
)
A = xA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
for Tb in (AbstractVector, AbstractMatrix)
@eval function LinearAlgebra.mul!(
C::StridedVecOrMat,
xA::$T{<:Any, <:AbstractSparseMatrixCOO},
B::$Tb,
α::Number,
β::Number,
)
A = xA.parent
size(A, 2) == size(C, 1) || throw(DimensionMismatch())
size(A, 1) == size(B, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
coo_adjtrans_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t)
C
end
coo_adjtrans_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t)
C
end
end

Expand All @@ -86,22 +90,24 @@ function coo_sym_mul!(C::AbstractMatrix, Arows, Acols, Avals, B::AbstractMatrix,
end

for (T, t) in ((Hermitian, adjoint), (Symmetric, transpose))
@eval function LinearAlgebra.mul!(
C::StridedVecOrMat,
xA::$T{<:Any, <:AbstractSparseMatrixCOO},
B::SparseArrays.DenseInputVecOrMat,
α::Number,
β::Number,
)
A = xA.data
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
for Tb in (AbstractVector, AbstractMatrix)
@eval function LinearAlgebra.mul!(
C::StridedVecOrMat,
xA::$T{<:Any, <:AbstractSparseMatrixCOO},
B::$Tb,
α::Number,
β::Number,
)
A = xA.data
size(A, 2) == size(B, 1) || throw(DimensionMismatch())
size(A, 1) == size(C, 1) || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
if β != 1
β != 0 ? rmul!(C, β) : fill!(C, zero(eltype(C)))
end
coo_sym_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t, xA.uplo)
C
end
coo_sym_mul!(C, A.rows, A.cols, A.vals, B, α, nnz(A), $t, xA.uplo)
C
end
end

Expand Down

0 comments on commit 63465cf

Please sign in to comment.