diff --git a/src/definitions.jl b/src/definitions.jl index 7901966..fe62e96 100644 --- a/src/definitions.jl +++ b/src/definitions.jl @@ -278,7 +278,7 @@ plan_ifft(x::AbstractArray, region; kws...) = plan_ifft!(x::AbstractArray, region; kws...) = ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region)) -plan_inv(p::ScaledPlan) = ScaledPlan(plan_inv(p.p), inv(p.scale)) +plan_inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale)) LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) = LinearAlgebra.lmul!(p.scale, LinearAlgebra.mul!(y, p.p, x)) diff --git a/test/testplans.jl b/test/testplans.jl index 5949da2..7abecfe 100644 --- a/test/testplans.jl +++ b/test/testplans.jl @@ -27,21 +27,20 @@ end function AbstractFFTs.plan_bfft(x::AbstractArray{T}, region; kwargs...) where {T} return InverseTestPlan{T}(region, size(x)) end + function AbstractFFTs.plan_inv(p::TestPlan{T}) where {T} unscaled_pinv = InverseTestPlan{T}(p.region, p.sz) - unscaled_pinv.pinv = p - pinv = AbstractFFTs.ScaledPlan( - unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), - ) + N = AbstractFFTs.normalization(T, p.sz, p.region) + unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, N) + pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, N) return pinv end -function AbstractFFTs.plan_inv(p::InverseTestPlan{T}) where {T} - unscaled_pinv = TestPlan{T}(p.region, p.sz) - unscaled_pinv.pinv = p - pinv = AbstractFFTs.ScaledPlan( - unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), - ) - return pinv +function AbstractFFTs.plan_inv(pinv::InverseTestPlan{T}) where {T} + unscaled_p = TestPlan{T}(pinv.region, pinv.sz) + N = AbstractFFTs.normalization(T, pinv.sz, pinv.region) + unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, N) + p = AbstractFFTs.ScaledPlan(unscaled_p, N) + return p end # Just a helper function since forward and backward are nearly identical @@ -118,22 +117,23 @@ function AbstractFFTs.plan_inv(p::TestRPlan{T,N}) where {T,N} firstdim = first(p.region)::Int d = p.sz[firstdim] sz = ntuple(i -> i == firstdim ? d รท 2 + 1 : p.sz[i], Val(N)) + _N = AbstractFFTs.normalization(T, p.sz, p.region) + unscaled_pinv = InverseTestRPlan{T}(d, p.region, sz) - unscaled_pinv.pinv = p - pinv = AbstractFFTs.ScaledPlan( - unscaled_pinv, AbstractFFTs.normalization(T, p.sz, p.region), - ) + unscaled_pinv.pinv = AbstractFFTs.ScaledPlan(p, _N) + pinv = AbstractFFTs.ScaledPlan(unscaled_pinv, _N) return pinv end -function AbstractFFTs.plan_inv(p::InverseTestRPlan{T,N}) where {T,N} - firstdim = first(p.region)::Int - sz = ntuple(i -> i == firstdim ? p.d : p.sz[i], Val(N)) - unscaled_pinv = TestRPlan{T}(p.region, sz) - unscaled_pinv.pinv = p - pinv = AbstractFFTs.ScaledPlan( - unscaled_pinv, AbstractFFTs.normalization(T, sz, p.region), - ) - return pinv + +function AbstractFFTs.plan_inv(pinv::InverseTestRPlan{T,N}) where {T,N} + firstdim = first(pinv.region)::Int + sz = ntuple(i -> i == firstdim ? pinv.d : pinv.sz[i], Val(N)) + _N = AbstractFFTs.normalization(T, sz, pinv.region) + + unscaled_p = TestRPlan{T}(pinv.region, sz) + unscaled_p.pinv = AbstractFFTs.ScaledPlan(pinv, _N) + p = AbstractFFTs.ScaledPlan(unscaled_p, _N) + return p end Base.size(p::TestRPlan) = p.sz