Source code for csnlp.wrappers.wrapper

from collections.abc import Iterable
from typing import Any, Generic, TypeVar, Union

import casadi as cs
from numpy import typing as npt

from ..core.solutions import Solution
from ..nlps.nlp import Nlp

SymType = TypeVar("SymType", cs.SX, cs.MX)


[docs] class Wrapper(Generic[SymType]): """Wraps an instance of :class:`csnlp.Nlp` to allow a modular transformation of its methods. This class is the base class for all wrappers. The subclass can then override some methods to change the behavior of the original environment without touching the original code. The base class is retroactive, in the sense that it can be applied to any NLP instance that already defines variables, parameters, and/or objective. Use :class:`NonRetroactiveWrapper` for wrappers that need to wrap an NLP before it is defined. Parameters ---------- nlp : Nlp or subclass The NLP to wrap. """ def __init__(self, nlp: Nlp[SymType]) -> None: super().__init__() self.nlp = nlp @property def unwrapped(self) -> Nlp[SymType]: """'Returns the original NLP of the wrapper.""" return self.nlp.unwrapped
[docs] def is_wrapped(self, wrapper_type: type["Wrapper[SymType]"]) -> bool: """Gets whether the NLP instance is wrapped or not by the given wrapper type. Parameters ---------- wrapper_type : type of Wrapper Type of wrapper to check if the NLP is wrapped with. Returns ------- bool ``True`` if wrapped by an instance of ``wrapper_type``; ``False``, otherwise. """ if isinstance(self, wrapper_type): return True return self.nlp.is_wrapped(wrapper_type)
def __getattr__(self, name: str) -> Any: """Reroutes attributes to the wrapped NLP instance.""" if name.startswith("_"): raise AttributeError(f"Accessing private attribute '{name}' is prohibited.") return getattr(self.nlp, name) def __call__( self, pars: Union[ None, dict[str, npt.ArrayLike], Iterable[dict[str, npt.ArrayLike]] ] = None, vals0: Union[ None, dict[str, npt.ArrayLike], Iterable[dict[str, npt.ArrayLike]] ] = None, **kwargs: Any, ) -> Union[Solution[SymType], list[Solution[SymType]]]: # Similar logic to `MultiStartNlp.__call__`: call solve_multi only if either # pars or vals0 is an iterable; otherwise, run the single, base NLP if not self.nlp.is_multi or ( (pars is None or isinstance(pars, dict)) and (vals0 is None or isinstance(vals0, dict)) ): return self.solve(pars, vals0) return self.solve_multi(pars, vals0, **kwargs) def __str__(self) -> str: """Returns the wrapped NLP string.""" return f"<{self.__class__.__name__}: {self.nlp.__str__()}>" def __repr__(self) -> str: """Returns the wrapped NLP representation.""" return f"<{self.__class__.__name__}: {self.nlp.__repr__()}>"
[docs] class NonRetroactiveWrapper(Wrapper[SymType], Generic[SymType]): """Same as :class:`Wrapper`, but the wrapped NLP instance must have no variable, parameter or objective specified; in other words, the wrapper must wrap the NLP before it gets defined. Parameters ---------- nlp : Nlp The NLP instance to be wrapped. Raises ------ ValueError Raises if the objective, variables, dual variables, parameters or constraints are already defined in this NLP instance. """ def __init__(self, nlp: Nlp[SymType]) -> None: super().__init__(nlp) unlp = nlp.unwrapped if ( unlp._f is not None or unlp._vars or unlp._dual_vars or unlp._pars or unlp._cons ): raise ValueError("Nlp already defined.")