Skip to content

Serialisation¤

Warning

This module is experimental and is subject to change!

Overview¤

This module is designed to be able to save and load models created in zodiax to and from a file. This is useful for saving the model for later use, or for sharing the model with others. Serialisation is a generally difficult problem to make fully robust, but can be constructed to cover most use cases!

There are two main types of functions in this module: the structure function and the serialisation functions. There are two structure functions:

  1. structure = build_structure(obj)
  2. obj = load_structure(structure)

The build_structure function traverses the input object and returns a structure dictionary that can be serialised. Each parameter in the object is either a 'container' node or a 'leaf' node, allowing the the full structure to be represented along with any nessecary meta-data required to reconstruct the object. The load_structure() function takes this structure dictionary and returns a pytree of the same structure that can be used in conjunction with equinox.tree_serialise_leaves() to return an identical object.

The serialisation functions are:

  1. serialise(obj, path)
  2. obj = deserialise(path)

The serialise function takes an object and a path, and saves the serialised object to the path. The deserialise function takes a path and returns the deserialised object.


Future changes¤

There are some future improvements that are planned for this module, hence the present experimental status!

  • Serialise package versions:

To try and ensure that the serialised object can be deserialised, the package versions should be serialised. This will allow the code to automatically check imported versions and raise warnings for imported package discrepancies.

  • Add support for serialising functions:

This should also raise warning as functions can not in general be robustly serialised, but should be supported.

  • Deal with static_fields:

There is a general issue with parameters in models that are marked as equinox.static_field(). Although this should rarely if even be used by the user, it is still a potential issue. Since the equinox.tree_serialise_leaves() function uses tree_map functions it is blind to these parameters. If this parameter is a string it is fine, however if it is some other data type it will at present not be serialised. This can be fixed by using the tree_flatten() function to determine what parameters are static and serialising them using a different method.

  • Add support for serialising general objects:

In order to deal with the above static_field() issue, we must add support for serialising general python types, along wth array types.

  • Implement robust tests:

The tests for this module are currently very basic, primarily becuase of the the tests are run in isolated enviroments, so classes that are created for the tests can not be re-imported.

When these changes have been implemented this module can be moved into the main zodiax package.

Full API

build_structure(obj, self_key=None, depth=0, _print=False) ¤

Recursively iterates over the input object in order to return a dictionary detailing the strucutre of the of the object. Each node can be either a conainter node or leaf node. Each node is a dictionary with the following structure:

{'node_type': 'container' or 'leaf', 'type': str, 'node': { param1 : {'node_type' : 'container', ...}, -> If container param2 : {'node_type' : 'leaf', '...' : ...}, -> If leaf conatining any leaf metadata }

Specific leaf metadata: Strings: String values are stored in the 'value' key and serialised via the returned structure dictionary. Jax/Numpy Arrays: Both the array shape and dtype are stored in the 'shape' and 'dtype' keys respectively.

This method can be developed further to support more leaf types, since each individual leaf type can be made to store any arbitrarity metadata, as long as it can be serialised by json and used to deserialise it later.

This dictionary can then be serialised using pickle and then later used to deserialise the object in conjunction with equinox leaf serialise/deserialise methods.

NOTE: This method is not equipped to handle equinox.static_field() parameters, as they can be arbitrary data types but do not get serialised by the equinox.serialise_tree_leaves() methods and hence require custom serialisation via this method. Therefore this method currently does not handle this case correctly. This is not checked for currently so will silently break or result in unexpected behaviour.

TODO: Serialise package versions in order to raise warnings when deserialising about inconsistent versions.

Parameters:

Name Type Description Default
obj Any

The object to get the leaves of.

required
self_key str

The key of the object in the parent container. Use to print the tree structure for debugging.

None
depth int

The depth of the object in the tree. Use to print the tree structure for debugging.

0
_print bool

If True, print the tree structure for debugging.

False

Returns:

Name Type Description
structure dict

The dictionary detailing the structure of the object.

Source code in zodiax/experimental/serialisation.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def build_structure(obj      : Any, 
                    self_key : str  = None, 
                    depth    : int  = 0,
                    _print   : bool = False):
    """
    Recursively iterates over the input object in order to return a dictionary 
    detailing the strucutre of the of the object. Each node can be either a
    conainter node or leaf node. Each node is a dictionary with the following
    structure:

    {'node_type': 'container' or 'leaf',
     'type': str,
     'node': {
        param1 : {'node_type' : 'container', ...}, -> If container
        param2 : {'node_type' : 'leaf',
                  '...' : ...}, -> If leaf conatining any leaf metadata
        }

    Specific leaf metadata:
        Strings:
            String values are stored in the 'value' key and serialised via the
            returned structure dictionary.
        Jax/Numpy Arrays:
            Both the array shape and dtype are stored in the 'shape' and
            'dtype' keys respectively. 

    This method can be developed further to support more leaf types, since each
    individual leaf type can be made to store any arbitrarity metadata, as long
    as it can be serialised by json and used to deserialise it later.

    This dictionary can then be serialised using pickle and then later
    used to deserialise the object in conjunction with equinox leaf 
    serialise/deserialise methods.

    NOTE: This method is not equipped to handle `equinox.static_field()` 
    parameters, as they can be arbitrary data types but do not get serialised
    by the  `equinox.serialise_tree_leaves()` methods and hence require custom 
    serialisation via this method. Therefore this method currently does not
    handle this case correctly. This is not checked for currently so will
    silently break or result in unexpected behaviour.

    TODO: Serialise package versions in order to raise warnings when 
    deserialising about inconsistent versions.

    Parameters
    ----------
    obj : Any
        The object to get the leaves of.
    self_key : str = None
        The key of the object in the parent container. Use to print the tree
        structure for debugging.
    depth : int = 0
        The depth of the object in the tree. Use to print the tree structure
        for debugging.
    _print : bool = False
        If True, print the tree structure for debugging.

    Returns
    -------
    structure : dict
        The dictionary detailing the structure of the object.
    """
    structure = {}
    is_container = _check_node(obj, self_key, depth, _print=_print)

    # Recursive case
    if is_container:
        keys, accessor = _get_accessor(obj)

        # Iterate over parameters
        for key in keys:
            sub_obj = accessor(obj, key)
            structure[key] = _build_node(sub_obj, key, depth, _print)

        # Deal with outermost container
        if depth == 0:
            return _build_conatiner_node(obj, structure)
        else:
            return structure

    # Base case    
    else:
        return obj

deserialise(path) ¤

Deserialises the input object at the input path.

Parameters:

Name Type Description Default
path str

The path to serialise the object to.

required

Returns:

Name Type Description
obj Any

The deserialised object.

Source code in zodiax/experimental/serialisation.py
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
def deserialise(path : str):
    """
    Deserialises the input object at the input path.

    Parameters
    ----------
    path : str
        The path to serialise the object to.

    Returns
    -------
    obj : Any
        The deserialised object.
    """
    with open(path, 'rb') as f:
        structure = pickle.load(f)
        like = load_structure(structure)
        obj = tree_deserialise_leaves(f, like)
    return obj

load_structure(structure) ¤

Load a structure from a dictionary to later be used in conjuction with eqx.tree_deserialise_leaves().

Custom leaf node desrialisation is handled by the _load_leaf function.

Parameters:

Name Type Description Default
structure dict

The structure to load.

required

Returns:

Name Type Description
obj object

The loaded structure.

Source code in zodiax/experimental/serialisation.py
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
def load_structure(structure : dict) -> object:
    """
    Load a structure from a dictionary to later be used in conjuction with
    `eqx.tree_deserialise_leaves()`.


    Custom leaf node desrialisation is handled by the `_load_leaf` function.

    Parameters
    ----------
    structure : dict
        The structure to load.

    Returns
    -------
    obj : object
        The loaded structure.
    """
    # Construct the object
    obj = _construct_class(structure['type'])

    # Container Node
    if structure['node_type'] == 'container': 

        # Iterarte over all parameters and update the object
        for key, value in structure['node'].items():
            obj = _load_container(obj, key, value)
        return obj

    # Leaf Node
    else: 
        return _load_leaf(obj, structure)

serialise(path, obj) ¤

Serialises the input zodiax pytree to the input path. This method works by creating a dictionary detailing the structure of the object to be serialised. This dictionary is then serialised using pickle and the pytree leaves are serialised using equinox.serialise_tree_leaves(). This object can then be deserialised using the deserialise() method.

This method is currently considered experimental for a number of reasons: - Some objects can not be gaurenteed to be deserialised correctly. - User-defined classes can be serialised but it is up to the user to import the class into the global namespace when deserialising. - User defined functions can not be gaurenteed to be deserialised correctly. - Different versions of packages can cause issues when deserialising. This metadata is planned to be serialised in the future and have warnings raised when deserialising. - static_field() parameters are not handled correctly. Since array types can be set as static_field() parameters, they are not serialised by equinox.serialise_tree_leaves() and hence require custom serialisation via this method. This is not checked for currently so will silently break. This can be fixed with some pre-filtering and type checking using the .tree_flatten() method.

Parameters:

Name Type Description Default
path str

The path to serialise the object to.

required
obj Any

The object to serialise.

required
Source code in zodiax/experimental/serialisation.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
def serialise(path : str, obj : Any) -> None:
    """
    Serialises the input zodiax pytree to the input path. This method works by
    creating a dictionary detailing the structure of the object to be
    serialised. This dictionary is then serialised using `pickle` and the 
    pytree leaves are serialised using `equinox.serialise_tree_leaves()`. This
    object can then be deserialised using the `deserialise()` method. 

    This method is currently considered experimental for a number of reasons:
     - Some objects can not be gaurenteed to be deserialised correctly. 
     - User-defined classes _can_ be serialised but it is up to the user to 
     import the class into the global namespace when deserialising. 
     - User defined functions can not be gaurenteed to be deserialised
     correctly.
     - Different versions of packages can cause issues when deserialising. This
    metadata is planned to be serialised in the future and have warnings raised
    when deserialising.
     - static_field() parameters are not handled correctly. Since array types
     can be set as static_field() parameters, they are not serialised by
     `equinox.serialise_tree_leaves()` and hence require custom serialisation
     via this method. This is not checked for currently so will silently break.
     This can be fixed with some pre-filtering and type checking using the
     `.tree_flatten()` method.

    Parameters
    ----------
    path : str
        The path to serialise the object to.
    obj : Any
        The object to serialise.
    """
    # Check path type
    if not isinstance(path, (str, Path)):
        raise TypeError(f'path must be a string or Path, not {type(path)}')
    else:
        # Convert to string in case of Path for adding .zdx extension
        path = str(path)

    # Add default .zdx extension
    if len(path.split('.')) == 1:
        path += '.zdx'

    # Serialise
    structure = build_structure(obj)
    with open(path, 'wb') as f:
        pickle.dump(structure, f)
        tree_serialise_leaves(f, obj)

¤

¤

¤

¤

¤