Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Potential bug in lbann.Scale #2433

Closed
jvwilliams23 opened this issue Mar 12, 2024 · 8 comments
Closed

Potential bug in lbann.Scale #2433

jvwilliams23 opened this issue Mar 12, 2024 · 8 comments
Assignees
Labels

Comments

@jvwilliams23
Copy link
Contributor

Hi,

After a series of turning things on/off in my code, I seem to have found that lbann.Scale is giving some problems or at least what I would call unexpected behaviour (even when using x = lbann.Scale(x, constant=1). I tested this also on the lenet example to get a MWE for sharing.

Output from lenet.py (unmodified)

--------------------------------------------------------------------------------------------
[0] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 0 objective function : 0.24518
model0 (instance 0) training epoch 0 accuracy : 92.1185%
model0 (instance 0) training epoch 0 run time : 2.32064s
model0 (instance 0) training epoch 0 mini-batch time statistics : 0.00273541s mean, 0.0024633s median, 0.189072s max, 0.00242463s min, 0.00649489s stdev
model0 (instance 0) validation objective function : 0.119927
model0 (instance 0) validation accuracy : 96.2667%
model0 (instance 0) validation run time : 0.0899092s
model0 (instance 0) validation mini-batch time statistics : 0.000939553s mean, 0.000930496s median, 0.00127683s max, 0.000921108s min, 4.43708e-05s stdev
--------------------------------------------------------------------------------------------
[1] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 1 objective function : 0.0687048
model0 (instance 0) training epoch 1 accuracy : 97.8815%
model0 (instance 0) training epoch 1 run time : 2.13083s
model0 (instance 0) training epoch 1 mini-batch time statistics : 0.00251189s mean, 0.0025022s median, 0.00585491s max, 0.00246355s min, 0.00011753s stdev
model0 (instance 0) validation objective function : 0.0686692
model0 (instance 0) validation accuracy : 97.7833%
model0 (instance 0) validation run time : 0.0891519s
model0 (instance 0) validation mini-batch time statistics : 0.000932478s mean, 0.000927124s median, 0.00124584s max, 0.000917194s min, 3.45411e-05s stdev
--------------------------------------------------------------------------------------------
[2] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 2 objective function : 0.0514064
model0 (instance 0) training epoch 2 accuracy : 98.35%
model0 (instance 0) training epoch 2 run time : 2.12826s
model0 (instance 0) training epoch 2 mini-batch time statistics : 0.00250887s mean, 0.00250032s median, 0.00584337s max, 0.00246867s min, 0.00011689s stdev
model0 (instance 0) validation objective function : 0.0544848
model0 (instance 0) validation accuracy : 98.4833%
model0 (instance 0) validation run time : 0.0893365s
model0 (instance 0) validation mini-batch time statistics : 0.000934504s mean, 0.000928747s median, 0.00124921s max, 0.000918249s min, 3.4701e-05s stdev

Output from modified lenet.py with lbann.Scale

--------------------------------------------------------------------------------------------
[0] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 0 objective function : 2.3036
model0 (instance 0) training epoch 0 accuracy : 11.1389%
model0 (instance 0) training epoch 0 run time : 2.55206s
model0 (instance 0) training epoch 0 mini-batch time statistics : 0.00300853s mean, 0.00274318s median, 0.194486s max, 0.00270395s min, 0.00666651s stdev
model0 (instance 0) validation objective function : 2.30216
model0 (instance 0) validation accuracy : 11.3333%
model0 (instance 0) validation run time : 0.0994602s
model0 (instance 0) validation mini-batch time statistics : 0.00104114s mean, 0.00103314s median, 0.001348s max, 0.00102378s min, 4.18474e-05s stdev
--------------------------------------------------------------------------------------------
[1] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 1 objective function : 2.30162
model0 (instance 0) training epoch 1 accuracy : 11.2259%
model0 (instance 0) training epoch 1 run time : 2.33295s
model0 (instance 0) training epoch 1 mini-batch time statistics : 0.00275055s mean, 0.00274311s median, 0.00602017s max, 0.00270121s min, 0.00011506s stdev
model0 (instance 0) validation objective function : 2.30204
model0 (instance 0) validation accuracy : 11.3333%
model0 (instance 0) validation run time : 0.0992328s
model0 (instance 0) validation mini-batch time statistics : 0.00103966s mean, 0.0010324s median, 0.00138047s max, 0.00102392s min, 4.01673e-05s stdev
--------------------------------------------------------------------------------------------
[2] Epoch : stats formated [tr/v/te/to] iter/epoch = [844/94/157/0]
            Global MB = [  64/  64/  64/   0]  Global last MB = [  48/  48/  16/   0]
--------------------------------------------------------------------------------------------
model0 (instance 0) training epoch 2 objective function : 2.30162
model0 (instance 0) training epoch 2 accuracy : 11.2259%
model0 (instance 0) training epoch 2 run time : 2.35187s
model0 (instance 0) training epoch 2 mini-batch time statistics : 0.00277249s mean, 0.00274219s median, 0.00616076s max, 0.00269846s min, 0.000158618s stdev
model0 (instance 0) validation objective function : 2.3013
model0 (instance 0) validation accuracy : 11.3333%
model0 (instance 0) validation run time : 0.0992319s
model0 (instance 0) validation mini-batch time statistics : 0.00103986s mean, 0.0010322s median, 0.0013561s max, 0.00102321s min, 3.66658e-05s stdev

Here is the modified code, with some annotations:

import argparse
import lbann
import data.mnist
import lbann.contrib.args
import lbann.contrib.launcher

# ----------------------------------
# Command-line arguments
# ----------------------------------

desc = ('Train LeNet on MNIST data using LBANN.')
parser = argparse.ArgumentParser(description=desc)
lbann.contrib.args.add_scheduler_arguments(parser, 'lbann_lenet')
args = parser.parse_args()

# ----------------------------------
# Construct layer graph
# ----------------------------------

# Input data
images = lbann.Input(data_field='samples')
labels = lbann.Input(data_field='labels')

# LeNet
x = lbann.Convolution(images,
                      num_dims = 2,
                      out_channels = 6,
                      groups = 1,
                      kernel_size = 5,
                      stride = 1,
                      dilation = 1,
                      has_bias = True)
x = lbann.Relu(x)
x = lbann.Scale(x, constant=1) # uncommenting this line gives loss ~ 2.3
x = lbann.Pooling(x,
                  num_dims = 2,
                  pool_dims_i = 2,
                  pool_strides_i = 2,
                  pool_mode = "max")
x = lbann.Convolution(x,
                      num_dims = 2,
                      out_channels = 16,
                      groups = 1,
                      kernel_size = 5,
                      stride = 1,
                      dilation = 1,
                      has_bias = True)
x = lbann.Relu(x)
# x = lbann.Scale(x, constant=1) # uncommenting this line gives loss ~ 2.3
x = lbann.Pooling(x,
                  num_dims = 2,
                  pool_dims_i = 2,
                  pool_strides_i = 2,
                  pool_mode = "max")
x = lbann.FullyConnected(x, num_neurons = 120, has_bias = True)
x = lbann.Relu(x)
# x = lbann.Scale(x, constant=1) # loss converges fine when this line is uncommented
x = lbann.FullyConnected(x, num_neurons = 84, has_bias = True)
x = lbann.Relu(x)
# x = lbann.Scale(x, constant=1) # uncommenting this line gives loss ~ 2.3
x = lbann.FullyConnected(x, num_neurons = 10, has_bias = True)
probs = lbann.Softmax(x)

# Loss function and accuracy
loss = lbann.CrossEntropy(probs, labels)
acc = lbann.CategoricalAccuracy(probs, labels)

# ----------------------------------
# Setup experiment
# ----------------------------------

# Setup model
mini_batch_size = 64
num_epochs = 20
model = lbann.Model(num_epochs,
                    layers=lbann.traverse_layer_graph([images, labels]),
                    objective_function=loss,
                    metrics=[lbann.Metric(acc, name='accuracy', unit='%')],
                    callbacks=[lbann.CallbackPrintModelDescription(),
                               lbann.CallbackPrint(),
                               lbann.CallbackTimer()])

# Setup optimizer
opt = lbann.SGD(learn_rate=0.01, momentum=0.9)

# Setup data reader
data_reader = data.mnist.make_data_reader()

# Setup trainer
trainer = lbann.Trainer(mini_batch_size=mini_batch_size)

# ----------------------------------
# Run experiment
# ----------------------------------
#kwargs = lbann.contrib.args.get_scheduler_kwargs(args)
#lbann.contrib.launcher.run(trainer, model, data_reader, opt,
#                           job_name=args.job_name,
#                           **kwargs)


kwargs = {
    "nodes": 1,
    "scheduler" : "openmpi",
    "time_limit" : 30,
}

lbann.run(trainer, model, data_reader, opt,
            job_name=args.job_name,
            **kwargs)

At first when applying a constant != 1, I thought perhaps this was only modifying the forward pass and not being back-propagated, so the gradients were not right. But, I would have expected that scaling the activations by 1 would have no effect at all on the training. So I have no idea where this issue lies (or is this behaviour expected?)

Best wishes,
Josh

@jvwilliams23
Copy link
Contributor Author

jvwilliams23 commented Mar 12, 2024

From the log file, it seems we have vanishing gradients, based on the fact that the loss is barely changing. I should mention including lbann.Scale(x, constant=1) in my StyleGAN implementation causes the loss to grow without bound (exploding gradients). That is what motivated the idea at the end of my original post on the issue in back-prop (although it may be something different entirely).

@benson31
Copy link
Collaborator

Reproduced. Thanks for pointing this out. We'll get this fixed very soon. In the meantime, if memory isn't an issue, you can use lbann.Multiply with an lbann.Constant layer -- this worked in my testing, as inefficient as it may be.

@benson31 benson31 self-assigned this Mar 12, 2024
@benson31 benson31 added the bug label Mar 12, 2024
@benson31
Copy link
Collaborator

I've traced this back to an issue with in-place operator layers. Still working on a fix. In the meantime, you can set LBANN_NO_INPLACE=1 in your runtime environment. This will have some overhead as well.

@jvwilliams23
Copy link
Contributor Author

Ok thanks, I added os.environ['LBANN_NO_INPLACE'] = "1" to my python script. It seems to be working well.

@jvwilliams23
Copy link
Contributor Author

Quick question, how much extra overhead do you expect the environment variable LBANN_NO_INPLACE=1 to create? I am seeing a significantly lower performance compared to pytorch (approx 3.5 days compared to 1 day for 1M iterations through my GAN).

@benson31
Copy link
Collaborator

benson31 commented May 1, 2024

@jvwilliams23 I think this issue should have been fixed by #2442. Can you confirm?

To answer your other question about LBANN_NO_INPLACE=1, I'm not sure if there's a performance overhead (@tbennun?), but there's a memory overhead as LBANN will allocate separate outputs for all layers in this mode.

@jvwilliams23
Copy link
Contributor Author

Will check the fix tomorrow morning (UK time)..

Just a quick comment regarding the performance issue I mentioned, I don't think it is the in-place. I profiled my job, see discussion #2438.

@jvwilliams23
Copy link
Contributor Author

Working now on lenet.py and also my StyleGAN implementation. Thanks @benson31 and @fiedorowicz1!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants