From d79a37c408301bf7f845c74e32e0c973138fdbeb Mon Sep 17 00:00:00 2001 From: Dudu Moshe <53430514+dudulightricks@users.noreply.github.com> Date: Fri, 20 Sep 2024 18:22:47 +0300 Subject: [PATCH] parallel_loader: Fix TPU memory leak when calling __iter__. (#8039) --- torch_xla/distributed/parallel_loader.py | 68 +++++++++++++++--------- 1 file changed, 43 insertions(+), 25 deletions(-) diff --git a/torch_xla/distributed/parallel_loader.py b/torch_xla/distributed/parallel_loader.py index 3d98ff4a225..a0304b4523a 100644 --- a/torch_xla/distributed/parallel_loader.py +++ b/torch_xla/distributed/parallel_loader.py @@ -95,12 +95,14 @@ def __init__(self, self._done = False self._queues = dict() self._input_sharding = input_sharding + self._threads = [] for device in self._devices: self._queues[device] = PerDeviceQueue(device, loader_prefetch_size, device_prefetch_size) thread = threading.Thread(target=self._loader_worker) thread.daemon = True thread.start() + self._threads.append(thread) for dqueue in self._queues.values(): for i in range(host_to_device_transfer_threads): thread = threading.Thread( @@ -111,6 +113,7 @@ def __init__(self, )) thread.daemon = True thread.start() + self._threads.append(thread) def per_device_loader(self, device): """Retrieves the loader iterator object for the given device. @@ -139,6 +142,9 @@ def close(self): dqueue.queue.close() dqueue.loader_queue.close() + for thread in self._threads: + thread.join() + @property def batches_per_execution(self): return self._batches_per_execution @@ -147,18 +153,21 @@ def _loader_worker(self): queues = list(self._queues.values()) data_iter = enumerate(self._loader) batch = [] - while not self._done: - try: - _, data = next(data_iter) - except StopIteration: - break - batch.append(data) - if len(batch) == len(self._devices): - for queue_no, device_batch in enumerate(batch): - queues[queue_no].loader_queue.put(device_batch) - batch = [] - for dqueue in queues: - dqueue.loader_queue.close_write() + + try: + while not self._done: + try: + _, data = next(data_iter) + except StopIteration: + break + batch.append(data) + if len(batch) == len(self._devices): + for queue_no, device_batch in enumerate(batch): + queues[queue_no].loader_queue.put(device_batch) + batch = [] + finally: + for dqueue in queues: + dqueue.loader_queue.close_write() def _get_batch(self, dqueue): batch = [] @@ -171,16 +180,21 @@ def _get_batch(self, dqueue): def _worker(self, dqueue, host_to_device_transfer_threads): device = torch.device(dqueue.device) - while True: - batch = self._get_batch(dqueue) - if not batch: - break - batch = xm.send_cpu_data_to_device(batch, device, self._input_sharding) - for data in batch: - dqueue.queue.put(data) - close_queue_count = next(dqueue.close_queue_count) - if close_queue_count == host_to_device_transfer_threads - 1: - dqueue.queue.close_write() + + try: + while True: + batch = self._get_batch(dqueue) + if not batch: + break + with torch.no_grad(): + batch = xm.send_cpu_data_to_device(batch, device, + self._input_sharding) + for data in batch: + dqueue.queue.put(data) + finally: + close_queue_count = next(dqueue.close_queue_count) + if close_queue_count == host_to_device_transfer_threads - 1: + dqueue.queue.close_write() class MpDeviceLoader(object): @@ -206,11 +220,15 @@ def __init__(self, loader, device, **kwargs): self._loader = loader self._device = device self._parallel_loader_kwargs = kwargs + self._parallel_loader = None def __iter__(self): - parallel_loader = ParallelLoader(self._loader, [self._device], - **self._parallel_loader_kwargs) - return parallel_loader.per_device_loader(self._device) + if self._parallel_loader is not None: + self._parallel_loader.close() + self._parallel_loader = None + self._parallel_loader = ParallelLoader(self._loader, [self._device], + **self._parallel_loader_kwargs) + return self._parallel_loader.per_device_loader(self._device) def __len__(self): return len(self._loader)