diff --git a/CHANGELOG.md b/CHANGELOG.md index 2ba6c350..07cda8f9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ # Version 0.2 (current devel) +- Refactored basis functions so that they return an xr.Dataset, rather than writing to temporary files. If an output directory is specified, they will save the basis functions as a side effect. + - Added option to run an inversion without boundary conditions. This is specified by adding `use_bc = False` in an .ini file. This assumes that the baseline has already been factored into the observations. - Added tests to test `get_data.py`, including creating, saving, and loading merged data. Refactored inversions tests to reload merged data, instead of creating merged data. diff --git a/openghg_inversions/basis_functions.py b/openghg_inversions/basis_functions.py index e320bc92..43ab55d4 100644 --- a/openghg_inversions/basis_functions.py +++ b/openghg_inversions/basis_functions.py @@ -10,13 +10,14 @@ # the inversion runs. # ***************************************************************************** -import os -import uuid import getpass -import scipy.optimize +import os +from typing import Optional + import numpy as np -import xarray as xr import pandas as pd +import scipy.optimize +import xarray as xr class quadTreeNode: @@ -108,21 +109,22 @@ def quadTreeGrid(grid, limit): outputGrid[leaf.xStart : leaf.xEnd, leaf.yStart : leaf.yEnd] = i boxList.append([leaf.xStart, leaf.xEnd, leaf.yStart, leaf.yEnd]) - return outputGrid, boxList + return outputGrid def quadtreebasisfunction( - emissions_name, - fp_all, - sites, - start_date, - domain, - species, - outputname, - outputdir=None, - nbasis=100, - abs_flux=False, -): + emissions_name: list[str], + fp_all: dict, + sites: list[str], + start_date: str, + domain: str, + species: str, + outputname: Optional[str] = None, + outputdir: Optional[str] = None, + nbasis: int = 100, + abs_flux: bool = False, + seed: Optional[int] = None, +) -> xr.Dataset: """ Creates a basis function with nbasis grid cells using a quadtree algorithm. The domain is split with smaller grid cells for regions which contribute @@ -160,10 +162,12 @@ def quadtreebasisfunction( i.e. nbasis % 4 = 1. abs_flux (bool): If True this will take the absolute value of the flux + seed: + Optional seed to pass to scipy.optimize.dual_annealing. Used for testing. Returns: - If outputdir is None, then returns a Temp directory. The new basis function is saved in this Temp directory. - If outputdir is not None, then does not return anything but saves the basis function in outputdir. + xr.Dataset with lat/lon dimensions and basis regions encoded by integers. + If outputdir is not None, then saves the basis function in outputdir. ----------------------------------- """ if abs_flux: @@ -199,18 +203,20 @@ def quadtreebasisfunction( fps = meanfp * meanflux def qtoptim(x): - basisQuad, boxes = quadTreeGrid(fps, x) + basisQuad = quadTreeGrid(fps, x) return (nbasis - np.max(basisQuad) - 1) ** 2 cost = 1e6 pwr = 0 while cost > 3.0: - optim = scipy.optimize.dual_annealing(qtoptim, np.expand_dims([0, 100 / 10**pwr], axis=0)) + optim = scipy.optimize.dual_annealing( + qtoptim, np.expand_dims([0, 100 / 10**pwr], axis=0), seed=seed + ) cost = np.sqrt(optim.fun) pwr += 1 if pwr > 10: - raise Exception("Quadtree did not converge after max iterations.") - basisQuad, boxes = quadTreeGrid(fps, optim.x[0]) + raise RuntimeError("Quadtree did not converge after max iterations.") + basisQuad = quadTreeGrid(fps, optim.x[0]) lon = fp_all[sites[0]].lon.values lat = fp_all[sites[0]].lat.values @@ -229,20 +235,10 @@ def qtoptim(x): newds.attrs["creator"] = getpass.getuser() newds.attrs["date created"] = str(pd.Timestamp.today()) - if outputdir is None: - cwd = os.getcwd() - tempdir = os.path.join(cwd, f"Temp_{str(uuid.uuid4())}") - os.mkdir(tempdir) - os.mkdir(os.path.join(tempdir, f"{domain}/")) - newds.to_netcdf( - os.path.join( - tempdir, domain, f"quadtree_{species}-{outputname}_{domain}_{start_date.split('-')[0]}.nc" - ), - mode="w", - ) - return tempdir - else: + if outputdir is not None: basisoutpath = os.path.join(outputdir, domain) + if outputname is None: + outputname = "output_name" if not os.path.exists(basisoutpath): os.makedirs(basisoutpath) newds.to_netcdf( @@ -251,7 +247,8 @@ def qtoptim(x): ), mode="w", ) - return outputdir + + return newds # BUCKET BASIS FUNCTIONS @@ -421,7 +418,7 @@ def bucketbasisfunction( outputdir=None, nbasis=100, abs_flux=False, -): +) -> xr.Dataset: """ Basis functions calculated using a weighted region approach where each basis function / scaling region contains approximately @@ -450,6 +447,9 @@ def bucketbasisfunction( abs_flux (bool): When set to True uses absolute values of a flux array + Returns: + xr.Dataset with lat/lon dimensions and basis regions encoded by integers. + If outputdir is not None, then saves the basis function in outputdir. """ if abs_flux: print("Using absolute values of flux array") @@ -502,21 +502,7 @@ def bucketbasisfunction( newds.attrs["creator"] = getpass.getuser() newds.attrs["date created"] = str(pd.Timestamp.today()) - if outputdir is None: - cwd = os.getcwd() - tempdir = os.path.join(cwd, f"Temp_{str(uuid.uuid4())}") - os.mkdir(tempdir) - os.mkdir(os.path.join(tempdir, f"{domain}/")) - newds.to_netcdf( - os.path.join( - tempdir, - domain, - f"weighted_{species}-{outputname}_{domain}_{start_date.split('-')[0]}{start_date.split('-')[1]}.nc", - ), - mode="w", - ) - return tempdir - else: + if outputdir is not None: basisoutpath = os.path.join(outputdir, domain) if not os.path.exists(basisoutpath): os.makedirs(basisoutpath) @@ -527,3 +513,5 @@ def bucketbasisfunction( ), mode="w", ) + + return newds diff --git a/openghg_inversions/hbmcmc/hbmcmc.py b/openghg_inversions/hbmcmc/hbmcmc.py index aba826c4..d9383a48 100644 --- a/openghg_inversions/hbmcmc/hbmcmc.py +++ b/openghg_inversions/hbmcmc/hbmcmc.py @@ -32,32 +32,32 @@ import os import pickle -import shutil +from pathlib import Path +from typing import Optional + import numpy as np -import openghg_inversions.hbmcmc.inversionsetup as setup -import openghg_inversions.hbmcmc.inversion_pymc as mcmc import openghg_inversions.basis_functions as basis -from openghg_inversions import utils -from openghg_inversions import get_data -from pathlib import Path +import openghg_inversions.hbmcmc.inversion_pymc as mcmc +import openghg_inversions.hbmcmc.inversionsetup as setup +from openghg_inversions import get_data, utils def basis_functions_wrapper( - basis_algorithm, - nbasis, - fp_basis_case, - bc_basis_case, - basis_directory, - bc_basis_directory, - use_bc, - fp_all, - species, - sites, - domain, - start_date, - emissions_name, - outputname, - output_path=None, + fp_all: dict, + species: str, + sites: list[str], + domain: str, + start_date: str, + emissions_name: list[str], + nbasis: int, + use_bc: bool, + basis_algorithm: Optional[str] = None, + fp_basis_case: Optional[str] = None, + bc_basis_case: Optional[str] = None, + basis_directory: Optional[str] = None, + bc_basis_directory: Optional[str] = None, + outputname: Optional[str] = None, + output_path: Optional[str] = None, ): """ Wrapper function for selecting basis function @@ -110,69 +110,50 @@ def basis_functions_wrapper( fp_data (dict): Dictionary object similar to fp_all but with information on basis functions and sensitivities - - basis_directory (str): - Path to emissions basis funciton directory - - bc_basis_directory (str): - Path to bc basis functinon directory - """ - if basis_algorithm == "quadtree": - print("Using Quadtree algorithm to derive basis functions") - if fp_basis_case is not None: - print("Basis case %s supplied but quadtree_basis set to True" % fp_basis_case) - print("Assuming you want to use %s " % fp_basis_case) - tempdir = None - else: - tempdir = basis.quadtreebasisfunction( - emissions_name, - fp_all, - sites, - start_date, - domain, - species, - outputname, - nbasis=nbasis, - outputdir=output_path, + if fp_basis_case is not None: + if basis_algorithm: + print( + f"Basis algorithm {basis_algorithm} and basis case {fp_basis_case} supplied; using {fp_basis_case}." ) + basis_func = utils.basis(domain=domain, basis_case=fp_basis_case, basis_directory=basis_directory) - fp_basis_case = "quadtree_" + species + "-" + outputname - basis_directory = tempdir + elif basis_algorithm is None: + raise ValueError("One of `fp_basis_case` or `basis_algorithm` must be specified.") + + elif basis_algorithm == "quadtree": + print("Using Quadtree algorithm to derive basis functions") + basis_func = basis.quadtreebasisfunction( + emissions_name, + fp_all, + sites, + start_date, + domain, + species, + outputname, + nbasis=nbasis, + outputdir=output_path, + ) elif basis_algorithm == "weighted": print("Using weighted by data algorithm to derive basis functions") - if fp_basis_case is not None: - print("Basis case %s supplied but bucket_basis set to True" % fp_basis_case) - print("Assuming you want to use %s " % fp_basis_case) - tempdir = None - else: - tempdir = basis.bucketbasisfunction( - emissions_name, - fp_all, - sites, - start_date, - domain, - species, - outputname, - nbasis=nbasis, - ) - - fp_basis_case = "weighted_" + species + "-" + outputname - basis_directory = tempdir - - elif basis_algorithm is None: - basis_directory = basis_directory - tempdir = None + basis_func = basis.bucketbasisfunction( + emissions_name, + fp_all, + sites, + start_date, + domain, + species, + outputname, + nbasis=nbasis, + ) else: raise ValueError( "Basis algorithm not recognised. Please use either 'quadtree' or 'weighted', or input a basis function file" ) - fp_data = utils.fp_sensitivity( - fp_all, domain=domain, basis_case=fp_basis_case, basis_directory=basis_directory - ) + fp_data = utils.fp_sensitivity(fp_all, basis_func=basis_func) if use_bc is True: fp_data = utils.bc_sensitivity( @@ -182,7 +163,7 @@ def basis_functions_wrapper( bc_basis_directory=bc_basis_directory, ) - return fp_data, tempdir, basis_directory, bc_basis_directory + return fp_data def fixedbasisMCMC( @@ -214,7 +195,6 @@ def fixedbasisMCMC( bc_basis_directory=None, country_file=None, bc_input=None, - max_level=None, basis_algorithm="weighted", nbasis=100, filters=[], @@ -504,21 +484,21 @@ def fixedbasisMCMC( raise ValueError("Model does not currently include tracer model. Watch this space") # Basis function regions and sensitivity matrices - fp_data, tempdir, basis_dir, bc_basis_dir = basis_functions_wrapper( - basis_algorithm, - nbasis, - fp_basis_case, - bc_basis_case, - basis_directory, - bc_basis_directory, - use_bc, - fp_all, - species, - sites, - domain, - start_date, - emissions_name, - outputname, + fp_data = basis_functions_wrapper( + basis_algorithm=basis_algorithm, + nbasis=nbasis, + fp_basis_case=fp_basis_case, + bc_basis_case=bc_basis_case, + basis_directory=basis_directory, + bc_basis_directory=bc_basis_directory, + fp_all=fp_all, + use_bc=use_bc, + species=species, + sites=sites, + domain=domain, + start_date=start_date, + emissions_name=emissions_name, + outputname=outputname, output_path=basis_output_path, ) @@ -613,7 +593,6 @@ def fixedbasisMCMC( "fp_data": fp_data, "emissions_name": emissions_name, "emissions_store": emissions_store, - "basis_directory": basis_dir, "country_file": country_file, } @@ -637,18 +616,6 @@ def fixedbasisMCMC( elif use_tracer: raise ValueError("Model does not currently include tracer model. Watch this space") - if basis_algorithm is not None: - # remove the temporary basis function directory - delete = True - if not os.path.dirname(tempdir).startswith("Temp_"): - delete = False - for _, _, files in os.walk(tempdir): - for file in files: - if not file.startswith("quadtree"): # TODO: update this to look for other basis types - delete = False - if delete: - shutil.rmtree(tempdir) - print("---- Inversion completed ----") return out diff --git a/openghg_inversions/hbmcmc/inversion_pymc.py b/openghg_inversions/hbmcmc/inversion_pymc.py index c2d03c30..581aef46 100644 --- a/openghg_inversions/hbmcmc/inversion_pymc.py +++ b/openghg_inversions/hbmcmc/inversion_pymc.py @@ -330,7 +330,6 @@ def inferpymc_postprocessouts( bcouts: Optional[np.ndarray] = None, Hbc: Optional[np.ndarray] = None, fp_data=None, - basis_directory=None, country_file=None, add_offset=False, rerun_file=None, diff --git a/openghg_inversions/utils.py b/openghg_inversions/utils.py index e84abc56..84744376 100644 --- a/openghg_inversions/utils.py +++ b/openghg_inversions/utils.py @@ -13,7 +13,6 @@ import json from pathlib import Path import os -import sys from types import SimpleNamespace import pandas as pd @@ -714,7 +713,7 @@ def timeseries_HiTRes( """ if verbose: print(f"\nCalculating timeseries with {time_resolution} resolution, this might take a few minutes") - ### get the high time res footprint + # get the high time res footprint if fp_HiTRes_ds is None and fp_file is None: print("Must provide either a footprint Dataset or footprint filename") return None @@ -827,7 +826,7 @@ def timeseries_HiTRes( # put the time array into tqdm if we want a progress bar to show throughout the loop iters = tqdm(time_array) if verbose else time_array - ### iterate through the time coord to get the total mf at each time step using the H back coord + # iterate through the time coord to get the total mf at each time step using the H back coord # at each release time we disaggregate the particles backwards over the previous 24hrs for tt, time in enumerate(iters): # get 4 dimensional chunk of high time res footprint for this timestep @@ -954,7 +953,7 @@ def timeseries_HiTRes( print(f"Saving to {output_file}") timeseries.to_netcdf(output_file) elif output_file is not None: - print(f"output type must be dataset to save to file") + print("output type must be dataset to save to file") if output_fpXflux: return timeseries, fpXflux @@ -962,7 +961,7 @@ def timeseries_HiTRes( return timeseries -def fp_sensitivity(fp_and_data, domain, basis_case, basis_directory=None, verbose=True): +def fp_sensitivity(fp_and_data, basis_func, verbose=True): """ The fp_sensitivity function adds a sensitivity matrix, H, to each site xarray dataframe in fp_and_data. @@ -982,10 +981,6 @@ def fp_sensitivity(fp_and_data, domain, basis_case, basis_directory=None, verbos String if only one basis case is required. Dict if there are multiple sources that require separate basis cases. In which case, keys in dict should reflect keys in emissions_name dict used in fp_data_merge. - basis_directory (str): - basis_directory can be specified if files are not in the default - directory. Must point to a directory which contains subfolders organized - by domain. (optional) Returns: dict (xarray.Dataset): @@ -997,173 +992,148 @@ def fp_sensitivity(fp_and_data, domain, basis_case, basis_directory=None, verbos flux_sources = list(fp_and_data[".flux"].keys()) - if type(basis_case) is not dict: + if not isinstance(basis_func, dict): if len(flux_sources) == 1: - basis_case = {flux_sources[0]: basis_case} + basis_func = {flux_sources[0]: basis_func} else: - basis_case = {"all": basis_case} + basis_func = {"all": basis_func} - if len(list(basis_case.keys())) != len(flux_sources): - if len(list(basis_case.keys())) == 1: - print("Using %s as the basis case for all sources" % basis_case[list(basis_case.keys())[0]]) + if len(list(basis_func.keys())) != len(flux_sources): + if len(list(basis_func.keys())) == 1: + print(f"Using {basis_func[list(basis_func.keys())[0]]} as the basis case for all sources") else: print( - "There should either only be one basis_case, or it should be a dictionary the same length\ + "There should either only be one basis_func, or it should be a dictionary the same length\ as the number of sources." ) return None for site in sites: - for si, source in enumerate(flux_sources): - if source in list(basis_case.keys()): - basis_func = basis( - domain=domain, basis_case=basis_case[source], basis_directory=basis_directory - ) + site_sensitivities = [] + for source in flux_sources: + if source in list(basis_func.keys()): + current_basis_func = basis_func[source] else: - basis_func = basis( - domain=domain, basis_case=basis_case["all"], basis_directory=basis_directory - ) + current_basis_func = basis_func["all"] + + sensitivity, site_bf = fp_sensitivity_single_site_basis_func( + scenario=fp_and_data[site], + flux=fp_and_data[".flux"][source], + source=source, + basis_func=current_basis_func, + verbose=verbose, + ) - if type(fp_and_data[".flux"][source]) == dict: - if "fp_HiTRes" in list(fp_and_data[site].keys()): - site_bf = xr.Dataset( - {"fp_HiTRes": fp_and_data[site]["fp_HiTRes"], "fp": fp_and_data[site]["fp"]} - ) - - fp_time = ( - (fp_and_data[site].time[1] - fp_and_data[site].time[0]) - .values.astype("timedelta64[h]") - .astype(int) - ) - - # calculate the H matrix - H_all = timeseries_HiTRes( - fp_HiTRes_ds=site_bf, - flux_dict=fp_and_data[".flux"][source], - output_TS=False, - output_fpXflux=True, - output_type="DataArray", - time_resolution=f"{fp_time}H", - verbose=verbose, - ) - else: - print( - "fp_and_data needs the variable fp_HiTRes to use the emissions dictionary with high_freq and low_freq emissions." - ) + site_sensitivities.append(sensitivity) - else: - site_bf = combine_datasets( - fp_and_data[site]["fp"].to_dataset(), fp_and_data[".flux"][source].data - ) - H_all = site_bf.fp * site_bf.flux + fp_and_data[site]["H"] = xr.concat(site_sensitivities, dim="region") + fp_and_data[".basis"] = site_bf.basis[:, :, 0] + # TODO: this will only contain the last value in the loop... - H_all_v = H_all.values.reshape((len(site_bf.lat) * len(site_bf.lon), len(site_bf.time))) + return fp_and_data - if "region" in list(basis_func.dims.keys()): - if "time" in basis_func.basis.dims: - basis_func = basis_func.isel(time=0) - site_bf = xr.merge([site_bf, basis_func]) +def fp_sensitivity_single_site_basis_func(scenario, flux, source, basis_func, verbose=True): + """ + The fp_sensitivity function adds a sensitivity matrix, H, to each + site xarray dataframe in fp_and_data. + Basis function data in an array: lat, lon, no. regions. + In each 'region'element of array there is a lat-lon grid with 1 in + region and 0 outside region. - H = np.zeros((len(site_bf.region), len(site_bf.time))) + Region numbering must start from 1 + ----------------------------------- + Args: + scenario: + Output from footprints_data_merge() function; e.g. `fp_all["TAC"]` + flux: + array with flux values + source: + name of flux source + domain (str): + Domain name. The footprint files should be sub-categorised by the domain. + basis_func: + basis functions - base_v = site_bf.basis.values.reshape( - (len(site_bf.lat) * len(site_bf.lon), len(site_bf.region)) - ) + Returns: + sensitivity ("H") xr.DataArray and site_bf xr.Dataset + ----------------------------------- + """ + if isinstance(flux, dict): + if "fp_HiTRes" in list(scenario.keys()): + site_bf = xr.Dataset({"fp_HiTRes": scenario["fp_HiTRes"], "fp": scenario["fp"]}) + + fp_time = (scenario.time[1] - scenario.time[0]).values.astype("timedelta64[h]").astype(int) + + # calculate the H matrix + H_all = timeseries_HiTRes( + fp_HiTRes_ds=site_bf, + flux_dict=flux, + output_TS=False, + output_fpXflux=True, + output_type="DataArray", + time_resolution=f"{fp_time}H", + verbose=verbose, + ) + else: + raise ValueError( + "fp_and_data needs the variable fp_HiTRes to use the emissions dictionary with high_freq and low_freq emissions." + ) - for i in range(len(site_bf.region)): - H[i, :] = np.nansum(H_all_v * base_v[:, i, np.newaxis], axis=0) - - if source == all: - if sys.version_info < (3, 0): - region_name = site_bf.region - else: - region_name = site_bf.region.decode("ascii") - else: - if sys.version_info < (3, 0): - region_name = [source + "-" + reg for reg in site_bf.region.values] - else: - region_name = [source + "-" + reg.decode("ascii") for reg in site_bf.region.values] - - sensitivity = xr.DataArray( - H, coords=[("region", region_name), ("time", fp_and_data[site].coords["time"])] - ) + else: + site_bf = combine_datasets(scenario["fp"].to_dataset(), flux.data) + H_all = site_bf.fp * site_bf.flux - else: - print("Warning: Using basis functions without a region dimension may be deprecated shortly.") + H_all_v = H_all.values.reshape((len(site_bf.lat) * len(site_bf.lon), len(site_bf.time))) - site_bf = combine_datasets(site_bf, basis_func, method="ffill") + if "region" in list(basis_func.dims.keys()): + if "time" in basis_func.basis.dims: + basis_func = basis_func.isel(time=0) - H = np.zeros((int(np.max(site_bf.basis)), len(site_bf.time))) + site_bf = xr.merge([site_bf, basis_func]) - basis_scale = xr.Dataset( - {"basis_scale": (["lat", "lon", "time"], np.zeros(np.shape(site_bf.basis)))}, - coords=site_bf.coords, - ) - site_bf = site_bf.merge(basis_scale) - - base_v = np.ravel(site_bf.basis.values[:, :, 0]) - for i in range(int(np.max(site_bf.basis))): - wh_ri = np.where(base_v == i + 1) - H[i, :] = np.nansum(H_all_v[wh_ri[0], :], axis=0) - - if source == all: - region_name = list(range(1, np.max(site_bf.basis.values) + 1)) - else: - region_name = [ - source + "-" + str(reg) for reg in range(1, int(np.max(site_bf.basis.values) + 1)) - ] - - sensitivity = xr.DataArray( - H.data, coords=[("region", region_name), ("time", fp_and_data[site].coords["time"].data)] - ) + H = np.zeros((len(site_bf.region), len(site_bf.time))) - if si == 0: - concat_sensitivity = sensitivity - else: - concat_sensitivity = xr.concat((concat_sensitivity, sensitivity), dim="region") - - sub_basis_cases = 0 - - if basis_case[source].startswith("sub"): - """ - To genrate sub_lon and sub_lat grids basis case must start with 'sub' - e.g. - 'sub-transd', 'sub_transd', sub-intem' will work - 'transd' or 'transd-sub' won't work - """ - sub_basis_cases += 1 - if sub_basis_cases > 1: - print("Can currently only use a sub basis case for one source. Skipping...") - else: - sub_fp_temp = site_bf.fp.sel(lon=site_bf.sub_lon, lat=site_bf.sub_lat, method="nearest") - sub_fp = xr.Dataset( - {"sub_fp": (["sub_lat", "sub_lon", "time"], sub_fp_temp.data)}, - coords={ - "sub_lat": (site_bf.coords["sub_lat"].data), - "sub_lon": (site_bf.coords["sub_lon"].data), - "time": (fp_and_data[site].coords["time"].data), - }, - ) - - sub_H_temp = H_all.sel(lon=site_bf.sub_lon, lat=site_bf.sub_lat, method="nearest") - sub_H = xr.Dataset( - {"sub_H": (["sub_lat", "sub_lon", "time"], sub_H_temp.data)}, - coords={ - "sub_lat": (site_bf.coords["sub_lat"].data), - "sub_lon": (site_bf.coords["sub_lon"].data), - "time": (fp_and_data[site].coords["time"].data), - }, - attrs={"flux_source_used_to_create_sub_H": source}, - ) - - fp_and_data[site] = fp_and_data[site].merge(sub_fp) - fp_and_data[site] = fp_and_data[site].merge(sub_H) - - fp_and_data[site]["H"] = concat_sensitivity - fp_and_data[".basis"] = site_bf.basis[:, :, 0] + base_v = site_bf.basis.values.reshape((len(site_bf.lat) * len(site_bf.lon), len(site_bf.region))) - return fp_and_data + for i in range(len(site_bf.region)): + H[i, :] = np.nansum(H_all_v * base_v[:, i, np.newaxis], axis=0) + + if source == "all": + region_name = site_bf.region.decode("ascii") + else: + region_name = [source + "-" + reg.decode("ascii") for reg in site_bf.region.values] + + sensitivity = xr.DataArray(H, coords=[("region", region_name), ("time", scenario.coords["time"])]) + + else: + print("Warning: Using basis functions without a region dimension may be deprecated shortly.") + + site_bf = combine_datasets(site_bf, basis_func, method="ffill") + + H = np.zeros((int(np.max(site_bf.basis)), len(site_bf.time))) + + basis_scale = xr.Dataset( + {"basis_scale": (["lat", "lon", "time"], np.zeros(np.shape(site_bf.basis)))}, + coords=site_bf.coords, + ) + site_bf = site_bf.merge(basis_scale) + + base_v = np.ravel(site_bf.basis.values[:, :, 0]) + for i in range(int(np.max(site_bf.basis))): + wh_ri = np.where(base_v == i + 1) + H[i, :] = np.nansum(H_all_v[wh_ri[0], :], axis=0) + + if source == "all": + region_name = list(range(1, np.max(site_bf.basis.values) + 1)) + else: + region_name = [source + "-" + str(reg) for reg in range(1, int(np.max(site_bf.basis.values) + 1))] + + sensitivity = xr.DataArray( + H.data, coords=[("region", region_name), ("time", scenario.coords["time"].data)] + ) + + return sensitivity, site_bf def bc_sensitivity(fp_and_data, domain, basis_case, bc_basis_directory=None): diff --git a/tests/data/basis/EUROPE/quadtree_ch4-test_basis_EUROPE_2019.nc b/tests/data/basis/EUROPE/quadtree_ch4-test_basis_EUROPE_2019.nc new file mode 100644 index 00000000..4fbb7976 Binary files /dev/null and b/tests/data/basis/EUROPE/quadtree_ch4-test_basis_EUROPE_2019.nc differ diff --git a/tests/test_basis_functions.py b/tests/test_basis_functions.py new file mode 100644 index 00000000..8dcb64b3 --- /dev/null +++ b/tests/test_basis_functions.py @@ -0,0 +1,30 @@ +import xarray as xr +from openghg_inversions import utils +from openghg_inversions.basis_functions import quadtreebasisfunction +from openghg_inversions.get_data import data_processing_surface_notracer + + +def test_quadtree_basis_function(tac_ch4_data_args, raw_data_path): + """Check if quadtree basis created with seed 42 and TAC CH4 args matches + a basis created with the same arguments and saved to file. + + This is to check against changes in the code from when this test was made + (13 Feb 2024) + """ + fp_all, *_ = data_processing_surface_notracer(**tac_ch4_data_args) + emissions_name = next(iter(fp_all[".flux"].keys())) + basis_func = quadtreebasisfunction( + emissions_name=[emissions_name], + fp_all=fp_all, + sites=["TAC"], + start_date="2019-01-01", + domain="EUROPE", + species="ch4", + seed=42, + ) + + basis_func_reloaded = utils.basis( + domain="EUROPE", basis_case="quadtree_ch4-test_basis", basis_directory=raw_data_path / "basis" + ) + + xr.testing.assert_allclose(basis_func, basis_func_reloaded) diff --git a/tests/test_get_data.py b/tests/test_get_data.py index 5756ae11..b52f6460 100644 --- a/tests/test_get_data.py +++ b/tests/test_get_data.py @@ -31,7 +31,6 @@ def test_data_processing_surface_notracer(tac_ch4_data_args, raw_data_path): def test_save_load_merged_data(tac_ch4_data_args, merged_data_dir): - merged_data_name = "test_save_load_merged_data" # make merged data dir