Skip to content

Commit

Permalink
feat: add DataFetcher (#23)
Browse files Browse the repository at this point in the history
  • Loading branch information
mrsmrynk authored Apr 26, 2024
1 parent bd553c4 commit 76a4a1c
Show file tree
Hide file tree
Showing 11 changed files with 534 additions and 0 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
geopandas==0.14.3
numpy==1.26.4
pyproj==3.6.1
rasterio==1.3.10
shapely==2.0.4
6 changes: 6 additions & 0 deletions src/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
MaskFilter,
SetFilter,
)
from .data_fetcher import (
DataFetcher,
VRTDataFetcher,
)
from .grid_generator import GridGenerator

__all__ = [
'CompositeFilter',
'CoordinatesFilter',
'DataFetcher',
'DuplicatesFilter',
'GeospatialFilter',
'GridGenerator',
'MaskFilter',
'SetFilter',
'VRTDataFetcher',
]
85 changes: 85 additions & 0 deletions src/data/data_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from pathlib import Path
from typing import Protocol

import numpy.typing as npt

from src.functional.data.data_fetcher import (
vrt_data_fetcher,
)
from src.utils.types import (
BufferSize,
GroundSamplingDistance,
InterpolationMode,
TileSize,
XMin,
YMin,
)


class DataFetcher(Protocol):

def __call__(
self,
x_min: XMin,
y_min: YMin,
) -> npt.NDArray:
"""
| Fetches the data.
:param x_min: minimum x coordinate
:param y_min: minimum y coordinate
:return: data
"""
...


class VRTDataFetcher:
_FILL_VALUE = 0

def __init__(
self,
path: Path,
tile_size: TileSize,
ground_sampling_distance: GroundSamplingDistance,
interpolation_mode: InterpolationMode = InterpolationMode.BILINEAR,
buffer_size: BufferSize = None,
drop_channels: list[int] = None,
) -> None:
"""
:param path: path to the VRT file
:param tile_size: tile size in meters
:param ground_sampling_distance: ground sampling distance in meters
:param interpolation_mode: interpolation mode (InterpolationMode.BILINEAR or InterpolationMode.NEAREST)
:param buffer_size: buffer size in meters
:param drop_channels: channel indices to drop
"""
self.path = path
self.tile_size = tile_size
self.ground_sampling_distance = ground_sampling_distance
self.interpolation_mode = interpolation_mode
self.buffer_size = buffer_size
self.drop_channels = drop_channels

def __call__(
self,
x_min: XMin,
y_min: YMin,
) -> npt.NDArray:
"""
| Fetches the data from the VRT file.
:param x_min: minimum x coordinate
:param y_min: minimum y coordinate
:return: data
"""
return vrt_data_fetcher(
x_min=x_min,
y_min=y_min,
path=self.path,
tile_size=self.tile_size,
ground_sampling_distance=self.ground_sampling_distance,
interpolation_mode=self.interpolation_mode,
buffer_size=self.buffer_size,
drop_channels=self.drop_channels,
fill_value=self._FILL_VALUE,
)
24 changes: 24 additions & 0 deletions src/data/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from pathlib import Path

import geopandas as gpd
import numpy as np
import pytest
Expand All @@ -9,9 +11,13 @@
MaskFilter,
SetFilter,
)
from src.data.data_fetcher import (
VRTDataFetcher,
)
from src.data.grid_generator import GridGenerator
from src.utils.types import (
GeospatialFilterMode,
InterpolationMode,
SetFilterMode,
)

Expand Down Expand Up @@ -78,3 +84,21 @@ def set_filter() -> SetFilter:
additional_coordinates=additional_coordinates,
mode=mode,
)


@pytest.fixture(scope='session')
def vrt_data_fetcher() -> VRTDataFetcher:
path = Path('test/test.vrt')
tile_size = 128
ground_sampling_distance = .2
interpolation_mode = InterpolationMode.BILINEAR
buffer_size = None
drop_channels = None
return VRTDataFetcher(
path=path,
tile_size=tile_size,
ground_sampling_distance=ground_sampling_distance,
interpolation_mode=interpolation_mode,
buffer_size=buffer_size,
drop_channels=drop_channels,
)
61 changes: 61 additions & 0 deletions src/data/tests/test_data_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from pathlib import Path
from unittest.mock import patch

from src.data.data_fetcher import (
VRTDataFetcher,
)
from src.utils.types import (
InterpolationMode,
)


def test_init_vrt_data_fetcher() -> None:
path = Path('test/test.vrt')
tile_size = 128
ground_sampling_distance = .2
interpolation_mode = InterpolationMode.BILINEAR
buffer_size = None
drop_channels = None
vrt_data_fetcher = VRTDataFetcher(
path=path,
tile_size=tile_size,
ground_sampling_distance=ground_sampling_distance,
interpolation_mode=interpolation_mode,
buffer_size=buffer_size,
drop_channels=drop_channels,
)

assert vrt_data_fetcher.path == path
assert vrt_data_fetcher.tile_size == tile_size
assert vrt_data_fetcher.ground_sampling_distance == ground_sampling_distance
assert vrt_data_fetcher.interpolation_mode == interpolation_mode
assert vrt_data_fetcher.buffer_size == buffer_size
assert vrt_data_fetcher.drop_channels == drop_channels


@patch('src.data.data_fetcher.vrt_data_fetcher')
def test_call_vrt_data_fetcher(
mocked_vrt_data_fetcher,
vrt_data_fetcher: VRTDataFetcher,
) -> None:
x_min = -128
y_min = -128
expected = 'expected'
mocked_vrt_data_fetcher.return_value = expected
data = vrt_data_fetcher(
x_min=x_min,
y_min=y_min,
)

mocked_vrt_data_fetcher.assert_called_once_with(
x_min=x_min,
y_min=y_min,
path=vrt_data_fetcher.path,
tile_size=vrt_data_fetcher.tile_size,
ground_sampling_distance=vrt_data_fetcher.ground_sampling_distance,
interpolation_mode=vrt_data_fetcher.interpolation_mode,
buffer_size=vrt_data_fetcher.buffer_size,
drop_channels=vrt_data_fetcher.drop_channels,
fill_value=vrt_data_fetcher._FILL_VALUE,
)
assert data == expected
158 changes: 158 additions & 0 deletions src/functional/data/data_fetcher.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
from pathlib import Path

import numpy as np
import numpy.typing as npt
import rasterio as rio
import rasterio.windows

from src.utils.types import (
BoundingBox,
BufferSize,
GroundSamplingDistance,
InterpolationMode,
TileSize,
XMin,
YMin,
)


def vrt_data_fetcher(
x_min: XMin,
y_min: YMin,
path: Path,
tile_size: TileSize,
ground_sampling_distance: GroundSamplingDistance,
interpolation_mode: InterpolationMode = InterpolationMode.BILINEAR,
buffer_size: BufferSize = None,
drop_channels: list[int] = None,
fill_value: int = 0,
) -> npt.NDArray:
"""
| Fetches the data from the VRT file.
:param x_min: minimum x coordinate
:param y_min: minimum y coordinate
:param path: path to the VRT file
:param tile_size: tile size in meters
:param ground_sampling_distance: ground sampling distance in meters
:param interpolation_mode: interpolation mode (InterpolationMode.BILINEAR or InterpolationMode.NEAREST)
:param buffer_size: buffer size in meters
:param drop_channels: channel indices to drop
:param fill_value: fill value of nodata pixels
:return: data
"""
x_min, y_min, x_max, y_max = _compute_bounding_box(
x_min=x_min,
y_min=y_min,
tile_size=tile_size,
buffer_size=buffer_size,
)
tile_size_pixels = _compute_tile_size_pixels(
tile_size=tile_size,
buffer_size=buffer_size,
ground_sampling_distance=ground_sampling_distance,
)

with rio.open(path) as src:
window = rio.windows.from_bounds(
left=x_min,
bottom=y_min,
right=x_max,
top=y_max,
transform=src.transform,
)
data = src.read(
window=window,
out_shape=(src.count, tile_size_pixels, tile_size_pixels),
resampling=interpolation_mode.to_rio(),
fill_value=fill_value,
)

data = _permute_data(
data=data,
)
data = _drop_channels(
data=data,
drop_channels=drop_channels,
)
return data


def _compute_bounding_box(
x_min: XMin,
y_min: YMin,
tile_size: TileSize,
buffer_size: BufferSize | None,
) -> BoundingBox:
"""
| Computes the bounding box of the tile.
:param x_min: minimum x coordinate
:param y_min: minimum y coordinate
:param tile_size: tile size in meters
:param buffer_size: buffer size in meters
:return: bounding box (x_min, y_min, x_max, y_max) of the tile
"""
if buffer_size is None:
return (
x_min,
y_min,
x_min + tile_size,
y_min + tile_size,
)

return (
x_min - buffer_size,
y_min - buffer_size,
x_min + tile_size + buffer_size,
y_min + tile_size + buffer_size,
)


def _compute_tile_size_pixels(
tile_size: TileSize,
buffer_size: BufferSize | None,
ground_sampling_distance: GroundSamplingDistance,
) -> int:
"""
| Computes the tile size in pixels.
:param tile_size: tile size in meters
:param ground_sampling_distance: ground sampling distance in meters
:return: tile size in pixels
"""
if buffer_size is None:
return int(tile_size / ground_sampling_distance)

return int((tile_size + 2 * buffer_size) / ground_sampling_distance)


def _drop_channels(
data: npt.NDArray,
drop_channels: list[int] | None,
) -> npt.NDArray:
"""
| Drops the specified channels from the data.
:param data: data
:param drop_channels: channel indices to drop
:return: data
"""
if drop_channels is None:
return data

channels = np.arange(data.shape[-1])
keep_channels = np.delete(channels, drop_channels)
return data[..., keep_channels]


def _permute_data(
data: npt.NDArray,
) -> npt.NDArray:
"""
| Permutes the data from channels-first to channels-last.
:param data: data
:return: data
"""
return np.transpose(data, (1, 2, 0))
Loading

0 comments on commit 76a4a1c

Please sign in to comment.