diff --git a/src/interpolation.jl b/src/interpolation.jl index 7d2bf8e8e..e49044063 100644 --- a/src/interpolation.jl +++ b/src/interpolation.jl @@ -67,41 +67,37 @@ to AD-based derivatives. i = 2 # Start the search thinking it's between t[1] and t[2] tdir*tvals[idx[end]] > tdir*t[end] && error("Solution interpolation cannot extrapolate past the final timepoint. Either solve on a longer timespan or use the local extrapolation from the integrator interface.") tdir*tvals[idx[1]] < tdir*t[1] && error("Solution interpolation cannot extrapolate before the first timepoint. Either start solving earlier or use the local extrapolation from the integrator interface.") - if typeof(idxs) <: Number - vals = Vector{eltype(first(u))}(undef, length(tvals)) - elseif typeof(idxs) <: AbstractVector - vals = Vector{Vector{eltype(first(u))}}(undef, length(tvals)) - else - vals = Vector{eltype(u)}(undef, length(tvals)) - end - @inbounds for j in idx - tval = tvals[j] - i = searchsortedfirst(@view(t[i:end]),tval,rev=tdir<0)+i-1 # It's in the interval t[i-1] to t[i] - avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual - avoid_constant_ends && i==1 && (i+=1) - if !avoid_constant_ends && t[i] == tval - lasti = lastindex(t) - k = continuity == :right && i+1 <= lasti && t[i+1] == tval ? i+1 : i - if idxs === nothing - vals[j] = u[k] - else - vals[j] = u[k][idxs] - end - elseif !avoid_constant_ends && t[i-1] == tval # Can happen if it's the first value! - if idxs === nothing - vals[j] = u[i-1] - else - vals[j] = u[i-1][idxs] - end - else - typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) - dt = t[i] - t[i-1] - Θ = (tval-t[i-1])/dt - idxs_internal = idxs - if typeof(id) <: HermiteInterpolation - vals[j] = interpolant(Θ,id,dt,u[i-1],u[i],du[i-1],du[i],idxs_internal,deriv) + + map(idx) do j + @inbounds begin + tval = tvals[j] + i = searchsortedfirst(@view(t[i:end]),tval,rev=tdir<0)+i-1 # It's in the interval t[i-1] to t[i] + avoid_constant_ends = deriv != Val{0} #|| typeof(tval) <: ForwardDiff.Dual + avoid_constant_ends && i==1 && (i+=1) + if !avoid_constant_ends && t[i] == tval + lasti = lastindex(t) + k = continuity == :right && i+1 <= lasti && t[i+1] == tval ? i+1 : i + if idxs === nothing + return u[k] + else + return u[k][idxs] + end + elseif !avoid_constant_ends && t[i-1] == tval # Can happen if it's the first value! + if idxs === nothing + return u[i-1] + else + return u[i-1][idxs] + end else - vals[j] = interpolant(Θ,id,dt,u[i-1],u[i],idxs_internal,deriv) + typeof(id) <: SensitivityInterpolation && error(SENSITIVITY_INTERP_MESSAGE) + dt = t[i] - t[i-1] + Θ = (tval-t[i-1])/dt + idxs_internal = idxs + if typeof(id) <: HermiteInterpolation + return interpolant(Θ,id,dt,u[i-1],u[i],du[i-1],du[i],idxs_internal,deriv) + else + return interpolant(Θ,id,dt,u[i-1],u[i],idxs_internal,deriv) + end end end end