Skip to content

Zodiax¤

PyPI version License integration codecov Documentation


Zodiax is a lightweight extension to the object-oriented Jax framework Equinox. Equinox allows for differentiable classes that are recognised as a valid Jax type and Zodiax adds lightweight methods to simplify interfacing with these classes! Zodiax was originially built in the development of dLux and was designed to make working with large nested classes structures simple and flexible.

Zodiax is directly integrated with both Jax and Equinox, gaining all of their core features:

Documentation: louisdesdoigts.github.io/zodiax/

Note: The Zodiax tutorials live in a separate repo. This allows users to directly download and run the notebooks, ensuring that the correct packages needed to run them are installed! It also allows for new tutorials and examples to be added quite easily, without needing to update the core library.

Contributors: Louis Desdoigts

Requires: Python 3.10+, Jax 0.4.25+

Installation:

pip install zodiax

Development installation:

pip install "zodiax[dev]"

Coverage:

pytest --cov=zodiax --cov-report=term-missing --cov-report=xml --cov-report=html tests

This writes coverage.xml and an htmlcov/ report for local inspection.


Quickstart¤

Create a regular class that inherits from zodiax.Base

import jax
import zodiax as zdx
import jax.numpy as np

class Linear(zdx.Base):
    m : jax.Array
    b : jax.Array

    def __init__(self, m, b):
        self.m = m
        self.b = b

    def __call__(self, x):
        return self.m * x + self.b

linear = Linear(1., 1.)

Its that simple! The linear class is now a fully differentiable object that gives us all the benefits of jax with an object-oriented interface! Lets see how we can jit-compile and take gradients of this class.

@jax.jit
@jax.grad
def loss_fn(model, xs, ys):
    return np.square(model(xs) - ys).sum()

xs = np.arange(5)
ys = 2*np.arange(5)
grads = loss_fn(linear, xs, ys)
print(grads)
print(grads.m, grads.b)
> Linear(m=f32[], b=f32[])
> -40.0 -10.0

The grads object is an instance of the Linear class with the gradients of the parameters with respect to the loss function!

Update Signatures (Minimal Overview)¤

Most Zodiax update methods (set, add, multiply, divide, power, min, max) support three equivalent input styles:

  1. (parameters, values) positional style
  2. {parameter: value} dictionary style
  3. param=value keyword style (and **{"nested.path": value} for nested paths)
# 1) Positional: (parameters, values)
linear = linear.set(["m", "b"], [2.0, 0.5])

# 2) Dictionary: {parameter: value}
linear = linear.add({"m": 0.1, "b": -0.2})

# 3) Keyword: param=value
linear = linear.multiply(m=2.0, b=0.5)

Use whichever style is clearest for your workflow. The operations remain immutable and return new objects.