Skip to content

Commit

Permalink
Merge pull request #25 from ColmTalbot/general-generator-access
Browse files Browse the repository at this point in the history
Add general generator interface
  • Loading branch information
ColmTalbot authored Mar 1, 2024
2 parents e9b50c4 + 0924068 commit 406a71d
Show file tree
Hide file tree
Showing 10 changed files with 146 additions and 76 deletions.
10 changes: 2 additions & 8 deletions .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,14 @@ jobs:
run: git fetch --prune --unshallow
- uses: s-weigand/setup-conda@v1
with:
python-version: 3.9
python-version: 3.11

- name: Install dependencies
run: |
sudo apt-get install texlive-latex-extra texlive-fonts-recommended dvipng cm-super
conda install pip setuptools
conda install flake8 pytest-cov
conda install -c conda-forge numpy scipy sympy
conda update -c conda-forge numpy scipy sympy
conda install -c conda-forge gwsurrogate python-lalsimulation sxs
conda install -c conda-forge matplotlib "basemap>=1.3.6"
conda install -c conda-forge --file requirements.txt --file optional_requirements.txt --file pages_requirements.txt
python -m pip install nrsur7dq2
conda install -c conda-forge ipykernel jupyter ipython
conda install -c conda-forge ipython_genutils jinja2 nbsphinx numpydoc pandoc pygments sphinx sphinx_rtd_theme
- name: Install gwmemory
run: |
pip install .
Expand Down
10 changes: 1 addition & 9 deletions .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,22 +38,14 @@ jobs:
run: |
conda install pip setuptools
conda install flake8 pytest-cov
conda install -c conda-forge numpy pandas scipy
conda install -c conda-forge gwsurrogate python-lalsimulation
conda install -c conda-forge sympy sxs
conda install -c conda-forge --file requirements.txt --file optional_requirements.txt
python -m pip install nrsur7dq2
- name: Install gwmemory
run: |
pip install .
- name: List installed
run: |
conda list
# - name: Run pre-commit checks
# run: |
# pre-commit run --all-files --verbose --show-diff-on-failure
# jupyter nbconvert --clear-output --inplace examples/*.ipynb
# codespell examples/*.ipynb -L "hist"
# git reset --hard
- name: Test with pytest
run: |
pytest --cov gwmemory -ra --color yes --cov-report=xml --junitxml=pytest.xml
Expand Down
20 changes: 17 additions & 3 deletions examples/Comparison.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
"outputs": [],
"source": [
"print(\"GWMemory time:\")\n",
"%time h_mem, times = tdm(h_lm=h, times=t)\n",
"%time h_mem, times = tdm(h_lm=h, times=t, l_max=4)\n",
"print(\"SXS time:\")\n",
"%time h_mem_sxs, times_sxs = sxs_memory(h, t)"
]
Expand All @@ -91,7 +91,10 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Now we plot the various memory modes along with the mismatches between the waveforms obtained with both methods."
"Now we plot the various memory modes along with the mismatches between the waveforms obtained with both methods.\n",
"\n",
"There are some differences in strongly subdominant modes, this is likely due to numerical error accumulating for small amplitude modes.\n",
"We explicitly skip these in the plotting below."
]
},
{
Expand All @@ -101,7 +104,11 @@
"outputs": [],
"source": [
"modes = set(h_mem.keys()).intersection(h_mem_sxs.keys())\n",
"fig, axes = plt.subplots(nrows=7, ncols=3, sharex=True, figsize=(20, 16))\n",
"for mode in modes.copy():\n",
" if max(abs(h_mem[mode])) < 1e-15:\n",
" modes.remove(mode)\n",
"\n",
"fig, axes = plt.subplots(nrows=4, ncols=3, sharex=True, figsize=(20, 16))\n",
"for ii, mode in enumerate(modes):\n",
" gwmem = h_mem[mode]\n",
" sxsmem = h_mem_sxs[mode]\n",
Expand All @@ -126,6 +133,13 @@
"plt.show()\n",
"plt.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {},
Expand Down
44 changes: 10 additions & 34 deletions gwmemory/gwmemory.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,49 +89,25 @@ def time_domain_memory(
times, array
Time series corresponding to the memory waveform.
"""
if h_lm is not None and times is not None:
wave = waveforms.MemoryGenerator(name=model, h_lm=h_lm, times=times)
elif "NRSur" in model or "NRHybSur" in model:
all_keys = inspect.signature(waveforms.Surrogate).parameters.keys()
model_kwargs = {key: kwargs[key] for key in all_keys if key in kwargs}
wave = waveforms.Surrogate(
q=q,
from .waveforms import memory_generator

if h_lm is not None and times is not None and model is None:
model = "base"

kwargs.update(
dict(
name=model,
total_mass=total_mass,
spin_1=spin_1,
spin_2=spin_2,
distance=distance,
times=times,
**model_kwargs,
)
elif "EOBNR" in model or "Phenom" in model:
all_keys = inspect.signature(waveforms.Approximant).parameters.keys()
model_kwargs = {key: kwargs[key] for key in all_keys if key in kwargs}
wave = waveforms.Approximant(
h_lm=h_lm,
q=q,
name=model,
total_mass=total_mass,
spin_1=spin_1,
spin_2=spin_2,
distance=distance,
times=times,
**model_kwargs,
)
elif model == "MWM":
all_keys = inspect.signature(waveforms.MWM).parameters.keys()
model_kwargs = {key: kwargs[key] for key in all_keys if key in kwargs}
wave = waveforms.MWM(
q=q,
name=model,
total_mass=total_mass,
distance=distance,
times=times,
**model_kwargs,
)
else:
print(f"Model {model} unknown")
return None
)

wave = memory_generator(model=model, **kwargs)
all_keys = inspect.signature(wave.time_domain_memory).parameters.keys()
function_kwargs = {key: kwargs[key] for key in all_keys if key in kwargs}
h_mem, times = wave.time_domain_memory(inc=inc, phase=phase, **function_kwargs)
Expand Down
78 changes: 73 additions & 5 deletions gwmemory/waveforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,74 @@
import inspect

from .base import MemoryGenerator # isort: skip
from . import approximant, mwm, nr, surrogate
from .approximant import Approximant
from .mwm import MWM
from .nr import SXSNumericalRelativity
from .surrogate import Surrogate
from . import approximant, mwm, nr, surrogate # isort: skip
from .approximant import Approximant # isort: skip
from .mwm import MWM # isort: skip
from .nr import SXSNumericalRelativity # isort: skip
from .surrogate import Surrogate # isort: skip


GENERATORS = dict(
base=MemoryGenerator,
EOBNR=Approximant,
MWM=MWM,
NRHybSur=Surrogate,
NRSur=Surrogate,
Phenom=Approximant,
SXS=SXSNumericalRelativity,
)


def memory_generator(model, **kwargs):
"""
Create a memory generator from any of the registered classes.
Parameters
==========
model: str
The name of the model to use.
kwargs:
Arguments to pass to the :code:`__init__` method of the generator.
Returns
=======
MemoryGenerator
The memory generator instance.
"""
cls_ = None
for key, value in GENERATORS.items():
if key in model:
cls_ = value
if cls_ is None:
raise ValueError(
f"Unknown waveform generator {model}. "
"Should match one of {GENERATORS.keys()}"
)
all_keys = inspect.signature(cls_).parameters.keys()
model_kwargs = {key: kwargs[key] for key in all_keys if key in kwargs}
return cls_(**model_kwargs)


def register_generator(model: str, generator: MemoryGenerator):
"""
Register a new memory generator.
If you have implemented a new memory generator (that matches the API)
here, you can use this to automatically have it be found by
:code:`gwmemory.time_domain_memory`.
Parameters
==========
model: str
The name to register the model as.
generator: MemoryGenerator
The new model to register.
"""
import warnings

if model in GENERATORS:
warnings.warn(f"Overwriting previously registered model {model}.")
else:
for key in GENERATORS:
if key in model:
warnings.warn(f"Collision with existing model {key}.")
GENERATORS[model] = generator
10 changes: 1 addition & 9 deletions gwmemory/waveforms/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import warnings
from typing import Tuple

import numpy as np
Expand All @@ -10,7 +9,7 @@
from ..utils import CC, GG, MPC, SOLAR_MASS, combine_modes


class MemoryGenerator(object):
class MemoryGenerator:
def __init__(self, name: str, h_lm: dict, times: np.ndarray, l_max: int = 4):

self.name = name
Expand Down Expand Up @@ -53,7 +52,6 @@ def time_domain_memory(
inc: float = None,
phase: float = None,
modes: list = None,
gamma_lmlm: dict = None,
) -> Tuple[dict, np.ndarray]:
"""
Calculate the spherical harmonic decomposition of the nonlinear
Expand All @@ -70,9 +68,6 @@ def time_domain_memory(
modes: list
The modes to consider when computing the memory. By default all
available modes will be used.
gamma_lmlm: dict, deprecated
Dictionary of arrays defining the angular dependence of the
different memory modes, these are now computed/cached on the fly.
Return
------
Expand Down Expand Up @@ -100,9 +95,6 @@ def time_domain_memory(
except KeyError:
pass

if gamma_lmlm is not None:
warnings.warn(f"The gamma_lmlm argument is deprecated and will be removed.")

# constant terms in SI units
const = 1 / 4 / np.pi
if self.distance is not None:
Expand Down
6 changes: 4 additions & 2 deletions optional_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
lalsuite
NRSur7dq2
python-lalsimulation
gwsurrogate
sxs
pytest
12 changes: 9 additions & 3 deletions pages_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
basemap>=1.3.6
ipykernel
ipython
ipython_genutils
jinja2
sphinx
pygments
numpydoc
nbsphinx
numpydoc
pandoc
pygments
pytest-cov
sphinx
sphinx_rtd_theme
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy
scipy
pandas
deepdish
sympy
29 changes: 28 additions & 1 deletion test/waveform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@

import gwmemory
from gwmemory import frequency_domain_memory, time_domain_memory
from gwmemory.waveforms import Approximant, Surrogate, SXSNumericalRelativity
from gwmemory.waveforms import (
GENERATORS,
Approximant,
Surrogate,
SXSNumericalRelativity,
memory_generator,
register_generator,
)

TEST_MODELS = [
"IMRPhenomD",
Expand Down Expand Up @@ -112,6 +119,10 @@ def test_memory_matches_sxs():
modes = set(h_mem.keys()).intersection(h_mem_sxs.keys())

for ii, mode in enumerate(modes):
# skip small amplitude modes as there is some kind of
# numerical noise issue
if max(abs(h_mem[mode])) < 1e-10:
continue
gwmem = h_mem[mode]
sxsmem = h_mem_sxs[mode]
overlap = (
Expand All @@ -121,3 +132,19 @@ def test_memory_matches_sxs():
)
assert overlap.real > 1 - 1e-8
assert abs(overlap.imag) < 1e-5


def test_register_generator():
with pytest.warns() as record:
register_generator("test", Surrogate)
register_generator("test", Surrogate)
register_generator("test2", Surrogate)
assert len(record) == 2
assert str(record[0].message).startswith("Overwriting")
assert str(record[1].message).startswith("Collision")
assert GENERATORS["test"] == Surrogate


def test_unknown_model_raises():
with pytest.raises(ValueError):
memory_generator("unknown")

0 comments on commit 406a71d

Please sign in to comment.