Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update DALI pipeline for ViT #911

Merged
merged 2 commits into from
Jul 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 6 additions & 25 deletions rosetta/rosetta/data/dali.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,32 +78,9 @@ def __init__(self,
self.prefetch = wds_config.prefetch
self.training = training

## set up the wds reader
self.pipe = self.get_wds_pipeline()
self.pipe.build()
## dataset metadata will be stored here and set in the iterator
self.meta = None

## dataset metadata
meta_dict = self.pipe.reader_meta()
assert(len(meta_dict) == 1), 'Pipeline has multiple readers but is expected to have only one'
self.meta = list(meta_dict.values())[0]

@abc.abstractmethod
def get_wds_pipeline(self):
"""Returns the pipeline which loads the wds files.

Expected to have the following format:

@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=1, device_id=None)
def wds_pipeline():
outputs = fn.readers.webdataset(
...)
return outputs
return wds_pipeline()

See ViT's `dali_utils.py` for an example

"""
pass

@abc.abstractmethod
def get_dali_pipeline(self):
Expand Down Expand Up @@ -144,6 +121,10 @@ def __init__(self, dali_wrapped_pipeline: BaseDALIPipeline):
self.wrapped_pipeline = dali_wrapped_pipeline
self.pipeline = dali_wrapped_pipeline.get_dali_pipeline()
self.pipeline.build()

meta_dict = self.pipeline.reader_meta()
assert(len(meta_dict) == 1), 'Pipeline has multiple readers but is expected to have only one'
dali_wrapped_pipeline.meta = list(meta_dict.values())[0]

self.training = dali_wrapped_pipeline.training

Expand Down
39 changes: 10 additions & 29 deletions rosetta/rosetta/data/dali_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
from rosetta.data.dali import BaseDALIPipeline


def class_preproc(raw_text):
return np.array([int(bytes(raw_text).decode('utf-8'))])

class DummyPipeline(BaseDALIPipeline):

def __init__(self,
Expand Down Expand Up @@ -52,9 +55,11 @@ def __init__(self,
num_shards=num_shards,
training=False)

def get_wds_pipeline(self):
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=1, device_id=None)
def wds_pipeline():

def get_dali_pipeline(self):
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=self.num_workers, device_id=None)
def main_pipeline():
# img, labels = fn.external_source(source=self.data_source(), num_outputs=2)
img, clss = fn.readers.webdataset(
paths=self.urls,
index_paths=self.index_paths,
Expand All @@ -64,34 +69,10 @@ def wds_pipeline():
shard_id=self.shard_id,
num_shards=self.num_shards,
pad_last_batch=False)
return img, clss
return wds_pipeline()

## non-image preprocessing
def class_preproc(self, raw_text):
bs = len(raw_text.shape())
ascii = [np.asarray(raw_text[i]) for i in range(bs)]

labels = np.zeros((bs, ))
for i, el in enumerate(ascii):
idx = int(bytes(el).decode('utf-8'))
labels[i] = idx

return labels

def data_source(self):
while True:
img, clss = self.pipe.run()
clss = self.class_preproc(clss)
yield img, clss


def get_dali_pipeline(self):
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=self.num_workers, device_id=None)
def main_pipeline():
img, labels = fn.external_source(source=self.data_source(), num_outputs=2)

img = fn.decoders.image(img, device='cpu', output_type=types.RGB)
labels = fn.python_function(clss, function=class_preproc, num_outputs=1)

return img, labels

return main_pipeline()
Expand Down
47 changes: 13 additions & 34 deletions rosetta/rosetta/projects/vit/dali_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
from rosetta.data.dali import BaseDALIPipeline
from rosetta.data.wds_utils import ModalityConfig

def non_image_preprocessing(raw_text):
""" preprocessing of class labels. """
return np.array([int(bytes(raw_text).decode('utf-8'))])


class ViTPipeline(BaseDALIPipeline):

Expand Down Expand Up @@ -51,11 +55,12 @@ def __init__(self,
training=training)


def get_wds_pipeline(self):
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=1, device_id=None, seed=self.seed)
def wds_vit_pipeline():
## assumes a particular order to the ftypes
img, clss = fn.readers.webdataset(
def get_dali_pipeline(self):

## need to enable conditionals for auto-augment
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=self.num_workers, device_id=None, enable_conditionals=True, seed=self.seed, prefetch_queue_depth=self.prefetch)
def main_vit_pipeline():
jpegs, clss = fn.readers.webdataset(
paths=self.urls,
index_paths=self.index_paths,
ext=[m.ftype for m in self.modalities],
Expand All @@ -64,36 +69,10 @@ def wds_vit_pipeline():
shard_id=self.shard_id,
num_shards=self.num_shards,
pad_last_batch=False if self.training else True)
return img, clss
return wds_vit_pipeline()

def non_image_preprocessing(self, raw_text, num_classes):
""" preprocessing of class labels. """
bs = len(raw_text.shape())
ascii = [np.asarray(raw_text[i]) for i in range(bs)]

one_hot = np.zeros((bs, num_classes))
for i, el in enumerate(ascii):
idx = int(bytes(el).decode('utf-8'))
one_hot[i][idx] = 1

return one_hot

def data_source(self, num_classes):
while True:
preprocessed_img, raw_text = self.pipe.run()
preprocessed_label = self.non_image_preprocessing(raw_text, num_classes)
yield preprocessed_img, preprocessed_label


def get_dali_pipeline(self):

## need to enable conditionals for auto-augment
@pipeline_def(batch_size=self.per_shard_batch_size, num_threads=self.num_workers, device_id=None, enable_conditionals=True, seed=self.seed, prefetch_queue_depth=self.prefetch)
def main_vit_pipeline():
jpegs, labels = fn.external_source(source=self.data_source(self.num_classes), num_outputs=2)

img = fn.decoders.image(jpegs, device='cpu', output_type=types.RGB)

labels = fn.python_function(clss, function=non_image_preprocessing, num_outputs=1)
labels = fn.one_hot(labels, num_classes=self.num_classes)

if self.training:
img = fn.random_resized_crop(img, size=self.image_shape[:-1], seed=self.seed)
Expand Down
Loading