diff --git a/nlmod/dims/base.py b/nlmod/dims/base.py index 56e5c1e0..caf14c68 100644 --- a/nlmod/dims/base.py +++ b/nlmod/dims/base.py @@ -77,6 +77,7 @@ def to_model_ds( angrot=0.0, drop_attributes=True, transport=False, + remove_nan_layers=True, ): """Transform an input dataset to a groundwater model dataset. @@ -125,6 +126,9 @@ def to_model_ds( transport : bool, optional flag indicating whether dataset includes data for a groundwater transport model (GWT). Default is False, no transport. + remove_nan_layers : bool, optional + if True remove layers with only nan values in the botm. Default is + True. Returns ------- @@ -174,6 +178,7 @@ def to_model_ds( anisotropy=anisotropy, fill_value_kh=fill_value_kh, fill_value_kv=fill_value_kv, + remove_nan_layers=remove_nan_layers, ) return ds diff --git a/nlmod/dims/layers.py b/nlmod/dims/layers.py index 4df04b70..750b67ed 100644 --- a/nlmod/dims/layers.py +++ b/nlmod/dims/layers.py @@ -223,6 +223,9 @@ def split_layers_ds( layers = list(ds.layer.data) + # Work on a shallow copy of split_dict + split_dict = split_dict.copy() + # do some input-checking on split_dict for lay0 in list(split_dict): if isinstance(lay0, int) & (ds.layer.dtype != int): @@ -230,13 +233,20 @@ def split_layers_ds( # replace lay0 by the name of the layer split_dict[layers[lay0]] = split_dict.pop(lay0) lay0 = layers[lay0] - if isinstance(split_dict[lay0], int): + if isinstance(split_dict[lay0], (int, np.integer)): # If split_dict[lay0] is of integer type # split the layer in evenly thick layers split_dict[lay0] = [1 / split_dict[lay0]] * split_dict[lay0] - else: + elif hasattr(split_dict[lay0], "__iter__"): # make sure the fractions add up to 1 + assert np.isclose(np.sum(split_dict[lay0]), 1), ( + f"Fractions for splitting layer '{lay0}' do not add up to 1." + ) split_dict[lay0] = split_dict[lay0] / np.sum(split_dict[lay0]) + else: + raise ValueError( + "split_dict should contain an iterable of factors or an integer" + ) logger.info(f"Splitting layers {list(split_dict)}") diff --git a/nlmod/read/regis.py b/nlmod/read/regis.py index 05f21227..729d7b44 100644 --- a/nlmod/read/regis.py +++ b/nlmod/read/regis.py @@ -177,6 +177,7 @@ def get_regis( variables = variables + ("sdh", "sdv") ds = ds[list(variables)] + ds.attrs["gridtype"] = "structured" ds.attrs["extent"] = extent for datavar in ds: ds[datavar].attrs["grid_mapping"] = "crs"