Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added option to set path to imagenet data #2192

Open
wants to merge 11 commits into
base: develop
Choose a base branch
from
6 changes: 5 additions & 1 deletion applications/vision/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
parser.add_argument(
'--num-classes', action='store', default=1000, type=int,
help='number of ImageNet classes (default: 1000)', metavar='NUM')
parser.add_argument(
'--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
lbann.contrib.args.add_optimizer_arguments(parser)
args = parser.parse_args()

Expand Down Expand Up @@ -64,7 +67,8 @@
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
data_path=args.data_path)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size)
Expand Down
49 changes: 29 additions & 20 deletions applications/vision/data/imagenet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import lbann
import lbann.contrib.launcher

def make_data_reader(num_classes=1000, small_testing=False):
def make_data_reader(num_classes=1000, small_testing=False, data_path=None):

# Load Protobuf message from file
current_dir = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -18,27 +18,36 @@ def make_data_reader(num_classes=1000, small_testing=False):
google.protobuf.text_format.Merge(f.read(), message)
message = message.data_reader

# Paths to ImageNet data
# Note: Paths are only known for some compute centers
compute_center = lbann.contrib.launcher.compute_center()
if compute_center == 'lc':
from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train',
num_classes=num_classes)
train_label_file = imagenet_labels(data_set='train',
num_classes=num_classes)
test_data_dir = imagenet_dir(data_set='val',
num_classes=num_classes)
test_label_file = imagenet_labels(data_set='val',

if data_path is not None:
print("Setting up data reader")
train_data_dir = os.path.join(data_path, 'train')
test_data_dir = os.path.join(data_path, 'val')
train_label_file = os.path.join(data_path, 'labels/train.txt')
test_label_file = os.path.join(data_path, 'labels/val.txt')

elif lbann.contrib.launcher.compute_center() in ['lc', 'nersc']:
# Paths to ImageNet data
# Note: Paths are only known for some compute centers
compute_center = lbann.contrib.launcher.compute_center()
if compute_center == 'lc':
from lbann.contrib.lc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train',
num_classes=num_classes)
elif compute_center == 'nersc':
from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train')
train_label_file = imagenet_labels(data_set='train')
test_data_dir = imagenet_dir(data_set='val')
test_label_file = imagenet_labels(data_set='val')
train_label_file = imagenet_labels(data_set='train',
num_classes=num_classes)
test_data_dir = imagenet_dir(data_set='val',
num_classes=num_classes)
test_label_file = imagenet_labels(data_set='val',
num_classes=num_classes)
elif compute_center == 'nersc':
from lbann.contrib.nersc.paths import imagenet_dir, imagenet_labels
train_data_dir = imagenet_dir(data_set='train')
train_label_file = imagenet_labels(data_set='train')
test_data_dir = imagenet_dir(data_set='val')
test_label_file = imagenet_labels(data_set='val')
else:
raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center})')
raise RuntimeError(f'ImageNet data paths are unknown for current compute center ({compute_center}). Set "--data-path" to the location of your dataset.')

# Check that data paths are accessible
if not os.path.isdir(train_data_dir):
Expand Down
8 changes: 6 additions & 2 deletions applications/vision/densenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,8 @@ def get_args():
parser.add_argument("--print-matrix-summary", dest="print_matrix_summary",
action="store_const",
const=True, default=False)
parser.add_argument('--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
args = parser.parse_args()
return args

Expand All @@ -438,7 +440,7 @@ def set_up_experiment(args,
labels):
algo = lbann.BatchedIterativeOptimizer("sgd", epoch_count=args.num_epochs)


# Set up objective function
cross_entropy = lbann.CrossEntropy([probs, labels])
layers = list(lbann.traverse_layer_graph(input_))
Expand Down Expand Up @@ -472,7 +474,9 @@ def set_up_experiment(args,
callbacks=callbacks)

# Set up data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes, small_testing=True)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
small_testing=True,
data_path=args.data_path)

percentage = 0.001 * 2 * (args.mini_batch_size / 16) * 2

Expand Down
6 changes: 5 additions & 1 deletion applications/vision/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
parser.add_argument(
'--random-seed', action='store', default=0, type=int,
help='random seed for LBANN RNGs', metavar='NUM')
parser.add_argument(
'--data-path', action='store', default=None, type=str,
help='Path to top-level imagenet directory. default: None')
lbann.contrib.args.add_optimizer_arguments(parser, default_learning_rate=0.1)
args = parser.parse_args()

Expand Down Expand Up @@ -145,7 +148,8 @@
opt = lbann.contrib.args.create_optimizer(args)

# Setup data reader
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes)
data_reader = data.imagenet.make_data_reader(num_classes=args.num_classes,
data_path=args.data_path)

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=args.mini_batch_size, random_seed=args.random_seed)
Expand Down
13 changes: 13 additions & 0 deletions docs/data_ingestion.rst
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
.. role:: bash(code)
:language: bash
.. role:: python(code)
:language: python

Data Ingestion
==============

Expand Down Expand Up @@ -27,6 +32,14 @@ Legacy Data Readers
Some of the legacy data readers are the ``MNIST``, ``ImageNet``, and
``CIFAR10`` data readers.

.. note:: The imagenet data reader uses a path that may not be known
to all compute centers. If the dataset is not found
:python:`--data-path` may be set to the top level of the data
set in :code:`resnet.py`, :code:`alexnet.py`, and
:code:`densenet.py`. The data set is must contain
:code:`labels/train.txt`, :code:`labels/val.txt`,
:code:`train/`, and :code:`val/`.


"New" Data Readers
-------------------
Expand Down