diff --git a/lasy/utils/grid.py b/lasy/utils/grid.py index 1168cec7..dc091f87 100644 --- a/lasy/utils/grid.py +++ b/lasy/utils/grid.py @@ -1,4 +1,5 @@ -from lasy.backend import xp +import numpy as np +from lasy.backend import xp, use_cupy time_axis_indx = -1 @@ -77,7 +78,9 @@ def set_temporal_field(self, field): """ assert field.shape == self.temporal_field.shape assert field.dtype == "complex128" - self.temporal_field[:, :, :] = field + if use_cupy and type(field) == np.ndarray: + field = xp.asarray(field) # Copy to GPU + self.temporal_field[:,:,:] = field self.temporal_field_valid = True self.spectral_field_valid = False # Invalidates the spectral field @@ -92,11 +95,13 @@ def set_spectral_field(self, field): """ assert field.shape == self.spectral_field.shape assert field.dtype == "complex128" - self.spectral_field[:, :, :] = field + if use_cupy and type(field) == np.ndarray: + field = xp.asarray(field) # Copy to GPU + self.spectral_field[:,:,:] = field self.spectral_field_valid = True self.temporal_field_valid = False # Invalidates the temporal field - def get_temporal_field(self): + def get_temporal_field(self, to_cpu=False): """ Return a copy of the temporal field. @@ -108,37 +113,41 @@ def get_temporal_field(self): field : ndarray of complexs The temporal field. """ - # We return a copy, so that the user cannot modify - # the original field, unless get_temporal_field is called - if self.temporal_field_valid: - return self.temporal_field.copy() - elif self.spectral_field_valid: + if not self.temporal_field_valid: self.spectral2temporal_fft() - return self.temporal_field.copy() + # Return a copy of the field, either on CPU or GPU, so that the user + # cannot modify the original field, unless set_spectral_field is called + if to_cpu and use_cupy: + return xp.asnumpy(self.temporal_field) else: - raise ValueError("Both temporal and spectral fields are invalid") + return self.temporal_field.copy() - def get_spectral_field(self): + def get_spectral_field(self, to_cpu=False): """ Return a copy of the spectral field. (Modifying the returned object will not modify the original field stored in the Grid object ; one must use set_spectral_field to do so.) + Parameters + ---------- + to_cpu : bool + If True, the returned field is always returned as a numpy array on CPU + (even when the lasy backend is cupy) + Returns ------- field : ndarray of complexs The spectral field. """ - # We return a copy, so that the user cannot modify - # the original field, unless set_spectral_field is called - if self.spectral_field_valid: - return self.spectral_field.copy() - elif self.temporal_field_valid: + if not self.spectral_field_valid: self.temporal2spectral_fft() - return self.spectral_field.copy() + # Return a copy of the field, either on CPU or GPU, so that the user + # cannot modify the original field, unless set_spectral_field is called + if to_cpu and use_cupy: + return xp.asnumpy(self.spectral_field) else: - raise ValueError("Both temporal and spectral fields are invalid") + return self.spectral_field.copy() def temporal2spectral_fft(self): """