Skip to content

Optimisation¤

The zodiax.optimisation module contains functions to provide a simple interface to apply Optax optimisers to individual leaves!

zodiax.optimisation ¤

adam(lr, start, *schedule) ¤

Wrapper for the optax Adam optimiser with a piecewise constant learning rate schedule.

Parameters:

Name Type Description Default
lr float

The initial learning rate.

required
start int

The starting step (learning rate will be ~0 before this).

required
args tuple

A variable number of tuples, each containing a step and a multiplier.

required

Returns:

Name Type Description
optimiser adam

The optimiser with the piecewise constant learning rate schedule.

Source code in zodiax/optimisation.py
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def adam(lr: float, start: int, *schedule):
    """
    Wrapper for the optax Adam optimiser with a piecewise constant learning rate
    schedule.

    Parameters
    ----------
    lr : float
        The initial learning rate.
    start : int
        The starting step (learning rate will be ~0 before this).
    args : tuple
        A variable number of tuples, each containing a step and a multiplier.

    Returns
    -------
    optimiser : optax.adam
        The optimiser with the piecewise constant learning rate schedule.
    """
    warnings.warn(
        "adam is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )
    return _base_adam(scheduler(lr, start, *schedule))

debug_nan_check(grads) ¤

Checks for NaN values in the gradients and triggers a breakpoint if any are found.

Parameters:

Name Type Description Default
grads PyTree

The gradients to be checked for NaN values.

required

Returns:

Name Type Description
grads PyTree

The gradients.

Source code in zodiax/optimisation.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def debug_nan_check(grads: PyTree) -> PyTree:
    """
    Checks for NaN values in the gradients and triggers a breakpoint if any are found.

    Parameters
    ----------
    grads : PyTree
        The gradients to be checked for NaN values.

    Returns
    -------
    grads : PyTree
        The gradients.
    """
    bool_tree = jtu.map(lambda x: np.isnan(x).any(), grads)
    vals = np.array(jtu.flatten(bool_tree)[0])
    eqx.debug.breakpoint_if(vals.sum() > 0)
    return grads

decompose(matrix, hermitian=True, normalise=False) ¤

Returns: eigvals: (D,) array sorted descending eigvecs: (D, D) array where each ROW is an eigenvector

Source code in zodiax/optimisation.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
def decompose(
    matrix: Array, hermitian: bool = True, normalise: bool = False
) -> tuple[Array, Array]:
    """
    Returns:
        eigvals: (D,) array sorted descending
        eigvecs: (D, D) array where each ROW is an eigenvector
    """
    if hermitian:
        eigvals, eigvecs = np.linalg.eigh(matrix)
        # Flip to descending and transpose so rows = vectors
        eigvals, eigvecs = eigvals[::-1], eigvecs.T[::-1]
    else:
        eigvals, eigvecs = np.linalg.eig(matrix)
        # Sort manually for non-hermitian consistency
        idx = np.argsort(eigvals.real)[::-1]
        eigvals, eigvecs = eigvals.real[idx], eigvecs.real.T[idx]

    if normalise:
        eigvals /= eigvals[0]

    return eigvals, eigvecs

delay(lr, start, length=1) ¤

Delays the learning rate by starting at 0 and linearly increasing to the specified learning rate over a specified number of steps.

Source code in zodiax/optimisation.py
101
102
103
104
105
106
def delay(lr: float, start: int, length: int = 1) -> optax.Schedule:
    """
    Delays the learning rate by starting at 0 and linearly increasing to the specified
    learning rate over a specified number of steps.
    """
    return optax.linear_schedule(0.0, lr, length, start)

eigen_projection(fmat=None, cov=None) ¤

Projects the parameter space into the an orthonormal basis

TODO: develop docs more

Source code in zodiax/optimisation.py
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def eigen_projection(fmat: Array = None, cov: Array = None) -> Array:
    """
    Projects the parameter space into the an orthonormal basis

    TODO: develop docs more
    """
    # Make sure we have one input
    if fmat is None and cov is None:
        raise ValueError("Must provide either fmat or cov")

    # Select the matrix to decompose
    mat = fmat if fmat is not None else cov

    # Ensure symmetry to avoid complex eigvals from numerical noise
    mat = (mat + mat.T) / 2.0

    # Decompose the matrix to get eigenvalues and eigenvectors
    vals, vecs = decompose(mat, normalise=False)

    # 1. Extract physical scales (the diagonal)
    # For Fisher, diag is 1/sigma^2. For Cov, diag is sigma^2.
    diag = np.diag(mat)

    if fmat is not None:
        # Physical 'step' size for each parameter
        phys_scale = 1.0 / np.sqrt(diag)
    else:
        phys_scale = np.sqrt(diag)

    # 2. Convert to Correlation-like matrix (dimensionless)
    # This prevents physical units from dominating the eigenvalue spectrum
    S_inv = 1.0 / np.sqrt(diag)
    norm_mat = S_inv[:, None] * mat * S_inv[None, :]

    # 3. Decompose the normalized matrix
    # All diagonal elements are now 1.0.
    # Eigenvalues now represent 'redundancy' or 'degeneracy' rather than 'units'.
    vals, vecs = decompose(norm_mat, normalise=False)

    # 4. Build the Projection Matrix P
    # P = (Physical Scale) * (Rotation) * (Orthonormal Scaling)
    # We apply the phys_scale back to return to the original units
    if fmat is not None:
        # For Fisher, eigenvalues of norm_mat are 'information'
        inner_scale = 1.0 / np.sqrt(vals)
    else:
        # For Cov, eigenvalues of norm_mat are 'variance'
        inner_scale = np.sqrt(vals)

    # P maps: (Reduced Space) -> (Dimensionless Space) -> (Physical Space)
    P = vecs.T * inner_scale[None, :]
    P = phys_scale[:, None] * P

    return P

get_optimiser(pytree, parameters, optimisers) ¤

Returns an Optax.GradientTransformion object, with the optimisers specified by optimisers applied to the leaves specified by parameters.

Parameters:

Name Type Description Default
pytree PyTree

A zodiax.base.PyTree object.

required
parameters Params

A path or list of parameters or list of nested parameters.

required
optimisers Optimisers

A optax.GradientTransformation or list of optax.GradientTransformation objects to be applied to the leaves specified by parameters.

required

Returns:

Name Type Description
optimiser GradientTransformion

TODO Update A tuple of (Optax.GradientTransformion, optax.MultiTransformState) objects, with the optimisers applied to the leaves specified by parameters, and the initialised optimisation state.

state MultiTransformState
Source code in zodiax/optimisation.py
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def get_optimiser(
    pytree: PyTree,
    parameters: Params,
    optimisers: Optimisers,
) -> tuple:
    """
    Returns an Optax.GradientTransformion object, with the optimisers
    specified by optimisers applied to the leaves specified by parameters.

    Parameters
    ----------
    pytree : PyTree
        A zodiax.base.PyTree object.
    parameters : Params
        A path or list of parameters or list of nested parameters.
    optimisers : Optimisers
        A optax.GradientTransformation or list of
        optax.GradientTransformation objects to be applied to the leaves
        specified by parameters.

    Returns
    -------
    optimiser : optax.GradientTransformion
        TODO Update
        A tuple of (Optax.GradientTransformion, optax.MultiTransformState)
        objects, with the optimisers applied to the leaves specified by
        parameters, and the initialised optimisation state.
    state : optax.MultiTransformState
    """
    warnings.warn(
        "get_optimiser is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    # Pre-wrap single inputs into a list since optimisers have a length of 2
    if not isinstance(optimisers, list):
        optimisers = [optimisers]

    # parameters have to default be wrapped in a list to match optimiser
    if isinstance(parameters, str):
        parameters = [parameters]

    # Construct groups and get param_spec
    groups = [str(i) for i in range(len(optimisers))]
    param_spec = jtu.map(lambda _: "null", pytree)
    param_spec = param_spec.set(parameters, groups)

    # Generate optimiser dictionary and Assign the null group
    opt_dict = dict([(groups[i], optimisers[i]) for i in range(len(groups))])
    opt_dict["null"] = optax.sgd(0.0)

    # Get optimiser object and filtered optimiser
    optim = optax.multi_transform(opt_dict, param_spec)
    opt_state = optim.init(eqx.filter(pytree, eqx.is_array))

    # Return
    return (optim, opt_state)

map_optimisers(params, optimisers, strict=False) ¤

Maps optimiser from a dictionary of optax optimisers to a dictionary of parameters.

TODO: Develop docs more

Source code in zodiax/optimisation.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
def map_optimisers(params: dict, optimisers: dict, strict: bool = False) -> tuple:
    """
    Maps optimiser from a dictionary of optax optimisers to a dictionary of parameters.

    TODO: Develop docs more
    """

    # unpack the dicts to ensure matching structures
    params = _unpack(params)
    optimisers = _unpack(optimisers)

    # Make sure all parameters are floating-point jax arrays so they return gradients
    params = jtu.map(lambda x: np.array(x, float), params)

    if strict and params.keys() != optimisers.keys():
        raise ValueError("Params and optimisers must have the same keys")

    # Check for keys in params that aren't in optimisers, and paste an empty optimiser
    for key in params.keys():
        if key not in optimisers.keys():
            optimisers[key] = optax.identity()

    # Check for keys in optimisers that aren't in params, and set to zero
    for key in optimisers.keys():
        if key not in params.keys():
            params[key] = 0.0

    param_spec = {param: param for param in optimisers.keys()}
    optim = optax.multi_transform(optimisers, param_spec)
    state = optim.init(params)
    return optim, state

scheduler(lr, start, *args) ¤

Function to easily interface with the optax library to create a piecewise constant learning rate schedule. The function takes a learning rate, a starting step and optionally, a variable number of tuples. Each tuple should contain a step and a multiplier; the learning rate will be multiplied by the corresponding multiplier at the specified step.

Parameters:

Name Type Description Default
lr float

The initial learning rate.

required
start int

The starting step (learning rate will be ~0 before this).

required
args tuple

A variable number of tuples, each containing a step and a multiplier.

()

Returns:

Name Type Description
schedule schedule

The piecewise constant learning rate schedule.

Source code in zodiax/optimisation.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def scheduler(lr: float, start: int, *args):
    """
    Function to easily interface with the optax library to create a piecewise
    constant learning rate schedule. The function takes a learning rate, a
    starting step and optionally, a variable number of tuples. Each tuple
    should contain a step and a multiplier; the learning rate will be multiplied
    by the corresponding multiplier at the specified step.

    Parameters
    ----------
    lr : float
        The initial learning rate.
    start : int
        The starting step (learning rate will be ~0 before this).
    args : tuple
        A variable number of tuples, each containing a step and a multiplier.

    Returns
    -------
    schedule : optax.schedule
        The piecewise constant learning rate schedule.
    """
    warnings.warn(
        "scheduler is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    # Create a dictionary to store the schedule
    sched_dict = {start: BIG}

    # looping over learning rate updates
    for start, mul in args:
        sched_dict[start] = mul

    return optax.piecewise_constant_schedule(lr / BIG, sched_dict)

sgd(lr, start, *schedule) ¤

Wrapper for the optax SGD optimiser with a piecewise constant learning rate schedule.

Parameters:

Name Type Description Default
lr float

The initial learning rate.

required
start int

The starting step (learning rate will be ~0 before this).

required
args tuple

A variable number of tuples, each containing a step and a multiplier.

required

Returns:

Name Type Description
optimiser sgd

The optimiser with the piecewise constant learning rate schedule.

Source code in zodiax/optimisation.py
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def sgd(lr: float, start: int, *schedule):
    """
    Wrapper for the optax SGD optimiser with a piecewise constant learning rate
    schedule.

    Parameters
    ----------
    lr : float
        The initial learning rate.
    start : int
        The starting step (learning rate will be ~0 before this).
    args : tuple
        A variable number of tuples, each containing a step and a multiplier.

    Returns
    -------
    optimiser : optax.sgd
        The optimiser with the piecewise constant learning rate schedule.
    """
    warnings.warn(
        "sgd is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )
    return _base_sgd(scheduler(lr, start, *schedule))

zero_nan_check(grads) ¤

Replaces any NaN values in the gradients and with zeros.

Parameters:

Name Type Description Default
grads PyTree

The gradients to be checked for NaN values.

required

Returns:

Name Type Description
grads PyTree

The gradients with NaN values replaced by zeros.

Source code in zodiax/optimisation.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def zero_nan_check(grads: PyTree) -> PyTree:
    """
    Replaces any NaN values in the gradients and with zeros.

    Parameters
    ----------
    grads : PyTree
        The gradients to be checked for NaN values.

    Returns
    -------
    grads : PyTree
        The gradients with NaN values replaced by zeros.
    """
    return jtu.map(lambda x: np.where(np.isnan(x), 0.0, x), grads)