Skip to content

Commit

Permalink
Immutable ScaledPlan
Browse files Browse the repository at this point in the history
  • Loading branch information
gaurav-arya committed Jul 16, 2022
1 parent 600b24a commit 10e12af
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions src/definitions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ plan_rfft(x::AbstractArray, region; kws...) = plan_rfft(realfloat(x), region; kw
_pinv_type(p::Plan) = typeof([plan_inv(x) for x in typeof(p)[]])
pinv_type(p::Plan) = eltype(_pinv_type(p))

function plan_inv end

inv(p::Plan) =
isdefined(p, :pinv) ? p.pinv::pinv_type(p) : (p.pinv = plan_inv(p))
\(p::Plan, x::AbstractArray) = inv(p) * x
Expand All @@ -243,10 +245,9 @@ LinearAlgebra.ldiv!(y::AbstractArray, p::Plan, x::AbstractArray) = LinearAlgebra
# implementations only need to provide the unnormalized backwards FFT,
# similar to FFTW, and we do the scaling generically to get the ifft:

mutable struct ScaledPlan{T,P,N} <: Plan{T}
struct ScaledPlan{T,P,N} <: Plan{T}
p::P
scale::N # not T, to avoid unnecessary promotion to Complex
pinv::Plan
ScaledPlan{T,P,N}(p, scale) where {T,P,N} = new(p, scale)
end
ScaledPlan{T}(p::P, scale::N) where {T,P,N} = ScaledPlan{T,P,N}(p, scale)
Expand Down Expand Up @@ -278,7 +279,7 @@ plan_ifft(x::AbstractArray, region; kws...) =
plan_ifft!(x::AbstractArray, region; kws...) =
ScaledPlan(plan_bfft!(x, region; kws...), normalization(x, region))

plan_inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale))
inv(p::ScaledPlan) = ScaledPlan(inv(p.p), inv(p.scale))

LinearAlgebra.mul!(y::AbstractArray, p::ScaledPlan, x::AbstractArray) =
LinearAlgebra.lmul!(p.scale, LinearAlgebra.mul!(y, p.p, x))
Expand Down

0 comments on commit 10e12af

Please sign in to comment.