Skip to content

Commit

Permalink
Use Path objects instead of string manipulation
Browse files Browse the repository at this point in the history
Importantly, reduces number of subprocess/os calls, in favour of using
Python-native operations.
  • Loading branch information
angus-g committed Sep 28, 2023
1 parent 7c338db commit fe594ba
Showing 1 changed file with 30 additions and 50 deletions.
80 changes: 30 additions & 50 deletions regional_mom6/regional_mom6.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import numpy as np
from itertools import cycle
import os
from pathlib import Path
import dask.array as da
import dask.bag as db
import numpy as np
Expand Down Expand Up @@ -475,15 +475,12 @@ def __init__(
toolpath,
gridtype="even_spacing",
):
try:
os.mkdir(mom_run_dir)
except:
pass
self.mom_run_dir = Path(mom_run_dir)
self.mom_input_dir = Path(mom_input_dir)

self.mom_run_dir.mkdir(exist_ok=True)
self.mom_input_dir.mkdir(exist_ok=True)

try:
os.mkdir(mom_input_dir)
except:
pass
self.xextent = xextent
self.yextent = yextent
self.daterange = [
Expand All @@ -494,29 +491,17 @@ def __init__(
self.vlayers = vlayers
self.dz_ratio = dz_ratio
self.depth = depth
self.mom_run_dir = mom_run_dir
self.mom_input_dir = mom_input_dir
self.toolpath = toolpath
self.hgrid = self._make_hgrid(gridtype)
self.vgrid = self._make_vgrid()
self.gridtype = gridtype
# if "temp" not in os.listdir(inputdir):
# os.mkdir(inputdir + "temp")

if "weights" not in os.listdir(self.mom_input_dir):
os.mkdir(mom_input_dir + "weights")
if "forcing" not in os.listdir(self.mom_input_dir):
os.mkdir(self.mom_input_dir + "forcing")

# create a simlink from input directory to run directory and vv
subprocess.run(
f"ln -s {self.mom_input_dir} {self.mom_run_dir}/inputdir", shell=True
)
subprocess.run(
f"ln -s {self.mom_run_dir} {self.mom_input_dir}/rundir", shell=True
)
# create additional directories and links
(self.mom_input_dir / "weights").mkdir(exist_ok=True)
(self.mom_input_dir / "forcing").mkdir(exist_ok=True)

return
(self.mom_run_dir / "inputdir").link_to(self.mom_input_dir.resolve())
(self.mom_input_dir / "rundir").link_to(self.mom_run_dir.resolve())

def _make_hgrid(self, gridtype):
"""Sets up hgrid based on users specification of
Expand Down Expand Up @@ -548,7 +533,7 @@ def _make_hgrid(self, gridtype):

y = np.linspace(self.yextent[0], self.yextent[1], ny)
hgrid = rectangular_hgrid(x, y)
hgrid.to_netcdf(self.mom_input_dir + "/hgrid.nc")
hgrid.to_netcdf(self.mom_input_dir / "hgrid.nc")

return hgrid

Expand All @@ -566,7 +551,7 @@ def _make_vgrid(self):
} ## THIS MIGHT BE WRONG REVISIT
)
vcoord["zi"].attrs = {"units": "meters"}
vcoord.to_netcdf(self.mom_input_dir + "/vcoord.nc")
vcoord.to_netcdf(self.mom_input_dir / "vcoord.nc")

return vcoord

Expand Down Expand Up @@ -806,7 +791,7 @@ def ocean_forcing(
eta_out = eta_out.isel(time=0).drop("time")

vel_out.fillna(0).to_netcdf(
self.mom_input_dir + "forcing/init_vel.nc",
self.mom_input_dir / "forcing/init_vel.nc",
mode="w",
encoding={
"u": {"_FillValue": netCDF4.default_fillvals["f4"]},
Expand All @@ -815,7 +800,7 @@ def ocean_forcing(
)

tracers_out.to_netcdf(
self.mom_input_dir + "forcing/init_tracers.nc",
self.mom_input_dir / "forcing/init_tracers.nc",
mode="w",
encoding={
"xh": {"_FillValue": None},
Expand All @@ -826,7 +811,7 @@ def ocean_forcing(
},
)
eta_out.to_netcdf(
self.mom_input_dir + "forcing/init_eta.nc",
self.mom_input_dir / "forcing/init_eta.nc",
mode="w",
encoding={
"xh": {"_FillValue": None},
Expand All @@ -840,12 +825,14 @@ def ocean_forcing(
self.ic_tracers = tracers_out
self.ic_vels = vel_out

if boundaries == None:
if boundaries is None:
return

print("BRUSHCUT BOUNDARIES")

## Generate a rectangular OBC domain. This is the default configuration. For fancier domains, need to use the segment class manually
## Generate a rectangular OBC domain. This is the default
## configuration. For fancier domains, need to use the segment
## class manually
for i, o in enumerate(boundaries):
print(f"Processing {o}...", end="")
seg = segment(
Expand Down Expand Up @@ -956,7 +943,7 @@ def bathymetry(
bathyout.elevation.attrs["long_name"] = "Elevation relative to sea level"
bathyout.elevation.attrs["coordinates"] = "lon lat"
bathyout.to_netcdf(
f"{self.mom_input_dir}bathy_original.nc", mode="w", engine="netcdf4"
self.mom_input_dir / "bathy_original.nc", mode="w", engine="netcdf4"
)

tgrid = xr.Dataset(
Expand Down Expand Up @@ -1003,7 +990,7 @@ def bathymetry(
tgrid.lon.attrs["_FillValue"] = 1e20
tgrid.lat.attrs["units"] = "degrees_north"
tgrid.to_netcdf(
f"{self.mom_input_dir}topog_raw.nc", mode="w", engine="netcdf4"
self.mom_input_dir / "topog_raw.nc", mode="w", engine="netcdf4"
)
tgrid.close()

Expand All @@ -1020,12 +1007,12 @@ def bathymetry(

topog = regridder(bathyout)
topog.to_netcdf(
f"{self.mom_input_dir}topog_raw.nc", mode="w", engine="netcdf4"
self.mom_input_dir / "topog_raw.nc", mode="w", engine="netcdf4"
)

## reopen topography to modify
print("Reading in regridded bathymetry to fix up metadata...", end="")
topog = xr.open_dataset(self.mom_input_dir + "topog_raw.nc", engine="netcdf4")
topog = xr.open_dataset(self.mom_input_dir / "topog_raw.nc", engine="netcdf4")

## Ensure correct encoding
topog = xr.Dataset(
Expand Down Expand Up @@ -1181,32 +1168,26 @@ def bathymetry(
topog["depth"] = topog["depth"].where(topog["depth"] != 0, np.nan)

topog.expand_dims({"ntiles": 1}).to_netcdf(
self.mom_input_dir + "topog_deseas.nc",
self.mom_input_dir / "topog_deseas.nc",
mode="w",
encoding={"depth": {"_FillValue": None}},
)

subprocess.run(
"mv topog_deseas.nc topog.nc", shell=True, cwd=self.mom_input_dir
)
(self.mom_input_dir / "topog_deseas.nc").rename(self.mom_input_dir / "topog.nc")
print("done.")
self.topog = topog
return

def FRE_tools(self, layout):
"""
Just a wrapper for FRE Tools check_mask, make_solo_mosaic and make_quick_mosaic. User provides processor layout tuple of processing units.
"""

if "topog.nc" not in os.listdir(self.mom_input_dir):
if not (self.mom_input_dir / "topog.nc").exists():
print("No topography file! Need to run make_bathymetry first")
return
try:
os.remove(
"mask_table*"
) ## Removes old mask table so as not to clog up inputdir
except:
pass

for p in self.mom_input_dir.glob("mask_table*"):
p.unlink()

print(
"MAKE SOLO MOSAIC",
Expand Down Expand Up @@ -1240,7 +1221,6 @@ def FRE_tools(self, layout):
),
)
self.layout = layout
return


class segment:
Expand Down

0 comments on commit fe594ba

Please sign in to comment.