Skip to content

Commit

Permalink
Added some bug fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
transientlunatic committed Sep 3, 2024
1 parent 208750c commit 07889ea
Show file tree
Hide file tree
Showing 7 changed files with 20 additions and 11 deletions.
11 changes: 7 additions & 4 deletions heron/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import click

from gwpy.timeseries import TimeSeries
from .types import TimeSeries
import astropy.units as u

from nessai.flowsampler import FlowSampler
Expand Down Expand Up @@ -84,6 +84,7 @@ def heron_inference(settings):
if "data files" in settings.get("data", {}):
# Load frame files from disk
for ifo in settings["interferometers"]:
print(f"Loading {ifo} data")
logger.info(
f"Loading {ifo} data from "
f"{settings['data']['data files'][ifo]}/{settings['data']['channels'][ifo]}"
Expand All @@ -93,14 +94,16 @@ def heron_inference(settings):
channel=settings["data"]["channels"][ifo],
format="gwf",
)
elif "injection" in other_settings:
pass
#elif "injection" in other_settings:
# pass

# Make Likelihood
if len(settings["interferometers"]) > 1:
likelihoods = []
print("Creating likelihoods")
waveform_model = KNOWN_WAVEFORMS[settings["waveform"]["model"]]()
for ifo in settings["interferometers"]:
print(f"\t {ifo}")
likelihoods.append(
KNOWN_LIKELIHOODS[settings.get("likelihood").get("function")](
data[ifo],
Expand All @@ -113,7 +116,7 @@ def heron_inference(settings):
),
)
)
likelihood = MultiDetector(*likelihoods)
likelihood = MultiDetector(*likelihoods)

priors = heron.priors.PriorDict()
priors.from_dictionary(settings["priors"])
Expand Down
2 changes: 1 addition & 1 deletion heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def injection_parameters_add_units(parameters):
UNITS = {"luminosity_distance": u.megaparsec, "m1": u.solMass, "m2": u.solMass}

for parameter, value in parameters.items():
if not isinstance(value, u.Quantity):
if not isinstance(value, u.Quantity) and parameter in UNITS:
parameters[parameter] = value * UNITS[parameter]
return parameters

Expand Down
7 changes: 5 additions & 2 deletions heron/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ def __init__(
self.timeseries = data
self.data = np.array(data.data)
self.times = data.times

self.C = self.psd.covariance_matrix(times=self.times)
self.inverse_C = np.linalg.inv(self.C)

Expand Down Expand Up @@ -96,7 +95,11 @@ def snr(self, waveform):
return np.sqrt(np.abs(h_h))

def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
w = self.timeseries.determine_overlap(self, waveform)
if w is not None:
(a,b) = w
else:
return -np.inf
residual = np.array(self.data.data[a[0]:a[1]]) - np.array(waveform.data[b[0]:b[1]])
weighted_residual = (
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
Expand Down
1 change: 0 additions & 1 deletion heron/models/lalsimulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ def _convert_units(self, args):
args[name] = argument.to_value(units[mappings[name]])
elif name in mappings.keys() and argument:
# This is commented out as it causes problems if e.g. lalnative values are passed
print(f"Performing a mapping on {name}, {argument}")
args[name] = (argument * default_units[mappings[name]]).to_value(
units[mappings[name]]
)
Expand Down
1 change: 1 addition & 0 deletions heron/priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"Uniform": bilby.prior.Uniform,
"PowerLaw": bilby.prior.PowerLaw,
"Sine": bilby.prior.Sine,
"Cosine": bilby.prior.Cosine,
"UniformSourceFrame": bilby.gw.prior.UniformSourceFrame,
"UniformInComponentsMassRatio": bilby.gw.prior.UniformInComponentsMassRatio,
}
Expand Down
3 changes: 3 additions & 0 deletions heron/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def is_in(time, timeseries):
overlap = timeseries_a.times[0], timeseries_a.times[-1]
else:
overlap = None
#print("No overlap found")
#print(timeseries_a.times[0], timeseries_a.times[-1])
#print(timeseries_b.times[0], timeseries_b.times[-1])
return None

start_a = np.argmin(np.abs(timeseries_a.times - overlap[0]))
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,6 @@ click
asimov
pesummary
nessai
gpytorch==1.0.1
torch==2.4.0
torchvision==0.5.0
gpytorch
torch
torchvision

0 comments on commit 07889ea

Please sign in to comment.