Skip to content

Commit

Permalink
Updated tests to ensure compatibility between torch and numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
transientlunatic committed Aug 19, 2024
1 parent 7f20b36 commit 97590b7
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 18 deletions.
2 changes: 1 addition & 1 deletion heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_injection_zero_noise(
logger.info(f"Saving framefile to {filename}")
injection.write(filename, format="gwf")

return injections
return injections

def injection_parameters_add_units(parameters):
UNITS = {"luminosity_distance": u.megaparsec, "m1": u.solMass, "m2": u.solMass}
Expand Down
24 changes: 18 additions & 6 deletions heron/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def log_likelihood(self, waveform, norm=True):
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return -0.5 * weighted_residual + 0.5 * normalisation
return 0.5 * weighted_residual + 0.5 * normalisation

def __call__(self, parameters):
self.logger.info(parameters)
Expand Down Expand Up @@ -233,6 +233,8 @@ def __init__(
self.logger.info(f"Using device {device}")
self.psd = psd

self.timeseries = data

self.data = torch.tensor(data.data, device=self.device, dtype=torch.double)
self.times = data.times

Expand Down Expand Up @@ -263,14 +265,24 @@ def snr(self, waveform):
h_h = (waveform_d.T @ self.solve(self.C, waveform_d)) * (dt * dt / N / 4) / 4
return torch.sqrt(torch.abs(h_h))

def log_likelihood(self, waveform):
def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
residual = np.array(self.data.data[a[0]:a[1]]) - np.array(waveform.data[b[0]:b[1]])
weighted_residual = (
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return 0.5 * weighted_residual + 0.5 * normalisation

def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
waveform_d = torch.tensor(waveform.data, device=self.device, dtype=torch.double)
residual = self.data - waveform_d
residual = self.data[a[0]:a[1]] - waveform_d[b[0]:b[1]]
weighted_residual = (
(residual) @ self.solve(self.C, residual) * (self.dt * self.dt / 4) / 4
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C)
return -0.5 * weighted_residual + 0.5 * normalisation
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return 0.5 * weighted_residual + 0.5 * normalisation

def __call__(self, parameters):
self.logger.info(parameters)
Expand Down
File renamed without changes.
File renamed without changes.
116 changes: 105 additions & 11 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from heron.models.lalnoise import AdvancedLIGO
from heron.injection import make_injection, make_injection_zero_noise
from heron.detector import Detector, AdvancedLIGOHanford, AdvancedLIGOLivingston, AdvancedVirgo
from heron.likelihood import MultiDetector, TimeDomainLikelihood, TimeDomainLikelihoodModelUncertainty
# TimeDomainLikelihoodPyTorch, TimeDomainLikelihoodModelUncertaintyPyTorch
from heron.likelihood import MultiDetector, TimeDomainLikelihood, TimeDomainLikelihoodModelUncertainty, TimeDomainLikelihoodPyTorch
#, TimeDomainLikelihoodModelUncertaintyPyTorch

from heron.inference import heron_inference, parse_dict, load_yaml

Expand Down Expand Up @@ -50,7 +50,7 @@ def setUp(self):
def test_likelihood_no_norm(self):
data = self.injections['H1']

from gwpy.plot import Plot
# from gwpy.plot import Plot

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)

Expand All @@ -64,20 +64,63 @@ def test_likelihood_no_norm(self):
phi_0=0, psi=0,
iota=0)

f = Plot(data, projected_waveform)
f.savefig("projected_waveform.png")
# f = Plot(data, projected_waveform)
# f.savefig("projected_waveform.png")

log_like = likelihood.log_likelihood(projected_waveform, norm=False)

self.assertTrue(log_like <= 1e-5)


def test_likelihood_maximum_at_true_value_mass_ratio(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform))

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)


class Test_PyTorch_Likelihood_ZeroNoise(unittest.TestCase):
"""
Test likelihoods on a zero noise injection.
"""

def setUp(self):
self.waveform = IMRPhenomPv2()
self.psd_model = AdvancedLIGO()

self.injections = make_injection_zero_noise(waveform=IMRPhenomPv2,
injection_parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
"gpstime": 0,
"total_mass": 60 * u.solMass},
detectors={"AdvancedLIGOHanford": "AdvancedLIGO",
"AdvancedLIGOLivingston": "AdvancedLIGO"}
)

def test_likelihood_no_norm(self):
data = self.injections['H1']

from gwpy.plot import Plot
# from gwpy.plot import Plot

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
Expand All @@ -89,12 +132,63 @@ def test_likelihood_no_norm(self):
phi_0=0, psi=0,
iota=0)

f = Plot(data, projected_waveform)
f.savefig("projected_waveform.png")
log_like = likelihood.log_likelihood(projected_waveform, norm=False)

log_like = likelihood.log_likelihood(projected_waveform)
self.assertTrue(log_like.cpu().numpy() <= 1e-5)

self.assertTrue(log_like <= 1e-5)

def test_likelihood_maximum_at_true_value_mass_ratio(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform).cpu().numpy())

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)


def test_likelihood_numpy_equivalent(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)
numpy_likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
log_likes_n = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform).cpu().numpy())
log_likes_n.append(numpy_likelihood.log_likelihood(projected_waveform))

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)
self.assertTrue(np.all((np.array(log_likes) - np.array(log_likes_n)) < 0.001))


class Test_Filter(unittest.TestCase):
Expand Down

0 comments on commit 97590b7

Please sign in to comment.