Skip to content

Tree¤

Deprecation Warning

This module is in the process of being deprecated from Zodiax. The v0.5.0 release will be the last release to include this module, and it will be removed in v0.6.0.

The Tree module provides a module for helpful pytree manipulation functions. It implements two functions, boolean_filter(pytree, parameters) and set_array(pytree, parameters).

boolean_filter(pytree, parameters) returns a matching pytree with boolean leaves, where the leaves specified by parameters are True and the rest are False.

set_array(pytree, parameters) returns a matching pytree with the leaves specified by parameters set to the value of the corresponding leaf in pytree. This is to ensure they have a shape parameter in order to create dynamic array shapes for the bayesian module.

zodiax.tree ¤

boolean_filter(pytree, parameters, inverse=False) ¤

Returns a pytree of matching structure with boolean values at the leaves. Leaves specified by paths will be True, all others will be False.

TODO: Possibly improve by setting both true and false simultaneously. Maybe do this with jax keypaths?

Parameters:

Name Type Description Default
pytree PyTree

The pytree to be filtered.

required
parameters Params

A path or list of paths or list of nested paths.

required
inverse bool = False

If True, the boolean values will be inverted, by default False

False

Returns:

Name Type Description
args PyTree

An pytree of matching structre with boolean values at the leaves.

Source code in zodiax/tree.py
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def boolean_filter(pytree: PyTree, parameters: Params, inverse: bool = False) -> PyTree:
    """
    Returns a pytree of matching structure with boolean values at the leaves.
    Leaves specified by paths will be True, all others will be False.

    TODO: Possibly improve by setting both true and false simultaneously.
    Maybe do this with jax keypaths?

    Parameters
    ----------
    pytree : PyTree
        The pytree to be filtered.
    parameters : Params
        A path or list of paths or list of nested paths.
    inverse : bool = False
        If True, the boolean values will be inverted, by default False

    Returns
    -------
    args : PyTree
        An pytree of matching structre with boolean values at the leaves.
    """
    warnings.warn(
        "boolean_filter is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )
    parameters = parameters if isinstance(parameters, list) else [parameters]
    if not inverse:
        false_pytree = jtu.map(lambda _: False, pytree)
        return false_pytree.set(parameters, len(parameters) * [True])
    else:
        true_pytree = jtu.map(lambda _: True, pytree)
        return true_pytree.set(parameters, len(parameters) * [False])

set_array(pytree, parameters=None) ¤

Converts all leaves in the pytree to arrays to ensure they have a .shape property for static dimensionality and size checks.

Parameters:

Name Type Description Default
pytree PyTree

The pytree to be converted.

required

Returns:

Name Type Description
pytree PyTree

The pytree with the leaves converted to arrays.

Source code in zodiax/tree.py
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def set_array(pytree: PyTree, parameters=None) -> PyTree:
    """
    Converts all leaves in the pytree to arrays to ensure they have a
    .shape property for static dimensionality and size checks.

    Parameters
    ----------
    pytree : PyTree
        The pytree to be converted.

    Returns
    -------
    pytree : PyTree
        The pytree with the leaves converted to arrays.
    """
    warnings.warn(
        "set_array is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )
    # Old routine for setting specified parameters
    if parameters is not None:
        new_leaves = jtu.map(_to_array, pytree.get(parameters))
        return pytree.set(parameters, new_leaves)

    # else convert all leaves to arrays

    # grabbing float data type
    dtype = np.float64 if config.x64_enabled else np.float32

    # partitioning the pytree into arrays and other
    floats, other = eqx.partition(pytree, eqx.is_inexact_array_like)

    # converting the floats to arrays
    floats = jtu.map(lambda x: np.array(x, dtype=dtype), floats)

    # recombining
    return eqx.combine(floats, other)