Skip to content

Wavefronts

Wavefront

Bases: Base

A simple class to hold the state of some wavefront as it is transformed and propagated throughout an optical system. All wavefronts assume square arrays.

Attributes:

Name Type Description
wavelength (float, meters)

The wavelength of the Wavefront.

amplitude (Array, power)

The electric field amplitude of the Wavefront.

phase (Array, radians)

The electric field phase of the Wavefront.

pixel_scale (float, meters / pixel or radians / pixel)

The pixel scale of the phase and amplitude arrays. If units='Cartesian' then the pixel scale is in meters/pixel, else if units='Angular' then the pixel scale is in radians/pixel.

plane str

The current plane type of wavefront, can be 'Pupil', 'Focal' or 'Intermediate'.

units str

The current units of the wavefront, can be 'Cartesian' or 'Angular'.

Source code in src/dLux/wavefronts.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
class Wavefront(Base):
    """
    A simple class to hold the state of some wavefront as it is transformed and
    propagated throughout an optical system. All wavefronts assume square arrays.

    Attributes
    ----------
    wavelength : float, meters
        The wavelength of the `Wavefront`.
    amplitude : Array, power
        The electric field amplitude of the `Wavefront`.
    phase : Array, radians
        The electric field phase of the `Wavefront`.
    pixel_scale : float, meters/pixel or radians/pixel
        The pixel scale of the phase and amplitude arrays. If `units='Cartesian'` then
        the pixel scale is in meters/pixel, else if `units='Angular'` then the pixel
        scale is in radians/pixel.
    plane : str
        The current plane type of wavefront, can be 'Pupil', 'Focal' or 'Intermediate'.
    units : str
        The current units of the wavefront, can be 'Cartesian' or 'Angular'.
    """

    wavelength: float
    pixel_scale: float
    amplitude: Array
    phase: Array
    plane: str
    units: str

    def __init__(
        self: Wavefront, npixels: int, diameter: float, wavelength: float
    ):
        """
        Parameters
        ----------
        npixels : int
            The number of pixels that represent the `Wavefront`.
        diameter : float, meters
            The total diameter of the `Wavefront`.
        wavelength : float, meters
            The wavelength of the `Wavefront`.
        """
        self.wavelength = np.asarray(wavelength, float)
        self.pixel_scale = np.asarray(diameter / npixels, float)
        self.amplitude = (
            np.ones((npixels, npixels), dtype=float) / npixels**2
        )
        self.phase = np.zeros((npixels, npixels), dtype=float)

        # Always initialised in Pupil plane with Cartesian Coords
        self.plane = "Pupil"
        self.units = "Cartesian"

    ####################
    # Getter Functions #
    ####################
    @property
    def diameter(self: Wavefront) -> Array:
        """
        Returns the current wavefront diameter calculated using the pixel scale and
        number of pixels.

        Returns
        -------
        diameter : Array, meters or radians
            The current diameter of the wavefront.
        """
        return self.npixels * self.pixel_scale

    @property
    def npixels(self: Wavefront) -> int:
        """
        Returns the side length of the arrays currently representing the wavefront.
        Taken from the last axis of the amplitude array.

        Returns
        -------
        pixels : int
            The number of pixels that represent the `Wavefront`.
        """
        return self.amplitude.shape[-1]

    @property
    def real(self: Wavefront) -> Array:
        """
        Returns the real component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The real component of the `Wavefront` phasor.
        """
        return self.amplitude * np.cos(self.phase)

    @property
    def imaginary(self: Wavefront) -> Array:
        """
        Returns the imaginary component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The imaginary component of the `Wavefront` phasor.
        """
        return self.amplitude * np.sin(self.phase)

    @property
    def phasor(self: Wavefront) -> Array:
        """
        The electric field phasor described by this Wavefront in complex form.

        Returns
        -------
        field : Array
            The electric field phasor of the wavefront.
        """
        return self.amplitude * np.exp(1j * self.phase)

    @property
    def psf(self: Wavefront) -> Array:
        """
        Calculates the Point Spread Function (PSF), i.e. the squared modulus
        of the complex wavefront.

        Returns
        -------
        psf : Array
            The PSF of the wavefront.
        """
        return self.amplitude**2

    @property
    def coordinates(self: Wavefront) -> Array:
        """
        Returns the physical positions of the wavefront pixels in meters.

        Returns
        -------
        coordinates : Array
            The coordinates of the centers of each pixel representing the
            wavefront.
        """
        return dlu.pixel_coords(self.npixels, self.diameter)

    @property
    def wavenumber(self: Wavefront) -> Array:
        """
        Returns the wavenumber of the wavefront (2 * pi / wavelength).

        Returns
        -------
        wavenumber : Array, 1/meters
            The wavenumber of the wavefront.
        """
        return 2 * np.pi / self.wavelength

    @property
    def fringe_size(self: Wavefront) -> Array:
        """
        Returns the size of the fringes in angular units.

        TODO Units check from focal plane
        Returns
        -------
        fringe_size : Array, radians
            The size of the linear diffraction fringe of the wavefront.
        """
        return self.wavelength / self.diameter

    @property
    def ndim(self: Wavefront) -> int:
        """
        Returns the number of 'dimensions' of the wavefront. This is used to track the
        vectorised version of the wavefront returned from vmapping.

        Returns
        -------
        ndim : int
            The 'dimensionality' of dimensions of the wavefront.
        """
        return self.pixel_scale.ndim

    #################
    # Magic Methods #
    #################
    def __add__(self: Wavefront, other: Any) -> Wavefront:
        """
        Adds the input 'other' to the wavefront. If the input is a numeric type, it is
        treated as an OPD, else if it is an optical layer, it will be applied to the
        wavefront.

        Parameters
        ----------
        other : Any
            The input to add to the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The output wavefront.
        """
        # None Type
        if other is None:
            return self

        # Some Optical Layer
        if isinstance(other, OpticalLayer()):
            return other.apply(self)

        # Array based inputs - Defaults to OPD
        if isinstance(other, (Array, float, int)):
            return self.add_opd(other)

        # Other
        else:
            raise TypeError(
                "Can only add an array or OpticalLayer to "
                f"Wavefront. Got: {type(other)}."
            )

    def __iadd__(self: Wavefront, other: Any) -> Wavefront:
        """
        Provides the += operator for the wavefront, calling the __add__ method.

        Parameters
        ----------
        other : Any
            The input to add to the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The output wavefront.
        """
        return self.__add__(other)

    def __mul__(self: Wavefront, other: Any) -> Wavefront:
        """
        Multiplies the input 'other' to the wavefront. If the input is a numeric type,
        it is treated as an array of transmission values and is multiplied by the
        wavefront amplitude, unless it is a complex number, in which case it will be
        multiplied with the wavefront phasor. If it is an optical layer, it will be
        applied to the wavefront.

        Parameters
        ----------
        other : Any
            The input to multiply with the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The output wavefront.
        """
        # None Type, return None
        if other is None:
            return self

        # Some Optical Layer, apply it
        if isinstance(other, OpticalLayer()):
            return other.apply(self)

        # Array based inputs
        if isinstance(other, (Array, float, int)):
            # Complex array - Multiply the phasors
            if isinstance(other, Array) and other.dtype.kind == "c":
                phasor = self.phasor * other
                return self.set(
                    ["amplitude", "phase"], [np.abs(phasor), np.angle(phasor)]
                )

            # Scalar array - Multiply amplitude
            else:
                return self.multiply("amplitude", other)

        # Other
        else:
            raise TypeError(
                "Can only multiply Wavefront by array or "
                f"OpticalLayer. Got: {type(other)}."
            )

    def __imul__(self: Wavefront, other: Any) -> Wavefront:
        """
        Provides the *= operator for the wavefront, calling the __mul__ method.

        Parameters
        ----------
        other : Any
            The input to multiply with the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The output wavefront.
        """
        return self.__mul__(other)

    ###################
    # Adder Functions #
    ###################
    def add_opd(self: Wavefront, opd: Array) -> Wavefront:
        """
        Adds an optical path difference (OPD) to the wavefront.

        Parameters
        ----------
        opd : Array, meters
            The opd to add to the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The new wavefront with the phases updated according to the supplied opd.
        """
        return self.add("phase", self.wavenumber * opd)

    def add_phase(self: Wavefront, phase: Array) -> Wavefront:
        """
        Adds a phase to the wavefront.

        Parameters
        ----------
        phase : Array, radians
            The phase to be added to the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The new wavefront with updated phases.
        """
        # Add this extra None check to allow PhaseOptics to have a None phase
        # and still be able to be 'added' to it, making this the phase
        # equivalent of `wf += opd` -> `wf = wf.add_phase(phase)`
        if phase is not None:
            return self.add("phase", phase)
        return self

    ###################
    # Other Functions #
    ###################
    def tilt(self: Wavefront, angles: Array) -> Wavefront:
        """
        Tilts the wavefront by the (x, y) angles.

        Parameters
        ----------
        angles : Array, radians
            The (x, y) angles by which to tilt the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The tilted wavefront.
        """
        if not isinstance(angles, Array) or angles.shape != (2,):
            raise ValueError("angles must be an array of shape (2,).")
        opd = -(angles[:, None, None] * self.coordinates).sum(0)
        return self.add_opd(opd)

    def normalise(self: Wavefront) -> Wavefront:
        """
        Normalises the total power of the wavefront to 1.

        Returns
        -------
        wavefront : Wavefront
            The normalised wavefront.
        """
        return self.divide("amplitude", np.linalg.norm(self.amplitude))

    def _to_field(self: Wavefront, complex: bool = False) -> Array:
        """
        Returns the wavefront in either (amplitude, phase) or (real, imaginary) form.

        Parameters
        ----------
        complex : bool = False
            Whether to return the wavefront in (real, imaginary) form.

        Returns
        -------
        field : Array
            The wavefront in either (amplitude, phase) or (real, imaginary) form.
        """
        if complex:
            return np.array([self.real, self.imaginary])
        return np.array([self.amplitude, self.phase])

    def _to_amplitude_phase(self: Wavefront, field: Array) -> Array:
        """
        Transforms the input field in (real, imaginary) to (amplitude, phase) form.

        Parameters
        ----------
        field : Array
            The wavefront field in (real, imaginary) form.

        Returns
        -------
        field : Array
            The wavefront field in (amplitude, phase) form.
        """
        amplitude = np.hypot(field[0], field[1])
        phase = np.arctan2(field[1], field[0])
        return np.array([amplitude, phase])

    def flip(self: Wavefront, axis: tuple) -> Wavefront:
        """
        Flips the wavefront along the specified axes. Note we use 'ij' indexing, so
        axis 0 is the y-axis and axis 1 is the x-axis.

        Parameters
        ----------
        axis : tuple
            The axes along which to flip the wavefront.

        Returns
        -------
        wavefront : Wavefront
            The new flipped wavefront.
        """
        field = self._to_field()
        flipper = vmap(np.flip, (0, None))
        amplitude, phase = flipper(field, axis)
        return self.set(["amplitude", "phase"], [amplitude, phase])

    def scale_to(
        self: Wavefront,
        npixels: int,
        pixel_scale: Array,
        complex: bool = False,
    ) -> Wavefront:
        """
        Interpolated the wavefront to a given npixels and pixel_scale. Can be done on
        the real and imaginary components by passing in complex=True.

        Parameters
        ----------
        npixels : int
            The number of pixels  to interpolate to.
        pixel_scale: Array
            The pixel scale to interpolate to.
        complex : bool = False
            Whether to rotate the real and imaginary representation of the wavefront as
            opposed to the amplitude and phase representation.

        Returns
        -------
        wavefront : Wavefront
            The new interpolated wavefront.
        """
        # Get field in either (amplitude, phase) or (real, imaginary)
        field = self._to_field(complex=complex)

        # Scale the field
        scale_fn = vmap(dlu.scale, (0, None, None))
        field = scale_fn(field, npixels, pixel_scale / self.pixel_scale)

        # Cast back to (amplitude, phase) if needed
        if complex:
            field = self._to_amplitude_phase(field)

        # Return new wavefront
        return self.set(
            ["amplitude", "phase", "pixel_scale"],
            [field[0], field[1], pixel_scale],
        )

    def rotate(
        self: Wavefront, angle: Array, order: int = 1, complex: bool = False
    ) -> Wavefront:
        """
        Rotates the wavefront by a given angle via interpolation. Can be done on the
        real and imaginary components by passing in complex=True.

        Parameters
        ----------
        angle : Array, radians
            The angle by which to rotate the wavefront in a clockwise
            direction.
        order : int = 1
            The interpolation order to use.
        complex : bool = False
            Whether to rotate the real and imaginary representation of the wavefront as
            opposed to the amplitude and phase representation.

        Returns
        -------
        wavefront : Wavefront
            The new wavefront rotated by angle in the clockwise direction.
        """
        # Get field in either (amplitude, phase) or (real, imaginary)
        field = self._to_field(complex=complex)

        # Rotate the field
        rotator = vmap(dlu.rotate, (0, None, None))
        field = rotator(field, angle, order)

        # Cast back to (amplitude, phase) if needed
        if complex:
            field = self._to_amplitude_phase(field)

        # Return new wavefront
        return self.set(["amplitude", "phase"], [field[0], field[1]])

    def resize(self: Wavefront, npixels: int) -> Wavefront:
        """
        Resizes the wavefront via a zero-padding or cropping operation.

        Parameters
        ----------
        npixels : int
            The size to resize the wavefront to.

        Returns
        -------
        wavefront : Wavefront
            The resized wavefront.
        """
        field = self._to_field()
        amplitude, phase = vmap(dlu.resize, (0, None))(field, npixels)
        return self.set(["amplitude", "phase"], [amplitude, phase])

    #########################
    # Propagation Functions #
    #########################
    def _prep_prop(self: Wavefront, focal_length) -> tuple:
        """
        Determines the propagation direction, output plane and output units.

        Parameters
        ----------
        focal_length : Union[float, None]
            The focal length of the propagation.

        Returns
        -------
        inverse : bool
            Whether the propagation is inverse or not.
        plane : str
            The output plane of the propagation.
        units : str
            The output units of the propagation.
        """
        # Determine propagation direction, output plane and output units
        if self.plane == "Pupil":
            inverse = False
            plane = "Focal"
            if focal_length is None:
                units = "Angular"
            else:
                units = "Cartesian"
        else:
            if focal_length is not None and self.units == "Angular":
                raise ValueError(
                    "focal_length can not be specific when"
                    "propagating from a Focal plane with angular units."
                )
            inverse = True
            plane = "Pupil"
            units = "Cartesian"

        return inverse, plane, units

    def propagate_FFT(
        self: Wavefront,
        focal_length: float = None,
        pad: int = 2,
    ) -> Wavefront:
        """
        Propagates the wavefront by performing a Fast Fourier Transform.

        Parameters
        ----------
        focal_length : float = None
            The focal length of the propagation. If None, the output pixel scale has
            units of radians, else meters.
        pad : int = 2
            The padding factory to apply to the input wavefront before the FFT.

        Returns
        -------
        wavefront : Wavefront
            The propagated wavefront.
        """
        inverse, plane, units = self._prep_prop(focal_length)

        # Calculate
        phasor, pixel_scale = dlu.FFT(
            self.phasor,
            self.wavelength,
            self.pixel_scale,
            focal_length,
            pad,
            inverse,
        )

        # Return new wavefront
        return self.set(
            ["amplitude", "phase", "pixel_scale", "plane", "units"],
            [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
        )

    # TODO: Class method this?
    def _MFT(
        self: Wavefront,
        phasor: Array,
        wavelength: float,
        pixel_scale: float,
        *args: tuple,
    ) -> Array:
        """
        Simple alias for the MFT function to allow for vectorisation over phasors,
        wavelengths, pixel_scales, etc.

        Parameters
        ----------
        phasor : Array
            The phasor to propagate.
        wavelength : float
            The wavelength of the wavefront.
        pixel_scale : float
            The pixel scale of the wavefront.
        args : tuple
            The propagation arguments.

        Returns
        -------
        phasor : Array
            The propagated phasor.
        """
        return dlu.MFT(phasor, wavelength, pixel_scale, *args)

    def propagate(
        self: Wavefront,
        npixels: int,
        pixel_scale: float,
        focal_length: float = None,
        shift: Array = np.zeros(2),
        pixel: bool = True,
    ) -> Wavefront:
        """
        Propagates the wavefront by performing an MFT, allowing for the output pixel
        scale and npixels to be specified.

        Parameters
        ----------
        npixels : int
            The number of pixels in the output plane.
        pixel_scale : float, meters/pixel or radians/pixel
            The pixel scale of the output plane.
        focal_length : float = None
            The focal length of the propagation. If None, the propagation is angular
            and pixel_scale_out is taken in as radians/pixel, else meters/pixel.
        shift : Array = np.zeros(2)
            The shift in the center of the output plane.
        pixel : bool = True
            Should the shift be taken in units of pixels, or pixel scale.

        Returns
        -------
        wavefront : Wavefront
            The propagated wavefront.
        """
        inverse, plane, units = self._prep_prop(focal_length)

        # Enforce array so output can be vectorised by vmap
        pixel_scale = np.asarray(pixel_scale, float)

        # Calculate
        # Using a self._MFT here allows for broadband wavefronts to define
        # vectorised propagation fn over phasors, wavels, px_scales, etc.
        # It also makes the code muuuuch nicer to read
        args = (npixels, pixel_scale, focal_length, shift, pixel, inverse)
        phasor = self._MFT(
            self.phasor, self.wavelength, self.pixel_scale, *args
        )

        # Update
        return self.set(
            ["amplitude", "phase", "pixel_scale", "plane", "units"],
            [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
        )

    # # TODO: Class method this?
    # def _fresnel(self, phasor, wavelength, pixel_scale, focal_shift, *args):
    #     return dlu.fresnel_MFT(phasor, wavelength, pixel_scale, *args)

    def propagate_fresnel(
        self: Wavefront,
        npixels: int,
        pixel_scale: float,
        focal_length: float,
        focal_shift: float = 0.0,
        shift: Array = np.zeros(2),
        pixel: bool = True,
    ) -> Wavefront:
        """
        Propagates the phasor using a Far-Field Fresnel propagation. This allows for
        psfs to be better modelled a few wavelengths from the focal plane.

        Parameters
        ----------
        npixels : int
            The number of pixels in the output plane.
        pixel_scale : float, meters/pixel or radians/pixel
            The pixel scale of the output plane.
        focal_length : float
            The focal length of the propagation.
        focal_shift: float, meters = 0.
            The shift from focus to propagate to.
        shift : Array = np.zeros(2)
            The shift in the center of the output plane.
        pixel : bool = True
            Should the shift be taken in units of pixels, or pixel scale.

        Returns
        -------
        wavefront : Wavefront
            The propagated wavefront.
        """
        # TODO: Try inverse propagation to see if it works, it probably will
        if self.plane == "Pupil":
            inverse = False
        else:
            inverse = True
        plane = "Intermediate"
        units = "Cartesian"

        # We can't fresnel from a focal plane
        if self.plane != "Pupil":
            raise ValueError(
                "Can only do an fresnel propagation from a Pupil plane, "
                f"current plane is {self.plane}."
            )

        # Calculate
        phasor = dlu.fresnel_MFT(
            self.phasor,
            self.wavelength,
            self.pixel_scale,
            npixels,
            pixel_scale,
            focal_length,
            focal_shift,
            shift,
            pixel,
            inverse,
        )

        # Update
        return self.set(
            ["amplitude", "phase", "pixel_scale", "plane", "units"],
            [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
        )

coordinates: Array property

Returns the physical positions of the wavefront pixels in meters.

Returns:

Name Type Description
coordinates Array

The coordinates of the centers of each pixel representing the wavefront.

diameter: Array property

Returns the current wavefront diameter calculated using the pixel scale and number of pixels.

Returns:

Name Type Description
diameter (Array, meters or radians)

The current diameter of the wavefront.

fringe_size: Array property

Returns the size of the fringes in angular units.

TODO Units check from focal plane

Returns:

Name Type Description
fringe_size (Array, radians)

The size of the linear diffraction fringe of the wavefront.

imaginary: Array property

Returns the imaginary component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The imaginary component of the Wavefront phasor.

ndim: int property

Returns the number of 'dimensions' of the wavefront. This is used to track the vectorised version of the wavefront returned from vmapping.

Returns:

Name Type Description
ndim int

The 'dimensionality' of dimensions of the wavefront.

npixels: int property

Returns the side length of the arrays currently representing the wavefront. Taken from the last axis of the amplitude array.

Returns:

Name Type Description
pixels int

The number of pixels that represent the Wavefront.

phasor: Array property

The electric field phasor described by this Wavefront in complex form.

Returns:

Name Type Description
field Array

The electric field phasor of the wavefront.

psf: Array property

Calculates the Point Spread Function (PSF), i.e. the squared modulus of the complex wavefront.

Returns:

Name Type Description
psf Array

The PSF of the wavefront.

real: Array property

Returns the real component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The real component of the Wavefront phasor.

wavenumber: Array property

Returns the wavenumber of the wavefront (2 * pi / wavelength).

Returns:

Name Type Description
wavenumber (Array, 1 / meters)

The wavenumber of the wavefront.

__add__(other)

Adds the input 'other' to the wavefront. If the input is a numeric type, it is treated as an OPD, else if it is an optical layer, it will be applied to the wavefront.

Parameters:

Name Type Description Default
other Any

The input to add to the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The output wavefront.

Source code in src/dLux/wavefronts.py
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
def __add__(self: Wavefront, other: Any) -> Wavefront:
    """
    Adds the input 'other' to the wavefront. If the input is a numeric type, it is
    treated as an OPD, else if it is an optical layer, it will be applied to the
    wavefront.

    Parameters
    ----------
    other : Any
        The input to add to the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The output wavefront.
    """
    # None Type
    if other is None:
        return self

    # Some Optical Layer
    if isinstance(other, OpticalLayer()):
        return other.apply(self)

    # Array based inputs - Defaults to OPD
    if isinstance(other, (Array, float, int)):
        return self.add_opd(other)

    # Other
    else:
        raise TypeError(
            "Can only add an array or OpticalLayer to "
            f"Wavefront. Got: {type(other)}."
        )

__iadd__(other)

Provides the += operator for the wavefront, calling the add method.

Parameters:

Name Type Description Default
other Any

The input to add to the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The output wavefront.

Source code in src/dLux/wavefronts.py
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
def __iadd__(self: Wavefront, other: Any) -> Wavefront:
    """
    Provides the += operator for the wavefront, calling the __add__ method.

    Parameters
    ----------
    other : Any
        The input to add to the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The output wavefront.
    """
    return self.__add__(other)

__imul__(other)

Provides the *= operator for the wavefront, calling the mul method.

Parameters:

Name Type Description Default
other Any

The input to multiply with the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The output wavefront.

Source code in src/dLux/wavefronts.py
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
def __imul__(self: Wavefront, other: Any) -> Wavefront:
    """
    Provides the *= operator for the wavefront, calling the __mul__ method.

    Parameters
    ----------
    other : Any
        The input to multiply with the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The output wavefront.
    """
    return self.__mul__(other)

__init__(npixels, diameter, wavelength)

Parameters:

Name Type Description Default
npixels int

The number of pixels that represent the Wavefront.

required
diameter (float, meters)

The total diameter of the Wavefront.

required
wavelength (float, meters)

The wavelength of the Wavefront.

required
Source code in src/dLux/wavefronts.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
def __init__(
    self: Wavefront, npixels: int, diameter: float, wavelength: float
):
    """
    Parameters
    ----------
    npixels : int
        The number of pixels that represent the `Wavefront`.
    diameter : float, meters
        The total diameter of the `Wavefront`.
    wavelength : float, meters
        The wavelength of the `Wavefront`.
    """
    self.wavelength = np.asarray(wavelength, float)
    self.pixel_scale = np.asarray(diameter / npixels, float)
    self.amplitude = (
        np.ones((npixels, npixels), dtype=float) / npixels**2
    )
    self.phase = np.zeros((npixels, npixels), dtype=float)

    # Always initialised in Pupil plane with Cartesian Coords
    self.plane = "Pupil"
    self.units = "Cartesian"

__mul__(other)

Multiplies the input 'other' to the wavefront. If the input is a numeric type, it is treated as an array of transmission values and is multiplied by the wavefront amplitude, unless it is a complex number, in which case it will be multiplied with the wavefront phasor. If it is an optical layer, it will be applied to the wavefront.

Parameters:

Name Type Description Default
other Any

The input to multiply with the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The output wavefront.

Source code in src/dLux/wavefronts.py
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def __mul__(self: Wavefront, other: Any) -> Wavefront:
    """
    Multiplies the input 'other' to the wavefront. If the input is a numeric type,
    it is treated as an array of transmission values and is multiplied by the
    wavefront amplitude, unless it is a complex number, in which case it will be
    multiplied with the wavefront phasor. If it is an optical layer, it will be
    applied to the wavefront.

    Parameters
    ----------
    other : Any
        The input to multiply with the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The output wavefront.
    """
    # None Type, return None
    if other is None:
        return self

    # Some Optical Layer, apply it
    if isinstance(other, OpticalLayer()):
        return other.apply(self)

    # Array based inputs
    if isinstance(other, (Array, float, int)):
        # Complex array - Multiply the phasors
        if isinstance(other, Array) and other.dtype.kind == "c":
            phasor = self.phasor * other
            return self.set(
                ["amplitude", "phase"], [np.abs(phasor), np.angle(phasor)]
            )

        # Scalar array - Multiply amplitude
        else:
            return self.multiply("amplitude", other)

    # Other
    else:
        raise TypeError(
            "Can only multiply Wavefront by array or "
            f"OpticalLayer. Got: {type(other)}."
        )

add_opd(opd)

Adds an optical path difference (OPD) to the wavefront.

Parameters:

Name Type Description Default
opd (Array, meters)

The opd to add to the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The new wavefront with the phases updated according to the supplied opd.

Source code in src/dLux/wavefronts.py
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
def add_opd(self: Wavefront, opd: Array) -> Wavefront:
    """
    Adds an optical path difference (OPD) to the wavefront.

    Parameters
    ----------
    opd : Array, meters
        The opd to add to the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The new wavefront with the phases updated according to the supplied opd.
    """
    return self.add("phase", self.wavenumber * opd)

add_phase(phase)

Adds a phase to the wavefront.

Parameters:

Name Type Description Default
phase (Array, radians)

The phase to be added to the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The new wavefront with updated phases.

Source code in src/dLux/wavefronts.py
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
def add_phase(self: Wavefront, phase: Array) -> Wavefront:
    """
    Adds a phase to the wavefront.

    Parameters
    ----------
    phase : Array, radians
        The phase to be added to the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The new wavefront with updated phases.
    """
    # Add this extra None check to allow PhaseOptics to have a None phase
    # and still be able to be 'added' to it, making this the phase
    # equivalent of `wf += opd` -> `wf = wf.add_phase(phase)`
    if phase is not None:
        return self.add("phase", phase)
    return self

flip(axis)

Flips the wavefront along the specified axes. Note we use 'ij' indexing, so axis 0 is the y-axis and axis 1 is the x-axis.

Parameters:

Name Type Description Default
axis tuple

The axes along which to flip the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The new flipped wavefront.

Source code in src/dLux/wavefronts.py
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
def flip(self: Wavefront, axis: tuple) -> Wavefront:
    """
    Flips the wavefront along the specified axes. Note we use 'ij' indexing, so
    axis 0 is the y-axis and axis 1 is the x-axis.

    Parameters
    ----------
    axis : tuple
        The axes along which to flip the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The new flipped wavefront.
    """
    field = self._to_field()
    flipper = vmap(np.flip, (0, None))
    amplitude, phase = flipper(field, axis)
    return self.set(["amplitude", "phase"], [amplitude, phase])

normalise()

Normalises the total power of the wavefront to 1.

Returns:

Name Type Description
wavefront Wavefront

The normalised wavefront.

Source code in src/dLux/wavefronts.py
377
378
379
380
381
382
383
384
385
386
def normalise(self: Wavefront) -> Wavefront:
    """
    Normalises the total power of the wavefront to 1.

    Returns
    -------
    wavefront : Wavefront
        The normalised wavefront.
    """
    return self.divide("amplitude", np.linalg.norm(self.amplitude))

propagate(npixels, pixel_scale, focal_length=None, shift=np.zeros(2), pixel=True)

Propagates the wavefront by performing an MFT, allowing for the output pixel scale and npixels to be specified.

Parameters:

Name Type Description Default
npixels int

The number of pixels in the output plane.

required
pixel_scale (float, meters / pixel or radians / pixel)

The pixel scale of the output plane.

required
focal_length float = None

The focal length of the propagation. If None, the propagation is angular and pixel_scale_out is taken in as radians/pixel, else meters/pixel.

None
shift Array = np.zeros(2)

The shift in the center of the output plane.

zeros(2)
pixel bool = True

Should the shift be taken in units of pixels, or pixel scale.

True

Returns:

Name Type Description
wavefront Wavefront

The propagated wavefront.

Source code in src/dLux/wavefronts.py
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
def propagate(
    self: Wavefront,
    npixels: int,
    pixel_scale: float,
    focal_length: float = None,
    shift: Array = np.zeros(2),
    pixel: bool = True,
) -> Wavefront:
    """
    Propagates the wavefront by performing an MFT, allowing for the output pixel
    scale and npixels to be specified.

    Parameters
    ----------
    npixels : int
        The number of pixels in the output plane.
    pixel_scale : float, meters/pixel or radians/pixel
        The pixel scale of the output plane.
    focal_length : float = None
        The focal length of the propagation. If None, the propagation is angular
        and pixel_scale_out is taken in as radians/pixel, else meters/pixel.
    shift : Array = np.zeros(2)
        The shift in the center of the output plane.
    pixel : bool = True
        Should the shift be taken in units of pixels, or pixel scale.

    Returns
    -------
    wavefront : Wavefront
        The propagated wavefront.
    """
    inverse, plane, units = self._prep_prop(focal_length)

    # Enforce array so output can be vectorised by vmap
    pixel_scale = np.asarray(pixel_scale, float)

    # Calculate
    # Using a self._MFT here allows for broadband wavefronts to define
    # vectorised propagation fn over phasors, wavels, px_scales, etc.
    # It also makes the code muuuuch nicer to read
    args = (npixels, pixel_scale, focal_length, shift, pixel, inverse)
    phasor = self._MFT(
        self.phasor, self.wavelength, self.pixel_scale, *args
    )

    # Update
    return self.set(
        ["amplitude", "phase", "pixel_scale", "plane", "units"],
        [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
    )

propagate_FFT(focal_length=None, pad=2)

Propagates the wavefront by performing a Fast Fourier Transform.

Parameters:

Name Type Description Default
focal_length float = None

The focal length of the propagation. If None, the output pixel scale has units of radians, else meters.

None
pad int = 2

The padding factory to apply to the input wavefront before the FFT.

2

Returns:

Name Type Description
wavefront Wavefront

The propagated wavefront.

Source code in src/dLux/wavefronts.py
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
def propagate_FFT(
    self: Wavefront,
    focal_length: float = None,
    pad: int = 2,
) -> Wavefront:
    """
    Propagates the wavefront by performing a Fast Fourier Transform.

    Parameters
    ----------
    focal_length : float = None
        The focal length of the propagation. If None, the output pixel scale has
        units of radians, else meters.
    pad : int = 2
        The padding factory to apply to the input wavefront before the FFT.

    Returns
    -------
    wavefront : Wavefront
        The propagated wavefront.
    """
    inverse, plane, units = self._prep_prop(focal_length)

    # Calculate
    phasor, pixel_scale = dlu.FFT(
        self.phasor,
        self.wavelength,
        self.pixel_scale,
        focal_length,
        pad,
        inverse,
    )

    # Return new wavefront
    return self.set(
        ["amplitude", "phase", "pixel_scale", "plane", "units"],
        [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
    )

propagate_fresnel(npixels, pixel_scale, focal_length, focal_shift=0.0, shift=np.zeros(2), pixel=True)

Propagates the phasor using a Far-Field Fresnel propagation. This allows for psfs to be better modelled a few wavelengths from the focal plane.

Parameters:

Name Type Description Default
npixels int

The number of pixels in the output plane.

required
pixel_scale (float, meters / pixel or radians / pixel)

The pixel scale of the output plane.

required
focal_length float

The focal length of the propagation.

required
focal_shift float

The shift from focus to propagate to.

0.0
shift Array = np.zeros(2)

The shift in the center of the output plane.

zeros(2)
pixel bool = True

Should the shift be taken in units of pixels, or pixel scale.

True

Returns:

Name Type Description
wavefront Wavefront

The propagated wavefront.

Source code in src/dLux/wavefronts.py
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
def propagate_fresnel(
    self: Wavefront,
    npixels: int,
    pixel_scale: float,
    focal_length: float,
    focal_shift: float = 0.0,
    shift: Array = np.zeros(2),
    pixel: bool = True,
) -> Wavefront:
    """
    Propagates the phasor using a Far-Field Fresnel propagation. This allows for
    psfs to be better modelled a few wavelengths from the focal plane.

    Parameters
    ----------
    npixels : int
        The number of pixels in the output plane.
    pixel_scale : float, meters/pixel or radians/pixel
        The pixel scale of the output plane.
    focal_length : float
        The focal length of the propagation.
    focal_shift: float, meters = 0.
        The shift from focus to propagate to.
    shift : Array = np.zeros(2)
        The shift in the center of the output plane.
    pixel : bool = True
        Should the shift be taken in units of pixels, or pixel scale.

    Returns
    -------
    wavefront : Wavefront
        The propagated wavefront.
    """
    # TODO: Try inverse propagation to see if it works, it probably will
    if self.plane == "Pupil":
        inverse = False
    else:
        inverse = True
    plane = "Intermediate"
    units = "Cartesian"

    # We can't fresnel from a focal plane
    if self.plane != "Pupil":
        raise ValueError(
            "Can only do an fresnel propagation from a Pupil plane, "
            f"current plane is {self.plane}."
        )

    # Calculate
    phasor = dlu.fresnel_MFT(
        self.phasor,
        self.wavelength,
        self.pixel_scale,
        npixels,
        pixel_scale,
        focal_length,
        focal_shift,
        shift,
        pixel,
        inverse,
    )

    # Update
    return self.set(
        ["amplitude", "phase", "pixel_scale", "plane", "units"],
        [np.abs(phasor), np.angle(phasor), pixel_scale, plane, units],
    )

resize(npixels)

Resizes the wavefront via a zero-padding or cropping operation.

Parameters:

Name Type Description Default
npixels int

The size to resize the wavefront to.

required

Returns:

Name Type Description
wavefront Wavefront

The resized wavefront.

Source code in src/dLux/wavefronts.py
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
def resize(self: Wavefront, npixels: int) -> Wavefront:
    """
    Resizes the wavefront via a zero-padding or cropping operation.

    Parameters
    ----------
    npixels : int
        The size to resize the wavefront to.

    Returns
    -------
    wavefront : Wavefront
        The resized wavefront.
    """
    field = self._to_field()
    amplitude, phase = vmap(dlu.resize, (0, None))(field, npixels)
    return self.set(["amplitude", "phase"], [amplitude, phase])

rotate(angle, order=1, complex=False)

Rotates the wavefront by a given angle via interpolation. Can be done on the real and imaginary components by passing in complex=True.

Parameters:

Name Type Description Default
angle (Array, radians)

The angle by which to rotate the wavefront in a clockwise direction.

required
order int = 1

The interpolation order to use.

1
complex bool = False

Whether to rotate the real and imaginary representation of the wavefront as opposed to the amplitude and phase representation.

False

Returns:

Name Type Description
wavefront Wavefront

The new wavefront rotated by angle in the clockwise direction.

Source code in src/dLux/wavefronts.py
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
def rotate(
    self: Wavefront, angle: Array, order: int = 1, complex: bool = False
) -> Wavefront:
    """
    Rotates the wavefront by a given angle via interpolation. Can be done on the
    real and imaginary components by passing in complex=True.

    Parameters
    ----------
    angle : Array, radians
        The angle by which to rotate the wavefront in a clockwise
        direction.
    order : int = 1
        The interpolation order to use.
    complex : bool = False
        Whether to rotate the real and imaginary representation of the wavefront as
        opposed to the amplitude and phase representation.

    Returns
    -------
    wavefront : Wavefront
        The new wavefront rotated by angle in the clockwise direction.
    """
    # Get field in either (amplitude, phase) or (real, imaginary)
    field = self._to_field(complex=complex)

    # Rotate the field
    rotator = vmap(dlu.rotate, (0, None, None))
    field = rotator(field, angle, order)

    # Cast back to (amplitude, phase) if needed
    if complex:
        field = self._to_amplitude_phase(field)

    # Return new wavefront
    return self.set(["amplitude", "phase"], [field[0], field[1]])

scale_to(npixels, pixel_scale, complex=False)

Interpolated the wavefront to a given npixels and pixel_scale. Can be done on the real and imaginary components by passing in complex=True.

Parameters:

Name Type Description Default
npixels int

The number of pixels to interpolate to.

required
pixel_scale Array

The pixel scale to interpolate to.

required
complex bool = False

Whether to rotate the real and imaginary representation of the wavefront as opposed to the amplitude and phase representation.

False

Returns:

Name Type Description
wavefront Wavefront

The new interpolated wavefront.

Source code in src/dLux/wavefronts.py
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
def scale_to(
    self: Wavefront,
    npixels: int,
    pixel_scale: Array,
    complex: bool = False,
) -> Wavefront:
    """
    Interpolated the wavefront to a given npixels and pixel_scale. Can be done on
    the real and imaginary components by passing in complex=True.

    Parameters
    ----------
    npixels : int
        The number of pixels  to interpolate to.
    pixel_scale: Array
        The pixel scale to interpolate to.
    complex : bool = False
        Whether to rotate the real and imaginary representation of the wavefront as
        opposed to the amplitude and phase representation.

    Returns
    -------
    wavefront : Wavefront
        The new interpolated wavefront.
    """
    # Get field in either (amplitude, phase) or (real, imaginary)
    field = self._to_field(complex=complex)

    # Scale the field
    scale_fn = vmap(dlu.scale, (0, None, None))
    field = scale_fn(field, npixels, pixel_scale / self.pixel_scale)

    # Cast back to (amplitude, phase) if needed
    if complex:
        field = self._to_amplitude_phase(field)

    # Return new wavefront
    return self.set(
        ["amplitude", "phase", "pixel_scale"],
        [field[0], field[1], pixel_scale],
    )

tilt(angles)

Tilts the wavefront by the (x, y) angles.

Parameters:

Name Type Description Default
angles (Array, radians)

The (x, y) angles by which to tilt the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

The tilted wavefront.

Source code in src/dLux/wavefronts.py
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
def tilt(self: Wavefront, angles: Array) -> Wavefront:
    """
    Tilts the wavefront by the (x, y) angles.

    Parameters
    ----------
    angles : Array, radians
        The (x, y) angles by which to tilt the wavefront.

    Returns
    -------
    wavefront : Wavefront
        The tilted wavefront.
    """
    if not isinstance(angles, Array) or angles.shape != (2,):
        raise ValueError("angles must be an array of shape (2,).")
    opd = -(angles[:, None, None] * self.coordinates).sum(0)
    return self.add_opd(opd)