Skip to content

Commit

Permalink
issue fatiando#439
Browse files Browse the repository at this point in the history
  • Loading branch information
Pedro Silva committed Mar 19, 2024
1 parent 878f195 commit 1b58b40
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
13 changes: 13 additions & 0 deletions verde/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,22 @@
meshgrid_to_1d,
parse_engine,
partition_by_sum,
fill_nans
)


def test_fill_nans():
"""
This function tests the fill_nans function.
"""

grid = np.array([[1, np.nan, 3],
[4, 5, np.nan],
[np.nan, 7, 8]])
filled_grid = fill_nans(grid)
assert np.isnan(filled_grid).sum() == 0


def test_parse_engine():
"Check that it works for common input"
assert parse_engine("numba") == "numba"
Expand Down
31 changes: 31 additions & 0 deletions verde/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pandas as pd
import xarray as xr
from scipy.spatial import cKDTree
from sklearn.impute import KNNImputer

try:
from pykdtree.kdtree import KDTree as pyKDTree
Expand Down Expand Up @@ -681,6 +682,36 @@ def kdtree(coordinates, use_pykdtree=True, **kwargs):
return tree


def fill_nans(grid, n_neighbors=1):
"""
This methos is responsible for fill the NaN values in the grid using the KNN algorithm.
Parameters
----------
grid : :class:`xarray.Dataset` or :class:`xarray.DataArray`
A 2D grid with one or more data variables.
n_neighbors : int
Number of nearest neighbors to use to fill the NaN values in the grid.
The greater the quantity, the longer the processing time, depending on the size of the matrix
Returns
-------
grid : :class:`xarray.Dataset` or :class:`xarray.DataArray`
A 2D grid with the NaN values filled.
"""

not_nan_values = np.argwhere(~np.isnan(grid)).reshape(-1, 1)
unknown_indices = np.argwhere(np.isnan(grid))

knn_imputer = KNNImputer(n_neighbors=n_neighbors)
knn_imputer.fit(not_nan_values)

predicted_values = knn_imputer.transform(not_nan_values)
for i, idx in enumerate(unknown_indices):
grid[tuple(idx)] = predicted_values[i]

return grid

def partition_by_sum(array, parts):
"""
Partition an array into parts of approximately equal sum.
Expand Down

0 comments on commit 1b58b40

Please sign in to comment.