r"""
A simple problem with multiple minima
=====================================

In this example, we show how to solve a simple nonlinear optimization problem with
multiple minimas using the multistart capabilities of :mod:`csnlp.multistart`.

The problem we consider is

.. math::
    \min_{x \in [-0.5, 1.4]}{
        -p_1 x^2 - \exp(-p_2 x^2) + \exp(-10 p_2 (x - 1)^2) + \exp(-10 p_2 (x - 1.5)^2)
    },

which, in the given range, has three local minima.
"""

# %%
# Without :mod:`csnlp.multistart`
# -------------------------------
# If we attempt in solving the given problem using the standard NLP formulation, i.e.,
# with a single initial condition, we might end up in a local minimum. The chances of
# this occurring increase with the complexity and dimension of the problem.
#
# First, we'll solve the problem the usual way with :class:`csnlp.Nlp`. Let's define
# the imports and the function to optimize.

from itertools import chain

import casadi as cs
import matplotlib.pyplot as plt
import numpy as np

from csnlp import Nlp
from csnlp import multistart as ms


def func(x, p0, p1):
    return (
        -p0 * x**2
        - np.exp(-p1 * x**2)
        + np.exp(-10 * p1 * (x - 1) ** 2)
        + np.exp(-10 * p1 * (x - 1.5) ** 2)
    )


# %%
# Then, let's build the NLP.

LB, UB = -0.5, 1.4
nlp = Nlp[cs.SX]()
x = nlp.variable("x", lb=LB, ub=UB)[0]
p0 = nlp.parameter("p0")
p1 = nlp.parameter("p1")
nlp.minimize(func(x, p0, p1))
opts = {"print_time": False, "ipopt": {"sb": "yes", "print_level": 0}}
nlp.init_solver(opts)

# %%
# We can solve it with some (single) initial conditions, e.g., :math:`x_0 = 0.0`. We fix
# the parameters to some constant values throughout the example.

x0 = 0.97
pars = {"p0": 0.3, "p1": 10}
sol_single = nlp.solve(pars=pars, vals0={"x": x0})

# %%
# If we plot this solution, we'll see that the gradient-based solver IPOPT got stuck in
# a local minimum.

fig, ax = plt.subplots(constrained_layout=True)

xs = np.linspace(LB, UB, 200)
F = lambda x: func(x, **pars)
Fs = F(xs)
ax.plot(xs, Fs, "k--")

xf = float(sol_single.vals["x"])
xs_sol = np.linspace(x0, xf, 100)
lbl = rf"$x_0={{{x0:.3f}}} \rightarrow f^{{\star}}={{{F(xf):.3f}}}$"
ax.plot(xs_sol, F(xs_sol), "-", lw=2, label=lbl, color="C0")
ax.plot(x0, F(x0), "o", markersize=6, color="C0")
ax.plot(xf, F(xf), "*", markersize=10, color="C0")

ax.set_xlabel("$x$")
ax.set_ylabel("$f(x)$")
ax.set_xlim(LB, UB)
ax.set_ylim(-1.1, 0.8)
ax.legend()
plt.show()

# %%
# With :mod:`csnlp.multistart`
# ----------------------------
# Now, we'll solve the same problem using the multistart capabilities of :mod:`csnlp`.
#
# The submodule :mod:`csnlp.multistart` provides three classes to solve an NLP problem
# in a multistart parallelized fashion: :class:`csnlp.multistart.StackedMultistartNlp`,
# :class:`csnlp.multistart.MappedMultistartNlp`, and
# :class:`csnlp.multistart.ParallelMultistartNlp`. The interfaces of these classes
# remain mostly unchanged with respect to the base class :class:`csnlp.Nlp`.
#
# The only novel is the introduction of the method
# :meth:`csnlp.multistart.MultistartNlp.solve_multi` (which each class implements
# differently), which allows to solve the problem from multiple initial conditions and
# multiple parameters. The method returns the best solution found, or all solutions if
# requested.
#
# To build the multistart NLP, we follow a very similar procedure as the one above.

n_multistarts = 4
nlp = ms.StackedMultistartNlp[cs.SX](starts=n_multistarts)
# nlp = ms.MappedMultistartNlp[cs.SX](starts=n_multistarts, parallelization="thread")
# nlp = ms.ParallelMultistartNlp[cs.SX](
#     starts=n_multistarts, parallel_kwargs={"n_jobs": n_multistarts}
# )
x = nlp.variable("x", lb=LB, ub=UB)[0]
p0 = nlp.parameter("p0")
p1 = nlp.parameter("p1")
nlp.minimize(func(x, p0, p1))
nlp.init_solver(opts)

# %%
# The :meth:`csnlp.multistart.MultistartNlp.solve_multi` method can accept as inputs an
# iterable of dictionaries for its arguments ``pars`` and ``vals0``. Each dictionary
# defines the conditions of a single start. However, it is also possible to pass a
# single dictionary, in which case it will be used across all starts.
#
# These iterables of dictionaries can be manually generated by the user. However,
# :mod:`csnlp.multistart` offers classes to generate the starting points automatically
# in a random or deterministic (structured) way. In what follows, we'll generate four
# starting locations while keeping the parameters fixed. Two of these locations will be
# uniformly randomly generated, the other two are linearly spaced in the given domain.

random_points = ms.RandomStartPoints(
    points={"x": ms.RandomStartPoint("uniform", LB, UB)},
    multistarts=n_multistarts // 2,
    seed=42,
)
structured_points = ms.StructuredStartPoints(
    points={"x": ms.StructuredStartPoint(LB, UB)},
    multistarts=n_multistarts // 2,
)

# %%
# The :class:`csnlp.multistart.RandomStartPoints` and
# :class:`csnlp.multistart.StructuredStartPoints` can be iterated over to yield the
# initial conditions for each start. We'll chain them together to pass them to the
# solver (we also convert them to a list just for plotting reasons).

x0s = list(chain(random_points, structured_points))

# %%
# Now, we can call the multistart solver. By passing ``return_all_sols=True``, we are
# requesting the method to return all the solutions found for each start, instead of
# returning just the best one.

sol_best = nlp.solve_multi(pars=pars, vals0=x0s)  # to get just the best solution
sols_multi = nlp.solve_multi(pars=pars, vals0=x0s, return_all_sols=True)  # to get all

# %%
# We can plot the convergence of each starting point as we did earlier.

fig, ax = plt.subplots(constrained_layout=True)

ax.plot(xs, Fs, "k--")

for i, (x0, sol) in enumerate(zip(x0s, sols_multi), start=1):
    x0 = x0["x"]
    xf = float(sol.vals["x"])
    xs = np.linspace(x0, xf, 100)
    lbl = rf"$x_0={{{x0:.3f}}} \rightarrow f^{{\star}}={{{F(xf):.3f}}}$"
    ax.plot(xs, F(xs), "-", lw=2, label=lbl, color=f"C{i}")
    ax.plot(x0, F(x0), "o", markersize=6, color=f"C{i}")
    ax.plot(xf, F(xf), "*", markersize=10, color=f"C{i}")

ax.plot(
    sol_best.vals["x"],
    sol_best.f,
    "s",
    markersize=14,
    fillstyle="none",
    color="k",
    label="Best",
)
x = float(sol_best.value(nlp.x))
ax.vlines(x, -1.1, sol_best.f, "k", ls="-.")
ax.hlines(sol_best.f, LB, x, "k", ls="-.")

ax.set_xlabel("$x$")
ax.set_ylabel("$f(x)$")
ax.set_xlim(LB, UB)
ax.set_ylim(-1.1, 0.8)
ax.legend()
plt.show()
