Skip to content

Commit

Permalink
add sensitivity tests for correctness
Browse files Browse the repository at this point in the history
  • Loading branch information
Algue-Rythme committed Mar 4, 2024
1 parent 48c3eba commit 23b86f2
Show file tree
Hide file tree
Showing 8 changed files with 284 additions and 46 deletions.
6 changes: 5 additions & 1 deletion deel/lipdp/dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,11 @@ def on_train_begin(self, logs=None):
self._assign_dp_dict(last_layer)

def get_gradloss(self):
"""Computes the norm of gradient of the loss with respect to the model's output."""
"""Computes the norm of gradient of the loss with respect to the model's output.
Warning: this method is unsafe from a privacy perspective, as the true gradient bound is computed.
It is meant to be used with privacy-preserving methods only, such as the ones implemented in this module.
"""
batch = next(iter(self.ds_train.take(1)))
imgs, labels = batch
self.model.loss.reduction = tf.keras.losses.Reduction.NONE
Expand Down
86 changes: 53 additions & 33 deletions deel/lipdp/sensitivity.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from deel.lipdp.model import get_eps_delta


def get_max_epochs(epsilon_max, model, epochs_max=1024, safe=True):
def get_max_epochs(epsilon_max, model, epochs_max=1024, safe=True, atol=1e-2):
"""Return the maximum number of epochs to reach a given epsilon_max value.
The computation of (epsilon, delta) is slow since it involves solving a minimization problem
Expand All @@ -47,6 +47,7 @@ def get_max_epochs(epsilon_max, model, epochs_max=1024, safe=True):
If None, the dichotomy search is used to find the upper bound.
safe: If True, the dichotomy search returns the largest number of epochs such that epsilon <= epsilon_max.
Otherwise, it returns the smallest number of epochs such that epsilon >= epsilon_max.
atol: The absolute tolerance to panic on numerical inaccuracy. Defaults to 1e-2.
Returns:
The maximum number of epochs to reach epsilon_max. It may be zero if epsilon_max is too small.
Expand All @@ -57,7 +58,6 @@ def fun(epoch):
if epoch == 0:
epsilon = 0
else:
niter = (epoch + 1) * steps_per_epoch
epsilon, _ = get_eps_delta(model, epoch)
return epsilon

Expand All @@ -83,46 +83,66 @@ def fun(epoch):
f"epoch bounds = {epochs_min, epochs_max} and epsilon = {epsilon} at epoch {epoch}"
)

return epochs_min if safe else epochs_max
if safe:
last_epsilon = fun(epochs_min)
error = last_epsilon - epsilon_max
if error <= 0:
return epochs_min
elif error < atol:
# This branch should never be taken if fun is a non-decreasing function of the number of epochs.
# fun is mathematcally non-decreasing, but numerical inaccuracy can lead to this case.
print(f"Numerical inaccuracy with error {error:.7f} in the dichotomy search: using a conservative value.")
return epochs_min - 1
else:
assert False, f"Numerical inaccuracy with error {error:.7f}>{atol:.3f} in the dichotomy search."

return epochs_max


def gradient_norm_check(upper_bounds, model, examples):
"""Verifies that the values of per-sample gradients on a layer never exceede a value
determined by the theoretical work.
def gradient_norm_check(K_list, model, examples):
"""
Verifies that the values of per-sample gradients on a layer never exceede a theoretical value
determined by our theoretical work.
Args :
Klist: The list of theoretical upper bounds we have identified for each layer and want to
put to the test.
upper_bounds: maximum gradient bounds for each layer (dictionnary of 'layers name ': 'bounds' pairs).
model: The model containing the layers we are interested in. Layers must only have one trainable variable.
Model must have a given input_shape or has to be built.
examples: Relevant examples. Inputting the whole training set might prove very costly to check element wise Jacobians.
examples: a batch of examples to test on.
Returns :
Boolean value. True corresponds to upper bound has been validated.
"""
image_axes = tuple(range(1, examples.ndim))
example_norms = tf.math.reduce_euclidean_norm(examples, axis=image_axes)
X_max = tf.reduce_max(example_norms).numpy()
upper_bounds = np.array(K_list) * X_max
assert len(model.layers) == len(upper_bounds)
for layer, bound in zip(model.layers, upper_bounds):
assert check_layer_gradient_norm(bound, layer, examples)


def check_layer_gradient_norm(S, layer, examples):
l_model = tf.keras.Sequential([layer])
if not l_model.trainable_variables:
print("Not a trainable layer assuming gradient norm < |x|")
assert len(l_model.trainable_variables) == 1
with tf.GradientTape() as tape:
y_pred = l_model(examples, training=True)
trainable_vars = l_model.trainable_variables[0]
jacobian = tape.jacobian(y_pred, trainable_vars)
jacobian = tf.reshape(
jacobian,
activations = examples
var_seen = set()
for layer in model.layers:
post_activations = layer(activations, training=True)
assert len(layer.trainable_variables) < 2
if len(layer.trainable_variables) == 1:
assert len(layer.trainable_variables) == 1
train_var = layer.trainable_variables[0]
var_name = layer.trainable_variables[0].name
var_seen.add(var_name)
bound = upper_bounds[var_name]
check_layer_gradient_norm(bound, layer, activations)
activations = post_activations
for var_name in upper_bounds:
assert var_name in var_seen


def check_layer_gradient_norm(S, layer, activations):
trainable_vars = layer.trainable_variables[0]
with tf.GradientTape() as tape:
y_pred = layer(activations, training=True)
flat_pred = tf.reshape(y_pred, (y_pred.shape[0], -1))
jacobians = tape.jacobian(flat_pred, trainable_vars)
assert jacobians.shape[0] == activations.shape[0]
assert jacobians.shape[1] == np.prod(y_pred.shape[1:])
assert np.prod(jacobians.shape[2:]) == np.prod(trainable_vars.shape)
jacobians = tf.reshape(
jacobians,
(y_pred.shape[0], -1, np.prod(trainable_vars.shape)),
name="Reshaped_Gradient",
)
J_sigma = tf.linalg.svd(jacobian, full_matrices=False, compute_uv=False, name=None)
J_sigma = tf.linalg.svd(jacobians, full_matrices=False, compute_uv=False, name=None)
J_2norm = tf.reduce_max(J_sigma, axis=-1)
J_2norm = tf.reduce_max(J_2norm).numpy()
return J_2norm < S
atol = 1e-5
return J_2norm < S+atol
2 changes: 1 addition & 1 deletion experiments/MNIST/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def default_cfg_mnist():
cfg.loss = "TauCategoricalCrossentropy"
cfg.log_wandb = "disabled"
cfg.noise_multiplier = 1.5
cfg.noisify_strategy = "local"
cfg.noisify_strategy = "per-layer"
cfg.optimizer = "Adam"
cfg.opt_iterations = None
cfg.save = False
Expand Down
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ setuptools
pre-commit
ml_collections
absl-py
pytest
tox
black
pytest
Expand Down
12 changes: 8 additions & 4 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# Tests

To run all the tests, simply type
To run all the tests, start from the root and simply type

```bash
cd test/
pytest .
```

To run a specific test, type
To run a specific test , type

```bash
python test_<name>.py
cd test/
python test_<name1>.py Test<name2>.test_<name3>
```

where `<name>` is the name of the test file.
where `<name1>, <name2>, <name3>` are the names of the test file, the class and the test function, respectively.

By default, tests are not run on GPU to enfore reproducibility.
116 changes: 114 additions & 2 deletions tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,122 @@
from absl.testing import absltest
from absl.testing import parameterized

from deel.lipdp.dynamic import AdaptiveQuantileClipping
from deel.lipdp.layers import *
from deel.lipdp.model import DP_Sequential, DPParameters
from deel.lipdp.pipeline import bound_normalize, load_and_prepare_images_data
from deel.lipdp.losses import DP_TauCategoricalCrossentropy


class ModelTest(parameterized.TestCase):
def test_create_model(self):
pass

def _get_mnist_cnn(self):
ds_train, _, dataset_metadata = load_and_prepare_images_data(
"mnist",
batch_size=64,
colorspace="grayscale",
drop_remainder=True,
bound_fct=bound_normalize(),
)

norm_max = 1.0
all_layers = [
DP_BoundedInput(input_shape=(28, 28, 1), upper_bound=norm_max),
DP_SpectralConv2D(
filters=16,
kernel_size=3,
kernel_initializer="orthogonal",
strides=1,
use_bias=False,
),
DP_AddBias(norm_max=norm_max),
DP_GroupSort(2),
DP_ScaledL2NormPooling2D(pool_size=2, strides=2),
DP_LayerCentering(),
DP_Flatten(),
DP_SpectralDense(1024, use_bias=False, kernel_initializer="orthogonal"),
DP_AddBias(norm_max=norm_max),
DP_SpectralDense(10, use_bias=False, kernel_initializer="orthogonal"),
DP_AddBias(norm_max=norm_max),
DP_ClipGradient(
clip_value=2. ** 0.5,
mode="dynamic",
),
]

dp_parameters = DPParameters(
noisify_strategy='per-layer',
noise_multiplier=2.2,
delta=1e-5,
)

model = DP_Sequential(
all_layers,
dp_parameters=dp_parameters,
dataset_metadata=dataset_metadata,
)

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)
loss = DP_TauCategoricalCrossentropy(
tau=1., reduction=tf.keras.losses.Reduction.SUM_OVER_BATCH_SIZE
)
model.compile(optimizer=optimizer, loss=loss, metrics=["accuracy"])

return model, ds_train

def test_forward_cnn(self):
model, ds_train = self._get_mnist_cnn()
batch_x, _ = ds_train.take(1).as_numpy_iterator().next()
logits = model(batch_x)
assert logits.shape == (len(batch_x), 10)

def test_create_residuals(self):
input_shape = (32, 32, 3)

patch_size = 4
seq_len = (input_shape[0] // patch_size) * (
input_shape[1] // patch_size
)
multiplier = 1
mlp_seq_dim = multiplier * seq_len

to_add = [
DP_Permute((2, 1)),
DP_QuickSpectralDense(
units=mlp_seq_dim, use_bias=False, kernel_initializer="orthogonal"
),
]
to_add.append(DP_GroupSort(2))
to_add.append(DP_LayerCentering())
to_add += [
DP_QuickSpectralDense(
units=seq_len, use_bias=False, kernel_initializer="orthogonal"
),
DP_Permute((2, 1)),
]

blocks = make_residuals("1-lip-add", to_add)
input_bound = 1.0 # placeholder
for layer in blocks[:-1]:
input_bound = layer.propagate_inputs(input_bound)
assert len(input_bound) == 2
last = blocks[-1].propagate_inputs(input_bound)
assert isinstance(last, float)

def test_adaptive_clipping(self):
num_steps_test_case = 3
model, ds_train = self._get_mnist_cnn()
ds_train = ds_train.take(num_steps_test_case)
adaptive = AdaptiveQuantileClipping(
ds_train=ds_train,
patience=1,
noise_multiplier=2.2,
quantile=0.9,
learning_rate=1.0,
)
adaptive.set_model(model)
callbacks = [adaptive]
model.fit(ds_train, epochs=2, callbacks=callbacks, steps_per_epoch=num_steps_test_case)


if __name__ == "__main__":
Expand Down
10 changes: 7 additions & 3 deletions tests/pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,11 @@
from deel.lipdp.pipeline import bound_clip_value
from deel.lipdp.pipeline import bound_normalize
from deel.lipdp.pipeline import load_and_prepare_images_data
from deel.lipdp.pipeline import default_delta_value


class PipelineTest(parameterized.TestCase):

def test_cifar10_common(self):
batch_size = 64
max_norm = 5e-2
Expand All @@ -41,7 +43,7 @@ def test_cifar10_common(self):
colorspace=colorspace,
drop_remainder=True, # accounting assumes fixed batch size.
bound_fct=bound_clip_value(max_norm),
multiplicity=0, # no multiplicity for mnist
multiplicity=0, # no multiplicity for cifar10
)

self.assertEqual(dataset_metadata.nb_classes, 10)
Expand All @@ -62,6 +64,8 @@ def test_cifar10_common(self):
self.assertEqual(batch_sizes[-1], batch_size)
self.assertEqual(len(batch_sizes), 50_000 // batch_size)
self.assertEqual(dataset_metadata.nb_steps_per_epochs, len(batch_sizes))
delta_heuristic = default_delta_value(dataset_metadata)
self.assertLessEqual(dataset_metadata.nb_samples_train, 1./delta_heuristic)

@parameterized.parameters(("RGB",), ("grayscale",), ("HSV",))
def test_cifar10_colorspace(self, colorspace):
Expand All @@ -74,7 +78,7 @@ def test_cifar10_colorspace(self, colorspace):
colorspace=colorspace,
drop_remainder=True, # accounting assumes fixed batch size.
bound_fct=bound_clip_value(max_norm),
multiplicity=0, # no multiplicity for mnist
multiplicity=0, # no multiplicity for cifar10
)

batch = next(iter(ds_test))
Expand All @@ -98,7 +102,7 @@ def test_cifar10_augmult(self, multiplicity: int):
colorspace=colorspace,
drop_remainder=True, # accounting assumes fixed batch size.
bound_fct=bound_clip_value(max_norm),
multiplicity=multiplicity, # no multiplicity for mnist
multiplicity=multiplicity,
)

self.assertEqual(dataset_metadata.batch_size, batch_size)
Expand Down
Loading

0 comments on commit 23b86f2

Please sign in to comment.