"""A collection of stand-alone functions to extract information from the CasADi
documentation via code. In particular, this module offers a way to get the solvers that
are available in CasADi (i.e., they have an interface) as well as and their options. The
functions are taken from the
`MPCTools <https://bitbucket.org/rawlings-group/mpc-tools-casadi/src/master/mpctools/util.py>`_
repository by the Rawlings' group.
"""
import contextlib
import itertools
import warnings
from typing import Any as _Any
from typing import Callable
from typing import NamedTuple as _NamedTuple
import casadi as cs
class _LambdaType:
def __init__(self, func: Callable[[_Any], _Any], typerepr: str) -> None:
self.__typerepr = typerepr
self.__func = func
def __call__(self, val: _Any) -> _Any:
return self.__func(val)
def __repr__(self) -> str:
return f"<type '{self.__typerepr}'>"
def __str__(self) -> str:
return repr(self)
class _DocCell(_NamedTuple):
id: str
default: str
doc: str
_TABLE_START = "+=" # string prefix that starts the table
_CELL_END = "+-" # string prefix that ends the cell
_CELL_CONTENTS = "|" # string prefix that continues the cell
_TYPES: dict[str, Callable] = { # casadi types to python types
"OT_INTEGER": int,
"OT_STRING": str,
"OT_REAL": float,
"OT_INT": int,
"OT_DICT": dict,
"OT_DOUBLE": float,
"OT_BOOL": bool,
"OT_STR": str,
"OT_INTVECTOR": _LambdaType(lambda x: [int(i) for i in x], "list[int]"),
"OT_STRINGVECTOR": _LambdaType(lambda x: [str(i) for i in x], "list[str]"),
}
def _get_doc_cell(lines: list[str]) -> _DocCell:
"""Returns a DocCell tuple for the set of lines.
joins is a tuple of strings to say how to join multiple lines in a given
cell. It must have exactly one entry for each cell.
"""
ncol = lines[0].count(" | ") + 1
assert ncol in (3, 4), f"Expected 3 or 4 columns in the docstring table; got {ncol}"
fields: tuple[list[str], ...] = tuple([] for _ in range(ncol))
for line in lines:
cells = [c.strip() for c in line.split(" | ") if c]
cells[0] = cells[0].lstrip("|").rstrip()
cells[-1] = cells[-1].rstrip("|").rstrip()
for i, c in enumerate(cells):
fields[i].append(c.strip())
if i == ncol - 1:
break
else:
raise ValueError("Wrong number of columns.")
if ncol == 3:
id, type, doc = (j.join(f) for j, f in zip(("", "", " "), fields))
default = None
else:
id, type, default, doc = (j.join(f) for j, f in zip(("", "", "", " "), fields))
if typefunc := _TYPES.get(type):
try:
default = typefunc(default)
except (ValueError, TypeError):
includetype = True
if default in ("None", "GenericType()"):
default = None
elif typefunc is int:
with contextlib.suppress(ValueError, TypeError):
default = int(float(default))
includetype = False
if includetype:
default = (default, typefunc)
else:
warnings.warn(f"Unknown type for '{id}', '{type}'.", stacklevel=2)
return _DocCell(id, default, doc)
def _get_doc_dict(docstring: str) -> dict[str, tuple[_Any, str]]:
lineiter = itertools.dropwhile(
lambda x: not x.startswith(_TABLE_START), docstring.split("\n")
)
try:
next(lineiter)
except StopIteration as e:
raise ValueError("No table found!") from e
thiscell: list[str] = []
allcells: list[_DocCell] = []
for line in lineiter:
if line.startswith(_CELL_END):
allcells.append(_get_doc_cell(thiscell))
thiscell.clear()
elif line.startswith(_CELL_CONTENTS):
thiscell.append(line)
else:
break
return {c.id: ("N/A" if c.default is None else c.default, c.doc) for c in allcells}
[docs]
def get_casadi_plugins() -> dict[str, list[str]]:
"""Returns the available CasADi plugins.
Returns
-------
dict of (str, list of str)
A dictionary containing for each type of problem a list of plugin names that
are available.
Raises
------
RuntimeError
Raises in case the plugins cannot be retrieved because the functions
:func:`cs.CasadiMeta_getPlugins` or :func:`cs.CasadiMeta_plugins` are not
available.
"""
func = getattr(cs, "CasadiMeta_getPlugins", getattr(cs, "CasadiMeta_plugins", None))
if func is None:
raise RuntimeError("Unable to get Casadi plugins.")
all_plugins: str = func()
plugins = (p.split("::") for p in all_plugins.split(";"))
plugin_dict: dict[str, list[str]] = {}
for group, name in plugins:
if group not in plugin_dict:
plugin_dict[group] = []
plugin_dict[group].append(name)
return plugin_dict
[docs]
def list_available_solvers() -> dict[str, list[str]]:
"""Returns the available CasADi solvers.
Returns
-------
dict of (str, list of str)
A dictionary containing for each type of problem a list of plugin names that
are available.
Raises
------
RuntimeError
Raises in case the plugins cannot be retrieved because the functions
:func:`cs.CasadiMeta_getPlugins` or :func:`cs.CasadiMeta_plugins` are not
available.
"""
availablesolvers = get_casadi_plugins()
return {
"nlp": availablesolvers.get("Nlpsol", []),
"qp": availablesolvers.get("Conic", []) + availablesolvers.get("Qpsol", []),
}
[docs]
def get_solver_options(
solver: str, display: bool = True
) -> dict[str, tuple[_Any, str]]:
"""Returns the solver-specific options, with default value and description whenever
available.
Parameters
----------
solver : str
The solver name.
display : bool, optional
Whether to print the options, by default ``True``
Returns
-------
dict of (str, tuple of (Any, str)
A dictionary containing for each option the default value and a description.
Raises
------
ValueError
Raises in case ``solver`` is not included in the available solvers.
"""
availablesolvers = list_available_solvers()
if solver in availablesolvers["nlp"]:
docstring = cs.doc_nlpsol(solver)
elif solver in availablesolvers["qp"]:
docstring = cs.doc_conic(solver)
else:
raise ValueError(f"Unknown solver: '{solver}'.")
options = _get_doc_dict(docstring)
if display:
print(f"Available options [default] for {solver}:\n")
for k in sorted(options.keys()):
default, doc = options[k]
print(f"{k} [{default!r}]: {doc}\n")
return options