{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Interactive fits\n",
    "\n",
    "This notebook showcases the interactive fitting capability of iminuit. Interactive fitting is useful to find good starting values and to debug the fit.\n",
    "\n",
    "**Note:** If you see this notebook on ReadTheDocs or otherwise statically rendered, changing the sliders won't change the plot. This requires a running Jupyter kernel."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from iminuit import cost\n",
    "from iminuit import Minuit\n",
    "from numba_stats import norm, t, bernstein, truncexpon\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## UnbinnedNLL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(1)\n",
    "\n",
    "s = rng.normal(0.5, 0.1, size=1000)\n",
    "b = rng.exponential(1, size=1000)\n",
    "b = b[b < 1]\n",
    "x = np.append(s, b)\n",
    "\n",
    "truth = len(s) / len(x), 0.5, 0.1, 1.0\n",
    "\n",
    "n, xe = np.histogram(x, bins=50)\n",
    "\n",
    "def model(x, f, mu, sigma, slope):\n",
    "    return f * norm.pdf(x, mu, sigma) + (1 - f) * truncexpon.pdf(x, 0, 1, 0, slope)\n",
    "\n",
    "c = cost.UnbinnedNLL(x, model)\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"f\", \"mu\"] = (0, 1)\n",
    "m.limits[\"sigma\", \"slope\"] = (0, None)\n",
    "\n",
    "m.interactive(model_points=1000)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ExtendedUnbinnedNLL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "rng = np.random.default_rng(1)\n",
    "\n",
    "s = rng.normal(0.5, 0.1, size=1000)\n",
    "b = rng.exponential(1, size=1000)\n",
    "b = b[b < 1]\n",
    "x = np.append(s, b)\n",
    "\n",
    "truth = len(s), 0.5, 0.1, len(b), 1.0\n",
    "\n",
    "n, xe = np.histogram(x, bins=50)\n",
    "\n",
    "def model(x, s, mu, sigma, b, slope):\n",
    "    x = s * norm.pdf(x, mu, sigma) + b * truncexpon.pdf(x, 0, 1, 0, slope)\n",
    "    return s + b, x\n",
    "\n",
    "c = cost.ExtendedUnbinnedNLL(x, model)\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"mu\"] = (0, 1)\n",
    "m.limits[\"sigma\", \"slope\", \"s\", \"b\"] = (0, None)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BinnedNLL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(xe, f, mu, sigma, nuinv, slope):\n",
    "    nu = 1 / nuinv\n",
    "    a, b = t.cdf((0, 1), nu, mu, sigma)\n",
    "    sn = f * (t.cdf(xe, nu, mu, sigma) - a) / (b - a)\n",
    "    bn = (1 - f) * truncexpon.cdf(xe, 0, 1, 0, slope)\n",
    "    return sn + bn\n",
    "\n",
    "rng = np.random.default_rng(1)\n",
    "\n",
    "truth = 0.5, 0.5, 0.1, 0.1, 1\n",
    "\n",
    "xe = np.linspace(0, 1, 100)\n",
    "sm = truth[0] * np.diff(model(xe, 1, *truth[1:]))\n",
    "bm = (1 - truth[0]) * np.diff(model(xe, 0, *truth[1:]))\n",
    "n = rng.poisson(1000 * np.diff(model(xe, *truth)))\n",
    "\n",
    "c = cost.BinnedNLL(n, xe, model)\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"sigma\", \"slope\"] = (0, None)\n",
    "m.limits[\"mu\", \"f\", \"nuinv\"] = (0, 1)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = cost.BinnedNLL(n, xe, model)\n",
    "\n",
    "cx = 0.5 * (xe[1:] + xe[:-1])\n",
    "c.mask = np.abs(cx - 0.5) > 0.3\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"sigma\", \"slope\"] = (0, None)\n",
    "m.limits[\"mu\", \"f\", \"nuinv\"] = (0, 1)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ExtendedBinnedNLL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(xe, s, mu, sigma, nuinv, b1, b2, b3):\n",
    "    nu = 1 / nuinv\n",
    "    sn = s * t.cdf(xe, nu, mu, sigma)\n",
    "    bn = bernstein.integral(xe, (b1, b2, b3), 0, 1)\n",
    "    return sn + bn\n",
    "\n",
    "truth = 1000., 0.5, 0.1, 0.1, 1000., 3000., 2000.\n",
    "\n",
    "xe = np.linspace(0, 1, 100)\n",
    "\n",
    "rng = np.random.default_rng(1)\n",
    "n = rng.poisson(np.diff(model(xe, *truth)))\n",
    "\n",
    "c = cost.ExtendedBinnedNLL(n, xe, model)\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"s\", \"sigma\", \"b1\", \"b2\", \"b3\"] = (0, None)\n",
    "m.limits[\"mu\", \"nuinv\"] = (0, 1)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "You can pass a custom plotting routine with `Minuit.interactive` to draw more detail. A simple function works that accesses data from the outer scope, but we create a class in the following example to store the cost function, which has all data we need, because we override the variables in the outer scope in this notebook."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Visualize signal and background components with different colors\n",
    "class Plotter:\n",
    "    def __init__(self, cost):\n",
    "        self.cost = cost\n",
    "\n",
    "    def __call__(self, args):\n",
    "        xe = self.cost.xe\n",
    "        n = self.cost.data\n",
    "        cx = 0.5 * (xe[1:] + xe[:-1])\n",
    "        plt.errorbar(cx, n, n ** 0.5, fmt=\"ok\")\n",
    "        sm = np.diff(self.cost.scaled_cdf(xe, *args[:4], 0, 0, 0))\n",
    "        bm = np.diff(self.cost.scaled_cdf(xe, 0, *args[1:]))\n",
    "        plt.stairs(bm, xe, fill=True, color=\"C1\")\n",
    "        plt.stairs(bm + sm, xe, baseline = bm, fill=True, color=\"C0\")\n",
    "\n",
    "m.interactive(Plotter(c))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = cost.ExtendedBinnedNLL(n, xe, model)\n",
    "\n",
    "cx = 0.5 * (xe[1:] + xe[:-1])\n",
    "c.mask = np.abs(cx - 0.5) > 0.3\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.limits[\"s\", \"sigma\", \"nuinv\", \"b1\", \"b2\", \"b3\"] = (0, None)\n",
    "m.limits[\"mu\", \"nuinv\"] = (0, 1)\n",
    "m.fixed[\"s\", \"mu\", \"sigma\", \"nuinv\"] = True\n",
    "m.values[\"s\"] = 0\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## BarlowBeestonLite"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "xe = np.linspace(0, 1, 20)\n",
    "bm = np.diff(truncexpon.cdf(xe, 0, 1, 0, 1))\n",
    "sm = np.diff(norm.cdf(xe, 0.5, 0.1))\n",
    "\n",
    "rng = np.random.default_rng(1)\n",
    "n = rng.poisson(1000 * bm + 100 * sm)\n",
    "b = rng.poisson(1e4 * bm)\n",
    "s = rng.poisson(1e2 * sm)\n",
    "\n",
    "c = cost.BarlowBeestonLite(n, xe, (b, s))\n",
    "\n",
    "m = Minuit(c, 1000, 100, name=(\"b\", \"s\"))\n",
    "m.limits = (0, None)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = cost.BarlowBeestonLite(n, xe, (b, s))\n",
    "cx = 0.5 * (xe[1:] + xe[:-1])\n",
    "c.mask = np.abs(cx - 0.5) > 0.2\n",
    "\n",
    "m = Minuit(c, 1000, 100, name=(\"b\", \"s\"))\n",
    "m.limits = (0, None)\n",
    "\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## LeastSquares"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(x, a, b):\n",
    "    return a + b * x\n",
    "\n",
    "truth = (1., 2.)\n",
    "x = np.linspace(0, 1)\n",
    "ym = model(x, *truth)\n",
    "ye = 0.1\n",
    "y = rng.normal(ym, ye)\n",
    "\n",
    "c = cost.LeastSquares(x, y, ye, model)\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "c = cost.LeastSquares(x, y, ye, model)\n",
    "c.mask = (x > 0.6) | (x < 0.2)\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "m.interactive()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cost functions with shared parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def model(xe, s, mu, sigma, nuinv):\n",
    "    nu = 1 / nuinv \n",
    "    return s * t.cdf(xe, nu, mu, sigma)\n",
    "\n",
    "truth = 100., 0.5, 0.1, 0.5\n",
    "\n",
    "rng = np.random.default_rng(1)\n",
    "xe = np.linspace(0, 1, 20)\n",
    "m = np.diff(model(xe, *truth))\n",
    "n = rng.poisson(m)\n",
    "\n",
    "c = cost.ExtendedBinnedNLL(n, xe, model) + cost.NormalConstraint([\"mu\", \"sigma\"], [0.5, 0.1], [0.1, 0.1])\n",
    "\n",
    "m = Minuit(c, *truth)\n",
    "\n",
    "m.interactive()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.13 ('venv': venv)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "vscode": {
   "interpreter": {
    "hash": "bdbf20ff2e92a3ae3002db8b02bd1dd1b287e934c884beb29a73dced9dbd0fa3"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
