diff --git a/.gitignore b/.gitignore index 8fc080c3..4c715ad7 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,9 @@ /root-* /dev /venv +/out +/.venv +/.vs /build /dist diff --git a/doc/notebooks/interactive.ipynb b/doc/notebooks/interactive.ipynb index 1e806405..996cb797 100644 --- a/doc/notebooks/interactive.ipynb +++ b/doc/notebooks/interactive.ipynb @@ -393,7 +393,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" + "version": "3.12.6" }, "vscode": { "interpreter": { diff --git a/src/iminuit/ipywidget.py b/src/iminuit/ipywidget.py new file mode 100644 index 00000000..f56bdeef --- /dev/null +++ b/src/iminuit/ipywidget.py @@ -0,0 +1,294 @@ +"""Interactive fitting widget for Jupyter notebooks.""" + +import warnings +import numpy as np +from typing import Dict, Any, Callable +import sys + +with warnings.catch_warnings(): + # ipywidgets produces deprecation warnings through use of internal APIs :( + warnings.simplefilter("ignore") + try: + import ipywidgets as widgets + from ipywidgets.widgets.interaction import show_inline_matplotlib_plots + from IPython.display import clear_output + from matplotlib import pyplot as plt + except ModuleNotFoundError as e: + e.msg += ( + "\n\nPlease install ipywidgets, IPython, and matplotlib to " + "enable interactive" + ) + raise + + +def make_widget( + minuit: Any, + plot: Callable[..., None], + kwargs: Dict[str, Any], + raise_on_exception: bool, +): + """Make interactive fitting widget.""" + # Implementations makes heavy use of closures, + # we frequently use variables which are defined + # near the end of the function. + original_values = minuit.values[:] + original_limits = minuit.limits[:] + + def plot_with_frame(from_fit, report_success): + trans = plt.gca().transAxes + try: + with warnings.catch_warnings(): + minuit.visualize(plot, **kwargs) + except Exception: + if raise_on_exception: + raise + + import traceback + + plt.figtext( + 0, + 0.5, + traceback.format_exc(limit=-1), + fontdict={"family": "monospace", "size": "x-small"}, + va="center", + color="r", + backgroundcolor="w", + wrap=True, + ) + return + + fval = minuit.fmin.fval if from_fit else minuit._fcn(minuit.values) + plt.text( + 0.05, + 1.05, + f"FCN = {fval:.3f}", + transform=trans, + fontsize="x-large", + ) + if from_fit and report_success: + plt.text( + 0.95, + 1.05, + f"{'success' if minuit.valid and minuit.accurate else 'FAILURE'}", + transform=trans, + fontsize="x-large", + ha="right", + ) + + def fit(): + if algo_choice.value == "Migrad": + minuit.migrad() + elif algo_choice.value == "Scipy": + minuit.scipy() + elif algo_choice.value == "Simplex": + minuit.simplex() + return False + else: + assert False # pragma: no cover, should never happen + return True + + class OnParameterChange: + # Ugly implementation notes: + # We want the plot when the user moves the slider widget, but not when + # we update the slider value manually from our code. Unfortunately, + # the latter also calls OnParameterChange, which leads to superfluous plotting. + # I could not find a nice way to prevent that (and I tried many), so as a workaround + # we optionally skip a number of calls, when the slider is updated. + def __init__(self, skip: int = 0): + self.skip = skip + + def __call__(self, change: Dict[str, Any] = {}): + if self.skip > 0: + self.skip -= 1 + return + + from_fit = change.get("from_fit", False) + report_success = change.get("report_success", False) + if not from_fit: + for i, x in enumerate(parameters): + minuit.values[i] = x.slider.value + + if any(x.fit.value for x in parameters): + saved = minuit.fixed[:] + for i, x in enumerate(parameters): + minuit.fixed[i] = not x.fit.value + from_fit = True + report_success = do_fit(None) + minuit.fixed = saved + + # Implementation like in ipywidegts.interaction.interactive_output + with out: + clear_output(wait=True) + plot_with_frame(from_fit, report_success) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + show_inline_matplotlib_plots() + + def do_fit(change): + report_success = fit() + for i, x in enumerate(parameters): + x.reset(minuit.values[i]) + if change is None: + return report_success + OnParameterChange()({"from_fit": True, "report_success": report_success}) + + def on_update_button_clicked(change): + for x in parameters: + x.slider.continuous_update = not x.slider.continuous_update + + def on_reset_button_clicked(change): + minuit.reset() + minuit.values = original_values + minuit.limits = original_limits + for i, x in enumerate(parameters): + x.reset(minuit.values[i], minuit.limits[i]) + OnParameterChange()() + + class Parameter(widgets.HBox): + def __init__(self, minuit, par): + val = minuit.values[par] + vmin, vmax = minuit.limits[par] + step = _guess_initial_step(val, vmin, vmax) + vmin2 = vmin if np.isfinite(vmin) else val - 100 * step + vmax2 = vmax if np.isfinite(vmax) else val + 100 * step + + tlabel = widgets.Label(par, layout=widgets.Layout(width=f"{longest_par}em")) + + tmin = widgets.BoundedFloatText( + _round(vmin2), + min=_make_finite(vmin), + max=vmax2, + step=1e-1 * (vmax2 - vmin2), + layout=widgets.Layout(width="4.1em"), + ) + + tmax = widgets.BoundedFloatText( + _round(vmax2), + min=vmin2, + max=_make_finite(vmax), + step=1e-1 * (vmax2 - vmin2), + layout=widgets.Layout(width="4.1em"), + ) + + self.slider = widgets.FloatSlider( + val, + min=vmin2, + max=vmax2, + step=step, + continuous_update=True, + readout_format=".3g", + layout=widgets.Layout(min_width="50%"), + ) + self.slider.observe(OnParameterChange(), "value") + + def on_min_change(change): + self.slider.min = change["new"] + tmax.min = change["new"] + lim = minuit.limits[par] + minuit.limits[par] = (self.slider.min, lim[1]) + + def on_max_change(change): + self.slider.max = change["new"] + tmin.max = change["new"] + lim = minuit.limits[par] + minuit.limits[par] = (lim[0], self.slider.max) + + tmin.observe(on_min_change, "value") + tmax.observe(on_max_change, "value") + + self.fix = widgets.ToggleButton( + minuit.fixed[par], + description="Fix", + tooltip="Fix", + layout=widgets.Layout(width="3.1em"), + ) + + self.fit = widgets.ToggleButton( + False, + description="Fit", + tooltip="Fit", + layout=widgets.Layout(width="3.5em"), + ) + + def on_fix_toggled(change): + minuit.fixed[par] = change["new"] + if change["new"]: + self.fit.value = False + + def on_fit_toggled(change): + self.slider.disabled = change["new"] + if change["new"]: + self.fix.value = False + OnParameterChange()() + + self.fix.observe(on_fix_toggled, "value") + self.fit.observe(on_fit_toggled, "value") + super().__init__([tlabel, tmin, self.slider, tmax, self.fix, self.fit]) + + def reset(self, value, limits=None): + self.slider.unobserve_all("value") + self.slider.value = value + if limits: + self.slider.min, self.slider.max = limits + # Installing the observer actually triggers a notification, + # we skip it. See notes in OnParameterChange. + self.slider.observe(OnParameterChange(1), "value") + + longest_par = max(len(par) for par in minuit.parameters) + parameters = [Parameter(minuit, par) for par in minuit.parameters] + + button_layout = widgets.Layout(max_width="8em") + + fit_button = widgets.Button( + description="Fit", + button_style="primary", + layout=button_layout, + ) + fit_button.on_click(do_fit) + + update_button = widgets.ToggleButton( + True, + description="Continuous", + layout=button_layout, + ) + update_button.observe(on_update_button_clicked) + + reset_button = widgets.Button( + description="Reset", + button_style="danger", + layout=button_layout, + ) + reset_button.on_click(on_reset_button_clicked) + + algo_choice = widgets.Dropdown( + options=["Migrad", "Scipy", "Simplex"], + value="Migrad", + layout=button_layout, + ) + + ui = widgets.VBox( + [ + widgets.HBox([fit_button, update_button, reset_button, algo_choice]), + widgets.VBox(parameters), + ] + ) + out = widgets.Output() + OnParameterChange()() + return widgets.HBox([out, ui]) + + +def _make_finite(x: float) -> float: + sign = -1 if x < 0 else 1 + if abs(x) == np.inf: + return sign * sys.float_info.max + return x + + +def _guess_initial_step(val: float, vmin: float, vmax: float) -> float: + if np.isfinite(vmin) and np.isfinite(vmax): + return 1e-2 * (vmax - vmin) + return 1e-2 + + +def _round(x: float) -> float: + return float(f"{x:.1g}") diff --git a/src/iminuit/minuit.py b/src/iminuit/minuit.py index 8d61a09a..63d6deca 100644 --- a/src/iminuit/minuit.py +++ b/src/iminuit/minuit.py @@ -2347,211 +2347,10 @@ def interactive( -------- Minuit.visualize """ - with warnings.catch_warnings(): - # ipywidgets produces deprecation warnings through use of internal APIs :( - warnings.simplefilter("ignore") - try: - from ipywidgets import ( - HBox, - VBox, - Output, - FloatSlider, - Button, - ToggleButton, - Layout, - Dropdown, - ) - from ipywidgets.widgets.interaction import show_inline_matplotlib_plots - from IPython.display import clear_output - from matplotlib import pyplot as plt - except ModuleNotFoundError as e: - e.msg += ( - "\n\nPlease install ipywidgets, IPython, and matplotlib to " - "enable interactive" - ) - raise - - plot = self._visualize(plot) - - def plot_with_frame(args, from_fit, report_success): - trans = plt.gca().transAxes - try: - with warnings.catch_warnings(): - if self._fcn._array_call: - plot([args], **kwargs) # prevent unpacking of array - else: - plot(args, **kwargs) - except Exception: - if raise_on_exception: - raise - - import traceback - - plt.figtext( - 0.01, - 0.5, - traceback.format_exc(), - ha="left", - va="center", - transform=trans, - color="r", - ) - return - if from_fit: - fval = self.fmin.fval - else: - fval = self._fcn(args) - plt.text( - 0.05, - 1.05, - f"FCN = {fval:.3f}", - transform=trans, - fontsize="x-large", - ) - if from_fit and report_success: - plt.text( - 0.95, - 1.05, - f"{'success' if self.valid and self.accurate else 'FAILURE'}", - transform=trans, - fontsize="x-large", - ha="right", - ) - - class ParameterBox(HBox): - def __init__(self, par, val, min, max, step, fix): - self.par = par - self.slider = FloatSlider( - val, - min=a, - max=b, - step=step, - description=par, - continuous_update=True, - readout_format=".4g", - layout=Layout(min_width="70%"), - ) - self.fix = ToggleButton( - fix, description="Fix", layout=Layout(width="3.1em") - ) - self.opt = ToggleButton( - False, description="Opt", layout=Layout(width="3.5em") - ) - self.opt.observe(self.on_opt_toggled, "value") - super().__init__([self.slider, self.fix, self.opt]) - - def on_opt_toggled(self, change): - self.slider.disabled = self.opt.value - on_slider_change(None) - - def fit(): - if algo_choice.value == "Migrad": - self.migrad() - elif algo_choice.value == "Scipy": - self.scipy() - elif algo_choice.value == "Simplex": - self.simplex() - return False - else: - assert False # pragma: no cover, should never happen - return True - - def on_slider_change(change): - if out.block: - return - args = [x.slider.value for x in parameters] - from_fit = False - report_success = False - if any(x.opt.value for x in parameters): - save = self.fixed[:] - self.fixed = [not x.opt.value for x in parameters] - self.values = args - report_success = fit() - args = self.values[:] - out.block = True - for x, val in zip(parameters, args): - x.slider.value = val - out.block = False - self.fixed = save - from_fit = True - with out: - clear_output(wait=True) - plot_with_frame(args, from_fit, report_success) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - show_inline_matplotlib_plots() - - def on_fit_button_clicked(change): - for x in parameters: - self.values[x.par] = x.slider.value - self.fixed[x.par] = x.fix.value - report_success = fit() - out.block = True - for x in parameters: - val = self.values[x.par] - if val < x.slider.min: - x.slider.min = val - elif val > x.slider.max: - x.slider.max = val - x.slider.value = val - out.block = False - with out: - clear_output(wait=True) - plot_with_frame(self.values, True, report_success) - show_inline_matplotlib_plots() - - def on_update_button_clicked(change): - for x in parameters: - x.slider.continuous_update = not x.slider.continuous_update - - def on_reset_button_clicked(change): - self.reset() - out.block = True - for x in parameters: - x.slider.value = self.values[x.par] - out.block = False - on_slider_change(None) - - parameters = [] - for par in self.parameters: - val = self.values[par] - step = mutil._guess_initial_step(val) - a, b = self.limits[par] - # safety margin to avoid overflow warnings - a = a + 1e-300 if np.isfinite(a) else val - 100 * step - b = b - 1e-300 if np.isfinite(b) else val + 100 * step - parameters.append(ParameterBox(par, val, a, b, step, self.fixed[par])) - - fit_button = Button(description="Fit") - fit_button.on_click(on_fit_button_clicked) - - update_button = ToggleButton(True, description="Continuous") - update_button.observe(on_update_button_clicked) - - reset_button = Button(description="Reset") - reset_button.on_click(on_reset_button_clicked) - - algo_choice = Dropdown( - options=["Migrad", "Scipy", "Simplex"], value="Migrad" - ) - - ui = VBox( - [ - HBox([fit_button, update_button, reset_button, algo_choice]), - VBox(parameters), - ] - ) - - out = Output() - out.block = False - - for x in parameters: - x.slider.observe(on_slider_change, "value") - - show_inline_matplotlib_plots() - on_slider_change(None) + from iminuit.ipywidget import make_widget - return HBox([out, ui]) + plot = self._visualize(plot) + return make_widget(self, plot, kwargs, raise_on_exception) def _free_parameters(self) -> Set[str]: return set(mp.name for mp in self._last_state if not mp.is_fixed) diff --git a/src/iminuit/util.py b/src/iminuit/util.py index 4c79b4f1..3db8bb51 100644 --- a/src/iminuit/util.py +++ b/src/iminuit/util.py @@ -1564,7 +1564,7 @@ def _histogram_segments(mask, xe, masked): return segments -def _smart_sampling(f, xmin, xmax, start=5, tol=5e-3, maxiter=20, maxtime=10): +def _smart_sampling(f, xmin, xmax, start=20, tol=5e-3, maxiter=20, maxtime=10): t0 = monotonic() x = np.linspace(xmin, xmax, start) ynew = f(x) diff --git a/tests/test_cost.py b/tests/test_cost.py index 6ffcf335..c04ab141 100644 --- a/tests/test_cost.py +++ b/tests/test_cost.py @@ -1375,7 +1375,7 @@ def test_LeastSquares_visualize(): assert_equal(x, (1, 2)) assert_equal(y, (2, 3)) assert_equal(ye, 0.1) - assert len(xm) < 10 + assert len(xm) == 39 # linear spacing (x, y, ye), (xm, ym) = c.visualize((1, 2), model_points=10) assert len(xm) == 10 diff --git a/tests/test_draw.py b/tests/test_draw.py index 3abaeff5..b4e9e088 100644 --- a/tests/test_draw.py +++ b/tests/test_draw.py @@ -189,10 +189,12 @@ def assert_call(self): update_button.value = False with plot.assert_call(): - parameters.children[0].slider.value = 0.4 # change first slider + # because of implementation details, we have to trigger the slider several times + for i in range(5): + parameters.children[0].slider.value = i # change first slider parameters.children[0].fix.value = True with plot.assert_call(): - parameters.children[0].opt.value = True + parameters.children[0].fit.value = True class Cost: def visualize(self, args): @@ -210,8 +212,8 @@ def __call__(self, a, b): ui = out.children[1] header, parameters = ui.children fit_button, update_button, reset_button, algo_select = header.children - assert parameters.children[0].slider.max < 100 - assert parameters.children[1].slider.min > -100 + assert parameters.children[0].slider.max == 1 + assert parameters.children[1].slider.min == -1 with plot.assert_call(): fit_button.click() assert_allclose(m.values, (100, -100), atol=1e-5) @@ -253,4 +255,4 @@ def __call__(self, par): trace_args = TraceArgs() m = Minuit(cost, (1, 2)) m.interactive(trace_args) - assert trace_args.nargs == 1 + assert trace_args.nargs > 0 diff --git a/tests/test_util.py b/tests/test_util.py index a20131cf..b529ea4c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -702,7 +702,7 @@ def test_histogram_segments(mask_expected): @pytest.mark.parametrize( - "fn_expected", ((lambda x: x, 15), (lambda x: x**11, 40), (np.log, 80)) + "fn_expected", ((lambda x: x, 40), (lambda x: x**11, 60), (np.log, 80)) ) def test_smart_sampling_1(fn_expected): fn, expected = fn_expected @@ -734,7 +734,7 @@ def step(x): with pytest.warns(RuntimeWarning, match="Time limit"): x, y = util._smart_sampling(step, 0, 1, maxtime=0) - assert 0 < len(x) < 10 + assert 0 < len(x) < 30 @pytest.mark.parametrize( diff --git a/tests/test_without_ipywidgets.py b/tests/test_without_ipywidgets.py index 57cd7266..fbd9b508 100644 --- a/tests/test_without_ipywidgets.py +++ b/tests/test_without_ipywidgets.py @@ -13,6 +13,6 @@ def test_interactive(): iminuit.Minuit(cost, 1).interactive() - with hide_modules("ipywidgets", reload="iminuit"): + with hide_modules("ipywidgets", reload="iminuit.ipywidget"): with pytest.raises(ModuleNotFoundError, match="Please install"): iminuit.Minuit(cost, 1).interactive()