Skip to content

Commit

Permalink
Merge pull request #82 from openghg/Iss62-refactor-basis-temp-files
Browse files Browse the repository at this point in the history
Iss62 refactor basis temp files
  • Loading branch information
brendan-m-murphy authored Mar 26, 2024
2 parents 245f30d + 6219259 commit 7d38d68
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 307 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

# Version 0.2 (current devel)

- Refactored basis functions so that they return an xr.Dataset, rather than writing to temporary files. If an output directory is specified, they will save the basis functions as a side effect.

- Added option to run an inversion without boundary conditions. This is specified by adding `use_bc = False` in an .ini file. This assumes that the baseline has already been factored into the observations.

- Added tests to test `get_data.py`, including creating, saving, and loading merged data. Refactored inversions tests to reload merged data, instead of creating merged data.
Expand Down
92 changes: 40 additions & 52 deletions openghg_inversions/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
# the inversion runs.
# *****************************************************************************

import os
import uuid
import getpass
import scipy.optimize
import os
from typing import Optional

import numpy as np
import xarray as xr
import pandas as pd
import scipy.optimize
import xarray as xr


class quadTreeNode:
Expand Down Expand Up @@ -108,21 +109,22 @@ def quadTreeGrid(grid, limit):
outputGrid[leaf.xStart : leaf.xEnd, leaf.yStart : leaf.yEnd] = i
boxList.append([leaf.xStart, leaf.xEnd, leaf.yStart, leaf.yEnd])

return outputGrid, boxList
return outputGrid


def quadtreebasisfunction(
emissions_name,
fp_all,
sites,
start_date,
domain,
species,
outputname,
outputdir=None,
nbasis=100,
abs_flux=False,
):
emissions_name: list[str],
fp_all: dict,
sites: list[str],
start_date: str,
domain: str,
species: str,
outputname: Optional[str] = None,
outputdir: Optional[str] = None,
nbasis: int = 100,
abs_flux: bool = False,
seed: Optional[int] = None,
) -> xr.Dataset:
"""
Creates a basis function with nbasis grid cells using a quadtree algorithm.
The domain is split with smaller grid cells for regions which contribute
Expand Down Expand Up @@ -160,10 +162,12 @@ def quadtreebasisfunction(
i.e. nbasis % 4 = 1.
abs_flux (bool):
If True this will take the absolute value of the flux
seed:
Optional seed to pass to scipy.optimize.dual_annealing. Used for testing.
Returns:
If outputdir is None, then returns a Temp directory. The new basis function is saved in this Temp directory.
If outputdir is not None, then does not return anything but saves the basis function in outputdir.
xr.Dataset with lat/lon dimensions and basis regions encoded by integers.
If outputdir is not None, then saves the basis function in outputdir.
-----------------------------------
"""
if abs_flux:
Expand Down Expand Up @@ -199,18 +203,20 @@ def quadtreebasisfunction(
fps = meanfp * meanflux

def qtoptim(x):
basisQuad, boxes = quadTreeGrid(fps, x)
basisQuad = quadTreeGrid(fps, x)
return (nbasis - np.max(basisQuad) - 1) ** 2

cost = 1e6
pwr = 0
while cost > 3.0:
optim = scipy.optimize.dual_annealing(qtoptim, np.expand_dims([0, 100 / 10**pwr], axis=0))
optim = scipy.optimize.dual_annealing(
qtoptim, np.expand_dims([0, 100 / 10**pwr], axis=0), seed=seed
)
cost = np.sqrt(optim.fun)
pwr += 1
if pwr > 10:
raise Exception("Quadtree did not converge after max iterations.")
basisQuad, boxes = quadTreeGrid(fps, optim.x[0])
raise RuntimeError("Quadtree did not converge after max iterations.")
basisQuad = quadTreeGrid(fps, optim.x[0])

lon = fp_all[sites[0]].lon.values
lat = fp_all[sites[0]].lat.values
Expand All @@ -229,20 +235,10 @@ def qtoptim(x):
newds.attrs["creator"] = getpass.getuser()
newds.attrs["date created"] = str(pd.Timestamp.today())

if outputdir is None:
cwd = os.getcwd()
tempdir = os.path.join(cwd, f"Temp_{str(uuid.uuid4())}")
os.mkdir(tempdir)
os.mkdir(os.path.join(tempdir, f"{domain}/"))
newds.to_netcdf(
os.path.join(
tempdir, domain, f"quadtree_{species}-{outputname}_{domain}_{start_date.split('-')[0]}.nc"
),
mode="w",
)
return tempdir
else:
if outputdir is not None:
basisoutpath = os.path.join(outputdir, domain)
if outputname is None:
outputname = "output_name"
if not os.path.exists(basisoutpath):
os.makedirs(basisoutpath)
newds.to_netcdf(
Expand All @@ -251,7 +247,8 @@ def qtoptim(x):
),
mode="w",
)
return outputdir

return newds


# BUCKET BASIS FUNCTIONS
Expand Down Expand Up @@ -421,7 +418,7 @@ def bucketbasisfunction(
outputdir=None,
nbasis=100,
abs_flux=False,
):
) -> xr.Dataset:
"""
Basis functions calculated using a weighted region approach
where each basis function / scaling region contains approximately
Expand Down Expand Up @@ -450,6 +447,9 @@ def bucketbasisfunction(
abs_flux (bool):
When set to True uses absolute values of a flux array
Returns:
xr.Dataset with lat/lon dimensions and basis regions encoded by integers.
If outputdir is not None, then saves the basis function in outputdir.
"""
if abs_flux:
print("Using absolute values of flux array")
Expand Down Expand Up @@ -502,21 +502,7 @@ def bucketbasisfunction(
newds.attrs["creator"] = getpass.getuser()
newds.attrs["date created"] = str(pd.Timestamp.today())

if outputdir is None:
cwd = os.getcwd()
tempdir = os.path.join(cwd, f"Temp_{str(uuid.uuid4())}")
os.mkdir(tempdir)
os.mkdir(os.path.join(tempdir, f"{domain}/"))
newds.to_netcdf(
os.path.join(
tempdir,
domain,
f"weighted_{species}-{outputname}_{domain}_{start_date.split('-')[0]}{start_date.split('-')[1]}.nc",
),
mode="w",
)
return tempdir
else:
if outputdir is not None:
basisoutpath = os.path.join(outputdir, domain)
if not os.path.exists(basisoutpath):
os.makedirs(basisoutpath)
Expand All @@ -527,3 +513,5 @@ def bucketbasisfunction(
),
mode="w",
)

return newds
Loading

0 comments on commit 7d38d68

Please sign in to comment.