Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new tests #5

Merged
merged 1 commit into from
Aug 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion heron/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def make_injection_zero_noise(
logger.info(f"Saving framefile to {filename}")
injection.write(filename, format="gwf")

return injections
return injections

def injection_parameters_add_units(parameters):
UNITS = {"luminosity_distance": u.megaparsec, "m1": u.solMass, "m2": u.solMass}
Expand Down
24 changes: 18 additions & 6 deletions heron/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def log_likelihood(self, waveform, norm=True):
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return -0.5 * weighted_residual + 0.5 * normalisation
return 0.5 * weighted_residual + 0.5 * normalisation

def __call__(self, parameters):
self.logger.info(parameters)
Expand Down Expand Up @@ -233,6 +233,8 @@ def __init__(
self.logger.info(f"Using device {device}")
self.psd = psd

self.timeseries = data

self.data = torch.tensor(data.data, device=self.device, dtype=torch.double)
self.times = data.times

Expand Down Expand Up @@ -263,14 +265,24 @@ def snr(self, waveform):
h_h = (waveform_d.T @ self.solve(self.C, waveform_d)) * (dt * dt / N / 4) / 4
return torch.sqrt(torch.abs(h_h))

def log_likelihood(self, waveform):
def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
residual = np.array(self.data.data[a[0]:a[1]]) - np.array(waveform.data[b[0]:b[1]])
weighted_residual = (
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return 0.5 * weighted_residual + 0.5 * normalisation

def log_likelihood(self, waveform, norm=True):
a, b = self.timeseries.determine_overlap(self, waveform)
waveform_d = torch.tensor(waveform.data, device=self.device, dtype=torch.double)
residual = self.data - waveform_d
residual = self.data[a[0]:a[1]] - waveform_d[b[0]:b[1]]
weighted_residual = (
(residual) @ self.solve(self.C, residual) * (self.dt * self.dt / 4) / 4
(residual) @ self.solve(self.C[a[0]:a[1],b[0]:b[1]], residual) * (self.dt * self.dt / 4) / 4
)
normalisation = self.logdet(2 * np.pi * self.C)
return -0.5 * weighted_residual + 0.5 * normalisation
normalisation = self.logdet(2 * np.pi * self.C[a[0]:a[1],b[0]:b[1]]) if norm else 0
return 0.5 * weighted_residual + 0.5 * normalisation

def __call__(self, parameters):
self.logger.info(parameters)
Expand Down
File renamed without changes.
File renamed without changes.
116 changes: 105 additions & 11 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from heron.models.lalnoise import AdvancedLIGO
from heron.injection import make_injection, make_injection_zero_noise
from heron.detector import Detector, AdvancedLIGOHanford, AdvancedLIGOLivingston, AdvancedVirgo
from heron.likelihood import MultiDetector, TimeDomainLikelihood, TimeDomainLikelihoodModelUncertainty
# TimeDomainLikelihoodPyTorch, TimeDomainLikelihoodModelUncertaintyPyTorch
from heron.likelihood import MultiDetector, TimeDomainLikelihood, TimeDomainLikelihoodModelUncertainty, TimeDomainLikelihoodPyTorch
#, TimeDomainLikelihoodModelUncertaintyPyTorch

from heron.inference import heron_inference, parse_dict, load_yaml

Expand Down Expand Up @@ -50,7 +50,7 @@ def setUp(self):
def test_likelihood_no_norm(self):
data = self.injections['H1']

from gwpy.plot import Plot
# from gwpy.plot import Plot

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)

Expand All @@ -64,20 +64,63 @@ def test_likelihood_no_norm(self):
phi_0=0, psi=0,
iota=0)

f = Plot(data, projected_waveform)
f.savefig("projected_waveform.png")
# f = Plot(data, projected_waveform)
# f.savefig("projected_waveform.png")

log_like = likelihood.log_likelihood(projected_waveform, norm=False)

self.assertTrue(log_like <= 1e-5)


def test_likelihood_maximum_at_true_value_mass_ratio(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform))

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)


class Test_PyTorch_Likelihood_ZeroNoise(unittest.TestCase):
"""
Test likelihoods on a zero noise injection.
"""

def setUp(self):
self.waveform = IMRPhenomPv2()
self.psd_model = AdvancedLIGO()

self.injections = make_injection_zero_noise(waveform=IMRPhenomPv2,
injection_parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
"gpstime": 0,
"total_mass": 60 * u.solMass},
detectors={"AdvancedLIGOHanford": "AdvancedLIGO",
"AdvancedLIGOLivingston": "AdvancedLIGO"}
)

def test_likelihood_no_norm(self):
data = self.injections['H1']

from gwpy.plot import Plot
# from gwpy.plot import Plot

likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": 0.6,
Expand All @@ -89,12 +132,63 @@ def test_likelihood_no_norm(self):
phi_0=0, psi=0,
iota=0)

f = Plot(data, projected_waveform)
f.savefig("projected_waveform.png")
log_like = likelihood.log_likelihood(projected_waveform, norm=False)

log_like = likelihood.log_likelihood(projected_waveform)
self.assertTrue(log_like.cpu().numpy() <= 1e-5)

self.assertTrue(log_like <= 1e-5)

def test_likelihood_maximum_at_true_value_mass_ratio(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform).cpu().numpy())

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)


def test_likelihood_numpy_equivalent(self):

data = self.injections['H1']

likelihood = TimeDomainLikelihoodPyTorch(data, psd=self.psd_model)
numpy_likelihood = TimeDomainLikelihood(data, psd=self.psd_model)
mass_ratios = np.linspace(0.1, 1.0, 100)

log_likes = []
log_likes_n = []
for mass_ratio in mass_ratios:

test_waveform = self.waveform.time_domain(parameters={"distance": 1000*u.megaparsec,
"mass_ratio": mass_ratio,
"gpstime": 0,
"total_mass": 60 * u.solMass}, times=likelihood.times)
projected_waveform = test_waveform.project(AdvancedLIGOHanford(),
ra=0, dec=0,
gpstime=0,
phi_0=0, psi=0,
iota=0)

log_likes.append(likelihood.log_likelihood(projected_waveform).cpu().numpy())
log_likes_n.append(numpy_likelihood.log_likelihood(projected_waveform))

self.assertTrue(mass_ratios[np.argmax(log_likes)] == 0.6)
self.assertTrue(np.all((np.array(log_likes) - np.array(log_likes_n)) < 0.001))


class Test_Filter(unittest.TestCase):
Expand Down
Loading