Tree¤
The Tree module provides a module for helpful pytree manipulation functions. It implements two functions, boolean_filter(pytree, parameters)
and set_array(pytree, parameters)
.
boolean_filter(pytree, parameters)
returns a matching pytree with boolean leaves, where the leaves specified by parameters
are True
and the rest are False
.
set_array(pytree, parameters)
returns a matching pytree with the leaves specified by parameters
set to the value of the corresponding leaf in pytree
. This is to ensure they have a shape parameter in order to create dynamic array shapes for the bayesian module.
Full API
boolean_filter(pytree, parameters, inverse=False)
¤
Returns a pytree of matching structure with boolean values at the leaves. Leaves specified by paths will be True, all others will be False.
TODO: Possibly improve by setting both true and false simultaneously. Maybe do this with jax keypaths?
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
PyTree
|
The pytree to be filtered. |
required |
parameters |
Union[str, list]
|
A path or list of paths or list of nested paths. |
required |
inverse |
bool
|
If True, the boolean values will be inverted, by default False |
False
|
Returns:
Name | Type | Description |
---|---|---|
args |
PyTree
|
An pytree of matching structre with boolean values at the leaves. |
Source code in zodiax/tree.py
18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
|
set_array(pytree, parameters)
¤
Converts all leaves specified by parameters in the pytree to arrays to
ensure they have a .shape property for static dimensionality and size
checks. This allows for 'dynamicly generated' array shapes from the path
based parameters
input. This is used for dynamically generating the
latent X parameter that we need to generate in order to calculate the
hessian.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pytree |
Base()
|
The pytree to be converted. |
required |
parameters |
Params
|
The leaves to be converted to arrays. |
required |
Returns:
Name | Type | Description |
---|---|---|
pytree |
Base()
|
The pytree with the specified leaves converted to arrays. |
Source code in zodiax/tree.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|