Skip to content

Commit

Permalink
Fix bad interaction with StaticArrays broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
projekter committed Jul 8, 2024
1 parent 9909ac2 commit f16d9a5
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 3 deletions.
11 changes: 9 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,22 +1,29 @@
name = "StandardPacked"
uuid = "65f29c17-9e56-4606-972f-81e04007695c"
authors = ["Benjamin Desef <[email protected]>"]
version = "1.0.2"
version = "1.0.3"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"

[weakdeps]
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Documenter", "Test"]
test = ["Documenter", "Test", "StaticArrays"]

[extensions]
StandardPackedStaticArrays = "StaticArrays"

[compat]
julia = "1.8"
Documenter = "1"
LinearAlgebra = "1"
SparseArrays = "1"
StaticArrays = "1.4.2"
Test = "1"
36 changes: 36 additions & 0 deletions ext/StandardPackedStaticArrays.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
module StandardPackedStaticArrays

import StandardPacked
import StaticArrays

# We need to be extra careful here - StaticArrays's broadcasting has higher precedence, but will try to write n^2 elements to
# our linear index. But we can spell this out and only copy the upper triangle.
@generated function StaticArrays._broadcast!(f, ::StaticArrays.Size{newsize}, dest::StandardPacked.SPMatrix,
s::Tuple{Vararg{StaticArrays.Size}}, a...) where {newsize}
sizes = [sz.parameters[1] for sz in s.parameters]

indices = CartesianIndices(newsize)
ps = StandardPacked.packedsize(newsize[1])
exprs_eval = similar(indices, Expr, ps)
exprs_setindex = similar(indices, Expr, ps)
cmp = StandardPacked.packed_isupper(dest) ? Base.: : Base.:
j = 1
for current_ind indices
if cmp(current_ind[1], current_ind[2])
exprs_vals = (StaticArrays.broadcast_getindex(sz, i, current_ind) for (i, sz) in enumerate(sizes))
symb_val_j = Symbol(:val_, j)
exprs_eval[j] = :($symb_val_j = f($(exprs_vals...)))
exprs_setindex[j] = :(dest[$j] = $symb_val_j)
j += 1
end
end

return quote
Base.@_inline_meta
$(Expr(:block, exprs_eval...))
@inbounds $(Expr(:block, exprs_setindex...))
return dest
end
end

end
15 changes: 14 additions & 1 deletion test/StandardPacked.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using Test
using StandardPacked
using LinearAlgebra, SparseArrays
using Base: _realtype
using StaticArrays

setprecision(8)

Expand Down Expand Up @@ -1084,4 +1085,16 @@ Base.eps(::Type{Complex{R}}) where {R} = 10eps(R)
end
end
end
end end
end end

@testset "StaticArrays broadcasting" begin
base = collect(1:16)
pm = SPMatrix(4, @view(base[1:packedsize(4)]))
s = SMatrix{4,4}([100 200 300 400; 500 600 700 800; 900 1000 1100 1200; 1300 1400 1500 1600])
pm .= s
@test base == [100, 200, 600, 300, 700, 1100, 400, 800, 1200, 1600, 11, 12, 13, 14, 15, 16]
copyto!(base, 1:16)
pm2 = SPMatrix(4, @view(base[1:packedsize(4)]), :L)
pm2 .= s
@test base == [100, 500, 900, 1300, 600, 1000, 1400, 1100, 1500, 1600, 11, 12, 13, 14, 15, 16]
end

0 comments on commit f16d9a5

Please sign in to comment.