Skip to content

Commit

Permalink
minor: adapt to dichotomy implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
adebardo committed Feb 16, 2024
1 parent 68d827a commit cb187bd
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 32 deletions.
5 changes: 4 additions & 1 deletion docs/source/userguide/step_by_step/refinement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,10 @@ Configuration and parameters
| *if above, will be bound to 9*
| **Optical flow**
| >0
- Yes
- | **Dichotomy**
| Yes
| **Optical flow**
| No
* - *filter*
- Name of the filter to use
- str
Expand Down
6 changes: 3 additions & 3 deletions notebooks/introduction_and_basic_usage.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -347,7 +347,7 @@
" },\n",
" \"refinement\" : {\n",
" \"refinement_method\" : \"optical_flow\",\n",
" \"nbr_iteration\": 4\n",
" \"iterations\": 4\n",
" }\n",
" }\n",
"}"
Expand Down Expand Up @@ -547,7 +547,7 @@
" },\n",
" \"refinement\":{\n",
" \"refinement_method\" : \"optical_flow\",\n",
" \"nbr_iteration\": 4\n",
" \"iterations\": 4\n",
" }\n",
" }\n",
"}"
Expand Down Expand Up @@ -775,7 +775,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
15 changes: 11 additions & 4 deletions pandora2d/refinement/dichotomy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) 2024 Centre National d'Etudes Spatiales (CNES).
# Copyright (c) 2024 CS GROUP France
#
# This file is part of PANDORA2D
#
Expand Down Expand Up @@ -66,14 +67,20 @@ def check_conf(cls, cfg: Dict) -> Dict:
def margins(self):
return Margins(2, 2, 2, 2)

def refinement_method(self, cost_volumes: xr.Dataset, pixel_maps: xr.Dataset) -> None:
def refinement_method(
self, cost_volumes: xr.Dataset, disp_map: xr.Dataset, img_left: xr.Dataset, img_right: xr.Dataset
) -> None:
"""
Return the subpixel disparity maps
:param cost_volumes: cost_volumes 4D row, col, disp_col, disp_row
:type cost_volumes: xarray.dataset
:param pixel_maps: pixels disparity maps
:type pixel_maps: xarray.dataset
:type cost_volumes: xarray.Dataset
:param disp_map: pixel disparity maps
:type disp_map: xarray.Dataset
param img_left: left image dataset
:type img_left: xarray.Dataset
:param img_right: right image dataset
:type img_right: xarray.Dataset
:return: the refined disparity maps
:rtype: Tuple[np.ndarray, np.ndarray]
"""
Expand Down
42 changes: 21 additions & 21 deletions pandora2d/refinement/optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

import numpy as np
import xarray as xr
from json_checker import And, Checker
from json_checker import And
from scipy.ndimage import map_coordinates
from pandora.margins import Margins

Expand All @@ -38,30 +38,32 @@ class OpticalFlow(refinement.AbstractRefinement):
OpticalFLow class allows to perform the subpixel cost refinement step
"""

_nbr_iteration = None
_iterations = None
_invalid_disp = None

_NBR_ITERATION = 4
_ITERATIONS = 4

schema = {"refinement_method": And(str, lambda x: x in ["optical_flow"]), "iterations": And(int, lambda it: it > 0)}

def __init__(self, cfg: dict = None, step: list = None, window_size: int = 5) -> None:
"""
:param cfg: optional configuration, {}
:type cfg: dict
:param step: list containing row and col step
:type step: list
:param window_size: window size
:type window_size: int
:return: None
"""
super().__init__(cfg)

self.cfg = self.check_conf(cfg)
self._nbr_iteration = self.cfg["nbr_iteration"]
self._iterations = self.cfg["iterations"]
self._refinement_method = self.cfg["refinement_method"]
self._window_size = window_size
self._step = [1, 1] if step is None else step

@property
def margins(self):
values = (self._window_size // 2 * ele for _ in range(2) for ele in self._step)
return Margins(*values)

def check_conf(self, cfg: Dict) -> Dict:
@classmethod
def check_conf(cls, cfg: Dict) -> Dict:
"""
Check the refinement configuration
Expand All @@ -71,19 +73,17 @@ def check_conf(self, cfg: Dict) -> Dict:
:rtype: cfg: dict
"""

if "nbr_iteration" not in cfg:
cfg["nbr_iteration"] = self._NBR_ITERATION
cfg["iterations"] = cfg.get("iterations", cls._ITERATIONS)

schema = {
"refinement_method": And(str, lambda x: x in ["optical_flow"]),
"nbr_iteration": And(int, lambda nbr_i: nbr_i > 0),
}

checker = Checker(schema)
checker.validate(cfg)
cfg = super().check_conf(cfg)

return cfg

@property
def margins(self):
values = (self._window_size // 2 * ele for _ in range(2) for ele in self._step)
return Margins(*values)

def reshape_to_matching_cost_window(
self,
img: xr.Dataset,
Expand Down Expand Up @@ -303,7 +303,7 @@ def refinement_method(

idx_to_compute = np.arange(reshaped_left.shape[2]).tolist()

for _ in range(self._nbr_iteration):
for _ in range(self._iterations):

computed_drow, computed_dcol, idx_to_compute = self.optical_flow(
reshaped_left, reshaped_right, idx_to_compute
Expand Down
6 changes: 5 additions & 1 deletion pandora2d/refinement/refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,14 @@ def decorator(subclass):

return decorator

def __init__(self, cfg: Dict) -> None:
def __init__(self, cfg: Dict, _: list = None, __: int = 5) -> None:
"""
:param cfg: optional configuration, {}
:type cfg: dict
:param step: list containing row and col step
:type step: list
:param window_size: window size
:type window_size: int
:return: None
"""
self.cfg = self.check_conf(cfg)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_dichotomy.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_refinement_method(config, caplog, mocker: MockerFixture):
dichotomy_instance = refinement.dichotomy.Dichotomy(config)

# We can pass anything as it is not yet implemented
dichotomy_instance.refinement_method(mocker.ANY, mocker.ANY)
dichotomy_instance.refinement_method(mocker.ANY, mocker.ANY, mocker.ANY, mocker.ANY)

assert "refinement_method of Dichotomy not yet implemented" in caplog.messages

Expand Down
2 changes: 1 addition & 1 deletion tests/test_refinement.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def test_check_conf_passes(refinement_method):


@pytest.mark.parametrize(
"refinement_config", [{"refinement_method": "wta"}, {"refinement_method": "optical_flow", "nbr_iteration": 0}]
"refinement_config", [{"refinement_method": "wta"}, {"refinement_method": "optical_flow", "iterations": 0}]
)
def test_check_conf_fails(refinement_config):
"""
Expand Down

0 comments on commit cb187bd

Please sign in to comment.