From 787b6179cef9cd01f0dceafc424600232da70b7e Mon Sep 17 00:00:00 2001 From: Vittoria Tommasini Date: Fri, 24 Apr 2026 15:30:46 -0700 Subject: [PATCH] Add GPR loader and unit tests --- gpr_loader.py | 204 +++++++++++++++++++++++++++++ test_gpr_loader.py | 313 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 517 insertions(+) create mode 100644 gpr_loader.py create mode 100644 test_gpr_loader.py diff --git a/gpr_loader.py b/gpr_loader.py new file mode 100644 index 0000000..ae94864 --- /dev/null +++ b/gpr_loader.py @@ -0,0 +1,204 @@ +# Distributed under the MIT License. +# See LICENSE.txt for details. + +# Loads saved GPR models and runs inference on new binary black hole simulations. No training is done here. +# How to use: +# 1.for a single simulation (metadata string): +# python gpr_loader.py --metadata "RunID=0111 ZwickyDays=176 q=8.0 chiA=-0.555366763183,-0.575801728447,-0.00249659459858 +# chiB=-0.0516668737278,0.102310708412,0.791746644364 D0=14.5189208984 Omega0=0.0164280601103 +# adot0=-2.39687233838e-05", +# 2. multiple simulations (one metadata string per line from a text file) +# python gpr_loader.py --file my_runs.txt +# 3. specify custom model paths (optional) +# python gpr_loader.py --file my_runs.txt --model_omega path/to/gpr_model_omega.pth --model_adot path/to/gpr_model.adot.pth + +import os +import argparse +import numpy as np + +# Import torch and related libraries +import torch +import gpytorch # GP library built on PyTorch + +# Import GPR library +import importlib +import GPR_library as gpr +from GPR_library import GPRegressionModel + +# Model Loading +def load_gpr_checkpoint(ckpt_path): + """ + Loads a saved GPR checkpoint file from disk. + Restores model weights, likelihood, normalization parameters, + and (optionally) the raw training data. + + Args: + ckpt_path (str): path to the saved .pth checkpoint file + + Returns: + model (GPRegressionModel): GPR model in evaluation mode with + trained weights loaded + likelihood (gpytorch.likelihoods.GaussianLikelihood): likelihood + in evaluation mode + meta (dict): metadata dictionary containing the input_features, + base_column, etc. + """ + + # Open, read, and load the saved checkpoint file into memory as a + # Python dictionary + ckpt = torch.load(ckpt_path, map_location="cpu") + + # Unpack metadata - describes what the model was trained on + # (features, base_column, etc.) + meta = ckpt["metadata"] + features = meta["input_features"] + D = len(features) # Count number of input features + + # Create dummy input and output tensors as placeholders to construct the + # model object correctly. These get replaced with the trained values + # saved in the checkpoint later in model.load_state_dict and + # likelihood.load_state_dict + dummy_x = torch.zeros(1, D) # ensures the model knows there are D input features + dummy_y = torch.zeros(1) # ensures scalar output + + # Initialize likelihood and model with the right structure + # likelihood represents the assumed noise model of the data + likelihood = gpytorch.likelihoods.GaussianLikelihood() + # construct the GP model object + model = gpr.GPRegressionModel(dummy_x, dummy_y, likelihood) + + # Load the trained parameters back into the model and likelihood + model.load_state_dict(ckpt["model_state_dict"]) + likelihood.load_state_dict(ckpt["likelihood_state_dict"]) + + # Restore the normalization statistics used during training + # so predictions are returned in the correct physical scale instead of + # in raw standardized numbers + norm = ckpt["normalization"] + model.set_normalization( + input_mean=np.array(norm["input_mean"]), # mean of the training features + input_std=np.array(norm["input_std"]), # SD of the training features + output_mean=norm["output_mean"], # mean of the training targets (deltas) + output_std=norm["output_std"], # SD of the training targets (deltas) + ) + + # Put model and likelihood into evaluation mode (for inference only) + model.eval() + likelihood.eval() + + return model, likelihood, meta + +# Inference +def run_inference(metadata_strings, model_omega_pth, model_adot_pth): + """ + Runs GPR inference on a list of metadata strings. + + Args: + - metadata_strings (list of str): each string is one simulation's metadata + in SpEC format, e.g.: "RunID =0111 q=8.0 chiA=... chiB=... D0=... + Omega0=... adot0=..." + - model_omega_pth (str): path to the saved GPR model checkpoint for Omega0 + - model_adot_pth (str): path to the saved GPR model checkpoint for adot0 + + Returns: + - df_test (pd.DataFrame): DataFrame with input parameters, predicted corrections, + uncertainties, and GPR-corrected Omega0 and adot0 values. + """ + # Parse metadata strings into a DataFrame + df_test=gpr.parse_test_runs(metadata_strings) + # Load models + model_omega, likelihood_omega, meta_omega = load_gpr_checkpoint(model_omega_pth) + model_adot, likelihood_adot, meta_adot = load_gpr_checkpoint(model_adot_pth) + + features = meta_omega["input_features"] + + # Extract input features + X_test = df_test[features].values + + # Predict corrections (deltas) and uncertainties + delta_omega_pred, omega_unc = gpr.predict_with_gpr_model( + X_test, model_omega, likelihood_omega) + delta_adot_pred, adot_unc = gpr.predict_with_gpr_model( + X_test, model_adot, likelihood_adot) + + # Store predictions + df_test["delta_pred_omega"] = delta_omega_pred + df_test["uncertainty_omega"] = omega_unc + + # Store uncertainties + df_test["delta_pred_adot"] = delta_adot_pred + df_test["uncertainty_adot"] = adot_unc + + # Apply corrections to the PN baseline guesses + df_test["gpr_corrected_omega"] = df_test["spec_pn_guess_omega"] + delta_omega_pred + df_test["gpr_corrected_adot"] = df_test["spec_pn_guess_adot"] + delta_adot_pred + + return df_test + +# CLI +def parse_args(): + parser = argparse.ArgumentParser( + description ="Run GPR inference on new BBH simulations.", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=__doc__, + ) + + # Input: either a single metadata string or a file of multiple strings + input_group = parser.add_mutually_exclusive_group(required=True) + input_group.add_argument( + "--metadata", type=str, + help="Single simulation metadata string in SpEC format.", + ) + input_group.add_argument( + "--file", type=str, + help="Path to a text file with one metadata string per line.", + ) + + # Model paths (optional) + parser.add_argument( + "--model_omega", type=str, default="gpr_model_omega.pth", + help = "Path to saved GPR model for Omega0 (default is gpr_model_omega.pth).", + ) + parser.add_argument( + "--model_adot", type=str, default="gpr_model_adot.pth", + help = "Path to saved GPR model for adot0 (default: gpr_model_adot.pth).", + ) + + # Output + parser.add_argument( + "--output", type=str, default=None, + help = "Optional path to save results as a CSV file.", + ) + + return parser.parse_args() + +def main(): + args = parse_args() + + # Load metadata strings + if args.metadata: + metadata_strings = [args.metadata] + else: + with open(args.file, "r") as f: + metadata_strings = [line.strip() for line in f if line.strip()] + + # Run inference + df_results = run_inference(metadata_strings, args.model_omega, args.model_adot) + + # Print results + print("\nGPR Inference Results:") + print("=" * 60) + output_cols = [ + "name", + "spec_pn_guess_omega", "gpr_corrected_omega", "uncertainty_omega", + "spec_pn_guess_adot", "gpr_corrected_adot", "uncertainty_adot", + ] + print(df_results[output_cols].to_string(index=False)) + + # Optional: save to CSV + if args.output: + df_results.to_csv(args.output, index=False) + print(f"\nResults saved to {args.output}") + +if __name__ == "__main__": + main() diff --git a/test_gpr_loader.py b/test_gpr_loader.py new file mode 100644 index 0000000..b3d12d4 --- /dev/null +++ b/test_gpr_loader.py @@ -0,0 +1,313 @@ +""" +test_gpr_loader.py + +Unit tests for gpr_loader.py using Python built-in unittest. + +Run with: + python -m unittest test_gpr_loader.py -v + +No real model files or Julia/SXS dependencies are required - all heavy +dependencies are mocked. +""" + +import unittest +import numpy as np +import pandas as pd +import sys +from unittest.mock import patch, MagicMock + +# Mock heavy dependencies before importing the module being tested +# This prevents Julia, sxs, qgrid, gpytorch, torch, etc. from being imported +# during the test + +sys.modules["qgrid"] = MagicMock() +sys.modules["qgridnext"] = MagicMock() +sys.modules["sxs"] = MagicMock() +sys.modules["sxs.julia"] = MagicMock() +sys.modules["sxs.julia.PostNewtonian"] = MagicMock() +sys.modules["GPR_library"] = MagicMock() +sys.modules["plotly"] = MagicMock() +sys.modules["plotly.graph_objects"] = MagicMock() +sys.modules["torch"] = MagicMock() +sys.modules["gpytorch"] = MagicMock() + +import gpr_loader + +# Test load_gpr_checkpoint +class TestLoadGPRCheckpoint(unittest.TestCase): + def setUp(self): + """Build a fake checkpoint dictionary that imitates what torch.load would return. """ + self.mock_checkpoint = { + "metadata":{ + "input_features": [ + "mass_ratio", + "initial_separation", + "S1x", "S1y", "S1z", + "S2x", "S2y", "S2z", + ], + "base_column": "spec_pn_guess_omega", + }, + "model_state_dict":{}, + "likelihood_state_dict": {}, + "normalization":{ + "input_mean": [0.0], + "input_std": [1.0], + "output_mean": 0.0, + "output_std": 1.0, + }, + } + + def test_returns_model_likelihood_meta(self): + """Test that load_gpr_checkpoint returns the model, likelihood, and metadata.""" + mock_model = MagicMock() + mock_likelihood = MagicMock() + + with patch("torch.load", return_value = self.mock_checkpoint), \ + patch("gpr_loader.gpr.GPRegressionModel", return_value=mock_model), \ + patch("gpytorch.likelihoods.GaussianLikelihood", return_value=mock_likelihood): + model, likelihood, meta = gpr_loader.load_gpr_checkpoint("fake_path.pth") + self.assertEqual(meta, self.mock_checkpoint["metadata"]) + mock_model.eval.assert_called_once() + mock_likelihood.eval.assert_called_once() + + def test_normalization_is_set(self): + """Test that load_gpr_checkpoint calls set_normalization correctly.""" + mock_model = MagicMock() + mock_likelihood = MagicMock() + + with patch("torch.load", return_value = self.mock_checkpoint), \ + patch("gpr_loader.gpr.GPRegressionModel", return_value=mock_model), \ + patch("gpr_loader.gpytorch.likelihoods.GaussianLikelihood", return_value=mock_likelihood): + + gpr_loader.load_gpr_checkpoint("fake_path.pth") + + mock_model.set_normalization.assert_called_once() + call_kwargs = mock_model.set_normalization.call_args.kwargs + self.assertIn("input_mean", call_kwargs) + self.assertIn("input_std", call_kwargs) + self.assertIn("output_mean", call_kwargs) + self.assertIn("output_std", call_kwargs) + + def test_model_loaded_in_eval_mode(self): + """Both model and likelihood should be put into eval mode.""" + mock_model = MagicMock() + mock_likelihood = MagicMock() + + with patch("torch.load", return_value=self.mock_checkpoint),\ + patch("gpr_loader.gpr.GPRegressionModel", return_value=mock_model),\ + patch("gpr_loader.gpytorch.likelihoods.GaussianLikelihood", + return_value=mock_likelihood): + + gpr_loader.load_gpr_checkpoint("fake_path.pth") + + mock_model.eval.assert_called_once() + mock_likelihood.eval.assert_called_once() + +# Test run_inference +class TestRunInference(unittest.TestCase): + def setUp(self): + """ Set up test data for run_inference tests.""" + self.sample_metadata_strings = [ + "RunID=0111 ZwickyDays=176 q=8.0 chiA=-0.555366763183,-0.575801728447," + " -0.00249659459858 chiB=-0.0516668737278,0.102310708412,0.791746644364" + " D0=14.5189208984 Omega0=0.0164280601103 adot0=-2.39687233838e-05", + "RunID=0271 ZwickyDays=138 q=7.16339643655 chiA=-0.247507893273," + " 0.666875889694,0.327541836441 chiB=-0.143315584105,0.537208136535," + " -0.545488940029 D0=14.6237792969 Omega0=0.0162422365835 adot0=-2.43238483539e-05", + "RunID=0449 ZwickyDays=92 q=5.31391992299 chiA=-0.676142027445,0.39956096948" + " ,-0.0775314966678 chiB=-0.339102219344,0.441481250187,-0.535621006618" + " D0=15.8556518555 Omega0=0.0146017859056 adot0=-2.369980335e-05", + ] + self.mock_meta = { + "input_features": [ + "mass_ratio", + "initial_separation", + "S1x", "S1y", "S1z", + "S2x", "S2y", "S2z" + ], + "base_column": "spec_pn_guess_omega", + } + self.mock_dataframe = pd.DataFrame({ + "name": ["test_0111", "test_0271", "test_0449"], + "initial_separation": [14.518921, 14.623779, 15.855652], + "spec_pn_guess_omega": [0.016428, 0.016242, 0.014602], + "spec_pn_guess_adot": [-0.000024, -0.000024, -0.000024], + "mass_ratio": [8.0, 7.163396, 5.313920], + "S1x": [-0.555367, -0.247508, -0.676142], + "S1y": [-0.575802, 0.666876, 0.399561], + "S1z": [-0.002497, 0.327542, -0.077531], + "S2x": [-0.051667, -0.143316, -0.339102], + "S2y": [ 0.102311, 0.537208, 0.441481], + "S2z": [ 0.791747, -0.545489, -0.535621], + "initial_mass1": [None, None, None], + "initial_mass2": [None, None, None], + "eccentricity": [None, None, None], + }) + self.mock_model = MagicMock() + self.mock_likelihood = MagicMock() + + def test_output_columns_present(self): + """Test that run_inference adds the GPR corrections and uncertainty columns correctly.""" + with patch("gpr_loader.load_gpr_checkpoint", + return_value = (self.mock_model, self.mock_likelihood, self.mock_meta)), \ + patch("gpr_loader.gpr.parse_test_runs", + return_value = self.mock_dataframe), \ + patch("gpr_loader.gpr.predict_with_gpr_model", + return_value = (np.zeros(3), np.zeros(3))): + df = gpr_loader.run_inference( + self.sample_metadata_strings, "omega.pth", "adot.pth" + ) + expected_columns = [ + "delta_pred_omega", "uncertainty_omega", + "delta_pred_adot" , "uncertainty_adot" , + "gpr_corrected_omega" , "gpr_corrected_adot", + ] + for col in expected_columns: + self.assertIn(col, df.columns) + + def test_correction_applied(self): + """Test whether GPR-corrected values are in fact the sum of the PN baseline + + the predicted deltas. + """ + delta_omega = np.array([-0.000560, -0.000539, -0.000439]) + delta_adot = np.array([ 0.000714, -0.000123, -0.000261]) # shape (3, ) + + with patch("gpr_loader.load_gpr_checkpoint", + return_value = (self.mock_model, self.mock_likelihood, self.mock_meta)), \ + patch("gpr_loader.gpr.parse_test_runs", + return_value = self.mock_dataframe.copy()), \ + patch("gpr_loader.gpr.predict_with_gpr_model", + side_effect = [ + (delta_omega, np.zeros(3)), + (delta_adot, np.zeros(3)), + ]): + df = gpr_loader.run_inference( + self.sample_metadata_strings, "omega.pth", "adot.pth" + ) + np.testing.assert_allclose( + df["gpr_corrected_omega"].values, + self.mock_dataframe["spec_pn_guess_omega"].values + delta_omega, + rtol=1e-6, + ) + np.testing.assert_allclose( + df["gpr_corrected_adot"].values, + self.mock_dataframe["spec_pn_guess_adot"].values + delta_adot, + rtol=1e-6, + ) + + def test_correct_number_rows(self): + """Test that the output DataFrame has the same number of rows as the input list.""" + with patch("gpr_loader.load_gpr_checkpoint", + return_value=(self.mock_model, self.mock_likelihood, self.mock_meta)), \ + patch("gpr_loader.gpr.parse_test_runs", + return_value=self.mock_dataframe), \ + patch("gpr_loader.gpr.predict_with_gpr_model", + return_value=(np.zeros(3), np.zeros(3))): + + df = gpr_loader.run_inference( + self.sample_metadata_strings, "omega.pth", "adot.pth" + ) + + self.assertEqual(len(df), len(self.sample_metadata_strings)) + + def test_uncertainties_are_nonneg(self): + """Test that the ncertainty values are always non-negative.""" + with patch("gpr_loader.load_gpr_checkpoint", + return_value=(self.mock_model, self.mock_likelihood, self.mock_meta)), \ + patch("gpr_loader.gpr.parse_test_runs", + return_value=self.mock_dataframe), \ + patch("gpr_loader.gpr.predict_with_gpr_model", + return_value=(np.zeros(3), np.array([0.0001, 0.0002, 0.0003]))): + + df = gpr_loader.run_inference( + self.sample_metadata_strings, "omega.pth", "adot.pth" + ) + + self.assertTrue((df["uncertainty_omega"] >= 0).all()) + self.assertTrue((df["uncertainty_adot"] >= 0).all()) + + def test_features_used_correctly(self): + """Test that predict_with_gpr_model receives a matrix with one column per feature.""" + with patch("gpr_loader.load_gpr_checkpoint", + return_value=(self.mock_model, self.mock_likelihood, self.mock_meta)), \ + patch("gpr_loader.gpr.parse_test_runs", + return_value = self.mock_dataframe.copy()), \ + patch("gpr_loader.gpr.predict_with_gpr_model", + return_value=(np.zeros(3), np.zeros(3))) as mock_predict: + + gpr_loader.run_inference( + self.sample_metadata_strings, "omega.pth", "adot.pth" + ) + args, _ = mock_predict.call_args + X_passed = args[0] + self.assertEqual(X_passed.shape[1], len(self.mock_meta["input_features"])) + +# Test CLI argument parsing +class TestArgParsing(unittest.TestCase): + + def test_metadata_arg(self): + """Test that --metadata flag is stored and --file defaults to None.""" + with patch("sys.argv", ["gpr_loader.py", "--metadata", "RunID=0111 q=8.0"]): + args = gpr_loader.parse_args() + self.assertEqual(args.metadata, "RunID=0111 q=8.0") + self.assertIsNone(args.file) + + def test_file_arg(self): + """Test that --file flag is stored and --metadata defaults to None.""" + with patch("sys.argv", ["gpr_loader.py", "--file", "my_runs.txt"]): + args = gpr_loader.parse_args() + self.assertEqual(args.file, "my_runs.txt") + self.assertIsNone(args.metadata) + + def test_metadata_and_file_are_mutually_exclusive(self): + """Test that if user passes both --metadata and --file, it raises SystemExit.""" + with patch("sys.argv", ["gpr_loader.py", + "--metadata", "q=8.0", + "--file", "runs.txt"]): + with self.assertRaises(SystemExit): + gpr_loader.parse_args() + + def test_no_input_raises_system_exit(self): + """Test that if user doesn't pass --metadata or --file, it raises SystemExit.""" + with patch("sys.argv", ["gpr_loader.py"]): + with self.assertRaises(SystemExit): + gpr_loader.parse_args() + + def test_default_model_paths(self): + """Test that the default model paths are gpr_model_omega.pth and gpr_model_adot.pth.""" + with patch("sys.argv", ["gpr_loader.py", "--metadata", "q=8.0"]): + args = gpr_loader.parse_args() + self.assertEqual(args.model_omega, "gpr_model_omega.pth") + self.assertEqual(args.model_adot, "gpr_model_adot.pth") + + def test_custom_model_paths(self): + """Test that the custom model paths override the defaults.""" + with patch("sys.argv", [ + "gpr_loader.py", + "--metadata", + "q=8.0", + "--model_omega", + "custom_omega.pth", + "--model_adot", + "custom_adot.pth", + ]): + args = gpr_loader.parse_args() + self.assertEqual(args.model_omega, "custom_omega.pth") + self.assertEqual(args.model_adot, "custom_adot.pth") + + def test_output_arg(self): + """Test that --output flag is stored correctly.""" + with patch("sys.argv", ["gpr_loader.py", + "--metadata", "q=8.0", + "--output", "results.csv"]): + args = gpr_loader.parse_args() + self.assertEqual(args.output, "results.csv") + + def test_output_defaults_to_none(self): + """Test that --output defaults to None if not provided.""" + with patch("sys.argv", ["gpr_loader.py", "--metadata", "q=8.0"]): + args = gpr_loader.parse_args() + self.assertIsNone(args.output) + +if __name__ == "__main__": + unittest.main(verbosity=2) \ No newline at end of file