From c9d643b88e47ea85291426c8495008a497aa185b Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 18 May 2023 12:48:45 -0700 Subject: [PATCH] Add an LBANN Python unit test wrapper and utilities (#2264) * Add an LBANN Python unit test wrapper and utilities * Add capability for extra metrics and callbacks * Add a simple test that uses the new interface * Relax bounds of NASNet further * Improve support for multidimensional tensors * Add single-tensor test data reader * Improve readability of pytest assertions * Fix weighted sum operation and add test * Make weighted sum in-place-capable, ensure backprop runs all the way through in testing --- .../single_tensor_data_reader.py | 57 ++++ ci_test/common_python/test_util.py | 322 ++++++++++++++++++ ci_test/common_python/tools.py | 6 +- .../test_integration_nasnet.py | 2 +- .../unit_tests/test_unit_layer_addconstant.py | 22 ++ .../unit_tests/test_unit_layer_weightedsum.py | 71 ++++ .../lbann/layers/transform/weighted_sum.hpp | 17 +- 7 files changed, 490 insertions(+), 7 deletions(-) create mode 100644 ci_test/common_python/single_tensor_data_reader.py create mode 100644 ci_test/common_python/test_util.py create mode 100644 ci_test/unit_tests/test_unit_layer_addconstant.py create mode 100644 ci_test/unit_tests/test_unit_layer_weightedsum.py diff --git a/ci_test/common_python/single_tensor_data_reader.py b/ci_test/common_python/single_tensor_data_reader.py new file mode 100644 index 00000000000..93169c98b17 --- /dev/null +++ b/ci_test/common_python/single_tensor_data_reader.py @@ -0,0 +1,57 @@ +################################################################################ +# Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LLNL/LBANN. +# +# Licensed under the Apache License, Version 2.0 (the "Licensee"); you +# may not use this file except in compliance with the License. You may +# obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the license. +# +################################################################################ +""" +Simple data reader that opens one file with one tensor. Used for unit testing. +""" +import numpy as np + +# Lazy-load tensor +tensor = None + + +def lazy_load(): + # This file operates under the assumption that the working directory is set + # to a specific experiment. + global tensor + if tensor is None: + tensor = np.load('data.npy') + assert len(tensor.shape) == 2 + + +def get_sample(idx): + lazy_load() + return tensor[idx] + + +def num_samples(): + lazy_load() + return tensor.shape[0] + + +def sample_dims(): + lazy_load() + return (tensor.shape[1], ) diff --git a/ci_test/common_python/test_util.py b/ci_test/common_python/test_util.py new file mode 100644 index 00000000000..70042b41ac2 --- /dev/null +++ b/ci_test/common_python/test_util.py @@ -0,0 +1,322 @@ +################################################################################ +# Copyright (c) 2014-2023, Lawrence Livermore National Security, LLC. +# Produced at the Lawrence Livermore National Laboratory. +# Written by the LBANN Research Team (B. Van Essen, et al.) listed in +# the CONTRIBUTORS file. +# +# LLNL-CODE-697807. +# All rights reserved. +# +# This file is part of LBANN: Livermore Big Artificial Neural Network +# Toolkit. For details, see http://software.llnl.gov/LBANN or +# https://github.com/LLNL/LBANN. +# +# Licensed under the Apache License, Version 2.0 (the "Licensee"); you +# may not use this file except in compliance with the License. You may +# obtain a copy of the License at: +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the license. +# +################################################################################ +import lbann +from dataclasses import dataclass, field +import functools +import inspect +from typing import Any, Callable, List, Optional, Tuple, Union +import numpy as np +import os +import re +import tools +import single_tensor_data_reader + + +def lbann_test(check_gradients=False, **decorator_kwargs): + """ + A decorator that wraps an LBANN-enabled model unit test. + Use it before a function named ``test_*`` to run it automatically in pytest. + The unit test in the wrapped function must return a ``test_util.ModelTester`` + object, which contains all the necessary information to test the model (e.g., + model, input/reference tensors). + + The decorator wraps the test with the appropriate setup phase, data reading, + callbacks, and metrics so that the test functions properly. + """ + + def internal_tester(f): + + @functools.wraps(f) + def wrapped(*args, **kwargs): + # Call model constructor + tester = f(*args, **kwargs) + + # Check return value + if not isinstance(tester, ModelTester): + raise ValueError('LBANN test must return a ModelTester object') + if tester.loss is None: + raise ValueError( + 'LBANN test did not define a loss function, ' + 'use ``ModelTester.set_loss`` or ``set_loss_function``.') + if tester.input_tensor is None: + raise ValueError('LBANN test did not define an input, call ' + '``ModelTester.inputs`` or ``inputs_like``.') + if (tester.reference_tensor is not None + and tester.reference_tensor.shape[0] != + tester.input_tensor.shape[0]): + raise ValueError( + 'Input and reference tensors in LBANN test ' + 'must match in the first (minibatch) dimension') + full_graph = lbann.traverse_layer_graph(tester.loss) + callbacks = [] + callbacks.append( + lbann.CallbackCheckMetric(metric='test', + lower_bound=0, + upper_bound=tester.tolerance, + error_on_failure=True, + execution_modes='test')) + if check_gradients: + callbacks.append( + lbann.CallbackCheckGradients(error_on_failure=True)) + callbacks.extend(tester.extra_callbacks) + + metrics = [lbann.Metric(tester.loss, name='test')] + metrics.extend(tester.extra_metrics) + model = lbann.Model(epochs=0, + layers=full_graph, + metrics=metrics, + callbacks=callbacks) + + # Get file + file = inspect.getfile(f) + + def setup_func(lbann, weekly): + # Get minibatch size from tensor + mini_batch_size = tester.input_tensor.shape[0] + + # Save combined input/reference data to file + work_dir = _get_work_dir(file) + os.makedirs(work_dir, exist_ok=True) + if tester.reference_tensor is not None: + flat_inp = tester.input_tensor.reshape(mini_batch_size, -1) + flat_ref = tester.reference_tensor.reshape( + mini_batch_size, -1) + np.save(os.path.join(work_dir, 'data.npy'), + np.concatenate((flat_inp, flat_ref), axis=1)) + else: + np.save(os.path.join(work_dir, 'data.npy'), + tester.input_tensor.reshape(mini_batch_size, -1)) + + # Setup data reader + data_reader = lbann.reader_pb2.DataReader() + data_reader.reader.extend([ + tools.create_python_data_reader( + lbann, single_tensor_data_reader.__file__, + 'get_sample', 'num_samples', 'sample_dims', 'train'), + tools.create_python_data_reader( + lbann, single_tensor_data_reader.__file__, + 'get_sample', 'num_samples', 'sample_dims', 'test') + ]) + + trainer = lbann.Trainer(mini_batch_size) + optimizer = lbann.NoOptimizer() + return trainer, model, data_reader, optimizer, None # Don't request any specific number of nodes + + test = tools.create_tests(setup_func, file, **decorator_kwargs)[0] + cluster = kwargs.get('cluster', 'unset') + weekly = kwargs.get('weekly', False) + test(cluster, weekly, False, **decorator_kwargs) + + return wrapped + + return internal_tester + + +@dataclass +class ModelTester: + """ + An object that is constructed within an ``lbann_test``-wrapped unit test. + """ + + # Input tensor (required for test to construct) + input_tensor: Optional[Any] = None + + reference: Optional[lbann.Layer] = None #: Reference LBANN node (optional) + reference_tensor: Optional[ + Any] = None #: Optional reference tensor to compare with + + loss: Optional[lbann.Layer] = None # Optional loss test + tolerance: float = 0.0 #: Tolerance value for loss test + + # Optional additional metrics to use in test + extra_metrics: List[lbann.Metric] = field(default_factory=list) + + # Optional additional callbacks to use in test + extra_callbacks: List[lbann.Callback] = field(default_factory=list) + + def inputs(self, tensor: Any) -> lbann.Layer: + """ + Marks the given tensor as an input of the tested LBANN model, and + returns a matching LBANN Input node (or a Slice/Reshape thereof). + + :param tensor: The input NumPy array to use. + :return: An LBANN layer object that will serve as the input. + """ + self.input_tensor = tensor + inp = lbann.Input(data_field='samples') + return slice_to_tensors(inp, tensor) + + def inputs_like(self, *tensors) -> List[lbann.Layer]: + """ + Marks the given tensors as input of the tested LBANN model, and + returns a list of matching LBANN Slice nodes, potentially reshaped to + be like the input tensors. + + :param tensors: The input NumPy arrays to use. + :return: A list of LBANN layer objects that will serve as the inputs. + """ + minibatch_size = tensors[0].shape[0] # Assume the first dimension + + # All tensors concatenated on the non-batch dimension + all_tensors_combined = np.concatenate( + [t.reshape(minibatch_size, -1) for t in tensors], axis=1) + + self.input_tensor = all_tensors_combined + x = lbann.Input(data_field='samples') + return slice_to_tensors(x, *tensors) + + def make_reference(self, ref: Any) -> lbann.Input: + """ + Marks the given tensor as a reference output of the tested LBANN model, + and returns a matching LBANN node. + + :param ref: The reference NumPy array to use. + :return: An LBANN layer object that will serve as the reference. + """ + # The reference is the second part of the input "samples" + refnode = lbann.Input(data_field='samples') + if self.input_tensor is None: + raise ValueError('Please call ``inputs`` or ``inputs_like`` prior ' + 'to calling ``make_reference`` for correctness.') + mbsize = self.input_tensor.shape[0] + + # Obtain reference + refnode = lbann.Reshape(lbann.Identity( + lbann.Slice( + refnode, + slice_points=[ + numel(self.input_tensor) // mbsize, + (numel(self.input_tensor) + numel(ref)) // mbsize + ], + )), + dims=ref.shape[1:]) + + # Store reference + self.reference = refnode + self.reference_tensor = ref + return self.reference + + def set_loss_function(self, + func: Callable[[lbann.Layer, lbann.Layer], + lbann.Layer], + output: lbann.Layer, + tolerance=None): + """ + Sets a loss function and the LBANN test output to be measured for the + test. + This assumes that the first argument has two parameters (e.g., + ``MeanSquaredError``), where the first argument will be used for the + LBANN output and the second will be used for the reference. + + :param func: The loss function. + :param output: The LBANN model output to use. + :param tolerance: Optional tolerance to set for the test. If ``None``, + the default tolerance of ``8*eps*mean(reference)`` + will be used. + """ + return self.set_loss(func(output, self.reference), tolerance) + + def set_loss(self, + loss: lbann.Layer, + tolerance: Optional[float] = None) -> None: + """ + Sets an LBANN node to be measured for the test. + + :param loss: The LBANN graph node to use for the test. + :param tolerance: Optional tolerance to set for the test. If ``None``, + the default tolerance of ``8*eps*mean(reference)`` + will be used. + """ + # Set loss node + self.loss = loss + + # Set tolerance + if tolerance is not None: + self.tolerance = tolerance + else: + if self.reference_tensor is None: + raise ValueError( + 'Cannot set tolerance on loss function automatically ' + 'without a reference tensor. Either set tolerance ' + 'explicitly or call ``ModelTester.make_reference``.') + # Default tolerance + self.tolerance = abs(8 * np.mean(self.reference_tensor) * + np.finfo(self.reference_tensor.dtype).eps) + + +def slice_to_tensors(x: lbann.Layer, *tensors) -> List[lbann.Layer]: + """ + Slices an LBANN layer into multiple tensors that match the dimensions of + the given numpy arrays. + """ + slice_points = [0] + offset = 0 + for tensor in tensors: + offset += numel(tensor) // tensor.shape[0] + + slice_points.append(offset) + lslice = lbann.Slice(x, slice_points=slice_points) + return [ + lbann.Reshape(_ensure_bp(t, lbann.Identity(lslice)), dims=t.shape[1:]) + for t in tensors + ] + + +def numel(tensor) -> int: + """ + Returns the number of elements in a NumPy array, PyTorch array, or integer. + """ + if isinstance(tensor, int): # Integer + return tensor + elif hasattr(tensor, 'numel'): # PyTorch array + return tensor.numel() + else: # NumPy array + return tensor.size + + +# Mimics the other tester's determination of working directory +def _get_work_dir(test_file: str) -> str: + test_fname = os.path.realpath(test_file) + # Create test name by removing '.py' from file name + test_fname = os.path.splitext(os.path.basename(test_fname))[0] + if not re.match('^test_.', test_fname): + # Make sure test name is prefixed with 'test_' + test_fname = 'test_' + test_fname + return os.path.join(os.path.dirname(test_file), 'experiments', test_fname) + + +# Ensures that backpropagation would be run through the entire model +def _ensure_bp(tensor: Any, node: lbann.Layer) -> lbann.Sum: + # Note: Sum with a weights layer so that gradient checking will + # verify that error signals are correct. + x_weights = lbann.Weights(initializer=lbann.ConstantInitializer(value=0.0)) + return lbann.Sum( + node, + lbann.WeightsLayer( + weights=x_weights, + dims=[numel(tensor) // tensor.shape[0]], + )) diff --git a/ci_test/common_python/tools.py b/ci_test/common_python/tools.py index f433c686b28..e53ed98fe42 100644 --- a/ci_test/common_python/tools.py +++ b/ci_test/common_python/tools.py @@ -672,14 +672,14 @@ def assert_success(return_code, error_file_name): if return_code != 0: error_line = get_error_line(error_file_name) raise AssertionError( - 'return_code={rc}\n{el}\nSee {efn}'.format( + '{el}\nreturn_code={rc}\nSee {efn}'.format( rc=return_code, el=error_line, efn=error_file_name)) def assert_failure(return_code, expected_error, error_file_name): if return_code == 0: raise AssertionError( - 'return_code={rc}\nSuccess when expecting failure.\nSee {efn}'.format( + 'Success when expecting failure. return_code={rc}\nSee {efn}'.format( rc=return_code, efn=error_file_name)) with open(error_file_name, 'r') as error_file: for line in error_file: @@ -689,7 +689,7 @@ def assert_failure(return_code, expected_error, error_file_name): # but we didn't get the expected error. actual_error = get_error_line(error_file_name) raise AssertionError( - 'return_code={rc}\nFailed with error different than expected.\nactual_error={ae}\nexpected_error={ee}\nSee {efn}'.format( + 'Failed with error different than expected: actual_error={ae}, expected_error={ee}\nreturn_code={rc}\nSee {efn}'.format( rc=return_code, ae=actual_error, ee=expected_error, efn=error_file_name)) diff --git a/ci_test/integration_tests/test_integration_nasnet.py b/ci_test/integration_tests/test_integration_nasnet.py index 1214f19c277..fa1968124aa 100644 --- a/ci_test/integration_tests/test_integration_nasnet.py +++ b/ci_test/integration_tests/test_integration_nasnet.py @@ -60,7 +60,7 @@ 'num_nodes': 2, 'num_epochs': 4, 'mini_batch_size': 64, - 'expected_train_accuracy_range': (48, 65), # BVE relaxed lower bound from 50 9/21/22 + 'expected_train_accuracy_range': (47.9, 65), # BVE relaxed lower bound from 50 9/21/22, TBN relaxed further to 47.9 5/16/23 'expected_test_accuracy_range': (49, 65), # BVE relaxed lower bound from 50 9/22/22 'expected_mini_batch_times': { 'lassen': 0.075, diff --git a/ci_test/unit_tests/test_unit_layer_addconstant.py b/ci_test/unit_tests/test_unit_layer_addconstant.py new file mode 100644 index 00000000000..464aef99d81 --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_addconstant.py @@ -0,0 +1,22 @@ +import lbann +import numpy as np +import test_util +import pytest + + +@pytest.mark.parametrize('constant', [0, 1]) +@test_util.lbann_test(check_gradients=True) +def test_simple(constant): + np.random.seed(20230515) + # Two samples of 2x3 tensors + x = np.random.rand(2, 2, 3).astype(np.float32) + ref = x + constant + + tester = test_util.ModelTester() + x = tester.inputs(x) + ref = tester.make_reference(ref) + + # Test layer + y = lbann.AddConstant(x, constant=constant) + tester.set_loss(lbann.MeanSquaredError(y, ref)) + return tester diff --git a/ci_test/unit_tests/test_unit_layer_weightedsum.py b/ci_test/unit_tests/test_unit_layer_weightedsum.py new file mode 100644 index 00000000000..393727edc8b --- /dev/null +++ b/ci_test/unit_tests/test_unit_layer_weightedsum.py @@ -0,0 +1,71 @@ +import lbann +import numpy as np +import test_util +import pytest + + +@test_util.lbann_test(check_gradients=True) +def test_weightedsum_twoinputs(): + # Prepare reference output + np.random.seed(20230516) + x1 = np.random.rand(20, 20) + x2 = np.random.rand(20, 20) + reference_numpy = 0.25 * x1 + 0.5 * x2 + + tester = test_util.ModelTester() + + x1, x2 = tester.inputs_like(x1, x2) + reference = tester.make_reference(reference_numpy) + + # Test layer + y = lbann.WeightedSum(x1, x2, scaling_factors=[0.25, 0.5]) + + # Set test loss + tester.set_loss(lbann.MeanSquaredError(y, reference)) + return tester + + +@pytest.mark.parametrize('inputs', [3, 5]) +@test_util.lbann_test(check_gradients=True) +def test_weightedsum_n_inputs(inputs): + # Prepare reference output + np.random.seed(20230516) + x = [np.random.rand(3, 20) for _ in range(inputs)] + factors = [np.random.rand() for _ in range(inputs)] + reference_numpy = sum(f * xi for f, xi in zip(factors, x)) + + tester = test_util.ModelTester() + + x = tester.inputs_like(*x) + reference = tester.make_reference(reference_numpy) + + # Test layer + y = lbann.WeightedSum(*x, scaling_factors=factors) + + # Set test loss + tester.set_loss(lbann.MeanSquaredError(y, reference)) + tester.tolerance *= inputs # Make it more tolerant towards more inputs + return tester + + +@pytest.mark.parametrize('dims', [1, 3]) +@test_util.lbann_test(check_gradients=True) +def test_weightedsum_oneinput(dims): + # Prepare reference output + np.random.seed(20230516) + shape = [3] + [2] * dims + x = np.random.rand(*shape).astype(np.float32) + a = 0.4 + ref = a * x + + tester = test_util.ModelTester() + + x = tester.inputs(x) + reference = tester.make_reference(ref) + + # Test layer + y = lbann.WeightedSum(x, scaling_factors=[a]) + + # Set test loss + tester.set_loss(lbann.MeanSquaredError(y, reference)) + return tester diff --git a/include/lbann/layers/transform/weighted_sum.hpp b/include/lbann/layers/transform/weighted_sum.hpp index 7a0c72ecd72..ee286ae65c0 100644 --- a/include/lbann/layers/transform/weighted_sum.hpp +++ b/include/lbann/layers/transform/weighted_sum.hpp @@ -135,8 +135,13 @@ class weighted_sum_layer : public data_type_layer void fp_compute() override { auto& output = this->get_activations(); - El::Zero(output); - for (int i = 0; i < this->get_num_parents(); ++i) { + + // Special case for the first input so that in-place operation works + if (!this->m_runs_inplace) + El::Copy(this->get_prev_activations(0), output); + + El::Scale(m_scaling_factors[0], output); + for (int i = 1; i < this->get_num_parents(); ++i) { El::Axpy(m_scaling_factors[i], this->get_prev_activations(i), output); } } @@ -144,11 +149,17 @@ class weighted_sum_layer : public data_type_layer void bp_compute() override { const auto& gradient_wrt_output = this->get_prev_error_signals(); - for (int i = 0; i < this->get_num_parents(); ++i) { + + for (int i = 1; i < this->get_num_parents(); ++i) { auto& gradient_wrt_input = this->get_error_signals(i); El::Zero(gradient_wrt_input); El::Axpy(m_scaling_factors[i], gradient_wrt_output, gradient_wrt_input); } + + // Special case for the first input so that in-place operation works + if (!this->m_runs_inplace) + El::Copy(gradient_wrt_output, this->get_error_signals(0)); + El::Scale(m_scaling_factors[0], this->get_error_signals(0)); } };