Skip to content

Commit

Permalink
DAS-2232 - replaced get_geo_grid_corners method to get 2 valid geo gr…
Browse files Browse the repository at this point in the history
…id points
  • Loading branch information
sudha-murthy committed Oct 16, 2024
1 parent dd98e81 commit de350b6
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 250 deletions.
297 changes: 183 additions & 114 deletions hoss/coordinate_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,14 @@
import numpy as np
from netCDF4 import Dataset
from numpy import ndarray
from pyproj import CRS
from pyproj import CRS, Transformer
from varinfo import VariableFromDmr, VarInfoFromDmr

from hoss.exceptions import (
IrregularCoordinateDatasets,
MissingCoordinateDataset,
MissingValidCoordinateDataset,
)
from hoss.projection_utilities import (
get_x_y_extents_from_geographic_points,
CannotComputeDimensionResolution,
InvalidCoordinateVariable,
IrregularCoordinateVariables,
MissingCoordinateVariable,
)


Expand Down Expand Up @@ -74,16 +72,16 @@ def get_override_projected_dimensions(


def get_variables_with_anonymous_dims(
varinfo: VarInfoFromDmr, required_variables: set[str]
) -> bool:
varinfo: VarInfoFromDmr, variables: set[str]
) -> set[str]:
"""
returns the list of required variables without any
returns a set of variables without any
dimensions
"""
return set(
required_variable
for required_variable in required_variables
if len(varinfo.get_variable(required_variable).dimensions) == 0
variable
for variable in variables
if len(varinfo.get_variable(variable).dimensions) == 0
)


Expand Down Expand Up @@ -146,36 +144,19 @@ def update_dimension_variables(
lat_arr,
lon_arr,
)

geo_grid_corners = get_geo_grid_corners(
lat_arr,
lon_arr,
lat_fill,
lon_fill,
geo_grid_points = get_two_valid_geo_grid_points(
lat_arr, lon_arr, lat_fill, lon_fill, row_size, col_size
)

x_y_extents = get_x_y_extents_from_geographic_points(geo_grid_corners, crs)
x_y_values = get_x_y_values_from_geographic_points(geo_grid_points, crs)

# get grid size and resolution
x_min = x_y_extents['x_min']
x_max = x_y_extents['x_max']
y_min = x_y_extents['y_min']
y_max = x_y_extents['y_max']
x_resolution = (x_max - x_min) / row_size
y_resolution = (y_max - y_min) / col_size
row_indices, col_indices = zip(*list(x_y_values.keys()))

# create the xy dim scales
lat_asc, lon_asc = is_lat_lon_ascending(lat_arr, lon_arr, lat_fill, lon_fill)
x_values, y_values = zip(*list(x_y_values.values()))

if lon_asc:
x_dim = np.arange(x_min, x_max, x_resolution)
else:
x_dim = np.arange(x_min, x_max, -x_resolution)
y_dim = get_dimension_scale_from_dimvalues(y_values, row_indices, row_size)

if lat_asc:
y_dim = np.arange(y_max, y_min, y_resolution)
else:
y_dim = np.arange(y_max, y_min, -y_resolution)
x_dim = get_dimension_scale_from_dimvalues(x_values, col_indices, col_size)

return {'projected_y': y_dim, 'projected_x': x_dim}

Expand All @@ -190,44 +171,17 @@ def get_row_col_sizes_from_coordinate_datasets(
"""
row_size = 0
col_size = 0
if lat_arr.ndim > 1:
if lat_arr.ndim > 1 and lon_arr.shape == lat_arr.shape:
col_size = lat_arr.shape[0]
row_size = lat_arr.shape[1]
if (lon_arr.shape[0] != lat_arr.shape[0]) or (lon_arr.shape[1] != lat_arr.shape[1]):
raise IrregularCoordinateDatasets(lon_arr.shape, lat_arr.shape)
if lat_arr.ndim and lon_arr.ndim == 1:
elif lat_arr.ndim == 1 and lon_arr.ndim == 1:
col_size = lat_arr.size
row_size = lon_arr.size
elif lon_arr.shape != lat_arr.shape:
raise IrregularCoordinateVariables(lon_arr.ndim, lat_arr.ndim)
return row_size, col_size


def is_lat_lon_ascending(
lat_arr: ndarray,
lon_arr: ndarray,
lat_fill: float,
lon_fill: float,
) -> tuple[bool, bool]:
"""
Checks if the latitude and longitude cooordinate datasets have values
that are ascending
"""

lat_col = lat_arr[:, 0]
lon_row = lon_arr[0, :]

lat_col_valid_indices = get_valid_indices(lat_col, lat_fill, 'latitude')
latitude_ascending = (
lat_col[lat_col_valid_indices[1]] > lat_col[lat_col_valid_indices[0]]
)

lon_row_valid_indices = get_valid_indices(lon_row, lon_fill, 'longitude')
longitude_ascending = (
lon_row[lon_row_valid_indices[1]] > lon_row[lon_row_valid_indices[0]]
)

return latitude_ascending, longitude_ascending


def get_lat_lon_arrays(
prefetch_dataset: Dataset,
latitude_coordinate: VariableFromDmr,
Expand All @@ -242,61 +196,184 @@ def get_lat_lon_arrays(
lon_arr = prefetch_dataset[longitude_coordinate.full_name_path][:]
return lat_arr, lon_arr
except Exception as exception:
raise MissingCoordinateDataset('latitude/longitude') from exception
raise MissingCoordinateVariable('latitude/longitude') from exception


def get_geo_grid_corners(
def get_two_valid_geo_grid_points(
lat_arr: ndarray,
lon_arr: ndarray,
lat_fill: float,
lon_fill: float,
) -> list[Tuple[float, float]]:
row_size: float,
col_size: float,
) -> dict[int, tuple]:
"""
This method is used to return the lat lon corners from a 2D
This method is used to return two valid lat lon points from a 2D
coordinate dataset. It gets the row and column of the latitude and longitude
arrays to get the corner points. This does a check for fill values and
This method does not check if there are fill values in the corner points
to go down to the next row and col. The fill values in the corner points
still needs to be addressed. It will raise an exception in those
cases.
arrays to get two valid points. This does a check for fill values and
This method does not go down to the next row and col. if the selected row and
column all have fills, it will raise an exception in those cases.
"""
first_row_col_index = -1
first_row_row_index = 0
next_col_row_index = -1
next_col_col_index = 1
lat_row_valid_indices = lon_row_valid_indices = np.empty((0, 0))

# get the first row with points that are valid in the lat and lon rows
first_row_row_index, lat_row_valid_indices = get_valid_indices_in_dataset(
lat_arr, row_size, lat_fill, 'latitude', 'row', first_row_row_index
)
first_row_row_index1, lon_row_valid_indices = get_valid_indices_in_dataset(
lon_arr, row_size, lon_fill, 'longitude', 'row', first_row_row_index
)
# get a point that is common on both row datasets
if (
(first_row_row_index == first_row_row_index1)
and (lat_row_valid_indices.size > 0)
and (lon_row_valid_indices.size > 0)
):
first_row_col_index = np.intersect1d(
lat_row_valid_indices, lon_row_valid_indices
)[0]

# get a valid column from the latitude and longitude datasets
next_col_col_index, lon_col_valid_indices = get_valid_indices_in_dataset(
lon_arr, col_size, lon_fill, 'longitude', 'col', next_col_col_index
)
next_col_col_index1, lat_col_valid_indices = get_valid_indices_in_dataset(
lat_arr, col_size, lat_fill, 'latitude', 'col', next_col_col_index
)

# get a point that is common to both column datasets
if (
(next_col_col_index == next_col_col_index1)
and (lat_col_valid_indices.size > 0)
and (lon_col_valid_indices.size > 0)
):
next_col_row_index = np.intersect1d(
lat_col_valid_indices, lon_col_valid_indices
)[-1]

# if the whole row and whole column has no valid indices
# we throw an exception now. This can be extended to move
# to the next row/col
if first_row_col_index == -1:
raise InvalidCoordinateVariable('latitude/longitude')
if next_col_row_index == -1:
raise InvalidCoordinateVariable('latitude/longitude')

geo_grid_indexes = [
(first_row_row_index, first_row_col_index),
(next_col_row_index, next_col_col_index),
]

geo_grid_points = [
(
lon_arr[first_row_row_index][first_row_col_index],
lat_arr[first_row_row_index][first_row_col_index],
),
(
lon_arr[next_col_row_index][next_col_col_index],
lat_arr[next_col_row_index][next_col_col_index],
),
]

return {
geo_grid_indexes[0]: geo_grid_points[0],
geo_grid_indexes[1]: geo_grid_points[1],
}


def get_x_y_values_from_geographic_points(points: Dict, crs: CRS) -> Dict[tuple, tuple]:
"""Take an input list of (longitude, latitude) coordinates and project
those points to the target grid. Then return the x-y dimscales
"""
point_longitudes, point_latitudes = zip(*list(points.values()))

top_left_row_idx = 0
top_left_col_idx = 0
from_geo_transformer = Transformer.from_crs(4326, crs)
points_x, points_y = ( # pylint: disable=unpacking-non-sequence
from_geo_transformer.transform(point_latitudes, point_longitudes)
)

# get the first row from the longitude dataset
lon_row = lon_arr[top_left_row_idx, :]
lon_row_valid_indices = get_valid_indices(lon_row, lon_fill, 'longitude')
x_y_points = {}
for index, point_x, point_y in zip(list(points.keys()), points_x, points_y):
x_y_points.update({index: (point_x, point_y)})

# get the index of the minimum longitude after checking for invalid entries
top_left_col_idx = lon_row_valid_indices[lon_row[lon_row_valid_indices].argmin()]
min_lon = lon_row[top_left_col_idx]
return x_y_points

# get the index of the maximum longitude after checking for invalid entries
top_right_col_idx = lon_row_valid_indices[lon_row[lon_row_valid_indices].argmax()]
max_lon = lon_row[top_right_col_idx]

# get the last valid longitude column to get the latitude array
lat_col = lat_arr[:, top_right_col_idx]
lat_col_valid_indices = get_valid_indices(lat_col, lat_fill, 'latitude')
def get_dimension_scale_from_dimvalues(
dim_values: ndarray, dim_indices: ndarray, dim_size: float
) -> ndarray:
"""
return a full dimension scale based on the 2 projected points and
grid size
"""
dim_resolution = 0.0
if (dim_indices[1] != dim_indices[0]) and (dim_values[1] != dim_values[0]):
dim_resolution = (dim_values[1] - dim_values[0]) / (
dim_indices[1] - dim_indices[0]
)
if dim_resolution == 0.0:
raise CannotComputeDimensionResolution(dim_values[0], dim_indices[0])

# get the index of minimum latitude after checking for valid values
bottom_right_row_idx = lat_col_valid_indices[
lat_col[lat_col_valid_indices].argmin()
]
min_lat = lat_col[bottom_right_row_idx]
# create the dim scale
dim_asc = dim_values[1] > dim_values[0]

# get the index of maximum latitude after checking for valid values
top_right_row_idx = lat_col_valid_indices[lat_col[lat_col_valid_indices].argmax()]
max_lat = lat_col[top_right_row_idx]
if dim_asc:
dim_min = dim_values[0] + (dim_resolution * dim_indices[0])
dim_max = dim_values[0] + (dim_resolution * (dim_size - dim_indices[0] - 1))
dim_data = np.linspace(dim_min, dim_max, dim_size)
else:
dim_max = dim_values[0] + (dim_resolution * dim_indices[0])
dim_min = dim_values[0] - (-dim_resolution * (dim_size - dim_indices[0] - 1))
dim_data = np.linspace(dim_max, dim_min, dim_size)

geo_grid_corners = [
(min_lon, max_lat),
(max_lon, max_lat),
(max_lon, min_lat),
(min_lon, min_lat),
]
return geo_grid_corners
return dim_data


def get_valid_indices_in_dataset(
coordinate_arr: ndarray,
dim_size: int,
coordinate_fill: float,
coordinate_name: str,
span_type: str,
start_index: int,
) -> tuple[int, ndarray]:
"""
This method gets valid indices in a row or column of a
coordinate dataset
"""
coordinate_index = start_index
valid_indices = []
if span_type == 'row':
valid_indices = get_valid_indices(
coordinate_arr[coordinate_index, :], coordinate_fill, coordinate_name
)
else:
valid_indices = get_valid_indices(
coordinate_arr[:, coordinate_index], coordinate_fill, coordinate_name
)
while valid_indices.size == 0:
if coordinate_index < dim_size:
coordinate_index = coordinate_index + 1
if span_type == 'row':
valid_indices = get_valid_indices(
coordinate_arr[coordinate_index, :],
coordinate_fill,
coordinate_name,
)
else:
valid_indices = get_valid_indices(
coordinate_arr[:, coordinate_index],
coordinate_fill,
coordinate_name,
)
else:
raise InvalidCoordinateVariable(coordinate_name)
return coordinate_index, valid_indices


def get_valid_indices(
Expand All @@ -319,11 +396,6 @@ def get_valid_indices(
(coordinate_row_col >= -90.0) & (coordinate_row_col <= 90.0)
)[0]

# if the first row does not have valid indices,
# should go down to the next row. We throw an exception
# for now till that gets addressed
if not valid_indices.size:
raise MissingValidCoordinateDataset(coordinate_name)
return valid_indices


Expand All @@ -341,9 +413,6 @@ def get_fill_values_for_coordinates(
lon_fill = None
lat_fill_value = latitude_coordinate.get_attribute_value('_FillValue')
lon_fill_value = longitude_coordinate.get_attribute_value('_FillValue')
# if fill_value is None:
# check if there are overrides in hoss_config.json using varinfo
# else

if lat_fill_value is not None:
lat_fill = float(lat_fill_value)
Expand Down
Loading

0 comments on commit de350b6

Please sign in to comment.