Skip to content

Commit

Permalink
Improve RV algebra
Browse files Browse the repository at this point in the history
  • Loading branch information
matejak committed Jun 16, 2023
1 parent bc2cee9 commit 0d21097
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 3 deletions.
9 changes: 9 additions & 0 deletions estimage/entities/estimate.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,15 @@ def divide_by_gauss_pdf(self, num_samples, mean, stdev):
return ret

def _divide_by_gauss_general_estimate(self, dom, mean, stdev):
if stdev == 0:
return self._divide_by_point_gauss_general_estimate(dom, mean)
return self._divide_by_general_gauss_general_estimate(dom, mean, stdev)

def _divide_by_point_gauss_general_estimate(self, dom, mean):
_, pert = self.get_pert(len(dom))
return pert

def _divide_by_general_gauss_general_estimate(self, dom, mean, stdev):
values = np.zeros_like(dom)
inner_resolution = len(dom) * 2
inner_domain, pert = self.get_pert(inner_resolution)
Expand Down
17 changes: 17 additions & 0 deletions estimage/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,20 @@ def norm_pdf(values, dx):
norming_factor = values.sum() * dx
if norming_factor:
values[:] /= norming_factor


def get_random_variable(dom, values):
import numpy as np
import scipy as sp

class Rv(sp.stats.rv_continuous):
def __init__(self, * args, ** kwargs):
super().__init__(* args, ** kwargs)
self.a = dom[0]
self.b = dom[-1]

self._pdf = sp.interpolate.interp1d(dom, values, kind="linear", bounds_error=False, fill_value=0)
self._cdf = sp.interpolate.interp1d(dom, np.cumsum(values) / sum(values), kind="linear")
self._ppf = sp.interpolate.interp1d(np.cumsum(values) / sum(values), dom, kind="linear")

return Rv()(loc=0, scale=1)
24 changes: 21 additions & 3 deletions tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import estimage.data as tm
from estimage.entities import estimate
from estimage import utilities


@pytest.fixture
Expand Down Expand Up @@ -433,18 +434,21 @@ def assert_sampling_corresponds_to_pdf(domain, generated, predicted, relative_di
predicted /= np.max(predicted)

high_diff = np.quantile(np.abs(predicted - histogram), 0.95)
# plot_diffs(domain, histogram, predicted)
if high_diff > relative_diff:
plot_diffs(domain, histogram, predicted)
assert high_diff < relative_diff


def plot_diffs(dom, histogram, predicted):
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use("Agg")
fig, ax = plt.subplots()
ax.plot(dom, histogram, label="simulation")
ax.plot(dom, predicted, label="prediction")
ax.legend()
ax.grid()
plt.show()
fig.savefig("testfail.png")


def test_rv_algebra_addition():
Expand Down Expand Up @@ -484,7 +488,7 @@ def test_rv_algebra_gauss_division():

def test_rv_algebra_division():
num_trials = 100000
num_samples = 100
num_samples = 150

gauss_mean = 1.5
gauss_std = 0.4
Expand All @@ -508,6 +512,20 @@ def test_rv_algebra_division():
generated_pert = e1.pert_rvs(num_trials)
generated_normal = sp.stats.norm.rvs(loc=gauss_mean, scale=gauss_std, size=num_trials)
assert_sampling_corresponds_to_pdf(dom, generated_pert / generated_normal, pdf)
random_variable = utilities.get_random_variable(dom, pdf)
# TODO: The following test is somewhat fragile. Perhaps there is a sampling difficulty of the PDF?
assert_sampling_corresponds_to_pdf(dom, random_variable.rvs(num_trials * 5), pdf)

old_dom, old_pdf = e1.get_pert(num_samples)
dom, pdf = e1.divide_by_gauss_pdf(num_samples, gauss_mean, 0)
assert (old_dom / dom).mean() == pytest.approx(gauss_mean)
assert (old_pdf - pdf).mean() == pytest.approx(0)

dom, pdf = e1.divide_by_gauss_pdf(num_samples, gauss_mean, 0.0001)
assert (old_dom / dom).mean() == pytest.approx(gauss_mean, rel=1e-2)
assert (old_pdf - pdf).mean() == pytest.approx(0)
generated_pert = e1.pert_rvs(num_trials)
assert_sampling_corresponds_to_pdf(dom, generated_pert / gauss_mean, pdf)

gauss_mean = 2
gauss_std = 0.4
Expand Down

0 comments on commit 0d21097

Please sign in to comment.