Skip to content

Commit

Permalink
Make setters and getters GPU aware
Browse files Browse the repository at this point in the history
  • Loading branch information
RemiLehe committed Jul 16, 2024
1 parent d92a196 commit 22960d9
Showing 1 changed file with 28 additions and 19 deletions.
47 changes: 28 additions & 19 deletions lasy/utils/grid.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from lasy.backend import xp
import numpy as np
from lasy.backend import xp, use_cupy

time_axis_indx = -1

Expand Down Expand Up @@ -77,7 +78,9 @@ def set_temporal_field(self, field):
"""
assert field.shape == self.temporal_field.shape
assert field.dtype == "complex128"
self.temporal_field[:, :, :] = field
if use_cupy and type(field) == np.ndarray:
field = xp.asarray(field) # Copy to GPU
self.temporal_field[:,:,:] = field
self.temporal_field_valid = True
self.spectral_field_valid = False # Invalidates the spectral field

Expand All @@ -92,11 +95,13 @@ def set_spectral_field(self, field):
"""
assert field.shape == self.spectral_field.shape
assert field.dtype == "complex128"
self.spectral_field[:, :, :] = field
if use_cupy and type(field) == np.ndarray:
field = xp.asarray(field) # Copy to GPU
self.spectral_field[:,:,:] = field
self.spectral_field_valid = True
self.temporal_field_valid = False # Invalidates the temporal field

def get_temporal_field(self):
def get_temporal_field(self, to_cpu=False):
"""
Return a copy of the temporal field.
Expand All @@ -108,37 +113,41 @@ def get_temporal_field(self):
field : ndarray of complexs
The temporal field.
"""
# We return a copy, so that the user cannot modify
# the original field, unless get_temporal_field is called
if self.temporal_field_valid:
return self.temporal_field.copy()
elif self.spectral_field_valid:
if not self.temporal_field_valid:
self.spectral2temporal_fft()
return self.temporal_field.copy()
# Return a copy of the field, either on CPU or GPU, so that the user
# cannot modify the original field, unless set_spectral_field is called
if to_cpu and use_cupy:
return xp.asnumpy(self.temporal_field)
else:
raise ValueError("Both temporal and spectral fields are invalid")
return self.temporal_field.copy()

def get_spectral_field(self):
def get_spectral_field(self, to_cpu=False):
"""
Return a copy of the spectral field.
(Modifying the returned object will not modify the original field stored
in the Grid object ; one must use set_spectral_field to do so.)
Parameters
----------
to_cpu : bool
If True, the returned field is always returned as a numpy array on CPU
(even when the lasy backend is cupy)
Returns
-------
field : ndarray of complexs
The spectral field.
"""
# We return a copy, so that the user cannot modify
# the original field, unless set_spectral_field is called
if self.spectral_field_valid:
return self.spectral_field.copy()
elif self.temporal_field_valid:
if not self.spectral_field_valid:
self.temporal2spectral_fft()
return self.spectral_field.copy()
# Return a copy of the field, either on CPU or GPU, so that the user
# cannot modify the original field, unless set_spectral_field is called
if to_cpu and use_cupy:
return xp.asnumpy(self.spectral_field)
else:
raise ValueError("Both temporal and spectral fields are invalid")
return self.spectral_field.copy()

def temporal2spectral_fft(self):
"""
Expand Down

0 comments on commit 22960d9

Please sign in to comment.