Skip to content

Commit

Permalink
Fixed waveform alignment in likelihood.
Browse files Browse the repository at this point in the history
  • Loading branch information
transientlunatic committed Aug 16, 2024
1 parent 918d53f commit 9a07570
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 16 deletions.
48 changes: 48 additions & 0 deletions heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,54 @@ def make_injection(
return injections


def make_injection_zero_noise(
waveform=IMRPhenomPv2,
injection_parameters={},
times=None,
detectors=None,
framefile=None,
):

parameters = {"ra": 0, "dec": 0, "psi": 0, "theta_jn": 0, "phase": 0, 'gpstime': 4000}
parameters.update(injection_parameters)

waveform = waveform()

if times is None:
times = np.linspace(-0.5, 0.1, int(0.6 * 4096)) + parameters['gpstime']
waveform = waveform.time_domain(
parameters,
times=times,
)

injections = {}
for detector, psd_model in detectors.items():
detector = KNOWN_IFOS[detector]()
channel = f"{detector.abbreviation}:Injection"
logger.info(f"Making injection for {detector} in channel {channel}")
psd_model = KNOWN_PSDS[psd_model]()
#data = psd_model.time_series(times)

# import matplotlib
# matplotlib.use("agg")
# from gwpy.plot import Plot
# f = Plot(data, waveform.project(detector), data+waveform.project(detector), separate=False)
# f.savefig(f"{detector.abbreviation}_injected_waveform.png")

injection = waveform.project(detector)
injection.channel = channel
injections[detector.abbreviation] = injection
likelihood = TimeDomainLikelihood(injection, psd=psd_model)
snr = likelihood.snr(waveform.project(detector))
logger.info(f"Optimal Filter SNR: {snr}")

if framefile:
filename = f"{detector.abbreviation}_{framefile}.gwf"
logger.info(f"Saving framefile to {filename}")
injection.write(filename, format="gwf")

return injections

def injection_parameters_add_units(parameters):
UNITS = {"luminosity_distance": u.megaparsec, "m1": u.solMass, "m2": u.solMass}

Expand Down
14 changes: 8 additions & 6 deletions heron/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def __init__(
timing_basis=None,
):
self.psd = psd

self.timeseries = data
self.data = np.array(data.data)
self.times = data.times

Expand Down Expand Up @@ -86,19 +86,21 @@ def snr(self, waveform):
"""
dt = (self.times[1] - self.times[0]).value
N = len(self.times)
w = np.array(waveform.data)
h_h = (
(np.array(waveform.data).T @ self.solve(self.C, np.array(waveform.data)))
(w.T @ self.solve(self.C, w))
* (dt * dt / N / 4)
/ 4
)
return np.sqrt(np.abs(h_h))

def log_likelihood(self, waveform):
residual = np.array(self.data.data) - np.array(waveform.data)
def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
residual = np.array(self.data.data[a[0]:a[1]]) - np.array(waveform.data[b[0]:b[1]])
weighted_residual = (
(residual) @ self.solve(self.C, residual) * (self.dt * self.dt / 4) / 4
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C)
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return -0.5 * weighted_residual + 0.5 * normalisation

def __call__(self, parameters):
Expand Down
62 changes: 54 additions & 8 deletions heron/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from lalinference import DetFrameToEquatorial

import numpy as array_library
import numpy as np
import matplotlib.pyplot as plt


Expand All @@ -16,8 +17,58 @@ class TimeSeries(TimeSeries):
Overload the GWPy timeseries so that additional methods can be defined upon it.
"""

pass
def determine_overlap(self, timeseries_a, timeseries_b):
def is_in(time, timeseries):
diff = np.min(np.abs(timeseries - time))
if diff < (timeseries[1] - timeseries[0]):
return True, diff
else:
return False, diff

overlap = None
if (
is_in(timeseries_a.times[-1], timeseries_b.times)[0]
and is_in(timeseries_b.times[0], timeseries_a.times)[0]
):
overlap = timeseries_b.times[0], timeseries_a.times[-1]
elif (
is_in(timeseries_a.times[0], timeseries_b.times)[0]
and is_in(timeseries_b.times[-1], timeseries_a.times)[0]
):
overlap = timeseries_a.times[0], timeseries_b.times[-1]
elif (
is_in(timeseries_b.times[0], timeseries_a.times)[0]
and is_in(timeseries_b.times[-1], timeseries_a.times)[0]
and not is_in(timeseries_a.times[-1], timeseries_b.times)[0]
):
overlap = timeseries_b.times[0], timeseries_b.times[-1]
elif (
is_in(timeseries_a.times[0], timeseries_b.times)[0]
and is_in(timeseries_a.times[-1], timeseries_b.times)[0]
and not is_in(timeseries_b.times[-1], timeseries_a.times)[0]
):
overlap = timeseries_a.times[0], timeseries_a.times[-1]
else:
overlap = None
return None

start_a = np.argmin(np.abs(timeseries_a.times - overlap[0]))
finish_a = np.argmin(np.abs(timeseries_a.times - overlap[-1]))

start_b = np.argmin(np.abs(timeseries_b.times - overlap[0]))
finish_b = np.argmin(np.abs(timeseries_b.times - overlap[-1]))
return (start_a, finish_a), (start_b, finish_b)

def align(self, waveform_b):
"""
Align this waveform with another one by altering the phase.
"""

indices = self.determine_overlap(self, waveform_b)

return self[indices[0][0]:indices[0][1]], waveform_b[indices[1][0]: indices[1][1]]



class PSD(FrequencySeries):
def __init__(self, data, frequencies, *args, **kwargs):
Expand All @@ -40,7 +91,7 @@ def __init__(self, variance=None, covariance=None, *args, **kwargs):
def __new__(self, variance=None, covariance=None, *args, **kwargs):
# if "covariance" in kwargs:
# self.covariance = kwargs.pop("covariance")
waveform = super(Waveform, self).__new__(TimeSeriesBase, *args, **kwargs)
waveform = super(Waveform, self).__new__(TimeSeries, *args, **kwargs)
waveform.covariance = covariance
waveform.variance = variance

Expand All @@ -50,11 +101,6 @@ def __new__(self, variance=None, covariance=None, *args, **kwargs):
# def dt(self):
# return self.waveform.times[1] - self.waveform.times[0]

def align(self, waveform_b):
"""
Align this waveform with another one by altering the phase.
"""
pass


class WaveformDict:
Expand Down Expand Up @@ -194,7 +240,7 @@ def project(
covariance=projected_covariance,
times=self.waveforms["plus"].times,
)

projected_waveform.shift(dt)

return projected_waveform
Expand Down
47 changes: 45 additions & 2 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

from heron.models.lalsimulation import SEOBNRv3, IMRPhenomPv2, IMRPhenomPv2_FakeUncertainty
from heron.models.lalnoise import AdvancedLIGO
from heron.injection import make_injection
from heron.injection import make_injection, make_injection_zero_noise
from heron.detector import Detector, AdvancedLIGOHanford, AdvancedLIGOLivingston, AdvancedVirgo
from heron.likelihood import MultiDetector, TimeDomainLikelihood, TimeDomainLikelihoodModelUncertainty
# TimeDomainLikelihoodPyTorch, TimeDomainLikelihoodModelUncertaintyPyTorch
Expand All @@ -29,6 +29,49 @@
CUDA_NOT_AVAILABLE = not is_available()


class Test_Likelihood_ZeroNoise(unittest.TestCase):
"""
Test likelihoods on a zero noise injection.
"""

def setUp(self):
self.waveform = IMRPhenomPv2()
self.psd_model = AdvancedLIGO()

self.injections = make_injection_zero_noise(waveform=IMRPhenomPv2,
injection_parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
"gpstime": 0,
"total_mass": 60 * u.solMass},
detectors={"AdvancedLIGOHanford": "AdvancedLIGO",
"AdvancedLIGOLivingston": "AdvancedLIGO"}
)

def test_likelihood_no_norm(self):
data = self.injections['H1']

from gwpy.plot import Plot

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

f = Plot(data, projected_waveform)
f.savefig("projected_waveform.png")

log_like = likelihood.log_likelihood(projected_waveform, norm=False)

self.assertTrue(log_like <= 1e-5)


class Test_Filter(unittest.TestCase):
"""Test that filters can be applied correctly to data."""

Expand All @@ -44,6 +87,7 @@ def setUp(self):
"AdvancedLIGOLivingston": "AdvancedLIGO"}
)


def test_timedomain_psd(self):
noise = self.psd_model.time_domain(times=self.injections['H1'].times)
#print(noise)
Expand All @@ -66,7 +110,6 @@ def test_snr(self):
f.savefig("projected_waveform.png")

snr = likelihood.snr(projected_waveform)
print("snr", snr)
self.assertTrue(snr > 40 and snr < 45)

# def test_snr_f(self):
Expand Down

0 comments on commit 9a07570

Please sign in to comment.