Skip to content

Equinox¤

Deprecation Warning

This module is in the process of being deprecated from Zodiax. The v0.5.0 release will be the last release to include this module, and it will be removed in v0.6.0.

zodiax.eqx ¤

filter_grad(parameters, *filter_args, **filter_kwargs) ¤

Applies the equinox filter_grad function to the input parameters. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
parameters Union[str, List[str]]

The parameters to filter. Can either be a single string path or a list of paths.

required
*filter_args Any

The args to pass to the equinox filter_grad function.

()
**filter_kwargs Any

The kwargs to pass to the equinox filter_grad function.

{}

Returns:

Type Description
Callable

The wrapped function.

Source code in zodiax/eqx.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
def filter_grad(parameters: Params, *filter_args, **filter_kwargs) -> Callable:
    """
    Applies the equinox filter_grad function to the input parameters. The
    corresponding equinox docs are found [here](https://docs.kidger.site/
    equinox/api/filtering/transformations/)

    Parameters
    ----------
    parameters : Union[str, List[str]]
        The parameters to filter. Can either be a single string path or a list
        of paths.
    *filter_args : Any
        The args to pass to the equinox filter_grad function.
    **filter_kwargs : Any
        The kwargs to pass to the equinox filter_grad function.

    Returns
    -------
    Callable
        The wrapped function.
    """
    warnings.warn(
        "filter_grad is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    def wrapper(func: Callable):
        @wraps(func)
        def inner_wrapper(pytree, *args, **kwargs):
            # Convert parameters
            pytree = set_array(pytree)
            bool_filter = boolean_filter(pytree, parameters)

            # Wrap original function
            @eqx.filter_grad(*filter_args, **filter_kwargs)
            def recombine(traced, static):
                return func(eqx.combine(traced, static), *args, **kwargs)

            # Return wrapped function
            return recombine(*eqx.partition(pytree, bool_filter))

        return inner_wrapper

    return wrapper

filter_value_and_grad(parameters, *filter_args, **filter_kwargs) ¤

Applies the equinox filter_value_and_grad function to the input parameters. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
parameters Union[str, List[str]]

The parameters to filter. Can either be a single string path or a list of paths.

required
*filter_args Any

The args to pass to the equinox filter_value_and_grad function.

()
**filter_kwargs Any

The kwargs to pass to the equinox filter_value_and_grad function.

{}

Returns:

Type Description
Callable

The wrapped function.

Source code in zodiax/eqx.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
def filter_value_and_grad(
    parameters: Params, *filter_args, **filter_kwargs
) -> Callable:
    """
    Applies the equinox filter_value_and_grad function to the input parameters.
    The corresponding equinox docs are found [here](https://docs.kidger.site/
    equinox/api/filtering/transformations/)

    Parameters
    ----------
    parameters : Union[str, List[str]]
        The parameters to filter. Can either be a single string path or a list
        of paths.
    *filter_args : Any
        The args to pass to the equinox filter_value_and_grad function.
    **filter_kwargs : Any
        The kwargs to pass to the equinox filter_value_and_grad function.

    Returns
    -------
    Callable
        The wrapped function.
    """
    warnings.warn(
        "filter_value_and_grad is deprecated as of v0.5.0 and will be removed in "
        "v0.6.0",
        DeprecationWarning,
    )

    def wrapper(func: Callable):
        @wraps(func)
        def inner_wrapper(pytree, *args, **kwargs):
            # Convert parameters
            pytree = set_array(pytree)
            bool_filter = boolean_filter(pytree, parameters)

            # Wrap original function
            @eqx.filter_value_and_grad(*filter_args, **filter_kwargs)
            def recombine(traced, static):
                return func(eqx.combine(traced, static), *args, **kwargs)

            # Return wrapped function
            return recombine(*eqx.partition(pytree, bool_filter))

        return inner_wrapper

    return wrapper

partition(pytree, parameters, *partition_args, **partition_kwargs) ¤

Wraps the equinox partition function to take in a list of parameters to partition. The corresponding equinox docs are found here

Parameters:

Name Type Description Default
pytree PyTree

The pytree to partition.

required
parameters Union[str, List[str]]

The parameters to partition. Can either be a single string path or a list of paths.

required
*partition_args Any

The args to pass to the equinox partition function.

()
**partition_kwargs Any

The kwargs to pass to the equinox partition function.

{}

Returns:

Name Type Description
pytree1 PyTree

A matching pytree with Nones at all leaves not specified by the parameters.

pytree2 PyTree

A matching pytree with Nones at all leaves specified by the parameters.

Source code in zodiax/eqx.py
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def partition(
    pytree: PyTree, parameters: Params, *partition_args, **partition_kwargs
) -> tuple:
    """
    Wraps the equinox partition function to take in a list of parameters to
    partition. The corresponding equinox docs are found [here](https://docs.
    kidger.site/equinox/api/filtering/transformations/)

    Parameters
    ----------
    pytree : PyTree
        The pytree to partition.
    parameters : Union[str, List[str]]
        The parameters to partition. Can either be a single string path or a
        list of paths.
    *partition_args : Any
        The args to pass to the equinox partition function.
    **partition_kwargs : Any
        The kwargs to pass to the equinox partition function.

    Returns
    -------
    pytree1 : PyTree
        A matching pytree with Nones at all leaves not specified by the
        parameters.
    pytree2 : PyTree
        A matching pytree with Nones at all leaves specified by the parameters.
    """

    if isinstance(parameters, str):
        parameters = [parameters]

    pytree = set_array(pytree)
    bool_filter = boolean_filter(pytree, parameters)
    return eqx.partition(pytree, bool_filter, *partition_args, **partition_kwargs)