From ed5ae944947dd2691a8d65c5af951d6d149d7993 Mon Sep 17 00:00:00 2001 From: mattiasakesson <33224977+mattiasakesson@users.noreply.github.com> Date: Thu, 12 Sep 2024 17:16:19 +0200 Subject: [PATCH] Feature/SK-811 | Example using Differential Privacy (#698) --- examples/mnist-pytorch-DPSGD/README.rst | 56 ++++++ examples/mnist-pytorch-DPSGD/client/data.py | 99 ++++++++++ examples/mnist-pytorch-DPSGD/client/fedn.yaml | 12 ++ examples/mnist-pytorch-DPSGD/client/model.py | 76 +++++++ .../client/python_env.yaml | 15 ++ examples/mnist-pytorch-DPSGD/client/train.py | 185 ++++++++++++++++++ .../mnist-pytorch-DPSGD/client/validate.py | 55 ++++++ 7 files changed, 498 insertions(+) create mode 100644 examples/mnist-pytorch-DPSGD/README.rst create mode 100644 examples/mnist-pytorch-DPSGD/client/data.py create mode 100644 examples/mnist-pytorch-DPSGD/client/fedn.yaml create mode 100644 examples/mnist-pytorch-DPSGD/client/model.py create mode 100644 examples/mnist-pytorch-DPSGD/client/python_env.yaml create mode 100644 examples/mnist-pytorch-DPSGD/client/train.py create mode 100644 examples/mnist-pytorch-DPSGD/client/validate.py diff --git a/examples/mnist-pytorch-DPSGD/README.rst b/examples/mnist-pytorch-DPSGD/README.rst new file mode 100644 index 000000000..88220584a --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/README.rst @@ -0,0 +1,56 @@ +FEDn Project: Federated Differential Privacy MNIST (Opacus + PyTorch) +---------------------------------------------------------------------- + +This example FEDn Project demonstrates how Differential Privacy can be integrated to enhance the confidentiality of client data. +We have expanded our baseline MNIST-PyTorch example by incorporating the Opacus framework, which is specifically designed for PyTorch models. + + + +Prerequisites +------------- + +- `Python >=3.8, <=3.12 `__ +- `A project in FEDn Studio `__ + +Edit Differential Privacy budget +-------------------------- +- The **Differential Privacy budget** (`FINAL_EPSILON`, `DELTA`) is configured in the `compute` package at `client/train.py` (lines 35 and 39). +- If `HARDLIMIT` (line 40) is set to `True`, the `FINAL_EPSILON` will not exceed its specified limit. +- If `HARDLIMIT` is set to `False`, the expected `FINAL_EPSILON` will be around its specified value given the server runs `GLOBAL_ROUNDS` variable (line 36). + +Creating the compute package and seed model +------------------------------------------- + +Install fedn: + +.. code-block:: + + pip install fedn + +Clone this repository, then locate into this directory: + +.. code-block:: + + git clone https://github.com/scaleoutsystems/fedn.git + cd fedn/examples/mnist-pytorch-DPSGD + +Create the compute package: + +.. code-block:: + + fedn package create --path client + +This creates a file 'package.tgz' in the project folder. + +Next, generate the seed model: + +.. code-block:: + + fedn run build --path client + +This will create a model file 'seed.npz' in the root of the project. This step will take a few minutes, depending on hardware and internet connection (builds a virtualenv). + +Running the project on FEDn +---------------------------- + +To learn how to set up your FEDn Studio project and connect clients, take the quickstart tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html. diff --git a/examples/mnist-pytorch-DPSGD/client/data.py b/examples/mnist-pytorch-DPSGD/client/data.py new file mode 100644 index 000000000..b7cc5b8e7 --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/data.py @@ -0,0 +1,99 @@ +import os +from math import floor + +import torch +import torchvision + +dir_path = os.path.dirname(os.path.realpath(__file__)) +abs_path = os.path.abspath(dir_path) + + +def get_data(out_dir="data"): + # Make dir if necessary + if not os.path.exists(out_dir): + os.mkdir(out_dir) + + # Only download if not already downloaded + if not os.path.exists(f"{out_dir}/train"): + torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True, download=True) + if not os.path.exists(f"{out_dir}/test"): + torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False, download=True) + + +def load_data(data_path, is_train=True): + """Load data from disk. + + :param data_path: Path to data file. + :type data_path: str + :param is_train: Whether to load training or test data. + :type is_train: bool + :return: Tuple of data and labels. + :rtype: tuple + """ + print("data_path is None: ", data_path is None) + if data_path is None: + data_path = os.environ.get("FEDN_DATA_PATH", abs_path + "/data/clients/1/mnist.pt") + + print("data path: ", data_path) + data = torch.load(data_path) + + if is_train: + X = data["x_train"] + y = data["y_train"] + else: + X = data["x_test"] + y = data["y_test"] + + # Normalize + X = X / 255 + + return X, y + + +def splitset(dataset, parts): + n = dataset.shape[0] + local_n = floor(n / parts) + result = [] + for i in range(parts): + result.append(dataset[i * local_n : (i + 1) * local_n]) + return result + + +def split(out_dir="data"): + n_splits = int(os.environ.get("FEDN_NUM_DATA_SPLITS", 2)) + + # Make dir + if not os.path.exists(f"{out_dir}/clients"): + os.mkdir(f"{out_dir}/clients") + + # Load and convert to dict + train_data = torchvision.datasets.MNIST(root=f"{out_dir}/train", transform=torchvision.transforms.ToTensor, train=True) + test_data = torchvision.datasets.MNIST(root=f"{out_dir}/test", transform=torchvision.transforms.ToTensor, train=False) + data = { + "x_train": splitset(train_data.data, n_splits), + "y_train": splitset(train_data.targets, n_splits), + "x_test": splitset(test_data.data, n_splits), + "y_test": splitset(test_data.targets, n_splits), + } + + # Make splits + for i in range(n_splits): + subdir = f"{out_dir}/clients/{str(i+1)}" + if not os.path.exists(subdir): + os.mkdir(subdir) + torch.save( + { + "x_train": data["x_train"][i], + "y_train": data["y_train"][i], + "x_test": data["x_test"][i], + "y_test": data["y_test"][i], + }, + f"{subdir}/mnist.pt", + ) + + +if __name__ == "__main__": + # Prepare data if not already done + if not os.path.exists(abs_path + "/data/clients/1"): + get_data() + split() diff --git a/examples/mnist-pytorch-DPSGD/client/fedn.yaml b/examples/mnist-pytorch-DPSGD/client/fedn.yaml new file mode 100644 index 000000000..30873488b --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/fedn.yaml @@ -0,0 +1,12 @@ +python_env: python_env.yaml +entry_points: + build: + command: python model.py + startup: + command: python data.py + train: + command: python train.py + validate: + command: python validate.py + predict: + command: python predict.py \ No newline at end of file diff --git a/examples/mnist-pytorch-DPSGD/client/model.py b/examples/mnist-pytorch-DPSGD/client/model.py new file mode 100644 index 000000000..6ad344770 --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/model.py @@ -0,0 +1,76 @@ +import collections + +import torch + +from fedn.utils.helpers.helpers import get_helper + +HELPER_MODULE = "numpyhelper" +helper = get_helper(HELPER_MODULE) + + +def compile_model(): + """Compile the pytorch model. + + :return: The compiled model. + :rtype: torch.nn.Module + """ + + class Net(torch.nn.Module): + def __init__(self): + super(Net, self).__init__() + self.fc1 = torch.nn.Linear(784, 64) + self.fc2 = torch.nn.Linear(64, 32) + self.fc3 = torch.nn.Linear(32, 10) + + def forward(self, x): + x = torch.nn.functional.relu(self.fc1(x.reshape(x.size(0), 784))) + x = torch.nn.functional.dropout(x, p=0.5, training=self.training) + x = torch.nn.functional.relu(self.fc2(x)) + x = torch.nn.functional.log_softmax(self.fc3(x), dim=1) + return x + + return Net() + + +def save_parameters(model, out_path): + """Save model paramters to file. + + :param model: The model to serialize. + :type model: torch.nn.Module + :param out_path: The path to save to. + :type out_path: str + """ + parameters_np = [val.cpu().numpy() for _, val in model.state_dict().items()] + helper.save(parameters_np, out_path) + + +def load_parameters(model_path): + """Load model parameters from file and populate model. + + param model_path: The path to load from. + :type model_path: str + :return: The loaded model. + :rtype: torch.nn.Module + """ + model = compile_model() + parameters_np = helper.load(model_path) + + params_dict = zip(model.state_dict().keys(), parameters_np) + state_dict = collections.OrderedDict({key: torch.tensor(x) for key, x in params_dict}) + model.load_state_dict(state_dict, strict=True) + return model + + +def init_seed(out_path="seed.npz"): + """Initialize seed model and save it to file. + + :param out_path: The path to save the seed model to. + :type out_path: str + """ + # Init and save + model = compile_model() + save_parameters(model, out_path) + + +if __name__ == "__main__": + init_seed("../seed.npz") diff --git a/examples/mnist-pytorch-DPSGD/client/python_env.yaml b/examples/mnist-pytorch-DPSGD/client/python_env.yaml new file mode 100644 index 000000000..a8175e20c --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/python_env.yaml @@ -0,0 +1,15 @@ +name: mnist-pytorch +build_dependencies: + - pip + - setuptools + - wheel +dependencies: + - fedn + - torch==2.4.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux") + # PyTorch macOS x86 builds deprecation + - torch==2.2.2; sys_platform == "darwin" and platform_machine == "x86_64" + - torchvision==0.19.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux") + - torchvision==0.17.2; sys_platform == "darwin" and platform_machine == "x86_64" + - numpy==2.0.2; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win" or sys_platform == "linux") + - numpy==1.26.4; sys_platform == "darwin" and platform_machine == "x86_64" + - opacus diff --git a/examples/mnist-pytorch-DPSGD/client/train.py b/examples/mnist-pytorch-DPSGD/client/train.py new file mode 100644 index 000000000..1210b7ad5 --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/train.py @@ -0,0 +1,185 @@ +import os +import sys + +import torch +from model import load_parameters, save_parameters + +from data import load_data +from fedn.utils.helpers.helpers import save_metadata + +from opacus import PrivacyEngine +from torch.utils.data import Dataset + +import numpy as np +from opacus.utils.batch_memory_manager import BatchMemoryManager +# Define a custom Dataset class +class CustomDataset(Dataset): + def __init__(self, x_data, y_data): + self.x_data = x_data + self.y_data = y_data + + def __len__(self): + return len(self.x_data) + + def __getitem__(self, idx): + x_data = self.x_data[idx] + y_data = self.y_data[idx] + return x_data, y_data + + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + +MAX_GRAD_NORM = 1.2 +FINAL_EPSILON = 8.0 +GLOBAL_ROUNDS = 4 +EPOCHS = 5 +EPSILON = FINAL_EPSILON/GLOBAL_ROUNDS +DELTA = 1e-5 +HARDLIMIT = False + +MAX_PHYSICAL_BATCH_SIZE = 32 + +def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01): + """Complete a model update. + + Load model paramters from in_model_path (managed by the FEDn client), + perform a model update, and write updated paramters + to out_model_path (picked up by the FEDn client). + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_model_path: The path to save the output model to. + :type out_model_path: str + :param data_path: The path to the data file. + :type data_path: str + :param batch_size: The batch size to use. + :type batch_size: int + :param epochs: The number of epochs to train. + :type epochs: int + :param lr: The learning rate to use. + :type lr: float + """ + # Load data + print("data_path: ", data_path) + x_train, y_train = load_data(data_path) + trainset = CustomDataset(x_train, y_train) + batch_size = 32 + train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, + shuffle=True, num_workers=2) + + # Load parmeters and initialize model + model = load_parameters(in_model_path) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device) + + + # Load epsilon + if os.path.isfile("epsilon.npy"): + + tot_epsilon = np.load("epsilon.npy") + print("load consumed epsilon: ", tot_epsilon) + + else: + + print("initiate tot_epsilon") + tot_epsilon = 0. + + # Train + optimizer = torch.optim.SGD(model.parameters(), lr=lr) + privacy_engine = PrivacyEngine() + + model, optimizer, train_loader = privacy_engine.make_private_with_epsilon( + module=model, + optimizer=optimizer, + data_loader=train_loader, + epochs=EPOCHS, + target_epsilon=EPSILON, + target_delta=DELTA, + max_grad_norm=MAX_GRAD_NORM, + ) + + print(f"Using sigma={optimizer.noise_multiplier} and C={MAX_GRAD_NORM}") + + + + for epoch in range(EPOCHS): + train_dp(model, train_loader, optimizer, epoch + 1, device, privacy_engine) + + d_epsilon = privacy_engine.get_epsilon(DELTA) + print("epsilon spent: ", d_epsilon) + tot_epsilon = np.sqrt(tot_epsilon**2 + d_epsilon**2) + print("saving tot_epsilon: ", tot_epsilon) + np.save("epsilon.npy", tot_epsilon) + + if HARDLIMIT and tot_epsilon >= FINAL_EPSILON: + print("DP Budget Exceeded: The differential privacy budget has been exhausted, no model updates will be applied to preserve privacy guarantees.") + + else: + # Metadata needed for aggregation server side + metadata = { + # num_examples are mandatory + "num_examples": len(x_train), + "batch_size": batch_size, + "epochs": epochs, + "lr": lr, + } + + # Save JSON metadata file (mandatory) + save_metadata(metadata, out_model_path) + + # Save model update (mandatory) + save_parameters(model, out_model_path) + +def accuracy(preds, labels): + return (preds == labels).mean() + + + + + +def train_dp(model, train_loader, optimizer, epoch, device, privacy_engine): + model.train() + criterion = torch.nn.NLLLoss() # nn.CrossEntropyLoss() + + losses = [] + top1_acc = [] + + with BatchMemoryManager( + data_loader=train_loader, + max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE, + optimizer=optimizer + ) as memory_safe_data_loader: + + for i, (images, target) in enumerate(memory_safe_data_loader): + optimizer.zero_grad() + images = images.to(device) + target = target.to(device) + + # compute output + output = model(images) + loss = criterion(output, target) + + preds = np.argmax(output.detach().cpu().numpy(), axis=1) + labels = target.detach().cpu().numpy() + + # measure accuracy and record loss + acc = accuracy(preds, labels) + + losses.append(loss.item()) + top1_acc.append(acc) + + loss.backward() + optimizer.step() + + if (i + 1) % 200 == 0: + epsilon = privacy_engine.get_epsilon(DELTA) + print( + f"\tTrain Epoch: {epoch} \t" + f"Loss: {np.mean(losses):.6f} " + f"Acc@1: {np.mean(top1_acc) * 100:.6f} " + f"(ε = {epsilon:.2f}, δ = {DELTA})" + ) + +if __name__ == "__main__": + train(sys.argv[1], sys.argv[2]) diff --git a/examples/mnist-pytorch-DPSGD/client/validate.py b/examples/mnist-pytorch-DPSGD/client/validate.py new file mode 100644 index 000000000..09328181f --- /dev/null +++ b/examples/mnist-pytorch-DPSGD/client/validate.py @@ -0,0 +1,55 @@ +import os +import sys + +import torch +from model import load_parameters + +from data import load_data +from fedn.utils.helpers.helpers import save_metrics + +dir_path = os.path.dirname(os.path.realpath(__file__)) +sys.path.append(os.path.abspath(dir_path)) + + +def validate(in_model_path, out_json_path, data_path=None): + """Validate model. + + :param in_model_path: The path to the input model. + :type in_model_path: str + :param out_json_path: The path to save the output JSON to. + :type out_json_path: str + :param data_path: The path to the data file. + :type data_path: str + """ + # Load data + x_train, y_train = load_data(data_path) + x_test, y_test = load_data(data_path, is_train=False) + + # Load model + model = load_parameters(in_model_path) + model.eval() + + # Evaluate + criterion = torch.nn.NLLLoss() + with torch.no_grad(): + train_out = model(x_train) + training_loss = criterion(train_out, y_train) + training_accuracy = torch.sum(torch.argmax(train_out, dim=1) == y_train) / len(train_out) + test_out = model(x_test) + test_loss = criterion(test_out, y_test) + test_accuracy = torch.sum(torch.argmax(test_out, dim=1) == y_test) / len(test_out) + + # JSON schema + report = { + "training_loss": training_loss.item(), + "training_accuracy": training_accuracy.item(), + "test_loss": test_loss.item(), + "test_accuracy": test_accuracy.item(), + } + + # Save JSON + save_metrics(report, out_json_path) + + +if __name__ == "__main__": + validate(sys.argv[1], sys.argv[2])