Skip to content

Tree¤

The Tree module provides a module for helpful pytree manipulation functions. It implements two functions, boolean_filter(pytree, parameters) and set_array(pytree, parameters).

boolean_filter(pytree, parameters) returns a matching pytree with boolean leaves, where the leaves specified by parameters are True and the rest are False.

set_array(pytree, parameters) returns a matching pytree with the leaves specified by parameters set to the value of the corresponding leaf in pytree. This is to ensure they have a shape parameter in order to create dynamic array shapes for the bayesian module.

Full API

boolean_filter(pytree, parameters, inverse=False) ¤

Returns a pytree of matching structure with boolean values at the leaves. Leaves specified by paths will be True, all others will be False.

TODO: Possibly improve by setting both true and false simultaneously. Maybe do this with jax keypaths?

Parameters:

Name Type Description Default
pytree PyTree

The pytree to be filtered.

required
parameters Union[str, list]

A path or list of paths or list of nested paths.

required
inverse bool

If True, the boolean values will be inverted, by default False

False

Returns:

Name Type Description
args PyTree

An pytree of matching structre with boolean values at the leaves.

Source code in zodiax/tree.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def boolean_filter(
    pytree     : Base(), 
    parameters : Params, 
    inverse    : bool = False) -> Base():
    """
    Returns a pytree of matching structure with boolean values at the leaves.
    Leaves specified by paths will be True, all others will be False. 

    TODO: Possibly improve by setting both true and false simultaneously.
    Maybe do this with jax keypaths?

    Parameters
    ----------
    pytree : PyTree
        The pytree to be filtered.
    parameters : Union[str, list]
        A path or list of paths or list of nested paths.
    inverse : bool = False
        If True, the boolean values will be inverted, by default False

    Returns
    -------
    args : PyTree
        An pytree of matching structre with boolean values at the leaves.
    """
    parameters = parameters if isinstance(parameters, list) else [parameters]
    if not inverse:
        false_pytree = jtu.tree_map(lambda _: False, pytree)
        return false_pytree.set(parameters, len(parameters) * [True])
    else:
        true_pytree = jtu.tree_map(lambda _: True, pytree)
        return true_pytree.set(parameters, len(parameters) * [False])

set_array(pytree, parameters) ¤

Converts all leaves specified by parameters in the pytree to arrays to ensure they have a .shape property for static dimensionality and size checks. This allows for 'dynamicly generated' array shapes from the path based parameters input. This is used for dynamically generating the latent X parameter that we need to generate in order to calculate the hessian.

Parameters:

Name Type Description Default
pytree Base()

The pytree to be converted.

required
parameters Params

The leaves to be converted to arrays.

required

Returns:

Name Type Description
pytree Base()

The pytree with the specified leaves converted to arrays.

Source code in zodiax/tree.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def set_array(pytree : Base(), parameters : Params) -> Base():
    """
    Converts all leaves specified by parameters in the pytree to arrays to 
    ensure they have a .shape property for static dimensionality and size
    checks. This allows for 'dynamicly generated' array shapes from the path
    based `parameters` input. This is used for dynamically generating the
    latent X parameter that we need to generate in order to calculate the
    hessian.

    Parameters
    ----------
    pytree : Base()
        The pytree to be converted.
    parameters : Params
        The leaves to be converted to arrays.

    Returns
    -------
    pytree : Base()
        The pytree with the specified leaves converted to arrays.
    """
    new_leaves = jtu.tree_map(_to_array, pytree.get(parameters))
    return pytree.set(parameters, new_leaves)

¤

¤

¤