From db0fba5a4462a5f79b4db5677c4f70a073bca70c Mon Sep 17 00:00:00 2001 From: mrava87 Date: Tue, 22 Oct 2024 00:06:10 +0300 Subject: [PATCH] minor: small code simplication --- pylops_mpi/DistributedArray.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pylops_mpi/DistributedArray.py b/pylops_mpi/DistributedArray.py index 4caa2db..0b2e2c6 100644 --- a/pylops_mpi/DistributedArray.py +++ b/pylops_mpi/DistributedArray.py @@ -79,7 +79,7 @@ class DistributedArray: axis : :obj:`int`, optional Axis along which distribution occurs. Defaults to ``0``. local_shapes : :obj:`list`, optional - List of tuples of integers representing local shapes at each rank. + List of tuples or integers representing local shapes at each rank. engine : :obj:`str`, optional Engine used to store array (``numpy`` or ``cupy``) dtype : :obj:`str`, optional @@ -106,8 +106,7 @@ def __init__(self, global_shape: Union[Tuple, Integral], self._partition = partition self._axis = axis - if local_shapes is not None: - local_shapes = [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] + local_shapes = local_shapes if local_shapes is None else [_value_or_sized_to_tuple(local_shape) for local_shape in local_shapes] self._check_local_shapes(local_shapes) self._local_shape = local_shapes[base_comm.rank] if local_shapes else local_split(global_shape, base_comm, partition, axis)