Skip to content

Commit

Permalink
Use new interpolation routine in Var
Browse files Browse the repository at this point in the history
This commit removes Interpolations.jl from Var.jl. To do this, the
function `_make_interpolant` was removed. Three new functions are added
which are `_check_interpolant`, `interpolate_point`, and
`interpolate_points`, where the latter two functions replace the
functionality of `_make_interpolant`. Furthermore, the function
`_find_extp_bound_cond` was refactored to `_find_extp_bound_conds` which
find multiple extrapolation condtions using `_find_extp_bound_cond`
which is refactored to find the extrapolation condition for a single
point. All functions that use an interpolant are updated to use the new
interpolation routine. One test worth mentioning is the test for
computing the bias in Atmos, which changes to check approximately close
to 0.0, due to floating point errors.
  • Loading branch information
ph-kev committed Dec 4, 2024
1 parent 770346e commit 7e0017f
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 58 deletions.
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ However, functions like `resampled_as` and interpolating using a `OutputVar` wil
as an interpolant must be generated. This means repeated calls to these functions will be
slower compared to the previous versions of ClimaAnalysis.

## Add interpolation routine
With this release, any functions that rely on interpolation now uses the interpolation
routine written for ClimaAnalysis instead of Interpolations.jl. This substantially reduce
the number and size of allocations when using these functions.

v0.5.12
-------

Expand Down
112 changes: 77 additions & 35 deletions src/Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ struct OutputVar{T <: AbstractArray, A <: AbstractArray, B, C}
end

"""
_make_interpolant(dims, data)
_check_interpolant(dims, data)
Make a linear interpolant from `dims`, a dictionary mapping dimension name to array and
`data`, an array containing data. Used in constructing a `OutputVar`.
Check if it is possible to create an interpolant.
If any element of the arrays in `dims` is a Dates.DateTime, then no interpolant is returned.
Interpolations.jl does not support interpolating on dates. If the longitudes span the entire
Expand All @@ -99,41 +98,86 @@ dimension. If the latitudes span the entire range and are equispaced, then a fla
condition is added for the latitude dimension. In all other cases, an error is thrown when
extrapolating outside of `dim_array`.
"""
function _make_interpolant(dims, data)
function _check_interpolant(dims)
# If any element is DateTime, then return nothing for the interpolant because
# Interpolations.jl do not support DateTimes
for dim_array in values(dims)
eltype(dim_array) <: Dates.DateTime && return nothing
eltype(dim_array) <: Dates.DateTime && return error(
"An interpolant cannot be made because interpolating on dates is not possible",
)
end

# We can only create interpolants when we have 1D dimensions
if isempty(dims) || any(d -> ndims(d) != 1 || length(d) == 1, values(dims))
return nothing
return error(
"An interpolant cannot be made because the dimensions are not 1D",
)
end

# Dimensions are all 1D, check that the knots are in increasing order (as required by
# Interpolations.jl)
for (dim_name, dim_array) in dims
if !issorted(dim_array)
@warn "Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order"
return nothing
return error(
"Dimension $dim_name is not in increasing order. An interpolant will not be created. See Var.reverse_dim if the dimension is in decreasing order",
)
end
end
return nothing
end

# Find boundary conditions for extrapolation
extp_bound_conds = (
_find_extp_bound_cond(dim_name, dim_array) for
(dim_name, dim_array) in dims
)
"""
interpolate_point(point, dims, data)
dims_tuple = tuple(values(dims)...)
extp_bound_conds_tuple = tuple(extp_bound_conds...)
return Intp.extrapolate(
Intp.interpolate(dims_tuple, data, Intp.Gridded(Intp.Linear())),
extp_bound_conds_tuple,
Linearly interpolate the point using `dims` and `data`.
Extrapolation conditions are determined by `_find_extp_bound_conds`.
"""
function interpolate_point(point, dims, data)
_check_interpolant(dims)
extp_bound_conds = _find_extp_bound_conds(dims)
return Numerics.linear_interpolate(
point,
Tuple(values(dims)),
data,
extp_bound_conds,
)
end

"""
interpolate_points(points, dims, data)
Linearly interpolate the points using `dims` and `data`.
Extrapolation conditions are determined by `_find_extp_bound_conds`.
"""
function interpolate_points(points, dims, data)
_check_interpolant(dims)
extp_bound_conds = _find_extp_bound_conds(dims)
dim_arrays_tuple = Tuple(values(dims))
interpolated_arr = [
Numerics.linear_interpolate(
point,
dim_arrays_tuple,
data,
extp_bound_conds,
) for point in points
]
return interpolated_arr
end

"""
_find_extp_bound_conds(dims)
Find the appropriate boundary conditions given the `dims` of an `OutputVar`.
"""
function _find_extp_bound_conds(dims)
return (
_find_extp_bound_cond(dim_name, dim_array) for
(dim_name, dim_array) in dims
) |> Tuple
end

"""
_find_extp_bound_cond(dim_name, dim_array)
Expand All @@ -151,17 +195,17 @@ function _find_extp_bound_cond(dim_name, dim_array)
conventional_dim_name(dim_name) == "longitude" &&
_isequispaced(dim_array) &&
isapprox(dim_size + dsize, 360.0)
) && return Intp.Periodic()
) && return Numerics.extp_cond_periodic()
(
conventional_dim_name(dim_name) == "longitude" &&
(dim_array[end] - dim_array[begin]) 360.0
) && return Intp.Periodic()
) && return Numerics.extp_cond_periodic()
(
conventional_dim_name(dim_name) == "latitude" &&
_isequispaced(dim_array) &&
isapprox(dim_size + dsize, 180.0)
) && return Intp.Flat()
return Intp.Throw()
) && return Numerics.extp_cond_flat()
return Numerics.extp_cond_throw()
end

function OutputVar(attribs, dims, dim_attribs, data)
Expand Down Expand Up @@ -1005,8 +1049,7 @@ julia> var2d = ClimaAnalysis.OutputVar(Dict("time" => time, "z" => z), data); va
```
"""
function (x::OutputVar)(target_coord)
itp = _make_interpolant(x.dims, x.data)
return itp(target_coord...)
return interpolate_point(target_coord, x.dims, x.data)
end

"""
Expand Down Expand Up @@ -1143,9 +1186,8 @@ function resampled_as(src_var::OutputVar, dest_var::OutputVar)
src_var = reordered_as(src_var, dest_var)
_check_dims_consistent(src_var, dest_var)

itp = _make_interpolant(src_var.dims, src_var.data)
src_resampled_data =
[itp(pt...) for pt in Base.product(values(dest_var.dims)...)]
coords = Base.product(values(dest_var.dims)...)
src_resampled_data = interpolate_points(coords, src_var.dims, src_var.data)

# Construct new OutputVar to return
src_var_ret_dims = empty(src_var.dims)
Expand Down Expand Up @@ -1756,14 +1798,14 @@ function make_lonlat_mask(

# Resample so that the mask match up with the grid of var
# Round because linear resampling is done and we want the mask to be only ones and zeros
intp = _make_interpolant(mask_var.dims, mask_var.data)
mask_arr =
[
intp(pt...) for pt in Base.product(
input_var.dims[longitude_name(input_var)],
input_var.dims[latitude_name(input_var)],
)
] .|> round
coords = [
pt for pt in Base.product(
input_var.dims[longitude_name(input_var)],
input_var.dims[latitude_name(input_var)],
)
]
mask_arr = interpolate_points(coords, mask_var.dims, mask_var.data)
mask_arr .= mask_arr .|> round

# Reshape data for broadcasting
lon_idx = input_var.dim2index[longitude_name(input_var)]
Expand Down
2 changes: 1 addition & 1 deletion test/test_Atmos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ end
sim_pressure = pressure3D,
obs_pressure = pressure3D,
)
@test global_rmse_pfull == 0.0
@test isapprox(global_rmse_pfull, 0.0, atol = 1e-11)

# Test if the computation is the same as a manual computation
zero_data = zeros(size(data))
Expand Down
53 changes: 31 additions & 22 deletions test/test_Var.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,35 +92,43 @@ end
lon = 0.5:1.0:359.5 |> collect
lat = -89.5:1.0:89.5 |> collect
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(), Intp.Flat(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Numerics.extp_cond_periodic(),
ClimaAnalysis.Numerics.extp_cond_flat(),
ClimaAnalysis.Numerics.extp_cond_throw(),
)

# Not equispaced for lon and lat
lon = 0.5:1.0:359.5 |> collect |> x -> push!(x, 42.0) |> sort
lat = -89.5:1.0:89.5 |> collect |> x -> push!(x, 42.0) |> sort
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Numerics.extp_cond_throw(),
ClimaAnalysis.Numerics.extp_cond_throw(),
ClimaAnalysis.Numerics.extp_cond_throw(),
)

# Does not span entire range for and lat
lon = 0.5:1.0:350.5 |> collect
lat = -89.5:1.0:80.5 |> collect
time = 1.0:100 |> collect
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Throw(), Intp.Throw(), Intp.Throw())
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (
ClimaAnalysis.Numerics.extp_cond_throw(),
ClimaAnalysis.Numerics.extp_cond_throw(),
ClimaAnalysis.Numerics.extp_cond_throw(),
)

# Lon is exactly 360 degrees
lon = 0.0:1.0:360.0 |> collect
data = ones(length(lon))
dims = OrderedDict(["lon" => lon])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test intp.et == (Intp.Periodic(),)
extp_conds = ClimaAnalysis.Var._find_extp_bound_conds(dims)
@test extp_conds == (ClimaAnalysis.Numerics.extp_cond_periodic(),)

# Dates for the time dimension
lon = 0.5:1.0:359.5 |> collect
Expand All @@ -130,17 +138,18 @@ end
Dates.DateTime(2020, 3, 1, 1, 2),
Dates.DateTime(2020, 3, 1, 1, 3),
]
data = ones(length(lon), length(lat), length(time))
dims = OrderedDict(["lon" => lon, "lat" => lat, "time" => time])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test isnothing(intp)
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)

# 2D dimensions
arb_dim = reshape(collect(range(-89.5, 89.5, 16)), (4, 4))
data = collect(1:16)
dims = OrderedDict(["arb_dim" => arb_dim])
intp = ClimaAnalysis.Var._make_interpolant(dims, data)
@test isnothing(intp)
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)

# Dimensions are not in increasing order
lon = [0.5, 42.0, 1.5, 110.0]
dims = OrderedDict(["lon" => lon])
@test_throws ErrorException ClimaAnalysis.Var._check_interpolant(dims)
end

@testset "empty" begin
Expand Down Expand Up @@ -497,6 +506,7 @@ end
@test ClimaAnalysis.pressure_name(pressure_var) == "pfull"
end

# FIX THIS
@testset "Interpolation" begin
# 1D interpolation with linear data, should yield correct results
long = -175.0:175.0 |> collect
Expand All @@ -507,7 +517,7 @@ end
@test longvar.([10.5, 20.5]) == [10.5, 20.5]

# Test error for data outside of range
@test_throws BoundsError longvar(200.0)
@test_throws ErrorException longvar(200.0)

# 2D interpolation with linear data, should yield correct results
time = 100.0:110.0 |> collect
Expand Down Expand Up @@ -812,7 +822,7 @@ end
@test src_var.data == ClimaAnalysis.resampled_as(src_var, src_var).data
resampled_var = ClimaAnalysis.resampled_as(src_var, dest_var)
@test resampled_var.data == reshape(1.0:(181 * 91), (181, 91))[1:91, 1:46]
@test_throws BoundsError ClimaAnalysis.resampled_as(dest_var, src_var)
@test_throws ErrorException ClimaAnalysis.resampled_as(dest_var, src_var)

# BoundsError check
src_long = 90.0:120.0 |> collect
Expand All @@ -837,7 +847,7 @@ end
dest_data,
)

@test_throws BoundsError ClimaAnalysis.resampled_as(src_var, dest_var)
@test_throws ErrorException ClimaAnalysis.resampled_as(src_var, dest_var)
end

@testset "Units" begin
Expand Down Expand Up @@ -1889,7 +1899,6 @@ end
attribs = Dict("long_name" => "hi")
dim_attribs = OrderedDict(["lon" => Dict("units" => "deg")])
var = ClimaAnalysis.OutputVar(attribs, dims, dim_attribs, data)
@test isnothing(ClimaAnalysis.Var._make_interpolant(dims, data))

reverse_var = ClimaAnalysis.reverse_dim(var, "lat")
@test reverse(lat) == reverse_var.dims["lat"]
Expand Down

0 comments on commit 7e0017f

Please sign in to comment.