Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds Imagen to rosetta #278

Merged
merged 6 commits into from
Oct 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ We currently enable training and evaluation for the following models:
| [GPT-3(paxml)](./rosetta/rosetta/projects/pax) | ✔️ | | ✔️ |
| [t5(t5x)](./rosetta/rosetta/projects/t5x) | ✔️ | ✔️ | ✔️ |
| [ViT](./rosetta/rosetta/projects/vit) | ✔️ | ✔️ | ✔️ |
| [Imagen](./rosetta/rosetta/projects/imagen) | ✔️ | | ✔️ |

We will update this table as new models become available, so stay tuned.

Expand Down
190 changes: 190 additions & 0 deletions rosetta/rosetta/data/multiloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
#
# Copyright (c) 2017-2023 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""An alternative to DataLoader using ZMQ.
This implements MultiLoader, an alternative to DataLoader when torch
is not available. Subprocesses communicate with the loader through
ZMQ, provided for high performance multithreaded queueing.
"""

import multiprocessing as mp
import pickle
import uuid
import weakref
import threading
import queue
import logging

import zmq
import os
from multiprocessing import Lock

the_protocol = pickle.HIGHEST_PROTOCOL

all_pids = weakref.WeakSet()


class EOF:
"""A class that indicates that a data stream is finished."""

def __init__(self, **kw):
"""Initialize the class with the kw as instance variables."""
self.__dict__.update(kw)

class BufferState():
def __init__(self, max_size):
self.q = mp.Queue(maxsize=max_size)

def increment(self):
self.q.put(0)

def decrement(self):
self.q.get()

def get_len(self):
return self.q.qsize()

def reset(self):
while not self.q.empty():
self.q.get_nowait()

def async_depickler(out_queue, in_zmq_pipe, stop_signal):
while True:
if stop_signal:
return
data = in_zmq_pipe.recv()
data = pickle.loads(data)
out_queue.put(data)

def reader(dataset, sockname, index, num_workers, buf_state, signal_state):
"""Read samples from the dataset and send them over the socket.
:param dataset: source dataset
:param sockname: name for the socket to send data to
:param index: index for this reader, using to indicate EOF
"""
global the_protocol
os.environ["WORKER"] = str(index)
os.environ["NUM_WORKERS"] = str(num_workers)
ctx = zmq.Context.instance()
sock = ctx.socket(zmq.PUSH)
# sock.set_hwm(prefetch_buffer_len)
sock.connect(sockname)
for sample in dataset:
buf_state.increment()
data = pickle.dumps(sample, protocol=the_protocol)
sock.send(data)
if signal_state.value != 0:
break
sock.send(pickle.dumps(EOF(index=index)))
sock.close()


class MultiLoader:
"""Alternative to PyTorch DataLoader based on ZMQ."""

def __init__(
self, dataset, workers=4, verbose=True, nokill=True, prefix="/tmp/_multi-", prefetch_buf_max=128
):
"""Create a MultiLoader for a dataset.
This creates ZMQ sockets, spawns `workers` subprocesses, and has them send data
to the socket.
:param dataset: source dataset
:param workers: number of workers
:param verbose: report progress verbosely
:param nokill: don't kill old processes when restarting (allows multiple loaders)
:param prefix: directory prefix for the ZMQ socket
"""
self.dataset = dataset
self.workers = workers
self.orig_workers = workers
self.max_workers = workers * 2
self.retune_period = 100
self.verbose = verbose
self.pids = []
self.socket = None
self.ctx = zmq.Context.instance()
self.nokill = nokill
self.prefix = prefix
# self.prefetch_buf_per_worker = prefetch_buf_per_worker
self.prefetch_buf_max = prefetch_buf_max
self.buf_state = BufferState(prefetch_buf_max)
self.signal_vals = []
self.buffer_low_mark=int(prefetch_buf_max * .15)
assert self.buffer_low_mark < self.prefetch_buf_max
self.depickled_queue = queue.Queue()
self.async_depickler = None
self.async_depickler_stop_signal = False
self.has_started = False

def kill(self):
"""kill."""
self.async_depickler_stop_signal = True
self.async_depickler = None

for pid in self.pids:
if pid is None:
continue
print("killing", pid)
pid.kill()

for pid in self.pids:
# pid.join(1.0)
if pid is None:
continue
print("joining", pid)
pid.join()

self.pids = []
if self.socket is not None:
print("closing", self.socket)
self.socket.close(linger=0)
print("Closed")
self.socket = None
self.buf_state.reset()

def __iter__(self):
"""Return an iterator over this dataloader."""
if self.has_started:
logging.warning("RESTARTING LOADER")
if not self.nokill:
self.kill()
if not self.has_started or not self.nokill:
self.sockname = "ipc://" + self.prefix + str(uuid.uuid4())
self.socket = self.ctx.socket(zmq.PULL)
self.socket.set(zmq.LINGER, 0)
self.socket.bind(self.sockname)
if self.verbose:
print("#", self.sockname)
self.pids = [None] * self.max_workers
self.signal_vals = [None] * self.max_workers
for index in range(self.workers):
signal = mp.Value('i', 0)
args = (self.dataset, self.sockname, index, self.workers, self.buf_state, signal)
self.pids[index] = mp.Process(target=reader, args=args)
self.signal_vals[index] = signal
all_pids.update(self.pids[:self.workers])
for pid in self.pids:
if pid is not None:
pid.start()

# Async depickler setup
self.async_depickler_stop_signal = False
self.async_depickler = threading.Thread(target=async_depickler, args=(self.depickled_queue, self.socket, self.async_depickler_stop_signal), daemon=True)
self.async_depickler.start()

self.has_started = True
count = 0
while self.pids.count(None) < len(self.pids):
sample = self.depickled_queue.get(block=True)
if isinstance(sample, EOF):
if self.verbose:
print("# subprocess finished", sample.index)
self.pids[sample.index].join(1.0)
self.pids[sample.index] = None
else:
self.buf_state.decrement()
yield sample
count += 1
2 changes: 2 additions & 0 deletions rosetta/rosetta/projects/diffusion/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
import tensorflow as tf
tf.config.set_visible_devices([], 'GPU')
74 changes: 74 additions & 0 deletions rosetta/rosetta/projects/diffusion/augmentations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright (c) 2022-2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Sample and conditioning augmentations for diffusion and multimodal
model training
"""

from typing import Callable, Tuple, Dict, Optional, Union
import typing_extensions

import jax
import jax.numpy as jnp
import jax.lax as lax

Augmentable = Union[jnp.ndarray, Dict[str, jnp.ndarray]]

class AugmentationCallable(typing_extensions.Protocol):
""" Call signature for a sample augmentation function.
Returns the augmented sample and fresh rng """
def __call__(self,
to_aug: Augmentable,
rng: jax.random.KeyArray
) -> Tuple[Augmentable, jax.random.KeyArray]:
...

def text_conditioning_dropout(to_aug: Augmentable,
rng: jax.random.KeyArray,
dropout_rate: float = 0.1,
drop_key: Optional[str] = None,
null_value = None,
) -> Tuple[Augmentable, jax.random.KeyArray]:
"""
Can take either a dictionary, where it will dropout on the 'text_mask' key by default,
or drop_key if supplied. If given just an array, it will dropout assuming shape = [b, ...]
(setting to 0, or null_value if supplied)
"""
if drop_key is None:
drop_key = 'text_mask'
if null_value is None:
null_value = 0
cond = to_aug
if isinstance(to_aug, dict):
cond = to_aug[drop_key]

my_rng, rng = jax.random.split(rng)
keep_prob = 1 - dropout_rate

mask_shape = list(cond.shape)
for i in range(1, len(mask_shape)):
mask_shape[i] = 1

mask = jax.random.bernoulli(my_rng, p=keep_prob, shape=mask_shape)
mask = jnp.broadcast_to(mask, cond.shape)

cond = lax.select(mask, cond, null_value * jnp.ones_like(cond))

if isinstance(to_aug, dict):
to_aug[drop_key] = cond
else:
to_aug = cond

return to_aug, rng
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
# Copyright (c) 2022-2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import jax
import jax.experimental.maps
import jax.numpy as jnp
import numpy as np

import torch.utils.data
from torchvision import datasets, transforms

from torch.utils.data import Dataset

from rosetta.projects.diffusion.common.generative_metrics.fid_metric import fid
from rosetta.projects.diffusion.common.generative_metrics.inception_v3 import load_pretrained_inception_v3
import sys
import os
import glob
from PIL import Image

class MyDataset(Dataset):
def __init__(self, root):
self.image_paths = self.get_image_paths(root)
self.transform = transforms.ToTensor()#lambda x: np.array(x)

def get_image_paths(self,directory):
image_paths = []

# Recursive search for all image files
for root, dirs, files in os.walk(directory):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')):
image_paths.append(os.path.join(root, file))

return image_paths

def __getitem__(self, index):
image_path = self.image_paths[index]
x = Image.open(image_path)
if self.transform is not None:
x = self.transform(x)
if x.shape != (3, 256, 256):
if x.shape[0] == 1:
x = torch.cat([x,x,x], dim=0)
if x.shape[0] == 4:
x = x[:3]
return x

def __len__(self):
return len(self.image_paths)

def collate_fn(batch):
return torch.stack(batch)

TRAIN_ROOT = sys.argv[1]
TEST_ROOT = sys.argv[2]
NUM_SAMPLES = int(sys.argv[3])

def load_cifar10(batch_size=128):
"""Load CIFAR10 dataset."""

train_dataset = MyDataset(TRAIN_ROOT)
test_dataset = MyDataset(TEST_ROOT)
train_dataset = torch.utils.data.Subset(train_dataset, np.arange(NUM_SAMPLES))
test_dataset = torch.utils.data.Subset(test_dataset, np.arange(NUM_SAMPLES))
print(len(train_dataset))
print(len(test_dataset))

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn,
)
test_loader = torch.utils.data.DataLoader(
test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn,
)
return train_loader, test_loader


train, test = load_cifar10(1024)

fid(train, test, 'inception_weights', 1)
Loading
Loading