Skip to content

Commit

Permalink
Smart sampling (#769)
Browse files Browse the repository at this point in the history
  • Loading branch information
HDembinski authored Jul 27, 2022
1 parent 9ce3048 commit 83e60f8
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 16 deletions.
35 changes: 22 additions & 13 deletions src/iminuit/cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
make_func_code,
merge_signatures,
PerformanceWarning,
_smart_sampling,
)
import numpy as np
from collections.abc import Sequence
Expand Down Expand Up @@ -673,7 +674,7 @@ def __init__(self, data, model: _tp.Callable, verbose: int, log: bool):
self._log = log
super().__init__(describe(model)[1:], _norm(data), verbose)

def visualize(self, args: _ArrayLike, model_points: int = 50):
def visualize(self, args: _ArrayLike, model_points: int = 0):
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -684,22 +685,26 @@ def visualize(self, args: _ArrayLike, model_points: int = 50):
args : array-like
Parameter values.
model_points : int, optional
How many points to use to draw the model. Default is 50.
How many points to use to draw the model. Default is 0, in this case
an smart sampling algorithm selects the number of points.
"""
from matplotlib import pyplot as plt

if self.data.ndim > 1:
raise ValueError("visualize is not implemented for multi-dimensional data")

# TODO
# - use log-binning if data spans over many orders of magnitude
# - make this configurable
n, xe = np.histogram(self.data, bins=50)
cx = 0.5 * (xe[1:] + xe[:-1])
plt.errorbar(cx, n, n**0.5, fmt="ok")
xm = np.linspace(xe[0], xe[-1], model_points)
if model_points > 0:
if xe[0] > 0 and xe[-1] / xe[0] > 1e2:
xm = np.geomspace(xe[0], xe[-1], model_points)
else:
xm = np.linspace(xe[0], xe[-1], model_points)
ym = self.scaled_pdf(xm, *args)
else:
xm, ym = _smart_sampling(lambda x: self.scaled_pdf(x, *args), xe[0], xe[-1])
dx = xe[1] - xe[0]
ym = self.scaled_pdf(xm, *args)
plt.fill_between(xm, 0, ym * dx, fc="C0")


Expand Down Expand Up @@ -1413,7 +1418,7 @@ def _call(self, args):
ym = _normalize_model_output(ym)
return self._cost(y, yerror, ym)

def visualize(self, args: _ArrayLike, model_points: int = 50):
def visualize(self, args: _ArrayLike, model_points: int = 0):
"""
Visualize data and model agreement (requires matplotlib).
Expand All @@ -1425,7 +1430,8 @@ def visualize(self, args: _ArrayLike, model_points: int = 50):
Parameter values.
model_points : int, optional
How many points to use to draw the model. Default is 50.
How many points to use to draw the model. Default is 0, in this case
an smart sampling algorithm selects the number of points.
"""
from matplotlib import pyplot as plt

Expand All @@ -1437,11 +1443,14 @@ def visualize(self, args: _ArrayLike, model_points: int = 50):

x, y, ye = self._masked.T
plt.errorbar(x, y, ye, fmt="ok")
if x[0] > 0 and x[-1] / x[0] > 1e2:
xm = np.geomspace(x[0], x[-1], model_points)
if model_points > 0:
if x[0] > 0 and x[-1] / x[0] > 1e2:
xm = np.geomspace(x[0], x[-1], model_points)
else:
xm = np.linspace(x[0], x[-1], model_points)
ym = self.model(xm, *args)
else:
xm = np.linspace(x[0], x[-1], model_points)
ym = self.model(xm, *args)
xm, ym = _smart_sampling(lambda x: self.model(x, *args), x[0], x[-1])
plt.plot(xm, ym)


Expand Down
38 changes: 38 additions & 0 deletions src/iminuit/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,3 +1446,41 @@ def _show_inline_matplotlib_plots():
or mpl.get_backend() == "module://matplotlib_inline.backend_inline"
):
flush_figures() # pragma: no cover


def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3):
x = np.linspace(xmin, xmax, start)
ynew = f(x)
ymin = np.min(ynew)
ymax = np.max(ynew)
y = {xi: yi for (xi, yi) in zip(x, ynew)}
a = x[:-1]
b = x[1:]
while len(a):
if len(y) > 10000:
warnings.warn("Too many points", RuntimeWarning) # pragma: no cover
break # pragma: no cover
xnew = 0.5 * (a + b)
ynew = f(xnew)
ymin = min(ymin, np.min(ynew))
ymax = max(ymax, np.max(ynew))
for xi, yi in zip(xnew, ynew):
y[xi] = yi
yint = 0.5 * (
np.fromiter((y[ai] for ai in a), float)
+ np.fromiter((y[bi] for bi in b), float)
)
dy = np.abs(ynew - yint)

mask = dy > tol * (ymax - ymin)

# intervals which do not pass interpolation test
a = a[mask]
b = b[mask]
xnew = xnew[mask]
a = np.append(a, xnew)
b = np.append(xnew, b)

xy = list(y.items())
xy.sort()
return np.transpose(xy)
12 changes: 9 additions & 3 deletions tests/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,12 @@ def test_UnbinnedNLL_properties(log):
@pytest.mark.parametrize("log", (False, True))
def test_UnbinnedNLL_visualize(log):
c = UnbinnedNLL([1, 2], norm_logpdf if log else norm_pdf, log=log)
c.visualize((1, 2))
c.visualize((1, 2)) # auto-sampling
c.visualize((1, 2), model_points=10) # linear spacing

# trigger log-spacing
c = UnbinnedNLL([1, 1000], norm_logpdf if log else norm_pdf, log=log)
c.visualize((1, 2), model_points=10)


@pytest.mark.skipif(not scipy_stats_available, reason="scipy.stats is needed")
Expand Down Expand Up @@ -746,11 +751,12 @@ def line(x, a, b):

c = LeastSquares([1, 2], [2, 3], 0.1, line)

c.visualize((1, 2))
c.visualize((1, 2)) # auto-sampling
c.visualize((1, 2), model_points=10) # linear spacing

# trigger use of log-spacing
c = LeastSquares([1, 2000], [2, 3], 0.1, line)
c.visualize((1, 2))
c.visualize((1, 2), model_points=10)


@pytest.mark.skipif(not matplotlib_available, reason="matplotlib is needed")
Expand Down
14 changes: 14 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -709,3 +709,17 @@ def test_histogram_segments(mask_expected):
masked = np.arange(len(mask))[np.array(mask)]
segments = util._histogram_segments(mask, xe, masked)
assert_equal([s[0] for s in segments], expected)


@pytest.mark.parametrize(
"fn_expected", ((lambda x: x, 15), (lambda x: x**11, 40), (np.log, 80))
)
def test_smart_sampling_1(fn_expected):
fn, expected = fn_expected
x, y = util._smart_sampling(fn, 1e-10, 5)
assert len(y) < expected


def test_smart_sampling_2():
with pytest.warns(RuntimeWarning):
util._smart_sampling(np.log, 1e-10, 1, tol=1e-10)

0 comments on commit 83e60f8

Please sign in to comment.