diff --git a/parallelformers/parallelize.py b/parallelformers/parallelize.py index df98e00..b24373b 100644 --- a/parallelformers/parallelize.py +++ b/parallelformers/parallelize.py @@ -290,6 +290,23 @@ def parallelize(self) -> None: traceback.print_exc() self.deparallelize() + @staticmethod + def _deallocate(item): + if torch.is_tensor(item) and item.is_cuda: + item.cpu() + + elif isinstance(item, list) or isinstance(item, tuple): + for i in item: + if torch.is_tensor(i) and i.is_cuda: + i.cpu() + + elif isinstance(item, dict): + for i in item: + if torch.is_tensor(item[i]) and item[i].is_cuda: + item[i].cpu() + + return item + @torch.no_grad() def hijack( self, @@ -314,6 +331,11 @@ def hijack( self.inference_mutexes, self.inputs_queues, ): + inputs = self._deallocate(inputs) + + for k in kwargs: + kwargs[k] = self._deallocate(kwargs[k]) + i_queue.put((inputs, kwargs, func)) i_mutex.set() # producer part diff --git a/setup.py b/setup.py index 773dca8..23eb37e 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ setup( name='parallelformers', - version='1.2.3', + version='1.2.4', description= 'An Efficient Model Parallelization Toolkit for Deployment', long_description=long_description,