Skip to content

Math¤

gaussian

dLux.utils.math.gaussian(mean=0.0, std=1.0, npixels=64, extent=5.0) ¤

Generates a normalized n-dimensional Gaussian function.

Parameters:

Name Type Description Default
mean float | Array = 0.0

The center position(s) of the Gaussian. Scalar for 1D, array for nD.

0.0
std float | Array = 1.0

The standard deviation(s) of the Gaussian. Scalar for 1D, array for nD.

1.0
npixels int | tuple[int, ...] = 64

The number of pixels along each axis. Scalar for 1D, tuple for nD.

64
extent float = 5.0

The extent of the grid in units of standard deviation on each side.

5.0

Returns:

Name Type Description
kernel Array

The normalized n-dimensional Gaussian kernel.

Source code in dLux/utils/math.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def gaussian(
    mean: float | Array = 0.0,
    std: float | Array = 1.0,
    npixels: int | tuple[int, ...] = 64,
    extent: float = 5.0,
) -> Array:
    """
    Generates a normalized n-dimensional Gaussian function.

    Parameters
    ----------
    mean : float | Array = 0.0
        The center position(s) of the Gaussian. Scalar for 1D, array for nD.
    std : float | Array = 1.0
        The standard deviation(s) of the Gaussian. Scalar for 1D, array for nD.
    npixels : int | tuple[int, ...] = 64
        The number of pixels along each axis. Scalar for 1D, tuple for nD.
    extent : float = 5.0
        The extent of the grid in units of standard deviation on each side.

    Returns
    -------
    kernel : Array
        The normalized n-dimensional Gaussian kernel.
    """
    # Check inputs and cast to tuples
    npixels = _cast_tuple(npixels, "npixels")
    ndim = max(len(npixels), _input_len(mean, "mean"), _input_len(std, "std"))
    mean = _cast_scalar(mean, ndim, "mean")
    std = _cast_scalar(std, ndim, "std")

    # Make sure npix is the right dimensionality
    if len(npixels) != ndim:
        npixels *= ndim

    # Generate per-axis coordinates and corresponding 1D Gaussians
    linspaces = jtu.map(
        lambda n: np.linspace(-extent, extent, n),
        npixels,
    )
    one_d_gauss = jtu.map(
        lambda axis, m, s: jsp.stats.norm.pdf(axis, loc=m, scale=s),
        linspaces,
        mean,
        std,
    )

    # Construct nD separable Gaussian kernel from 1D marginals
    kernel = np.array(np.meshgrid(*one_d_gauss, indexing="xy")).prod(0)
    return kernel / np.sum(kernel)
mv_gaussian

dLux.utils.math.mv_gaussian(mean, cov, npix=64, extent=5.0) ¤

Generates a normalized multivariate Gaussian function.

Parameters:

Name Type Description Default
mean Array

The mean vector of the multivariate Gaussian. Shape (ndim,).

required
cov Array

The covariance matrix of the multivariate Gaussian. Shape (ndim, ndim).

required
npix int | Array = 64

The number of pixels along each axis.

64
extent float = 5.0

The extent of the grid in units of standard deviation on each side.

5.0

Returns:

Name Type Description
kernel Array

The normalized multivariate Gaussian kernel.

Source code in dLux/utils/math.py
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def mv_gaussian(
    mean: Array,
    cov: Array,
    npix: int | Array = 64,
    extent: float = 5.0,
) -> Array:
    """
    Generates a normalized multivariate Gaussian function.

    Parameters
    ----------
    mean : Array
        The mean vector of the multivariate Gaussian. Shape (ndim,).
    cov : Array
        The covariance matrix of the multivariate Gaussian. Shape (ndim, ndim).
    npix : int | Array = 64
        The number of pixels along each axis.
    extent : float = 5.0
        The extent of the grid in units of standard deviation on each side.

    Returns
    -------
    kernel : Array
        The normalized multivariate Gaussian kernel.
    """
    raise NotImplementedError("Multivariate Gaussian generation is under development.")

    mean = np.asarray(mean, dtype=float)
    cov = np.asarray(cov, dtype=float)
    npix_arr = np.atleast_1d(np.asarray(npix, dtype=int))
    ndim = mean.size

    # Get standard deviations from covariance matrix diagonal
    stds = np.sqrt(np.diag(cov))

    # Handle npix broadcasting
    if npix_arr.size == 1:
        npix_arr = np.repeat(npix_arr, ndim)

    # Create linspace function
    def make_axis(i):
        return np.linspace(
            mean[i] - extent * stds[i],
            mean[i] + extent * stds[i],
            npix_arr[i],
        )

    # Generate coordinate arrays for each dimension
    linspaces = jtu.map(make_axis, np.arange(ndim))

    # Create meshgrid
    grids = np.meshgrid(*linspaces, indexing="xy")

    # Stack grids into points array and reshape for computation
    grid_shape = tuple(len(ls) for ls in linspaces)
    points = np.stack(grids, axis=0).reshape(ndim, -1).T  # (n_points, ndim)

    # Compute multivariate Gaussian
    kernel = jsp.stats.multivariate_normal.pdf(points, mean=mean, cov=cov)
    kernel = kernel.reshape(grid_shape)

    # Normalize
    return kernel / np.sum(kernel)
factorial

dLux.utils.math.factorial(n) ¤

Calculate n! in a JAX-friendly way.

Parameters:

Name Type Description Default
n float

The value to calculate the factorial of.

required

Returns:

Type Description
n! : float

The factorial of the value.

Source code in dLux/utils/math.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
def factorial(n: float) -> float:
    """
    Calculate n! in a JAX-friendly way.

    Parameters
    ----------
    n : float
        The value to calculate the factorial of.

    Returns
    -------
    n! : float
        The factorial of the value.
    """
    n = np.asarray(n, float)
    return lax.cond(
        n == 0,
        lambda x: np.asarray(1.0, dtype=x.dtype),
        lambda x: lax.exp(lax.lgamma(x + 1.0)),
        n,
    )
triangular_number

dLux.utils.math.triangular_number(n) ¤

Calculate the nth triangular number.

Parameters:

Name Type Description Default
n int

The nth triangular number to calculate.

required

Returns:

Name Type Description
n int

The nth triangular number.

Source code in dLux/utils/math.py
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
def triangular_number(n: int) -> int:
    """
    Calculate the nth triangular number.

    Parameters
    ----------
    n : int
        The nth triangular number to calculate.

    Returns
    -------
    n : int
        The nth triangular number.
    """
    return n * (n + 1) / 2
eval_basis

dLux.utils.math.eval_basis(basis, coefficients) ¤

Performs an n-dimensional dot-product between the basis and coefficients arrays.

Parameters:

Name Type Description Default
basis Array

The basis to use.

required
coefficients Array

The Array of coefficients to be applied to each basis vector.

required
Source code in dLux/utils/math.py
176
177
178
179
180
181
182
183
184
185
186
187
188
def eval_basis(basis: Array, coefficients: Array) -> Array:
    """
    Performs an n-dimensional dot-product between the basis and coefficients arrays.

    Parameters
    ----------
    basis: Array
        The basis to use.
    coefficients: Array
        The Array of coefficients to be applied to each basis vector.
    """
    ndim = coefficients.ndim
    return np.tensordot(basis, coefficients, axes=2 * (tuple(range(ndim)),))
nandiv

dLux.utils.math.nandiv(a, b, fill=np.inf) ¤

Divides two arrays, replacing any NaNs with a fill value.

Parameters:

Name Type Description Default
a Array

The numerator.

required
b Array

The denominator.

required
fill Any = np.inf

The value to replace NaNs with.

inf

Returns:

Type Description
a / b : Array

The result of the division.

Source code in dLux/utils/math.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
def nandiv(a: Array, b: Array, fill: Any = np.inf) -> Array:
    """
    Divides two arrays, replacing any NaNs with a fill value.

    Parameters
    ----------
    a : Array
        The numerator.
    b : Array
        The denominator.
    fill : Any = np.inf
        The value to replace NaNs with.

    Returns
    -------
    a / b : Array
        The result of the division.
    """
    # Avoid evaluating a/0 under jax_debug_nans by dividing through a safe denominator
    safe_b = np.where(b == 0, 1, b)
    out = a / safe_b
    return np.where(b == 0, fill, out)