Skip to content

Math Utility Functions

This module contains some basic maths functions used under the hood.

factorial

Calculate n! in a jax friendly way. Note that n == 0 is not a safe case.

Parameters:

Name Type Description Default
n float

The value to calculate the factorial of.

required

Returns:

Type Description
n! : float

The factorial of the value.

Source code in src/dLux/utils/math.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
def factorial(n: float) -> float:
    """
    Calculate n! in a jax friendly way. Note that n == 0 is not a
    safe case.

    Parameters
    ----------
    n : float
        The value to calculate the factorial of.

    Returns
    -------
    n! : float
        The factorial of the value.
    """
    return lax.exp(lax.lgamma(n + 1.0))

triangular_number

Calculate the nth triangular number.

Parameters:

Name Type Description Default
n int

The nth triangular number to calculate.

required

Returns:

Name Type Description
n int

The nth triangular number.

Source code in src/dLux/utils/math.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def triangular_number(n: int) -> int:
    """
    Calculate the nth triangular number.

    Parameters
    ----------
    n : int
        The nth triangular number to calculate.

    Returns
    -------
    n : int
        The nth triangular number.
    """
    return n * (n + 1) / 2

eval_basis

Performs an n-dimensional dot-product between the basis and coefficients arrays.

Parameters:

Name Type Description Default
basis Array

The basis to use.

required
coefficients Array

The Array of coefficients to be applied to each basis vector.

required
Source code in src/dLux/utils/math.py
44
45
46
47
48
49
50
51
52
53
54
55
56
def eval_basis(basis: Array, coefficients: Array) -> Array:
    """
    Performs an n-dimensional dot-product between the basis and coefficients arrays.

    Parameters
    ----------
    basis: Array
        The basis to use.
    coefficients: Array
        The Array of coefficients to be applied to each basis vector.
    """
    ndim = coefficients.ndim
    return np.tensordot(basis, coefficients, axes=2 * (tuple(range(ndim)),))

nandiv

Divides two arrays, replacing any NaNs with a fill value.

Parameters:

Name Type Description Default
a Array

The numerator.

required
b Array

The denominator.

required
fill Any = np.inf

The value to replace NaNs with.

inf

Returns:

Type Description
a / b : Array

The result of the division.

Source code in src/dLux/utils/math.py
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def nandiv(a: Array, b: Array, fill: Any = np.inf) -> Array:
    """
    Divides two arrays, replacing any NaNs with a fill value.

    Parameters
    ----------
    a : Array
        The numerator.
    b : Array
        The denominator.
    fill : Any = np.inf
        The value to replace NaNs with.

    Returns
    -------
    a / b : Array
        The result of the division.
    """
    return np.where(b == 0, fill, a / b)