diff --git a/pylops_mpi/basicoperators/VStack.py b/pylops_mpi/basicoperators/VStack.py index 14b24f8..96727f7 100644 --- a/pylops_mpi/basicoperators/VStack.py +++ b/pylops_mpi/basicoperators/VStack.py @@ -118,8 +118,9 @@ def __init__(self, ops: Sequence[LinearOperator], def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - if x.partition is not Partition.BROADCAST: - raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}") + if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: + raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" + f"Got {x.partition} instead...") y = DistributedArray(global_shape=self.shape[0], local_shapes=self.local_shapes_n, engine=x.engine, dtype=self.dtype) y1 = [] diff --git a/pylops_mpi/signalprocessing/Fredholm1.py b/pylops_mpi/signalprocessing/Fredholm1.py index d2e0327..3b8a9ef 100644 --- a/pylops_mpi/signalprocessing/Fredholm1.py +++ b/pylops_mpi/signalprocessing/Fredholm1.py @@ -108,8 +108,9 @@ def __init__( def _matvec(self, x: DistributedArray) -> DistributedArray: ncp = get_module(x.engine) - if x.partition is not Partition.BROADCAST: - raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}") + if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: + raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" + f"Got {x.partition} instead...") y = DistributedArray(global_shape=self.shape[0], partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) x = x.local_array.reshape(self.dims).squeeze() @@ -129,8 +130,9 @@ def _matvec(self, x: DistributedArray) -> DistributedArray: def _rmatvec(self, x: NDArray) -> NDArray: ncp = get_module(x.engine) - if x.partition is not Partition.BROADCAST: - raise ValueError(f"x should have partition={Partition.BROADCAST}, {x.partition} != {Partition.BROADCAST}") + if x.partition not in [Partition.BROADCAST, Partition.UNSAFE_BROADCAST]: + raise ValueError(f"x should have partition={Partition.BROADCAST},{Partition.UNSAFE_BROADCAST}" + f"Got {x.partition} instead...") y = DistributedArray(global_shape=self.shape[1], partition=Partition.BROADCAST, engine=x.engine, dtype=self.dtype) x = x.local_array.reshape(self.dimsd).squeeze()