Tree¤
Deprecation Warning
This module is in the process of being deprecated from Zodiax. The v0.5.0 release will be the last release to include this module, and it will be removed in v0.6.0.
The Tree module provides a module for helpful pytree manipulation functions. It implements two functions, boolean_filter(pytree, parameters) and set_array(pytree, parameters).
boolean_filter(pytree, parameters) returns a matching pytree with boolean leaves, where the leaves specified by parameters are True and the rest are False.
set_array(pytree, parameters) returns a matching pytree with the leaves specified by parameters set to the value of the corresponding leaf in pytree. This is to ensure they have a shape parameter in order to create dynamic array shapes for the bayesian module.
zodiax.tree
¤
boolean_filter(pytree, parameters, inverse=False)
¤
Returns a pytree of matching structure with boolean values at the leaves. Leaves specified by paths will be True, all others will be False.
TODO: Possibly improve by setting both true and false simultaneously. Maybe do this with jax keypaths?
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pytree
|
PyTree
|
The pytree to be filtered. |
required |
parameters
|
Params
|
A path or list of paths or list of nested paths. |
required |
inverse
|
bool = False
|
If True, the boolean values will be inverted, by default False |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
args |
PyTree
|
An pytree of matching structre with boolean values at the leaves. |
Source code in zodiax/tree.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 | |
set_array(pytree, parameters=None)
¤
Converts all leaves in the pytree to arrays to ensure they have a .shape property for static dimensionality and size checks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pytree
|
PyTree
|
The pytree to be converted. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
pytree |
PyTree
|
The pytree with the leaves converted to arrays. |
Source code in zodiax/tree.py
51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | |