diff --git a/README.md b/README.md index cfde95063..ca41b0cf8 100644 --- a/README.md +++ b/README.md @@ -139,6 +139,7 @@ We currently enable training and evaluation for the following models: | [GPT-3(paxml)](./rosetta/rosetta/projects/pax) | ✔️ | | ✔️ | | [t5(t5x)](./rosetta/rosetta/projects/t5x) | ✔️ | ✔️ | ✔️ | | [ViT](./rosetta/rosetta/projects/vit) | ✔️ | ✔️ | ✔️ | +| [Imagen](./rosetta/rosetta/projects/imagen) | ✔️ | | ✔️ | We will update this table as new models become available, so stay tuned. diff --git a/rosetta/rosetta/data/multiloader.py b/rosetta/rosetta/data/multiloader.py new file mode 100644 index 000000000..8d6a02526 --- /dev/null +++ b/rosetta/rosetta/data/multiloader.py @@ -0,0 +1,190 @@ +# +# Copyright (c) 2017-2023 NVIDIA CORPORATION. All rights reserved. +# This file is part of the WebDataset library. +# See the LICENSE file for licensing terms (BSD-style). +# + +"""An alternative to DataLoader using ZMQ. +This implements MultiLoader, an alternative to DataLoader when torch +is not available. Subprocesses communicate with the loader through +ZMQ, provided for high performance multithreaded queueing. +""" + +import multiprocessing as mp +import pickle +import uuid +import weakref +import threading +import queue +import logging + +import zmq +import os +from multiprocessing import Lock + +the_protocol = pickle.HIGHEST_PROTOCOL + +all_pids = weakref.WeakSet() + + +class EOF: + """A class that indicates that a data stream is finished.""" + + def __init__(self, **kw): + """Initialize the class with the kw as instance variables.""" + self.__dict__.update(kw) + +class BufferState(): + def __init__(self, max_size): + self.q = mp.Queue(maxsize=max_size) + + def increment(self): + self.q.put(0) + + def decrement(self): + self.q.get() + + def get_len(self): + return self.q.qsize() + + def reset(self): + while not self.q.empty(): + self.q.get_nowait() + +def async_depickler(out_queue, in_zmq_pipe, stop_signal): + while True: + if stop_signal: + return + data = in_zmq_pipe.recv() + data = pickle.loads(data) + out_queue.put(data) + +def reader(dataset, sockname, index, num_workers, buf_state, signal_state): + """Read samples from the dataset and send them over the socket. + :param dataset: source dataset + :param sockname: name for the socket to send data to + :param index: index for this reader, using to indicate EOF + """ + global the_protocol + os.environ["WORKER"] = str(index) + os.environ["NUM_WORKERS"] = str(num_workers) + ctx = zmq.Context.instance() + sock = ctx.socket(zmq.PUSH) + # sock.set_hwm(prefetch_buffer_len) + sock.connect(sockname) + for sample in dataset: + buf_state.increment() + data = pickle.dumps(sample, protocol=the_protocol) + sock.send(data) + if signal_state.value != 0: + break + sock.send(pickle.dumps(EOF(index=index))) + sock.close() + + +class MultiLoader: + """Alternative to PyTorch DataLoader based on ZMQ.""" + + def __init__( + self, dataset, workers=4, verbose=True, nokill=True, prefix="/tmp/_multi-", prefetch_buf_max=128 + ): + """Create a MultiLoader for a dataset. + This creates ZMQ sockets, spawns `workers` subprocesses, and has them send data + to the socket. + :param dataset: source dataset + :param workers: number of workers + :param verbose: report progress verbosely + :param nokill: don't kill old processes when restarting (allows multiple loaders) + :param prefix: directory prefix for the ZMQ socket + """ + self.dataset = dataset + self.workers = workers + self.orig_workers = workers + self.max_workers = workers * 2 + self.retune_period = 100 + self.verbose = verbose + self.pids = [] + self.socket = None + self.ctx = zmq.Context.instance() + self.nokill = nokill + self.prefix = prefix + # self.prefetch_buf_per_worker = prefetch_buf_per_worker + self.prefetch_buf_max = prefetch_buf_max + self.buf_state = BufferState(prefetch_buf_max) + self.signal_vals = [] + self.buffer_low_mark=int(prefetch_buf_max * .15) + assert self.buffer_low_mark < self.prefetch_buf_max + self.depickled_queue = queue.Queue() + self.async_depickler = None + self.async_depickler_stop_signal = False + self.has_started = False + + def kill(self): + """kill.""" + self.async_depickler_stop_signal = True + self.async_depickler = None + + for pid in self.pids: + if pid is None: + continue + print("killing", pid) + pid.kill() + + for pid in self.pids: + # pid.join(1.0) + if pid is None: + continue + print("joining", pid) + pid.join() + + self.pids = [] + if self.socket is not None: + print("closing", self.socket) + self.socket.close(linger=0) + print("Closed") + self.socket = None + self.buf_state.reset() + + def __iter__(self): + """Return an iterator over this dataloader.""" + if self.has_started: + logging.warning("RESTARTING LOADER") + if not self.nokill: + self.kill() + if not self.has_started or not self.nokill: + self.sockname = "ipc://" + self.prefix + str(uuid.uuid4()) + self.socket = self.ctx.socket(zmq.PULL) + self.socket.set(zmq.LINGER, 0) + self.socket.bind(self.sockname) + if self.verbose: + print("#", self.sockname) + self.pids = [None] * self.max_workers + self.signal_vals = [None] * self.max_workers + for index in range(self.workers): + signal = mp.Value('i', 0) + args = (self.dataset, self.sockname, index, self.workers, self.buf_state, signal) + self.pids[index] = mp.Process(target=reader, args=args) + self.signal_vals[index] = signal + all_pids.update(self.pids[:self.workers]) + for pid in self.pids: + if pid is not None: + pid.start() + + # Async depickler setup + self.async_depickler_stop_signal = False + self.async_depickler = threading.Thread(target=async_depickler, args=(self.depickled_queue, self.socket, self.async_depickler_stop_signal), daemon=True) + self.async_depickler.start() + + self.has_started = True + count = 0 + while self.pids.count(None) < len(self.pids): + sample = self.depickled_queue.get(block=True) + if isinstance(sample, EOF): + if self.verbose: + print("# subprocess finished", sample.index) + self.pids[sample.index].join(1.0) + self.pids[sample.index] = None + else: + self.buf_state.decrement() + yield sample + count += 1 diff --git a/rosetta/rosetta/projects/diffusion/__init__.py b/rosetta/rosetta/projects/diffusion/__init__.py new file mode 100644 index 000000000..2f3a18785 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/__init__.py @@ -0,0 +1,2 @@ +import tensorflow as tf +tf.config.set_visible_devices([], 'GPU') \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/augmentations.py b/rosetta/rosetta/projects/diffusion/augmentations.py new file mode 100644 index 000000000..2d671d58e --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/augmentations.py @@ -0,0 +1,74 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Sample and conditioning augmentations for diffusion and multimodal +model training +""" + +from typing import Callable, Tuple, Dict, Optional, Union +import typing_extensions + +import jax +import jax.numpy as jnp +import jax.lax as lax + +Augmentable = Union[jnp.ndarray, Dict[str, jnp.ndarray]] + +class AugmentationCallable(typing_extensions.Protocol): + """ Call signature for a sample augmentation function. + Returns the augmented sample and fresh rng """ + def __call__(self, + to_aug: Augmentable, + rng: jax.random.KeyArray + ) -> Tuple[Augmentable, jax.random.KeyArray]: + ... + +def text_conditioning_dropout(to_aug: Augmentable, + rng: jax.random.KeyArray, + dropout_rate: float = 0.1, + drop_key: Optional[str] = None, + null_value = None, + ) -> Tuple[Augmentable, jax.random.KeyArray]: + """ + Can take either a dictionary, where it will dropout on the 'text_mask' key by default, + or drop_key if supplied. If given just an array, it will dropout assuming shape = [b, ...] + (setting to 0, or null_value if supplied) + """ + if drop_key is None: + drop_key = 'text_mask' + if null_value is None: + null_value = 0 + cond = to_aug + if isinstance(to_aug, dict): + cond = to_aug[drop_key] + + my_rng, rng = jax.random.split(rng) + keep_prob = 1 - dropout_rate + + mask_shape = list(cond.shape) + for i in range(1, len(mask_shape)): + mask_shape[i] = 1 + + mask = jax.random.bernoulli(my_rng, p=keep_prob, shape=mask_shape) + mask = jnp.broadcast_to(mask, cond.shape) + + cond = lax.select(mask, cond, null_value * jnp.ones_like(cond)) + + if isinstance(to_aug, dict): + to_aug[drop_key] = cond + else: + to_aug = cond + + return to_aug, rng diff --git a/rosetta/rosetta/projects/diffusion/common/__init__.py b/rosetta/rosetta/projects/diffusion/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rosetta/rosetta/projects/diffusion/common/generative_metrics/evaluate_fid.py b/rosetta/rosetta/projects/diffusion/common/generative_metrics/evaluate_fid.py new file mode 100644 index 000000000..6910fedd1 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/common/generative_metrics/evaluate_fid.py @@ -0,0 +1,91 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import jax +import jax.experimental.maps +import jax.numpy as jnp +import numpy as np + +import torch.utils.data +from torchvision import datasets, transforms + +from torch.utils.data import Dataset + +from rosetta.projects.diffusion.common.generative_metrics.fid_metric import fid +from rosetta.projects.diffusion.common.generative_metrics.inception_v3 import load_pretrained_inception_v3 +import sys +import os +import glob +from PIL import Image + +class MyDataset(Dataset): + def __init__(self, root): + self.image_paths = self.get_image_paths(root) + self.transform = transforms.ToTensor()#lambda x: np.array(x) + + def get_image_paths(self,directory): + image_paths = [] + + # Recursive search for all image files + for root, dirs, files in os.walk(directory): + for file in files: + if file.lower().endswith(('.png', '.jpg', '.jpeg', '.gif')): + image_paths.append(os.path.join(root, file)) + + return image_paths + + def __getitem__(self, index): + image_path = self.image_paths[index] + x = Image.open(image_path) + if self.transform is not None: + x = self.transform(x) + if x.shape != (3, 256, 256): + if x.shape[0] == 1: + x = torch.cat([x,x,x], dim=0) + if x.shape[0] == 4: + x = x[:3] + return x + + def __len__(self): + return len(self.image_paths) + +def collate_fn(batch): + return torch.stack(batch) + +TRAIN_ROOT = sys.argv[1] +TEST_ROOT = sys.argv[2] +NUM_SAMPLES = int(sys.argv[3]) + +def load_cifar10(batch_size=128): + """Load CIFAR10 dataset.""" + + train_dataset = MyDataset(TRAIN_ROOT) + test_dataset = MyDataset(TEST_ROOT) + train_dataset = torch.utils.data.Subset(train_dataset, np.arange(NUM_SAMPLES)) + test_dataset = torch.utils.data.Subset(test_dataset, np.arange(NUM_SAMPLES)) + print(len(train_dataset)) + print(len(test_dataset)) + + train_loader = torch.utils.data.DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=collate_fn, + ) + test_loader = torch.utils.data.DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn, + ) + return train_loader, test_loader + + +train, test = load_cifar10(1024) + +fid(train, test, 'inception_weights', 1) diff --git a/rosetta/rosetta/projects/diffusion/common/generative_metrics/fid_metric.py b/rosetta/rosetta/projects/diffusion/common/generative_metrics/fid_metric.py new file mode 100644 index 000000000..b56c4744a --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/common/generative_metrics/fid_metric.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This implementation is basing on: https://github.com/mseitzer/pytorch-fid +Which is not the original implementation, but is widely used as the best PyTorch port of the original TF version. +""" +import jax +from rosetta.projects.diffusion.common.generative_metrics.inception_v3 import load_pretrained_inception_v3 +import jax.numpy as jnp +from jax.scipy import linalg +from tqdm import tqdm +import logging +import time + +import numpy as np + + +def get_activations(params, model, batches): + """Calculates the activations of the pool_3 layer for all images. + + Returns: + -- A jax array of dimension (num images, dims) that contains the activations of the batches. + """ + all_outs = [] + for batch in tqdm(batches, desc="Calculating activations"): + if isinstance(batch, tuple) or isinstance(batch, list): + batch = batch[0] + + batch = jnp.array(batch) + + if batch.shape[1] == 3: + batch = batch.transpose(0, 2, 3, 1) + + logging.info("batched") + outs = model(params, batch) + all_outs.append(outs) + logging.info("Completed 1 batch") + time.sleep(1.0) + + all_outs = jnp.concatenate(all_outs, axis=0) + return all_outs + + +def calculate_activation_statistics(params, model, batches, num_pads=0): + """Calculation of the statistics used by the FID. + + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of the inception model. + """ + act = get_activations(params, model, batches) + act = act[:act.shape[0] - num_pads] + mu = jnp.mean(act, axis=0) + sigma = jnp.cov(act, rowvar=False) + return mu, sigma + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = jnp.atleast_1d(mu1) + mu2 = jnp.atleast_1d(mu2) + + sigma1 = jnp.atleast_2d(sigma1) + sigma2 = jnp.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean = linalg.sqrtm(sigma1.dot(sigma2)) + if not jnp.isfinite(covmean).all(): + msg = 'fid calculation produces singular product; adding %s to diagonal of cov estimates' % eps + print(msg) + offset = jnp.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if jnp.iscomplexobj(covmean): + if not jnp.allclose(jnp.diagonal(covmean).imag, 0, atol=1e-3): + m = jnp.max(jnp.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = jnp.trace(covmean) + + return diff.dot(diff) + jnp.trace(sigma1) + jnp.trace(sigma2) - 2 * tr_covmean + + +def fid(samples1, samples2, inception_weight_path, inception_batch_size=32): + """Load pretrained Inception and calculate the FID of two set of batches""" + params, inception = load_pretrained_inception_v3(jax_weight_restore_path=inception_weight_path) + + def pad_and_batch_array(array, divisor=inception_batch_size): + remainder = divisor - (array.shape[0] % divisor) + pad_instances = jnp.repeat(array[:1], remainder, axis=0) + num_batches = (array.shape[0] + remainder) // divisor + array = jnp.concatenate((array, pad_instances), axis=0) + return array.reshape((num_batches, divisor, *(array.shape[1:]))), remainder + + def run(params, batch): + return inception.apply( + params, + batch, + resize_input=True, + normalize_input=True, + return_featuremap=True, + ) + + jitted_run = jax.jit(run) + logging.info("Jitted run") + + m1, s1 = calculate_activation_statistics(params, jitted_run, samples1) + m2, s2 = calculate_activation_statistics(params, jitted_run, samples2) + cpu_device = jax.devices("cpu")[0] + m1 = jax.device_put(m1, cpu_device) + s1 = jax.device_put(s1, cpu_device) + m2 = jax.device_put(m2, cpu_device) + s2 = jax.device_put(s2, cpu_device) + + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + jax.debug.print(f'fid_value {fid_value}') + return fid_value \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/common/generative_metrics/inception_v3.py b/rosetta/rosetta/projects/diffusion/common/generative_metrics/inception_v3.py new file mode 100644 index 000000000..df7507ee0 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/common/generative_metrics/inception_v3.py @@ -0,0 +1,318 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Implementation of Inception v3 network in Flax.""" + +import os +import pickle +import io + +import flax.linen as nn +import jax +import jax.numpy as jnp +from jax.experimental import multihost_utils +import torch +from flax.core import FrozenDict +from flax.traverse_util import flatten_dict, unflatten_dict +import requests + + +# DOWNLOADED FROM: +PYT_LINK = "https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth" + +# LOCAL_PATH = "FID/pt_inception-2015-12-05-6726825d.pth" + + +class ConvBlock(nn.Module): + """Convolutional block with batch normalization and ReLU activation.""" + + features: int + kernel_size: int + strides: int = 1 + padding: str = "SAME" + + @nn.compact + def __call__(self, x): + x = nn.Conv(self.features, self.kernel_size, self.strides, self.padding, use_bias=False)(x) + x = nn.BatchNorm(use_running_average=True, momentum=0.9, epsilon=0.001)(x) + return nn.relu(x) + + +class FIDInceptionA(nn.Module): + """InceptionA module.""" + + pool_features: int + + # 255,904 parameters for 288 input channels and 32 pool features. + + @nn.compact + def __call__(self, x): + branch1x1 = ConvBlock(features=64, kernel_size=(1, 1))(x) + + branch3x3dbl = ConvBlock(features=64, kernel_size=(1, 1))(x) + branch3x3dbl = ConvBlock(features=96, kernel_size=(3, 3))(branch3x3dbl) + branch3x3dbl = ConvBlock(features=96, kernel_size=(3, 3))(branch3x3dbl) + + branch5x5 = ConvBlock(features=48, kernel_size=(1, 1))(x) + branch5x5 = ConvBlock(features=64, kernel_size=(5, 5))(branch5x5) + + branch_pool = nn.avg_pool( + x, window_shape=(3, 3), strides=(1, 1), padding="SAME", count_include_pad=False + ) + branch_pool = ConvBlock(features=self.pool_features, kernel_size=(1, 1))(branch_pool) + + return jax.lax.concatenate([branch1x1, branch5x5, branch3x3dbl, branch_pool], dimension=3) + + +class FIDInceptionB(nn.Module): + """InceptionB module.""" + + # 1,153,280 parameters for 288 input channels. + + @nn.compact + def __call__(self, x): + branch3x3 = ConvBlock(features=384, kernel_size=(3, 3), strides=2, padding="VALID")(x) + + branch3x3dbl = ConvBlock(features=64, kernel_size=(1, 1))(x) + branch3x3dbl = ConvBlock(features=96, kernel_size=(3, 3))(branch3x3dbl) + + # No padding here in PyTorch version + branch3x3dbl = ConvBlock(features=96, kernel_size=(3, 3), strides=2, padding="VALID")(branch3x3dbl) + + branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID") + + return jax.lax.concatenate([branch3x3, branch3x3dbl, branch_pool], dimension=3) + + +class FIDInceptionC(nn.Module): + """InceptionC module.""" + + channels7x7: int + + # 1,297,408 parameters for 768 input channels and 128 channels7x7. + + @nn.compact + def __call__(self, x): + branch1x1 = ConvBlock(features=192, kernel_size=(1, 1))(x) + + branch7x7 = ConvBlock(features=self.channels7x7, kernel_size=(1, 1))(x) + branch7x7 = ConvBlock(features=self.channels7x7, kernel_size=(1, 7))(branch7x7) + branch7x7 = ConvBlock(features=192, kernel_size=(7, 1))(branch7x7) + + branch7x7dbl = ConvBlock(features=self.channels7x7, kernel_size=(1, 1))(x) + branch7x7dbl = ConvBlock(features=self.channels7x7, kernel_size=(7, 1))(branch7x7dbl) + branch7x7dbl = ConvBlock(features=self.channels7x7, kernel_size=(1, 7))(branch7x7dbl) + branch7x7dbl = ConvBlock(features=self.channels7x7, kernel_size=(7, 1))(branch7x7dbl) + branch7x7dbl = ConvBlock(features=192, kernel_size=(1, 7))(branch7x7dbl) + + branch_pool = nn.avg_pool( + x, window_shape=(3, 3), strides=(1, 1), padding="SAME", count_include_pad=False + ) + branch_pool = ConvBlock(features=192, kernel_size=(1, 1))(branch_pool) + + return jax.lax.concatenate([branch1x1, branch7x7, branch7x7dbl, branch_pool], dimension=3) + + +class FIDInceptionD(nn.Module): + # 1,698,304 parameters for 768 input channels. + + @nn.compact + def __call__(self, x): + branch3x3 = ConvBlock(features=192, kernel_size=(1, 1))(x) + branch3x3 = ConvBlock(features=320, kernel_size=(3, 3), strides=2, padding="VALID")(branch3x3) + + branch7x7x3 = ConvBlock(features=192, kernel_size=(1, 1))(x) + branch7x7x3 = ConvBlock(features=192, kernel_size=(1, 7))(branch7x7x3) + branch7x7x3 = ConvBlock(features=192, kernel_size=(7, 1))(branch7x7x3) + branch7x7x3 = ConvBlock(features=192, kernel_size=(3, 3), strides=2, padding="VALID")(branch7x7x3) + + branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID") + + return jax.lax.concatenate([branch3x3, branch7x7x3, branch_pool], dimension=3) + + +class FIDInceptionE_1(nn.Module): + # 5,044,608 parameters for 1024 channels + + @nn.compact + def __call__(self, x): + branch1x1 = ConvBlock(features=320, kernel_size=(1, 1))(x) + + branch3x3 = ConvBlock(features=384, kernel_size=(1, 1))(x) + branch3x3_1 = ConvBlock(features=384, kernel_size=(1, 3))(branch3x3) + branch3x3_2 = ConvBlock(features=384, kernel_size=(3, 1))(branch3x3) + branch3x3 = jax.lax.concatenate([branch3x3_1, branch3x3_2], dimension=3) + + branch3x3dbl = ConvBlock(features=448, kernel_size=(1, 1))(x) + branch3x3dbl = ConvBlock(features=384, kernel_size=(3, 3))(branch3x3dbl) + branch3x3dbl_1 = ConvBlock(features=384, kernel_size=(1, 3))(branch3x3dbl) + branch3x3dbl_2 = ConvBlock(features=384, kernel_size=(3, 1))(branch3x3dbl) + branch3x3dbl = jax.lax.concatenate([branch3x3dbl_1, branch3x3dbl_2], dimension=3) + + branch_pool = nn.avg_pool( + x, window_shape=(3, 3), strides=(1, 1), padding="SAME", count_include_pad=False + ) + branch_pool = ConvBlock(features=192, kernel_size=(1, 1))(branch_pool) + + return jax.lax.concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], dimension=3) + + +class FIDInceptionE_2(nn.Module): + # 6,076,800 parameters for 2048 channels + + @nn.compact + def __call__(self, x): + branch1x1 = ConvBlock(features=320, kernel_size=(1, 1))(x) + + branch3x3 = ConvBlock(features=384, kernel_size=(1, 1))(x) + branch3x3_1 = ConvBlock(features=384, kernel_size=(1, 3))(branch3x3) + branch3x3_2 = ConvBlock(features=384, kernel_size=(3, 1))(branch3x3) + branch3x3 = jax.lax.concatenate([branch3x3_1, branch3x3_2], dimension=3) + + branch3x3dbl = ConvBlock(features=448, kernel_size=(1, 1))(x) + branch3x3dbl = ConvBlock(features=384, kernel_size=(3, 3))(branch3x3dbl) + branch3x3dbl_1 = ConvBlock(features=384, kernel_size=(1, 3))(branch3x3dbl) + branch3x3dbl_2 = ConvBlock(features=384, kernel_size=(3, 1))(branch3x3dbl) + branch3x3dbl = jax.lax.concatenate([branch3x3dbl_1, branch3x3dbl_2], dimension=3) + + branch_pool = nn.max_pool(x, window_shape=(3, 3), strides=(1, 1), padding="SAME") + branch_pool = ConvBlock(features=192, kernel_size=(1, 1))(branch_pool) + + return jax.lax.concatenate([branch1x1, branch3x3, branch3x3dbl, branch_pool], dimension=3) + + +class FIDInceptionV3(nn.Module): + dropout: float = 0.5 + num_classes: int = 1008 + + @nn.compact + def __call__(self, x, resize_input=False, normalize_input=False, return_featuremap=True): + if resize_input: + x = jax.image.resize(x, shape=(x.shape[0], 299, 299, x.shape[3]), method="bilinear") + + if normalize_input: + x = 2 * x - 1 + + # N x 3 x 299 x 299 + x = ConvBlock(features=32, kernel_size=(3, 3), strides=2, padding="VALID")(x) # N x 32 x 149 x 149 + x = ConvBlock(features=32, kernel_size=(3, 3), padding="VALID")(x) # N x 32 x 147 x 147 + x = ConvBlock(features=64, kernel_size=(3, 3))(x) # N x 64 x 147 x 147 + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID") # N x 64 x 73 x 73 + + x = ConvBlock(features=80, kernel_size=(1, 1))(x) # N x 80 x 73 x 73 + x = ConvBlock(features=192, kernel_size=(3, 3), padding="VALID")(x) # N x 192 x 71 x 71 + x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="VALID") # N x 192 x 35 x 35 + + x = FIDInceptionA(32)(x) # N x 256 x 35 x 35 + x = FIDInceptionA(64)(x) # N x 288 x 35 x 35 + x = FIDInceptionA(64)(x) # N x 288 x 35 x 35 + + x = FIDInceptionB()(x) # N x 768 x 17 x 17 + + x = FIDInceptionC(128)(x) # N x 768 x 17 x 17 + x = FIDInceptionC(160)(x) # N x 768 x 17 x 17 + x = FIDInceptionC(160)(x) # N x 768 x 17 x 17 + x = FIDInceptionC(192)(x) # N x 768 x 17 x 17 + + x = FIDInceptionD()(x) # N x 1280 x 8 x 8 + x = FIDInceptionE_1()(x) # N x 2048 x 8 x 8 + x = FIDInceptionE_2()(x) # N x 2048 x 8 x 8 + + x = jnp.mean(x, axis=(1, 2)) # Global pooling: N x 2048 + # THIS x SHOULD BE RETURNED FOR THE FID CALCULATIONS + if return_featuremap: + return x + + x = nn.Dense(features=self.num_classes)(x) # N x 1008 + return x + + +def convert_array(pyt_w): + """Convert array from PyTorch to Jax.""" + arr = jnp.array(pyt_w) + if len(pyt_w.shape) == 4: # Convolution + return jnp.transpose(arr, (2, 3, 1, 0)) + elif len(pyt_w.shape) == 2: # Dense + return jnp.transpose(arr, (1, 0)) + return arr + + +def convert_all(pyt_params, jax_params, verbose=True): + new_jax_params = {} + flat_jax_params = flatten_dict(jax_params, sep=".") + + if verbose: + print("CONVERTING WEIGHTS FROM PYT TO JAX") + print("FOLLOWING CONVERSION WILL BE APPLIED:") + for pyt_key, jax_key in zip(pyt_params, flat_jax_params): + pyt_key = str(pyt_key) + jax_key = str(jax_key) + print(f"{pyt_key.ljust(50)} -> {jax_key.ljust(50)}") + + for (k1, v1), (k2, v2) in zip(pyt_params.items(), flat_jax_params.items()): + new_jax_params[k2] = convert_array(v1) + msg = f"Tried to pass weight of {k1} {v1.shape} to {k2} {v2.shape}!" + assert new_jax_params[k2].shape == v2.shape, msg + + new_jax_params = unflatten_dict(new_jax_params, sep=".") + return FrozenDict(new_jax_params) + + +def load_pretrained_inception_v3(convert_pyt_weights=None, jax_weight_restore_path=None): + network = FIDInceptionV3() + assert convert_pyt_weights is not None or jax_weight_restore_path is not None, "Either pytorch or jax weights must be given" + + if convert_pyt_weights is not None: + rnd_params = network.init(jax.random.PRNGKey(0), jnp.ones((1, 299, 299, 3))) + + pyt_params = torch.load(convert_pyt_weights) + pyt_params_batch_stats = {k: v for k, v in pyt_params.items() if "running" in k} + jax_batch_stats = convert_all(pyt_params_batch_stats, rnd_params["batch_stats"], verbose=True) + pyt_params = { + k: v for k, v in pyt_params.items() if "num_batches_tracked" not in k and "running" not in k + } + + # Every ConvBlock in PyTorch has the following order: + # bn.bias + # bn.weight + # conv.weight + # Meanwhile, in Jax the order is reversed: + # conv.weight + # bn.weight + # bn.bias + # We fix this by reversing PyTorch triplets. + + pyt_keys = list(pyt_params) + pyt_keys_in_groups = [pyt_keys[i: i + 3][::-1] for i in range(0, len(pyt_keys), 3)] + pyt_keys = [key for group in pyt_keys_in_groups for key in group] + pyt_params = {key: pyt_params[key] for key in pyt_keys} + + jax_params = convert_all(pyt_params, rnd_params["params"]) + final_jax_params = FrozenDict(params=jax_params, batch_stats=jax_batch_stats) + elif jax_weight_restore_path is not None: + if not os.path.exists(jax_weight_restore_path): + # If we don't have jax weights on hand, download the pytorch ones and convert + weights = requests.get(PYT_LINK, allow_redirects=True).content + weights_io = io.BytesIO(weights) + params, _ = load_pretrained_inception_v3(convert_pyt_weights=weights_io) + with open(jax_weight_restore_path, 'wb') as f: + pickle.dump(params, f) + + multihost_utils.sync_global_devices("download_inception") + + final_jax_params = pickle.load(open(jax_weight_restore_path, "rb")) + else: + raise NotImplementedError + + return final_jax_params, network diff --git a/rosetta/rosetta/projects/diffusion/common/gin_utils.py b/rosetta/rosetta/projects/diffusion/common/gin_utils.py new file mode 100644 index 000000000..4601081fe --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/common/gin_utils.py @@ -0,0 +1,65 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools + + +def call(func, args=None, **keywords): + ''' + Same as partial, except will call the output of partial assuming all args are already specified + ''' + return partial(func=func, args=args, **keywords)() + + +def partial(func, args=None, **keywords): + ''' + Use this in gin-configs if you want to use functools.partial. + + In gin-config, there is no support for positional arguments as bindings. Also, to use functools.partial, + you must pass the function as a positional argument, so this function exists as a workaround. To be + concrete, none of the following would work in a gin file + + ``` + # Case 1 + train/functools.partial: + my_awesome_fn + a_kwarg=3 + + # Case 2 + train/functools.partial: + func=my_awesome_fn + a_kwarg=3 + + # Case 3 + train/functools.partial: + func=my_awesome_fn + {'a_dict': 'as a pos_arg'} + ``` + + This function can support Case 2, and Case 1 needs to be written like Case 2 to work. As for + Case 3, you'd need to write it: + ``` + train/functools.partial: + func=my_awesome_fn + args=[{'a_dict': 'as a pos_arg'}] + + The list is required so that we can unpack it as positional args to functools.partial + ``` + ''' + if args is not None and not isinstance(args, (list, tuple)): + raise TypeError(f'If you specify args, it must be a tuple or list, but you passed {type(args)}: {args}') + + if args is None: + args = [] + return functools.partial(func, *args, **keywords) diff --git a/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh new file mode 100644 index 000000000..8eeeaa854 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/common/set_gpu_xla_flags.sh @@ -0,0 +1 @@ +export XLA_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=false --xla_gpu_enable_async_all_gather=false --xla_gpu_enable_async_reduce_scatter=false --xla_gpu_enable_triton_gemm=false --xla_gpu_cuda_graph_level=0 --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_async_all_reduce=false ${XLA_FLAGS}" \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/configs/adamw_ema_opt.gin b/rosetta/rosetta/projects/diffusion/configs/adamw_ema_opt.gin new file mode 100644 index 000000000..79da9cb8b --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/configs/adamw_ema_opt.gin @@ -0,0 +1,36 @@ +from __gin__ import dynamic_registration + +import optax +from t5x import optimizers +from rosetta.projects.diffusion.common import gin_utils as gin_utils + +EMA=%gin.REQUIRED + +OPTIMIZER = @optimizers.chain() +optimizers.chain: + transformations = [@optax.clip_by_global_norm(), @adamw/gin_utils.call(), @optimizers.apply_ema_weights()] + +optimizers.apply_ema_weights: + decay = %EMA + debias = False + +optax.clip_by_global_norm: + max_norm = 1.0 + +adamw/optimizers.inject_hyperparams: + inner_factory = @optax.adamw + # If jax.config.x64_enabled, you may want to pass your model parameter dtype to + # avoid inject_hyperparams inferring the dtype, which may be wrong, i.e. float64, + # and that causes the gradient updates to be promoted. + # hyperparam_dtype = @jnp.float32 + +# Same as gin_utils.partial, except it will call the wrapped function as well. +# This is a workaround since in gin-config you cannot do @functools.partial()() +adamw/gin_utils.call: + func = @adamw/optimizers.inject_hyperparams() + #learning_rate = @utils.create_learning_rate_scheduler() + learning_rate = @optax.join_schedules() + weight_decay = 0.01 + b1 = 0.9 + b2 = 0.999 + eps = 1e-8 diff --git a/rosetta/rosetta/projects/diffusion/data_utils/__init__.py b/rosetta/rosetta/projects/diffusion/data_utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rosetta/rosetta/projects/diffusion/data_utils/wds_helper_cli.py b/rosetta/rosetta/projects/diffusion/data_utils/wds_helper_cli.py new file mode 100644 index 000000000..dff53ab41 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/data_utils/wds_helper_cli.py @@ -0,0 +1,62 @@ +'''Debugging tool: prints out the first element in a Webdataset, print cardinality, and create small subset for testing''' +import argparse +import json +import time + +import jax +import numpy as np +import tqdm +import webdataset as wds +from jax import tree_util + + +def list_of_dict_to_dict_of_list(samples): + outer = tree_util.tree_structure([0 for _ in samples]) + inner = tree_util.tree_structure(samples[0]) + return tree_util.tree_transpose(outer, inner, samples) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('urls', type=str, help='Urls for webdataset. Supports braceexpand syntax') + parser.add_argument('--len', default=False, action='store_true') + parser.add_argument('--batch_size', default=1, type=int, help='If provided will call batched(N) on the Webdataset') + parser.add_argument('--first100', default=None, type=str, help='If provided, will make a sample of 100 and write it to the file named here') + + args = parser.parse_args() + + dataset = wds.WebDataset(args.urls).decode('rgb') + if args.first100: + if not args.first100.endswith('.tar'): + raise ValueError(f'--first100={args.first100} should end with .tar') + writer = wds.TarWriter(args.first100) + for i, x in enumerate(tqdm.tqdm(dataset, desc=f'Writing to samples to {args.first100}')): + if i == 100: + break + writer.write(x) + writer.close() + single_elem = next(iter(dataset)) + keys = list(single_elem.keys()) + + if args.batch_size > 1: + dataset = dataset.to_tuple(*keys).batched(args.batch_size, collation_fn=list_of_dict_to_dict_of_list) + + def printer(obj): + if isinstance(obj, (str, bytes, int, float)): + return obj + elif isinstance(obj, np.ndarray): + return f'np.ndarray(shape={obj.shape}, elem[:3]={np.ravel(obj)[:3]})' + else: + raise ValueError(f'Not sure how to print type {type(obj)}: {obj}') + print('== SINGLE EXAMPLE ==') + print(json.dumps(jax.tree_map(printer, single_elem), indent=2)) + # if args.batch_size > 1: + # print(f'== BATCH [N={args.batch_size}] EXAMPLE ==') + # print(json.dumps(jax.tree_map(printer, next(iter(dataset))), indent=2)) + if args.len: + start = time.time() + for i, x in enumerate(tqdm.tqdm(dataset, desc='iterating thru dataset for len')): + pass + elapsed = time.time() - start + print(f'Dataset length: {i+1}') + print(f'example/sec: {args.batch_size*(i+1)/elapsed:.3f}') + print(f'batch/sec: {(i+1)/elapsed:.3f}') diff --git a/rosetta/rosetta/projects/diffusion/denoisers.py b/rosetta/rosetta/projects/diffusion/denoisers.py new file mode 100644 index 000000000..c86ec7de6 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/denoisers.py @@ -0,0 +1,383 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diffusion-based denoisers + +This module builds a denoising model that can be used as a black box +for training and sampling with arbitrary methods +""" + +from typing import Mapping, Optional, Tuple, Union, Type +import abc +import typing_extensions + +from flax import linen as nn +from flax.core import scope as flax_scope +import jax +import jax.numpy as jnp + +PyTreeDef = Type[type(jax.tree_util.tree_structure(None))] + +class PrecondSigmaFnCallable(typing_extensions.Protocol): + """ Call signature for sigma-dependant preconditioning function """ + def __call__(self, sigma:jnp.ndarray) -> jnp.ndarray: + ... + +class DenoisingFunctionCallableWithParams(typing_extensions.Protocol): + """ Call signature for a denoising function """ + def __call__(self, + params: PyTreeDef, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> jnp.ndarray: + ... + +class DenoisingFunctionCallable(typing_extensions.Protocol): + """ Call signature for a denoising function """ + def __call__(self, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> jnp.ndarray: + ... + +class DenoisingFunctionWithAuxCallable(typing_extensions.Protocol): + """ Call signature for a denoising function """ + def __call__(self, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + ... + +class NoisePredictorCallable(typing_extensions.Protocol): + """ Call signature for a 'eps' or 'v' noise predicting function """ + def __call__(self, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> jnp.ndarray: + ... + +class NoisePredictorWithAuxCallable(typing_extensions.Protocol): + """ Call signature for a 'eps' or 'v' noise predicting function """ + def __call__(self, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + ... + +class Denoiser(abc.ABC): + """ + Model that returns a denoised sample given a noised sample and noise conditioning. + Implements: + $$D_\\theta(x; \\sigma) = c_{skip}(\\sigma)x + + c_{out}(\\sigma)F_\\theta(c_{in}(\\sigma)x; c_{noise}(\\sigma))$$ + from Karras et. al. EDM + """ + def __init__(self, + raw_model: nn.Module, + c_skip_fn: PrecondSigmaFnCallable, + c_out_fn: PrecondSigmaFnCallable, + c_in_fn: PrecondSigmaFnCallable, + c_noise_fn: PrecondSigmaFnCallable): + """ + Args: + raw_model: nn.Module that corresponds to $F_\\theta$ + c_skip_fn, c_out_fn, + c_in_fn, c_noise_fn: Functions of $\\sigma$ that correspond to their terms + in the $D_\\theta$ equation. + """ + self.module = raw_model + self.c_skip_fn = c_skip_fn + self.c_out_fn = c_out_fn + self.c_in_fn = c_in_fn + self.c_noise_fn = c_noise_fn + + @abc.abstractmethod + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """ Apply raw neural net """ + + def prob_grad(self, params, noised_sample, sigma) -> jnp.ndarray: + """ + Computes the gradient of the probability distribution wrt x: + $ \\nabla_x log(p(x; \\sigma) = (D_\\theta - x) / \\sigma**2 $ + """ + return (self.denoise_sample(params, noised_sample, sigma) - noised_sample) / sigma ** 2 + + def denoise_sample(self, + params: PyTreeDef, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None, + flax_mutables: Optional[PyTreeDef]=None, + ) -> jnp.ndarray: + """ Returns denoised sample given noised sample and conditioning """ + sigma = expand_dims_like(sigma, noised_sample) + skip_scale = self.c_skip_fn(sigma) + out_scale = self.c_out_fn(sigma) + in_scale = self.c_in_fn(sigma) + noise_cond = self.c_noise_fn(sigma) + + batch = {'samples': in_scale * noised_sample, 'noise_cond': noise_cond} + if other_cond is not None: + batch.update(other_cond) + + if 'low_res_images' in batch.keys(): + noise_aug_level = batch.get('noise_aug_level', jnp.ones_like(sigma) * 0.002) + low_res_noise_cond = self.c_noise_fn(noise_aug_level) + low_res_in_scale = self.c_in_fn(noise_aug_level) + low_res_batch = {'low_res_images': low_res_in_scale * batch['low_res_images'], 'noise_aug_level': low_res_noise_cond} + batch.update(low_res_batch) + + return skip_scale * noised_sample + \ + out_scale * self.apply_module(params, batch, dropout_rng, other_variables=flax_mutables) + +class EDMUnconditionalDenoiser(Denoiser): + """ Denoiser that implements the training regime from Karras et. al. EDM """ + def __init__(self, + raw_model: nn.Module, + sigma_data: float=0.5): + self.sigma_data = sigma_data + c_skip_fn = lambda sigma: (sigma_data ** 2) / (sigma ** 2 + sigma_data ** 2) + c_out_fn = lambda sigma: (sigma * sigma_data) / jnp.sqrt(sigma_data ** 2 + sigma ** 2) + c_in_fn = lambda sigma: 1.0 / jnp.sqrt(sigma ** 2 + sigma_data ** 2) + c_noise_fn = lambda sigma: 0.25 * jnp.log(sigma) + + super().__init__(raw_model, c_skip_fn, c_out_fn, c_in_fn, c_noise_fn) + + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes module output via a forward pass of `self.module`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': rngs} if rngs is not None else None + if other_variables is None: + other_variables = {} + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['samples'], + batch['noise_cond'], + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + +class EDMTextConditionedDenoiser(EDMUnconditionalDenoiser): + """ + Denoiser that implements the training regime from Karras et. al. EDM + and accepts text conditioning. + """ + + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes module output via a forward pass of `self.module`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': rngs} if rngs is not None else None + if other_variables is None: + other_variables = {} + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['samples'], + batch['noise_cond'], + batch['text'], + batch['text_mask'], + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + +class EDMTextConditionedSuperResDenoiser(EDMTextConditionedDenoiser): + """ + Denoiser that implements the training regime from Karras et. al. EDM + and accepts text and lowres conditioning. + """ + + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes module output via a forward pass of `self.module`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': rngs} if rngs is not None else None + if other_variables is None: + other_variables = {} + # jax.debug.print('noise aug {n}', n=batch.get('noise_aug_level')) + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['samples'], + batch['noise_cond'], + text_enc=batch['text'], + text_lens=batch['text_mask'], + low_res_images=batch['low_res_images'], + noise_aug_level=batch.get('noise_aug_level', None), + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + +class EDMLatentConditionalDenoiser(EDMUnconditionalDenoiser): + """ + Denoiser that implements the EDM training regime and accepts latent-preconditioning + for RIN networks + """ + + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes module output via a forward pass of `self.module`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': rngs} if rngs is not None else None + if other_variables is None: + other_variables = {} + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['samples'], + batch['noise_cond'], + batch['prev_latents'], + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + +class VP_EPSNoisePredictor(abc.ABC): + """ + Model that returns a denoised sample given a noised sample and noise conditioning. + Implements: + $$pred = c_{out}(\\sigma)F_\\theta(c_{in}(\\sigma)x; c_{noise}(\\sigma))$$ + from Karras et. al. EDM + """ + def __init__(self, + raw_model: nn.Module): + """ + Args: + raw_model: nn.Module that corresponds to $F_\\theta$ + c_noise_fn, c_out_fn, + c_in_fn : Functions of $\\sigma$ that correspond to their terms + in the $pred$ equation. + """ + self.module = raw_model + + @abc.abstractmethod + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """ Apply raw neural net """ + + def prob_grad(self, params, noised_sample, sigma) -> jnp.ndarray: + """ + Computes the gradient of the probability distribution wrt x: + $ \\nabla_x log(p(x; \\sigma) = (D_\\theta - x) / \\sigma**2 $ + """ + return (self.pred_sample(params, noised_sample, sigma) - noised_sample) / sigma ** 2 + + def pred_sample(self, + params: PyTreeDef, + noised_sample: jnp.ndarray, + sigma: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None, + dropout_rng: Optional[jax.random.KeyArray]=None + ) -> jnp.ndarray: + """ Returns denoised sample given noised sample and conditioning """ + sigma = expand_dims_like(sigma, noised_sample) + + batch = {'samples': noised_sample, 'noise_cond': sigma} + if other_cond is not None: + batch.update(other_cond) + + return self.apply_module(params, batch, dropout_rng) + +class VP_EPSNoisePredictorTextConditional(VP_EPSNoisePredictor): + """ + Denoiser that implements the VP Eps prediction + and accepts text conditioning. + """ + + def apply_module( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rngs: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes module output via a forward pass of `self.module`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': rngs} if rngs is not None else None + if other_variables is None: + other_variables = {} + return self.module.apply( + { + 'params': params, + **other_variables + }, + batch['samples'], + batch['noise_cond'], + batch['text'], + enable_dropout=rngs is not None, + rngs=rngs, + mutable=mutable) + + +def expand_dims_like(target, source): + return jnp.reshape(target, target.shape + (1, ) * (len(source.shape) - len(target.shape))) \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/losses.py b/rosetta/rosetta/projects/diffusion/losses.py new file mode 100644 index 000000000..7ec3b37ec --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/losses.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diffusion training losses + +This module includes loss functions for various diffusion model +training regimes +""" + +from typing import Callable, Tuple, Mapping, Optional, Union +import typing_extensions +import functools + +import jax +import jax.numpy as jnp +import jax.debug +from rosetta.projects.diffusion.augmentations import AugmentationCallable +from rosetta.projects.diffusion.denoisers import DenoisingFunctionCallable, NoisePredictorCallable + +PyTreeDef = type(jax.tree_util.tree_structure(None)) + +class DiffusionLossCallable(typing_extensions.Protocol): + """ Call signature for a diffusion loss function. + Returns the loss and the noises used """ + def __call__(self, + denoise_fn: Callable, + rng: jax.random.KeyArray, + samples: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ... + +class EPSDiffusionLossCallable(typing_extensions.Protocol): + """ Call signature for a diffusion loss function based on noise(eps) prediction. + Returns the loss and the noises used """ + def __call__(self, + eps_predictor: Callable, + rng: jax.random.KeyArray, + samples: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]] + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + ... + +class EDMLoss: + """ EDM loss from Karras et. al. 2022""" + def __init__(self, p_mean=-1.2, p_std=1.2, sigma_data=0.5, + sample_aug_fn: Optional[AugmentationCallable]=None, + cond_aug_fn: Optional[AugmentationCallable]=None, + dim_noise_scalar=1.): + self.p_mean = p_mean + self.p_std = p_std + self.sigma_data = sigma_data + self.sample_aug_fn = sample_aug_fn + self.cond_aug_fn = cond_aug_fn + self.dim_noise_scalar = dim_noise_scalar + + def _loss_weight(self, sigma): + sigma = sigma / self.dim_noise_scalar + return (sigma ** 2 + self.sigma_data ** 2) / (sigma * self.sigma_data) ** 2 + + def _noise_sampler(self, rng: jax.random.KeyArray, count: int, dim_scalar:float=1.): + rnd_normal = jax.random.normal(rng, (count, )) + return jnp.exp(rnd_normal * self.p_std + self.p_mean) * dim_scalar + + + + def __call__(self, + denoise_fn: Callable, + rng: jax.random.KeyArray, + samples: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + Returns the EDM loss and the noises used given a denoiser and samples. + Args: + denoise_fn: black-box function that denoises an image given + samples, sigmas, and other conditioning + rng: rng for sampling sigmas, noise, and optionally dropout + samples: array of samples to be diffused + other_cond: arbitrary other conditioning to pass to the + denoiser. Could be text conditioning. + enable_dropout: should use dropout + """ + dropout_rng, sigma_rng, noise_rng = jax.random.split(rng, 3) + + if self.sample_aug_fn: + samples, dropout_rng = self.sample_aug_fn(samples, dropout_rng) + if self.cond_aug_fn: + other_cond, dropout_rng = self.cond_aug_fn(other_cond, dropout_rng) + + batch_dim = samples.shape[0] + sigma = self._noise_sampler(sigma_rng, batch_dim, self.dim_noise_scalar) + sigma = expand_dims_like(sigma, samples) + weight = jnp.reshape(self._loss_weight(sigma), batch_dim) + + noise = jax.random.normal(noise_rng, samples.shape, samples.dtype) + noised_sample = samples + noise * sigma + + denoised = denoise_fn(noised_sample, sigma, other_cond, dropout_rng=dropout_rng) + sq_err = (denoised - samples) ** 2 + loss_unweighted = jnp.mean(jnp.reshape(sq_err, (batch_dim, -1)), axis=-1) + return weight * loss_unweighted, jnp.mean(loss_unweighted), sigma + +class EDMSuperResolutionLoss(EDMLoss): + def __call__(self, + denoise_fn: Callable, + rng: jax.random.KeyArray, + samples: jnp.ndarray, + other_cond: Optional[Mapping[str, jnp.ndarray]]=None + ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + lowres_aug_rng, noise_rng, rng = jax.random.split(rng, 3) + + assert other_cond and 'low_res_images' in other_cond.keys(), f'Superresolution loss requires a low_res_image in the other_cond of the sample. One was not found' + lowres = other_cond['low_res_images'] + batch_dim = samples.shape[0] + sigma = self._noise_sampler(lowres_aug_rng, batch_dim, 1.) + sigma = expand_dims_like(sigma, lowres) + + noise = jax.random.normal(noise_rng, lowres.shape, lowres.dtype) + noised_low_res = lowres + noise * sigma + + other_cond = {'low_res_samples': noised_low_res, 'noise_aug_level': sigma, **other_cond} + + return super().__call__(denoise_fn, rng, samples, other_cond) + +def expand_dims_like(target, source): + return jnp.reshape(target, target.shape + (1, ) * (len(source.shape) - len(target.shape))) \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/mm_utils.py b/rosetta/rosetta/projects/diffusion/mm_utils.py new file mode 100644 index 000000000..43f16ca79 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/mm_utils.py @@ -0,0 +1,555 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""General utility functions for diffusion/wds """ +from jax.experimental import multihost_utils +import tensorflow as tf +import typing_extensions +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union, List +from flax.linen import partitioning as flax_partitioning +from t5x import partitioning +import jax +import jax.numpy as jnp +import numpy as np +import seqio +import time + +from absl import logging +import dataclasses +import collections +import collections.abc +from concurrent.futures import thread +import contextlib +import dataclasses +import functools +import importlib +import inspect +import os +import re +import time +import typing +from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple, Type, Union +import warnings +import gc + +from absl import logging +import clu.data +from flax import traverse_util +import flax.core +from flax.core import scope as flax_scope +from flax.linen import partitioning as flax_partitioning +import jax +from jax.experimental import multihost_utils +import jax.numpy as jnp +import numpy as np +import orbax.checkpoint +import seqio +from t5x import checkpoints +from t5x import optimizers +from t5x import partitioning +from t5x import state_utils +from t5x import train_state as train_state_lib +import tensorflow as tf +from tensorflow.io import gfile +import typing_extensions +from rosetta.projects.diffusion import wds_utils +from rosetta.projects.diffusion.denoisers import DenoisingFunctionCallableWithParams, DenoisingFunctionCallable + +Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] +PyTreeDef = type(jax.tree_structure(None)) +PartitionSpec = partitioning.PartitionSpec +DType = Union[np.dtype, type(jnp.bfloat16)] +Shape = Tuple[int, ...] + +# ----------------------------------------------------------------------------- +# SeqIO utility functions. +# ----------------------------------------------------------------------------- + + +def import_module(module: str): + """Imports the given module at runtime.""" + logging.info('Importing %s.', module) + try: + importlib.import_module(module) + except RuntimeError as e: + if (str(e) == + 'Attempted to add a new configurable after the config was locked.'): + raise RuntimeError( + 'Your Task/Mixture module contains gin configurables that must be ' + 'loaded before gin flag parsing. One fix is to add ' + f"'import {module}' in your gin file.") + raise e + +class ShardInfo: + def __init__(self, id, ct): + self.id = id + self.ct = ct + +def get_dataset(cfg: wds_utils.WebDatasetConfig, + shard_id: int, + num_shards: int, + feature_converter_cls: Type[seqio.FeatureConverter], + num_epochs: Optional[int] = None, + continue_from_last_checkpoint: bool = False, should_batch=True) -> tf.data.Dataset: + """Returns a dataset from webdataset based on a `WebDatasetConfig`.""" + if continue_from_last_checkpoint: + raise ValueError( + '`continue_from_last_checkpoint` must be set to False as this is not ' + 'supported by this dataset fn.') + del continue_from_last_checkpoint + + if cfg.batch_size % num_shards: + raise ValueError( + f'Batch size ({cfg.batch_size}) must be divisible by number of ' + f'shards ({num_shards}).') + + if cfg.seed is None: + # Use a shared timestamp across devices as the seed. + seed = multihost_utils.broadcast_one_to_all(np.int32(time.time())) + else: + seed = cfg.seed + + shard_info = ShardInfo(shard_id, num_shards) + return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed, + num_epochs, should_batch) + +def get_dataset_inner(cfg: wds_utils.WebDatasetConfig, + shard_info: ShardInfo, + feature_converter_cls: Callable[..., + seqio.FeatureConverter], + seed: Optional[int] = None, + num_epochs: Optional[int] = None, + should_batch=True): + """Internal fn to load a dataset from WebDataset based on a `WebDatasetConfig`.""" + batch_size = cfg.batch_size // shard_info.ct + if seed is not None: + multihost_utils.assert_equal( + np.array(seed), + f'`seed` is not same across hosts; {jax.process_index} has a seed of ' + f'{seed}') + logging.info( + "Initializing dataset for task '%s' with a replica batch size of %d and " + 'a seed of %d', cfg.mixture_or_task_name, batch_size, seed) + + # should_batch implies that we will add a batch dimension ourselves to the loaded data + print("SHOULD BATCH?", should_batch) + ds, out_shapes, out_types = wds_utils.get_mm_wds_from_urls(cfg, batch_size=batch_size if should_batch else -1) + if should_batch: + for k in out_shapes.keys(): + out_shapes[k] = (batch_size,) + out_shapes[k] + + ds = tf.data.Dataset.from_generator( + generator=ds.__iter__, output_types=out_types, output_shapes=out_shapes + ) + if cfg.samples: + add_fake_length_method(ds, cfg.samples) + return ds + +class GetDatasetCallable(typing_extensions.Protocol): + + def __call__(self, + cfg: wds_utils.WebDatasetConfig, + shard_id: int, + num_shards: int, + feature_converter_cls: Callable[..., seqio.FeatureConverter], + num_epochs: Optional[int] = None, + continue_from_last_checkpoint: bool = True) -> tf.data.Dataset: + ... + +def multihost_assert_equal(input_tree, fail_message: str = ''): + """Verifies that all the hosts have the same tree of values.""" + # Internal mock TPU handling + multihost_utils.assert_equal(input_tree, fail_message) +class InferStepWithRngCallable(typing_extensions.Protocol): + + def __call__(self, + params: Mapping[str, Any], + batch: Mapping[str, jnp.ndarray], + rng: jnp.ndarray = None) -> PyTreeDef: + """Runs an inference step returning a prediction or score.""" + ... + + +class InferStepWithoutRngCallable(typing_extensions.Protocol): + + def __call__(self, params: Mapping[str, Any], + batch: Mapping[str, jnp.ndarray]) -> PyTreeDef: + """Runs an inference step returning a prediction or score.""" + ... + + +InferStepCallable = Union[InferStepWithRngCallable, InferStepWithoutRngCallable] + +# NOTE: We're not more prescriptive than PyTreeDef because that's what +# InferStepCallable expects. +_InferFnResult = Sequence[Tuple[int, PyTreeDef]] +_InferFnWithAuxResult = Tuple[_InferFnResult, Mapping[str, Sequence[Any]]] + + +class InferFnCallable(typing_extensions.Protocol): + + def __call__( + self, + ds: tf.data.Dataset, + train_state: train_state_lib.TrainState, + rng: Optional[jnp.ndarray] = None + ) -> Union[_InferFnResult, _InferFnWithAuxResult]: + """Runs inference on the dataset.""" + ... + + +def _remove_padding(all_inferences, all_indices): + """Remove padded examples. + + Args: + all_inferences: PyTree[total_examples + padding_count, ...]. + all_indices: [total_examples + padding_count]. + + Returns: + all_inferences in shape PyTree[total_examples, ...]. + all_indices in shape [total_exmamples]. + """ + non_pad_idxs = np.where(all_indices >= 0) + all_indices = all_indices[non_pad_idxs] + all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences) + return all_inferences, all_indices + + +def get_infer_fn(infer_step: InferStepCallable, batch_size: int, + train_state_axes: train_state_lib.TrainState, + partitioner: partitioning.BasePartitioner, num_samples:Optional[int]=None, + return_batch_keys:Optional[List[str]]=None) -> InferFnCallable: + """Get prediction function for the SeqIO evaluator. + + The returned prediction function should take in an enumerated dataset, make + predictions and return in an enumerated form with the original indices and + examples zipped together. This ensures that the predictions are compared to + the targets in a correct order even if the dataset is sharded across + multiple hosts and gathered in a nondeterministic way. + + jax.process_index == 0 is used as a "main host", i.e., it gathers all + inference results and returns. + + Shape notation: + Per replica set num replicas: R + Per replica set batch size: B + Number of replica sets: H + Length: L + + Some transformations have shape transformation annotation, e.g., + [B, L] -> [R, B/R, L]. + + Args: + infer_step: a callable that executes one prediction step. Should not yet be + partitioned or pmapped. + batch_size: the number of examples in the global infer batch. + train_state_axes: Partitioning info for the train state object. + partitioner: partitioner to use. + + Returns: + predict_fn: a callable which takes in the enumerated infer dataset and an + optimizer and runs the prediction. + """ + + return_batch = return_batch_keys is not None + def infer_step_with_indices(params, batch, rng, indices): + if 'rng' in inspect.signature(infer_step).parameters: + res = typing.cast(InferStepWithRngCallable, infer_step)(params, batch, + rng) + else: + res = typing.cast(InferStepWithoutRngCallable, infer_step)(params, batch) + if return_batch: + return indices, res, batch + else: + return indices, res + + outs = (None, ) * 3 if return_batch else (None, ) * 2 + partitioned_infer_step = partitioner.partition( + infer_step_with_indices, + in_axis_resources=(train_state_axes.params, + partitioner.data_partition_spec, None, + partitioner.data_partition_spec), + out_axis_resources=outs) + + data_layout = partitioner.get_data_layout(batch_size) + shard_id = data_layout.shard_id + num_shards = data_layout.num_shards + + per_shard_batch_size = batch_size // num_shards + num_batches = num_samples // batch_size + + def infer_fn(ds: tf.data.Dataset, + train_state: train_state_lib.TrainState, + rng: Optional[jnp.ndarray] = None): + ds_shapes = jax.tree_map(lambda x: jnp.array(x.shape), ds.element_spec) + ds = ds.enumerate() + multihost_assert_equal( + ds_shapes, 'Dataset element shapes do not agree across hosts. ' + 'This could be an indication that the dataset is nondeterministic.') + try: + original_ds_length = num_samples #len(ds) + dataset_remainder = original_ds_length % batch_size # pytype:disable=wrong-arg-types + logging.info('length of dataset = %s', num_samples)# len(ds)) + except TypeError as e: + if str(e) == 'dataset length is unknown.': + logging.warning( + 'The following error is likely due to the use of TensorFlow v1 in ' + 'your dataset pipeline. Verify you are not importing from ' + '`tf.compat.v1` as part of your pipeline.') + raise e + + if dataset_remainder: + dataset_pad_amt = batch_size - dataset_remainder + logging.info( + 'Padding infer dataset with %d examples for even per-replica shards.', + dataset_pad_amt) + # Pad with the first example using an index of -1 so seqio will ignore. + pad_ds = ds.take(1).map(lambda i, x: (np.int64(-1), x)).repeat( + dataset_pad_amt) + ds = ds.concatenate(pad_ds) + + # Shard the infer dataset across replica sets. + sharded_ds = ds.shard(num_shards, shard_id).batch( + per_shard_batch_size, drop_remainder=True) + # multihost_assert_equal( + # jnp.array(len(sharded_ds)), + # 'Dataset lengths do not agree across hosts.') + + logging.info( + 'The infer dataset is sharded into %d shards with per-shard ' + 'batch size of %d', num_shards, per_shard_batch_size) + + # Run inference for each replica set. + batched_results, all_indices, batched_return = [], [], [] + for batch_idx, (index, infer_batch) in enumerate(sharded_ds.as_numpy_iterator()): + logging.info(str(index)) + if batch_idx >= num_batches: + break + if rng is None: + step_rng = None + else: + step_rng, rng = jax.random.split(rng) + # Run fast inference on batch. + # [B, ...] -> [B * shard_count, ...] + # partitioned_infer_step executes infer_step on sharded batched data, and + # returns de-sharded batched indices and result replicated on all hosts. + + if jax.config.jax_array and jax.process_count() > 1: + logging.info('in array conf. array shape is ' + str(jax.tree_map(lambda x: x.shape, infer_batch))) + inputs = multihost_utils.host_local_array_to_global_array( + (infer_batch, step_rng, index), partitioner.mesh, + (partitioner.data_partition_spec, None, + partitioner.data_partition_spec)) + logging.info('input batch shape' + str(tree_shape(inputs[0]))) + if return_batch: + batch_indices, batch_result, batch_ret = partitioned_infer_step( + train_state.params, *inputs) + logging.info('out batch shape' + str(tree_shape(batch_ret))) + batch_indices, batch_result, batch_ret = multihost_utils.global_array_to_host_local_array( + (batch_indices, batch_result, batch_ret), partitioner.mesh, (None, None, None)) + + else: + batch_indices, batch_result = partitioned_infer_step( + train_state.params, *inputs) + + batch_indices, batch_result = multihost_utils.global_array_to_host_local_array( + (batch_indices, batch_result), partitioner.mesh, (None, None)) + + logging.info('out shape' + str(jax.tree_map(lambda x: x.shape, batch_result))) + logging.info('out idx shape' + str(jax.tree_map(lambda x: x.shape, batch_indices))) + else: + if return_batch: + batch_indices, batch_result, batch_ret = partitioned_infer_step( + train_state.params, infer_batch, step_rng, index) + else: + batch_indices, batch_result = partitioned_infer_step( + train_state.params, infer_batch, step_rng, index) + logging.info('Inference of batch %s done.', index) + + + def _copy_to_host_async(x): + if hasattr(x, 'addressable_data'): + # Array is fully replicated. + x.addressable_data(0).copy_to_host_async() + return x.addressable_data(0) + else: + x.copy_to_host_async() + return x + + try: + logging.info("full result " + str(jax.tree_map(lambda x: x.shape, batch_result))) + batch_result = jax.tree_map(_copy_to_host_async, batch_result) + if return_batch_keys: + if return_batch_keys == True: + ret = batch_ret + else: + ret = {} + for k in return_batch_keys: + ret[k] = batch_ret[k] + batch_return = ret + else: + batch_return = None + batch_indices = jax.tree_map(_copy_to_host_async, batch_indices) + except AttributeError: + # Similar to jax.device_get, we skip transfers for non DeviceArrays. + pass + + logging.info('out idx shape after copy' + str(jax.tree_map(lambda x: x.shape, batch_indices))) + + batched_results.append(batch_result) + if return_batch_keys: + batched_return.append(batch_return) + all_indices.append(batch_indices) + logging.info('returns' + str(tree_shape(batched_return))) + + logging.info('Inference of all batches done.') + all_inferences = batched_results + + # List[B * shard_count, ...] -> [B * shard_count * batch_count, ...] + all_inferences = jax.tree_map(lambda *args: np.concatenate(args), + *all_inferences) + all_indices = np.concatenate(all_indices) + logging.info(str(tree_shape(all_inferences)) + str(tree_shape(all_indices))) + + all_inferences, all_indices = _remove_padding(all_inferences, all_indices) + + # Results are returned from infer_step out of order due to shard operation. + # Note: remove padding first, as -1 indices would mess up this operation. + # Note: all_inferences may be a PyTree, not just an array, e.g. if + # `infer_step` is `model.predict_batch_with_aux`. + logging.info(str(tree_shape(all_inferences)) + str(tree_shape(all_indices))) + if return_batch_keys: + all_batches = jax.tree_map(lambda *args: np.concatenate(args), + *batched_return) + + # aux_values is supposed to be a dictionary that maps strings to a set of + # auxiliary values. + # + # We don't want to flatten/unflatten the aux values. We want to preserve the + # unflattened values with the type List[Mapping[str, Sequence[Any]]]. We do + # this as a memory optimization to avoid lots of redundant keys if we'd + # instead had List[Mapping[str, Any]]. + # + # It has shape Mapping[str, [B * shard_count * batch_count, ...]]. That is, + # the first dimension of each of the values in aux_values is equal to + # len(all_inferences). + aux_values = None + if (isinstance(all_inferences, tuple) and len(all_inferences) == 2 and + isinstance(all_inferences[1], Mapping)): + all_inferences, aux_values = all_inferences + + # Translate to List[...] by flattening inferences making sure to + # preserve structure of individual elements (inferences are not assumed to + # be simple np.array). Finally, zip inferences with corresponding indices + # and convert leaf np.arrays into lists. + if return_batch_keys: + indices_and_outputs = (all_indices, all_inferences, all_batches) + + else: + indices_and_outputs = (all_indices, all_inferences) + + logging.info('final idxes ' + str(all_indices)) + logging.info('final out ' + str(tree_shape(indices_and_outputs))) + if indices_and_outputs[0].shape[0] != original_ds_length: + raise ValueError( + 'Size of indices_and_outputs does not match length of original ' + 'dataset: %d versus %d' % + (indices_and_outputs[0].shape[0], original_ds_length)) + + if aux_values is None: + return indices_and_outputs + else: + aux_values = jax.tree_map(lambda x: np.array(x).tolist(), aux_values) + return indices_and_outputs, aux_values + + return infer_fn + +class DiffusionSamplingEvaluator: + def __init__(self, dataset_cfg, dataset, log_dir=None, fixed_rng=True): + self.dataset_cfg = dataset_cfg + self.dataset = dataset + self.log_dir = log_dir + self.rng = jax.random.PRNGKey(0) + self.keep_random = fixed_rng + self.eval_tasks = [1] #non empty + + def evaluate(self, + compute_metrics: bool, + step: int, + predict_fn: Optional[Callable] = None, + score_fn: Optional[Callable] = None, + predict_with_aux_fn: Optional[Callable] = None, + ): + samples = predict_fn(self.dataset) + + save_dir = '{}/samples/{}_trainsteps'.format(self.log_dir, step) + try: + os.makedirs(save_dir) + except: + pass + + if jax.process_index() == 0: + import matplotlib.image as matimg + logging.info('Saving samples to {}'.format(save_dir)) + for i in range(len(samples)): + np_arr = np.clip(samples[i][1], a_min = 0, a_max = 1) + matimg.imsave(os.path.join(save_dir, '{}-{}.png'.format(jax.process_index(), i)), np_arr) + if len(samples) == 3: + np_arr_batch = (samples[2]['low_res_images'][i] + 1) / 2. + logging.info(str(np_arr_batch)) + np_arr_batch = np.clip(np_arr_batch, a_min = 0, a_max = 1) + matimg.imsave(os.path.join(save_dir, 'dataset-{}-{}.png'.format(jax.process_index(), i)), np_arr_batch) + + multihost_utils.sync_global_devices('eval') + return None, None + +def expand_dims_like(target, source): + return jnp.reshape(target, target.shape + (1, ) * (len(source.shape) - len(target.shape))) + +def tree_shape(tree): + return jax.tree_map(lambda x: x.shape, tree) + +def add_fake_length_method(obj, size): + def length(self): + return size + + Combined = type( + obj.__class__.__name__ + "_Length", + (obj.__class__,), + {"__len__": length}, + ) + obj.__class__ = Combined + return obj + +def _copy_to_host_async(x): + if hasattr(x, 'addressable_data'): + # Array is fully replicated. + x.addressable_data(0).copy_to_host_async() + return x.addressable_data(0) + else: + x.copy_to_host_async() + return x diff --git a/rosetta/rosetta/projects/diffusion/models.py b/rosetta/rosetta/projects/diffusion/models.py new file mode 100644 index 000000000..29d0f32d2 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/models.py @@ -0,0 +1,236 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Diffusion Models. + +This module wraps around networks.py to integrate training, sampling, and construction. +""" + +from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union, Type + +import abc +import clu.metrics as clu_metrics +from flax import core as flax_core +from flax import linen as nn +from flax.core import scope as flax_scope +from flax.linen import partitioning as flax_partitioning +from flax.training import common_utils +import jax +import jax.numpy as jnp +import numpy as np +from einops import rearrange +from t5x import metrics as metrics_lib +from t5x import optimizers +from t5x.models import BaseModel +import tensorflow as tf +import typing_extensions +import functools + +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import samplers + +Array = Union[np.ndarray, jnp.ndarray, tf.Tensor] +MetricsMap = metrics_lib.MetricsMap +PyTreeDef = Type[type(jax.tree_util.tree_structure(None))] +BatchType = Mapping[str, jnp.ndarray] + +class DiffusionBase(BaseModel): + def __init__(self, + optimizer_def: optimizers.OptimizerDefType): + super().__init__(optimizer_def=optimizer_def) + self.FEATURE_CONVERTER_CLS=None # for t5x trainer compatibility + + def eval_fn(self, + params: PyTreeDef, + batch: BatchType, + ) -> Tuple[jnp.ndarray, MetricsMap]: + return self.loss_fn(params, batch, dropout_rng=None) + + def score_batch(self, + params: PyTreeDef, + batch: BatchType, + return_intermediates: bool = False) -> jnp.ndarray: + raise NotImplementedError("Batch scoring not supported by Diffusion Models") + + +class DenoisingDiffusionModel(DiffusionBase): + """ Wrapper for a denoiser with an arbirary training scheme """ + def __init__(self, + denoiser: denoisers.Denoiser, + diffusion_loss: losses.DiffusionLossCallable, + diffusion_sampler: samplers.DiffusionSampler, + optimizer_def: optimizers.OptimizerDefType, + sampling_cfg: Optional[samplers.SamplingConfig]=None): + self.denoiser = denoiser + self.diffusion_loss = diffusion_loss + self.sampler = diffusion_sampler + self.sampling_cfg = sampling_cfg + super().__init__(optimizer_def=optimizer_def) + + def _compute_metrics(self, + loss: jnp.ndarray, + loss_unweighted: jnp.ndarray, + avg_sigma: jnp.ndarray, + num_examples: int) -> MetricsMap: + return compute_basic_diffusion_metrics(loss, loss_unweighted, avg_sigma, num_examples) + + def _denoise_fn(self, + params: PyTreeDef, + flax_mutables: Optional[PyTreeDef] = None, + ): + return functools.partial(self.denoiser.denoise_sample, params, flax_mutables=flax_mutables) + + def loss_fn(self, + params: PyTreeDef, + batch: BatchType, + dropout_rng: Optional[jax.random.KeyArray], + flax_mutables: Optional[PyTreeDef] = None, + ) -> Tuple[jnp.ndarray, MetricsMap]: + denoise_fn = self._denoise_fn(params, flax_mutables) + + samples = batch['samples'] + other_cond = {k: batch[k] for k in batch if k != 'samples'} + batch_dim = samples.shape[0] + + loss, loss_unweighted, sigma = self.diffusion_loss(denoise_fn, dropout_rng, samples, other_cond) + + loss = jnp.mean(loss) + avg_sigma = jnp.mean(sigma) + return loss, self._compute_metrics(loss, loss_unweighted, avg_sigma, batch_dim) + + def predict_batch(self, + params: PyTreeDef, + batch: BatchType, + rng: Optional[jax.random.KeyArray] = None, + *, + sampling_cfg: Optional[samplers.SamplingConfig] = None, + ) -> jnp.ndarray: + return self.predict_batch_with_aux(params, batch, rng=rng, sampling_cfg=sampling_cfg)[0] + + def predict_batch_with_aux(self, + params: PyTreeDef, + batch: BatchType, + rng: Optional[jax.random.KeyArray] = None, + *, + sampling_cfg: Optional[samplers.SamplingConfig] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + denoise_fn = self._denoise_fn(params) + sampling_cfg = sampling_cfg if sampling_cfg is not None else self.sampling_cfg + + if rng is None: + ValueError("RNG is not optional for diffusion model sampling") + exit() + else: + l_rng, rng = jax.random.split(rng) + + batch_samples=batch['samples'] + other_cond = {k: batch[k] for k in batch if k != 'samples'} + latent = jax.random.normal(l_rng, batch_samples.shape) + + print("Running sampling at resolution: ", batch_samples.shape) + step_idxs = jnp.arange(0, sampling_cfg.num_steps) + samples = self.sampler.sample(denoise_fn, step_idxs, latent, rng, other_cond, + sampling_cfg=sampling_cfg) + return samples, {'None': None} + + def get_initial_variables( + self, + rng: jax.random.KeyArray, + input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, jnp.dtype]] = None + ) -> flax_scope.FrozenVariableDict: + """Returns the initial variables of the model.""" + input_types = {} if input_types is None else input_types + sample_shape = input_shapes['samples'] + print("sample shape", sample_shape) + sample_dtype = input_types.get('samples', jnp.float32) + sigma_shape = input_shapes.get('timesteps', (sample_shape[0],)) + + if len(sigma_shape) != 1: + print("BAD SIGMA SHAPE: ", str(sigma_shape), " going to ", sample_shape[0]) + sigma_shape = sample_shape[0] + sigma_dtype = input_types.get('timesteps', jnp.float32) + print("Init Shapes: Sample: ", sample_shape, " Sigma: ", sigma_shape) + + inits = (jnp.ones(sample_shape, sample_dtype), jnp.ones(sigma_shape, sigma_dtype)) + + low_res_type = input_types.get('low_res_images', None) + # jax.debug.print(str(input_shapes)) + + text_enc_dtype = input_types.get('text', None) + if text_enc_dtype is not None: + text_enc_shape = input_shapes.get('text',None) + text_mask_dtype = input_types.get('text_mask', None) + text_mask_shape = input_shapes.get('text_mask', None) + + init_txt = jnp.ones(text_enc_shape, text_enc_dtype) + init_txt_mask = jnp.ones(text_mask_shape, text_mask_dtype) + inits = inits + (init_txt, init_txt_mask) + + if low_res_type is not None: + low_res_shape = input_shapes.get('low_res_images', None) + aug_level_shape = input_shapes.get('noise_aug_level', sigma_shape) + aug_level_type = input_types.get('noise_aug_level', sigma_dtype) + jax.debug.print(str(low_res_shape)) + inits = inits + (jnp.ones(low_res_shape, low_res_type), jnp.ones(aug_level_shape, aug_level_type)) + + initial_variables = self.denoiser.module.init( + rng, + *inits, + enable_dropout=False) + return initial_variables + +def compute_basic_diffusion_metrics( + loss: jnp.ndarray, + loss_unweighted: jnp.ndarray, + avg_sigma: jnp.ndarray, + num_examples: int, +) -> MetricsMap: + """Compute summary metrics. + + Args: + loss: loss (float) + mean_sigma: mean sigma noises used (float) + num_examples (int) number of examples in batch + + Returns: + Dict of metrics. + """ + num_devices = jax.device_count() + assert num_devices, 'JAX is reporting no devices, but it should.' + # Note: apply mask again even though mask has already been applied to loss. + # This is needed to divide by mask sum, but should not affect correctness of + # the numerator. + metrics = { + 'loss': + metrics_lib.AveragePerStep(total=loss), + 'loss_unweighted': + metrics_lib.AveragePerStep(total=loss_unweighted), + 'timing/images_per_second': + metrics_lib.TimeRate.from_model_output(numerator=num_examples), + 'timing/steps_per_second': + metrics_lib.StepsPerTime.from_model_output(), + 'timing/seconds': + metrics_lib.Time(), + 'timing/images': + metrics_lib.Sum(num_examples), + 'timing/images_per_second_per_core': + metrics_lib.TimeRate.from_model_output(numerator=num_examples / + num_devices), + 'diff_stats/avg_sigma': + metrics_lib.AveragePerStep(total=avg_sigma), + + } + return metrics diff --git a/rosetta/rosetta/projects/diffusion/samplers.py b/rosetta/rosetta/projects/diffusion/samplers.py new file mode 100644 index 000000000..d3918e511 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/samplers.py @@ -0,0 +1,339 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Diffusion/Score Matching Samplers + +This module holds samplers that use black box denoisers +""" + +from typing import Mapping, Optional, Tuple, Callable, Sequence +import typing_extensions +import abc +import dataclasses + +import jax +import jax.numpy as jnp +from rosetta.projects.diffusion.denoisers import DenoisingFunctionCallable +from rosetta.projects.diffusion.mm_utils import expand_dims_like + +PyTreeDef = type(jax.tree_util.tree_structure(None)) +BatchType = Mapping[str, jnp.ndarray] + +@dataclasses.dataclass +class SamplingConfig: + num_steps: int = 50 + generation_shape: Optional[Sequence[int]] = None + +@dataclasses.dataclass +class CFGSamplingConfig(SamplingConfig): + cf_guidance_weight: Optional[float] = None + cf_guidance_nulls: Optional[Mapping[str, Optional[jnp.ndarray]]] = None + +class DiffusionSamplerCallable(typing_extensions.Protocol): + """ Call signature for a diffusion sampling function. + Returns the samples.""" + def __call__(self, + denoise_fn: Callable, + step_indices: jnp.ndarray, + latent: jnp.ndarray, + rng: Optional[jax.random.KeyArray], + other_cond: Optional[Mapping[str, jnp.ndarray]], + sampling_cfg: Optional[SamplingConfig]=None + ) -> jnp.ndarray: + ... + +class DiffusionSampler(abc.ABC): + @abc.abstractmethod + def sample(self, + denoise_fn: Callable, + step_indices: jnp.ndarray, + latent: jnp.ndarray, + rng: Optional[jax.random.KeyArray], + other_cond: Optional[Mapping[str, jnp.ndarray]], + sampling_cfg: Optional[SamplingConfig]=None + ) -> jnp.ndarray: + pass + + def apply_cf_guidance(self, with_cond: jnp.ndarray, no_cond: jnp.ndarray, guidance_weight:float) -> jnp.ndarray: + """ + Applies classifier-free guidance. + + Args: + with_cond: Model output, assumed to have shape [b, ...] + no_cond: Model output, assumed to have shape [b, ...] + guidance_weight: cf guidance weight + """ + diff = with_cond - no_cond + + guided = with_cond + guidance_weight * diff + return guided + + +identity = lambda x:x +class EDMSampler(DiffusionSampler): + """ + Samples using a denoising model as per Karras et. al. EDM Algorithm 2 + """ + def __init__(self, + sigma_min: float = 0.002, + sigma_max: float = 80, + rho: float = 7, + S_churn: float = 0, + S_min: float = 0, + S_max: float = float('inf'), + S_noise: float = 1.0, + round_sigma: Callable = identity, + dim_noise_scalar: float = 1.0, + ): + self.sigma_min = sigma_min + self.sigma_max = sigma_max + self.rho = rho + self.S_churn = S_churn + self.S_min = S_min + self.S_max = S_max + self.S_noise = S_noise + self.round_sigma = round_sigma + self.dim_noise_scalar = dim_noise_scalar + + + def _sample_noise(self, shape: Tuple, rng: jax.random.KeyArray): + return jax.random.normal(rng, shape) * self.S_noise + + def _scannable_single_step(self, denoise_fn, num_steps, t_steps, other_cond, null_cond, second_order_correct=True, cf_guidance_weight=None): + """ Wraps single_step_sample to make it usable in jax.lax.scan """ + def wrapped_fn(x_rng_state: Tuple[jnp.ndarray, jax.random.KeyArray], idx: int): + return self.single_step_sample(denoise_fn, num_steps, t_steps, x_rng_state[1], idx, \ + second_order_correct=second_order_correct, x_curr=x_rng_state[0], \ + other_cond=other_cond, null_cond=null_cond, cf_guidance_weight=cf_guidance_weight) + return wrapped_fn + + def _get_fori_body(self, denoise_fn: DenoisingFunctionCallable, + num_steps: int, + t_steps: jnp.ndarray, + other_cond:Optional[BatchType]=None, + null_cond:Optional[BatchType]=None, + second_order_correct=True, + cf_guidance_weight:Optional[float]=None): + def loop_body(step, args): + x, curr_rng = args + args, _ = self.single_step_sample(denoise_fn, num_steps, t_steps, curr_rng, step, \ + second_order_correct=second_order_correct, x_curr=x, \ + other_cond=other_cond, null_cond=null_cond, cf_guidance_weight=cf_guidance_weight) + return args + + return loop_body + + def _get_eps(self, denoise_fn:DenoisingFunctionCallable, + noised_x, t, + other_cond:Optional[BatchType]=None, + null_cond:Optional[BatchType]=None, + cf_guidance_weight:Optional[float]=None): + #Calculates a potentially CF-guided eps in one forward pass + batch_dim = noised_x.shape[0] + + # Setup concats for cf_guidance + if cf_guidance_weight is not None: + assert null_cond is not None, f"Using CF-guidance {cf_guidance_weight}. \ + You must provide a null_cond if doing classifier-free guidance. \ + It's currently None" + noised_x = jnp.concatenate([noised_x, noised_x], axis=0) + t = jnp.concatenate([t, t], axis=0) + concatenate_fn = lambda x, y: jnp.concatenate([x, y], axis=0) + other_cond = jax.tree_util.tree_map(concatenate_fn, other_cond, null_cond) + + denoised = denoise_fn(noised_sample=noised_x, sigma=t, other_cond=other_cond) + denoised = dynamic_thresholding(denoised) + eps = (noised_x - denoised) / t + + #Apply CF Guidance + if cf_guidance_weight is not None: + eps = self.apply_cf_guidance(eps[:batch_dim], eps[batch_dim:], cf_guidance_weight) + + return eps + + def single_step_sample(self, denoise_fn: DenoisingFunctionCallable, + num_steps: int, + t_steps: jnp.ndarray, + rng: jax.random.KeyArray, + t_idx:int, + x_curr: jnp.ndarray=None, + other_cond:Optional[BatchType]=None, + null_cond:Optional[BatchType]=None, + second_order_correct=True, + cf_guidance_weight:Optional[float]=None + ) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], jax.random.KeyArray]: + """ Single step of sampling """ + rng, step_rng = jax.random.split(rng) + + t_curr = t_steps[t_idx] + t_next = t_steps[t_idx + 1] + + # Increase noise temporarily + m_gamma = jax.lax.min(self.S_churn / num_steps, jnp.sqrt(2) - 1) + gamma = jax.lax.cond(self.S_min <= t_curr, + lambda:jax.lax.cond(t_curr <= self.S_max, lambda:m_gamma, lambda:0.0), lambda:0.0) + t_hat = self.round_sigma(t_curr + gamma * t_curr) + x_hat = x_curr + jnp.sqrt((t_hat ** 2 - t_curr ** 2)) * \ + self.S_noise * self._sample_noise(x_curr.shape, step_rng) + + # Shape matching + t_hat = batch_expand(t_hat, x_curr) + t_next = batch_expand(t_next, x_curr) + + # Denoising + eps = self._get_eps(denoise_fn, x_hat, t_hat, other_cond, null_cond, cf_guidance_weight) + + # Euler Step + x_next = x_hat + (t_next - t_hat) * eps + + # Second order correction if t_idx < num_steps - 1 + if second_order_correct: + corrected = self.second_order_correct(x_next, denoise_fn, x_hat, t_hat, t_next, eps, other_cond, null_cond, cf_guidance_weight) + else: + corrected = x_next + return (corrected, rng), None #denoised + + def second_order_correct(self, x_next, denoise_fn, x_hat, t_hat, t_next, eps, other_cond, null_cond, cf_guidance_weight=None + ) -> jnp.ndarray: + # Denoising + eps_prime = self._get_eps(denoise_fn, x_next, t_next, other_cond, null_cond, cf_guidance_weight) + + # 2nd order correction + x_next = x_hat + (t_next - t_hat) * (0.5 * eps + 0.5 * eps_prime) + return x_next + + def sample(self, + denoise_fn: Callable, + step_indices: jnp.ndarray, + latent: jnp.ndarray, + rng: jax.random.KeyArray, + other_cond: Optional[BatchType]=None, + sampling_cfg: Optional[SamplingConfig]=None + ) -> jnp.ndarray: + # Classifier-free guidance will be enabled if cf_guidance_weight is not None + if sampling_cfg is None or not hasattr(sampling_cfg, 'cf_guidance_weight'): + cf_guidance_weight = None + cf_guidance_nulls = None + else: + cf_guidance_weight = sampling_cfg.cf_guidance_weight + cf_guidance_nulls = sampling_cfg.cf_guidance_nulls + jax.debug.print("Using CF-Guidance weight {}".format(cf_guidance_weight)) + + # Find timesteps + r_rho = 1 / self.rho + timesteps = (self.sigma_max ** r_rho + step_indices / (step_indices.shape[0] - 1) * \ + (self.sigma_min ** r_rho - self.sigma_max ** r_rho)) ** self.rho + timesteps = self.dim_noise_scalar * timesteps + timesteps = jnp.concatenate((self.round_sigma(timesteps), jnp.zeros_like(timesteps[:1]))) + + # Sampling Loop + null_cond = None + if cf_guidance_weight is not None: + assert other_cond is not None, "other_cond is None. Cannot do cf-guidance without any conditioning" + null_cond = assemble_cf_guidance_conds(other_cond, cf_guidance_nulls) + + prior = latent * timesteps[0] + loop_body = self._get_fori_body(denoise_fn, num_steps=step_indices.shape[0], \ + t_steps=timesteps, other_cond=other_cond, null_cond=null_cond, \ + second_order_correct=True,cf_guidance_weight=cf_guidance_weight) + samples, rng = jax.lax.fori_loop(0, step_indices.shape[0] - 1, loop_body, (prior, rng)) + + # Last step (no second_order_correct) + (samples, _), denoised = self.single_step_sample(denoise_fn, step_indices.shape[0], timesteps, rng, step_indices.shape[0] - 1, other_cond=other_cond, \ + null_cond=null_cond, x_curr=samples, second_order_correct=False, cf_guidance_weight=cf_guidance_weight) + jax.debug.print("single final step") + + return (samples + 1) / 2 + + # A Data Parallel sampling loop that uses a pjitted denoise_fn call + def sample_loop(self, + denoise_fn: Callable, + sampling_cfg: SamplingConfig, + latent: jnp.ndarray, + rng: jax.random.KeyArray, + other_cond: Optional[BatchType]=None, + )-> jnp.ndarray: + # Classifier-free guidance will be enabled if cf_guidance_weight is not None + if not hasattr(sampling_cfg, 'cf_guidance_weight'): + cf_guidance_weight = None + cf_guidance_nulls = None + else: + cf_guidance_weight = sampling_cfg.cf_guidance_weight + cf_guidance_nulls = sampling_cfg.cf_guidance_nulls + jax.debug.print("Using CF-Guidance weight {}".format(cf_guidance_weight)) + + # Find timesteps + step_indices = jnp.arange(sampling_cfg.num_steps) + r_rho = 1 / self.rho + timesteps = (self.sigma_max ** r_rho + step_indices / (step_indices.shape[0] - 1) * \ + (self.sigma_min ** r_rho - self.sigma_max ** r_rho)) ** self.rho + timesteps = jnp.concatenate((self.round_sigma(timesteps), jnp.zeros_like(timesteps[:1]))) + + # Sampling Loop + null_cond = None + if cf_guidance_weight is not None: + assert other_cond is not None, "other_cond is None. Cannot do cf-guidance without any conditioning" + jax.debug.inspect_array_sharding(other_cond, callback=print) + null_cond = assemble_cf_guidance_conds(other_cond, cf_guidance_nulls) + jax.debug.print("Assembed conds") + + prior = jnp.asarray(latent, jnp.float64) * timesteps[0] + for time_idx in range(sampling_cfg.num_steps - 1): + timestep = timesteps[sampling_cfg.num_steps - 1 - time_idx] + step_fn = self._scannable_single_step(denoise_fn, step_indices.shape[0], timesteps, other_cond, null_cond, second_order_correct=True, cf_guidance_weight=cf_guidance_weight) + (samples, rng), denoised = jax.lax.scan(step_fn, (prior, rng), jnp.arange(0, step_indices.shape[0] - 1)) + jax.debug.print("scanned") + + # Last step (no second_order_correct) + (samples, _), denoised = self.single_step_sample(denoise_fn, step_indices.shape[0], timesteps, rng, step_indices.shape[0] - 1, other_cond=other_cond, \ + null_cond=null_cond, x_curr=samples, second_order_correct=False, cf_guidance_weight=cf_guidance_weight) + jax.debug.print("single final step") + + samples = (samples + 1) / 2 + + repl_samples = jax.pjit(lambda x: x, in_shardings=None, out_shardings=None)(samples) + return repl_samples + + + +def assemble_cf_guidance_conds(other_cond: BatchType, + guidance_nulls:Optional[Mapping[str, Optional[jnp.ndarray]]]) -> BatchType: + null_cond = {} + for k, v in other_cond.items(): + if guidance_nulls is None or k in guidance_nulls.keys(): + null_cond_val = None + # If no explicit 'null' is provided, use zeros_like + if guidance_nulls is None or guidance_nulls[k] is None: + null_cond_val = jnp.zeros_like(v) + else: + null_cond_val = guidance_nulls[k] + null_cond[k] = null_cond_val + else: + null_cond[k] = v + + return null_cond + +def dynamic_thresholding(denoised, p=99.5): + s = jnp.percentile( + jnp.abs(denoised), p, + axis=tuple(range(1, denoised.ndim)), + keepdims=True) + s = jnp.max(jnp.concatenate([s, jnp.ones_like(s)]), axis=0) + return jnp.clip(denoised, -s, s) / s + +def batch_expand(scalar: jnp.ndarray, imitate: jnp.ndarray): + """ Match batch dimension and expand rank to match 'imitate' """ + out = scalar * jnp.ones(imitate.shape[0], scalar.dtype) + return expand_dims_like(out, imitate) \ No newline at end of file diff --git a/rosetta/rosetta/projects/diffusion/tests/augmentations_test.py b/rosetta/rosetta/projects/diffusion/tests/augmentations_test.py new file mode 100644 index 000000000..9e22495e7 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/tests/augmentations_test.py @@ -0,0 +1,78 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for rosetta.projects.diffusion.augmentations.""" + +from absl.testing import absltest +import jax +import jax.numpy as jnp +import numpy as np +import sys +from rosetta.projects.diffusion import augmentations + +class AugTest(absltest.TestCase): + + def test_text_cond_aug_array(self): + in_arr = jnp.ones((3, 4, 2)) + rng_in = jax.random.PRNGKey(0) + masked, rng = augmentations.text_conditioning_dropout(in_arr, rng_in, dropout_rate=0.5) + + output = jnp.ones((3,4,2)) + output = output.at[0, :].set(0) + output = output.at[1, :].set(0) + assert jnp.allclose(output, masked), f'expected: {output}, got: {masked}' + assert (rng_in != rng).all() + assert rng is not None + + def test_text_cond_aug_mapping(self): + in_arr = {'text_mask': jnp.ones((3, 4, 2)), 'another_key':jnp.ones((1, 4, 2))} + rng = jax.random.PRNGKey(0) + masked, rng = augmentations.text_conditioning_dropout(in_arr, rng, dropout_rate=0.5) + + output = jnp.ones((3,4,2)) + output = output.at[0, :].set(0) + output = output.at[1, :].set(0) + masked_arr = masked['text_mask'] + assert jnp.allclose(output, masked_arr), f'expected: {output}, got: {masked_arr}' + masked_ones = masked['another_key'] + assert jnp.allclose(masked_ones, jnp.ones((1, 4, 2))), f'expected ones, got: {masked_ones}' + assert list(masked.keys()) == ['text_mask', 'another_key'], 'modified keys in batch' + + def test_text_cond_aug_array_preserved(self): + in_arr = jnp.ones((3, 4, 2)) + in_arr = in_arr.at[2, 2:, 1:].set(0) + rng = jax.random.PRNGKey(0) + masked, rng = augmentations.text_conditioning_dropout(in_arr, rng, dropout_rate=0.5) + + output = jnp.ones((3,4,2)) + output = output.at[0, :].set(0) + output = output.at[1, :].set(0) + output = output.at[2, 2:, 1:].set(0) + assert jnp.allclose(output, masked) + + def test_text_cond_aug_jit(self): + in_arr = jnp.ones((3, 4, 2)) + in_arr = in_arr.at[2, 2:, 1:].set(0) + rng = jax.random.PRNGKey(0) + masked, rng = jax.jit(augmentations.text_conditioning_dropout)(in_arr, rng, dropout_rate=0.5) + output = jnp.ones((3,4,2)) + output = output.at[0, :].set(0) + output = output.at[1, :].set(0) + output = output.at[2, 2:, 1:].set(0) + assert jnp.allclose(output, masked) + + +if __name__ == '__main__': + sys.path.append('../') + import rosetta.projects.diffusion.augmentations as augmentations + absltest.main() diff --git a/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.tar b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.tar new file mode 100644 index 000000000..1d57efbb7 Binary files /dev/null and b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.tar differ diff --git a/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt new file mode 100644 index 000000000..df0caad47 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt @@ -0,0 +1,8 @@ +a photograph of an astronaut riding a horse +a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal +a highly realistic picture of a dog wearing a green spotted bowtie and black top hat. +a black apple and green backpack +a photorealistic dragon flying over a grassy mountain +drawing of a cartoon computer displaying an exclamation mark +picture of a dystopian city with an ominous red portal in the sky. Aliens are emerging from the portal. +a selfie of a man wearing a blue jacket diff --git a/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/make_custom_prompt_wds.py b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/make_custom_prompt_wds.py new file mode 100644 index 000000000..3ce0978cf --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts/make_custom_prompt_wds.py @@ -0,0 +1,16 @@ +import webdataset as wds +import sys + +BASENAME='custom_eval_prompts' +with open(f'{BASENAME}.txt', 'r') as f: + prompts = f.readlines() + +sink = wds.TarWriter(f"{BASENAME}.tar") +for index, line in enumerate(prompts): + if index%1000==0: + print(f"{index:6d}", end="\r", flush=True, file=sys.stderr) + sink.write({ + "__key__": "sample%06d" % index, + "txt": line.strip(), + }) +sink.close() diff --git a/rosetta/rosetta/projects/diffusion/wds_utils.py b/rosetta/rosetta/projects/diffusion/wds_utils.py new file mode 100644 index 000000000..7ad220af8 --- /dev/null +++ b/rosetta/rosetta/projects/diffusion/wds_utils.py @@ -0,0 +1,551 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import functools +import logging +import os +import pickle as pkl +import random +from pathlib import Path +from PIL import Image +import numpy as np +from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, Iterable + +import jax +from jax import tree_util +import seqio +import t5.data +import tensorflow as tf +import webdataset as wds +from pytriton.client import ModelClient +import rosetta.data.multiloader as multiloader +import braceexpand + +seqio_vocab = t5.data.get_default_vocabulary() +server_list = None + +# ----------------------------------------------------------------------------- +# Configurations +# ----------------------------------------------------------------------------- + +@dataclasses.dataclass +class ModalityConfig: + ftype: Optional[str] + out_type: Tuple[Any] + shape: Tuple[int] + process_func: Optional[Callable] + prefilter_func: Optional[Callable]=None + no_load: bool = False # Don't load modality from webdataset pipeline. Used for modalities that are created from others' process_funcs + +@dataclasses.dataclass +class WebDatasetConfig: + """Configuration for loading a WebDataset""" + mixture_or_task_name: Union[str, Iterable[str]] + batch_size: int + shuffle: bool + seed: Optional[int] + + # Controls total number of samples. Ignored in training + samples: Optional[int] = None + modalities: Optional[Mapping[str, ModalityConfig]] = None #will error if None + batch_proc: Optional[Callable] = None + hostnames_file: Optional[str] = None + num_parallel_processes: int = 16 + pack=False # Webdataset doesn't currently support text packing + +# ----------------------------------------------------------------------------- +# Data processing utils +# ----------------------------------------------------------------------------- + +# ------------------ +# Image +# ------------------ + +def image_crop_scale(image, out_img_shape=(32, 32, 3), nhwc=True): + """ + Resizes image to out_img_shape by first doing an aspect-preserving resize + to match the short side of the image to the out_img_shape, then doing a + random crop. + + Assumes image is ranged [0, 1] + """ + if not nhwc: + # if non nhwc output is desired, out_img_shape should be given in nchw format. + # We transpose it here to make it compatible with processing and transpose later + out_img_shape = (out_img_shape[1], out_img_shape[2], out_img_shape[0]) + curr_img_shape = image.shape + + # square crop randomly with min dimension. + min_dim = min(curr_img_shape[:2]) + + left = random.randint(0, curr_img_shape[1] - min_dim) + right = left + min_dim + top = random.randint(0, curr_img_shape[0] - min_dim) + bottom = top + min_dim + image = image[top:bottom, left:right] + + #resize to final dimensions + image = np.asarray(Image.fromarray((image * 255).astype(np.uint8)).resize(out_img_shape[:2], resample=Image.BILINEAR)) / 255. + + image = image * 2 - 1 # [-1, 1] ranging + if not nhwc: + image = np.transpose(image, (2, 0, 1)) + return image + +def image_crop_scale_with_lowres(image, out_img_shape=(32, 32, 3), low_res_img_shape=(32, 32, 3), nhwc=True): + """ + Resizes image to out_img_shape by first doing an aspect-preserving resize + to match the short side of the image to the out_img_shape, then doing a + random crop. + Further returns a downsized version of this final image for superresolution training + + Assumes image is ranged [0, 1] + """ + if not nhwc: + # if non nhwc output is desired, out_img_shape should be given in nchw format. + # We transpose it here to make it compatible with processing and transpose later + out_img_shape = (out_img_shape[1], out_img_shape[2], out_img_shape[0]) + low_res_img_shape = (low_res_img_shape[1], low_res_img_shape[2], low_res_img_shape[0]) + curr_img_shape = image.shape + + # square crop randomly with min dimension. + min_dim = min(curr_img_shape[:2]) + + left = random.randint(0, curr_img_shape[1] - min_dim) + right = left + min_dim + top = random.randint(0, curr_img_shape[0] - min_dim) + bottom = top + min_dim + image = image[top:bottom, left:right] + + #resize to final dimensions + image = Image.fromarray((image * 255).astype(np.uint8)).resize(out_img_shape[:2], resample=Image.BILINEAR) + image_large = np.asarray(image) / 255. + image_lowres = np.asarray(image.resize(low_res_img_shape[:2], resample=Image.BILINEAR)) / 255. + + image_large = image_large * 2 - 1 # [-1, 1] ranging + image_lowres = image_lowres * 2 - 1 # [-1, 1] ranging + if not nhwc: + image_large = np.transpose(image_large, (2, 0, 1)) + image_lowres = np.transpose(image_lowres, (2, 0, 1)) + return {'samples': image_large, 'low_res_images': image_lowres} + +def image_subcrop_scale_with_lowres(image, init_image_shape=(1024,1024,3), crop_shape=(256,256,3), low_res_img_shape=(64,64,3), nhwc=True): + """ + Does a random crop of an image by first resizing it to a target resolution then doing a random crop. + Further returns a downsized version of this final image for superresolution training + + Assumes image is ranged [0, 1] + """ + + image = image_crop_scale(image, out_img_shape=init_image_shape, nhwc=nhwc) + if not nhwc: + # if non nhwc output is desired, out_img_shape should be given in nchw format. + # We transpose it here to make it compatible with processing and transpose later + out_img_shape = (out_img_shape[1], out_img_shape[2], out_img_shape[0]) + low_res_img_shape = (low_res_img_shape[1], low_res_img_shape[2], low_res_img_shape[0]) + crop_shape = (crop_shape[1], crop_shape[2], crop_shape[0]) + curr_img_shape = image.shape + + # square crop randomly + # min_dim = min(curr_img_shape[:2]) + crop_width = crop_shape[1] + crop_height = crop_shape[0] + + left = random.randint(0, curr_img_shape[1] - crop_width) + right = left + crop_width + top = random.randint(0, curr_img_shape[0] - crop_height) + bottom = top + crop_height + image_large = image[top:bottom, left:right] + + #resize to final dimensions + image = Image.fromarray((image_large * 255).astype(np.uint8))#.resize(out_img_shape[:2], resample=Image.BILINEAR) + image_lowres = np.asarray(image.resize(low_res_img_shape[:2], resample=Image.BILINEAR)) / 255. + + image_large = image_large * 2 - 1 # [-1, 1] ranging + image_lowres = image_lowres * 2 - 1 # [-1, 1] ranging + if not nhwc: + image_large = np.transpose(image_large, (2, 0, 1)) + image_lowres = np.transpose(image_lowres, (2, 0, 1)) + return {'samples': image_large, 'low_res_images': image_lowres} + +def blank_image(out_img_shape=(32, 32, 3), nhwc=True): + """ Dummy image creator. Used for sampling and dummy datasets """ + img = np.zeros(out_img_shape) + if not nhwc: + return np.transpose(img, (2, 0, 1)) + return img + +def filter_lowres(image, min_dims=(64,64,3), nhwc=True): + # returns false for images that don't meet the minimum dimensions specified. + # min_dims should be specified in the same format as images (NHWC or NCHW) + shape = image.shape + assert len(shape) == len(min_dims), f"Minimum dimension spec and image shape length need to match.\ + Given min_dims {min_dims} and image shape {shape}" + if not nhwc: + dims = [1] * len(min_dims) + dims[0] = min_dims[1] + dims[1] = min_dims[2] + dims[2] = min_dims[0] + min_dims = tuple(dims) + + for i in range(len(shape)): + if shape[i] < min_dims[i]: + return False + + return True + +# ------------------ +# Text +# ------------------ + +def triton_textencode(text_batch: List[str]): + """ Encodes a list of python strings into numpy character arrays """ + enc = np.array([[np.char.encode(i, 'utf-8')] for i in text_batch]) + enc = np.reshape(enc, (enc.shape[0], 1)) + return enc + +def seqio_tokenizer(shape=(128,)): + tok_config = {'text': seqio.Feature(vocabulary=t5.data.get_default_vocabulary(), add_eos=True)} + def tok(text): + unpad = seqio.preprocessors.tokenize_impl({'text': text}, tok_config, copy_pretokenized=False, with_eos=True)['text'] + padded = tf.pad(unpad, [[0, shape[0] - unpad.shape[0]]]) + return padded + return tok + +def mscoco_text_process(text_in, shape, vocab=seqio_vocab): + text = text_in['caption'] + + mask = np.zeros(shape[0]) #dummy mask + if text is None or not isinstance(text, str): + text = '' + + return {'text': text, 'text_mask': mask} + +def bare_txt_process(text, shape): + mask = np.ones(shape[0]) #dummy mask + if text is None or not isinstance(text, str): + text = '' + logging.info("WARNING: no text") + return {'text': text, 'text_mask': mask} + +def sd_clip_text_tokenize(string, tokenizer): + tok = np.array(tokenizer(string, max_length=tokenizer.model_max_length, padding="max_length", truncation=True).input_ids) + return {'input_ids': tok} + +def cls_process(cls): + return {'cls': cls} + +# ------------------ +# External inference +# ------------------ + +def dummy_batch_infer(batch, server_list=None, text_emb_shape:Tuple=(256, 4096), model_name:str='t5_xxl'): + IS_BATCHED = True + if not isinstance(batch['text'], list): + batch_dim = 1 + IS_BATCHED = False + else: + batch_dim = len(batch['text']) + + # construct text masks and padding + mask = np.zeros((batch_dim, text_emb_shape[0])).astype('int32') + rand_idxs = np.random.randint(text_emb_shape[0], size=batch_dim) + padded_embeds = np.random.normal(size=(batch_dim, *text_emb_shape)).astype('float32') + for i in range(rand_idxs.shape[0]): + mask[i, :rand_idxs[i]] = 1 + padded_embeds[i, rand_idxs[i]:] = 0 + + if not IS_BATCHED: + padded_embeds = padded_embeds[0] + mask = mask[0] + + batch['text'] = padded_embeds + batch['text_mask'] = mask + + return batch + + +def batch_infer_extern(batch, server_list=None, text_emb_shape:Tuple=(128, 4096), model_name:str='t5_xxl'): + """ + Calls a remote PyTriton server to get logits from a text transformer. + This is done batched for efficiency + """ + if server_list is None: + raise ValueError("server_list is required") + + text_batch = batch['text'] + IS_BATCHED=True + + full_text = text_batch + if not isinstance(full_text, list): + IS_BATCHED=False + full_text = [full_text] + + # encoding text batch for triton inference + encoded_batch = triton_textencode(full_text) + + recieved_out = False + try_ctr = 0 + + # keep trying until successful inference + padded_embeds = None + mask = None + while not recieved_out or padded_embeds is None: + rand_server = server_list[random.randint(0, len(server_list) - 1)] + with ModelClient(rand_server, model_name=model_name) as client: + try: + text_emb_dict = client.infer_batch(encoded_batch) + except: + text_emb_dict = None + + if text_emb_dict: + #embeds = [pkl.loads(inst) for inst in text_emb_dict['encodings']] + seqlens = text_emb_dict['encodings_seqlens'] + batch_dim = seqlens.shape[0] + padded_embeds = text_emb_dict['encodings_padded'] + # padding embeds up to full seqlen + if padded_embeds.shape[1] < text_emb_shape[0]: + embeds = np.zeros((batch_dim, *text_emb_shape), dtype='float16') + embeds[:, :padded_embeds.shape[1]] = padded_embeds + padded_embeds = embeds + mask = np.zeros((batch_dim, text_emb_shape[0])).astype('int32') + for idx in range(batch_dim): + mask[idx, :seqlens[idx]] = 1 + + if padded_embeds is None: + try_ctr += 1 + sam = text_batch[0] + logging.info(f"Server returned nothing. Retrying. Attempt {try_ctr}. {sam}") + else: + recieved_out = True + + assert padded_embeds is not None and mask is not None + if padded_embeds.shape[0] != len(full_text): + logging.info("WARNING: somehow, embeds and full_text don't match!") + + # construct text masks and padding + if not IS_BATCHED: + padded_embeds = padded_embeds[0] + mask = mask[0] + + batch['text'] = padded_embeds + batch['text_mask'] = mask + + return batch + +# ----------------------------------------------------------------------------- +# WebDataset construction and setup +# ----------------------------------------------------------------------------- + +def type_proc(dtype:str): + if dtype == 'float32': + return np.float32 + elif dtype == 'int': + return np.int32 + elif dtype == 'float16': + return np.float16 + elif dtype == 'bfloat16': + return jax.numpy.bfloat16 + else: + raise ValueError("Could not parse dtype: %s" % dtype) + +def run_preproc(sample:Any, keys:List[str]=[], modalities: Mapping[str, ModalityConfig]={}): + datapoint = {} + + non_loaded_ctr = 0 + for i in range(len(keys)): + k = keys[i] + process_fn = modalities[k].process_func + if modalities[k].no_load: + non_loaded_ctr += 1 + elif process_fn is not None and modalities[k].ftype is None: # for generating a fixed shape dummy sample + mod_out = process_fn() + unpack_assign_or_assign(key=k, value=mod_out, dictionary=datapoint) + non_loaded_ctr += 1 + else: + mod_out = process_fn(sample[i - non_loaded_ctr]) if process_fn is not None else sample[i - non_loaded_ctr] + unpack_assign_or_assign(key=k, value=mod_out, dictionary=datapoint) + + return datapoint + +def run_prefilter(sample:Any, keys:List[str]=[], modalities: Mapping[str, ModalityConfig]={}): + datapoint = {} + + non_loaded_ctr = 0 + for i in range(len(keys)): + k = keys[i] + prefilter_fn = modalities[k].prefilter_func + if modalities[k].no_load: + non_loaded_ctr += 1 + elif prefilter_fn is not None: + if not prefilter_fn(sample[i - non_loaded_ctr]): + return False + return True + +def unpack_assign_or_assign(key: Any, value: Any, dictionary: Dict, strict: bool = True): + '''Tries to unpack value ignoring key, if value is a dictionary; or assigns value to key as fallback''' + if isinstance(value, Mapping): + for nested_key, nested_value in value.items(): + if strict and nested_key in dictionary: + raise ValueError(f'strict=True, and key={nested_key} already in dictionary={dictionary}') + dictionary[nested_key] = nested_value + else: + if strict and key in dictionary: + raise ValueError(f'strict=True, and key={key} already in dictionary={dictionary}') + dictionary[key] = value + +def dict_batch(samples): + """ batches elements of the same key """ + outer = tree_util.tree_structure([0 for _ in samples]) + inner = tree_util.tree_structure(samples[0]) + return tree_util.tree_transpose(outer, inner, samples) + +def _simple_map(data, f, handler=wds.filters.reraise_exception): + """Map samples.""" + for sample in data: + try: + result = f(sample) + except Exception as exn: + if handler(exn): + continue + else: + break + if result is None: + continue + yield result + +simple_map = wds.filters.pipelinefilter(_simple_map) + +def url_process(url_str:Union[str, Iterable[str]]) -> Union[List[str], str]: + """ + If url_str is a directory, this function will return a list of + all .tar files found recursively in it. + If url_str is an iterable, expands all urls contained as directories + or braceexpands and concatenates them together + """ + logging.info(f'Processing URLS for {url_str}') + if isinstance(url_str, str): + url_str = [url_str] + paths = [] + for url in url_str: + if os.path.isdir(url): + url_paths = [str(p) for p in Path(url).rglob('*.tar')] + logging.info("{} tarfiles found in {}".format(len(url_paths), url)) + paths += url_paths + else: + logging.info(f'{url} doesn\'t seem to be a directory. Treating it as a path with braceexpand') + paths += list(braceexpand.braceexpand(url)) + return paths + +def get_mm_wds_from_urls(cfg: WebDatasetConfig, batch_size:int =-1) -> Tuple[Any, Mapping[str, Tuple[int]], Mapping[str, Any]]: + global server_list + + # Getting all urls + urls = url_process(cfg.mixture_or_task_name) + + # Setting up modalities (shapes and types) + modalities = cfg.modalities + assert modalities is not None, "Modalities cannot be None. Don't know how to process data!" + keys = list(modalities.keys()) + in_ftypes = [] + out_shapes = {} + out_types = {} + for k in keys: + m = modalities[k] + if m.ftype is not None: + in_ftypes.append(m.ftype) + unpack_assign_or_assign(key=k, value=m.shape, dictionary=out_shapes) + unpack_assign_or_assign(key=k, value=jax.tree_map(type_proc, m.out_type), dictionary=out_types) + + # Inference Server determination + if server_list is None: + if cfg.hostnames_file not in (None, "", "None"): + with open(cfg.hostnames_file, 'r') as f: + server_list = f.readlines() + for i in range(len(server_list)): + server_list[i] = server_list[i].strip() + else: + logging.info("No hostnames file. Will not initialize remote inferencing") + server_list = None + else: + logging.info('SERVER LIST GIVEN. Not reading from cfg.hostnames_file') + logging.info(server_list) + + preprocessor = functools.partial(run_preproc, keys=keys, modalities=modalities) + pre_filter = functools.partial(run_prefilter, keys=keys, modalities=modalities) + dataset = wds.WebDataset(urls, resampled=True).shuffle(0).decode("rgb").to_tuple(*in_ftypes).select(pre_filter).map(preprocessor) + if cfg.samples: + dataset = dataset.with_length(cfg.samples) + if batch_size > 0: + dataset = dataset.batched(batch_size, collation_fn=dict_batch) + if cfg.batch_proc is not None: + bp = functools.partial(cfg.batch_proc, server_list=server_list) + dataset = dataset.compose(simple_map(bp)) + + load = dataset + if cfg.num_parallel_processes > 1: + load = multiloader.MultiLoader(dataset, workers=cfg.num_parallel_processes) + return load, out_shapes, out_types + +def get_random_wds(cfg: WebDatasetConfig) -> Tuple[Any, Mapping[str, Tuple[int]], Mapping[str, Any]]: + ''' + same as get_mm_wds_from_urls, except random dataset for SOL test or if you don't have a dataset yet + ''' + # Setting up modalities (shapes and types) + modalities = cfg.modalities + assert modalities is not None, "Modalities cannot be None. Don't know how to process data!" + keys = list(modalities.keys()) + in_ftypes = [] + out_shapes = {} + out_types = {} + for k in keys: + m = modalities[k] + if m.ftype is not None: + in_ftypes.append(m.ftype) + unpack_assign_or_assign(key=k, value=m.shape, dictionary=out_shapes) + unpack_assign_or_assign(key=k, value=jax.tree_map(type_proc, m.out_type), dictionary=out_types) + + preprocessor = functools.partial(run_preproc, keys=keys, modalities=modalities) + def random_generator(wds_config: WebDatasetConfig, num_elements: int = 100): + for _ in range(num_elements): + datum = {} + for _, modality_config in wds_config.modalities.items(): + if isinstance(modality_config.shape, (tuple, list)): + datum[modality_config.ftype] = np.random.randint(size=modality_config.shape, low=0, high=2).astype(modality_config.out_type) + else: + datum[modality_config.ftype] = jax.tree_map( + lambda shape, dtype: np.random.randint(size=shape, low=0, high=2).astype(dtype), + modality_config.shape, + modality_config.out_type, + is_leaf=lambda shape: isinstance(shape, (list, tuple)), + ) + yield datum + + dataset = wds.DataPipeline( + lambda: random_generator(cfg), + wds.to_tuple(*in_ftypes), + wds.map(preprocessor), + ) + print(len(list(dataset)), 'a') + if cfg.batch_size > 0: + dataset.pipeline.append(wds.batched(cfg.batch_size, collation_fn=dict_batch)) + if cfg.batch_proc is not None: + bp = functools.partial(cfg.batch_proc, server_list=server_list) + dataset = dataset.compose(simple_map(bp)) + print(len(list(dataset)), 'b') + + if cfg.num_parallel_processes > 1: + raise NotImplementedError(f'No suport for parallel processes for random data generation') + return dataset, out_shapes, out_types \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/README.md b/rosetta/rosetta/projects/imagen/README.md new file mode 100644 index 000000000..4959a4118 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/README.md @@ -0,0 +1,141 @@ +# Imagen +[Imagen](https://arxiv.org/abs/2205.11487) is a text-to-image generative diffusion model that operates in pixel-space. This repository contains the necessary tools and scripts for performantly training Imagen from base model to its superresolution models in JAX on GPUs. + +![A racoon wearing a hat and leather jacket in front of a backyard window. There are raindrops on the window.](assets/A%20raccoon%20wearing%20a%20hat%20and%20black%20leather%20jacket%20is%20behind%20the%20backyard%20window.%20Rain%20droplets%20on%20the%20window_16.png) +![A blue colored pizza](assets/A%20blue%20coloured%20pizza_14.png) +![mystical portal man](assets/a%20highly%20detailed%20digital%20painting%20of%20a%20portal%20in%20a%20mystic%20forest%20with%20many%20beautiful%20trees.%20A%20person%20is%20standing%20in%20front%20of%20the%20portal_20.png) + +Prompts: +- A racoon wearning a hat and leather jacketin front of a backyard window. There are raindrops on the window +- A blue colored pizza +- a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal. + +## Architecture +For maximum flexibility and low disk requirements, this repo supports a **distributed architecture** for text embedding in diffusion model training. Upon launching training, it will spawn LLM inference servers that will performantly calculate text embeddings online (with no latency hit). It does this by creating several inference **clients** in the diffusion model trainer's dataloaders, which send embedding requests to the inference servers. These servers are based on [NVIDIA PyTriton](https://github.com/triton-inference-server/pytriton), so execute all requests batched. Currently, this inference server supports T5x LLMs, but can be changed to be based on anything (doesn't even have to be JAX!) since the diffusion model trainer's client is simply making PyTriton (http) calls. + +## GPU Scripts and Usage +We provide [scripts](scripts) to run [interactively](scripts/singlenode_inf_train.sh) or on [SLURM](scripts/example_slurm_inf_train.sub). + +### Container +We provide a fully built and ready-to-use container here: `ghcr.io/nvidia/t5x:imagen-2023-10-02`. + +We do not currently have custom-built container workflows, but are actively working on supporting this, stay tuned for updates! +Imagen will also be available in our T5x container in future releases. + +### Dataset +This model accepts webdataset-format datasets for training. For reference, we have an imagenet webdataset example [here](https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/vit#downloading-the-dataset).(NOTE: imagen is not directly compatible with imagenet). For imagen training with a compatible dataset, you can find or create your own webdataset (with image and text modalities). + +Once you have your webdataset, update the dataset configs {[base](configs/img-txt-ds-base.gin), [sr1](configs/img-txt-ds-sr1.gin), [sr2](configs/img-txt-ds-sr2.gin)} with the paths to your dataset(s) under ```MIXTURE_OR_TASK_NAME```. + + +The 'img-txt-ds' configs assume a webdataset with a text and image modality. The images are in jpg format and the text is raw text in a ```'.txt'``` file. Currently, the configs are set up to do resolution-based filtering, scale-preserved square random cropping, and low-resolution image generation for SR model training. This can be changed (i.e. if you want your text in ```.json``` format and want to do additional processing) in the dataset configuration files {[base](configs/img-txt-ds-base.gin), [sr1](configs/img-txt-ds-sr1.gin), [sr2](configs/img-txt-ds-sr2.gin)}. + +### Downloading the LLM checkpoint +You will need to acquire the LLM checkpoint for T5 (for multimodal training) from T5x [here](https://t5x.readthedocs.io/en/latest/models.html#t5-1-1-checkpoints). All models use T51.1 format T5-xxl by default. Once you have the checkpoint, place it at ```rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl``` (appending the ```_{size}``` to the checkpoint folder). **NOTE**: We're working on adding TransformerEngine support to the inference server, but for now, please run with the ```DISABLE_TE=True``` environment variable (example scripts include this). + +### Running interactively +**Note**: this should only be done with singlenode jobs + +```bash +CONTAINER=ghcr.io/nvidia/t5x:imagen-2023-10-02 +docker run --rm --gpus=all -it --net=host --ipc=host -v ${PWD}:/opt/rosetta -v ${DATASET_PATH}:/mnt/datasets --privileged $CONTAINER bash +``` + +### Single Node runs +Pretraining can be done on multiple gpus within 1 host with `scripts/singlenode_inf_train.sh`. This will build an Imagen model with the Adam optimizer and relevant parameters. It will also launch the relevant LLM inference servers. + +```bash +#### Pretraining (interactive: already inside container) with example args +bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS} {BSIZE/GPU} {LOGDIR} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE} + +#### Pretraining (non-interactive) +docker run --rm --gpus=all --net=host --ipc=host -v ${DATASET_PATH}:/mnt/datasets $CONTAINER bash rosetta/projects/imagen/scripts/singlenode_inf_train.sh {args from above} +``` + +### Multi Node runs +For a SLURM+pyxis cluster, the `scripts/example_slurm_inf_train.sub` file provides an example slurm submit file (edit with your details), which calls `scripts/multinode_train.sh` and `scripts/specialized_run.py` to execute training. + +### Pretraining run commands +All commands below assume you are in `$ROSETTA_DIR=/opt/rosetta` and have the scripts and slurm scripts locally. + +### Multinode +Arguments are set as such: +```sh +sbatch -N {NODE_CT} rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ +{DATASET NAME} {MODEL NAME} {PRECISION} {NUM GPUS / NODE} {BSIZE/GPU} {MODEL DIR} {NUM LLM INFERENCE GPUS} {INFERENCE SERVER LLM SIZE} +``` + +All parameters can be found in the relevant script. + +### Example training Commands +Assumes 8GPU 80GB A100/H100 Nodes. + +#### Imagen-base small (500M): +```sh +sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ +{DATASET} imagen_base_500M bfloat16 8 32 runs/imagen-base 48 xxl +``` + +#### Imagen-base large (2B): +```sh +sbatch -N 20 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ +{DATASET} imagen_base_2B bfloat16 8 16 runs/imagen-base 32 xxl +``` + +#### Imagen-sr1 (efficient unet) (600M): +```sh +sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ +{DATASET} imagen_sr1_efficientunet_600M bfloat16 8 32 runs/imagen-sr1 48 xxl +``` + +#### Imagen-sr2 (efficient unet) (600M): +```sh +sbatch -N 14 rosetta/projects/imagen/scripts/example_slurm_inf_train.sub \ +{DATASET} imagen_sr2_efficientunet_600M bfloat16 8 32 runs/imagen-sr2 48 xxl +``` + + +### Sampling +You can find example sampling scripts that use the 500M base model and EfficientUnet SR models in [scripts](scripts). Prompts should be specified as in [example](../diffusion/tests/custom_eval_prompts/custom_eval_prompts.txt) + +#### Sampling 256x256 images +Defaults to [imagen_256_sample.gin](configs/imagen_256_sample.gin) config (can be adjusted in script) +``` +CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_256.sh +``` + +#### Sampling 1024x1024 images +Defaults to [imagen_1024_sample.gin](configs/imagen_1024_sample.gin) config (can be adjusted in script). +``` +CUDA_VISIBLE_DEVICES= CFG=5.0 BASE_PATH= SR1_PATH= SR2_PATH= PROMPT_TEXT_FILES= ./rosetta/projects/imagen/scripts/sample_imagen_1024.sh +``` + + +## Convergence and Performance +Global Batch size = 2048. We assume 2.5B Training examples in these calculations. LLM Inference server nodes are not included in these numbers. +| size | GPU | Precision | #GPUs | BS / GPU | Images/Sec | Im/Sec/GPU | Est. Walltime (hr) | GPU-days | Config | +| ----------------------- | ------------ | --------- | ----- | -------- | ---------- | ---------- | ------------------ | -------- | ----------------------------------- | +| Imagen-base-500M | A100-80G-SXM | BF16 | 8 | 64 | 858 | 107.0 | 809 | 269 | [cfg](configs/imagen_base_500M.gin) | +| Imagen-base-500M | A100-80G-SXM | BF16 | 32 | 64 | 3056 | 95.5 | 227 | 303 | [cfg](configs/imagen_base_500M.gin) | +| Imagen-base-2B | A100-80G-SXM | BF16 | 8 | 16 | 219 | 27.4 | 3170 | 1057 | [cfg](configs/imagen_base_2B.gin) | +| Imagen-base-2B | A100-80G-SXM | BF16 | 32 | 16 | 795 | 24.8 | 873 | 1164 | [cfg](configs/imagen_base_2B.gin) | +| Imagen-base-2B | A100-80G-SXM | BF16 | 128 | 16 | 2934 | 22.9 | 236 | 1258 | [cfg](configs/imagen_base_2B.gin) | +| Imagen-SR1-600M-EffUNet | A100-80G-SXM | BF16 | 8 | 64 | 674 | 84.3 | 1030 | 343 | [cfg](configs/imagen_sr1_efficientunet_600M.gin) | +| Imagen-SR1-600M-EffUNet | A100-80G-SXM | BF16 | 32 | 64 | 2529 | 79.1 | 274 | 365 | [cfg](configs/imagen_sr1_efficientunet_600M.gin) | +| Imagen-SR2-600M-EffUNet | A100-80G-SXM | BF16 | 8 | 64 | 678 | 84.8 | 1024 | 341 | [cfg](configs/imagen_sr2_efficientunet_600M.gin) | +| Imagen-SR2-600M-EffUNet | A100-80G-SXM | BF16 | 32 | 64 | 2601 | 81.3 | 267 | 356 | [cfg](configs/imagen_sr2_efficientunet_600M.gin) | +| Imagen-SR1-430M-UNet | A100-80G-SXM | BF16 | 8 | 16 | 194 | 24.3 | 3580 | 1193 | [cfg](configs/imagen_sr1_unet_430M.gin) | + +`Imagen-SR1-430M-UNet` is not currently supported. You can use the sr1-efficient-unet instead. Coming Soon! + + +Imagen base 500M + Efficient SR1 (600M): +|cfg|FID-30K (256x256)| +| - |-----------------------------------------------| +| 2 | 11.30 | +| 3 | 10.23 | +| 4 | 11.33 | +| 6 | 12.34 | + +## Known Issues +* Currently, the nightly images will not be able to run Imagen since they lack a patch that needs refactoring. This will be released soon! diff --git a/rosetta/rosetta/projects/imagen/__init__.py b/rosetta/rosetta/projects/imagen/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/rosetta/rosetta/projects/imagen/assets/A blue coloured pizza_14.png b/rosetta/rosetta/projects/imagen/assets/A blue coloured pizza_14.png new file mode 100755 index 000000000..003f6cfde Binary files /dev/null and b/rosetta/rosetta/projects/imagen/assets/A blue coloured pizza_14.png differ diff --git a/rosetta/rosetta/projects/imagen/assets/A raccoon wearing a hat and black leather jacket is behind the backyard window. Rain droplets on the window_16.png b/rosetta/rosetta/projects/imagen/assets/A raccoon wearing a hat and black leather jacket is behind the backyard window. Rain droplets on the window_16.png new file mode 100755 index 000000000..557e3403a Binary files /dev/null and b/rosetta/rosetta/projects/imagen/assets/A raccoon wearing a hat and black leather jacket is behind the backyard window. Rain droplets on the window_16.png differ diff --git a/rosetta/rosetta/projects/imagen/assets/a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal_20.png b/rosetta/rosetta/projects/imagen/assets/a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal_20.png new file mode 100755 index 000000000..9c1575927 Binary files /dev/null and b/rosetta/rosetta/projects/imagen/assets/a highly detailed digital painting of a portal in a mystic forest with many beautiful trees. A person is standing in front of the portal_20.png differ diff --git a/rosetta/rosetta/projects/imagen/configs/dummy-base.gin b/rosetta/rosetta/projects/imagen/configs/dummy-base.gin new file mode 100644 index 000000000..ccef461ef --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/dummy-base.gin @@ -0,0 +1,71 @@ +from __gin__ import dynamic_registration +import jax.numpy as jnp +from rosetta.projects.diffusion import mm_utils +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import wds_utils +import t5.data + +IM_SHAPE=%gin.REQUIRED #[h, w, c] i.e. (64,64,3) +TXT_SHAPE=%gin.REQUIRED #[l, c] i.e. (128, 4096) +TXT_SEQLEN=%gin.REQURED #[l,] i.e. (128). Should match dim[0] above. Must be a tuple (include trailing comma) + +MIXTURE_OR_TASK_NAME = "/opt/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts" +MIXTURE_OR_TASK_NAME_SAMPLING = "/opt/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts" +DROPOUT_RATE = 0.0 +MODALITIES = {'samples': @images/wds_utils.ModalityConfig(), 'text': @text/wds_utils.ModalityConfig(), 'text_mask': @text_mask/wds_utils.ModalityConfig()} +MODALITIES_SAMPLE = {'samples': @images_sample/wds_utils.ModalityConfig(), 'text': @text_sample/wds_utils.ModalityConfig(), 'text_mask': @text_mask_sample/wds_utils.ModalityConfig()} +BATCH_PROC = @wds_utils.dummy_batch_infer + +SAMPLING_CONFIG = @samplers.CFGSamplingConfig() +samplers.CFGSamplingConfig: + num_steps=5 + cf_guidance_weight=7.50 + cf_guidance_nulls=None + +images/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.blank_image + +text/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +wds_utils.bare_txt_process: + shape = %TXT_SHAPE + +text_mask/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +wds_utils.dummy_batch_infer: + text_emb_shape=%TXT_SHAPE + +# Sampling modalities +images_sample/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.blank_image + +wds_utils.blank_image: + out_img_shape=%IM_SHAPE + +text_sample/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +text_mask_sample/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin new file mode 100644 index 000000000..2294a72d1 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_1024_sample.gin @@ -0,0 +1,78 @@ +# Imagen Sampling pipeline +include "rosetta/projects/imagen/configs/imagen_256_sample.gin" + +from __gin__ import dynamic_registration +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +#---------------- SR1024 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR1024 = @sr1024/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr1024/models.DenoisingDiffusionModel: + denoiser= @sr1024/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr1024/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr1024/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr1024/network_sr.ImagenEfficientUNet() + +sr1024/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr1024/network_sr.ImagenEfficientUNet.config = @sr1024/network_sr.ImagenEfficientUNetConfig() +sr1024/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 8) + width_multipliers = (1, 2, 4, 6, 6) + attn_resolutions_divs = {16: 'cross'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr1024/samplers.CFGSamplingConfig: + num_steps=30 + cf_guidance_weight=0.0 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr1024/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr1024/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr1024/sample_script.DiffusionModelSetupData: + model = %SR1024 + sampling_cfg = @sr1024/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr1024/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 1024, 1024, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 256, 256, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +sample_script.sample: + sr1024_setupdata = @sr1024/sample_script.DiffusionModelSetupData() \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin new file mode 100644 index 000000000..93e87d066 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_256_sample.gin @@ -0,0 +1,219 @@ +# Imagen Sampling pipeline +from __gin__ import dynamic_registration + +import __main__ as sample_script +from t5x import gin_utils +from t5x import utils +from t5x import partitioning + +SAVE_DIR='generations' +PROMPT_TEXT_FILE='custom_text.txt' +GLOBAL_BATCH_SIZE=32 +MAX_GENERATE=50000000 +GEN_PER_PROMPT=2 +NOISE_COND_AUG=0.002 + +TXT_SHAPE=(1, 128, 4096) #T5 xxl, seqlen x embed_dim +TXT_SEQLEN=(1, 128, ) +TXT_SEQLEN_SINGLE=128 +DTYPE='bfloat16' +DROPOUT_RATE=0 +RESUME_FROM=0 #Sampling count to resume from +#---------------- Base Model ------------------------------------------------- +from rosetta.projects.imagen import network +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +# ------------------- Model ---------------------------------------------------- +BASE = @base_model/models.DenoisingDiffusionModel() +base_model/models.DenoisingDiffusionModel: + denoiser= @base_model/denoisers.EDMTextConditionedDenoiser() + diffusion_loss = None + diffusion_sampler= @base_model/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +base_model/denoisers.EDMTextConditionedDenoiser: + raw_model= @base_model/network.ImagenUNet() + +# ------------------- Network specification ------------------------------------ +base_model/network.ImagenUNet.config = @base_model/network.DiffusionConfig() +base_model/network.DiffusionConfig: + dtype = %DTYPE + model_dim = 256 + attn_cond_dim = 512 + cond_dim = 1024 + resblocks_per_level = 3 + width_multipliers = (1, 2, 3, 4) + attn_resolutions = (32, 16, 8) + mha_head_dim = 64 + attn_heads = 4 + resblock_activation = 'silu' + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +BASE_SAMPLING_CONFIG = @base_model/samplers.CFGSamplingConfig() +base_model/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=5.00 + cf_guidance_nulls=None + +base_model/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +base_model/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +base_model/sample_script.DiffusionModelSetupData: + model = %BASE + sampling_cfg = @base_model/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @base_model/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 64, 64, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int'} + +#---------------- SR256 Model ------------------------------------------------- + +# ------------------- Model ---------------------------------------------------- +SR256 = @sr256/models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +sr256/models.DenoisingDiffusionModel: + denoiser= @sr256/denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= None + diffusion_sampler= @sr256/samplers.EDMSampler() + optimizer_def = None + +# |--- Denoiser +sr256/denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @sr256/network_sr.ImagenEfficientUNet() + +sr256/samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +sr256/network_sr.ImagenEfficientUNet.config = @sr256/network_sr.ImagenEfficientUNetConfig() +sr256/network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 512 + attn_cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 2) + width_multipliers = (1, 2, 4, 8, 8) + attn_resolutions_divs = {8: 'fused', 16: 'fused'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True + +sr256/samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=4 + cf_guidance_nulls={'text': None, 'text_mask': None} + +sr256/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +sr256/utils.RestoreCheckpointConfig: + mode = 'specific' + dtype = 'bfloat16' + +sr256/sample_script.DiffusionModelSetupData: + model = %SR256 + sampling_cfg = @sr256/samplers.CFGSamplingConfig() + restore_checkpoint_cfg = @sr256/utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + input_shapes = {'samples': (1, 256, 256, 3), 'text': %TXT_SHAPE, 'text_mask': %TXT_SEQLEN, 'low_res_images': (1, 64, 64, 3)} + input_types = {'samples': 'float32', 'text': 'float16', 'text_mask': 'int', 'low_res_images': 'float32'} + +#---------------- Text Model ------------------------------------------------- +import seqio +from rosetta.projects.inference_serving.t5 import network as t5x_network +from rosetta.projects.inference_serving.t5 import models as t5x_models + +# ===================================== +# === T5 Encoder only configuration === +# ===================================== +T5_CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" +BATCH_SIZE = 256 # Will be overridden +SEQ_LEN = 128 # MAX seqlen + +# Vocabulary +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# --------------- Model ------------------ +TEXT_ENC = @text_enc/t5x_models.EncoderOnlyModel() +text_enc/t5x_models.EncoderOnlyModel: + module = @t5x_network.TransformerEncoderOnly() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = None + z_loss = 0.0001 + label_smoothing = 0.0 + loss_normalizing_factor = None + +# -------- Network specification --------- +t5x_network.TransformerEncoderOnly.config = @t5x_network.T5Config() +t5x_network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 0 + head_dim = 64 + mlp_dim = 10240 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + +text_enc/partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +text_enc/utils.RestoreCheckpointConfig: + path = %T5_CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' + +text_enc/sample_script.setup_text_enc: + model=%TEXT_ENC + restore_checkpoint_cfg=@text_enc/utils.RestoreCheckpointConfig() + partitioner=@text_enc/partitioning.PjitPartitioner() + batch_size=1 + seq_len=%TXT_SEQLEN_SINGLE + vocab = %VOCABULARY + +sample_script.sample: + base_setupdata = @base_model/sample_script.DiffusionModelSetupData() + sr256_setupdata = @sr256/sample_script.DiffusionModelSetupData() + sr1024_setupdata = None + out_dir = %SAVE_DIR + gen_per_prompt = %GEN_PER_PROMPT + prompt_file = %PROMPT_TEXT_FILE + batch_size = %GLOBAL_BATCH_SIZE + max_images = %MAX_GENERATE + text_enc_infer = @text_enc/sample_script.setup_text_enc() + noise_conditioning_aug = %NOISE_COND_AUG + resume_from = %RESUME_FROM \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_base_2B.gin b/rosetta/rosetta/projects/imagen/configs/imagen_base_2B.gin new file mode 100644 index 000000000..35b228dcb --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_base_2B.gin @@ -0,0 +1,63 @@ +# Imagen Base model. +from __gin__ import dynamic_registration + +import seqio +from rosetta.projects.imagen import network +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import augmentations +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses + +include 'rosetta/projects/diffusion/configs/adamw_ema_opt.gin' + +# ------------------- Loss HParam ---------------------------------------------- +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED +DTYPE= %gin.REQUIRED +SAMPLING_CONFIG = None + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +models.DenoisingDiffusionModel: + denoiser= @denoisers.EDMTextConditionedDenoiser() + diffusion_loss= @losses.EDMLoss() + diffusion_sampler= @samplers.EDMSampler() + optimizer_def = %OPTIMIZER + sampling_cfg = %SAMPLING_CONFIG + +# |--- Denoiser +denoisers.EDMTextConditionedDenoiser: + raw_model= @network.ImagenUNet() + +# |--- Diffusion Loss/Trainer +losses.EDMLoss: + sigma_data = %SIGMA_DATA + cond_aug_fn = @augmentations.text_conditioning_dropout + +augmentations.text_conditioning_dropout: + dropout_rate = 0.1 + +# ------------------- Network specification ------------------------------------ +network.ImagenUNet.config = @network.DiffusionConfig() +network.DiffusionConfig: + dtype = %DTYPE + model_dim = 512 + attn_cond_dim = 2048 + cond_dim = 2048 + resblocks_per_level = 3 + width_multipliers = (1, 2, 3, 4) + attn_resolutions = (32, 16, 8) + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_base_2B_img-txt-ds.gin b/rosetta/rosetta/projects/imagen/configs/imagen_base_2B_img-txt-ds.gin new file mode 100644 index 000000000..98589b886 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_base_2B_img-txt-ds.gin @@ -0,0 +1,8 @@ +include "rosetta/projects/imagen/configs/imagen_base_2B.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/img-txt-ds-base.gin" + +TRAIN_STEPS = 2500000 +IM_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_base_500M.gin b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M.gin new file mode 100644 index 000000000..3f6e07b26 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M.gin @@ -0,0 +1,23 @@ +# Imagen tiny model. +include 'rosetta/projects/imagen/configs/imagen_base_2B.gin' + +network.DiffusionConfig: + dtype = %DTYPE + model_dim = 256 + attn_cond_dim = 512 + cond_dim = 1024 + resblocks_per_level = 3 + width_multipliers = (1, 2, 3, 4) + attn_resolutions = (32, 16, 8) + mha_head_dim = 64 + attn_heads = 4 + resblock_activation = 'silu' + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_dummy.gin b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_dummy.gin new file mode 100644 index 000000000..d3a1d3f80 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_dummy.gin @@ -0,0 +1,8 @@ +include "rosetta/projects/imagen/configs/imagen_base_500M.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/dummy-base.gin" + +TRAIN_STEPS = 2500000 +IM_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_img-txt-ds.gin b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_img-txt-ds.gin new file mode 100644 index 000000000..e28aa2231 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_base_500M_img-txt-ds.gin @@ -0,0 +1,8 @@ +include "rosetta/projects/imagen/configs/imagen_base_500M.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/img-txt-ds-base.gin" + +TRAIN_STEPS = 2500000 +IM_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M.gin new file mode 100644 index 000000000..0e1eda741 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M.gin @@ -0,0 +1,69 @@ +# Imagen Base model. +from __gin__ import dynamic_registration + +import seqio +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +include 'rosetta/projects/diffusion/configs/adamw_ema_opt.gin' + +# ------------------- Loss HParam ---------------------------------------------- +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED +DTYPE = %gin.REQUIRED +SAMPLING_CONFIG = None + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +models.DenoisingDiffusionModel: + denoiser= @denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= @losses.EDMSuperResolutionLoss() + diffusion_sampler= @samplers.EDMSampler() + optimizer_def = %OPTIMIZER + sampling_cfg = %SAMPLING_CONFIG + +# |--- Denoiser +denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @network_sr.ImagenEfficientUNet() + +# |--- Diffusion Loss/Trainer +losses.EDMSuperResolutionLoss: + sigma_data = %SIGMA_DATA + cond_aug_fn = @augmentations.text_conditioning_dropout + dim_noise_scalar = 4. #due to running at 256x256 instead of 64x64 + +augmentations.text_conditioning_dropout: + dropout_rate = 0.1 + +samplers.assemble_cf_guidance_conds: + guidance_nulls = {'text': None, 'text_mask': None} + +samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +network_sr.ImagenEfficientUNet.config = @network_sr.ImagenEfficientUNetConfig() +network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 512 + attn_cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 2) + width_multipliers = (1, 2, 4, 8, 8) + attn_resolutions_divs = {8: 'fused', 16: 'fused'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M_img-txt-ds.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M_img-txt-ds.gin new file mode 100644 index 000000000..764fa30dc --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M_img-txt-ds.gin @@ -0,0 +1,9 @@ +include "rosetta/projects/imagen/configs/imagen_sr1_efficientunet_600M.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/img-txt-ds-sr1.gin" + +TRAIN_STEPS = 2500000 +IM_SHAPE=(256,256,3) #nhwc +LOW_RES_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M.gin new file mode 100644 index 000000000..7775fcf37 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M.gin @@ -0,0 +1,71 @@ +# Imagen Base model. +from __gin__ import dynamic_registration + +import seqio +from rosetta.projects.imagen import network +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +include 'rosetta/projects/diffusion/configs/adamw_ema_opt.gin' + +# ------------------- Loss HParam ---------------------------------------------- +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED +DTYPE = %gin.REQUIRED +SAMPLING_CONFIG = None + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +models.DenoisingDiffusionModel: + denoiser= @denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= @losses.EDMSuperResolutionLoss() + diffusion_sampler= @samplers.EDMSampler() + optimizer_def = %OPTIMIZER + sampling_cfg = %SAMPLING_CONFIG + +# |--- Denoiser +denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @network.ImagenUNet() + +# |--- Diffusion Loss/Trainer +losses.EDMSuperResolutionLoss: + sigma_data = %SIGMA_DATA + cond_aug_fn = @augmentations.text_conditioning_dropout + dim_noise_scalar = 4. #due to running at 256x256 instead of 64x64 + +augmentations.text_conditioning_dropout: + dropout_rate = 0.1 + +samplers.assemble_cf_guidance_conds: + guidance_nulls = {'text': None, 'text_mask': None} #avoid voiding lowres cond in cfg + +samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +network.ImagenUNet.config = @network.DiffusionConfig() +network.DiffusionConfig: + dtype = %DTYPE + model_dim = 128 + attn_cond_dim = 768 + cond_dim = 768 + resblocks_per_level = (2, 2, 3, 4, 2) + width_multipliers = (1, 2, 4, 6, 6) + attn_resolutions = (32, 16) #{8: 'fused', 16: 'fused'} + mha_head_dim = 64 + attn_heads = 12 + resblock_activation = 'silu' + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + upsample_mode = 'shuffle' + downsample_mode = 'shuffle' + spatial_skip = False + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M_img-txt-ds.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M_img-txt-ds.gin new file mode 100644 index 000000000..cc18ff5ba --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr1_unet_430M_img-txt-ds.gin @@ -0,0 +1,9 @@ +include "rosetta/projects/imagen/configs/imagen_sr1_unet_430M.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/img-txt-ds-sr1.gin" + +TRAIN_STEPS = 2500000 +IM_SHAPE=(256,256,3) #nhwc +LOW_RES_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M.gin new file mode 100644 index 000000000..c5d110f75 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M.gin @@ -0,0 +1,68 @@ +# Imagen Base model. +from __gin__ import dynamic_registration + +import seqio +from rosetta.projects.imagen import network_sr +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import denoisers +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import losses +from rosetta.projects.diffusion import augmentations + +include 'rosetta/projects/diffusion/configs/adamw_ema_opt.gin' + +# ------------------- Loss HParam ---------------------------------------------- +# Dropout should be specified in the "run" files +DROPOUT_RATE = %gin.REQUIRED +DTYPE = %gin.REQUIRED +SAMPLING_CONFIG = None + +# ------------------- Model ---------------------------------------------------- +MODEL = @models.DenoisingDiffusionModel() +SIGMA_DATA = 0.5 +models.DenoisingDiffusionModel: + denoiser= @denoisers.EDMTextConditionedSuperResDenoiser() + diffusion_loss= @losses.EDMSuperResolutionLoss() + diffusion_sampler= @samplers.EDMSampler() + optimizer_def = %OPTIMIZER + sampling_cfg = %SAMPLING_CONFIG + +# |--- Denoiser +denoisers.EDMTextConditionedSuperResDenoiser: + raw_model= @network_sr.ImagenEfficientUNet() + +# |--- Diffusion Loss/Trainer +losses.EDMSuperResolutionLoss: + sigma_data = %SIGMA_DATA + cond_aug_fn = @augmentations.text_conditioning_dropout + dim_noise_scalar = 4. #due to running at 256x256 instead of 64x64 + +augmentations.text_conditioning_dropout: + dropout_rate = 0.1 + +samplers.assemble_cf_guidance_conds: + guidance_nulls = {'text': None, 'text_mask': None} + +samplers.EDMSampler: + dim_noise_scalar = 4. + +# ------------------- Network specification ------------------------------------ +network_sr.ImagenEfficientUNet.config = @network_sr.ImagenEfficientUNetConfig() +network_sr.ImagenEfficientUNetConfig: + dtype = %DTYPE + model_dim = 128 + cond_dim = 1024 + resblocks_per_level = (2, 4, 8, 8, 8) + width_multipliers = (1, 2, 4, 6, 6) + attn_resolutions_divs = {16: 'cross'} + mha_head_dim = 64 + attn_heads = 8 + resblock_activation = 'silu' + resblock_zero_out = True + resblock_scale_skip = True + dropout_rate = %DROPOUT_RATE + cond_strategy = 'shift_scale' + norm_32 = True + scale_attn_logits = True + float32_attention_logits=False + text_conditionable = True diff --git a/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M_img-txt-ds.gin b/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M_img-txt-ds.gin new file mode 100644 index 000000000..d999cd491 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M_img-txt-ds.gin @@ -0,0 +1,12 @@ +include "rosetta/projects/imagen/configs/imagen_sr2_efficientunet_600M.gin" +include "rosetta/projects/imagen/configs/pretrain.gin" +include "rosetta/projects/imagen/configs/img-txt-ds-sr2.gin" + +TRAIN_STEPS = 2500000 + +IM_SHAPE=(1024,1024,3) #nhwc +LOW_RES_SHAPE_SAMPLING=(256,256,3) #[h, w, c] i.e (64,64,3) +CROP_SHAPE=(256,256,3) +LOW_RES_SHAPE=(64,64,3) #nhwc +TXT_SHAPE=(128,4096) #l, c +TXT_SEQLEN=(128,) #l \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/configs/img-txt-ds-base.gin b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-base.gin new file mode 100644 index 000000000..5fee17637 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-base.gin @@ -0,0 +1,79 @@ +from __gin__ import dynamic_registration +import jax.numpy as jnp +from rosetta.projects.diffusion import mm_utils +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import wds_utils +import t5.data + +IM_SHAPE=%gin.REQUIRED #[h, w, c] i.e. (64,64,3) +TXT_SHAPE=%gin.REQUIRED #[l, c] i.e. (128, 4096) +TXT_SEQLEN=%gin.REQURED #[l,] i.e. (128). Should match dim[0] above. Must be a tuple (include trailing comma) + +MIXTURE_OR_TASK_NAME = ("/mnt/datasets/folder_containing_image_txt_dataset_shards", "/mnt/datasets/optionally_another_folder_containing_image_txt_dataset_shards") # Can be a single/tuple of directories or file paths in braceexpand format +MIXTURE_OR_TASK_NAME_SAMPLING = "/opt/rosetta/rosetta/projects/diffusion/tests/custom_eval_prompts" +DROPOUT_RATE = 0.0 +MODALITIES = {'samples': @images/wds_utils.ModalityConfig(), 'text': @text/wds_utils.ModalityConfig(), 'text_mask': @text_mask/wds_utils.ModalityConfig()} +MODALITIES_SAMPLE = {'samples': @images_sample/wds_utils.ModalityConfig(), 'text': @text_sample/wds_utils.ModalityConfig(), 'text_mask': @text_mask_sample/wds_utils.ModalityConfig()} +BATCH_PROC = @wds_utils.batch_infer_extern + +SAMPLING_CONFIG = @samplers.CFGSamplingConfig() +samplers.CFGSamplingConfig: + num_steps=150 + cf_guidance_weight=7.50 + cf_guidance_nulls=None + +images/wds_utils.ModalityConfig: + ftype='jpg' + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.image_crop_scale + prefilter_func=@wds_utils.filter_lowres + +wds_utils.image_crop_scale: + out_img_shape=%IM_SHAPE + +wds_utils.filter_lowres: + min_dims=%IM_SHAPE + nhwc=True + +text/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +wds_utils.bare_txt_process: + shape = %TXT_SHAPE + +text_mask/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +wds_utils.batch_infer_extern: + text_emb_shape=%TXT_SHAPE + +# Sampling modalities +images_sample/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.blank_image + +wds_utils.blank_image: + out_img_shape=%IM_SHAPE + +text_sample/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +text_mask_sample/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True diff --git a/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr1.gin b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr1.gin new file mode 100644 index 000000000..527f52cbc --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr1.gin @@ -0,0 +1,93 @@ +from __gin__ import dynamic_registration +import jax.numpy as jnp +from rosetta.projects.diffusion import mm_utils +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import wds_utils +import t5.data + +IM_SHAPE=%gin.REQUIRED #[h, w, c] i.e. (256,256,3) +LOW_RES_SHAPE=%gin.REQUIRED #[h, w, c] i.e (64,64,3) +TXT_SHAPE=%gin.REQUIRED #[l, c] i.e. (128, 4096) +TXT_SEQLEN=%gin.REQURED #[l,] i.e. (128). Should match dim[0] above. Must be a tuple (include trailing comma) + +MIXTURE_OR_TASK_NAME = ("/mnt/datasets/folder_containing_image_txt_dataset_shards", "/mnt/datasets/optionally_another_folder_containing_image_txt_dataset_shards") # Can be a single/tuple of directories or file paths in braceexpand format +MIXTURE_OR_TASK_NAME_SAMPLING = %MIXTURE_OR_TASK_NAME +DROPOUT_RATE = 0.0 +MODALITIES = {'samples': @images/wds_utils.ModalityConfig(), 'text': @text/wds_utils.ModalityConfig(), 'text_mask': @text_mask/wds_utils.ModalityConfig(), 'low_res_images': @low_res_samples/wds_utils.ModalityConfig()} +MODALITIES_SAMPLE = {'samples': @images_sample/wds_utils.ModalityConfig(), 'text': @text_sample/wds_utils.ModalityConfig(), 'text_mask': @text_mask_sample/wds_utils.ModalityConfig(), 'low_res_images': @low_res_samples_sampling/wds_utils.ModalityConfig()} + +BATCH_PROC = @wds_utils.batch_infer_extern + +SAMPLING_CONFIG = @samplers.CFGSamplingConfig() +samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=3.00 + cf_guidance_nulls=None + +images/wds_utils.ModalityConfig: + ftype='jpg' + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.image_crop_scale_with_lowres + prefilter_func=@wds_utils.filter_lowres + +wds_utils.image_crop_scale_with_lowres: + out_img_shape=%IM_SHAPE + low_res_img_shape=%LOW_RES_SHAPE + +wds_utils.filter_lowres: + min_dims=%IM_SHAPE + nhwc=True + +text/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +wds_utils.bare_txt_process: + shape = %TXT_SHAPE + +text_mask/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +low_res_samples/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%LOW_RES_SHAPE + process_func=None + no_load=True + +wds_utils.batch_infer_extern: + text_emb_shape=%TXT_SHAPE + +# Sampling modalities +images_sample/wds_utils.ModalityConfig: + ftype='jpg' + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.image_crop_scale_with_lowres + +text_sample/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +text_mask_sample/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +low_res_samples_sampling/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%LOW_RES_SHAPE + process_func=None + no_load=True diff --git a/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr2.gin b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr2.gin new file mode 100644 index 000000000..1afbead89 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/img-txt-ds-sr2.gin @@ -0,0 +1,101 @@ +from __gin__ import dynamic_registration +import jax.numpy as jnp +from rosetta.projects.diffusion import mm_utils +from rosetta.projects.diffusion import samplers +from rosetta.projects.diffusion import wds_utils +import t5.data + +IM_SHAPE=%gin.REQUIRED #[h, w, c] i.e. (256,256,3) +LOW_RES_SHAPE=%gin.REQUIRED #[h, w, c] i.e (64,64,3) +LOW_RES_SHAPE_SAMPLING=%gin.REQUIRED #[h, w, c] i.e (64,64,3) +CROP_SHAPE=%gin.REQUIRED + +TXT_SHAPE=%gin.REQUIRED #[l, c] i.e. (128, 4096) +TXT_SEQLEN=%gin.REQURED #[l,] i.e. (128). Should match dim[0] above. Must be a tuple (include trailing comma) + +MIXTURE_OR_TASK_NAME = ("/mnt/datasets/folder_containing_image_txt_dataset_shards", "/mnt/datasets/optionally_another_folder_containing_image_txt_dataset_shards") # Can be a single/tuple of directories or file paths in braceexpand format +MIXTURE_OR_TASK_NAME_SAMPLING = %MIXTURE_OR_TASK_NAME +DROPOUT_RATE = 0.0 +MODALITIES = {'samples': @images/wds_utils.ModalityConfig(), 'text': @text/wds_utils.ModalityConfig(), 'text_mask': @text_mask/wds_utils.ModalityConfig(), 'low_res_images': @low_res_samples/wds_utils.ModalityConfig()} +MODALITIES_SAMPLE = {'samples': @images_sample/wds_utils.ModalityConfig(), 'text': @text_sample/wds_utils.ModalityConfig(), 'text_mask': @text_mask_sample/wds_utils.ModalityConfig(), 'low_res_images': @low_res_samples_sampling/wds_utils.ModalityConfig()} + +BATCH_PROC = @wds_utils.batch_infer_extern + +SAMPLING_CONFIG = @samplers.CFGSamplingConfig() +samplers.CFGSamplingConfig: + num_steps=50 + cf_guidance_weight=0.00 + +images/wds_utils.ModalityConfig: + ftype='jpg' + out_type='float32' + shape=%CROP_SHAPE + process_func=@wds_utils.image_subcrop_scale_with_lowres + prefilter_func=@wds_utils.filter_lowres + +wds_utils.image_subcrop_scale_with_lowres: + init_image_shape=%IM_SHAPE + crop_shape=%CROP_SHAPE + low_res_img_shape=%LOW_RES_SHAPE + +wds_utils.filter_lowres: + min_dims=%IM_SHAPE + nhwc=True + +text/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +wds_utils.bare_txt_process: + shape = %TXT_SHAPE + +text_mask/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +low_res_samples/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%LOW_RES_SHAPE + process_func=None + no_load=True + +wds_utils.batch_infer_extern: + text_emb_shape=%TXT_SHAPE + +# Sampling modalities +images_sample/wds_utils.ModalityConfig: + ftype='jpg' + out_type='float32' + shape=%IM_SHAPE + process_func=@wds_utils.image_crop_scale_with_lowres + +wds_utils.image_crop_scale_with_lowres: + out_img_shape=%IM_SHAPE + low_res_img_shape=%LOW_RES_SHAPE_SAMPLING + nhwc=True + +text_sample/wds_utils.ModalityConfig: + ftype='txt' + out_type='float16' + shape=%TXT_SHAPE + process_func=@wds_utils.bare_txt_process + +text_mask_sample/wds_utils.ModalityConfig: + ftype=None + out_type='int' + shape=%TXT_SEQLEN + process_func=None + no_load=True + +low_res_samples_sampling/wds_utils.ModalityConfig: + ftype=None + out_type='float32' + shape=%LOW_RES_SHAPE_SAMPLING + process_func=None + no_load=True diff --git a/rosetta/rosetta/projects/imagen/configs/pretrain.gin b/rosetta/rosetta/projects/imagen/configs/pretrain.gin new file mode 100644 index 000000000..dcb5b54d8 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/configs/pretrain.gin @@ -0,0 +1,132 @@ +# Defaults for pretraining with train.py. +# +# +# You must also include a binding for MODEL. +# +# Required to be set: +# +# - MIXTURE_OR_TASK_NAME +# - TRAIN_STEPS +# - MODALITIES +# - MODEL_DIR +# +# Commonly overridden options: +# +# - train/DatasetConfig.batch_size +# - train_eval/DatasetConfig.batch_size +# - PjitPartitioner.num_partitions +# - Trainer.num_microbatches +# - DROPOUT_RATE +from __gin__ import dynamic_registration + +import __main__ as train_script +from t5x import gin_utils +from t5x import partitioning +from t5x import utils +from t5x import trainer +from rosetta.projects.diffusion import wds_utils +from rosetta.projects.diffusion import mm_utils +import optax + +MIXTURE_OR_TASK_NAME = %gin.REQUIRED +MIXTURE_OR_TASK_NAME_SAMPLING = %MIXTURE_OR_TASK_NAME +TRAIN_STEPS = %gin.REQUIRED +MODEL_DIR = %gin.REQUIRED +BATCH_SIZE = 128 +INFER_BS = %BATCH_SIZE +INFER_SAMPLES = %INFER_BS +BATCH_PROC=None +HOSTNAMES_FILE=None +SAMPLING_CONFIG= %gin.REQUIRED +SAMPLING_ENABLE=True +EMA = 0.9999 + +# DEPRECATED: Import the this module in your gin file. +MIXTURE_OR_TASK_MODULE = None +SHUFFLE_TRAIN_EXAMPLES = True + +# HW RNG is faster than SW, but has limited determinism. +# Most notably it is not deterministic across different +# submeshes. +USE_HARDWARE_RNG = False +# None always uses faster, hardware RNG +RANDOM_SEED = None + +# Can be overridden with `train.*`.` +train_script.train: + model = %MODEL # imported from separate gin file + model_dir = %MODEL_DIR + train_dataset_cfg = @train/wds_utils.WebDatasetConfig() + train_eval_dataset_cfg = None + infer_eval_dataset_cfg = @sampling/wds_utils.WebDatasetConfig() + inference_evaluator_cls = @mm_utils.DiffusionSamplingEvaluator + checkpoint_cfg = @utils.CheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + trainer_cls = @trainer.Trainer + total_steps = %TRAIN_STEPS + eval_steps = 20 + eval_period = 5000 + random_seed = %RANDOM_SEED + use_hardware_rng = %USE_HARDWARE_RNG + summarize_config_fn = @gin_utils.summarize_gin_config + get_dataset_fn = @mm_utils.get_dataset + run_eval_before_training=True + gc_period=1000 + actions = {'TRAIN': @trainer.TerminateOnNanAction} + verify_matching_vocabs_fn = None + +partitioning.PjitPartitioner: + num_partitions = 1 + model_parallel_submesh = None + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +train/wds_utils.WebDatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME + batch_size = %BATCH_SIZE + shuffle = %SHUFFLE_TRAIN_EXAMPLES + seed = None # use a new seed each run/restart + modalities = %MODALITIES + batch_proc=%BATCH_PROC + hostnames_file=%HOSTNAMES_FILE + + +sampling/wds_utils.WebDatasetConfig: + mixture_or_task_name = %MIXTURE_OR_TASK_NAME_SAMPLING + batch_size = %INFER_BS + shuffle = False + seed = None # use a new seed each run/restart + modalities = %MODALITIES_SAMPLE + samples=%INFER_SAMPLES + batch_proc=%BATCH_PROC + hostnames_file=%HOSTNAMES_FILE + +utils.CheckpointConfig: + restore = @utils.RestoreCheckpointConfig() + save = @utils.SaveCheckpointConfig() +utils.RestoreCheckpointConfig: + path = [] # initialize from scratch + +trainer.Trainer: + num_microbatches = None + +utils.SaveCheckpointConfig: + period = 5000 + dtype = 'float32' + keep = 2 # keep 2 checkpoints + save_dataset = False # checkpoint dataset state + +# This scheduler is made with adam in mind. Use the scheduler from pretrain.gin if using adafactor +WARMUP_STEPS = 10000 +warmup/optax.linear_schedule: + init_value = 0.000001 + end_value = 0.0001 + transition_steps = %WARMUP_STEPS +decay/optax.linear_schedule: + init_value = 0.0001 + end_value = 0.00001 + transition_steps = 2_490_000 +optax.join_schedules: + schedules = [@warmup/optax.linear_schedule(), @decay/optax.linear_schedule()] + boundaries = [%WARMUP_STEPS] + + diff --git a/rosetta/rosetta/projects/imagen/imagen_pipe.py b/rosetta/rosetta/projects/imagen/imagen_pipe.py new file mode 100644 index 000000000..fb96d7ffd --- /dev/null +++ b/rosetta/rosetta/projects/imagen/imagen_pipe.py @@ -0,0 +1,323 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Runs inference of an Imagen base model and a 64->256 superresolution model +import dataclasses +import os +import re +import functools +from typing import Mapping, Any, Optional, Callable, Sequence +import logging + +import numpy as np +import jax +import jax.numpy as jnp +import seqio +from t5x import partitioning +from t5x import utils +from t5x import models as t5x_models +from seqio.vocabularies import PAD_ID +from rosetta.projects.diffusion import models +from rosetta.projects.diffusion import samplers +import matplotlib.image as matimg + +from rosetta.projects.diffusion.mm_utils import expand_dims_like +# Automatically search for gin files relative to the T5X package. +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + +@dataclasses.dataclass +class DiffusionModelSetupData: + model: models.DenoisingDiffusionModel + sampling_cfg: samplers.SamplingConfig + restore_checkpoint_cfg: utils.RestoreCheckpointConfig + partitioner: partitioning.BasePartitioner + input_shapes: Mapping[str, Any] + input_types: Mapping[str, Any] + +def pad_right(tokens, seq_len, eos_id,pad_id): + padded, tok_lengths = [], [] + for t in tokens: + diff = seq_len - (len(t) + 1) + #assert diff >= 0 + if diff < 0: + padded.append(t[:seq_len - 1] + [eos_id]) + tok_lengths.append(seq_len) + else: + padded.append(t + [eos_id] + [pad_id] * diff) + tok_lengths.append(len(t) + 1) + + return jnp.array(padded, dtype=jnp.int32), tok_lengths, seq_len + +def seqio_preprocessing(mbatch, vocab:Any, seq_len:int=128): + return pad_right(vocab.encode(mbatch), seq_len=seq_len, eos_id=vocab.eos_id, pad_id=PAD_ID) + +def setup_text_enc(model: t5x_models.BaseTransformerModel, + restore_checkpoint_cfg: utils.RestoreCheckpointConfig, + partitioner: partitioning.BasePartitioner, + batch_size=1, seq_len=128, vocab=None, + ): + input_shapes = {'encoder_input_tokens': (batch_size, seq_len)} + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=None, # Do not load optimizer state. + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + partitioner=partitioner) + train_state_axes = train_state_initializer.train_state_axes + + # Disable strictness since we are dropping the optimizer state. + restore_checkpoint_cfg.strict = False + + fallback_init_rng = None + + if fallback_init_rng is not None: + fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) + train_state = list(train_state_initializer.from_checkpoints([restore_checkpoint_cfg], init_rng=fallback_init_rng))[0] + logging.warning(f'Restored from Checkpoint: {train_state[1]}') + train_state = train_state[0] + + partitioned_fn = partitioner.partition( + model.score_batch, + in_axis_resources=(train_state_axes.params, partitioning.PartitionSpec('data',)), + out_axis_resources=None) + + def infer_fn(inputs: Sequence[str]): + tokenized_padded, batch_len, curr_seqlen = seqio_preprocessing(inputs, vocab, seq_len=seq_len) + results = partitioned_fn(train_state.params, {"encoder_input_tokens": tokenized_padded}).astype(jnp.float16) + + bs = len(inputs) + individual_shape = results[0].shape + padded_output = np.zeros((bs, *individual_shape), dtype=np.float16) + for idx, (tensor, true_len) in enumerate(zip(results, batch_len)): + padded_output[idx, :true_len] = tensor[:true_len] + + return padded_output, jnp.array(batch_len, dtype=np.int32) + + + return infer_fn + +def get_sample_fn(model_setup:DiffusionModelSetupData): + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=None, # Do not load optimizer state. + init_fn=model_setup.model.get_initial_variables, + input_shapes=model_setup.input_shapes, + input_types=model_setup.input_types, + partitioner=model_setup.partitioner) + + train_state_axes = train_state_initializer.train_state_axes + + # Disable strictness since we are dropping the optimizer state. + model_setup.restore_checkpoint_cfg.strict = False + + fallback_init_rng = None + + if fallback_init_rng is not None: + fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) + train_state = list(train_state_initializer.from_checkpoints([model_setup.restore_checkpoint_cfg], init_rng=fallback_init_rng))[0] + logging.warning(f'Restored from Checkpoint: {train_state[1]}') + train_state = train_state[0] + + model_pred = functools.partial(model_setup.model.predict_batch, sampling_cfg=model_setup.sampling_cfg) + partitioned_fn = model_setup.partitioner.partition( + model_pred, + in_axis_resources=(train_state_axes.params, model_setup.partitioner.data_partition_spec, None), + out_axis_resources=model_setup.partitioner.data_partition_spec) + + return train_state.params, partitioned_fn + +def sanitize_filename(filename): + # Remove leading and trailing whitespaces + filename = filename.strip() + + # Replace spaces with underscores + filename = filename.replace(" ", "_") + + # Remove any other characters that are not allowed in a Linux filename + filename = re.sub(r'[^\w\.-]', '', filename) + + # Remove forward slashes + filename = filename.replace("/", "") + + return filename + +def sample( + base_setupdata: DiffusionModelSetupData, + sr256_setupdata: DiffusionModelSetupData, + out_dir: str, + sr1024_setupdata: DiffusionModelSetupData=None, + gen_per_prompt: int = 1, + text_enc_infer = Callable, + prompt_file=None, + batch_size=32, + max_images=50000000, + base_img_size=(64, 64, 3), + sr256_img_size=(256, 256, 3), + sr1024_img_size=(1024, 1024, 3), + noise_conditioning_aug=0.002, + resume_from=0 + ): + if not os.path.exists(out_dir): + os.makedirs(out_dir) + base_dir = os.path.join(out_dir, 'base') + sr_dir = os.path.join(out_dir, 'sr') + sr2_dir = os.path.join(out_dir, 'sr2') + if not os.path.exists(base_dir): + os.makedirs(base_dir) + if not os.path.exists(sr_dir): + os.makedirs(sr_dir) + if sr1024_setupdata is not None and not os.path.exists(sr_dir): + os.makedirs(sr_dir) + + with open(prompt_file, 'r') as f: + prompts = f.readlines() + prompt_ct = len(prompts) + + # Set up models + base_params, base_fn = get_sample_fn(base_setupdata) + sr256_params, sr256_fn = get_sample_fn(sr256_setupdata) + if sr1024_setupdata is not None: + sr1024_params, sr1024_fn = get_sample_fn(sr1024_setupdata) + text_encoder = text_enc_infer + + sampled_ctr = 0 + rng = jax.random.PRNGKey(0) + for start_idx in range(resume_from, max_images, batch_size // gen_per_prompt): + if start_idx > prompt_ct: + break + prompt_batch = prompts[start_idx: start_idx + (batch_size // gen_per_prompt)] * gen_per_prompt + rng, rng_base, rng_sr, rng_sr2, rng_aug = jax.random.split(rng, 5) + + # Encode Text + encoded_text, text_lens = text_encoder(prompt_batch) + text_mask = np.zeros(encoded_text.shape[:2]) + for i in range(text_lens.shape[0]): + text_mask[i][:text_lens[i]] = 1 + + # Base model generation + base_img_inputs = jnp.zeros((len(prompt_batch), *base_img_size)) + sr256_img_inputs = jnp.zeros((len(prompt_batch), *sr256_img_size)) + sr1024_img_inputs = jnp.zeros((len(prompt_batch), *sr1024_img_size)) + base_batch = {'samples': base_img_inputs, 'text': encoded_text, 'text_mask': text_mask} + base_out = base_fn(base_params, base_batch, rng_base) + for i in range(base_out.shape[0]): + matimg.imsave(os.path.join(base_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), np.clip(base_out[i], a_min=0, a_max=1)) + + # Stage 2: Super Resolution (64-> 256) + base_aug = (base_out * 2 - 1) + noise_aug_level = expand_dims_like(jnp.ones((base_aug.shape[0], )) * noise_conditioning_aug, base_aug) + sr256_batch = {'samples': sr256_img_inputs, 'text':encoded_text, 'text_mask':text_mask, 'low_res_images': base_aug, 'noise_aug_level': noise_aug_level} + sr_out = sr256_fn(sr256_params, sr256_batch, rng_sr) + sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) + for i in range(sr_out.shape[0]): + matimg.imsave(os.path.join(sr_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + + # Stage 3: Super Resolution (256-> 1024) + if sr1024_setupdata is not None: + sr_aug = (sr_out * 2 - 1) + noise_aug_level = expand_dims_like(jnp.ones((sr_aug.shape[0], )) * noise_conditioning_aug, base_aug) + sr1024_batch = {'samples': sr1024_img_inputs, 'text':encoded_text, 'text_mask':text_mask, 'low_res_images': sr_aug, 'noise_aug_level': noise_aug_level} + sr_out = sr1024_fn(sr1024_params, sr1024_batch, rng_sr2) + sr_out = jnp.clip(sr_out, a_min = 0, a_max = 1) + for i in range(sr_out.shape[0]): + matimg.imsave(os.path.join(sr2_dir, sanitize_filename(f'{prompt_batch[i]}_{sampled_ctr + i}.png')), sr_out[i]) + + sampled_ctr += sr_out.shape[0] + + +if __name__ == '__main__': + # pylint: disable=g-import-not-at-top + from absl import app + from absl import flags + import gin + from t5x import gin_utils + import tensorflow as tf + # pylint: enable=g-import-not-at-top + FLAGS = flags.FLAGS + + jax.config.parse_flags_with_absl() + + flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + + flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings.') + + flags.DEFINE_list( + 'gin_search_paths', + default=['.'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + + flags.DEFINE_boolean( + 'multiprocess_gpu', + False, + help='Initialize JAX distributed system for multi-host GPU, using ' + '`coordinator_address`, `process_count`, and `process_index`.') + + flags.DEFINE_string( + 'coordinator_address', + None, + help='IP address:port for multi-host GPU coordinator.') + + flags.DEFINE_integer( + 'process_count', None, help='Number of processes for multi-host GPU.') + + flags.DEFINE_integer('process_index', None, help='Index of this process.') + + + def main(argv: Sequence[str]): + """Wrapper for pdb post mortems.""" + _main(argv) + + def _main(argv: Sequence[str]): + """True main function.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + # OOM fix. Prevents TF from seeing GPUs to stop conflict with JAX. + # This must go after InitGoogle(), which is called by + # gin_utils.run(main). + tf.config.experimental.set_visible_devices([], 'GPU') + + + if FLAGS.multiprocess_gpu: + logging.info( + 'Initializing distributed system for multi-host GPU:\n' + ' coordinator_address: %s\n process_count: %s\n process_index: %s', + FLAGS.coordinator_address, FLAGS.process_count, FLAGS.process_index) + + jax.distributed.initialize(FLAGS.coordinator_address, FLAGS.process_count, + FLAGS.process_index) + + # Create gin-configurable version of `train`. + sample_using_gin = gin.configurable(sample) + + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file, + FLAGS.gin_bindings) + sample_using_gin() + jax.effects_barrier() + + + gin_utils.run(main) diff --git a/rosetta/rosetta/projects/imagen/layers.py b/rosetta/rosetta/projects/imagen/layers.py new file mode 100644 index 000000000..e245df7dd --- /dev/null +++ b/rosetta/rosetta/projects/imagen/layers.py @@ -0,0 +1,648 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"Diffusion U-Net layers" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +#TODO Dropout +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np +from einops import rearrange + +from t5x.contrib.gpu.t5.layers import MultiHeadDotProductAttention, make_attention_mask, combine_biases, dot_product_attention, RelativePositionBiases, MlpBlock +from flax.linen import DenseGeneral, GroupNorm, Conv, LayerNorm +from rosetta.projects.imagen.network import DiffusionConfig + +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + +# Type annotations +Array = jnp.ndarray +OptionalArray = Optional[jnp.ndarray] +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] +PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], + Tuple[lax.Precision, lax.Precision]] +PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] +LaxPadding = Union[str, Sequence[Tuple[int, int]]] +Dtype = Any + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + +dynamic_vector_slice_in_dim = jax.vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + +def get_timestep_embedding(timesteps, embedding_dim: int, dtype=jnp.float32): + """Build sinusoidal embeddings + Args: + timesteps: jnp.ndarray: generate embedding vectors at these timesteps + embedding_dim: int: dimension of the embeddings to generate + dtype: data type of the generated embeddings + Returns: + embedding vectors with shape `[len(timesteps), embedding_dim]` + """ + timesteps = jnp.reshape(timesteps, timesteps.shape[0]) + assert len(timesteps.shape) == 1, "timesteps don't have one dimension, " + str(timesteps.shape) + half = embedding_dim // 2 + freqs = jnp.exp( + -np.log(10000) * jnp.arange(0, half, dtype=dtype) / half + ) + args = timesteps[:, None] * freqs[None] + embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1) + + if embedding_dim % 2 == 1: # zero pad + embedding = jax.lax.pad(embedding, dtype(0), ((0, 0, 0), (0, 1, 0))) + assert embedding.shape == (timesteps.shape[0], embedding_dim) + + return embedding + +def FP32Wrap(operator, should_fp32=False): + if should_fp32: + def ret(x): + h = jnp.asarray(x, dtype=jnp.float32) + return jnp.asarray(operator(h), dtype=x.dtype) + return ret + else: + return operator + +class FusedSelfCrossMultiHeadDotProductAttention(nn.Module): + """Fused self attention with cross attention for image-text self and cross attn. + + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + head_dim: dimension of each head. + dtype: the dtype of the computation. + dropout_rate: dropout rate + kernel_init: initializer for the kernel of the Dense layers. + float32_logits: bool, if True then compute logits in float32 to avoid + numerical issues with bfloat16. + zero_out: bool. Will initialize out projection to zero if true + """ + + num_heads: int + head_dim: int + dtype: DType = jnp.float32 + dropout_rate: float = 0. + kernel_init: Initializer = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal') + float32_logits: bool = False # computes logits in float32 for stability. + scale_attn_logits: bool = False + zero_out: bool = True + project_output: bool = True + + @nn.compact + def __call__(self, + inputs_q: Array, + inputs_kv_self: Array, + inputs_kv_cross: Optional[Array] = None, + mask: Optional[Array] = None, + bias: Optional[Array] = None, + *, + deterministic: bool = False) -> Array: + """Applies self and cross attention on the input data. + + Projects the inputs into multi-headed query, key, and value vectors, + applies dot-product attention and project the results to an output vector. + + Args: + inputs_q: input queries of shape `[batch, q_length, q_features]`. + inputs_kv_self: key/values of shape `[batch, kv_self_length, kv_self_features]`. + inputs_kv_cross: key/values of shape `[batch, kv_cross_length, kv_cross_features]' + mask: attention mask of shape `[batch, num_heads, q_length, kv_length]`. + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + decode: Whether to prepare and use an autoregressive cache. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, length, q_features]`. + """ + projection = functools.partial( + DenseGeneral, + axis=-1, + features=(self.num_heads, self.head_dim), + kernel_axes=('embed', 'joined_kv'), + use_bias=True, + dtype=self.dtype) + + # NOTE: T5 does not explicitly rescale the attention logits by + # 1/sqrt(depth_kq)! This is folded into the initializers of the + # linear transformations, which is equivalent under Adafactor. + depth_scaling = jnp.sqrt(self.head_dim).astype(self.dtype) + query_init = lambda *args: self.kernel_init(*args) / depth_scaling + + # Project inputs_q to multi-headed q/k/v + # dimensions are then [batch, length, num_heads, head_dim] + query = projection(kernel_init=query_init, name='query')( \ + (inputs_q / depth_scaling) if self.scale_attn_logits else inputs_q) + key = projection(kernel_init=self.kernel_init, name='key')(inputs_kv_self) + value = projection(kernel_init=self.kernel_init, name='value')(inputs_kv_self) + + query = with_sharding_constraint(query, ('batch', 'length', 'heads', 'kv')) + + if inputs_kv_cross is not None: + key_cross = projection(kernel_init=self.kernel_init, name='key_cross')(inputs_kv_cross) + value_cross = projection(kernel_init=self.kernel_init, name='value_cross')(inputs_kv_cross) + + # Concatenate on length axis + key = jnp.concatenate((key, key_cross), axis=1) + value = jnp.concatenate((value, value_cross), axis = 1) + + key = with_sharding_constraint(key, ('batch', 'length', 'heads', 'kv')) + value = with_sharding_constraint(value, ('batch', 'length', 'heads', 'kv')) + + # Convert the boolean attention mask to an attention bias. + if mask is not None: + # attention mask in the form of attention bias + attention_bias = lax.select( + mask > 0, + jnp.full(mask.shape, 0.).astype(self.dtype), + jnp.full(mask.shape, -1e10).astype(self.dtype)) + else: + attention_bias = None + + # Add provided bias term (e.g. relative position embedding). + if bias is not None: + attention_bias = combine_biases(attention_bias, bias) + + dropout_rng = None + if not deterministic and self.dropout_rate > 0.: + dropout_rng = self.make_rng('dropout') + + # Apply attention. + x = dot_product_attention( + query, + key, + value, + bias=attention_bias, + dropout_rng=dropout_rng, + dropout_rate=self.dropout_rate, + deterministic=deterministic, + dtype=self.dtype, + float32_logits=self.float32_logits) + + if self.project_output: + # Back to the original inputs dimensions. + out = DenseGeneral( + features=inputs_q.shape[-1], # output dim is set to the input dim. + axis=(-2, -1), + kernel_init=self.kernel_init if not self.zero_out else nn.initializers.zeros, + kernel_axes=('joined_kv', 'embed'), + use_bias=True, + dtype=self.dtype, + name='out')( + x) + return out + else: + return x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + +class AttentionPoolingBlock(nn.Module): + """ + Attention Pooling via self+cross attention with the mean. Assumes inputs + are normalized already. Uses RelativePositionBiases. + + Attributes: + cfg: Model-wide configuration. This will use the same parameters + as all other attention layers. + num_heads: optional override on the number of heads in this layer + """ + cfg: DiffusionConfig + num_heads: Optional[int] = None + + @nn.compact + def __call__(self, + inputs: Array, + text_lens: Optional[Array] = None, + *, + deterministic: bool = False) -> Array: + """ Performs attention pooling by doing cross attention between mean embedding and + all tokens. + + Args: + inputs: input sequence of shape `[batch, seq_length, features]`. + text_lens: Array of text masks (for masking) [batch, seq_length] + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, features]`. + """ + cfg = self.cfg + num_heads = self.num_heads if self.num_heads is not None else cfg.attn_heads + + pos_bias = RelativePositionBiases( + num_buckets=32, + max_distance=128, + num_heads=num_heads, + dtype=cfg.dtype, + embedding_init=nn.initializers.variance_scaling(1.0, 'fan_avg', + 'uniform'), + name='relpos_bias')(1, inputs.shape[1] + 1, False) + + mask = None + if text_lens is not None: + m = jnp.ones((inputs.shape[0], 1)) + mt = jnp.concatenate((m, text_lens), axis = 1) + mask = make_attention_mask(m, mt, dtype=cfg.dtype) + + text_mask = rearrange(text_lens, 'b l -> b l 1') + masked_inputs = inputs * text_mask + mean = jnp.mean(masked_inputs, axis=1, keepdims=True) # [batch, 1, q_features] + out = MultiHeadDotProductAttention(num_heads=num_heads, + head_dim=cfg.mha_head_dim, + dtype=cfg.dtype, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + scale_attn_logits=cfg.scale_attn_logits) \ + (mean, jnp.concatenate((mean, inputs), axis=1), mask=mask, bias=pos_bias, deterministic=deterministic) + out = rearrange(out, 'b l e -> b (l e)') # l should equal 1 + assert out.shape[1] == inputs.shape[2] + return out + +class DeepFloydAttentionPoolingBlock(nn.Module): + """ + Attention Pooling via self+cross attention with the mean. Assumes inputs + are normalized already. Uses Absolute Position Embeddings. + + Attributes: + cfg: Model-wide configuration. This will use the same parameters + as all other attention layers. + num_heads: optional override on the number of heads in this layer + """ + cfg: DiffusionConfig + num_heads: Optional[int] = None + + @nn.compact + def __call__(self, + inputs: Array, + text_lens: Optional[Array] = None, + *, + deterministic: bool = False) -> Array: + """ Performs attention pooling by doing cross attention between mean embedding and + all tokens. + + Args: + inputs: input sequence of shape `[batch, seq_length, features]`. + text_lens: Array of text masks (for masking) [batch, seq_length] + bias: attention bias of shape `[batch, num_heads, q_length, kv_length]`. + deterministic: Disables dropout if set to True. + + Returns: + output of shape `[batch, features]`. + """ + cfg = self.cfg + num_heads = self.num_heads if self.num_heads is not None else cfg.attn_heads + + embedding_init = nn.initializers.normal(stddev=jnp.sqrt(inputs.shape[2])) + pos_bias = param_with_axes('position_embed', embedding_init, (1, inputs.shape[2],), jnp.float32, axes=('empty', 'embed')) + pos_bias = jnp.asarray(pos_bias, dtype=inputs.dtype) + pos_bias = rearrange(pos_bias, '1 e -> 1 1 e') + mask = None + if text_lens is not None: + m = jnp.ones((inputs.shape[0], 1)) + mt = jnp.concatenate((m, text_lens), axis = 1) + mask = make_attention_mask(m, mt, dtype=cfg.dtype) + + text_mask = rearrange(text_lens, 'b l -> b l 1') + masked_inputs = inputs * text_mask + mean = jnp.mean(masked_inputs, axis=1, keepdims=True) + pos_bias # [batch, 1, q_features] + out = FusedSelfCrossMultiHeadDotProductAttention(num_heads=num_heads, + head_dim=cfg.mha_head_dim, + dtype=cfg.dtype, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + scale_attn_logits=cfg.scale_attn_logits, + project_output=False) \ + (mean, jnp.concatenate((mean, inputs), axis=1), None, mask=mask, bias=None, deterministic=deterministic) + out = rearrange(out, 'b l e -> b (l e)') # l should equal 1 + assert out.shape[1] == inputs.shape[2] + return out + + +class ImgAttentionBlock(nn.Module): + """ Residual MHA block with normalization and reshaping for images + and optional text conditioning. + + cfg: DiffusionConfig + num_heads: Optional number of heads for attention. Will default to + cfg.attn_heads + """ + cfg: DiffusionConfig + num_heads: Optional[int] + + @nn.compact + def __call__(self, + inputs: Array, + text_enc: Optional[Array]=None, + text_mask: Optional[Array]=None, + deterministic=True) -> Array: + """ + Applies self-attention to an image an optionally text encodings. Assumes text_enc + has been normalized already. + + Args: + inputs: Images [b, h, w, c] + text_enc: Text encodings [b, l, c] + text_mask: Array of text masks (for masking) [b, l] + deterministic: Whether to enable dropout + """ + + cfg = self.cfg + x = rearrange(inputs, 'b h w c -> b (h w) c') + if cfg.unified_qkv_norm: + x_q = FP32Wrap(GroupNorm(num_groups=32, name='img_attn_gn'), should_fp32=cfg.norm_32)(x) + x_kv = x_q + else: + x_q = FP32Wrap(GroupNorm(num_groups=32, name='img_attn_gn_q'), should_fp32=cfg.norm_32)(x) + x_kv = FP32Wrap(GroupNorm(num_groups=32, name='img_attn_gn_kv'), should_fp32=cfg.norm_32)(x) + + mask = None + if text_enc is not None: + text_enc = FP32Wrap(GroupNorm(num_groups=32, name='text_enc_ln'), should_fp32=cfg.norm_32)(text_enc) + if text_mask is not None: + m = jnp.ones((x.shape[0], x.shape[1])) + mt = jnp.concatenate((m, text_mask), axis = 1) + mask = make_attention_mask(m, mt, dtype=cfg.dtype) + + num_heads = self.num_heads if self.num_heads is not None else cfg.attn_heads + x = FusedSelfCrossMultiHeadDotProductAttention(num_heads=num_heads, + head_dim=cfg.mha_head_dim, + dtype=cfg.dtype, + dropout_rate=cfg.dropout_rate, + float32_logits=cfg.float32_attention_logits, + scale_attn_logits=cfg.scale_attn_logits, + name='mha_layer')(x_q, x_kv, text_enc, mask=mask, deterministic=deterministic) + + x = rearrange(x, 'b (h w) c -> b h w c', h=inputs.shape[1], w=inputs.shape[2]) + return x + inputs + +def identity(x: Array) -> Array: + return x + +class ResBlock(nn.Module): + """ Residual block in the style of Nichol et. al. Used in a UNet + + Attributes: + cfg: DiffuionConfig + out_channels: Output channel count + up_down_sample: 'up', 'down', or 'none'. Sets if upscaling, downscaling or neither + kernel_init: Kernel init function for embedding layers and first conv + """ + cfg: DiffusionConfig + out_channels: int + up_down_sample: str = 'none' + kernel_init: Optional[Initializer] = nn.initializers.lecun_normal() + + def _get_scaling_block(self, up_down:str): + cfg = self.cfg + if up_down == 'up': + return Upsample(mode=self.cfg.upsample_mode, + kernel_init=self.kernel_init, + dtype=cfg.dtype, + norm_32=self.cfg.norm_32) + elif up_down == 'down': + return Downsample(mode=self.cfg.downsample_mode, + kernel_init=self.kernel_init, + dtype=cfg.dtype) + elif up_down == 'none': + return identity + else: + raise ValueError(f'Attempting to construct a resblock with up_down \ + type {up_down}. Please use one of \'up\', \'down\', or \'none\'') + + @nn.compact + def __call__(self, inputs, + conditioning, + deterministic: bool=False): + """ Apply ResBlock. + + input shape: B H W C_model + conditioning: B C_cond_embedding + """ + cfg = self.cfg + + # normalization = GroupNormFP32 if cfg.norm_32 else GroupNorm + activation = _convert_to_activation_function(cfg.resblock_activation) + spatial_scaling_fn = self._get_scaling_block(self.up_down_sample) + spatial_conv = functools.partial(Conv, + self.out_channels, + kernel_size=(3,3), + strides=1, + padding='SAME', + dtype=cfg.dtype) + + # conditioning embedding calculation + cond_dim = self.out_channels * (2 if cfg.cond_strategy == 'shift_scale' else 1) + cond = activation(conditioning) + cond = DenseGeneral(cond_dim, + axis=-1, + dtype=cfg.dtype, + use_bias=True, + kernel_axes=('embed',))(cond) + cond = rearrange(cond, 'b c -> b 1 1 c') + + # spatial scaling residual + res = spatial_scaling_fn(inputs) + + # in block + # ensure input channels % 32 == 0 + h = FP32Wrap(GroupNorm(num_groups=32, name='resblock_pre_in_conv_groupnorm'), should_fp32=cfg.norm_32)(inputs) + h = activation(h) + h = spatial_scaling_fn(h) + h = spatial_conv(kernel_init=self.kernel_init, name='resblock_in_conv')(h) + + # combine embedding + out_block + out_norm = FP32Wrap(GroupNorm(num_groups=32, name='resblock_pre_out_conv_groupnorm'), should_fp32=cfg.norm_32) + if cfg.cond_strategy == 'shift_scale': + h = out_norm(h) + # combine embedding + shift, scale = jnp.split(cond, 2, axis=-1) + h = h * (scale + 1) + shift + elif cfg.cond_strategy == 'addition': + h = h + cond + h = out_norm(h) + else: + NotImplementedError(cfg.cond_strategy + " conditioning strategy not implemented.\ + Use \'shift_scale\' or \'addition\' instead.") + h = activation(h) + h = nn.Dropout(rate=cfg.dropout_rate)(h, deterministic=deterministic) + h = spatial_conv(kernel_init=nn.initializers.zeros, name='resblock_out_conv')(h) + + # residual channel adjustment + if self.out_channels != inputs.shape[-1]: + if cfg.spatial_skip: + res = spatial_conv(kernel_init=self.kernel_init, name='resblock_skip_conv')(res) + else: + res = Conv(self.out_channels, + kernel_size=(1,1), + dtype=cfg.dtype, + kernel_init=self.kernel_init, + name='resblock_skip_conv')(res) + + # residual addition + out_sum = h + res + if cfg.resblock_scale_skip: + return out_sum * .7071 # 1/sqrt(2) + else: + return out_sum + +def _pixel_shuffle_kernel_init(key, shape, dtype, window=2): + """ + Conv kernel init with replication over the shuffled axis. + Replicated such that initial initial upscales will be like interpolated onces. + """ + h, w, i, o = shape + jax.debug.print('Conv kernel shape: ', str(shape)) + partial_shape = h, w, i, o // (window ** 2) + init = nn.initializers.kaiming_uniform()(key, partial_shape, dtype) + repl = jnp.repeat(init, window ** 2, axis=3) # H, W, I, O + return repl + +class Upsample(nn.Module): + """ + Upsampling module done optionally with convolution + + Attributes: + scaling_factor: Defaults to 2. Identical for all axes + mode: 'shuffle': pixel shuffle + 'conv' : interpolate -> 3x3 convolution + 'resize' : interpolated scaling + kernel_init: conv kernel init + dtype: conv dtype + """ + scaling_factor: int = 2 + mode: str = 'shuffle' + kernel_init: Optional[Initializer] = nn.initializers.lecun_normal() + dtype: Any = jnp.float32 + norm_32: bool = True + + @nn.compact + def __call__(self, x: Array) -> Array: + """ Upscales input by self.scaling_factor in HW dims assuming NHWC format """ + in_ch = x.shape[-1] + + if self.mode == 'resize' or self.mode == 'conv': + n, h, w, c = x.shape + h1 = h * self.scaling_factor + w1 = w * self.scaling_factor + x = jax.image.resize(x, (n, h1, w1, c), method='bilinear') + if self.mode == 'resize': + return x # early return for simple interpolation + + kernel=(3,3) + out_ch=x.shape[-1] + + elif self.mode == 'shuffle': + kernel=(1,1) + out_ch=x.shape[-1] * self.scaling_factor ** 2 + + else: + ValueError("Upsample mode must be \'resize\',\'conv\', or \ + \'shuffle\'. " + self.mode + " is not supported.") + exit() + + # 'conv' -> out_ch=in_ch, kernel=3 + # 'pix_shuffle' -> out_ch=in_ch * scaling_factor **2, kernel=1 + x = Conv(out_ch, + kernel_size=kernel, + strides=1, + dtype=self.dtype, + kernel_init=self.kernel_init, + name='upsample_convolution')(x) + + if self.mode == 'shuffle': + x = FP32Wrap(GroupNorm(num_groups=32, name='pix_shuffle_gn'), should_fp32=self.norm_32)(x) + x = _convert_to_activation_function('silu')(x) + + # shifting channel dims into square spatial pixel dims + return rearrange(x, 'b h w (s1 s2 c) -> b (h s1) (w s2) c', \ + s1=self.scaling_factor, s2=self.scaling_factor) + else: #mode == 'conv' + return x + + +class Downsample(nn.Module): + """ + Downsampling module done optionally with convolution + + Attributes: + scaling_factor: Defaults to 2. Identical for all axes + mode: 'shuffle': SP-conv from: https://arxiv.org/pdf/2208.03641.pdf. + Basically pixel unshuffle + 'conv' : strided convolution downsampling + 'resize' : average pooling + kernel_init: conv kernel init + dtype: conv dtype + """ + scaling_factor: int = 2 + mode: str = 'shuffle' + kernel_init: Optional[Initializer] = nn.initializers.lecun_normal() + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, x:Array) -> Array: + channels = x.shape[-1] + if self.mode == 'resize': + window_tuple = (self.scaling_factor, self.scaling_factor) + return nn.avg_pool(x, window_tuple, window_tuple) + elif self.mode == 'conv': + kernel_size = (3,3) + stride=self.scaling_factor + padding = 1 + elif self.mode == 'shuffle': + kernel_size = (1,1) + stride=1 + padding = 0 + x = rearrange(x, 'b (h s1) (w s2) c -> b h w (s1 s2 c)', + s1 = self.scaling_factor, s2 = self.scaling_factor) + else: + raise ValueError('Downsampling mode must be \'resize\', \'conv\',\ + or \'shuffle\'. ' + self.mode + " not supported") + + return Conv(channels, + kernel_size=kernel_size, + strides=stride, + padding=padding, + dtype=self.dtype, + kernel_init=self.kernel_init, + name='downsample_convolution')(x) \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/layers_sr.py b/rosetta/rosetta/projects/imagen/layers_sr.py new file mode 100644 index 000000000..aa4fa3711 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/layers_sr.py @@ -0,0 +1,299 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"Diffusion Efficient U-Net layers. Mostly used for superresolution" + +# pylint: disable=attribute-defined-outside-init,g-bare-generic + +#TODO Dropout +import dataclasses +import functools +import operator +from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +import jax +from jax import lax +from jax import random +import jax.numpy as jnp +import numpy as np +from einops import rearrange + +from t5x.contrib.gpu.t5.layers import MultiHeadDotProductAttention, make_attention_mask, combine_biases, dot_product_attention, RelativePositionBiases, MlpBlock +from flax.linen import DenseGeneral, GroupNorm, Conv, ConvTranspose, LayerNorm +from rosetta.projects.imagen.layers import FusedSelfCrossMultiHeadDotProductAttention + +param_with_axes = nn_partitioning.param_with_axes +with_sharding_constraint = nn_partitioning.with_sharding_constraint + +# Type annotations +Array = jnp.ndarray +DType = jnp.dtype +PRNGKey = jnp.ndarray +Shape = Iterable[int] +Activation = Callable[..., Array] +# Parameter initializers. +Initializer = Callable[[PRNGKey, Shape, DType], Array] +PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], + Tuple[lax.Precision, lax.Precision]] +PaddingLike = Union[str, int, Sequence[Union[int, Tuple[int, int]]]] +LaxPadding = Union[str, Sequence[Tuple[int, int]]] +Dtype = Any + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + +dynamic_vector_slice_in_dim = jax.vmap( + lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None)) + +def _convert_to_activation_function( + fn_or_string: Union[str, Callable]) -> Callable: + """Convert a string to an activation function.""" + if fn_or_string == 'linear': + return lambda x: x + elif isinstance(fn_or_string, str): + return getattr(nn, fn_or_string) + elif callable(fn_or_string): + return fn_or_string + else: + raise ValueError("don't know how to convert %s to an activation function" % + (fn_or_string,)) + +def FP32Wrap(operator, should_fp32=False): + if should_fp32: + def ret(x): + h = jnp.asarray(x, dtype=jnp.float32) + return jnp.asarray(operator(h), dtype=x.dtype) + return ret + else: + return operator + +class EfficientResBlock(nn.Module): + """ Residual block in the style of Saharia et. al. from Imagen + + Attributes: + out_channels: Output channel count + kernel_init: Kernel init function for embedding layers and first conv + """ + out_channels: int + dtype: Any = jnp.float32 + norm_32: bool = False + activation: str = 'silu' + cond_strategy: str = 'shift_scale' + dropout_rate: float = 0.0 + kernel_init: Optional[Initializer] = nn.initializers.lecun_normal() + zero_out: bool = True + scale_skip: bool = True + + @nn.compact + def __call__(self, inputs, + conditioning, + deterministic: bool=False): + """ Apply Efficient ResBlock. + + input shape: B H W C_model + conditioning: B C_cond_embedding + """ + + activation = _convert_to_activation_function(self.activation) + spatial_conv = functools.partial(Conv, + self.out_channels, + kernel_size=(3,3), + strides=(1,1), + padding='SAME', + dtype=self.dtype) + + # conditioning embedding calculation + cond_dim = self.out_channels * (2 if self.cond_strategy == 'shift_scale' else 1) + cond = activation(conditioning) + cond = DenseGeneral(cond_dim, + axis=-1, + dtype=self.dtype, + use_bias=True, + kernel_axes=('embed',))(cond) + cond = rearrange(cond, 'b c -> b 1 1 c') + + # in block + # ensure input channels % 32 == 0 + h = FP32Wrap(GroupNorm(num_groups=32, name='resblock_pre_in_conv_groupnorm'), should_fp32=self.norm_32)(inputs) + h = activation(h) + h = spatial_conv(kernel_init=self.kernel_init, name='resblock_in_conv')(h) + + # combine embedding + out_block + out_norm = FP32Wrap(GroupNorm(num_groups=32, name='resblock_pre_out_conv_groupnorm'), should_fp32=self.norm_32) + if self.cond_strategy == 'shift_scale': + h = out_norm(h) + # combine embedding + shift, scale = jnp.split(cond, 2, axis=-1) + h = h * (scale + 1) + shift + elif self.cond_strategy == 'addition': + h = h + cond + h = out_norm(h) + else: + NotImplementedError(self.cond_strategy + " conditioning strategy not implemented.\ + Use \'shift_scale\' or \'addition\' instead.") + h = activation(h) + h = nn.Dropout(rate=self.dropout_rate)(h, deterministic=deterministic) + out_init = nn.initializers.zeros if self.zero_out else self.kernel_init + h = spatial_conv(kernel_init=out_init, name='resblock_out_conv')(h) + + # residual channel adjustment + res = Conv(self.out_channels, + kernel_size=(1,1), + dtype=self.dtype, + kernel_init=self.kernel_init, + name='resblock_skip_conv')(inputs) + + # residual addition + if self.scale_skip: + res = res / jnp.sqrt(2) + return h + res + +class EfficientBlock(nn.Module): + """ + Down and Up Block from Saharia et. al. Imagen Efficient U-Net. + """ + out_channels: int + dtype: Any = jnp.float32 + norm_32: bool = False + activation: str = 'silu' + cond_strategy: str = 'shift_scale' + zero_out: bool = True + scale_skip: bool = True + dropout_rate: float = 0.0 + + use_attn: bool = False + attn_heads: int = 8 + mha_head_dim: int = 64 + attn_type: str = 'fused' # 'self', 'fused', or 'cross' + + up_down: str = 'down' # 'up', 'down' or None + num_resblocks: int = 2 + strides: tuple = (1,1) + + @nn.compact + def __call__(self, + inputs: Array, + conditioning: Array, + text_enc: Optional[Array]=None, + text_mask: Optional[Array]=None, + deterministic=False) -> Array: + # dblock in conv + if self.up_down == 'down': + inputs = Conv(self.out_channels, + kernel_size=(3,3), + strides=self.strides, + padding='SAME', + dtype=self.dtype, + name='dblock_in_conv')(inputs) + + h = inputs + for res_idx in range(self.num_resblocks): + h = EfficientResBlock(out_channels=self.out_channels, + dtype=self.dtype, + norm_32=self.norm_32, + activation=self.activation, + cond_strategy=self.cond_strategy, + dropout_rate=self.dropout_rate, + zero_out=self.zero_out, + scale_skip=self.scale_skip, + name=f'resblock_{res_idx}')(h, conditioning=conditioning, deterministic=deterministic) + if self.use_attn: + h = ImgAttentionBlock(attn_heads=self.attn_heads, + head_dim=self.mha_head_dim, + attn_type=self.attn_type, + dtype=self.dtype, + dropout_rate=self.dropout_rate, + scale_attn_logits=True, + name='attention')(h, text_enc, text_mask, deterministic=deterministic) + if self.up_down == 'up': + h = ConvTranspose(self.out_channels, kernel_size=(3,3), strides=self.strides, padding='SAME', dtype=self.dtype, name='ublock_out_conv')(h) + return h + +class ImgAttentionBlock(nn.Module): + """ Residual MHA block with normalization and reshaping for images + and optional text conditioning. + """ + norm_32: bool = False + attn_heads: int = 8 + head_dim: int = 64 + attn_type: str = 'fused' # 'self', 'fused', or 'cross' + float32_attention_logits: bool = False + scale_attn_logits: bool = True + zero_out: bool = True + dropout_rate: float = 0.0 + dtype: Any = jnp.float32 + + @nn.compact + def __call__(self, + inputs: Array, + text_enc: Optional[Array]=None, + text_mask: Optional[Array]=None, + deterministic=True) -> Array: + """ + Applies self-attention to an image an optionally text encodings. Assumes text_enc + has been normalized already. + + Args: + inputs: Images [b, h, w, c] + text_enc: Text encodings [b, l, c] + text_mask: Array of text masks (for masking) [b, l] + deterministic: Whether to enable dropout + """ + x = rearrange(inputs, 'b h w c -> b (h w) c') + x_q = FP32Wrap(GroupNorm(num_groups=32, name='img_attn_gn_q'), should_fp32=self.norm_32)(x) + x_kv = FP32Wrap(GroupNorm(num_groups=32, name='img_attn_gn_kv'), should_fp32=self.norm_32)(x) + + mask = None + if text_enc is not None: + text_enc = FP32Wrap(GroupNorm(num_groups=32, name='text_enc_ln'), should_fp32=self.norm_32)(text_enc) + if text_mask is None and text_enc is not None: + text_mask = jnp.ones((text_enc.shape[0], text_enc.shape[1])) + else: + if self.attn_type == 'cross': + raise ValueError('Cannot have both cross attention and no text conditioning.') + if self.attn_type == 'fused': + self.attn_type = 'self' + + m = jnp.ones((x.shape[0], x.shape[1])) + q = x_q + if self.attn_type == 'self': + mask = make_attention_mask(m, m, dtype=self.dtype) + kv_self = x_kv + kv_cross = None + elif self.attn_type == 'fused': + mt = jnp.concatenate((m, text_mask), axis = 1) + mask = make_attention_mask(m, mt, dtype=self.dtype) + kv_self = x_kv + kv_cross = text_enc + elif self.attn_type == 'cross': + mt = text_mask + mask = make_attention_mask(m, mt, dtype=self.dtype) + kv_self = text_enc + kv_cross = None + else: + raise NotImplementedError(f'attention type {self.attn_type} is not implemented. Please choose from self, cross, and fused') + + x = FusedSelfCrossMultiHeadDotProductAttention(num_heads=self.attn_heads, + head_dim=self.head_dim, + dtype=self.dtype, + dropout_rate=self.dropout_rate, + float32_logits=self.float32_attention_logits, + scale_attn_logits=self.scale_attn_logits, + zero_out=self.zero_out, + name='mha_layer')(q, kv_self, kv_cross, mask=mask, deterministic=deterministic) + + x = rearrange(x, 'b (h w) c -> b h w c', h=inputs.shape[1], w=inputs.shape[2]) + return x + inputs \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/network.py b/rosetta/rosetta/projects/imagen/network.py new file mode 100644 index 000000000..0b955b41c --- /dev/null +++ b/rosetta/rosetta/projects/imagen/network.py @@ -0,0 +1,265 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Diffusion Model Backbones""" + +from typing import Any, Sequence, Union, Iterable, Optional +import functools + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +from flax import struct +import jax.numpy as jnp +from t5x.contrib.gpu.t5.layers import MultiHeadDotProductAttention, make_attention_mask +from flax.linen import DenseGeneral, Conv, LayerNorm, GroupNorm +import jax +from einops import rearrange + +param_with_axes = nn_partitioning.param_with_axes +OptionalArray = Optional[jnp.ndarray] +Array = jnp.ndarray + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + +@struct.dataclass +class DiffusionConfig: + dtype: Any = jnp.float32 + model_dim: int = 64 + attn_cond_dim: int = 512 + cond_dim: Optional[int] = None #timestep/pooled text embedding channels. Defaults to model_dim*4 + resblocks_per_level: Union[int, Sequence[int]] = 2 + width_multipliers: Sequence[int] = (1, 2, 3, 4) + attn_resolutions: Sequence[int] = (32, 16, 8) + mha_head_dim: int = 64 + attn_heads: Union[int, Sequence[int]] = 2 + unified_qkv_norm: bool = False + deepfloyd_attnpooling: bool = False + resblock_activation: str = 'swish' + resblock_scale_skip: bool = False + output_ch: int = 3 + dropout_rate: float = 0.1 + + # modes are 'shuffle', 'conv' and 'resize'. + # (upsampling , downsampling) + # shuffle -> pixel shuffle / unshuffle + # conv -> resize + conv / strided conv + # resize -> interpolation / average pool + upsample_mode: str = 'shuffle' + downsample_mode: str = 'shuffle' + + # If True, block will use a 3x3 conv instead of 1x1 conv to correct + # channel dim mismatch in skip connection (if one exists) + spatial_skip: bool = False + + # 'shift_scale' or 'addition'. Strategy for incorporating the conditioning vector. + cond_strategy: str = 'shift_scale' + + # force groupnorm in fp32 + norm_32: bool = True + + # scale attention logits by \sqrt(head_dim) + scale_attn_logits: bool = True + float32_attention_logits: bool =False + text_conditionable: bool = True + null_text_emb_ct: int = 0 + +def single_or_idx(possibly_iter, idx): + if isinstance(possibly_iter, Sequence): + return possibly_iter[idx] + else: + return possibly_iter + +class ImagenUNet(nn.Module): + """ An Imagen diffusion U-net """ + config: DiffusionConfig + + @nn.compact + def __call__(self, images, time, + text_enc: OptionalArray=None, + text_lens: OptionalArray=None, + low_res_images: OptionalArray=None, + noise_aug_level: OptionalArray=None, + enable_dropout=True): + """ + Args: + images: samples to denoise. [b, h, w, c] + time: time conditioning. [b, 1] + text_enc: text embeddings (required for text_conditionable) [b, seq_len, embed] + text_lens: text sequence lengths in binary mask format [b, seq_len] + """ + from rosetta.projects.imagen import layers # to avoid circular import + cfg = self.config + activation = layers._convert_to_activation_function(cfg.resblock_activation) + deterministic=not enable_dropout + linear = functools.partial(DenseGeneral, + axis=-1, + use_bias=True, + dtype=cfg.dtype, + kernel_axes=('embed',)) + + spatial_conv = functools.partial(Conv, + kernel_size=(3,3), + strides=1, + padding='SAME', + dtype=cfg.dtype) + input_ch = images.shape[-1] + + # create time embedding + if cfg.cond_dim is None: + time_embed_dim = cfg.model_dim * 4 + else: + time_embed_dim = cfg.cond_dim + cond_embed = layers.get_timestep_embedding(time, cfg.model_dim, dtype=jnp.float32) + cond_embed = linear(features=time_embed_dim, name='time_dense_1')(cond_embed) + cond_embed = activation(cond_embed) + cond_embed = linear(features=time_embed_dim, name='time_dense_2')(cond_embed) + + if low_res_images is not None: + print('Low res images available. Running as a superresolution network') + if noise_aug_level is None: + jax.debug.print('noise_aug not given but it *really* should be') + noise_aug_level = jnp.ones_like(time) * (.25 * jnp.log(0.002)) # fallback EDM c_noise_fn on minimal noise + aug_embed = layers.get_timestep_embedding(noise_aug_level, cfg.model_dim, dtype=jnp.float32) + aug_embed = linear(features=time_embed_dim, name='aug_dense_1')(aug_embed) + aug_embed = activation(aug_embed) + aug_embed = linear(features=time_embed_dim, name='aug_dense_2')(aug_embed) + + cond_embed = cond_embed + aug_embed + + scaled_low_res = jax.image.resize(low_res_images, images.shape, "bicubic") + images = jnp.concatenate([images, scaled_low_res], axis=-1) + + # create attn pooled text embedding and project text to cond_dim + if cfg.text_conditionable: + assert text_lens is not None, "text_lens cannot be None. If you're trying to null condition, pass in a 0 mask" + assert text_enc is not None, "text_enc cannot be None. If you're trying to null condition, pass in 0s of appropriate shape. This network will add in the requisite null tokens" + + # setup null tokens + if cfg.null_text_emb_ct > 0: + null_text_mask = jnp.sum(text_lens, axis=1, keepdims=True) + + embedding_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) + null_text_embedding = param_with_axes( + 'null_text_embeddings', + embedding_init, + (cfg.null_text_emb_ct, text_enc.shape[-1]), + jnp.float32, + axes=('vocab', 'embed')) + + # first null_text_emb_ct tokens to null_text_embedding + null_text_enc = jnp.zeros_like(text_enc[0:1]).at[:, :cfg.null_text_emb_ct].set(null_text_embedding) + null_text_lens = jnp.zeros_like(text_lens[0:1]).at[:, :cfg.null_text_emb_ct].set(1) + text_enc = jnp.where(jnp.expand_dims(null_text_mask, axis=-1) > 0, text_enc, null_text_enc) + text_lens = jnp.where(null_text_mask > 0, text_lens, null_text_lens) + + # attention pooling + if isinstance(cfg.attn_heads, Iterable): + num_heads = text_enc.shape[-1] // cfg.mha_head_dim + else: + num_heads = None + + if cfg.deepfloyd_attnpooling: + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_0')(text_enc) + attn_pooled = layers.DeepFloydAttentionPoolingBlock(cfg=cfg, num_heads=num_heads)(attn_pooled, text_lens) + attn_pooled = linear(features=cfg.attn_cond_dim, name='attn_pool_dense_0')(attn_pooled) + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_1')(attn_pooled) + else: + attn_pooled = layers.AttentionPoolingBlock(cfg=cfg, num_heads=num_heads)(text_enc, text_lens) + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_0')(attn_pooled) + attn_pooled = linear(features=cfg.attn_cond_dim, name='attn_pool_dense_0')(attn_pooled) + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_1')(attn_pooled) + attn_pooled = linear(features=time_embed_dim, name='attn_pool_dense_1')(attn_pooled) + attn_pooled = activation(attn_pooled) + attn_pooled = linear(features=time_embed_dim, name='attn_pool_dense_2')(attn_pooled) + + cond_embed = attn_pooled + cond_embed # Has dimension of time_embed_dim + + # text embedding projection to cond_dim + text_enc = linear(features=cfg.attn_cond_dim, name='text_enc_projection')(text_enc) + + # make image_channel -> model_dim convolution + x = spatial_conv(features=cfg.model_dim * cfg.width_multipliers[0], name='unet_in_conv')(images) + + # down branch resblocks + attn on designated resolutions + down_outputs = [x] + for level, width_mult in enumerate(cfg.width_multipliers): + level_channels = cfg.model_dim * width_mult + img_res = images.shape[1] // (2 ** level) + + for res_idx in range(single_or_idx(cfg.resblocks_per_level, level)): + x = layers.ResBlock(cfg=cfg, out_channels=level_channels, up_down_sample='none', + name='resblock_enc_{}_{}'.format(img_res, res_idx)) \ + (x, cond_embed, deterministic=deterministic) + print("Encoder ResBlock #", res_idx, " at resolution: ", img_res, " level: ", level, " width: ", level_channels) + + # attend if image is of designated resolution + if img_res in cfg.attn_resolutions: + attn_idx = level - (len(cfg.width_multipliers) - len(cfg.attn_resolutions)) + x = layers.ImgAttentionBlock(cfg=cfg, num_heads=single_or_idx(cfg.attn_heads, attn_idx), name='attnblock_enc_{}_{}'.format(img_res, res_idx)) \ + (x, text_enc, text_lens, deterministic=deterministic) + print("SelfAttentionBlock #", res_idx, " at resolution: ", img_res, " level: ", level, " width: ", level_channels) + + down_outputs.append(x) + + # if not on the last level, downsample + if level != len(cfg.width_multipliers) - 1: + x = layers.ResBlock(cfg=cfg, out_channels=level_channels, up_down_sample='down', + name='resblock_enc_downsampling_{}_{}'.format(img_res, res_idx)) \ + (x, cond_embed, deterministic=deterministic) + print("Downsampling ResBlock at resolution (to half): ", img_res, " level: ", level, " width: ", level_channels) + down_outputs.append(x) + + # middle layers + mid_channels = cfg.model_dim * cfg.width_multipliers[-1] + x = layers.ResBlock(cfg=cfg, out_channels=mid_channels, up_down_sample='none', + name='resblock_mid_1')(x, cond_embed, deterministic=deterministic) + x = layers.ImgAttentionBlock(cfg=cfg, num_heads=single_or_idx(cfg.attn_heads, -1),name='attnblock_mid_1')(x, text_enc, text_lens, deterministic=deterministic) + x = layers.ResBlock(cfg=cfg, out_channels=mid_channels, up_down_sample='none', + name='resblock_mid_2')(x, cond_embed, deterministic=deterministic) + + print('Encoder Skip Shapes: ',list(map(lambda x : x.shape, down_outputs))) + # up branch resblocks + attn on designated resolutions + skip connections + for level, width_mult in list(enumerate(cfg.width_multipliers))[::-1]: + level_channels = cfg.model_dim * width_mult + img_res = images.shape[1] // (2 ** level) + + for res_idx in range(single_or_idx(cfg.resblocks_per_level, level) + 1): + u_skip = down_outputs.pop() + print("Decoder ResBlock #", res_idx, " at resolution: ", img_res, " level: ", level, " width: ", level_channels, " skip shape: ", u_skip.shape) + x = jnp.concatenate([x, u_skip], axis=-1) + x = layers.ResBlock(cfg=cfg, out_channels=level_channels, up_down_sample='none', + name='resblock_dec_{}_{}'.format(img_res, res_idx)) \ + (x, cond_embed, deterministic=deterministic) + + # attend if image is of designated resolution + if img_res in cfg.attn_resolutions: + print("SelfAttentionBlock #", res_idx, " at resolution: ", img_res, " level: ", level, " width: ", level_channels) + attn_idx = level - (len(cfg.width_multipliers) - len(cfg.attn_resolutions)) + x = layers.ImgAttentionBlock(cfg=cfg, num_heads=single_or_idx(cfg.attn_heads, attn_idx), name='attnblock_dec_{}_{}'.format(img_res, res_idx))\ + (x, text_enc, text_lens, deterministic=deterministic) + + # upsample if on last resblock and not the highest level + if res_idx == single_or_idx(cfg.resblocks_per_level, level) and level != 0: + print("Upsamling ResBlock at resolution (to double): ", img_res, " level: ", level, " width: ", level_channels, " no skip") + x = layers.ResBlock(cfg=cfg, out_channels=level_channels, up_down_sample='up', + name='resblock_dec_upsampling_{}_{}'.format(img_res, res_idx)) \ + (x, cond_embed, deterministic=deterministic) + + # out convolution model_dim -> image_channels + x = layers.FP32Wrap(GroupNorm(num_groups=32, name='unet_out_gn'))(x) + x = activation(x) + # x = spatial_conv(features=cfg.output_ch, name='unet_out_conv', kernel_init=nn.initializers.zeros, no_embed_axis=True)(x) + x = spatial_conv(features=cfg.output_ch, name='unet_out_conv', kernel_init=nn.initializers.zeros)(x) + return x \ No newline at end of file diff --git a/rosetta/rosetta/projects/imagen/network_sr.py b/rosetta/rosetta/projects/imagen/network_sr.py new file mode 100644 index 000000000..003992f01 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/network_sr.py @@ -0,0 +1,251 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Super resolution Diffusion Model Backbones""" + +from typing import Any, Sequence, Union, Iterable, Optional, Mapping +import functools + +from flax import linen as nn +from flax.linen import partitioning as nn_partitioning +from flax import struct +import jax.numpy as jnp +from t5x.contrib.gpu.t5.layers import MultiHeadDotProductAttention, make_attention_mask +from flax.linen import DenseGeneral, Conv, LayerNorm +import jax +from einops import rearrange + +param_with_axes = nn_partitioning.param_with_axes +OptionalArray = Optional[jnp.ndarray] +Array = jnp.ndarray + +default_embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0) + +def single_or_idx(possibly_iter, idx): + if isinstance(possibly_iter, Iterable): + return possibly_iter[idx] + else: + return possibly_iter + +@struct.dataclass +class ImagenEfficientUNetConfig: + dtype: Any = jnp.float32 + model_dim: int = 64 + cond_dim: int = 256 #timestep/pooled text embedding channels + attn_cond_dim: int = 256 + resblocks_per_level: Union[int, Iterable[int]] = (2, 4, 8, 8) + width_multipliers: Iterable[int] = (1, 2, 3, 4) + attn_resolutions_divs: Mapping[int, str] = None #{8: 'fused'} #use attn at resolutions of input / {elems of list} + mha_head_dim: int = 64 + attn_heads: Union[int, Sequence[int]] = 8 + resblock_activation: str = 'swish' + resblock_zero_out: bool = True + resblock_scale_skip: bool = True # enables scaling residuals by 1/\sqrt{2} + dropout_rate: float = 0.1 + + # 'shift_scale' or 'addition'. Strategy for incorporating the conditioning vector. + cond_strategy: str = 'shift_scale' + + # force groupnorm in fp32 + norm_32: bool = True + + # scale attention logits by \sqrt(head_dim) + scale_attn_logits: bool = True + float32_attention_logits: bool =False + text_conditionable: bool = True + null_text_emb_ct: int = 0 + +class ImagenEfficientUNet(nn.Module): + """ An Imagen diffusion U-net """ + config: ImagenEfficientUNetConfig + + @nn.compact + def __call__(self, images, time, + text_enc: OptionalArray=None, + text_lens: OptionalArray=None, + low_res_images: OptionalArray=None, + noise_aug_level: OptionalArray=None, + enable_dropout=True): + """ + Args: + images: samples to denoise. [b, h, w, c] + time: time conditioning. [b, 1] + text_enc: text embeddings (required for text_conditionable) [b, seq_len, embed] + text_lens: text sequence lengths in binary mask format [b, seq_len] + """ + from rosetta.projects.imagen import layers_sr # to avoid circular import + from rosetta.projects.imagen import layers # to avoid circular import + cfg = self.config + activation = layers_sr._convert_to_activation_function(cfg.resblock_activation) + deterministic=not enable_dropout + linear = functools.partial(DenseGeneral, + axis=-1, + use_bias=True, + dtype=cfg.dtype, + kernel_axes=('embed',)) + + spatial_conv = functools.partial(Conv, + kernel_size=(3,3), + strides=(1,1), + padding='SAME', + dtype=cfg.dtype) + input_ch = images.shape[-1] + + # create time embedding + time_embed_dim = cfg.cond_dim + cond_embed = layers.get_timestep_embedding(time, cfg.model_dim, dtype=jnp.float32) + cond_embed = linear(features=time_embed_dim, name='time_dense_1')(cond_embed) + cond_embed = activation(cond_embed) + cond_embed = linear(features=time_embed_dim, name='time_dense_2')(cond_embed) + + if low_res_images is not None: + print('Low res images available. Running as a superresolution network') + if noise_aug_level is None: + print('noise_aug not given but it *really* should be') + noise_aug_level = jnp.ones_like(time) * (.25 * jnp.log(0.002)) # fallback EDM c_noise_fn on minimal noise + aug_embed = layers.get_timestep_embedding(noise_aug_level, cfg.model_dim, dtype=jnp.float32) + aug_embed = linear(features=time_embed_dim, name='aug_dense_1')(aug_embed) + aug_embed = activation(aug_embed) + aug_embed = linear(features=time_embed_dim, name='aug_dense_2')(aug_embed) + + cond_embed = cond_embed + aug_embed + + scaled_low_res = jax.image.resize(low_res_images, images.shape, "bicubic") + images = jnp.concatenate([images, scaled_low_res], axis=-1) + + # create attn pooled text embedding and project text to cond_dim + if cfg.text_conditionable: + assert text_lens is not None, "text_lens cannot be None. If you're trying to null condition, pass in a 0 mask" + assert text_enc is not None, "text_enc cannot be None. If you're trying to null condition, pass in 0s of appropriate shape. This network will add in the requisite null tokens" + + # setup null tokens + if cfg.null_text_emb_ct > 0: + null_text_mask = jnp.sum(text_lens, axis=1, keepdims=True) + + embedding_init = nn.initializers.variance_scaling(1.0, 'fan_in', 'normal', out_axis=0) + null_text_embedding = param_with_axes( + 'null_text_embeddings', + embedding_init, + (cfg.null_text_emb_ct, text_enc.shape[-1]), + jnp.float32, + axes=('vocab', 'embed')) + + # first null_text_emb_ct tokens to null_text_embedding + null_text_enc = jnp.zeros_like(text_enc[0:1]).at[:, :cfg.null_text_emb_ct].set(null_text_embedding) + null_text_lens = jnp.zeros_like(text_lens[0:1]).at[:, :cfg.null_text_emb_ct].set(1) + text_enc = jnp.where(jnp.expand_dims(null_text_mask, axis=-1) > 0, text_enc, null_text_enc) + text_lens = jnp.where(null_text_mask > 0, text_lens, null_text_lens) + + # attention pooling + attn_pooled = layers.AttentionPoolingBlock(cfg=cfg)(text_enc, text_lens) + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_0')(attn_pooled) + attn_pooled = linear(features=cfg.cond_dim, name='attn_pool_dense_0')(attn_pooled) + attn_pooled = LayerNorm(dtype=jnp.float32 if cfg.norm_32 else cfg.dtype, name='attn_pool_ln_1')(attn_pooled) + attn_pooled = linear(features=time_embed_dim, name='attn_pool_dense_1')(attn_pooled) + attn_pooled = activation(attn_pooled) + attn_pooled = linear(features=time_embed_dim, name='attn_pool_dense_2')(attn_pooled) + + cond_embed = attn_pooled + cond_embed # Has dimension of time_embed_dim + + # text embedding projection to cond_dim + text_enc = linear(features=cfg.attn_cond_dim, name='text_enc_projection')(text_enc) + + # make image_channel -> model_dim convolution (and/or CrossEmbedLayer inception style?) + x = spatial_conv(features=cfg.model_dim * cfg.width_multipliers[0], name='unet_in_conv')(images) + + # down branch resblocks + attn on designated resolutions + down_outputs = [] + use_attn, attn_type = None, None + + common_block_kwargs = { + 'dtype':cfg.dtype, + 'norm_32':cfg.norm_32, + 'activation':cfg.resblock_activation, + 'cond_strategy':cfg.cond_strategy, + 'dropout_rate':cfg.dropout_rate, + 'attn_heads':cfg.attn_heads, + 'mha_head_dim':cfg.mha_head_dim, + 'zero_out':cfg.resblock_zero_out, + 'scale_skip':cfg.resblock_scale_skip, + } + + resolution_div = 1 + for level, width_mult in enumerate(cfg.width_multipliers[:-1]): + level_channels = cfg.model_dim * width_mult + + use_attn = (2 ** level) in cfg.attn_resolutions_divs.keys() + attn_type = cfg.attn_resolutions_divs[2 ** level] if use_attn else None + x = layers_sr.EfficientBlock(out_channels=level_channels, + use_attn=use_attn, + attn_type=attn_type, + up_down='down', + num_resblocks=single_or_idx(cfg.resblocks_per_level, level), + strides=(2,2), + name=f'DBlock_l{level}', + **common_block_kwargs) \ + (x, cond_embed, text_enc, text_lens, deterministic=deterministic) + print("Encoder DBlock #", level, " out resolution: ", x.shape[1:3], " level: ", level, " width: ", level_channels) + down_outputs.append(x) + resolution_div = 2 ** level + + # middle layers + use_attn = (2 * resolution_div) in cfg.attn_resolutions_divs.keys() + attn_type = cfg.attn_resolutions_divs[2 ** (level + 1)] if use_attn else None + mid_channels = cfg.model_dim * cfg.width_multipliers[-1] + x = layers_sr.EfficientBlock(out_channels=mid_channels, + use_attn=use_attn, + attn_type=attn_type, + up_down='down', + num_resblocks=single_or_idx(cfg.resblocks_per_level, -1), + strides=(1,1), + name='dblock_mid', + **common_block_kwargs) \ + (x, cond_embed, text_enc, text_lens, deterministic=deterministic) + + x = layers_sr.EfficientBlock(out_channels=mid_channels, + use_attn=use_attn, + attn_type=attn_type, + up_down='up', + num_resblocks=single_or_idx(cfg.resblocks_per_level, -1), + strides=(1,1), + name='ublock_mid', + **common_block_kwargs) \ + (x, cond_embed, text_enc, text_lens, deterministic=deterministic) + + print('Encoder Skip Shapes: ',list(map(lambda x : x.shape, down_outputs))) + # up branch resblocks + attn on designated resolutions + skip connections + for level, width_mult in list(enumerate(cfg.width_multipliers[:-1]))[::-1]: + level_channels = cfg.model_dim * width_mult + + # for res_idx in range(cfg.resblocks_per_level + 1): + u_skip = down_outputs.pop() + x = jnp.concatenate([x, u_skip], axis=-1) + use_attn = (2 ** level) in cfg.attn_resolutions_divs.keys() + attn_type = cfg.attn_resolutions_divs[2 ** level] if use_attn else None + x = layers_sr.EfficientBlock(out_channels=level_channels, + use_attn=use_attn, + attn_type=attn_type, + up_down='up', + num_resblocks=single_or_idx(cfg.resblocks_per_level, level), + strides=(2,2), + name=f'UBlock_l{level}', + **common_block_kwargs) \ + (x, cond_embed, text_enc, text_lens, deterministic=deterministic) + print("Decoder UBlock #", level, " out resolution: ", x.shape[1:3], " level: ", level, " width: ", level_channels, " skip shape: ", u_skip.shape) + + # out convolution model_dim -> image_channels + x = spatial_conv(features=input_ch, name='unet_out_conv', kernel_init=nn.initializers.zeros)(x) + return x + diff --git a/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub new file mode 100755 index 000000000..6da0ccde1 --- /dev/null +++ b/rosetta/rosetta/projects/imagen/scripts/example_slurm_inf_train.sub @@ -0,0 +1,107 @@ +#!/bin/bash +#SBATCH -A +#SBATCH -p +#SBATCH -N 1 # number of nodes +#SBATCH -t 04:00:00 # wall time (8 for backfill, 4 for Luna) +#SBATCH -J # job name (<< CHANGE ! >>) +#SBATCH --exclusive # exclusive node access +#SBATCH --mem=0 # all mem avail +#SBATCH --mail-type=FAIL # only send email on failure +#SBATCH --overcommit # Needed for pytorch +#SBATCH --dependency=singleton +set -x + +# File system and volume glue code +#------------------------------------------------------------------------------- +# << CHANGE ! >> +SLURM_ACCOUNT= +USERID= + +# << CHANGE ! >> +CONTAINER= + +# << CHANGE ! >> +BASE_ROSETTA_DIR="/jax-toolbox-mirror/rosetta/" # path to your clone of the repo +BASE_DATA_DIR="/datasets/" +BASE_WORKSPACE_DIR="${BASE_ROSETTA_DIR}/workspace" # path to where outputs will be dumped +BASE_HOSTNAME_COMM="${BASE_WORKSPACE_DIR}/outputs/multinode/communicators/${SLURM_JOB_ID}-inf-server-comms/" + +# Default env variables for paths required by t5x training scripts +DATA_DIR=/mnt/datasets/ +ROSETTA_DIR=/opt/rosetta/ +WORKSPACE_DIR=/opt/rosetta/workspace +HOSTNAMES_DIR=/inference_srv/ +HOSTNAMES_FILE=${HOSTNAMES_DIR}/hostnames.txt + +# Add the T5x/JAX specific mounts +MOUNTS="--container-mounts=$BASE_ROSETTA_DIR:$ROSETTA_DIR,$BASE_DATA_DIR:$DATA_DIR,$BASE_WORKSPACE_DIR:$WORKSPACE_DIR,$BASE_HOSTNAME_COMM:$HOSTNAMES_DIR" + +# Add T5x/JAX specific exports +EXPORTS="--export=ALL,DATA_DIR=${DATA_DIR},ROSETTA_DIR=${ROSETTA_DIR},WORKSPACE_DIR=${WORKSPACE_DIR}" +#------------------------------------------------------------------------------- + +# Command line arguments needed by the underlying scripts +DATASET=$1 +T5_SIZE=$2 # base +PREC="$3" # bfloat16, float32 +GPUS_PER_NODE=$4 # usually 8 +BSIZE_PER_GPU=$5 # local batch size/gpu +MODEL_DIR_LOCAL=$6 # directory to save checkpoints and config dump to +INF_SERV_CT=$7 # number of inference server processes +INF_SIZE=${8:-"xxl"} # t5 model size of inference server +NUM_MICROBATCHES=${9} # number of gradient accumulation steps +MP=${10} # tensor parallel count + +NUM_GPUS=$(( GPUS_PER_NODE * SLURM_JOB_NUM_NODES )) + +# remove hostnames file if there are no inference servers +if [ -z "${INF_SERV_CT}" ] || [ "${INF_SERV_CT}" -eq 0 ]; then + HOSTNAMES_FILE=None +fi + + +# << CHANGE ! >> +# You can add binding to the command below with the following line (after nvidia-smi). Remove the '&&' on the next bash line. +# && bash <>/bind.sh --cpu=exclusive --ib=single -- \ +read -r -d '' train_cmd < ${LOG_DIR}/${MODEL_TYPE}_${DATASET}_gpu_${TRAIN_GPUS}_${PREC}_gbs_${BSIZE}-${PROC_ID}.log & diff --git a/rosetta/rosetta/projects/imagen/scripts/specialized_run.py b/rosetta/rosetta/projects/imagen/scripts/specialized_run.py new file mode 100755 index 000000000..4ce97a56d --- /dev/null +++ b/rosetta/rosetta/projects/imagen/scripts/specialized_run.py @@ -0,0 +1,80 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import argparse +import subprocess +import os +import fcntl +import time + +parser = argparse.ArgumentParser(description='Run either training or inference server based on process ID') + +parser.add_argument('--proc_id', type=int, required=False, default=-1) +parser.add_argument('--proc_total_ct', type=int, required=True) +parser.add_argument('--inf_server_ct', type=int, required=True) +parser.add_argument('--gpus_per_node', type=int, default=-1, required=False) #Needed for multinode +parser.add_argument('--gpu_collection_size', type=int, default=1, required=False) + +parser.add_argument('--train_run_command', type=str, required=True) +parser.add_argument('--inf_server_run_command', type=str, required=True) +parser.add_argument('--hostnames_file', type=str, required=True) + +parser.add_argument('--inf_log_file', type=str, required=False) + +args = parser.parse_args() + +train_servers = args.proc_total_ct - args.inf_server_ct + +PROCESS_ID = args.proc_id if args.proc_id >= 0 else None + +if PROCESS_ID is None and os.getenv('SLURM_PROCID') is not None: + PROCESS_ID = int(os.getenv('SLURM_PROCID')) + if PROCESS_ID is None: + raise ValueError("Failed to get process ID when specializing") + +gpus_in_device = args.gpus_per_node if args.gpus_per_node > 0 else args.proc_total_ct +local_id = PROCESS_ID % gpus_in_device +# one inference server per node case +if PROCESS_ID == train_servers or (PROCESS_ID >= train_servers and local_id == 0): + inf_id = PROCESS_ID - train_servers + hostname = subprocess.check_output(["hostname", "-I"]) + hostname = hostname.split()[0].decode("utf-8") + port = 2345 + inf_id + hostname = hostname + ':' + str(port) + '\n' + + devices = [str(i) for i in range(local_id, gpus_in_device)] + device_str = ','.join(devices) + + with open(args.hostnames_file, 'a') as hf: + # Should be under buffer size on hostname, preventing + # writing race conditions, but will lock to be safe. + fcntl.flock(hf, fcntl.LOCK_EX) + hf.write(hostname) + fcntl.flock(hf, fcntl.LOCK_UN) + + inf_command_withargs = args.inf_server_run_command + f' --port={port} --devices={device_str} --total_device_first_idx={inf_id}' + # if args.inf_log_file is not None: + # inf_command_withport += f' &> {args.inf_log_file}' + + print("Inference Command: " + inf_command_withargs) + if args.inf_log_file is not None: + with open(args.inf_log_file, 'w') as f: + subprocess.call(inf_command_withargs, stdout=f, stderr=f, shell=True) + else: + os.system(inf_command_withargs) + + +# train server case +elif PROCESS_ID < train_servers: + time.sleep(10) + os.system(f'PROC_ID={PROCESS_ID} ' + args.train_run_command) diff --git a/rosetta/rosetta/projects/inference_serving/configs/t5_large_server.yml b/rosetta/rosetta/projects/inference_serving/configs/t5_large_server.yml new file mode 100644 index 000000000..ac19c6b56 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/configs/t5_large_server.yml @@ -0,0 +1,38 @@ +--- +models: + t5_large: + # Model config + max_bs: + a100_80g: 4096 + a6000: 2048 + gv100_32g: 1024 + default: null + find_max_bs: + cache_dir: "/opt/rosetta/server_bs_cache.json" #null to disable reading/writing batch size information to cache + + # Command to start a server per gpu group + run_command: "/opt/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.sh large" + gpus_per_process: 1 + + # PyTriton config + inputs: + - sequence: + dtype: "bytes_" + shape: !!python/tuple [-1] + outputs: + - encodings_padded: + dtype: "float16" + shape: !!python/tuple [-1] + - encodings_seqlens: + dtype: 'int32' + shape: !!python/tuple [-1] + + # "static" or "dymanic" + batching: + "dynamic" + + # specify fraction or absolute count of allocated GPUs to commit here. + resources: + fraction: 1.0 + count: null +... diff --git a/rosetta/rosetta/projects/inference_serving/configs/t5_xxl_server.yml b/rosetta/rosetta/projects/inference_serving/configs/t5_xxl_server.yml new file mode 100644 index 000000000..e12a67e8f --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/configs/t5_xxl_server.yml @@ -0,0 +1,39 @@ +--- +models: + t5_xxl: + # Model config + max_bs: + a100_80g: 1024 + a100_40g: 512 + a6000: 512 + gv100_32g: 272 + default: null + find_max_bs: + cache_dir: "/opt/rosetta/server_bs_cache.json" #null to disable reading/writing batch size information to cache + + # Command to start a server per gpu group + run_command: "/opt/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.sh xxl" + gpus_per_process: 1 + + # PyTriton config + inputs: + - sequence: + dtype: "bytes_" + shape: !!python/tuple [-1] + outputs: + - encodings_padded: + dtype: 'float16' + shape: !!python/tuple [-1] + - encodings_seqlens: + dtype: 'int32' + shape: !!python/tuple [-1] + + # "static" or "dymanic" + batching: + "dynamic" + + # specify fraction or absolute count of allocated GPUs to commit here. + resources: + fraction: 1.0 + count: null +... diff --git a/rosetta/rosetta/projects/inference_serving/server.py b/rosetta/rosetta/projects/inference_serving/server.py new file mode 100644 index 000000000..a94b58c6a --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/server.py @@ -0,0 +1,330 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import logging +import time +import os +import signal +from typing import List, Dict +import pickle as pkl +import copy +import argparse +import functools +import yaml +from yaml import Loader + +import uuid +import zmq +import subprocess + +from pytriton.model_config import ModelConfig, Tensor, DynamicBatcher +from pytriton.triton import Triton, TritonConfig +from pytriton.decorators import batch + +from rosetta.projects.inference_serving.server_utils import pow2list, triton_textencode +from rosetta.projects.inference_serving.shared_numpy import SharedNPDict + +# List of strings of comma-separated device indexes. i.e. ['0', '1', '2'] or ['0,1', '2,3'] +# Each list element contains the CUDA_VISIBLE_DEVICES visible to an inference process +ModelDevicesType = List[str] + +# ZMQ-based infer function. Sends input over socket and reads output return +def infer_fn(socket, **inputs: np.ndarray): + # start_time = time.time() + # out = [np.array([pkl.dumps(np.zeros((4096, 128), dtype=np.float32))] * list(inputs.values())[0].shape[0])] + # out = [np.zeros((list(inputs.values())[0].shape[0], 4096*128), dtype=np.float32)] + # logging.warning(f'time to create out {time.time() - start_time}') + # return out + for k, v in inputs.items(): + #logging.warning(f'inferring on type {v.dtype}') + inputs[k] = np.array([pkl.dumps(v)]) + start_time = time.time() + shared_inputs = SharedNPDict(dict_to_share=inputs) + socket.send_pyobj(shared_inputs.get_metas()) + out = socket.recv_pyobj() + shared_inputs.close_and_unlink() + # logging.warning(f'[triton] time to backend {time.time() - start_time}') + # out_time = time.strftime('%X %x %Z') + # logging.warning(f'outtime {out_time}') + if isinstance(out, str): + return out + out = SharedNPDict(metadata=out).localize(close_shared=True, unlink_shared=True)#['out'] + #logging.warning(out) + # return [out['padded_outs'], out['seqlens']] + return out + +def get_infer_fns(device_struct: ModelDevicesType, child_command:str): + sockets = [] + infer_fns = [] + for dl in device_struct: + ctx = zmq.Context(io_threads=1) + socket_addr = "ipc:///tmp/pytriton_multi-" + str(uuid.uuid4()) + + socket = ctx.socket(zmq.REQ) + socket.bind(socket_addr) + + sockets.append(socket) + + subprocess.Popen(f'SOCKET_ADDRESS={socket_addr} CUDA_VISIBLE_DEVICES={dl} {child_command} &', shell=True) + + for sock in sockets: + infer_fns.append(batch(functools.partial(infer_fn, socket=sock))) + logging.info(f'Built infer_fn for {sock}') + + return infer_fns + +def find_model_max_bs(devices: str, server_config:dict, model_name:str): + logging.basicConfig(level=logging.INFO) + logging.info(f"Finding the maximum batch size for model: {model_name}") + + + ctx = zmq.Context(io_threads=1) + socket_addr = "ipc:///tmp/pytriton_MAX_BS_multi-" + str(uuid.uuid4()) + socket = ctx.socket(zmq.REQ) + socket.bind(socket_addr) + + command = server_config['models'][model_name]['run_command'] + proc = subprocess.Popen(f'SOCKET_ADDRESS={socket_addr} CUDA_VISIBLE_DEVICES={devices} {command} &', shell=True, preexec_fn=os.setsid) + time.sleep(5) #setup time + + lower = 1 + test = 256 + upper_fail = None + + while upper_fail is None or ((float(upper_fail) / lower >= 1.125) and upper_fail - lower > 1): + logging.info(f'Trying bs {test}') + socket.send_pyobj({'singleton': test}) + out = socket.recv_pyobj() + + if isinstance(out, str): + # logging.warning(f'bs: {test} failed with error: {out}') + if test == 1: + logging.info('bs of 1 has failed. Exiting') + exit() + else: + upper_fail = test + test = (lower + upper_fail) // 2 + else: + logging.info(f'bs: {test} succeeded') + if upper_fail is None: + lower = test + test *= 2 + else: + lower = test + test = (lower + upper_fail) // 2 + + logging.info(f'New lower: {lower}, test: {test}, upper: {upper_fail}') + + socket.close() + os.killpg(os.getpgid(proc.pid), signal.SIGKILL) + return lower + +def get_batchsize(device_struct:ModelDevicesType, gpu_name:str, server_config:dict, model_name:str): + devices = device_struct[0] + if gpu_name is None: + logging.info("No gpu_name given. Finding max_bs") + return find_model_max_bs(devices, server_config, model_name) + + for config_gpu, max_bs in server_config['models'][model_name]['max_bs'].items(): + if gpu_name in config_gpu: + logging.info(f"Matched gpu name: {gpu_name} to configuration {config_gpu} with max_bs {max_bs}") + if max_bs is None: + logging.info("Since found max_bs is None, finding max_bs") + max_bs = find_model_max_bs(devices, server_config, model_name) + return max_bs + + logging.info(f'GPU name {gpu_name} not found in config. Finding max_bs') + return find_model_max_bs(devices, server_config, model_name) + +def config_to_tensor(triton_in_out_config): + out = [] + for inp in triton_in_out_config: + first_key = list(inp.keys())[0] + out += [ + Tensor(name=first_key, + dtype=np.dtype(inp[first_key]['dtype']), + shape=inp[first_key]['shape']) + ] + return out + +def triton_run(port:int, device_structs:Dict[str, ModelDevicesType], gpu_name:str, server_config:dict): + logging.warning(f'port {port}, devices {device_structs}') + triton_config = TritonConfig(http_port=port, grpc_port=port+1000, metrics_port=port+2000, log_verbose=0) + with Triton(config=triton_config) as triton: + for model_name in server_config['models'].keys(): + model_cfg = server_config['models'][model_name] + logging.warning(f'Setting up model {model_name} with configuration: {model_cfg}') + batch_size = get_batchsize(device_structs[model_name], gpu_name, server_config, model_name) + logging.warning(f'Using batch size {batch_size}') + + infer_fns = get_infer_fns(device_structs[model_name], model_cfg['run_command']) + + dyn_batch = DynamicBatcher( + max_queue_delay_microseconds = 100000, + preferred_batch_size = pow2list(batch_size) + ) + + triton.bind( + model_name=model_name, + infer_func=infer_fns, + inputs=config_to_tensor(model_cfg['inputs']), + outputs=config_to_tensor(model_cfg['outputs']), + config=ModelConfig(max_batch_size=batch_size, batching=True, batcher=dyn_batch), + ) + triton.serve() + +def build_visible_device_structs(devices_available, + total_devices, + total_device_first_idx, + server_config) -> Dict[str, ModelDevicesType]: + def single_model_devices(model_name, devices_available): + device_list = [] + devices_per_process = server_config['models'][model_name]['gpus_per_process'] + + # Number of devices must be divisible by the number of devices per process + assert len(devices_available) % devices_per_process == 0 + + for proc_id in range(len(devices_available) // devices_per_process): + device_list.append(','.join(devices_available[proc_id * devices_per_process: (proc_id + 1) * devices_per_process])) + + return device_list + + if len(server_config['models'].keys()) > 1: + assert total_devices is not None, "total_devices must be given if using more than one model" + assert total_device_first_idx is not None, "total_device_first_idx must be given if using more than one model" + + # Figuring out gpus per model + device_ctr = 0 + device_counts: dict[str, int] = {} + gpus_per_model_proc = {} + for model_name in server_config['models'].keys(): + model_cfg_resources = server_config['models'][model_name]['resources'] + gpus_per_model_proc[model_name] = server_config['models'][model_name]['gpus_per_process'] + if model_cfg_resources['fraction'] is not None: + count = round(total_devices * model_cfg_resources['fraction']) + elif model_cfg_resources['count'] is not None: + count = model_cfg_resources['count'] + else: + assert False, f'No resources specified for {model_name}' + + count = count // gpus_per_model_proc[model_name] * gpus_per_model_proc[model_name] + device_ctr += count + device_counts[model_name] = count + + model_names = list(device_counts.keys()) + if device_ctr > total_devices: + logging.info(f"Current resource specification is using too many devices! ({device_counts}; \ + Available: {total_devices} This can be due to rounding errors with fractional resources \ + or too many devices specified under 'count'. Reducing devices until program can run") + + idx = 0 + since_last_update = 0 + while device_counts > total_devices: + model_name = model_names[idx] + if device_counts[model_name] > gpus_per_model_proc[model_name]: + device_counts[model_name] -= gpus_per_model_proc[model_name] + device_ctr -= gpus_per_model_proc[model_name] + since_last_update = 0 + else: + since_last_update += 1 + idx += 1 + idx %= len(model_names) + if since_last_update > len(model_names): + assert False, "There are not enough devices to run 1 process of each model" + + if device_ctr < total_devices: + logging.info(f'Warning, {total_devices - device_ctr} devices idle') + + logging.info(f'Device arrangement: {device_counts}') + + before_this_host = total_device_first_idx + while before_this_host > 0: + for model in model_names: + if device_counts[model] > 0: + device_counts[model] -= gpus_per_model_proc[model] + before_this_host -= gpus_per_model_proc[model] + break + + logging.warning(f'gpus_per_process {gpus_per_model_proc}, device_counts {device_counts}, devices_available {devices_available}') + devices_per_model: Dict[str, List[int]] = {} + remaining_devices = copy.deepcopy(devices_available) + while len(remaining_devices) > 0: + for model in model_names: + logging.warning(f'remaining_devices {remaining_devices}, gpus_per_model_proc {gpus_per_model_proc}') + if device_counts[model] > 0: + assert gpus_per_model_proc[model] <= len(remaining_devices), "Cannot evenly fit model onto this host" + if model not in devices_per_model.keys(): + devices_per_model[model] = [] + devices_per_model[model] += remaining_devices[:gpus_per_model_proc[model]] + remaining_devices = remaining_devices[gpus_per_model_proc[model]:] + break + + # final construction + visible_device_structs = {} + for model, devices in devices_per_model.items(): + visible_device_structs[model] = single_model_devices(model, devices) + + return visible_device_structs + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='PyTriton Inference server with many GPUs communicating over zmq') + parser.add_argument( + '--port', + type=int, + default=1234, + help="port for server") + parser.add_argument( + '--devices', + type=str, + required=True, + help="Comma-separated list of GPU indexes available") + parser.add_argument( + '--total_devices', + type=int, + required=False, + help="Total number of inferencing devices. This is required when using multiple models.") + parser.add_argument( + '--total_device_first_idx', + type=int, + required=False, + help="Index of first device out of all inference devices. I.e. if this is the second 8-gpu host \ + doing inference, then this argument should be 8. Required for multiple model inference") + parser.add_argument( + '--gpu_name', + type=str, + required=False, + help="Used to match up to batch size configs. This program will check if any keys under 'max_bs' \ + are substrings of this name and use the first hit. I.e. a100_80g. If no default is set, \ + it will run the max batch size finder") + parser.add_argument( + '--config_file', + type=str, + required=True, + help='YAML configuration for this server') + + args = parser.parse_args() + + with open(args.config_file, 'r') as f: + server_config = yaml.load(f.read(), Loader=Loader) + + # Figure out devices I can use + all_devices = args.devices.split(',') + visible_device_structs = \ + build_visible_device_structs(all_devices, args.total_devices, \ + args.total_device_first_idx, server_config) + + triton_run(args.port, visible_device_structs, args.gpu_name, server_config) + diff --git a/rosetta/rosetta/projects/inference_serving/server_utils.py b/rosetta/rosetta/projects/inference_serving/server_utils.py new file mode 100644 index 000000000..66e6aaecb --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/server_utils.py @@ -0,0 +1,33 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from typing import List + +def pow2list(n: int): + pow2 = [] + i = 1 + while i < n: + pow2.append(i) + i *= 2 + + pow2.append(n) + return pow2 + +def triton_textencode(text_batch: List[str]): + enc = np.array([[np.char.encode(i, 'utf-8')] for i in text_batch]) + enc = np.reshape(enc, (enc.shape[0], 1)) + + return enc + diff --git a/rosetta/rosetta/projects/inference_serving/shared_numpy.py b/rosetta/rosetta/projects/inference_serving/shared_numpy.py new file mode 100644 index 000000000..ba107c182 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/shared_numpy.py @@ -0,0 +1,141 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +from dataclasses import dataclass +from typing import Tuple, Optional, Dict +from multiprocessing import shared_memory +import logging +import time + +@dataclass +class SharedNPMeta: + shmem_name: str + shape: Tuple[int] + dtype: np.dtype + +class SharedNPArray: + metadata: SharedNPMeta + shmem: shared_memory.SharedMemory + array: np.ndarray + closed: bool = False + name: Optional[str] = None + + def __init__(self, arr_to_share: Optional[np.ndarray]=None, metadata: Optional[SharedNPMeta]=None, name: Optional[str]=None): + if arr_to_share is not None: + start_time = time.time() + # Creates shared memory array + assert metadata is None, "Please provide either an array_to_share or metadata" + + nbytes = arr_to_share.nbytes + self.shmem = shared_memory.SharedMemory(create=True, size=arr_to_share.nbytes) + # logging.warning(f'creating {self.shmem.name} with {arr_to_share.nbytes} bytes') + self.metadata = SharedNPMeta(shmem_name=self.shmem.name, + shape=arr_to_share.shape, + dtype=arr_to_share.dtype) + self.array = np.ndarray(arr_to_share.shape, arr_to_share.dtype, buffer=self.shmem.buf) + self.array[:] = arr_to_share#[:] + # logging.warning(f'just shared {self.array}') + logging.warning(f'time to share {nbytes} {time.time() - start_time}') + else: + # Makes local array with given shared memory + assert metadata is not None, "Please provide either an array_to_share or metadata" + start_time = time.time() + self.metadata = metadata + # print(f'recv side shmem name {metadata.shmem_name}') + self.shmem = shared_memory.SharedMemory(name=metadata.shmem_name) + # logging.warning(f'getting {self.shmem.name}') + self.array = np.ndarray(metadata.shape, dtype=metadata.dtype, buffer=self.shmem.buf) + logging.warning(f'time to recieve {time.time() - start_time}') + self.name = name + + def __repr__(self): + return f'SharedNPArray: name:{self.name}, meta{self.metadata}, closed{self.closed}' + + def localize(self, close_shared=True, unlink_shared=False): + # dump contents into local (unshared memory) + #logging.warning(f'self array {self.array}') + start_time = time.time() + new_array = np.array(self.array, copy=True) + # logging.warning(f'localizing {self.name}') + if close_shared: + self.close() + if unlink_shared: + self.unlink() + logging.warning(f'time to localize {time.time() - start_time}') + return new_array + + def close(self): + if (self.closed): + raise ValueError(f"ERROR: Trying to close an array {self.name} {self.metadata} that has already been closed here") + self.shmem.close() + self.closed = True + del self.array + + def unlink(self): + self.shmem.unlink() + + +class SharedNPDict: + arrays: Dict[str, SharedNPArray] + + def __init__(self, dict_to_share: Optional[Dict[str, np.ndarray]]=None, metadata: Optional[Dict[str, SharedNPMeta]]=None): + self.arrays = {} + if dict_to_share is not None: + # Creates shared memory array + assert metadata is None, "Please provide either an dict_to_share or metadata" + assert isinstance(dict_to_share, dict), f"Dict to share must be a dictionary. got {type(dict_to_share)}" + + for k, v in dict_to_share.items(): + self.arrays[k] = SharedNPArray(arr_to_share=v, name=k) + else: + # Makes local array with given shared memory + assert metadata is not None, "Please provide either an array_to_share or metadata" + for k, v in metadata.items(): + shared_arr = SharedNPArray(metadata=v, name=k) + self.arrays[k] = shared_arr + + def __repr__(self): + out_dict = {} + for k, v in self.arrays.items(): + out_dict[k] = str(v) + return str(out_dict) + + def localize(self, close_shared=False, unlink_shared=False): + # dump contents into local (unshared memory) + # logging.warning(f'I am {self.__repr__()}') + out_dict = {} + for k, v in self.arrays.items(): + local_arr = v.localize(close_shared=close_shared, unlink_shared=unlink_shared) + out_dict[k] = local_arr + return out_dict + + def close(self): + for _, v in self.arrays.items(): + v.close() + + def unlink(self): + for _, v in self.arrays.items(): + v.unlink() + + def close_and_unlink(self): + for _, v in self.arrays.items(): + v.close() + v.unlink() + + def get_metas(self): + meta_dict = {} + for k, v in self.arrays.items(): + meta_dict[k] = v.metadata + return meta_dict diff --git a/rosetta/rosetta/projects/inference_serving/t5/embed_large.gin b/rosetta/rosetta/projects/inference_serving/t5/embed_large.gin new file mode 100644 index 000000000..1bda71dc2 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/embed_large.gin @@ -0,0 +1,87 @@ +# Defaults for a T5 large embedding server +# +# Required to be set: +# +# - CHECKPOINT_PATH: The model checkpoint to evaluate +# - EVAL_OUTPUT_DIR: The dir to write results to. +# +# +# Commonly overridden options: +# +# - BATCH_SIZE +from __gin__ import dynamic_registration + +import __main__ as embed_script +from t5x import partitioning +from t5x import utils +from t5x import adafactor +import seqio +from rosetta.projects.inference_serving.t5 import network +from rosetta.projects.inference_serving.t5 import models + +# ===================================== +# === T5 Encoder only configuration === +# ===================================== +CHECKPOINT_PATH = "/opt/rosetta/checkpoints/t5/checkpoint_1000000_t5_1_1_large" +EVAL_OUTPUT_DIR = %gin.REQUIRED +BATCH_SIZE = 256 # Will be overridden +SEQ_LEN = 128 # MAX seqlen + +# Vocabulary +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# --------------- Model ------------------ +MODEL = @models.EncoderOnlyModel() +models.EncoderOnlyModel: + module = @network.TransformerEncoderOnly() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = None + z_loss = 0.0001 + label_smoothing = 0.0 + loss_normalizing_factor = None + +# -------- Network specification --------- +network.TransformerEncoderOnly.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 1024 + num_heads = 16 + num_encoder_layers = 24 + num_decoder_layers = 0 + head_dim = 64 + mlp_dim = 2816 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + +# ====================================== +# === Embedding script configuration === +# ====================================== +embed_script.zmq_run: + infer_fn = @embed_script.get_infer_fn() + +embed_script.get_infer_fn: + model = %MODEL # imported from separate gin file + vocab = %VOCABULARY + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + preproc_fn = @embed_script.seqio_preprocessing_pow2 + output_dir = %EVAL_OUTPUT_DIR + batch_size = %BATCH_SIZE + seq_len = %SEQ_LEN + +embed_script.seqio_preprocessing_pow2: + vocab = %VOCABULARY + seq_len = %SEQ_LEN + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' diff --git a/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.py b/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.py new file mode 100644 index 000000000..6716a713a --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.py @@ -0,0 +1,259 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import jax +import jax.numpy as jnp + +import seqio +from t5x import gin_utils +from t5x import models +from t5x import partitioning +from t5x import utils +from seqio.vocabularies import PAD_ID + +import logging +import time +import os +from typing import Any, Callable, Sequence +import pickle as pkl +import zmq +from rosetta.projects.inference_serving import server_utils +from rosetta.projects.inference_serving.shared_numpy import SharedNPDict + +_DEFAULT_GIN_SEARCH_PATHS = [ + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +] + +os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'true' +os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.95' + +def get_singleton_batch(batch_size: int): + in_singleton = ['the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog. \ + the quick brown fox jumped over the lazy dog. the quick brown fox jumped over the lazy dog.'] + batch = in_singleton * batch_size + return {'batch': server_utils.triton_textencode(batch)} + + +def pad_right(tokens, seq_len, eos_id,pad_id): + padded, tok_lengths = [], [] + for t in tokens: + diff = seq_len - (len(t) + 1) + #assert diff >= 0 + if diff < 0: + padded.append(t[:seq_len - 1] + [eos_id]) + tok_lengths.append(seq_len) + else: + padded.append(t + [eos_id] + [pad_id] * diff) + tok_lengths.append(len(t) + 1) + + return jnp.array(padded, dtype=jnp.int32), tok_lengths, seq_len + +def seqio_preprocessing(mbatch, vocab:Any=None, seq_len:int=128): + return pad_right(vocab.encode(mbatch), seq_len=seq_len, eos_id=vocab.eos_id, pad_id=PAD_ID) + +def pow2upper(n: int): + i = 1 + while i < n: + i *= 2 + return i + +def pow_2_pad_right(tokens_batch, seq_len, eos_id, pad_id): + padded, tok_lengths = [], [] + max_seq_len = max([len(t) for t in tokens_batch]) + 1 + seq_len = min(pow2upper(max_seq_len), seq_len) + + for t in tokens_batch: + diff = seq_len - (len(t) + 1) + # assert diff >= 0 + if diff < 0: + padded.append(t[:seq_len - 1] + [eos_id]) + tok_lengths.append(seq_len) + else: + padded.append(t + [eos_id] + [pad_id] * diff) + tok_lengths.append(len(t) + 1) + + return jnp.array(padded, dtype=jnp.int32), tok_lengths, seq_len + +def seqio_preprocessing_pow2(mbatch, vocab:Any=None, seq_len:int=128): + return pow_2_pad_right(vocab.encode(mbatch), seq_len=seq_len, eos_id=vocab.eos_id, pad_id=PAD_ID) + +def get_infer_fn( + *, + model: models.BaseTransformerModel, + vocab: Any, + restore_checkpoint_cfg: utils.RestoreCheckpointConfig, + partitioner: partitioning.BasePartitioner, + output_dir: str, + preproc_fn: Callable, + batch_size: int, + seq_len: int): + + input_shapes = {'encoder_input_tokens': (batch_size, seq_len)} + + train_state_initializer = utils.TrainStateInitializer( + optimizer_def=None, # Do not load optimizer state. + init_fn=model.get_initial_variables, + input_shapes=input_shapes, + partitioner=partitioner) + train_state_axes = train_state_initializer.train_state_axes + # Log the variable shapes information and write to a file. + log_file = os.path.join(output_dir, 'model-info.txt') + utils.log_model_info(log_file, + train_state_initializer.global_train_state_shape, + partitioner) + + # Disable strictness since we are dropping the optimizer state. + restore_checkpoint_cfg.strict = False + + fallback_init_rng = None + + if fallback_init_rng is not None: + fallback_init_rng = jax.random.PRNGKey(fallback_init_rng) + train_state = list(train_state_initializer.from_checkpoints([restore_checkpoint_cfg], init_rng=fallback_init_rng))[0] + logging.warning(f'Restored from Checkpoint: {train_state[1]}') + train_state = train_state[0] + + partitioned_fn = partitioner.partition( + model.score_batch, + in_axis_resources=(train_state_axes.params, partitioning.PartitionSpec('data',)), + out_axis_resources=None) + + CUDA_VIS = os.getenv('CUDA_VISIBLE_DEVICES') + + def infer_fn(**inputs: np.ndarray): + start_time = time.time() + (sequence_batch,) = inputs.values() + batch = np.array([i[0] for i in sequence_batch]) + sequence_batch = to_str_list(np.char.decode(batch.astype("bytes"), "utf-8")) + + tokenized_padded, batch_len, curr_seqlen = preproc_fn(sequence_batch) + results = partitioned_fn(train_state.params, {"encoder_input_tokens": tokenized_padded}).astype(jnp.float16) + + results.block_until_ready() + bs = batch.shape[0] + pre_pad_time = time.time() + individual_shape = results[0].shape + padded_output = np.zeros((bs, *individual_shape), dtype=np.float16) + for idx, (tensor, true_len) in enumerate(zip(results, batch_len)): + padded_output[idx, :true_len] = tensor[:true_len] + + logging.info('Throughput (seq/sec): {}, bs: {}, devices: {}, seqlen: {}, throughput w/o pad {}'.format(bs / (time.time() - start_time), bs, CUDA_VIS, curr_seqlen, bs/(pre_pad_time - start_time))) + # return sliced_output + return padded_output, np.array(batch_len, dtype=np.int32) + + + return infer_fn + +def to_str_list(batch): + b = [] + for i in batch: + b.append(str(i)) + return b + +def zmq_run(socket, infer_fn: Callable): + logging.info("Starting ZMQ Server") + + while True: + socket_in = socket.recv_pyobj() + # logging.warning(f"Recieved from socket, {socket_in}") + localized_inputs = SharedNPDict(metadata=socket_in).localize(close_shared=True) + # logging.warning("Localized socket_in") + if isinstance(localized_inputs, dict): + if 'singleton' in localized_inputs.keys(): + count = localized_inputs['singleton'] + localized_inputs = get_singleton_batch(localized_inputs['singleton']) + logging.info(f"Recieved singleton command {count}") + else: + for k, v in localized_inputs.items(): + localized_inputs[k] = pkl.loads(v[0]) + try: + padded_outs, seqlens = infer_fn(**localized_inputs) + # logging.info("created out") + outputs_shared = SharedNPDict(dict_to_share={'encodings_padded': padded_outs, 'encodings_seqlens': seqlens}) + logging.info("Shared out") + outputs = outputs_shared.get_metas() + outputs_shared.close() + # logging.info("closed out") + except Exception as e: + outputs = str(e) + + socket.send_pyobj(outputs) + +if __name__ == '__main__': + from absl import app + from absl import flags + import gin + + FLAGS = flags.FLAGS + + jax.config.parse_flags_with_absl() + + flags.DEFINE_multi_string( + 'gin_file', + default=None, + help='Path to gin configuration file. Multiple paths may be passed and ' + 'will be imported in the given order, with later configurations ' + 'overriding earlier ones.') + + flags.DEFINE_multi_string( + 'gin_bindings', default=[], help='Individual gin bindings.') + + flags.DEFINE_list( + 'gin_search_paths', + default=['.'], + help='Comma-separated list of gin config path prefixes to be prepended ' + 'to suffixes given via `--gin_file`. If a file appears in. Only the ' + 'first prefix that produces a valid path for each suffix will be ' + 'used.') + + def main(argv: Sequence[str]): + """Wrapper for pdb post mortems.""" + + if jax.process_index() == 0: + main_fn = lambda: _main(argv) + _main(argv) + else: + _main(argv) + + def _main(argv: Sequence[str]): + """True main function.""" + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + socket_name = os.environ.get('SOCKET_ADDRESS') + ctx = zmq.Context.instance() + socket = ctx.socket(zmq.REP) + socket.connect(socket_name) + + # Create gin-configurable version of `eval`. + # tr = functools.partial(triton_run, port=FLAGS.port) + run_using_gin = gin.configurable(zmq_run) + + gin_utils.parse_gin_flags( + # User-provided gin paths take precedence if relative paths conflict. + FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, + FLAGS.gin_file, + FLAGS.gin_bindings) + run_using_gin(socket) + + gin_utils.run(main) diff --git a/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.sh b/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.sh new file mode 100755 index 000000000..36a9f1cc2 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/embed_t5x.sh @@ -0,0 +1,45 @@ +#! /bin/bash + +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set -x + +MODEL_SIZE=$1 + +# Arguments +PREC="bfloat16" + +BSIZE=8 #Overridden + +MP=1 + +MODEL_DIR_LOCAL="/tmp/${MODEL_SIZE}_t5_inference_debugs" + +MODEL_DIR=${PWD}/${MODEL_DIR_LOCAL} + +mkdir -p $MODEL_DIR + +echo $MODEL_DIR + +CFG_NAME="embed_${MODEL_SIZE}.gin" + +T5X_SERVER_DIR="/opt/rosetta/rosetta/projects/inference_serving/t5/" + +python ${T5X_SERVER_DIR}/embed_t5x.py \ + --gin_file="${T5X_SERVER_DIR}/${CFG_NAME}" \ + --gin.EVAL_OUTPUT_DIR=\"${MODEL_DIR}\" \ + --gin.network.T5Config.dtype=\"${PREC}\" \ + --gin.BATCH_SIZE=$BSIZE \ + --gin.partitioning.PjitPartitioner.num_partitions=$MP diff --git a/rosetta/rosetta/projects/inference_serving/t5/embed_xxl.gin b/rosetta/rosetta/projects/inference_serving/t5/embed_xxl.gin new file mode 100644 index 000000000..ad5e8f4d6 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/embed_xxl.gin @@ -0,0 +1,87 @@ +# Defaults for a T5 large embedding server +# +# Required to be set: +# +# - CHECKPOINT_PATH: The model checkpoint to evaluate +# - EVAL_OUTPUT_DIR: The dir to write results to. +# +# +# Commonly overridden options: +# +# - BATCH_SIZE +from __gin__ import dynamic_registration + +import __main__ as embed_script +from t5x import partitioning +from t5x import utils +from t5x import adafactor +import seqio +from rosetta.projects.inference_serving.t5 import network +from rosetta.projects.inference_serving.t5 import models + +# ===================================== +# === T5 Encoder only configuration === +# ===================================== +CHECKPOINT_PATH = "/opt/rosetta/rosetta/projects/inference_serving/checkpoints/checkpoint_1000000_t5_1_1_xxl" +EVAL_OUTPUT_DIR = %gin.REQUIRED +BATCH_SIZE = 256 # Will be overridden +SEQ_LEN = 128 # MAX seqlen + +# Vocabulary +VOCABULARY = @seqio.SentencePieceVocabulary() +seqio.SentencePieceVocabulary.sentencepiece_model_file = "gs://t5-data/vocabs/cc_all.32000.100extra/sentencepiece.model" +TASK_FEATURE_LENGTHS = None # auto-computes the maximum features length to use. + +# --------------- Model ------------------ +MODEL = @models.EncoderOnlyModel() +models.EncoderOnlyModel: + module = @network.TransformerEncoderOnly() + input_vocabulary = %VOCABULARY + output_vocabulary = %VOCABULARY + optimizer_def = None + z_loss = 0.0001 + label_smoothing = 0.0 + loss_normalizing_factor = None + +# -------- Network specification --------- +network.TransformerEncoderOnly.config = @network.T5Config() +network.T5Config: + vocab_size = 32128 # vocab size rounded to a multiple of 128 for TPU efficiency + dtype = 'bfloat16' + emb_dim = 4096 + num_heads = 64 + num_encoder_layers = 24 + num_decoder_layers = 0 + head_dim = 64 + mlp_dim = 10240 + mlp_activations = ('gelu', 'linear') + dropout_rate = 0.0 + +# ====================================== +# === Embedding script configuration === +# ====================================== +embed_script.zmq_run: + infer_fn = @embed_script.get_infer_fn() + +embed_script.get_infer_fn: + model = %MODEL # imported from separate gin file + vocab = %VOCABULARY + restore_checkpoint_cfg = @utils.RestoreCheckpointConfig() + partitioner = @partitioning.PjitPartitioner() + preproc_fn = @embed_script.seqio_preprocessing_pow2 + output_dir = %EVAL_OUTPUT_DIR + batch_size = %BATCH_SIZE + seq_len = %SEQ_LEN + +embed_script.seqio_preprocessing_pow2: + vocab = %VOCABULARY + seq_len = %SEQ_LEN + +partitioning.PjitPartitioner: + num_partitions = 1 + logical_axis_rules = @partitioning.standard_logical_axis_rules() + +utils.RestoreCheckpointConfig: + path = %CHECKPOINT_PATH + mode = 'specific' + dtype = 'bfloat16' diff --git a/rosetta/rosetta/projects/inference_serving/t5/models.py b/rosetta/rosetta/projects/inference_serving/t5/models.py new file mode 100644 index 000000000..42fb13f06 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/models.py @@ -0,0 +1,176 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5X EncoderOnly Model""" + +from typing import Any, Callable, Mapping, MutableMapping, Optional, Tuple, Union + +from flax import linen as nn +from flax.core import scope as flax_scope +import jax +import jax.numpy as jnp +import seqio +from t5x import decoding +from t5x import optimizers +from t5x.models import BaseTransformerModel, DecodeFnCallable, Array, PyTreeDef + +class EncoderOnlyModel(BaseTransformerModel): + """Wrapper class for the TransformerEncoderOnly nn.module.""" + + FEATURE_CONVERTER_CLS = seqio.EncDecFeatureConverter + + def __init__( + self, + module: nn.Module, + input_vocabulary: seqio.Vocabulary, + output_vocabulary: seqio.Vocabulary, + optimizer_def: optimizers.OptimizerDefType, + decode_fn: DecodeFnCallable = decoding.beam_search, + feature_converter_cls: Optional[Callable[..., + seqio.FeatureConverter]] = None, + label_smoothing: float = 0.0, + z_loss: float = 0.0, + loss_normalizing_factor: Optional[float] = None, + ): + if feature_converter_cls is not None: + self.FEATURE_CONVERTER_CLS = feature_converter_cls # pylint: disable=invalid-name + super().__init__( + module=module, + input_vocabulary=input_vocabulary, + output_vocabulary=output_vocabulary, + optimizer_def=optimizer_def, + decode_fn=decode_fn, + label_smoothing=label_smoothing, + z_loss=z_loss, + loss_normalizing_factor=loss_normalizing_factor, + ) + + def get_initial_variables( + self, + rng: jax.random.KeyArray, + input_shapes: Mapping[str, Array], + input_types: Optional[Mapping[str, jnp.dtype]] = None + ) -> flax_scope.FrozenVariableDict: + """Get the initial variables for an encoder-decoder model.""" + input_types = {} if input_types is None else input_types + encoder_shape = input_shapes['encoder_input_tokens'] + encoder_type = input_types.get('encoder_input_tokens', jnp.float32) + if 'encoder_positions' in input_shapes: + encoder_positions = jnp.ones( + input_shapes['encoder_positions'], + input_types.get('encoder_positions', jnp.int32)) + else: + encoder_positions = None + + if 'encoder_segment_ids' in input_shapes: + encoder_segment_ids = jnp.ones( + input_shapes['encoder_segment_ids'], + input_types.get('encoder_segment_ids', jnp.int32)) + else: + encoder_segment_ids = None + initial_variables = self.module.init( + rng, + jnp.ones(encoder_shape, encoder_type), + encoder_segment_ids=encoder_segment_ids, + enable_dropout=False) + return initial_variables + + def _compute_logits( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + dropout_rng: Optional[jax.random.KeyArray] = None, + mutable: flax_scope.CollectionFilter = False, + other_variables: Optional[PyTreeDef] = None, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, flax_scope.FrozenVariableDict]]: + """Computes logits via a forward pass of `self.module_cls`.""" + # Dropout is provided only for the training mode. + rngs = {'dropout': dropout_rng} if dropout_rng is not None else None + if other_variables is None: + other_variables = {} + + variables = { + 'params': params, + **other_variables + } + + return self.module.apply( + variables, + batch['encoder_input_tokens'], + encoder_segment_ids=batch.get('encoder_segment_ids', None), + enable_dropout=False, + rngs=rngs, + mutable=mutable) + + def _compute_logits_from_slice( + self, flat_ids: jnp.ndarray, flat_cache: Mapping[str, jnp.ndarray], + params: PyTreeDef, encoded_inputs: jnp.ndarray, raw_inputs: jnp.ndarray, + max_decode_length: int) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + """Token slice to logits from decoder model.""" + # flat_ids: [batch * beam, seq_len=1] + # cache is expanded inside beam_search to become flat_cache + # flat_cache: [batch * beam, num_heads, depth_per_head, max_decode_len] + # flat_logits: [batch * beam, seq_len=1, vocab] + flat_logits, new_vars = self.module.apply( + { + 'params': params, + 'cache': flat_cache + }, + encoded_inputs, + raw_inputs, # only needed for encoder padding mask + flat_ids, + flat_ids, + enable_dropout=False, + decode=True, + max_decode_length=max_decode_length, + mutable=['cache'], + method=self.module.decode) + # Remove sequence length dimension since it's always 1 during decoding. + flat_logits = jnp.squeeze(flat_logits, axis=1) + new_flat_cache = new_vars['cache'] + return flat_logits, new_flat_cache + + def score_batch( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + return_intermediates: bool = False, + ) -> Union[jnp.ndarray, Tuple[jnp.ndarray, Mapping[str, Any]]]: + """Compute log likelihood score on a batch.""" + + output = self._compute_logits(params, batch) # type: jnp.ndarray + + return output + + def predict_batch_with_aux( + self, + params: PyTreeDef, + batch: Mapping[str, jnp.ndarray], + rng: Optional[jax.random.KeyArray] = None, + ) -> Tuple[jnp.ndarray, Mapping[str, jnp.ndarray]]: + raise NotImplementedError("Predict Batch is not implemented for encoder only. Use score_batch") diff --git a/rosetta/rosetta/projects/inference_serving/t5/network.py b/rosetta/rosetta/projects/inference_serving/t5/network.py new file mode 100644 index 000000000..38fd13774 --- /dev/null +++ b/rosetta/rosetta/projects/inference_serving/t5/network.py @@ -0,0 +1,105 @@ +# Copyright (c) 2022-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright 2022 The T5X Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""T5.1.1 Transformer Encoder only model for getting text embeddings""" + +from flax import linen as nn +import jax.numpy as jnp +from t5x.contrib.gpu.t5 import layers +from t5x.contrib.gpu.t5.network import T5Config, Encoder, SeqDataFormat +from t5x.te_helper import TransformerEngineHelper + +class TransformerEncoderOnly(nn.Module): + """An encoder-only T5 Transformer model.""" + config: T5Config + + def setup(self): + cfg = TransformerEngineHelper.get_t5x_config(self.config) + + self.shared_embedding = layers.Embed( + num_embeddings=cfg.vocab_size, + features=cfg.emb_dim, + dtype=cfg.dtype, + attend_dtype=jnp.float32, # for logit training stability + embedding_init=nn.initializers.normal(stddev=1.0), + one_hot=False, + name='token_embedder') + + self.encoder = Encoder(config=cfg, shared_embedding=self.shared_embedding) + + def encode(self, + encoder_input_tokens, + encoder_segment_ids=None, + enable_dropout=True, + output_format=SeqDataFormat.BATCH_SEQ_HIDDEN): + """Applies Transformer encoder-branch on the inputs.""" + cfg = self.config + assert encoder_input_tokens.ndim == 2 # (batch, len) + + # Make padding attention mask. + encoder_mask = layers.make_attention_mask( + encoder_input_tokens > 0, encoder_input_tokens > 0, dtype=cfg.dtype) + # Add segmentation block-diagonal attention mask if using segmented data. + if encoder_segment_ids is not None: + encoder_mask = layers.combine_masks( + encoder_mask, + layers.make_attention_mask( + encoder_segment_ids, + encoder_segment_ids, + jnp.equal, + dtype=cfg.dtype)) + + encoder_mask = TransformerEngineHelper.get_attn_mask(encoder_mask) + + return self.encoder( + encoder_input_tokens, encoder_mask, deterministic=not enable_dropout, + output_format=output_format) + + def __call__(self, + encoder_input_tokens, + encoder_segment_ids=None, + *, + enable_dropout: bool = True): + """Applies Transformer encoder-only model on the inputs. + + This method requires just encoder inputs + + Args: + encoder_input_tokens: input data to the encoder. + encoder_segment_ids: encoder segmentation info for packed examples. + enable_dropout: Ensables dropout if set to True. + + Returns: + logits array from just the encoder of a T5 model + """ + encoded = self.encode( + encoder_input_tokens, + encoder_segment_ids=encoder_segment_ids, + enable_dropout=enable_dropout) + + return encoded diff --git a/rosetta/setup.py b/rosetta/setup.py index 70b19b1b7..7f0476f46 100644 --- a/rosetta/setup.py +++ b/rosetta/setup.py @@ -26,8 +26,13 @@ }, scripts=[], install_requires=[ - 'nvidia-dali-cuda120', + 'zmq', + 'nvidia-pytriton', + 'einops', + 'pillow', 'webdataset', + 'matplotlib', + 'nvidia-dali-cuda120', ], extras_require={