Skip to content

Optimisation¤

The zodiax.optimisation module contains only a single function, get_optmiser. It is a simple interface designed to apply Optax optimisers to individual leaves!

Full API

get_optimiser(pytree, parameters, optimisers) ¤

Returns an Optax.GradientTransformion object, with the optimisers specified by optimisers applied to the leaves specified by parameters.

Parameters:

Name Type Description Default
pytree Base

A zodiax.base.Base object.

required
parameters Union[str, List[str]]

A path or list of parameters or list of nested parameters.

required
optimisers Union[optax.GradientTransformation, list]

A optax.GradientTransformation or list of optax.GradientTransformation objects to be applied to the leaves specified by parameters.

required

Returns:

Name Type Description
optimiser optax.GradientTransformion

TODO Update A tuple of (Optax.GradientTransformion, optax.MultiTransformState) objects, with the optimisers applied to the leaves specified by parameters, and the initialised optimisation state.

state optax.MultiTransformState
Source code in zodiax/optimisation.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def get_optimiser(pytree     : Base(),
                  parameters : Params,
                  optimisers : Optimisers,
                  ) -> tuple:
    """
    Returns an Optax.GradientTransformion object, with the optimisers
    specified by optimisers applied to the leaves specified by parameters.

    Parameters
    ----------
    pytree : Base
        A zodiax.base.Base object.
    parameters :  Union[str, List[str]]
        A path or list of parameters or list of nested parameters.
    optimisers : Union[optax.GradientTransformation, list]
        A optax.GradientTransformation or list of
        optax.GradientTransformation objects to be applied to the leaves
        specified by parameters.

    Returns
    -------
    optimiser : optax.GradientTransformion
        TODO Update
        A tuple of (Optax.GradientTransformion, optax.MultiTransformState)
        objects, with the optimisers applied to the leaves specified by
        parameters, and the initialised optimisation state.
    state : optax.MultiTransformState
    """
    # Pre-wrap single inputs into a list since optimisers have a length of 2
    if not isinstance(optimisers, list):
        optimisers = [optimisers]

    # parameters have to default be wrapped in a list to match optimiser
    if isinstance(parameters, str):
        parameters = [parameters]

    # Construct groups and get param_spec
    groups = [str(i) for i in range(len(optimisers))]
    param_spec = tree_map(lambda _: "null", pytree)
    param_spec = param_spec.set(parameters, groups)

    # Generate optimiser dictionary and Assign the null group
    opt_dict = dict([(groups[i], optimisers[i]) \
                        for i in range(len(groups))])
    opt_dict["null"] = adam(0.0)

    # Get optimiser object and filtered optimiser
    optim = multi_transform(opt_dict, param_spec)
    opt_state = optim.init(eqx_filter(pytree, is_array))

    # Return
    return (optim, opt_state)

¤

¤