External Minimizers: Plotting Fit Progress

In this example we are demonstrating how to run a typical fitting task in BornAgain using a third party minimizer while plotting the results. As in our previous example, we use lmfit for sake of illustration.

To plot the fit progress, it is needed to use the lmfit iteration callback function. It will come handy to define the plotting callback function as a specialized class:

class Plotter:
    """
    Adapts standard plotter for lmfit minimizer.
    """
    def __init__(self, fit_objective, every_nth = 10):
        self.fit_objective = fit_objective
        self.plotter_gisas = ba.PlotterGISAS()
        self.every_nth = every_nth

    def __call__(self, params, iter, resid):
        if iter%self.every_nth == 0:
            self.plotter_gisas.plot(self.fit_objective)

An instance of this class is then passed to the lmfit minimization function:

    plotter = Plotter(fit_objective)
    result = lmfit.minimize(fit_objective.evaluate_residuals, params, iter_cb=plotter)

The complete script to plot the fitting progress and the image produced by it are shown below.

Plotting the fitting progress of external minimizers

Plotting the fitting progress of external minimizers

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#!/usr/bin/env python3
"""
External minimize: using lmfit minimizers for BornAgain fits.
Fit progress is plotted using lmfit iteration callback function.
"""
import bornagain as ba
from bornagain import ba_fitmonitor, ba_plot as bp, nm
import lmfit
from matplotlib import pyplot as plt
import model2_hexlattice as model


class LMFITPlotter:
    """
    Adapts standard plotter for lmfit minimizer.
    """

    def __init__(self, fit_objective, every_nth=10):
        self.fit_objective = fit_objective
        self.plotter_gisas = ba_fitmonitor.PlotterGISAS()
        self.every_nth = every_nth

    def __call__(self, params, i, resid):
        if i % self.every_nth == 0:
            self.plotter_gisas.plot(self.fit_objective)


if __name__ == '__main__':
    bp.parse_args(sim_n=100)

    real_data = model.create_real_data()

    fit_objective = ba.FitObjective()
    fit_objective.addSimulationAndData(model.get_simulation, real_data, 1)
    fit_objective.initPrint(10)

    params = lmfit.Parameters()
    params.add('radius', value=7*nm, min=5*nm, max=8*nm)
    params.add('length', value=10*nm, min=8*nm, max=14*nm)

    plotter = LMFITPlotter(fit_objective)
    result = lmfit.minimize(fit_objective.evaluate_residuals,
                            params,
                            iter_cb=plotter)
    fit_objective.finalize(result)

    result.params.pretty_print()
    print(lmfit.fit_report(result))
    plt.show()
Examples/fit/scatter2d/lmfit_with_plotting.py