Skip to content

Wavefronts¤

Wavefront

dLux.wavefronts.Wavefront ¤

Bases: Base

Holds the state of a wavefront as it is transformed and propagated through an optical system. All wavefronts assume square arrays.

UML

UML

Attributes:

Name Type Description
wavelength (float, meters)

The wavelength of the Wavefront.

phasor Array[complex]

The electric field of the Wavefront.

pixel_scale (float, meters / pixel)

The pixel scale of the phase and amplitude arrays.

center Array

The centre coordinate of the wavefront grid.

diameter (Array, property)

Derived property from pixel_scale and npixels; wavefront diameter.

npixels (int, property)

Derived property from phasor; side length of wavefront arrays.

real (Array, property)

Derived property from phasor; real component of the electric field.

imaginary (Array, property)

Derived property from phasor; imaginary component of the electric field.

amplitude (Array, property)

Derived property from phasor; field amplitude abs(phasor).

phase (Array, property)

Derived property from phasor; field phase angle.

complex (tuple[Array, Array], property)

Derived property from phasor; (real, imaginary) representation.

polar (tuple[Array, Array], property)

Derived property from phasor; (amplitude, phase) representation.

psf (Array, property)

Derived property from phasor; intensity image abs(phasor) ** 2.

wavenumber (Array, property)

Derived property from wavelength; scalar 2 * pi / wavelength.

ndim (int, property)

Derived property from pixel_scale; vectorisation rank of wavefront state.

power (Array, property)

Derived property from amplitude; total wavefront power.

Source code in dLux/wavefronts.py
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
class Wavefront(zdx.Base):
    """
    Holds the state of a wavefront as it is transformed and propagated
    through an optical system. All wavefronts assume square arrays.

    ??? abstract "UML"
        ![UML](../../assets/uml/Wavefront.png)

    Attributes
    ----------
    wavelength : float, meters
        The wavelength of the `Wavefront`.
    phasor : Array[complex]
        The electric field of the `Wavefront`.
    pixel_scale : float, meters/pixel
        The pixel scale of the phase and amplitude arrays.
    center : Array
        The centre coordinate of the wavefront grid.
    diameter : Array, property
        Derived property from `pixel_scale` and `npixels`; wavefront diameter.
    npixels : int, property
        Derived property from `phasor`; side length of wavefront arrays.
    real : Array, property
        Derived property from `phasor`; real component of the electric field.
    imaginary : Array, property
        Derived property from `phasor`; imaginary component of the electric field.
    amplitude : Array, property
        Derived property from `phasor`; field amplitude `abs(phasor)`.
    phase : Array, property
        Derived property from `phasor`; field phase angle.
    complex : tuple[Array, Array], property
        Derived property from `phasor`; `(real, imaginary)` representation.
    polar : tuple[Array, Array], property
        Derived property from `phasor`; `(amplitude, phase)` representation.
    psf : Array, property
        Derived property from `phasor`; intensity image `abs(phasor) ** 2`.
    wavenumber : Array, property
        Derived property from `wavelength`; scalar `2 * pi / wavelength`.
    ndim : int, property
        Derived property from `pixel_scale`; vectorisation rank of wavefront state.
    power : Array, property
        Derived property from `amplitude`; total wavefront power.
    """

    phasor: Array[complex]
    wavelength: float
    pixel_scale: float
    center: float

    def __init__(
        self: Wavefront,
        wavelength: float,
        npixels: int,
        diameter: float = None,
        pixel_scale: float = None,
        center: Array = None,
    ):
        """
        Parameters
        ----------
        wavelength : float, meters
            The wavelength of the `Wavefront`.
        npixels : int
            The number of pixels that represent the `Wavefront`.
        diameter : float = None, meters
            The total diameter of the `Wavefront`. Either `diameter` or `pixel_scale`
            must be provided.
        pixel_scale : float = None, meters/pixel
            The pixel scale of the `Wavefront`. Either `diameter` or `pixel_scale`
            must be provided.
        center : Array = None
            The centre coordinate of the wavefront grid, in metres. Defaults to zero.
        """
        # Handle diameter vs pixel_scale
        if diameter is None and pixel_scale is None:
            raise ValueError("Provide one: diameter or pixel_scale.")
        if diameter is not None and pixel_scale is not None:
            raise ValueError(
                "Cannot specify both 'diameter' and 'pixel_scale' - they are "
                "interdependent (diameter = pixel_scale × npixels). Choose one: "
                "use 'diameter' for wavefront diameter, or 'pixel_scale' for "
                "wavefront sampling."
            )

        self.wavelength = np.asarray(wavelength, float)
        if diameter is not None:
            self.pixel_scale = np.asarray(diameter / npixels, float)
        else:
            self.pixel_scale = np.asarray(pixel_scale, float)

        amplitude = np.ones((npixels, npixels), dtype=float) / npixels**2
        phase = np.zeros((npixels, npixels), dtype=float)
        self.phasor = amplitude * np.exp(1j * phase)

        if center is not None:
            self.center = np.asarray(center, float)

            # NOTE: only 1d offsets are presently supported
            if self.center.shape != (1,):
                raise ValueError("center must have shape (1,).")
        else:
            self.center = np.zeros(1, float)

    @classmethod
    def from_phasor(
        cls,
        phasor: Array[complex],
        wavelength: float,
        pixel_scale: float = None,
        diameter: float = None,
        center: Array = None,
    ) -> Wavefront:
        """
        Create a Wavefront from an existing phasor array.

        Parameters
        ----------
        phasor : Array[complex]
            The complex electric field array.
        wavelength : float, meters
            The wavelength of the wavefront.
        pixel_scale : float = None, meters/pixel
            The pixel scale of the phasor array. Either `pixel_scale` or
            `diameter` must be provided.
        diameter : float = None, meters
            The diameter of the phasor array. Either `pixel_scale` or
            `diameter` must be provided.
        center : Array = None
            The centre coordinate of the wavefront grid, in metres. Defaults to zero.

        Returns
        -------
        wavefront : Wavefront
            A new Wavefront object with the specified phasor.
        """
        # Infer npixels from phasor shape
        phasor_arr = np.asarray(phasor, complex)
        npixels = phasor_arr.shape[-1]

        # Create instance with appropriate parameters and set the phasor
        return cls(
            npixels=npixels,
            wavelength=wavelength,
            diameter=diameter,
            pixel_scale=pixel_scale,
            center=center,
        ).set(phasor=phasor_arr)

    @property
    def diameter(self: Wavefront) -> Array:
        """
        Returns the current wavefront diameter calculated using the pixel scale and
        number of pixels.

        Returns
        -------
        diameter : Array, meters or radians
            The current diameter of the wavefront.
        """
        return self.npixels * self.pixel_scale

    @property
    def npixels(self: Wavefront) -> int:
        """
        Returns the side length of the arrays currently representing the wavefront.
        Taken from the last axis of the amplitude array.

        Returns
        -------
        pixels : int
            The number of pixels that represent the `Wavefront`.
        """
        return self.phasor.shape[-1]

    @property
    def real(self: Wavefront) -> Array:
        """
        Returns the real component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The real component of the `Wavefront` phasor.
        """
        return self.phasor.real

    @property
    def imaginary(self: Wavefront) -> Array:
        """
        Returns the imaginary component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The imaginary component of the `Wavefront` phasor.
        """
        return self.phasor.imag

    @property
    def amplitude(self: Wavefront) -> Array:
        """
        Returns the amplitude component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The amplitude component of the `Wavefront` phasor.
        """
        return np.abs(self.phasor)

    @property
    def phase(self: Wavefront) -> Array:
        """
        Returns the phase component of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The phase component of the `Wavefront` phasor.
        """
        return np.angle(self.phasor)

    @property
    def complex(self: Wavefront) -> Array:
        """
        Returns the complex phasor of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The complex phasor of the `Wavefront`.
        """
        return np.stack([self.phasor.real, self.phasor.imag], axis=0)

    @property
    def polar(self: Wavefront) -> Array:
        """
        Returns the polar representation (amplitude, phase) of the `Wavefront`.

        Returns
        -------
        wavefront : Array
            The polar representation of the `Wavefront` as a stack of amplitude and
            phase.
        """
        return np.stack([self.amplitude, self.phase], axis=0)

    @property
    def psf(self: Wavefront) -> Array:
        """
        Calculates the Point Spread Function (PSF), i.e. the squared modulus
        of the complex wavefront.

        Returns
        -------
        psf : Array
            The PSF of the wavefront.
        """
        return np.abs(self.phasor) ** 2

    def to_psf(self: Wavefront) -> PSF:
        """
        Converts the wavefront to a dLux PSF object.

        Returns
        -------
        psf : PSF
            A PSF object containing the current wavefront intensity and
            pixel scale.
        """
        return PSF(self.psf, self.pixel_scale)

    @property
    def wavenumber(self: Wavefront) -> Array:
        """
        Returns the wavenumber of the wavefront (2 * pi / wavelength).

        Returns
        -------
        wavenumber : Array, 1/meters
            The wavenumber of the wavefront.
        """
        return 2 * np.pi / self.wavelength

    @property
    def ndim(self: Wavefront) -> int:
        """
        Returns the number of 'dimensions' of the wavefront. This is used to track the
        vectorised version of the wavefront returned from vmapping.

        NOTE: May clash with future polarised wavefront.

        Returns
        -------
        ndim : int
            The 'dimensionality' of dimensions of the wavefront.
        """
        return self.pixel_scale.ndim

    @property
    def power(self: Wavefront) -> Array:
        """
        Returns the total power of the wavefront (sum of |E|^2 over pixels).

        Returns
        -------
        power : Array
            The total power of the wavefront.
        """
        return np.sum(np.abs(self.phasor) ** 2)

    def add_phase(self: Wavefront, phase: Array) -> Wavefront:
        """
        Applies a phase (in radians) to the wavefront by multiplying the phasor
        by exp(1j * phase). Supports broadcasting.

        Parameters
        ----------
        phase : Array, radians
            The phase to be added to the wavefront.

        Returns
        -------
        wavefront : Wavefront
            New wavefront whose phasor is self.phasor * exp(1j * phase).
        """
        if phase is None:
            return self
        return self.multiply("phasor", np.exp(1j * phase))

    def add_opd(self: Wavefront, opd: Array) -> Wavefront:
        """
        Applies an optical path difference (in meters) by multiplying the phasor
        by exp(1j * k * opd), where k = 2*pi / wavelength. Supports broadcasting.

        Parameters
        ----------
        opd : Array, meters
            The optical path difference to apply.

        Returns
        -------
        wavefront : Wavefront
            New wavefront with phasor multiplied by exp(1j * k * opd).
        """
        if opd is None:
            return self
        return self.add_phase(self.wavenumber * np.asarray(opd))

    def tilt(self: Wavefront, angles: Array, unit: str = "rad") -> Wavefront:
        """
        Tilts the wavefront by the (x, y) angles.

        Parameters
        ----------
        angles : Array
            The (x, y) angles by which to tilt the wavefront, in `unit`.
        unit : str
            The units of the angles, e.g. "rad", "deg", "arcmin", "arcsec", and
            prefixed forms like "mrad", "mas", etc (as supported by utils/units.py).

        Returns
        -------
        wavefront : Wavefront
            The tilted wavefront.
        """
        angles = np.asarray(angles, dtype=float)
        if angles.shape != (2,):
            raise ValueError("angles must be a 1d array of shape (2,).")

        # Calculate scaled coordinates
        coords = self.coordinates(scale=dlu.unit_factor_to_rad(unit))

        # Tilt the wavefront
        return self.add_opd(np.sum(angles[:, None, None] * coords, axis=0))

    def normalise(
        self: Wavefront,
        mode: str = "power",
        value: float = 1.0,
    ) -> Wavefront:
        """
        Normalise the wavefront.

        Parameters
        ----------
        mode : {"power","peak"} = "power"
            - "power": scales so sum(|E|^2) == value (discrete sum over pixels).
            - "peak" : scales so max(|E|^2) == value.
        value : float = 1.0
            Target value for the selected mode.

        Returns
        -------
        wavefront : Wavefront
            New wavefront with phasor scaled to achieve the normalisation.
        """
        if mode == "power":
            scale = np.sqrt(value / self.power)
        elif mode == "peak":
            scale = np.sqrt(value / self.psf.max())
        else:
            raise ValueError("mode must be 'power' or 'peak'")
        return self.multiply("phasor", scale)

    def flip(self: Wavefront, axis: tuple[int] | int) -> Wavefront:
        """
        Flip the complex phasor along one or more axes (ij indexing: 0=y, 1=x).

        Parameters
        ----------
        axis : int or tuple of ints
            Axes to flip.

        Returns
        -------
        wavefront : Wavefront
            New wavefront with phasor flipped.
        """
        return self.set(phasor=np.flip(self.phasor, axis))

    def scale_to(
        self: Wavefront,
        npixels: int,
        pixel_scale: Array,
        complex: bool = True,
    ) -> Wavefront:
        """
        Interpolates the wavefront to a given npixels and pixel_scale. Can be done on
        the real and imaginary components by passing in complex=True.

        Parameters
        ----------
        npixels : int
            The number of pixels to interpolate to.
        pixel_scale: Array
            The pixel scale to interpolate to.
        complex : bool = True
            If True, interpolate the real and imaginary components. If False,
            interpolate the amplitude and phase components.

        Returns
        -------
        wavefront : Wavefront
            The new interpolated wavefront.
        """
        # Get field in either (amplitude, phase) or (real, imaginary)
        fields = self.complex if complex else self.polar

        # Scale the field
        scale_fn = vmap(dlu.scale, (0, None, None))
        fields = scale_fn(fields, npixels, pixel_scale / self.pixel_scale)

        # Convert back to complex form
        if complex:
            phasor = fields[0] + 1j * fields[1]
        else:
            phasor = fields[0] * np.exp(1j * fields[1])

        # Return new wavefront
        return self.set(phasor=phasor, pixel_scale=pixel_scale)

    def rotate(
        self: Wavefront,
        angle: Array,
        method: str = "linear",
        complex: bool = True,
    ) -> Wavefront:
        """
        Rotates the wavefront by a given angle via interpolation. Can be done on the
        real and imaginary components by passing in complex=True.

        Parameters
        ----------
        angle : Array, radians
            The angle by which to rotate the wavefront in a clockwise
            direction.
        method : str = "linear"
            The interpolation method.
        complex : bool = True
            If True, rotate the real and imaginary components. If False, rotate the
            amplitude and phase components.

        Returns
        -------
        wavefront : Wavefront
            The new wavefront rotated by angle in the clockwise direction.
        """
        # Get field in either (amplitude, phase) or (real, imaginary)
        fields = self.complex if complex else self.polar

        # Rotate the field
        rotator = vmap(dlu.rotate, (0, None, None))
        fields = rotator(fields, angle, method)

        # Convert back to complex form
        if complex:
            phasor = fields[0] + 1j * fields[1]
        else:
            phasor = fields[0] * np.exp(1j * fields[1])

        # Return new wavefront
        return self.set(phasor=phasor)

    def resize(self: Wavefront, npixels: int) -> Wavefront:
        """
        Resizes the wavefront via a zero-padding or cropping operation.

        Parameters
        ----------
        npixels : int
            The size to resize the wavefront to.

        Returns
        -------
        wavefront : Wavefront
            The resized wavefront.
        """
        return self.set(phasor=dlu.resize(self.phasor, npixels, 0j))

    def coordinates(
        self: Wavefront,
        scale=1.0,
        polar: bool = False,
    ) -> Array:
        """
        Returns the physical positions of the wavefront pixels in meters, with an
        optional scaling factor for numerical stability.

        Parameters
        ----------
        scale : float = 1.0
            Optional scaling factor applied to the diameter for numerical stability.
        polar : bool = False
            Output the coordinates in polar (r, phi) coordinates.

        Returns
        -------
        coordinates : Array
            The coordinates of the centers of each pixel representing the wavefront.
        """
        xs = self.xs * scale
        coords = np.array(np.meshgrid(xs, xs))
        if polar:
            return dlu.cart2polar(coords)
        return coords

    @property
    def spec(self):
        """
        Returns the current wavefront sampling as a `CoordSpec`.

        Returns
        -------
        spec : CoordSpec
            Coordinate specification with `n`, `d`, and `c` set from the
            current wavefront state.
        """
        return CoordSpec(self.npixels, self.pixel_scale, self.center)

    @property
    def xs(self):
        """
        1D array of pixel centre coordinates along one axis.

        Returns
        -------
        xs : Array
            Coordinates of pixel centres, in metres.
        """
        return self.spec.xs

    def set_spec(self, spec: CoordSpec):
        """
        Updates the wavefront pixel scale and centre from a `CoordSpec`.

        Parameters
        ----------
        spec : CoordSpec
            The coordinate specification to apply.

        Returns
        -------
        wavefront : Wavefront
            New wavefront with updated `pixel_scale` and `center`.
        """
        return self.set(pixel_scale=spec.d, center=spec.c)

    def propagate_FFT(
        self,
        pad=2,
        focal_length=None,
        spec_out: CoordSpec = None,
        inverse=False,
    ):
        """
        Propagates the wavefront using an FFT-based method.

        Parameters
        ----------
        pad : int = 2
            Zero-padding factor applied before the FFT.
        focal_length : float | None = None
            Focal length for Cartesian focal sampling. Pass `None` for
            angular (far-field) sampling.
        spec_out : CoordSpec | None = None
            Output coordinate specification. If provided, only `c` (centre)
            may be set; `n` and `d` are determined by the propagation.
        inverse : bool = False
            If False, propagate forward through the system. If True, propagate
            backward through the system.

        Returns
        -------
        wavefront : Wavefront
            Propagated wavefront with updated phasor and sampling metadata.
        """
        # Input spec
        spec_in = self.spec
        wl = self.wavelength

        # Default FFT output center
        n_out = spec_in.n * pad
        d_fft, c_fft = dlu.fft_spec(n_out, spec_in.d, wl, focal_length)

        # Get the phase ramp and the output center
        if spec_out is not None:
            if spec_out.d is not None:
                raise ValueError("Output spec cannot specify d; FFT output d is fixed.")
            if spec_out.n is not None:
                raise ValueError(
                    "Output spec cannot specify n; FFT output n is determined by the ",
                    "pad parameter.",
                )

            # Calculate the input phase ramp for the FFT propagation
            shift = c_fft - spec_out.c
            in_ramp = dlu.fft_phase_ramp(spec_in.xs, wl, shift, focal_length, inverse)

            # Calculate the output phase ramp correction
            spec_out = spec_out.set(n=n_out, d=d_fft)
            shift = dlu.fft_spec(spec_out.n, spec_out.d, wl, focal_length)[1]
            out_ramp = dlu.fft_phase_ramp(spec_out.xs, wl, shift, focal_length, inverse)

        else:
            in_ramp, out_ramp = 1.0, 1.0
            spec_out = CoordSpec(n=n_out, d=d_fft, c=c_fft)

        # Apply ramp and FFT
        phasor, pixel_scale = dlu.FFT(
            phasor=self.phasor * in_ramp,
            wavelength=self.wavelength,
            pixel_scale=self.pixel_scale,
            focal_length=focal_length,
            inverse=inverse,
            pad=pad,
        )

        # Update the values
        return self.set(
            phasor=phasor * out_ramp, pixel_scale=pixel_scale, center=spec_out.c
        )

    def propagate(
        self: Wavefront,
        npixels: int,
        pixel_scale: float,
        focal_length: float = None,
        inverse: bool = False,
    ) -> Wavefront:
        """
        Legacy MFT propagation function without CoordSpec.

        Parameters
        ----------
        npixels : int
            Output array size (square).
        pixel_scale : float
            Desired output pixel scale (meters/pixel or radians/pixel depending on
            units).
        focal_length : float | None
            Focal length for Cartesian focal sampling; None for angular focal sampling.
        inverse : bool = False
            If False, propagate forward through the system. If True, propagate
            backward through the system.

        Returns
        -------
        wavefront : Wavefront
            Propagated wavefront with new phasor and sampling metadata.

        Notes
        -----
        - Ideal for generating PSFs at arbitrary sampling.
        - For broadband propagation, vmap this function over wavelength and pixel_scale.
        """
        # Propagate
        phasor = dlu.MFT(
            phasor=self.phasor,
            wavelength=self.wavelength,
            pixel_scale_in=self.pixel_scale,
            npixels_out=npixels,
            pixel_scale_out=pixel_scale,
            focal_length=focal_length,
            inverse=inverse,
        )
        return self.set(phasor=phasor, pixel_scale=np.array(pixel_scale, float))

    def propagate_MFT(self, spec_out, focal_length=None, inverse=None):
        """
        Propagates the wavefront using an MFT-based method with a `CoordSpec`.

        Parameters
        ----------
        spec_out : CoordSpec
            Output coordinate specification defining the number of pixels
            and pixel scale of the propagated field.
        focal_length : float | None = None
            Focal length for Cartesian focal sampling. Pass `None` for
            angular (far-field) sampling.
        inverse : bool | None = None
            If False or None, propagate forward through the system. If True,
            propagate backward through the system.

        Returns
        -------
        wavefront : Wavefront
            Propagated wavefront with updated phasor and pixel scale.
        """
        # Propagate
        phasor = dlu.MFT(
            phasor=self.phasor,
            wavelength=self.wavelength,
            pixel_scale_in=self.pixel_scale,
            npixels_out=spec_out.n,
            pixel_scale_out=spec_out.d,
            focal_length=focal_length,
            inverse=inverse,
        )
        return self.set(phasor=phasor, pixel_scale=np.array(spec_out.d, float))

    #######################
    ### New Propagators ###
    #######################
    def propagate_ASM(self):
        """Angular spectrum free-space propagation"""
        raise NotImplementedError()

    def propagate_fresnel(self):
        """LCT-based MFT Fresnel propagation"""
        raise NotImplementedError()

    def propagate_fresnel_fft(self):
        """LCT-based FFT Fresnel propagation"""
        raise NotImplementedError()

    def propagate_fraunhofer(self):
        """
        Fraunhofer propagation via MFT (same as propagate MFT, but with abcdLux backend)
        """
        raise NotImplementedError()

    def propagate_fraunhofer_fft(self):
        """
        Fraunhofer propagation via FFT (same as propagate FFT, but with abcdLux backend)
        """
        raise NotImplementedError()

    def _magic_unified_op(
        self: Wavefront, other: Wavefront | Array | None, op: str
    ) -> Wavefront:
        """
        Internal helper function to unify the logic of the magic methods for addition,
        subtraction, multiplication and division.

        Parameters
        ----------
        other : Wavefront | Array | None
            The object to operate with. Can be a complex array, a Wavefront, or None.
        op : str
            The operation to perform: 'add', 'subtract', 'multiply', or 'divide'.

        Returns
        -------
        wavefront : Wavefront
            The resulting wavefront after applying the operation.
        """
        # Nones always return unchanged
        if other is None:
            return self

        # Check for supported types
        if not isinstance(other, (Wavefront, Array, float, int, complex)):
            raise TypeError(
                f"Unsupported type for {op}: {type(other)}. Must be an array, "
                "Wavefront, or None."
            )

        # Extract phasor if other is a Wavefront
        if isinstance(other, Wavefront):
            other = other.phasor

        # Apply the operation
        if op == "add":
            return self.add("phasor", other)
        elif op == "subtract":
            return self.add("phasor", -other)
        elif op == "multiply":
            return self.multiply("phasor", other)
        elif op == "divide":
            return self.multiply("phasor", 1 / other)
        else:
            raise ValueError(f"Unsupported operation '{op}'.")

    def __add__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """
        Allows complex phasors or Wavefront objects to be added together. None values
        are ignored.
        """
        return self._magic_unified_op(other, "add")

    def __sub__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """
        Allows complex phasors or Wavefront objects to be subtracted. None values are
        ignored.
        """
        return self._magic_unified_op(other, "subtract")

    def __mul__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """
        Allows complex phasors or Wavefront objects to be multiplied. None values are
        ignored.
        """
        return self._magic_unified_op(other, "multiply")

    def __truediv__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """
        Allows complex phasors or Wavefront objects to be divided. None values are
        ignored.
        """
        return self._magic_unified_op(other, "divide")

    def __iadd__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """In-place addition."""
        return self.__add__(other)

    def __isub__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """In-place subtraction."""
        return self.__sub__(other)

    def __imul__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """In-place multiplication."""
        return self.__mul__(other)

    def __itruediv__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
        """In-place division."""
        return self.__truediv__(other)

amplitude property ¤

Returns the amplitude component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The amplitude component of the Wavefront phasor.

complex property ¤

Returns the complex phasor of the Wavefront.

Returns:

Name Type Description
wavefront Array

The complex phasor of the Wavefront.

diameter property ¤

Returns the current wavefront diameter calculated using the pixel scale and number of pixels.

Returns:

Name Type Description
diameter (Array, meters or radians)

The current diameter of the wavefront.

imaginary property ¤

Returns the imaginary component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The imaginary component of the Wavefront phasor.

ndim property ¤

Returns the number of 'dimensions' of the wavefront. This is used to track the vectorised version of the wavefront returned from vmapping.

NOTE: May clash with future polarised wavefront.

Returns:

Name Type Description
ndim int

The 'dimensionality' of dimensions of the wavefront.

npixels property ¤

Returns the side length of the arrays currently representing the wavefront. Taken from the last axis of the amplitude array.

Returns:

Name Type Description
pixels int

The number of pixels that represent the Wavefront.

phase property ¤

Returns the phase component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The phase component of the Wavefront phasor.

polar property ¤

Returns the polar representation (amplitude, phase) of the Wavefront.

Returns:

Name Type Description
wavefront Array

The polar representation of the Wavefront as a stack of amplitude and phase.

power property ¤

Returns the total power of the wavefront (sum of |E|^2 over pixels).

Returns:

Name Type Description
power Array

The total power of the wavefront.

psf property ¤

Calculates the Point Spread Function (PSF), i.e. the squared modulus of the complex wavefront.

Returns:

Name Type Description
psf Array

The PSF of the wavefront.

real property ¤

Returns the real component of the Wavefront.

Returns:

Name Type Description
wavefront Array

The real component of the Wavefront phasor.

spec property ¤

Returns the current wavefront sampling as a CoordSpec.

Returns:

Name Type Description
spec CoordSpec

Coordinate specification with n, d, and c set from the current wavefront state.

wavenumber property ¤

Returns the wavenumber of the wavefront (2 * pi / wavelength).

Returns:

Name Type Description
wavenumber (Array, 1 / meters)

The wavenumber of the wavefront.

xs property ¤

1D array of pixel centre coordinates along one axis.

Returns:

Name Type Description
xs Array

Coordinates of pixel centres, in metres.

__add__(other) ¤

Allows complex phasors or Wavefront objects to be added together. None values are ignored.

Source code in dLux/wavefronts.py
831
832
833
834
835
836
def __add__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """
    Allows complex phasors or Wavefront objects to be added together. None values
    are ignored.
    """
    return self._magic_unified_op(other, "add")

__iadd__(other) ¤

In-place addition.

Source code in dLux/wavefronts.py
859
860
861
def __iadd__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """In-place addition."""
    return self.__add__(other)

__imul__(other) ¤

In-place multiplication.

Source code in dLux/wavefronts.py
867
868
869
def __imul__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """In-place multiplication."""
    return self.__mul__(other)

__init__(wavelength, npixels, diameter=None, pixel_scale=None, center=None) ¤

Parameters:

Name Type Description Default
wavelength (float, meters)

The wavelength of the Wavefront.

required
npixels int

The number of pixels that represent the Wavefront.

required
diameter float = None, meters

The total diameter of the Wavefront. Either diameter or pixel_scale must be provided.

None
pixel_scale float = None, meters/pixel

The pixel scale of the Wavefront. Either diameter or pixel_scale must be provided.

None
center Array = None

The centre coordinate of the wavefront grid, in metres. Defaults to zero.

None
Source code in dLux/wavefronts.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self: Wavefront,
    wavelength: float,
    npixels: int,
    diameter: float = None,
    pixel_scale: float = None,
    center: Array = None,
):
    """
    Parameters
    ----------
    wavelength : float, meters
        The wavelength of the `Wavefront`.
    npixels : int
        The number of pixels that represent the `Wavefront`.
    diameter : float = None, meters
        The total diameter of the `Wavefront`. Either `diameter` or `pixel_scale`
        must be provided.
    pixel_scale : float = None, meters/pixel
        The pixel scale of the `Wavefront`. Either `diameter` or `pixel_scale`
        must be provided.
    center : Array = None
        The centre coordinate of the wavefront grid, in metres. Defaults to zero.
    """
    # Handle diameter vs pixel_scale
    if diameter is None and pixel_scale is None:
        raise ValueError("Provide one: diameter or pixel_scale.")
    if diameter is not None and pixel_scale is not None:
        raise ValueError(
            "Cannot specify both 'diameter' and 'pixel_scale' - they are "
            "interdependent (diameter = pixel_scale × npixels). Choose one: "
            "use 'diameter' for wavefront diameter, or 'pixel_scale' for "
            "wavefront sampling."
        )

    self.wavelength = np.asarray(wavelength, float)
    if diameter is not None:
        self.pixel_scale = np.asarray(diameter / npixels, float)
    else:
        self.pixel_scale = np.asarray(pixel_scale, float)

    amplitude = np.ones((npixels, npixels), dtype=float) / npixels**2
    phase = np.zeros((npixels, npixels), dtype=float)
    self.phasor = amplitude * np.exp(1j * phase)

    if center is not None:
        self.center = np.asarray(center, float)

        # NOTE: only 1d offsets are presently supported
        if self.center.shape != (1,):
            raise ValueError("center must have shape (1,).")
    else:
        self.center = np.zeros(1, float)

__isub__(other) ¤

In-place subtraction.

Source code in dLux/wavefronts.py
863
864
865
def __isub__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """In-place subtraction."""
    return self.__sub__(other)

__itruediv__(other) ¤

In-place division.

Source code in dLux/wavefronts.py
871
872
873
def __itruediv__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """In-place division."""
    return self.__truediv__(other)

__mul__(other) ¤

Allows complex phasors or Wavefront objects to be multiplied. None values are ignored.

Source code in dLux/wavefronts.py
845
846
847
848
849
850
def __mul__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """
    Allows complex phasors or Wavefront objects to be multiplied. None values are
    ignored.
    """
    return self._magic_unified_op(other, "multiply")

__sub__(other) ¤

Allows complex phasors or Wavefront objects to be subtracted. None values are ignored.

Source code in dLux/wavefronts.py
838
839
840
841
842
843
def __sub__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """
    Allows complex phasors or Wavefront objects to be subtracted. None values are
    ignored.
    """
    return self._magic_unified_op(other, "subtract")

__truediv__(other) ¤

Allows complex phasors or Wavefront objects to be divided. None values are ignored.

Source code in dLux/wavefronts.py
852
853
854
855
856
857
def __truediv__(self: Wavefront, other: Wavefront | Array | None) -> Wavefront:
    """
    Allows complex phasors or Wavefront objects to be divided. None values are
    ignored.
    """
    return self._magic_unified_op(other, "divide")

add_opd(opd) ¤

Applies an optical path difference (in meters) by multiplying the phasor by exp(1j * k * opd), where k = 2*pi / wavelength. Supports broadcasting.

Parameters:

Name Type Description Default
opd (Array, meters)

The optical path difference to apply.

required

Returns:

Name Type Description
wavefront Wavefront

New wavefront with phasor multiplied by exp(1j * k * opd).

Source code in dLux/wavefronts.py
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
def add_opd(self: Wavefront, opd: Array) -> Wavefront:
    """
    Applies an optical path difference (in meters) by multiplying the phasor
    by exp(1j * k * opd), where k = 2*pi / wavelength. Supports broadcasting.

    Parameters
    ----------
    opd : Array, meters
        The optical path difference to apply.

    Returns
    -------
    wavefront : Wavefront
        New wavefront with phasor multiplied by exp(1j * k * opd).
    """
    if opd is None:
        return self
    return self.add_phase(self.wavenumber * np.asarray(opd))

add_phase(phase) ¤

Applies a phase (in radians) to the wavefront by multiplying the phasor by exp(1j * phase). Supports broadcasting.

Parameters:

Name Type Description Default
phase (Array, radians)

The phase to be added to the wavefront.

required

Returns:

Name Type Description
wavefront Wavefront

New wavefront whose phasor is self.phasor * exp(1j * phase).

Source code in dLux/wavefronts.py
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
def add_phase(self: Wavefront, phase: Array) -> Wavefront:
    """
    Applies a phase (in radians) to the wavefront by multiplying the phasor
    by exp(1j * phase). Supports broadcasting.

    Parameters
    ----------
    phase : Array, radians
        The phase to be added to the wavefront.

    Returns
    -------
    wavefront : Wavefront
        New wavefront whose phasor is self.phasor * exp(1j * phase).
    """
    if phase is None:
        return self
    return self.multiply("phasor", np.exp(1j * phase))

coordinates(scale=1.0, polar=False) ¤

Returns the physical positions of the wavefront pixels in meters, with an optional scaling factor for numerical stability.

Parameters:

Name Type Description Default
scale float = 1.0

Optional scaling factor applied to the diameter for numerical stability.

1.0
polar bool = False

Output the coordinates in polar (r, phi) coordinates.

False

Returns:

Name Type Description
coordinates Array

The coordinates of the centers of each pixel representing the wavefront.

Source code in dLux/wavefronts.py
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
def coordinates(
    self: Wavefront,
    scale=1.0,
    polar: bool = False,
) -> Array:
    """
    Returns the physical positions of the wavefront pixels in meters, with an
    optional scaling factor for numerical stability.

    Parameters
    ----------
    scale : float = 1.0
        Optional scaling factor applied to the diameter for numerical stability.
    polar : bool = False
        Output the coordinates in polar (r, phi) coordinates.

    Returns
    -------
    coordinates : Array
        The coordinates of the centers of each pixel representing the wavefront.
    """
    xs = self.xs * scale
    coords = np.array(np.meshgrid(xs, xs))
    if polar:
        return dlu.cart2polar(coords)
    return coords

flip(axis) ¤

Flip the complex phasor along one or more axes (ij indexing: 0=y, 1=x).

Parameters:

Name Type Description Default
axis int or tuple of ints

Axes to flip.

required

Returns:

Name Type Description
wavefront Wavefront

New wavefront with phasor flipped.

Source code in dLux/wavefronts.py
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
def flip(self: Wavefront, axis: tuple[int] | int) -> Wavefront:
    """
    Flip the complex phasor along one or more axes (ij indexing: 0=y, 1=x).

    Parameters
    ----------
    axis : int or tuple of ints
        Axes to flip.

    Returns
    -------
    wavefront : Wavefront
        New wavefront with phasor flipped.
    """
    return self.set(phasor=np.flip(self.phasor, axis))

from_phasor(phasor, wavelength, pixel_scale=None, diameter=None, center=None) classmethod ¤

Create a Wavefront from an existing phasor array.

Parameters:

Name Type Description Default
phasor Array[complex]

The complex electric field array.

required
wavelength (float, meters)

The wavelength of the wavefront.

required
pixel_scale float = None, meters/pixel

The pixel scale of the phasor array. Either pixel_scale or diameter must be provided.

None
diameter float = None, meters

The diameter of the phasor array. Either pixel_scale or diameter must be provided.

None
center Array = None

The centre coordinate of the wavefront grid, in metres. Defaults to zero.

None

Returns:

Name Type Description
wavefront Wavefront

A new Wavefront object with the specified phasor.

Source code in dLux/wavefronts.py
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
@classmethod
def from_phasor(
    cls,
    phasor: Array[complex],
    wavelength: float,
    pixel_scale: float = None,
    diameter: float = None,
    center: Array = None,
) -> Wavefront:
    """
    Create a Wavefront from an existing phasor array.

    Parameters
    ----------
    phasor : Array[complex]
        The complex electric field array.
    wavelength : float, meters
        The wavelength of the wavefront.
    pixel_scale : float = None, meters/pixel
        The pixel scale of the phasor array. Either `pixel_scale` or
        `diameter` must be provided.
    diameter : float = None, meters
        The diameter of the phasor array. Either `pixel_scale` or
        `diameter` must be provided.
    center : Array = None
        The centre coordinate of the wavefront grid, in metres. Defaults to zero.

    Returns
    -------
    wavefront : Wavefront
        A new Wavefront object with the specified phasor.
    """
    # Infer npixels from phasor shape
    phasor_arr = np.asarray(phasor, complex)
    npixels = phasor_arr.shape[-1]

    # Create instance with appropriate parameters and set the phasor
    return cls(
        npixels=npixels,
        wavelength=wavelength,
        diameter=diameter,
        pixel_scale=pixel_scale,
        center=center,
    ).set(phasor=phasor_arr)

normalise(mode='power', value=1.0) ¤

Normalise the wavefront.

Parameters:

Name Type Description Default
mode ('power', 'peak')
  • "power": scales so sum(|E|^2) == value (discrete sum over pixels).
  • "peak" : scales so max(|E|^2) == value.
"power","peak"
value float = 1.0

Target value for the selected mode.

1.0

Returns:

Name Type Description
wavefront Wavefront

New wavefront with phasor scaled to achieve the normalisation.

Source code in dLux/wavefronts.py
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
def normalise(
    self: Wavefront,
    mode: str = "power",
    value: float = 1.0,
) -> Wavefront:
    """
    Normalise the wavefront.

    Parameters
    ----------
    mode : {"power","peak"} = "power"
        - "power": scales so sum(|E|^2) == value (discrete sum over pixels).
        - "peak" : scales so max(|E|^2) == value.
    value : float = 1.0
        Target value for the selected mode.

    Returns
    -------
    wavefront : Wavefront
        New wavefront with phasor scaled to achieve the normalisation.
    """
    if mode == "power":
        scale = np.sqrt(value / self.power)
    elif mode == "peak":
        scale = np.sqrt(value / self.psf.max())
    else:
        raise ValueError("mode must be 'power' or 'peak'")
    return self.multiply("phasor", scale)

propagate(npixels, pixel_scale, focal_length=None, inverse=False) ¤

Legacy MFT propagation function without CoordSpec.

Parameters:

Name Type Description Default
npixels int

Output array size (square).

required
pixel_scale float

Desired output pixel scale (meters/pixel or radians/pixel depending on units).

required
focal_length float | None

Focal length for Cartesian focal sampling; None for angular focal sampling.

None
inverse bool = False

If False, propagate forward through the system. If True, propagate backward through the system.

False

Returns:

Name Type Description
wavefront Wavefront

Propagated wavefront with new phasor and sampling metadata.

Notes
  • Ideal for generating PSFs at arbitrary sampling.
  • For broadband propagation, vmap this function over wavelength and pixel_scale.
Source code in dLux/wavefronts.py
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
def propagate(
    self: Wavefront,
    npixels: int,
    pixel_scale: float,
    focal_length: float = None,
    inverse: bool = False,
) -> Wavefront:
    """
    Legacy MFT propagation function without CoordSpec.

    Parameters
    ----------
    npixels : int
        Output array size (square).
    pixel_scale : float
        Desired output pixel scale (meters/pixel or radians/pixel depending on
        units).
    focal_length : float | None
        Focal length for Cartesian focal sampling; None for angular focal sampling.
    inverse : bool = False
        If False, propagate forward through the system. If True, propagate
        backward through the system.

    Returns
    -------
    wavefront : Wavefront
        Propagated wavefront with new phasor and sampling metadata.

    Notes
    -----
    - Ideal for generating PSFs at arbitrary sampling.
    - For broadband propagation, vmap this function over wavelength and pixel_scale.
    """
    # Propagate
    phasor = dlu.MFT(
        phasor=self.phasor,
        wavelength=self.wavelength,
        pixel_scale_in=self.pixel_scale,
        npixels_out=npixels,
        pixel_scale_out=pixel_scale,
        focal_length=focal_length,
        inverse=inverse,
    )
    return self.set(phasor=phasor, pixel_scale=np.array(pixel_scale, float))

propagate_ASM() ¤

Angular spectrum free-space propagation

Source code in dLux/wavefronts.py
761
762
763
def propagate_ASM(self):
    """Angular spectrum free-space propagation"""
    raise NotImplementedError()

propagate_FFT(pad=2, focal_length=None, spec_out=None, inverse=False) ¤

Propagates the wavefront using an FFT-based method.

Parameters:

Name Type Description Default
pad int = 2

Zero-padding factor applied before the FFT.

2
focal_length float | None = None

Focal length for Cartesian focal sampling. Pass None for angular (far-field) sampling.

None
spec_out CoordSpec | None = None

Output coordinate specification. If provided, only c (centre) may be set; n and d are determined by the propagation.

None
inverse bool = False

If False, propagate forward through the system. If True, propagate backward through the system.

False

Returns:

Name Type Description
wavefront Wavefront

Propagated wavefront with updated phasor and sampling metadata.

Source code in dLux/wavefronts.py
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
def propagate_FFT(
    self,
    pad=2,
    focal_length=None,
    spec_out: CoordSpec = None,
    inverse=False,
):
    """
    Propagates the wavefront using an FFT-based method.

    Parameters
    ----------
    pad : int = 2
        Zero-padding factor applied before the FFT.
    focal_length : float | None = None
        Focal length for Cartesian focal sampling. Pass `None` for
        angular (far-field) sampling.
    spec_out : CoordSpec | None = None
        Output coordinate specification. If provided, only `c` (centre)
        may be set; `n` and `d` are determined by the propagation.
    inverse : bool = False
        If False, propagate forward through the system. If True, propagate
        backward through the system.

    Returns
    -------
    wavefront : Wavefront
        Propagated wavefront with updated phasor and sampling metadata.
    """
    # Input spec
    spec_in = self.spec
    wl = self.wavelength

    # Default FFT output center
    n_out = spec_in.n * pad
    d_fft, c_fft = dlu.fft_spec(n_out, spec_in.d, wl, focal_length)

    # Get the phase ramp and the output center
    if spec_out is not None:
        if spec_out.d is not None:
            raise ValueError("Output spec cannot specify d; FFT output d is fixed.")
        if spec_out.n is not None:
            raise ValueError(
                "Output spec cannot specify n; FFT output n is determined by the ",
                "pad parameter.",
            )

        # Calculate the input phase ramp for the FFT propagation
        shift = c_fft - spec_out.c
        in_ramp = dlu.fft_phase_ramp(spec_in.xs, wl, shift, focal_length, inverse)

        # Calculate the output phase ramp correction
        spec_out = spec_out.set(n=n_out, d=d_fft)
        shift = dlu.fft_spec(spec_out.n, spec_out.d, wl, focal_length)[1]
        out_ramp = dlu.fft_phase_ramp(spec_out.xs, wl, shift, focal_length, inverse)

    else:
        in_ramp, out_ramp = 1.0, 1.0
        spec_out = CoordSpec(n=n_out, d=d_fft, c=c_fft)

    # Apply ramp and FFT
    phasor, pixel_scale = dlu.FFT(
        phasor=self.phasor * in_ramp,
        wavelength=self.wavelength,
        pixel_scale=self.pixel_scale,
        focal_length=focal_length,
        inverse=inverse,
        pad=pad,
    )

    # Update the values
    return self.set(
        phasor=phasor * out_ramp, pixel_scale=pixel_scale, center=spec_out.c
    )

propagate_MFT(spec_out, focal_length=None, inverse=None) ¤

Propagates the wavefront using an MFT-based method with a CoordSpec.

Parameters:

Name Type Description Default
spec_out CoordSpec

Output coordinate specification defining the number of pixels and pixel scale of the propagated field.

required
focal_length float | None = None

Focal length for Cartesian focal sampling. Pass None for angular (far-field) sampling.

None
inverse bool | None = None

If False or None, propagate forward through the system. If True, propagate backward through the system.

None

Returns:

Name Type Description
wavefront Wavefront

Propagated wavefront with updated phasor and pixel scale.

Source code in dLux/wavefronts.py
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
def propagate_MFT(self, spec_out, focal_length=None, inverse=None):
    """
    Propagates the wavefront using an MFT-based method with a `CoordSpec`.

    Parameters
    ----------
    spec_out : CoordSpec
        Output coordinate specification defining the number of pixels
        and pixel scale of the propagated field.
    focal_length : float | None = None
        Focal length for Cartesian focal sampling. Pass `None` for
        angular (far-field) sampling.
    inverse : bool | None = None
        If False or None, propagate forward through the system. If True,
        propagate backward through the system.

    Returns
    -------
    wavefront : Wavefront
        Propagated wavefront with updated phasor and pixel scale.
    """
    # Propagate
    phasor = dlu.MFT(
        phasor=self.phasor,
        wavelength=self.wavelength,
        pixel_scale_in=self.pixel_scale,
        npixels_out=spec_out.n,
        pixel_scale_out=spec_out.d,
        focal_length=focal_length,
        inverse=inverse,
    )
    return self.set(phasor=phasor, pixel_scale=np.array(spec_out.d, float))

propagate_fraunhofer() ¤

Fraunhofer propagation via MFT (same as propagate MFT, but with abcdLux backend)

Source code in dLux/wavefronts.py
773
774
775
776
777
def propagate_fraunhofer(self):
    """
    Fraunhofer propagation via MFT (same as propagate MFT, but with abcdLux backend)
    """
    raise NotImplementedError()

propagate_fraunhofer_fft() ¤

Fraunhofer propagation via FFT (same as propagate FFT, but with abcdLux backend)

Source code in dLux/wavefronts.py
779
780
781
782
783
def propagate_fraunhofer_fft(self):
    """
    Fraunhofer propagation via FFT (same as propagate FFT, but with abcdLux backend)
    """
    raise NotImplementedError()

propagate_fresnel() ¤

LCT-based MFT Fresnel propagation

Source code in dLux/wavefronts.py
765
766
767
def propagate_fresnel(self):
    """LCT-based MFT Fresnel propagation"""
    raise NotImplementedError()

propagate_fresnel_fft() ¤

LCT-based FFT Fresnel propagation

Source code in dLux/wavefronts.py
769
770
771
def propagate_fresnel_fft(self):
    """LCT-based FFT Fresnel propagation"""
    raise NotImplementedError()

resize(npixels) ¤

Resizes the wavefront via a zero-padding or cropping operation.

Parameters:

Name Type Description Default
npixels int

The size to resize the wavefront to.

required

Returns:

Name Type Description
wavefront Wavefront

The resized wavefront.

Source code in dLux/wavefronts.py
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
def resize(self: Wavefront, npixels: int) -> Wavefront:
    """
    Resizes the wavefront via a zero-padding or cropping operation.

    Parameters
    ----------
    npixels : int
        The size to resize the wavefront to.

    Returns
    -------
    wavefront : Wavefront
        The resized wavefront.
    """
    return self.set(phasor=dlu.resize(self.phasor, npixels, 0j))

rotate(angle, method='linear', complex=True) ¤

Rotates the wavefront by a given angle via interpolation. Can be done on the real and imaginary components by passing in complex=True.

Parameters:

Name Type Description Default
angle (Array, radians)

The angle by which to rotate the wavefront in a clockwise direction.

required
method str = "linear"

The interpolation method.

'linear'
complex bool = True

If True, rotate the real and imaginary components. If False, rotate the amplitude and phase components.

True

Returns:

Name Type Description
wavefront Wavefront

The new wavefront rotated by angle in the clockwise direction.

Source code in dLux/wavefronts.py
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def rotate(
    self: Wavefront,
    angle: Array,
    method: str = "linear",
    complex: bool = True,
) -> Wavefront:
    """
    Rotates the wavefront by a given angle via interpolation. Can be done on the
    real and imaginary components by passing in complex=True.

    Parameters
    ----------
    angle : Array, radians
        The angle by which to rotate the wavefront in a clockwise
        direction.
    method : str = "linear"
        The interpolation method.
    complex : bool = True
        If True, rotate the real and imaginary components. If False, rotate the
        amplitude and phase components.

    Returns
    -------
    wavefront : Wavefront
        The new wavefront rotated by angle in the clockwise direction.
    """
    # Get field in either (amplitude, phase) or (real, imaginary)
    fields = self.complex if complex else self.polar

    # Rotate the field
    rotator = vmap(dlu.rotate, (0, None, None))
    fields = rotator(fields, angle, method)

    # Convert back to complex form
    if complex:
        phasor = fields[0] + 1j * fields[1]
    else:
        phasor = fields[0] * np.exp(1j * fields[1])

    # Return new wavefront
    return self.set(phasor=phasor)

scale_to(npixels, pixel_scale, complex=True) ¤

Interpolates the wavefront to a given npixels and pixel_scale. Can be done on the real and imaginary components by passing in complex=True.

Parameters:

Name Type Description Default
npixels int

The number of pixels to interpolate to.

required
pixel_scale Array

The pixel scale to interpolate to.

required
complex bool = True

If True, interpolate the real and imaginary components. If False, interpolate the amplitude and phase components.

True

Returns:

Name Type Description
wavefront Wavefront

The new interpolated wavefront.

Source code in dLux/wavefronts.py
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def scale_to(
    self: Wavefront,
    npixels: int,
    pixel_scale: Array,
    complex: bool = True,
) -> Wavefront:
    """
    Interpolates the wavefront to a given npixels and pixel_scale. Can be done on
    the real and imaginary components by passing in complex=True.

    Parameters
    ----------
    npixels : int
        The number of pixels to interpolate to.
    pixel_scale: Array
        The pixel scale to interpolate to.
    complex : bool = True
        If True, interpolate the real and imaginary components. If False,
        interpolate the amplitude and phase components.

    Returns
    -------
    wavefront : Wavefront
        The new interpolated wavefront.
    """
    # Get field in either (amplitude, phase) or (real, imaginary)
    fields = self.complex if complex else self.polar

    # Scale the field
    scale_fn = vmap(dlu.scale, (0, None, None))
    fields = scale_fn(fields, npixels, pixel_scale / self.pixel_scale)

    # Convert back to complex form
    if complex:
        phasor = fields[0] + 1j * fields[1]
    else:
        phasor = fields[0] * np.exp(1j * fields[1])

    # Return new wavefront
    return self.set(phasor=phasor, pixel_scale=pixel_scale)

set_spec(spec) ¤

Updates the wavefront pixel scale and centre from a CoordSpec.

Parameters:

Name Type Description Default
spec CoordSpec

The coordinate specification to apply.

required

Returns:

Name Type Description
wavefront Wavefront

New wavefront with updated pixel_scale and center.

Source code in dLux/wavefronts.py
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
def set_spec(self, spec: CoordSpec):
    """
    Updates the wavefront pixel scale and centre from a `CoordSpec`.

    Parameters
    ----------
    spec : CoordSpec
        The coordinate specification to apply.

    Returns
    -------
    wavefront : Wavefront
        New wavefront with updated `pixel_scale` and `center`.
    """
    return self.set(pixel_scale=spec.d, center=spec.c)

tilt(angles, unit='rad') ¤

Tilts the wavefront by the (x, y) angles.

Parameters:

Name Type Description Default
angles Array

The (x, y) angles by which to tilt the wavefront, in unit.

required
unit str

The units of the angles, e.g. "rad", "deg", "arcmin", "arcsec", and prefixed forms like "mrad", "mas", etc (as supported by utils/units.py).

'rad'

Returns:

Name Type Description
wavefront Wavefront

The tilted wavefront.

Source code in dLux/wavefronts.py
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def tilt(self: Wavefront, angles: Array, unit: str = "rad") -> Wavefront:
    """
    Tilts the wavefront by the (x, y) angles.

    Parameters
    ----------
    angles : Array
        The (x, y) angles by which to tilt the wavefront, in `unit`.
    unit : str
        The units of the angles, e.g. "rad", "deg", "arcmin", "arcsec", and
        prefixed forms like "mrad", "mas", etc (as supported by utils/units.py).

    Returns
    -------
    wavefront : Wavefront
        The tilted wavefront.
    """
    angles = np.asarray(angles, dtype=float)
    if angles.shape != (2,):
        raise ValueError("angles must be a 1d array of shape (2,).")

    # Calculate scaled coordinates
    coords = self.coordinates(scale=dlu.unit_factor_to_rad(unit))

    # Tilt the wavefront
    return self.add_opd(np.sum(angles[:, None, None] * coords, axis=0))

to_psf() ¤

Converts the wavefront to a dLux PSF object.

Returns:

Name Type Description
psf PSF

A PSF object containing the current wavefront intensity and pixel scale.

Source code in dLux/wavefronts.py
277
278
279
280
281
282
283
284
285
286
287
def to_psf(self: Wavefront) -> PSF:
    """
    Converts the wavefront to a dLux PSF object.

    Returns
    -------
    psf : PSF
        A PSF object containing the current wavefront intensity and
        pixel scale.
    """
    return PSF(self.psf, self.pixel_scale)