Skip to content

Fisher¤

Deprecation Warning

This module is in the process of being deprecated from Zodiax. The v0.5.0 release will be the last release to include this module, and it will be removed in v0.6.0.

The zodiax.fisher module contains helper functions for Hessian, Fisher matrix, and covariance matrix calculations.

zodiax.fisher ¤

covariance_matrix(pytree, parameters, loglike_fn, *loglike_args, shape_dict={}, save_memory=False, **loglike_kwargs) ¤

Calculates the covariance matrix of the pytree parameters. The shaped_dict parameter is used to specify the shape of the differentiated vector for specific parameters. For example, if the parameter param is a 1D array of shape (5,) and we wanted to calculate the covariance matrix of the mean, we can pass in shape_dict={'param': (1,)}. This will differentiate the log likelihood with respect to the mean of the parameters.

Parameters:

Name Type Description Default
pytree PyTree

Pytree with a .model() function.

required
parameters Params

A path or list of paths or list of nested paths.

required
loglike_fn callable

The log likelihood function to differentiate.

required
shape_dict dict = {}

A dictionary specifying the shape of the differentiated vector for specific parameters.

{}
save_memory bool = False

If True, the Hessian is calculated column by column to save memory. This is slightly slower.

False
*loglike_args Any

The args to pass to the log likelihood function.

()
**loglike_kwargs Any

The kwargs to pass to the log likelihood function.

{}

Returns:

Name Type Description
covariance_matrix Array

The covariance matrix of the pytree parameters.

Source code in zodiax/fisher.py
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def covariance_matrix(
    pytree: PyTree,
    parameters: Params,
    loglike_fn: callable,
    *loglike_args: Any,
    shape_dict: dict = {},
    save_memory: bool = False,
    **loglike_kwargs: Any,
) -> Array:
    """
    Calculates the covariance matrix of the pytree parameters. The
    `shaped_dict` parameter is used to specify the shape of the differentiated
    vector for specific parameters. For example, if the parameter `param` is
    a 1D array of shape (5,) and we wanted to calculate the covariance matrix
    of the mean, we can pass in `shape_dict={'param': (1,)}`. This will
    differentiate the log likelihood with respect to the mean of the parameters.

    Parameters
    ----------
    pytree : PyTree
        Pytree with a .model() function.
    parameters : Params
        A path or list of paths or list of nested paths.
    loglike_fn : callable
        The log likelihood function to differentiate.
    shape_dict : dict = {}
        A dictionary specifying the shape of the differentiated vector for
        specific parameters.
    save_memory : bool = False
        If True, the Hessian is calculated column by column to save memory.
        This is slightly slower.
    *loglike_args : Any
        The args to pass to the log likelihood function.
    **loglike_kwargs : Any
        The kwargs to pass to the log likelihood function.

    Returns
    -------
    covariance_matrix : Array
        The covariance matrix of the pytree parameters.
    """
    warnings.warn(
        "covariance_matrix is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    return np.linalg.inv(
        fisher_matrix(
            pytree,
            parameters,
            loglike_fn,
            shape_dict=shape_dict,
            save_memory=save_memory,
            *loglike_args,
            **loglike_kwargs,
        )
    )

fisher_matrix(pytree, parameters, loglike_fn, *loglike_args, shape_dict={}, save_memory=False, **loglike_kwargs) ¤

Calculates the Fisher information matrix of the log likelihood function with respect to the parameters of the pytree. It is evaluated at the current values of the parameters as listed in the pytree. Simply returns the negative Hessian.

Source code in zodiax/fisher.py
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def fisher_matrix(
    pytree: PyTree,
    parameters: Params,
    loglike_fn: callable,
    *loglike_args,
    shape_dict: dict = {},
    save_memory: bool = False,
    **loglike_kwargs,
) -> Array:
    """
    Calculates the Fisher information matrix of the log likelihood function with
    respect to the parameters of the pytree. It is evaluated at the current values
    of the parameters as listed in the pytree. Simply returns the negative Hessian.
    """
    warnings.warn(
        "fisher_matrix is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    return -hessian(
        pytree,
        parameters,
        loglike_fn,
        shape_dict=shape_dict,
        save_memory=save_memory,
        *loglike_args,
        **loglike_kwargs,
    )

hessian(pytree, parameters, fn, *fn_args, shape_dict={}, save_memory=False, **fn_kwargs) ¤

Calculates the Hessian of the function with respect to the parameters of the pytree. It is evaluated at the current values of the parameters as listed in the pytree.

Parameters:

Name Type Description Default
pytree PyTree

Pytree holding the parameters values.

required
parameters Params

Names of parameters to be used in Hessian calculation.

required
fn callable

The function to differentiate.

required
*fn_args Any

The args to pass to the log likelihood function.

()
shape_dict dict = {}

A dictionary specifying the shape of the differentiated vector for specific parameters.

{}
save_memory bool = False

If True, the Hessian is calculated column by column to save memory.

False
**fn__kwargs Any

The kwargs to pass to the log likelihood function.

required

Returns:

Name Type Description
hessian Array

The Hessian of the log likelihood function with respect to the parameters of the pytree.

Source code in zodiax/fisher.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def hessian(
    pytree: PyTree,
    parameters: Params,
    fn: callable,
    *fn_args,
    shape_dict: dict = {},
    save_memory: bool = False,
    **fn_kwargs,
) -> Array:
    """
    Calculates the Hessian of the function with respect to the parameters of the pytree.
    It is evaluated at the current values of the parameters as listed in the pytree.

    Parameters
    ----------
    pytree : PyTree
        Pytree holding the parameters values.
    parameters : Params
        Names of parameters to be used in Hessian calculation.
    fn : callable
        The function to differentiate.
    *fn_args : Any
        The args to pass to the log likelihood function.
    shape_dict : dict = {}
        A dictionary specifying the shape of the differentiated vector for
        specific parameters.
    save_memory : bool = False
        If True, the Hessian is calculated column by column to save memory.
    **fn__kwargs : Any
        The kwargs to pass to the log likelihood function.

    Returns
    -------
    hessian : Array
        The Hessian of the log likelihood function with respect to the
        parameters of the pytree.
    """
    warnings.warn(
        "fisher.hessian is deprecated as of v0.5.0 and will be removed in v0.6.0",
        DeprecationWarning,
    )

    # If only one parameter is passed, make it a list
    if len(parameters) == 1:
        parameters = [parameters]

    # Build empty vector to perturb
    pytree = set_array(pytree)
    shapes, lengths = _shapes_and_lengths(pytree, parameters, shape_dict)
    X = np.zeros(_lengths_to_N(lengths))

    # Build perturbation function to differentiate
    def loglike_fn_vec(X):
        parametric_pytree = _perturb(X, pytree, parameters, shapes, lengths)
        return fn(parametric_pytree, *fn_args, **fn_kwargs)

    # optional column by column hessian calculation for RAM mercy
    if save_memory:
        return _hessian_col_by_col(loglike_fn_vec, X)

    elif not save_memory:
        return jax.hessian(loglike_fn_vec)(X)