"""A collection of utilities for input/output operations. The goals of this module are:
- compatibility of pickling/deepcopying with CasADi objects and classes that hold such
objects (since these are often not picklable).
- saving and loading data to/from files, possibly compressed.
"""
import pickle
from functools import partial as _partial
from os.path import splitext as _splitext
from pickletools import optimize as _optimize
from typing import TYPE_CHECKING, Callable, Literal, Optional
from typing import Any as _Any
if TYPE_CHECKING:
from scipy.io.matlab import mat_struct
_COMPRESSION_EXTS: dict[
str, Optional[Literal["lzma", "bz2", "gzip", "brotli", "blosc2", "matlab", "numpy"]]
] = {
".pkl": None,
".xz": "lzma",
".pbz2": "bz2",
".gz": "gzip",
".bt": "brotli",
".bl2": "blosc2",
".mat": "matlab",
".npz": "numpy",
}
[docs]
def save(
filename: str,
compression: Optional[
Literal["lzma", "bz2", "gzip", "brotli", "blosc2", "matlab", "numpy"]
] = None,
**data: _Any,
) -> str:
"""Saves data to a (possibly compressed) file. Inspired by
`this discussion <https://stackoverflow.com/a/57983757/19648688>`_
and `this other discussion <https://stackoverflow.com/a/8832212/19648688>`_.
Parameters
----------
filename : str
The name of the file to save to. If the filename does not end in the correct
extension, then it is automatically added. The extensions are
- ``"pickle"``: .pkl
- ``"lzma"``: .xz
- ``"bz2"``: .pbz2
- ``"gzip"``: .gz
- ``"brotli"``: .bt
- ``"blosc2"``: .bl2
- ``"matlab"``: .mat
- ``"numpy"``: .npz.
**data : dict
Any data to be saved to a file.
compression : {"lzma", "bz2", "gzip", "brotli", "blosc2", "matlab", "npz"}
Type of compression to apply to the file. By default, `pickle` is used.
Returns
-------
filename : str
The complete name of the file where the data was written to.
Notes
-----
Note that the compression types ``brotli`` and ``blosc2`` require the installation
of the corresponding pip packages (see `Brotli <https://github.com/google/brotli>`_
and `Blosc2 <https://www.blosc.org/python-blosc/python-blosc.html>`_).
``matlab`` requires instead the installation of :mod:`scipy` to save as .mat file
(see :func:`scipy.io.savemat` and :func:`scipy.io.loadmat` for more details).
"""
actual_ext = _splitext(filename)[1]
if compression is None:
compression = _COMPRESSION_EXTS.get(actual_ext)
open_fun: Callable
compress_fun: Callable
if compression is None:
expected_ext = ".pkl"
open_fun = open
compress_fun = lambda o: o
elif compression == "lzma":
import lzma
expected_ext = ".xz"
open_fun = lzma.open
compress_fun = lambda o: o
elif compression == "bz2":
import bz2
expected_ext = ".pbz2"
open_fun = bz2.BZ2File
compress_fun = lambda o: o
elif compression == "gzip":
import gzip
expected_ext = ".gz"
open_fun = gzip.open
compress_fun = lambda o: o
elif compression == "brotli":
import brotli
expected_ext = ".bt"
open_fun = open
compress_fun = brotli.compress
elif compression == "blosc2":
import blosc2
expected_ext = ".bl2"
open_fun = open
compress_fun = _partial(blosc2.compress, typesize=None)
elif compression == "matlab":
expected_ext = ".mat"
elif compression == "numpy":
expected_ext = ".npz"
else:
raise ValueError(f"Unknown compression method {compression}.")
if expected_ext != actual_ext:
filename += expected_ext
# address first special cases that do not adhere to the open/compress scheme
if compression == "matlab":
import scipy.io as spio
spio.savemat(filename, data, do_compression=True, oned_as="column")
elif compression == "numpy":
import numpy as np
np.savez_compressed(filename, **data)
# address all other cases that do adhere to the open/compress scheme
else:
pickled = pickle.dumps(data)
optimized = _optimize(pickled)
compressed = compress_fun(optimized)
with open_fun(filename, "wb") as f:
f.write(compressed)
return filename
[docs]
def load(filename: str) -> dict[str, _Any]:
"""Loads data from a (possibly compressed) file.
Parameters
----------
filename : str, optional
The name of the file to load. If the filename does not end in a known extension,
then it fails. The known extensions are
- ``"pickle"``: .pkl
- ``"lzma"``: .xz
- ``"bz2"``: .pbz2
- ``"gzip"``: .gz
- ``"brotli"``: .bt
- ``"blosc2"``: .bl2
- ``"matlab"``: .mat
- ``"numpy"``: .npz.
Returns
-------
data : dict
The saved data in the shape of a dictionary.
"""
ext = _splitext(filename)[1]
compression = _COMPRESSION_EXTS[ext]
open_fun: Callable
decompress_fun: Callable
if compression is None:
open_fun = open
decompress_fun = pickle.loads
elif compression == "lzma":
import lzma
open_fun = lzma.open
decompress_fun = pickle.loads
elif compression == "bz2":
import bz2
open_fun = bz2.BZ2File
decompress_fun = pickle.loads
elif compression == "gzip":
import gzip
open_fun = gzip.open
decompress_fun = pickle.loads
elif compression == "brotli":
import brotli
open_fun = open
decompress_fun = lambda o: pickle.loads(brotli.decompress(o))
elif compression == "blosc2":
import blosc2
open_fun = open
decompress_fun = lambda o: pickle.loads(blosc2.decompress(o))
elif compression not in ("matlab", "numpy"):
raise ValueError(f"Unknown file extension {ext}.")
# address first special cases that do not adhere to the open/decompress scheme
if compression == "matlab":
import scipy.io as spio
data = _check_mat_keys(
spio.loadmat(filename, struct_as_record=False, squeeze_me=True),
spio.matlab.mat_struct,
)
elif compression == "numpy":
import numpy as np
with np.load(filename, allow_pickle=True) as file:
data = dict(file)
# address all other cases that do adhere to the open/decompress scheme
else:
with open_fun(filename, "rb") as f:
data = decompress_fun(f.read())
# if it is only a dict with one key, return the value of the key directly.
if isinstance(data, dict) and len(data.keys()) == 1:
data = data[next(iter(data.keys()))]
return data
def _check_mat_keys(dictionary: dict, mat_struct_type: type) -> dict:
"""Internal utility to check if entries in dictionary are mat-objects. If yes,
todict is called to change them to nested dictionaries."""
def _todict_recursive(matobj: "mat_struct") -> dict:
dictionary = {}
for strg in matobj._fieldnames:
elem = matobj.__dict__[strg]
dictionary[strg] = (
_todict_recursive(elem) if isinstance(elem, mat_struct_type) else elem
)
return dictionary
for bad_key in ("__header__", "__version__", "__globals__"):
dictionary.pop(bad_key, None)
for key, value in dictionary.items():
if isinstance(value, mat_struct_type):
dictionary[key] = _todict_recursive(value)
return dictionary