Skip to content

Equinox¤

Zodiax designed to be a 'drop in' replacement for Equinox, this means that all Equinox functions are available through Zodiax! Functions in the main Equinox namespace are raised into the Zodiax namespace, meaning these two line will import the same function:

from equinox import filter_jit
from zodiax import filter_jit

Some Equinox functions are overwritten in order to give a path based interface. Currently there are two functions that are overwritten, filter_grad and filter_value_and_grad. This means that the following two lines will import different functions:

from equinox import filter_grad
from zodiax import filter_grad

Submodules in Equinox are also raised into the Zodiax namespace through the zodiax.equinox submodule. This is how you would import the nn submodule from either Equinox or Zodiax:

from equinox import nn
from zodiax.eqx import nn

There are three methods from Equinox that are overwitten to give them a path based interface. These are filter_grad, filter_value_and_grad, and partition. Their usage can be seen in the 'usage' tutorials.

Full API

filter_grad(parameters, *filter_args, **filter_kwargs) ¤

Applies the equinox filter_grad function to the input parameters. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
parameters Union[str, List[str]]

The parameters to filter. Can either be a single string path or a list of paths.

required
*filter_args Any

The args to pass to the equinox filter_grad function.

()
**filter_kwargs Any

The kwargs to pass to the equinox filter_grad function.

{}

Returns:

Type Description
Callable

The wrapped function.

Source code in zodiax/eqx.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def filter_grad(
    parameters : Params, 
    *filter_args,
    **filter_kwargs
    ) -> Callable:
    """
    Applies the equinox filter_grad function to the input parameters. The 
    corresponding equinox docs are found [here](https://docs.kidger.site/
    equinox/api/filtering/transformations/)

    Parameters
    ----------
    parameters : Union[str, List[str]]
        The parameters to filter. Can either be a single string path or a list
        of paths.
    *filter_args : Any
        The args to pass to the equinox filter_grad function.
    **filter_kwargs : Any
        The kwargs to pass to the equinox filter_grad function.

    Returns
    -------
    Callable
        The wrapped function.
    """
    def wrapper(func : Callable):

        @wraps(func)
        def inner_wrapper(pytree : PyTree, *args, **kwargs):

            # Convert parameters
            boolean_filter = zodiax.tree.boolean_filter(pytree, parameters)

            # Wrap original function
            @equinox.filter_grad(*filter_args, **filter_kwargs)
            def recombine(traced : PyTree, static : PyTree):
                return func(eqx.combine(traced, static), *args, **kwargs)

            # Return wrapped function
            return recombine(*eqx.partition(pytree, boolean_filter))
        return inner_wrapper
    return wrapper

filter_value_and_grad(parameters, *filter_args, **filter_kwargs) ¤

Applies the equinox filter_value_and_grad function to the input parameters. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
parameters Union[str, List[str]]

The parameters to filter. Can either be a single string path or a list of paths.

required
*filter_args Any

The args to pass to the equinox filter_value_and_grad function.

()
**filter_kwargs Any

The kwargs to pass to the equinox filter_value_and_grad function.

{}

Returns:

Type Description
Callable

The wrapped function.

Source code in zodiax/eqx.py
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
def filter_value_and_grad(
    parameters : Params, 
    *filter_args, 
    **filter_kwargs
    ) -> Callable:
    """
    Applies the equinox filter_value_and_grad function to the input parameters.
    The corresponding equinox docs are found [here](https://docs.kidger.site/
    equinox/api/filtering/transformations/)

    Parameters
    ----------
    parameters : Union[str, List[str]]
        The parameters to filter. Can either be a single string path or a list
        of paths.
    *filter_args : Any
        The args to pass to the equinox filter_value_and_grad function.
    **filter_kwargs : Any
        The kwargs to pass to the equinox filter_value_and_grad function.

    Returns
    -------
    Callable
        The wrapped function.
    """
    def wrapper(func : Callable):

        @wraps(func)
        def inner_wrapper(pytree : PyTree, *args, **kwargs):

            # Convert parameters
            boolean_filter = zodiax.tree.boolean_filter(pytree, parameters)

            # Wrap original function
            @equinox.filter_value_and_grad(*filter_args, **filter_kwargs)
            def recombine(traced : PyTree, static : PyTree):
                return func(eqx.combine(traced, static), *args, **kwargs)

            # Return wrapped function
            return recombine(*eqx.partition(pytree, boolean_filter))
        return inner_wrapper
    return wrapper

partition(pytree, parameters, *partition_args, **partition_kwargs) ¤

Wraps the equinox partition function to take in a list of parameters to partition. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
pytree Base()

The pytree to partition.

required
parameters Union[str, List[str]]

The parameters to partition. Can either be a single string path or a list of paths.

required
*partition_args Any

The args to pass to the equinox partition function.

()
**partition_kwargs Any

The kwargs to pass to the equinox partition function.

{}

Returns:

Name Type Description
pytree1 Base()

A matching pytree with Nones at all leaves not specified by the parameters.

pytree2 Base()

A matching pytree with Nones at all leaves specified by the parameters.

Source code in zodiax/eqx.py
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def partition(
    pytree : Base(), 
    parameters : Params, 
    *partition_args, 
    **partition_kwargs) -> tuple:
    """
    Wraps the equinox partition function to take in a list of parameters to
    partition. The corresponding equinox docs are found [here](https://docs.
    kidger.site/equinox/api/filtering/transformations/)

    Parameters
    ----------
    pytree : Base()
        The pytree to partition.
    parameters : Union[str, List[str]]
        The parameters to partition. Can either be a single string path or a
        list of paths.
    *partition_args : Any
        The args to pass to the equinox partition function.
    **partition_kwargs : Any
        The kwargs to pass to the equinox partition function.

    Returns
    -------
    pytree1 : Base()
        A matching pytree with Nones at all leaves not specified by the
        parameters.
    pytree2 : Base()
        A matching pytree with Nones at all leaves specified by the parameters.
    """
    if isinstance(parameters, str):
        parameters = [parameters]
    boolean_filter = zodiax.tree.boolean_filter(pytree, parameters)
    return equinox.partition(pytree, boolean_filter, *partition_args,
        **partition_kwargs)

¤

¤

¤

¤