diff --git a/demos/reanalysis-forced.ipynb b/demos/reanalysis-forced.ipynb index 82607591..72f56527 100644 --- a/demos/reanalysis-forced.ipynb +++ b/demos/reanalysis-forced.ipynb @@ -288,16 +288,14 @@ " glorys_path / \"ic_unprocessed.nc\", # directory where the unprocessed initial condition is stored, as defined earlier\n", " ocean_varnames,\n", " arakawa_grid=\"A\"\n", - " )\n", + " ) \n", "\n", - "# Now iterate through our four boundaries \n", - "for i, orientation in enumerate([\"south\", \"north\", \"west\", \"east\"]):\n", - " expt.rectangular_boundary(\n", - " glorys_path / (orientation + \"_unprocessed.nc\"),\n", + "# Set up the four boundary conditions. Remember that in the glorys_path, we have four boundary files names north_unprocessed.nc etc. \n", + "expt.rectangular_boundaries(\n", + " glorys_path,\n", " ocean_varnames,\n", - " orientation, # Needs to know the cardinal direction of the boundary\n", - " i + 1, # Just a number to identify the boundary. Indexes from 1 \n", - " arakawa_grid=\"A\"\n", + " boundaries = [\"south\", \"north\", \"west\", \"east\"],\n", + " arakawa_grid = \"A\"\n", " )" ] }, diff --git a/regional_mom6/regional_mom6.py b/regional_mom6/regional_mom6.py index d2709f1a..26fa71be 100644 --- a/regional_mom6/regional_mom6.py +++ b/regional_mom6/regional_mom6.py @@ -568,14 +568,18 @@ def _make_vgrid(self): return vcoord def initial_condition( - self, ic_path, varnames, arakawa_grid="A", vcoord_type="height" + self, + raw_ic_path, + varnames, + arakawa_grid="A", + vcoord_type="height", ): """ Reads the initial condition from files in ``ic_path``, interpolates to the model grid, fixes up metadata, and saves back to the input directory. Args: - ic_path (Union[str, Path]): Path to initial condition file. + raw_ic_path (Union[str, Path]): Path to raw initial condition file to read in. varnames (Dict[str, str]): Mapping from MOM6 variable/coordinate names to the names in the input dataset. For example, ``{'xq': 'lonq', 'yh': 'lath', 'salt': 'so', ...}``. arakawa_grid (Optional[str]): Arakawa grid staggering type of the initial condition. @@ -587,7 +591,7 @@ def initial_condition( # Remove time dimension if present in the IC. # Assume that the first time dim is the intended on if more than one is present - ic_raw = xr.open_dataset(ic_path) + ic_raw = xr.open_dataset(raw_ic_path) if varnames["time"] in ic_raw.dims: ic_raw = ic_raw.isel({varnames["time"]: 0}) if varnames["time"] in ic_raw.coords: @@ -771,7 +775,7 @@ def initial_condition( .bfill("lat") ) - ## Make our three horizontal regrideers + ## Make our three horizontal regridders regridder_u = xe.Regridder( ic_raw_u, ugrid, @@ -886,19 +890,67 @@ def initial_condition( "eta_t": {"_FillValue": None}, }, ) - print("Done.\nFinished setting up initial condition.") self.ic_eta = eta_out self.ic_tracers = tracers_out self.ic_vels = vel_out + + print("done setting up initial condition.") + return - def rectangular_boundary( + def rectangular_boundaries( + self, + raw_boundaries_path, + varnames, + boundaries=["south", "north", "west", "east"], + arakawa_grid="A", + ): + """ + This function is a wrapper for `simple_boundary`. Given a list of up to four cardinal directions, + it creates a boundary forcing file for each one. Ensure that the raw boundaries are all saved in the same directory, + and that they are named using the format `east_unprocessed.nc` + + Args: + raw_boundaries_path (str): Path to the directory containing the raw boundary forcing files. + varnames (Dict[str, str]): Mapping from MOM6 variable/coordinate names to the name in the + input dataset. + boundaries (List[str]): List of cardinal directions for which to create boundary forcing files. + Default is `["south", "north", "west", "east"]`. + arakawa_grid (Optional[str]): Arakawa grid staggering type of the boundary forcing. + Either ``'A'`` (default), ``'B'``, or ``'C'``. + """ + for i in boundaries: + if i not in ["south", "north", "west", "east"]: + raise ValueError( + f"Invalid boundary direction: {i}. Must be one of ['south', 'north', 'west', 'east']" + ) + + if len(boundaries) < 4: + print( + "NOTE: the 'setup_run_directories' method assumes that you have four boundaries. You'll need to modify the MOM_input file manually to reflect the number of boundaries you have, and their orientations. You should be able to find the relevant section in the MOM_input file by searching for 'segment_'. Ensure that the segment names match those in your inputdir/forcing folder" + ) + + if len(boundaries) > 4: + raise ValueError( + "This method only supports up to four boundaries. To set up more complex boundary shapes you can manually call the 'simple_boundary' method for each boundary." + ) + # Now iterate through our four boundaries + for i, orientation in enumerate(boundaries, start=1): + self.simple_boundary( + Path(raw_boundaries_path) / (orientation + "_unprocessed.nc"), + varnames, + orientation, # The cardinal direction of the boundary + i, # A number to identify the boundary; indexes from 1 + arakawa_grid=arakawa_grid, + ) + + def simple_boundary( self, path_to_bc, varnames, orientation, segment_number, arakawa_grid="A" ): """ - Set up a boundary forcing file for a given orientation. Here the term 'rectangular' - means boundaries along lines of constant latitude or longitude. + Here 'simple' refers to boundaries that are parallel to lines of constant longitude or latitude. + Set up a boundary forcing file for a given orientation. Args: path_to_bc (str): Path to boundary forcing file. Ideally this should be a pre cut-out @@ -916,7 +968,10 @@ def rectangular_boundary( """ print("Processing {} boundary...".format(orientation), end="") - + if not path_to_bc.exists(): + raise FileNotFoundError( + f"Boundary file not found at {path_to_bc}. Please ensure that the files are named in the format `east_unprocessed.nc`." + ) seg = segment( hgrid=self.hgrid, infile=path_to_bc, # location of raw boundary diff --git a/tests/test_expt_class.py b/tests/test_expt_class.py index c484a355..d0713c2f 100644 --- a/tests/test_expt_class.py +++ b/tests/test_expt_class.py @@ -355,7 +355,7 @@ def test_ocean_forcing( ), ], ) -def test_rectangular_boundary( +def test_rectangular_boundaries( longitude_extent, latitude_extent, date_range, @@ -443,7 +443,7 @@ def test_rectangular_boundary( ), } ) - eastern_boundary.to_netcdf(tmp_path / "east_unprocessed") + eastern_boundary.to_netcdf(tmp_path / "east_unprocessed.nc") eastern_boundary.close() expt = experiment( @@ -471,4 +471,4 @@ def test_rectangular_boundary( "tracers": {"temp": "temp", "salt": "salt"}, } - expt.rectangular_boundary(tmp_path / "east_unprocessed", varnames, "east", 1) + expt.rectangular_boundaries(tmp_path, varnames, ["east"])