From c8098e8188421568e713e319f7cdb119fdee2b6e Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sat, 25 May 2024 21:18:59 +0300 Subject: [PATCH 01/11] feature: enabled use of cupy arrays --- pylops_mpi/DistributedArray.py | 22 +++++++++++++++++++++- pylops_mpi/basicoperators/BlockDiag.py | 15 ++++++++++----- pylops_mpi/basicoperators/VStack.py | 13 +++++++++---- pylops_mpi/optimization/cls_basic.py | 8 ++++---- pylops_mpi/utils/decorators.py | 3 ++- 5 files changed, 46 insertions(+), 15 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index f6b367db..a80d1165 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -5,6 +5,7 @@ from enum import Enum from pylops.utils import DTypeLike, NDArray +from pylops.utils.backend import get_module class Partition(Enum): @@ -78,6 +79,8 @@ class DistributedArray: Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional List of tuples representing local shapes at each rank. + engine : :obj:`str`, optional + Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional Type of elements in input array. Defaults to ``numpy.float64``. """ @@ -86,6 +89,7 @@ def __init__(self, global_shape: Union[Tuple, Integral], base_comm: Optional[MPI.Comm] = MPI.COMM_WORLD, partition: Partition = Partition.SCATTER, axis: int = 0, local_shapes: Optional[List[Tuple]] = None, + engine: Optional[str] = "numpy", dtype: Optional[DTypeLike] = np.float64): if isinstance(global_shape, Integral): global_shape = (global_shape,) @@ -103,7 +107,8 @@ def __init__(self, global_shape: Union[Tuple, Integral], self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis) - self._local_array = np.empty(shape=self.local_shape, dtype=self.dtype) + self._engine = engine + self._local_array = get_module(engine).empty(shape=self.local_shape, dtype=self.dtype) def __getitem__(self, index): return self.local_array[index] @@ -160,6 +165,16 @@ def local_shape(self): """ return self._local_shape + @property + def engine(self): + """Engine of the Distributed array + + Returns + ------- + engine : :obj:`str` + """ + return self._engine + @property def local_array(self): """View of the Local Array @@ -334,6 +349,7 @@ def __neg__(self): partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + engine=self.engine, dtype=self.dtype) arr[:] = -self.local_array return arr @@ -365,6 +381,7 @@ def add(self, dist_array): dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, + engine=self.engine, axis=self.axis) SumArray[:] = self.local_array + dist_array.local_array return SumArray @@ -387,6 +404,7 @@ def multiply(self, dist_array): dtype=self.dtype, partition=self.partition, local_shapes=self.local_shapes, + engine=self.engine, axis=self.axis) if isinstance(dist_array, DistributedArray): # multiply two DistributedArray @@ -480,6 +498,7 @@ def conj(self): partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + engine=self.engine, dtype=self.dtype) conj[:] = self.local_array.conj() return conj @@ -492,6 +511,7 @@ def copy(self): partition=self.partition, axis=self.axis, local_shapes=self.local_shapes, + engine=self.engine, dtype=self.dtype) arr[:] = self.local_array return arr diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index 685638f7..c96e6e10 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -3,8 +3,9 @@ from mpi4py import MPI from typing import Optional, Sequence -from pylops.utils import DTypeLike from pylops import LinearOperator +from pylops.utils import DTypeLike +from pylops.utils.backend import get_module from pylops_mpi import MPILinearOperator, MPIStackedLinearOperator from pylops_mpi import DistributedArray, StackedDistributedArray @@ -113,22 +114,26 @@ def __init__(self, ops: Sequence[LinearOperator], @reshaped(forward=True, stacking=True) def _matvec(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.matvec(x.local_array[self.mmops[iop]: self.mmops[iop + 1]])) - y[:] = np.concatenate(y1) + y[:] = ncp.concatenate(ncp.asarray(y1)) return y @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=self.shape[1], local_shapes=self.local_shapes_m, + engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) - y[:] = np.concatenate(y1) + y[:] = ncp.concatenate(ncp.asarray(y1)) return y diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index daf26c61..480cbd22 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -5,6 +5,7 @@ from pylops import LinearOperator from pylops.utils import DTypeLike +from pylops.utils.backend import get_module from pylops_mpi import ( MPILinearOperator, @@ -116,22 +117,26 @@ def __init__(self, ops: Sequence[LinearOperator], super().__init__(shape=shape, dtype=dtype, base_comm=base_comm) def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) if x.partition is not Partition.BROADCAST: raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}") - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, dtype=self.dtype) + y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.matvec(x.local_array)) - y[:] = np.concatenate(y1) + y[:] = ncp.concatenate(y1) return y @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, + engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) - y1 = np.sum(y1, axis=0) + y1 = ncp.sum(ncp.asarray(y1), axis=0) y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y diff --git a/pylops_mpi/optimization/cls_basic.py b/pylops_mpi/optimization/cls_basic.py index b71617c9..8e9e1ef2 100644 --- a/pylops_mpi/optimization/cls_basic.py +++ b/pylops_mpi/optimization/cls_basic.py @@ -337,7 +337,7 @@ def setup(self, self.rank = x.rank self.c = r.copy() self.q = self.Op.matvec(self.c) - self.kold = np.abs(r.dot(r.conj())) + self.kold = float(np.abs(r.dot(r.conj()))) # create variables to track the residual norm and iterations self.cost = [] @@ -373,13 +373,13 @@ def step(self, x: Union[DistributedArray, StackedDistributedArray], """ - a = self.kold / (self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj())) + a = float(self.kold / (self.q.dot(self.q.conj()) + self.damp * self.c.dot(self.c.conj()))) x += a * self.c self.s -= a * self.q damped_x = self.damp * x r = self.Op.rmatvec(self.s) - damped_x - k = np.abs(r.dot(r.conj())) - b = k / self.kold + k = float(np.abs(r.dot(r.conj()))) + b = float(k / self.kold) self.c = r + b * self.c self.q = self.Op.matvec(self.c) self.kold = k diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index cfc736f8..1ba1c277 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -54,7 +54,8 @@ def wrapper(self, x: DistributedArray): local_shapes = None global_shape = getattr(self, "dims") arr = DistributedArray(global_shape=global_shape, - local_shapes=local_shapes, axis=0, dtype=x.dtype) + local_shapes=local_shapes, axis=0, + engine=x.engine, dtype=x.dtype) arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape))) x_local_shapes = np.asarray(x.base_comm.allgather(np.prod(x.local_shape))) # Calculate num_ghost_cells required for each rank From 48b7a836c512b0fb04b83d570420198e3dcb6cf2 Mon Sep 17 00:00:00 2001 From: rohanbabbar04 Date: Sun, 26 May 2024 19:50:46 +0530 Subject: [PATCH 02/11] Fix lint --- pylops_mpi/basicoperators/BlockDiag.py | 2 +- pylops_mpi/basicoperators/VStack.py | 4 ++-- pylops_mpi/utils/decorators.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index c96e6e10..a644c969 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -115,7 +115,7 @@ def __init__(self, ops: Sequence[LinearOperator], @reshaped(forward=True, stacking=True) def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 480cbd22..45870519 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -120,7 +120,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) if x.partition is not Partition.BROADCAST: raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}") - y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, + y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): @@ -131,7 +131,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: @reshaped(forward=False, stacking=True) def _rmatvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, + y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) y1 = [] for iop, oper in enumerate(self.ops): diff --git a/pylops_mpi/utils/decorators.py b/pylops_mpi/utils/decorators.py index 1ba1c277..457b559b 100644 --- a/pylops_mpi/utils/decorators.py +++ b/pylops_mpi/utils/decorators.py @@ -54,7 +54,7 @@ def wrapper(self, x: DistributedArray): local_shapes = None global_shape = getattr(self, "dims") arr = DistributedArray(global_shape=global_shape, - local_shapes=local_shapes, axis=0, + local_shapes=local_shapes, axis=0, engine=x.engine, dtype=x.dtype) arr_local_shapes = np.asarray(arr.base_comm.allgather(np.prod(arr.local_shape))) x_local_shapes = np.asarray(x.base_comm.allgather(np.prod(x.local_shape))) From f2c9e4cb3d45cd9415e5b3fbad2cdf27c0b07cac Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 31 May 2024 18:24:05 +0300 Subject: [PATCH 03/11] feature: enable cupy in to_dist --- pylops_mpi/DistributedArray.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index a80d1165..8cba4098 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -5,7 +5,7 @@ from enum import Enum from pylops.utils import DTypeLike, NDArray -from pylops.utils.backend import get_module +from pylops.utils.backend import get_module, get_array_module, get_module_name class Partition(Enum): @@ -294,6 +294,7 @@ def to_dist(cls, x: NDArray, partition=partition, axis=axis, local_shapes=local_shapes, + engine=get_module_name(get_array_module(x)), dtype=x.dtype) if partition == Partition.BROADCAST: dist_array[:] = x From 93c2b22130b55c7484d29997a3fbfaa0dccb1687 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 31 May 2024 18:24:20 +0300 Subject: [PATCH 04/11] feature: enable cupy for cg --- pylops_mpi/optimization/cls_basic.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pylops_mpi/optimization/cls_basic.py b/pylops_mpi/optimization/cls_basic.py index 8e9e1ef2..9ad9a34e 100644 --- a/pylops_mpi/optimization/cls_basic.py +++ b/pylops_mpi/optimization/cls_basic.py @@ -86,7 +86,7 @@ def setup( self.r = self.y - self.Op.matvec(x) self.rank = x.rank self.c = self.r.copy() - self.kold = np.abs(self.r.dot(self.r.conj())) + self.kold = float(np.abs(self.r.dot(self.r.conj()))) # create variables to track the residual norm and iterations self.cost: List = [] @@ -120,11 +120,11 @@ def step(self, x: Union[DistributedArray, StackedDistributedArray], """ Opc = self.Op.matvec(self.c) cOpc = np.abs(self.c.dot(Opc.conj())) - a = self.kold / cOpc + a = float(self.kold / cOpc) x += a * self.c self.r -= a * Opc - k = np.abs(self.r.dot(self.r.conj())) - b = k / self.kold + k = float(np.abs(self.r.dot(self.r.conj()))) + b = float(k / self.kold) self.c = self.r + b * self.c self.kold = k self.iiter += 1 From 84c5eebc5fec345d14218615a674a51e7956782d Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 31 May 2024 22:59:23 +0300 Subject: [PATCH 05/11] feature: enable cupy in FirstDerivative and SecondDerivative --- pylops_mpi/DistributedArray.py | 1 + pylops_mpi/LinearOperator.py | 2 + pylops_mpi/basicoperators/FirstDerivative.py | 69 ++++++++++++------- pylops_mpi/basicoperators/SecondDerivative.py | 50 +++++++++----- 4 files changed, 77 insertions(+), 45 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 8cba4098..eec4b057 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -535,6 +535,7 @@ def ravel(self, order: Optional[str] = "C"): arr = DistributedArray(global_shape=np.prod(self.global_shape), local_shapes=local_shapes, partition=self.partition, + engine=self.engine, dtype=self.dtype) local_array = np.ravel(self.local_array, order=order) x = local_array.copy() diff --git a/pylops_mpi/LinearOperator.py b/pylops_mpi/LinearOperator.py index 7ac851f6..a7bc9bea 100644 --- a/pylops_mpi/LinearOperator.py +++ b/pylops_mpi/LinearOperator.py @@ -81,6 +81,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: base_comm=self.base_comm, partition=x.partition, axis=x.axis, + engine=x.engine, dtype=self.dtype) y[:] = self.Op._matvec(x.local_array) return y @@ -117,6 +118,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: base_comm=self.base_comm, partition=x.partition, axis=x.axis, + engine=x.engine, dtype=self.dtype) y[:] = self.Op._rmatvec(x.local_array) return y diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index d4c7671c..9fa9ea03 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -2,8 +2,8 @@ import numpy as np from mpi4py import MPI -from pylops.utils import DTypeLike -from pylops.utils.typing import InputDimsLike +from pylops.utils.backend import get_module +from pylops.utils.typing import DTypeLike, InputDimsLike from pylops.utils._internal import _value_or_sized_to_tuple from pylops_mpi import ( @@ -140,17 +140,21 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=1) y_forward = ghosted_x[1:] - ghosted_x[:-1] if self.rank == self.size - 1: - y_forward = np.append(y_forward, np.zeros((1,) + self.dims[1:]), axis=0) + y_forward = ncp.append(y_forward, ncp.zeros((1,) + self.dims[1:]), axis=0) y[:] = y_forward / self.sampling return y @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: y[:-1] -= x[:-1] @@ -159,29 +163,33 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ghosted_x = x.add_ghost_cells(cells_front=1) y_forward = ghosted_x[:-1] if self.rank == 0: - y_forward = np.insert(y_forward, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_forward = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_forward, axis=0) y[:] += y_forward y[:] /= self.sampling return y @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] if self.rank == 0: - y_backward = np.insert(y_backward, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_backward = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_backward, axis=0) y[:] = y_backward / self.sampling return y @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=1) y_backward = ghosted_x[1:] if self.rank == self.size - 1: - y_backward = np.append(y_backward, np.zeros((1,) + self.dims[1:]), axis=0) + y_backward = ncp.append(y_backward, ncp.zeros((1,) + self.dims[1:]), axis=0) y[:] -= y_backward if self.rank == 0: y[1:] += x[1:] @@ -192,13 +200,15 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2]) if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0] - 1, 1), ) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0] - 1, 1), ) + self.dims[1:]), axis=0) y[:] = y_centered if self.edge: if self.rank == 0: @@ -210,18 +220,21 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 + ghosted_x = x.add_ghost_cells(cells_back=2) y_centered = 0.5 * ghosted_x[1:-1] if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) y[:] -= y_centered ghosted_x = x.add_ghost_cells(cells_front=2) y_centered = 0.5 * ghosted_x[1:-1] if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), y_centered, axis=0) y[:] += y_centered if self.edge: if self.rank == 0: @@ -235,7 +248,9 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2) y_centered = ( ghosted_x[:-4] / 12.0 @@ -244,9 +259,9 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: - ghosted_x[4:] / 12.0 ) if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0] - 2, 2),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0] - 2, 2),) + self.dims[1:]), axis=0) y[:] = y_centered if self.edge: if self.rank == 0: @@ -260,34 +275,36 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=4) y_centered = ghosted_x[2:-2] / 12.0 if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0], 4),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0], 4),) + self.dims[1:]), axis=0) y[:] += y_centered ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=3) y_centered = 2.0 * ghosted_x[2:-2] / 3.0 if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0] - 1, 3),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0] - 1, 3),) + self.dims[1:]), axis=0) y[:] -= y_centered ghosted_x = x.add_ghost_cells(cells_front=3, cells_back=1) y_centered = 2.0 * ghosted_x[2:-2] / 3.0 if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((min(y.global_shape[0], 3),) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((min(y.global_shape[0], 3),) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(y.global_shape[0] - 3, 1),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(y.global_shape[0] - 3, 1),) + self.dims[1:]), axis=0) y[:] += y_centered ghosted_x = x.add_ghost_cells(cells_front=4) y_centered = ghosted_x[2:-2] / 12.0 if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((min(y.global_shape[0], 4),) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((min(y.global_shape[0], 4),) + self.dims[1:]), y_centered, axis=0) y[:] -= y_centered if self.edge: if self.rank == 0: diff --git a/pylops_mpi/basicoperators/SecondDerivative.py b/pylops_mpi/basicoperators/SecondDerivative.py index 7d77aa9e..6c4fb961 100644 --- a/pylops_mpi/basicoperators/SecondDerivative.py +++ b/pylops_mpi/basicoperators/SecondDerivative.py @@ -2,6 +2,7 @@ import numpy as np from mpi4py import MPI +from pylops.utils.backend import get_module from pylops.utils.typing import DTypeLike, InputDimsLike from pylops.utils._internal import _value_or_sized_to_tuple @@ -122,16 +123,19 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=2) y_forward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == self.size - 1: - y_forward = np.append(y_forward, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_forward = ncp.append(y_forward, ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) y[:] = y_forward / self.sampling ** 2 return y @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -142,45 +146,49 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_forward = ghosted_x[:-2] if self.rank == 0: - y_forward = np.insert(y_forward, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_forward = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_forward, axis=0) if self.rank == self.size - 1: - y_forward = np.append(y_forward, np.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) + y_forward = ncp.append(y_forward, ncp.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) y[:] -= 2 * y_forward ghosted_x = x.add_ghost_cells(cells_front=2) y_forward = ghosted_x[:-2] if self.rank == 0: - y_forward = np.insert(y_forward, 0, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_forward = ncp.append(ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), y_forward, axis=0) y[:] += y_forward y[:] /= self.sampling ** 2 return y @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2) y_backward = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: - y_backward = np.insert(y_backward, 0, np.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), axis=0) + y_backward = ncp.append(ncp.zeros((min(y.global_shape[0], 2),) + self.dims[1:]), y_backward, axis=0) y[:] = y_backward / self.sampling ** 2 return y @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_backward = ghosted_x[2:] if self.rank == self.size - 1: - y_backward = np.append(y_backward, np.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), axis=0) + y_backward = ncp.append(y_backward, ncp.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), axis=0) y[:] += y_backward ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_backward = 2 * ghosted_x[2:] if self.rank == 0: - y_backward = np.insert(y_backward, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_backward = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_backward, axis=0) if self.rank == self.size - 1: - y_backward = np.append(y_backward, np.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) + y_backward = ncp.append(y_backward, ncp.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) y[:] -= y_backward if self.rank == 0: @@ -192,13 +200,15 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = ghosted_x[2:] - 2 * ghosted_x[1:-1] + ghosted_x[:-2] if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) y[:] = y_centered if self.edge: if self.rank == 0: @@ -210,26 +220,28 @@ def _matvec_centered(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered(self, x: DistributedArray) -> DistributedArray: - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, dtype=self.dtype) + ncp = get_module(x.engine) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=2) y_centered = ghosted_x[1:-1] if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), axis=0) y[:] += y_centered ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 2 * ghosted_x[1:-1] if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((1,) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((1,) + self.dims[1:]), y_centered, axis=0) if self.rank == self.size - 1: - y_centered = np.append(y_centered, np.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) + y_centered = ncp.append(y_centered, ncp.zeros((min(1, y.global_shape[0] - 1),) + self.dims[1:]), axis=0) y[:] -= y_centered ghosted_x = x.add_ghost_cells(cells_front=2) y_centered = ghosted_x[1:-1] if self.rank == 0: - y_centered = np.insert(y_centered, 0, np.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), axis=0) + y_centered = ncp.append(ncp.zeros((min(2, y.global_shape[0]),) + self.dims[1:]), y_centered, axis=0) y[:] += y_centered if self.edge: if self.rank == 0: From 47a7a37bd4f9e9481ac022ca6b316a717bd7f5e9 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 31 May 2024 23:03:00 +0300 Subject: [PATCH 06/11] minor: fix linting issue --- pylops_mpi/basicoperators/FirstDerivative.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index 9fa9ea03..73ddb510 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -141,7 +141,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_back=1) y_forward = ghosted_x[1:] - ghosted_x[:-1] @@ -153,7 +153,7 @@ def _matvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 if self.rank == self.size - 1: @@ -171,8 +171,8 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] if self.rank == 0: @@ -183,7 +183,7 @@ def _matvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=1) @@ -201,7 +201,7 @@ def _rmatvec_backward(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1, cells_back=1) y_centered = 0.5 * (ghosted_x[2:] - ghosted_x[:-2]) @@ -221,7 +221,7 @@ def _matvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 @@ -249,7 +249,7 @@ def _rmatvec_centered3(self, x: DistributedArray) -> DistributedArray: @reshaped def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=2, cells_back=2) y_centered = ( @@ -276,7 +276,7 @@ def _matvec_centered5(self, x: DistributedArray) -> DistributedArray: @reshaped def _rmatvec_centered5(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, + y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, axis=x.axis, engine=x.engine, dtype=self.dtype) y[:] = 0 ghosted_x = x.add_ghost_cells(cells_back=4) From c7c33044dfb3dc9262dde97d3e7c4a3562dad220 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Fri, 31 May 2024 23:04:36 +0300 Subject: [PATCH 07/11] minor: fix another linting issue --- pylops_mpi/basicoperators/FirstDerivative.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/basicoperators/FirstDerivative.py b/pylops_mpi/basicoperators/FirstDerivative.py index 73ddb510..5adbe284 100644 --- a/pylops_mpi/basicoperators/FirstDerivative.py +++ b/pylops_mpi/basicoperators/FirstDerivative.py @@ -172,7 +172,7 @@ def _rmatvec_forward(self, x: DistributedArray) -> DistributedArray: def _matvec_backward(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) y = DistributedArray(global_shape=x.global_shape, local_shapes=x.local_shapes, - axis=x.axis, engine=x.engine, dtype=self.dtype) + axis=x.axis, engine=x.engine, dtype=self.dtype) ghosted_x = x.add_ghost_cells(cells_front=1) y_backward = ghosted_x[1:] - ghosted_x[:-1] if self.rank == 0: From aa72eccbcc09c81bd5578004911eef60e623a24f Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sat, 1 Jun 2024 23:08:20 +0300 Subject: [PATCH 08/11] minor: small doc fixes --- examples/plot_stacked_array.py | 2 +- pylops_mpi/DistributedArray.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/plot_stacked_array.py b/examples/plot_stacked_array.py index a92a98f5..a6b2f03f 100644 --- a/examples/plot_stacked_array.py +++ b/examples/plot_stacked_array.py @@ -1,6 +1,6 @@ """ Stacked Array -========================= +============= This example shows how to use the :py:class:`pylops_mpi.StackedDistributedArray`. This class provides a way to combine and act on multiple :py:class:`pylops_mpi.DistributedArray` within the same program. This is very useful in scenarios where an array can be logically diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index eec4b057..84893e04 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -284,6 +284,7 @@ def to_dist(cls, x: NDArray, Axis of Distribution local_shapes : :obj:`list`, optional Local Shapes at each rank. + Returns ---------- dist_array : :obj:`DistributedArray` From 41583fb98b5dc215a0313e8ccf95028a38b0931f Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 2 Jun 2024 22:20:06 +0300 Subject: [PATCH 09/11] minor: fix lint issue --- pylops_mpi/DistributedArray.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 84893e04..6e5a471a 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -284,7 +284,7 @@ def to_dist(cls, x: NDArray, Axis of Distribution local_shapes : :obj:`list`, optional Local Shapes at each rank. - + Returns ---------- dist_array : :obj:`DistributedArray` From 076868dfad17683c2b710ff47bd08556dcfe6d54 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 2 Jun 2024 22:38:06 +0300 Subject: [PATCH 10/11] doc: added gpu section to doc --- docs/source/gpu.rst | 85 ++++++++++++++++++++++++++ docs/source/index.rst | 1 + pylops_mpi/basicoperators/BlockDiag.py | 4 +- 3 files changed, 88 insertions(+), 2 deletions(-) create mode 100644 docs/source/gpu.rst diff --git a/docs/source/gpu.rst b/docs/source/gpu.rst new file mode 100644 index 00000000..43c9e768 --- /dev/null +++ b/docs/source/gpu.rst @@ -0,0 +1,85 @@ +.. _gpu: + +GPU Support +=========== + +Overview +-------- +PyLops-mpi supports computations on GPUs leveraging the GPU backend of PyLops. Under the hood, +`CuPy `_ (``cupy-cudaXX>=v13.0.0``) is used to perform all of the operations. +This library must be installed *before* PyLops-mpi is installed. + +.. note:: + + Set environment variable ``CUPY_PYLOPS=0`` to force PyLops to ignore the ``cupy`` backend. + This can be also used if a previous (or faulty) version of ``cupy`` is installed in your system, + otherwise you will get an error when importing PyLops. + + +The :class:`pylops_mpi.DistributedArray` and :class:`pylops_mpi.StackedDistributedArray` objects can be +generated using both ``numpy`` and ``cupy`` based local arrays, and all of the operators and solvers in PyLops-mpi +can handle both scenarios. Note that, since most operators in PyLops-mpi are thin-wrappers around PyLops operators, +some of the operators in PyLops that lack a GPU implementation cannot be used also in PyLops-mpi when working with +cupy arrays. + + +Example +------- + +Finally, let's briefly look at an example. First we write a code snippet using +``numpy`` arrays which PyLops-mpi will run on your CPU: + +.. code-block:: python + + # MPI helpers + comm = MPI.COMM_WORLD + rank = MPI.COMM_WORLD.Get_rank() + size = MPI.COMM_WORLD.Get_size() + + # Create distributed data (broadcast) + nxl, nt = 20, 20 + dtype = np.float32 + d_dist = pylops_mpi.DistributedArray(global_shape=nxl * nt, + partition=pylops_mpi.Partition.BROADCAST, + engine="numpy", dtype=dtype) + d_dist[:] = np.ones(d_dist.local_shape, dtype=dtype) + + # Create and apply VStack operator + Sop = pylops.MatrixMult(np.ones((nxl, nxl)), otherdims=(nt, )) + HOp = pylops_mpi.MPIVStack(ops=[Sop, ]) + y_dist = HOp @ d_dist + + +Now we write a code snippet using ``cupy`` arrays which PyLops will run on +your GPU: + +.. code-block:: python + + # MPI helpers + comm = MPI.COMM_WORLD + rank = MPI.COMM_WORLD.Get_rank() + size = MPI.COMM_WORLD.Get_size() + + # Define gpu to use + cp.cuda.Device(device=rank).use() + + # Create distributed data (broadcast) + nxl, nt = 20, 20 + dtype = np.float32 + d_dist = pylops_mpi.DistributedArray(global_shape=nxl * nt, + partition=pylops_mpi.Partition.BROADCAST, + engine="cupy", dtype=dtype) + d_dist[:] = cp.ones(d_dist.local_shape, dtype=dtype) + + # Create and apply VStack operator + Sop = pylops.MatrixMult(cp.ones((nxl, nxl)), otherdims=(nt, )) + HOp = pylops_mpi.MPIVStack(ops=[Sop, ]) + y_dist = HOp @ d_dist + +The code is almost unchanged apart from the fact that we now use ``cupy`` arrays, +PyLops-mpi will figure this out! + +.. note:: + + The CuPy backend is in active development, with many examples not yet in the docs. + You can find many `other examples `_ from the `PyLops Notebooks repository `_. \ No newline at end of file diff --git a/docs/source/index.rst b/docs/source/index.rst index 52485d7e..b5d538ee 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -71,6 +71,7 @@ class and implementing the ``_matvec`` and ``_rmatvec``. self installation.rst + gpu.rst .. toctree:: :maxdepth: 2 diff --git a/pylops_mpi/basicoperators/BlockDiag.py b/pylops_mpi/basicoperators/BlockDiag.py index a644c969..7911692f 100644 --- a/pylops_mpi/basicoperators/BlockDiag.py +++ b/pylops_mpi/basicoperators/BlockDiag.py @@ -121,7 +121,7 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: for iop, oper in enumerate(self.ops): y1.append(oper.matvec(x.local_array[self.mmops[iop]: self.mmops[iop + 1]])) - y[:] = ncp.concatenate(ncp.asarray(y1)) + y[:] = ncp.concatenate(y1) return y @reshaped(forward=False, stacking=True) @@ -133,7 +133,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) - y[:] = ncp.concatenate(ncp.asarray(y1)) + y[:] = ncp.concatenate(y1) return y From 45186ef9fe5fd76210f7e686cbcf5a438387afa0 Mon Sep 17 00:00:00 2001 From: mrava87 Date: Sun, 2 Jun 2024 22:44:14 +0300 Subject: [PATCH 11/11] minor: remove asarray from MPIVStack --- pylops_mpi/basicoperators/VStack.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 45870519..14b24f89 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -136,7 +136,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray: y1 = [] for iop, oper in enumerate(self.ops): y1.append(oper.rmatvec(x.local_array[self.nnops[iop]: self.nnops[iop + 1]])) - y1 = ncp.sum(ncp.asarray(y1), axis=0) + y1 = ncp.sum(y1, axis=0) y[:] = self.base_comm.allreduce(y1, op=MPI.SUM) return y