From 2a39e9093d2a6de61aa67c37dca3b9d658a6e0bc Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 22 Apr 2024 14:11:08 -0700 Subject: [PATCH] Benchmark --- python/lbann/contrib/data/molecule_dataset.py | 45 ++++++++++++++----- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/python/lbann/contrib/data/molecule_dataset.py b/python/lbann/contrib/data/molecule_dataset.py index 7e987c3adc1..7f01aa7e565 100644 --- a/python/lbann/contrib/data/molecule_dataset.py +++ b/python/lbann/contrib/data/molecule_dataset.py @@ -56,7 +56,8 @@ def __init__(self, file_or_files: Union[str, List[str]]): np.memmap(f + '.offsets', dtype=np.uint64) for f in file_or_files ] self.samples = [o.shape[0] for o in self.offsets] - self.cs = np.cumsum(np.array(self.samples, dtype=np.uint64), dtype=np.uint64) + self.cs = np.cumsum(np.array(self.samples, dtype=np.uint64), + dtype=np.uint64) self.total_samples = sum(self.samples) # Clean memmapped files so that the object can be pickled @@ -222,17 +223,41 @@ def trim_and_pad(self, sample, random: bool): if __name__ == '__main__': import sys - if len(sys.argv) != 4: - print('USAGE: dataloader_mlm.py ' - '') + if len(sys.argv) < 4: + print('USAGE: dataloader_mlm.py ' + ' [other dataset files]') exit(1) - dataset = ChemTextDataset( - fname=[sys.argv[1]], - vocab=sys.argv[2], - seqlen=64, - tokenizer_type=ChemTokenType[sys.argv[3].upper()]) - print('Dataset samples:', len(dataset)) + _, vocab, toktype, *files = sys.argv + + dataset = ChemTextDataset(fname=files, + vocab=vocab, + seqlen=64, + tokenizer_type=ChemTokenType[toktype.upper()]) + + # Test 1: Arbitrary sample retrieval + num_samples = len(dataset) + print('Dataset samples:', num_samples) print('Dataset sample -1:') print( dataset.tokenizer.decode(dataset[-1].sample[:dataset.sequence_length])) + + # Test 2: Retrieval bandwidth + import time + try: + from tqdm import trange + except (ModuleNotFoundError, ImportError): + trange = range + + # Warmup + for _ in range(10): + samp = np.random.randint(0, num_samples - 1) + _ = dataset[samp] + + SAMPLES = 5000 + start = time.time() + for i in trange(SAMPLES): + samp = np.random.randint(0, num_samples - 1) + _ = dataset[samp] + end = time.time() + print('Samples per second:', SAMPLES / (end - start))