-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
534 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
geopandas==0.14.3 | ||
numpy==1.26.4 | ||
pyproj==3.6.1 | ||
rasterio==1.3.10 | ||
shapely==2.0.4 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from pathlib import Path | ||
from typing import Protocol | ||
|
||
import numpy.typing as npt | ||
|
||
from src.functional.data.data_fetcher import ( | ||
vrt_data_fetcher, | ||
) | ||
from src.utils.types import ( | ||
BufferSize, | ||
GroundSamplingDistance, | ||
InterpolationMode, | ||
TileSize, | ||
XMin, | ||
YMin, | ||
) | ||
|
||
|
||
class DataFetcher(Protocol): | ||
|
||
def __call__( | ||
self, | ||
x_min: XMin, | ||
y_min: YMin, | ||
) -> npt.NDArray: | ||
""" | ||
| Fetches the data. | ||
:param x_min: minimum x coordinate | ||
:param y_min: minimum y coordinate | ||
:return: data | ||
""" | ||
... | ||
|
||
|
||
class VRTDataFetcher: | ||
_FILL_VALUE = 0 | ||
|
||
def __init__( | ||
self, | ||
path: Path, | ||
tile_size: TileSize, | ||
ground_sampling_distance: GroundSamplingDistance, | ||
interpolation_mode: InterpolationMode = InterpolationMode.BILINEAR, | ||
buffer_size: BufferSize = None, | ||
drop_channels: list[int] = None, | ||
) -> None: | ||
""" | ||
:param path: path to the VRT file | ||
:param tile_size: tile size in meters | ||
:param ground_sampling_distance: ground sampling distance in meters | ||
:param interpolation_mode: interpolation mode (InterpolationMode.BILINEAR or InterpolationMode.NEAREST) | ||
:param buffer_size: buffer size in meters | ||
:param drop_channels: channel indices to drop | ||
""" | ||
self.path = path | ||
self.tile_size = tile_size | ||
self.ground_sampling_distance = ground_sampling_distance | ||
self.interpolation_mode = interpolation_mode | ||
self.buffer_size = buffer_size | ||
self.drop_channels = drop_channels | ||
|
||
def __call__( | ||
self, | ||
x_min: XMin, | ||
y_min: YMin, | ||
) -> npt.NDArray: | ||
""" | ||
| Fetches the data from the VRT file. | ||
:param x_min: minimum x coordinate | ||
:param y_min: minimum y coordinate | ||
:return: data | ||
""" | ||
return vrt_data_fetcher( | ||
x_min=x_min, | ||
y_min=y_min, | ||
path=self.path, | ||
tile_size=self.tile_size, | ||
ground_sampling_distance=self.ground_sampling_distance, | ||
interpolation_mode=self.interpolation_mode, | ||
buffer_size=self.buffer_size, | ||
drop_channels=self.drop_channels, | ||
fill_value=self._FILL_VALUE, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from pathlib import Path | ||
from unittest.mock import patch | ||
|
||
from src.data.data_fetcher import ( | ||
VRTDataFetcher, | ||
) | ||
from src.utils.types import ( | ||
InterpolationMode, | ||
) | ||
|
||
|
||
def test_init_vrt_data_fetcher() -> None: | ||
path = Path('test/test.vrt') | ||
tile_size = 128 | ||
ground_sampling_distance = .2 | ||
interpolation_mode = InterpolationMode.BILINEAR | ||
buffer_size = None | ||
drop_channels = None | ||
vrt_data_fetcher = VRTDataFetcher( | ||
path=path, | ||
tile_size=tile_size, | ||
ground_sampling_distance=ground_sampling_distance, | ||
interpolation_mode=interpolation_mode, | ||
buffer_size=buffer_size, | ||
drop_channels=drop_channels, | ||
) | ||
|
||
assert vrt_data_fetcher.path == path | ||
assert vrt_data_fetcher.tile_size == tile_size | ||
assert vrt_data_fetcher.ground_sampling_distance == ground_sampling_distance | ||
assert vrt_data_fetcher.interpolation_mode == interpolation_mode | ||
assert vrt_data_fetcher.buffer_size == buffer_size | ||
assert vrt_data_fetcher.drop_channels == drop_channels | ||
|
||
|
||
@patch('src.data.data_fetcher.vrt_data_fetcher') | ||
def test_call_vrt_data_fetcher( | ||
mocked_vrt_data_fetcher, | ||
vrt_data_fetcher: VRTDataFetcher, | ||
) -> None: | ||
x_min = -128 | ||
y_min = -128 | ||
expected = 'expected' | ||
mocked_vrt_data_fetcher.return_value = expected | ||
data = vrt_data_fetcher( | ||
x_min=x_min, | ||
y_min=y_min, | ||
) | ||
|
||
mocked_vrt_data_fetcher.assert_called_once_with( | ||
x_min=x_min, | ||
y_min=y_min, | ||
path=vrt_data_fetcher.path, | ||
tile_size=vrt_data_fetcher.tile_size, | ||
ground_sampling_distance=vrt_data_fetcher.ground_sampling_distance, | ||
interpolation_mode=vrt_data_fetcher.interpolation_mode, | ||
buffer_size=vrt_data_fetcher.buffer_size, | ||
drop_channels=vrt_data_fetcher.drop_channels, | ||
fill_value=vrt_data_fetcher._FILL_VALUE, | ||
) | ||
assert data == expected |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,158 @@ | ||
from pathlib import Path | ||
|
||
import numpy as np | ||
import numpy.typing as npt | ||
import rasterio as rio | ||
import rasterio.windows | ||
|
||
from src.utils.types import ( | ||
BoundingBox, | ||
BufferSize, | ||
GroundSamplingDistance, | ||
InterpolationMode, | ||
TileSize, | ||
XMin, | ||
YMin, | ||
) | ||
|
||
|
||
def vrt_data_fetcher( | ||
x_min: XMin, | ||
y_min: YMin, | ||
path: Path, | ||
tile_size: TileSize, | ||
ground_sampling_distance: GroundSamplingDistance, | ||
interpolation_mode: InterpolationMode = InterpolationMode.BILINEAR, | ||
buffer_size: BufferSize = None, | ||
drop_channels: list[int] = None, | ||
fill_value: int = 0, | ||
) -> npt.NDArray: | ||
""" | ||
| Fetches the data from the VRT file. | ||
:param x_min: minimum x coordinate | ||
:param y_min: minimum y coordinate | ||
:param path: path to the VRT file | ||
:param tile_size: tile size in meters | ||
:param ground_sampling_distance: ground sampling distance in meters | ||
:param interpolation_mode: interpolation mode (InterpolationMode.BILINEAR or InterpolationMode.NEAREST) | ||
:param buffer_size: buffer size in meters | ||
:param drop_channels: channel indices to drop | ||
:param fill_value: fill value of nodata pixels | ||
:return: data | ||
""" | ||
x_min, y_min, x_max, y_max = _compute_bounding_box( | ||
x_min=x_min, | ||
y_min=y_min, | ||
tile_size=tile_size, | ||
buffer_size=buffer_size, | ||
) | ||
tile_size_pixels = _compute_tile_size_pixels( | ||
tile_size=tile_size, | ||
buffer_size=buffer_size, | ||
ground_sampling_distance=ground_sampling_distance, | ||
) | ||
|
||
with rio.open(path) as src: | ||
window = rio.windows.from_bounds( | ||
left=x_min, | ||
bottom=y_min, | ||
right=x_max, | ||
top=y_max, | ||
transform=src.transform, | ||
) | ||
data = src.read( | ||
window=window, | ||
out_shape=(src.count, tile_size_pixels, tile_size_pixels), | ||
resampling=interpolation_mode.to_rio(), | ||
fill_value=fill_value, | ||
) | ||
|
||
data = _permute_data( | ||
data=data, | ||
) | ||
data = _drop_channels( | ||
data=data, | ||
drop_channels=drop_channels, | ||
) | ||
return data | ||
|
||
|
||
def _compute_bounding_box( | ||
x_min: XMin, | ||
y_min: YMin, | ||
tile_size: TileSize, | ||
buffer_size: BufferSize | None, | ||
) -> BoundingBox: | ||
""" | ||
| Computes the bounding box of the tile. | ||
:param x_min: minimum x coordinate | ||
:param y_min: minimum y coordinate | ||
:param tile_size: tile size in meters | ||
:param buffer_size: buffer size in meters | ||
:return: bounding box (x_min, y_min, x_max, y_max) of the tile | ||
""" | ||
if buffer_size is None: | ||
return ( | ||
x_min, | ||
y_min, | ||
x_min + tile_size, | ||
y_min + tile_size, | ||
) | ||
|
||
return ( | ||
x_min - buffer_size, | ||
y_min - buffer_size, | ||
x_min + tile_size + buffer_size, | ||
y_min + tile_size + buffer_size, | ||
) | ||
|
||
|
||
def _compute_tile_size_pixels( | ||
tile_size: TileSize, | ||
buffer_size: BufferSize | None, | ||
ground_sampling_distance: GroundSamplingDistance, | ||
) -> int: | ||
""" | ||
| Computes the tile size in pixels. | ||
:param tile_size: tile size in meters | ||
:param ground_sampling_distance: ground sampling distance in meters | ||
:return: tile size in pixels | ||
""" | ||
if buffer_size is None: | ||
return int(tile_size / ground_sampling_distance) | ||
|
||
return int((tile_size + 2 * buffer_size) / ground_sampling_distance) | ||
|
||
|
||
def _drop_channels( | ||
data: npt.NDArray, | ||
drop_channels: list[int] | None, | ||
) -> npt.NDArray: | ||
""" | ||
| Drops the specified channels from the data. | ||
:param data: data | ||
:param drop_channels: channel indices to drop | ||
:return: data | ||
""" | ||
if drop_channels is None: | ||
return data | ||
|
||
channels = np.arange(data.shape[-1]) | ||
keep_channels = np.delete(channels, drop_channels) | ||
return data[..., keep_channels] | ||
|
||
|
||
def _permute_data( | ||
data: npt.NDArray, | ||
) -> npt.NDArray: | ||
""" | ||
| Permutes the data from channels-first to channels-last. | ||
:param data: data | ||
:return: data | ||
""" | ||
return np.transpose(data, (1, 2, 0)) |
Oops, something went wrong.