Equinox¤
Zodiax designed to be a 'drop in' replacement for Equinox, this means that all Equinox functions are available through Zodiax! Functions in the main Equinox namespace are raised into the Zodiax namespace, meaning these two line will import the same function:
from equinox import filter_jit
from zodiax import filter_jit
Some Equinox functions are overwritten in order to give a path based interface. Currently there are two functions that are overwritten, filter_grad
and filter_value_and_grad
. This means that the following two lines will import different functions:
from equinox import filter_grad
from zodiax import filter_grad
Submodules in Equinox are also raised into the Zodiax namespace through the zodiax.equinox
submodule. This is how you would import the nn
submodule from either Equinox or Zodiax:
from equinox import nn
from zodiax.eqx import nn
There are three methods from Equinox that are overwitten to give them a path based interface. These are filter_grad
, filter_value_and_grad
, and partition
. Their usage can be seen in the 'usage' tutorials.
Full API
filter_grad(parameters, *filter_args, **filter_kwargs)
¤
Applies the equinox filter_grad function to the input parameters. The corresponding equinox docs are found here
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters |
Union[str, List[str]]
|
The parameters to filter. Can either be a single string path or a list of paths. |
required |
*filter_args |
Any
|
The args to pass to the equinox filter_grad function. |
()
|
**filter_kwargs |
Any
|
The kwargs to pass to the equinox filter_grad function. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
The wrapped function. |
Source code in zodiax/eqx.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
filter_value_and_grad(parameters, *filter_args, **filter_kwargs)
¤
Applies the equinox filter_value_and_grad function to the input parameters. The corresponding equinox docs are found here
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parameters |
Union[str, List[str]]
|
The parameters to filter. Can either be a single string path or a list of paths. |
required |
*filter_args |
Any
|
The args to pass to the equinox filter_value_and_grad function. |
()
|
**filter_kwargs |
Any
|
The kwargs to pass to the equinox filter_value_and_grad function. |
{}
|
Returns:
Type | Description |
---|---|
Callable
|
The wrapped function. |
Source code in zodiax/eqx.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 |
|
partition(pytree, parameters, *partition_args, **partition_kwargs)
¤
Wraps the equinox partition function to take in a list of parameters to partition. The corresponding equinox docs are found here
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
Base()
|
The pytree to partition. |
required |
parameters |
Union[str, List[str]]
|
The parameters to partition. Can either be a single string path or a list of paths. |
required |
*partition_args |
Any
|
The args to pass to the equinox partition function. |
()
|
**partition_kwargs |
Any
|
The kwargs to pass to the equinox partition function. |
{}
|
Returns:
Name | Type | Description |
---|---|---|
pytree1 |
Base()
|
A matching pytree with Nones at all leaves not specified by the parameters. |
pytree2 |
Base()
|
A matching pytree with Nones at all leaves specified by the parameters. |
Source code in zodiax/eqx.py
103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
|