Skip to content

Helpers¤

map2array

dLux.utils.helpers.map2array(fn, tree, leaf_fn=None) ¤

Maps a function across a pytree, flattening it and turning it into an array.

Parameters:

Name Type Description Default
fn Callable

The function to be mapped across the pytree.

required
tree Any

The pytree to be mapped across.

required
leaf_fn Callable = None

The function to be used to determine whether a leaf is reached.

None

Returns:

Name Type Description
array Array

The flattened array of the pytree.

Source code in dLux/utils/helpers.py
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
def map2array(fn: Callable, tree: Any, leaf_fn: Callable = None) -> Array:
    """
    Maps a function across a pytree, flattening it and turning it into an
    array.

    Parameters
    ----------
    fn : Callable
        The function to be mapped across the pytree.
    tree : Any
        The pytree to be mapped across.
    leaf_fn : Callable = None
        The function to be used to determine whether a leaf is reached.

    Returns
    -------
    array : Array
        The flattened array of the pytree.
    """
    if leaf_fn is not None:
        return np.array(jtu.flatten(jtu.map(fn, tree, is_leaf=leaf_fn))[0])
    else:
        return np.array(jtu.flatten(jtu.map(fn, tree))[0])
list2dictionary

dLux.utils.helpers.list2dictionary(list_in, ordered, allowed_types=()) ¤

Converts some input list to a dictionary. The input list entries can either be objects, in which case the keys are taken as the class name, else a (key, object) tuple can be used to specify a key.

If any duplicate keys are found, the key is appended with an index value. i.e. if two of the list entries have the same key 'layer', they will be assigned 'layer_0' and 'layer_1' respectively, depending on their input order in the list.

Parameters:

Name Type Description Default
list_in list

The list of objects to be converted into a dictionary.

required
ordered bool

Whether to return an ordered or regular dictionary.

required
allowed_types tuple

The allowed types of layers to be included in the dictionary.

()

Returns:

Name Type Description
dictionary dict

The equivalent dictionary or ordered dictionary.

Source code in dLux/utils/helpers.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
def list2dictionary(list_in: list, ordered: bool, allowed_types: tuple = ()) -> dict:
    """
    Converts some input list to a dictionary. The input list entries can either be
    objects, in which case the keys are taken as the class name, else a (key, object)
    tuple can be used to specify a key.

    If any duplicate keys are found, the key is appended with an index value. i.e. if
    two of the list entries have the same key 'layer', they will be assigned 'layer_0'
    and 'layer_1' respectively, depending on their input order in the list.

    Parameters
    ----------
    list_in : list
        The list of objects to be converted into a dictionary.
    ordered : bool
        Whether to return an ordered or regular dictionary.
    allowed_types : tuple
        The allowed types of layers to be included in the dictionary.

    Returns
    -------
    dictionary : dict
        The equivalent dictionary or ordered dictionary.
    """
    # Construct names list and identify repeats
    names, repeats = [], []
    for item in list_in:
        # Check for specified names
        if isinstance(item, tuple):
            # item, name = item
            name, item = item
        else:
            name = item.__class__.__name__

        # Check input types
        if allowed_types != () and not isinstance(item, allowed_types):
            raise TypeError(f"Item {name} is not an allowed type, got " f"{type(item)}")

        # Check for Repeats
        if name in names:
            repeats.append(name)
        names.append(name)

    # Get list of unique repeats
    repeats = list(set(repeats))

    # Iterate over repeat names
    for i in range(len(repeats)):
        # Iterate over names list and append index value to name
        idx = 0
        for j in range(len(names)):
            if repeats[i] == names[j]:
                names[j] = names[j] + "_{}".format(idx)
                idx += 1

    # Turn list into Dictionary
    dict_out = OrderedDict() if ordered else {}
    for i in range(len(names)):
        # Check for spaces in names
        if " " in names[i]:
            raise ValueError(f"Names cannot contain spaces, got {names[i]}")

        # Add to dict
        if isinstance(list_in[i], tuple):
            # item = list_in[i][0]
            item = list_in[i][1]
        else:
            item = list_in[i]
        dict_out[names[i]] = item
    return dict_out
insert_layer

dLux.utils.helpers.insert_layer(layers, layer, index, allowed_type) ¤

Inserts a layer into a dictionary of layers at a specified index. This function calls the list2dictionary function to ensure all keys remain unique. Note that this can result in some keys being modified if they are duplicates. The input 'layer' can be a tuple of (key, layer) to specify a key, else the key is taken as the class name of the layer.

Parameters:

Name Type Description Default
layers dict

The dictionary of layers to insert the layer into.

required
layer Any

The layer to be inserted.

required
index int

The index at which to insert the layer.

required
allowed_type Any

The type of layer to be inserted. Used for type-checking.

required

Returns:

Name Type Description
layers dict

The updated dictionary of layers.

Source code in dLux/utils/helpers.py
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
def insert_layer(
    layers: dict,
    layer: Any,
    index: int,
    allowed_type: Any,
) -> dict:
    """
    Inserts a layer into a dictionary of layers at a specified index. This function
    calls the list2dictionary function to ensure all keys remain unique. Note that this
    can result in some keys being modified if they are duplicates. The input 'layer'
    can be a tuple of (key, layer) to specify a key, else the key is taken as the
    class name of the layer.

    Parameters
    ----------
    layers : dict
        The dictionary of layers to insert the layer into.
    layer : Any
        The layer to be inserted.
    index : int
        The index at which to insert the layer.
    allowed_type : Any
        The type of layer to be inserted. Used for type-checking.

    Returns
    -------
    layers : dict
        The updated dictionary of layers.
    """
    layers_list = list(zip(layers.keys(), layers.values()))
    layers_list.insert(index, layer)
    return list2dictionary(layers_list, True, allowed_type)
remove_layer

dLux.utils.helpers.remove_layer(layers, key) ¤

Removes a layer from a dictionary of layers, specified by its key.

Parameters:

Name Type Description Default
layers dict

The dictionary of layers to remove the layer from.

required
key str

The key of the layer to be removed.

required

Returns:

Name Type Description
layers dict

The updated dictionary of layers.

Source code in dLux/utils/helpers.py
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
def remove_layer(layers: dict, key: str) -> dict:
    """
    Removes a layer from a dictionary of layers, specified by its key.

    Parameters
    ----------
    layers : dict
        The dictionary of layers to remove the layer from.
    key : str
        The key of the layer to be removed.

    Returns
    -------
    layers : dict
        The updated dictionary of layers.
    """
    layers.pop(key)
    return layers
imshow_extent

dLux.utils.helpers.imshow_extent(size) ¤

Returns a square imshow extent in [xmin, xmax, ymin, ymax] order.

Parameters:

Name Type Description Default
size float

The total width of the image in the relevant physical units.

required

Returns:

Name Type Description
extent Array

The extent array to pass directly to matplotlib imshow.

Source code in dLux/utils/helpers.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
def imshow_extent(size: float) -> Array:
    """
    Returns a square imshow extent in [xmin, xmax, ymin, ymax] order.

    Parameters
    ----------
    size : float
        The total width of the image in the relevant physical units.

    Returns
    -------
    extent : Array
        The extent array to pass directly to matplotlib imshow.
    """
    half_size = np.asarray(size, dtype=float) / 2
    return np.array([-half_size, half_size, -half_size, half_size])
inherit_docstrings

dLux.utils.helpers.inherit_docstrings(cls, method_names=None) ¤

Inherit docstrings and annotations from parent classes for specified methods.

This function walks the MRO to find the first parent class with a docstring or annotations for each method, and copies them to the child class if missing.

Parameters:

Name Type Description Default
cls type

The class being created via init_subclass.

required
method_names list[str] | None

List of method names to inherit docstrings/annotations for. If None, only 'call' is checked.

None

Returns:

Type Description
None

Modifies cls in place.

Source code in dLux/utils/helpers.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
def inherit_docstrings(cls, method_names=None):
    """
    Inherit docstrings and annotations from parent classes for specified methods.

    This function walks the MRO to find the first parent class with a docstring
    or annotations for each method, and copies them to the child class if missing.

    Parameters
    ----------
    cls : type
        The class being created via __init_subclass__.
    method_names : list[str] | None
        List of method names to inherit docstrings/annotations for.
        If None, only '__call__' is checked.

    Returns
    -------
    None
        Modifies cls in place.
    """
    if method_names is None:
        method_names = ["__call__"]

    for method_name in method_names:
        # Only process if method is defined in this class
        if method_name in cls.__dict__:
            method = cls.__dict__[method_name]

            # Inherit docstring if missing
            if method.__doc__ is None:
                for base in cls.__mro__[1:]:
                    if (
                        hasattr(base, method_name)
                        and getattr(base, method_name).__doc__ is not None
                    ):
                        method.__doc__ = getattr(base, method_name).__doc__
                        break

            # Inherit annotations if missing
            if not hasattr(method, "__annotations__") or not method.__annotations__:
                for base in cls.__mro__[1:]:
                    if method_name in base.__dict__ and hasattr(
                        base.__dict__[method_name], "__annotations__"
                    ):
                        method.__annotations__ = base.__dict__[
                            method_name
                        ].__annotations__
                        break
missing_attribute_error

dLux.utils.helpers.missing_attribute_error(owner, key, valid_attrs=None, hint=None) ¤

Builds a consistent AttributeError message for missing attributes.

Parameters:

Name Type Description Default
owner Any

The object raising the error.

required
key str

The missing attribute name.

required
valid_attrs list[str] = None

Optional list of valid attribute names to surface.

None
hint str = None

Optional additional guidance appended to the message.

None

Returns:

Name Type Description
error AttributeError

The formatted AttributeError instance.

Source code in dLux/utils/helpers.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def missing_attribute_error(
    owner: Any,
    key: str,
    valid_attrs: list[str] = None,
    hint: str = None,
) -> AttributeError:
    """
    Builds a consistent AttributeError message for missing attributes.

    Parameters
    ----------
    owner : Any
        The object raising the error.
    key : str
        The missing attribute name.
    valid_attrs : list[str] = None
        Optional list of valid attribute names to surface.
    hint : str = None
        Optional additional guidance appended to the message.

    Returns
    -------
    error : AttributeError
        The formatted AttributeError instance.
    """
    message = f"{owner.__class__.__name__} has no attribute '{key}'."
    if valid_attrs:
        attrs = sorted(valid_attrs)
        attrs_str = ", ".join(attrs[:6])
        ellipsis = "..." if len(attrs) > 6 else ""
        message += f" Valid attributes: {attrs_str}{ellipsis}"
    if hint:
        message += f" {hint}"
    return AttributeError(message)