Skip to content

Commit

Permalink
Merge branch 'v0.4-preview' of github.com:transientlunatic/heron into…
Browse files Browse the repository at this point in the history
… v0.4-preview
  • Loading branch information
transientlunatic committed Oct 1, 2024
2 parents f7df371 + e17b9a2 commit 62ce6c1
Show file tree
Hide file tree
Showing 9 changed files with 38 additions and 21 deletions.
9 changes: 6 additions & 3 deletions heron/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def heron_inference(settings):
if "data files" in settings.get("data", {}):
# Load frame files from disk
for ifo in settings["interferometers"]:
print(f"Loading {ifo} data")
logger.info(
f"Loading {ifo} data from "
f"{settings['data']['data files'][ifo]}/{settings['data']['channels'][ifo]}"
Expand All @@ -93,14 +94,16 @@ def heron_inference(settings):
channel=settings["data"]["channels"][ifo],
format="gwf",
)
elif "injection" in other_settings:
pass
#elif "injection" in other_settings:
# pass

# Make Likelihood
if len(settings["interferometers"]) > 1:
likelihoods = []
print("Creating likelihoods")
waveform_model = KNOWN_WAVEFORMS[settings["waveform"]["model"]]()
for ifo in settings["interferometers"]:
print(f"\t {ifo}")
likelihoods.append(
KNOWN_LIKELIHOODS[settings.get("likelihood").get("function")](
data[ifo],
Expand All @@ -113,7 +116,7 @@ def heron_inference(settings):
),
)
)
likelihood = MultiDetector(*likelihoods)
likelihood = MultiDetector(*likelihoods)

priors = heron.priors.PriorDict()
priors.from_dictionary(settings["priors"])
Expand Down
2 changes: 1 addition & 1 deletion heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def injection_parameters_add_units(parameters):
UNITS = {"luminosity_distance": u.megaparsec, "m1": u.solMass, "m2": u.solMass}

for parameter, value in parameters.items():
if not isinstance(value, u.Quantity):
if not isinstance(value, u.Quantity) and parameter in UNITS:
parameters[parameter] = value * UNITS[parameter]
return parameters

Expand Down
9 changes: 6 additions & 3 deletions heron/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
self.timeseries = data
self.data = np.array(data.data)
self.times = data.times

self.C = self.psd.covariance_matrix(times=self.times)
self.inverse_C = np.linalg.inv(self.C)

Expand Down Expand Up @@ -96,7 +95,11 @@ def snr(self, waveform):
return np.sqrt(np.abs(h_h))

def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
w = self.timeseries.determine_overlap(self, waveform)
if w is not None:
(a,b) = w
else:
return -np.inf
residual = np.array(self.data.data[a[0]:a[1]]) - np.array(waveform.data[b[0]:b[1]])
weighted_residual = (
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
Expand All @@ -108,7 +111,7 @@ def __call__(self, parameters):
self.logger.info(parameters)

keys = set(parameters.keys())
extrinsic = {"phase", "psi", "ra", "dec", "theta_jn"}
extrinsic = {"phase", "psi", "ra", "dec", "theta_jn", "gpstime", "geocent_time"}
conversions = {"mass_ratio", "total_mass", "luminosity_distance"}
bad_keys = keys - set(self.waveform._args.keys()) - extrinsic - conversions
if len(bad_keys) > 0:
Expand Down
25 changes: 16 additions & 9 deletions heron/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from lal import antenna
import numpy as np
from lal import antenna, MSUN_SI
from astropy import units as u


Expand All @@ -26,16 +27,22 @@ def _convert_luminosity_distance(self, args):
return args

def _convert_mass_ratio_total_mass(self, args):

args["m1"] = (args["total_mass"] / (1 + args["mass_ratio"])).to_value(
u.kilogram
)
args["m2"] = (args["total_mass"] / (1 + 1 / args["mass_ratio"])).to_value(
u.kilogram
)
args["m1"] = (args["total_mass"] / (1 + args["mass_ratio"]))
args["m2"] = (args["total_mass"] / (1 + (1 / args["mass_ratio"])))
# Do these have units?
# If not then we can skip some relatively expensive operations and apply a heuristic.
if isinstance(args["m1"], u.Quantity):
args["m1"] = args["m1"].to_value(u.kilogram)
args["m2"] = args["m2"].to_value(u.kilogram)
if (not isinstance(args["m1"], u.Quantity)) and (args["m1"] < 1000):
# This appears to be in solar masses
args["m1"] *= MSUN_SI
if (not isinstance(args["m2"], u.Quantity)) and (args["m2"] < 1000):
# This appears to be in solar masses
args["m2"] *= MSUN_SI

args.pop("total_mass")
args.pop("mass_ratio")

return args


Expand Down
1 change: 0 additions & 1 deletion heron/models/lalsimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _convert_units(self, args):
args[name] = argument.to_value(units[mappings[name]])
elif name in mappings.keys() and argument:
# This is commented out as it causes problems if e.g. lalnative values are passed
print(f"Performing a mapping on {name}, {argument}")
args[name] = (argument * default_units[mappings[name]]).to_value(
units[mappings[name]]
)
Expand Down
1 change: 1 addition & 0 deletions heron/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Uniform": bilby.prior.Uniform,
"PowerLaw": bilby.prior.PowerLaw,
"Sine": bilby.prior.Sine,
"Cosine": bilby.prior.Cosine,
"UniformSourceFrame": bilby.gw.prior.UniformSourceFrame,
"UniformInComponentsMassRatio": bilby.gw.prior.UniformInComponentsMassRatio,
}
Expand Down
3 changes: 3 additions & 0 deletions heron/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def is_in(time, timeseries):
overlap = timeseries_a.times[0], timeseries_a.times[-1]
else:
overlap = None
#print("No overlap found")
#print(timeseries_a.times[0], timeseries_a.times[-1])
#print(timeseries_b.times[0], timeseries_b.times[-1])
return None

start_a = np.argmin(np.abs(timeseries_a.times - overlap[0]))
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ cython
george
nestle
elk-waveform>=0.13
gpytorch
torch
lalsuite
h5py
click
asimov
pesummary
nessai
gpytorch
torch
torchvision
4 changes: 2 additions & 2 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ nestle

sphinx
sphinx-versions
numpydoc
kentigern
nbsphinx
sphinxcontrib-bibtex
numpydoc==1.8.0
nbsphinx==0.9.5

0 comments on commit 62ce6c1

Please sign in to comment.