Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/SK-811 | DP-SGD #769

Merged
merged 10 commits into from
Dec 14, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions examples/mnist-pytorch-DPSGD/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ FEDn Project: Federated Differential Privacy MNIST (Opacus + PyTorch)
----------------------------------------------------------------------

This example FEDn Project demonstrates how Differential Privacy can be integrated to enhance the confidentiality of client data.
We have expanded our baseline MNIST-PyTorch example by incorporating the Opacus framework, which is specifically designed for PyTorch models.
We have expanded our baseline MNIST-PyTorch example by incorporating the Opacus framework, which is specifically designed for PyTorch models. If you are interested more about Differential Privacy read our `blogpost <https://www.scaleoutsystems.com/post/guaranteeing-data-privacy-for-clients-in-federated-machine-learning>`__ about it.
mattiasakesson marked this conversation as resolved.
Show resolved Hide resolved



Expand All @@ -12,11 +12,20 @@ Prerequisites
- `Python >=3.9, <=3.12 <https://www.python.org/downloads>`__
- `A project in FEDn Studio <https://fedn.scaleoutsystems.com/signup>`__

Edit Differential Privacy budget

Edit Client-Specific Differential Privacy Parameters
--------------------------
- The **Differential Privacy budget** (`FINAL_EPSILON`, `DELTA`) is configured in the `compute` package at `client/train.py` (lines 35 and 39).
- If `HARDLIMIT` (line 40) is set to `True`, the `FINAL_EPSILON` will not exceed its specified limit.
- If `HARDLIMIT` is set to `False`, the expected `FINAL_EPSILON` will be around its specified value given the server runs `GLOBAL_ROUNDS` variable (line 36).
The **Differential Privacy budget** (``epsilon``, ``delta``), along with other settings, is configurable in the ``client_settings.yaml`` file:

- **epochs**: Number of local epochs per round.
- **epsilon**: Total epsilon budget to spend, determined by the ``global_rounds`` set on the server side.
- **delta**: Total delta budget to spend.
- **max_grad_norm**: Clipping threshold for gradients.
- **global_rounds**: Number of rounds the server will run.
- **hardlimit**:

- If ``hardlimit`` is set to ``True``, the ``epsilon`` budget will not exceed its specified limit, even if it means skipping updates for some rounds.
- If ``hardlimit`` is set to ``False``, the expected ``epsilon`` will be approximately equal to its specified value, assuming the server completes the specified ``global_rounds`` of updates.

Creating the compute package and seed model
-------------------------------------------
Expand Down
4 changes: 3 additions & 1 deletion examples/mnist-pytorch-DPSGD/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
python_env: python_env.yaml
# Remove the python_env tag below to handle the environment manually
python_env: python_env.yaml

entry_points:
build:
command: python model.py
Expand Down
147 changes: 64 additions & 83 deletions examples/mnist-pytorch-DPSGD/client/train.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
import math
import os
import sys

import yaml
import torch
from model import load_parameters, save_parameters
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager


from model import load_parameters, save_parameters
from data import load_data
from fedn.utils.helpers.helpers import save_metadata
import numpy as np

from opacus import PrivacyEngine
from torch.utils.data import Dataset

import numpy as np
from opacus.utils.batch_memory_manager import BatchMemoryManager
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))

# Define a custom Dataset class
class CustomDataset(Dataset):
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, x_data, y_data):
self.x_data = x_data
self.y_data = y_data
Expand All @@ -26,21 +30,15 @@ def __getitem__(self, idx):
y_data = self.y_data[idx]
return x_data, y_data


dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.abspath(dir_path))

MAX_GRAD_NORM = 1.2
FINAL_EPSILON = 8.0
GLOBAL_ROUNDS = 4
EPOCHS = 5
EPSILON = FINAL_EPSILON/GLOBAL_ROUNDS
DELTA = 1e-5
HARDLIMIT = False

MAX_PHYSICAL_BATCH_SIZE = 32
EPOCHS = 1
EPSILON = 1000.
DELTA = 1e-5
MAX_GRAD_NORM = 1.2
GLOBAL_ROUNDS = 10
HARDLIMIT = True

def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1, lr=0.01):
def train(in_model_path, out_model_path, data_path=None, batch_size=32, lr=0.01):
"""Complete a model update.

Load model paramters from in_model_path (managed by the FEDn client),
Expand All @@ -60,68 +58,74 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1
:param lr: The learning rate to use.
:type lr: float
"""

with open("../../client_settings.yaml", "r") as fh:
try:
settings = yaml.safe_load(fh)
EPSILON = float(settings["epsilon"])
DELTA = float(settings["delta"])
MAX_GRAD_NORM = float(settings["max_grad_norm"])
GLOBAL_ROUNDS = int(settings["global_rounds"])
HARDLIMIT = bool(settings["hardlimit"])
global MAX_PHYSICAL_BATCH_SIZE
MAX_PHYSICAL_BATCH_SIZE = int(settings["max_physical_batch_size"])
except yaml.YAMLError as exc:
print(exc)

# Load data
print("data_path: ", data_path)
x_train, y_train = load_data(data_path)
trainset = CustomDataset(x_train, y_train)
batch_size = 32
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)

# Load parmeters and initialize model
model = load_parameters(in_model_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
n_batches = int(math.ceil(len(x_train) / batch_size))
criterion = torch.nn.NLLLoss()

# Load epsilon
if os.path.isfile("epsilon.npy"):
privacy_engine = PrivacyEngine()

tot_epsilon = np.load("epsilon.npy")
print("load consumed epsilon: ", tot_epsilon)
if os.path.isfile("privacy_accountant.state"):
privacy_engine.accountant = torch.load("privacy_accountant.state")

else:
trainset = CustomDataset(x_train, y_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, num_workers=2)

print("initiate tot_epsilon")
tot_epsilon = 0.

try:
epsilon_spent = privacy_engine.get_epsilon(DELTA)
except:
epsilon_spent = 0
print("epsilon before training: ", epsilon_spent)

# Train
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
privacy_engine = PrivacyEngine()
round_epsilon = np.sqrt((epsilon_spent/EPSILON*np.sqrt(GLOBAL_ROUNDS))**2+1)*EPSILON/np.sqrt(GLOBAL_ROUNDS)

print("target epsilon: ", round_epsilon)

model, optimizer, train_loader = privacy_engine.make_private_with_epsilon(
module=model,
optimizer=optimizer,
data_loader=train_loader,
epochs=EPOCHS,
target_epsilon=EPSILON,
target_epsilon=round_epsilon,
target_delta=DELTA,
max_grad_norm=MAX_GRAD_NORM,
)

print(f"Using sigma={optimizer.noise_multiplier} and C={MAX_GRAD_NORM}")



for epoch in range(EPOCHS):
train_dp(model, train_loader, optimizer, epoch + 1, device, privacy_engine)

d_epsilon = privacy_engine.get_epsilon(DELTA)
print("epsilon spent: ", d_epsilon)
tot_epsilon = np.sqrt(tot_epsilon**2 + d_epsilon**2)
print("saving tot_epsilon: ", tot_epsilon)
np.save("epsilon.npy", tot_epsilon)

if HARDLIMIT and tot_epsilon >= FINAL_EPSILON:
print("DP Budget Exceeded: The differential privacy budget has been exhausted, no model updates will be applied to preserve privacy guarantees.")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_dp(model, train_loader, optimizer, EPOCHS, device, privacy_engine)
try:
print("epsilon after training: ", privacy_engine.get_epsilon(DELTA))
except:
print("cant calculate epsilon")

else:
if HARDLIMIT and privacy_engine.get_epsilon(DELTA)<EPSILON:
# Metadata needed for aggregation server side
metadata = {
# num_examples are mandatory
"num_examples": len(x_train),
"batch_size": batch_size,
"epochs": epochs,
"epochs": EPOCHS,
"lr": lr,
}

Expand All @@ -130,21 +134,15 @@ def train(in_model_path, out_model_path, data_path=None, batch_size=32, epochs=1

# Save model update (mandatory)
save_parameters(model, out_model_path)
else:
print("Epsilon too high, not saving model")

def accuracy(preds, labels):
return (preds == labels).mean()




# Save privacy accountant
torch.save(privacy_engine.accountant,'privacy_accountant.state')

def train_dp(model, train_loader, optimizer, epoch, device, privacy_engine):
model.train()
criterion = torch.nn.NLLLoss() # nn.CrossEntropyLoss()

losses = []
top1_acc = []

with BatchMemoryManager(
data_loader=train_loader,
max_physical_batch_size=MAX_PHYSICAL_BATCH_SIZE,
Expand All @@ -159,27 +157,10 @@ def train_dp(model, train_loader, optimizer, epoch, device, privacy_engine):
# compute output
output = model(images)
loss = criterion(output, target)

preds = np.argmax(output.detach().cpu().numpy(), axis=1)
labels = target.detach().cpu().numpy()

# measure accuracy and record loss
acc = accuracy(preds, labels)

losses.append(loss.item())
top1_acc.append(acc)

loss.backward()
optimizer.step()

if (i + 1) % 200 == 0:
epsilon = privacy_engine.get_epsilon(DELTA)
print(
f"\tTrain Epoch: {epoch} \t"
f"Loss: {np.mean(losses):.6f} "
f"Acc@1: {np.mean(top1_acc) * 100:.6f} "
f"(ε = {epsilon:.2f}, δ = {DELTA})"
)


if __name__ == "__main__":
train(sys.argv[1], sys.argv[2])
8 changes: 8 additions & 0 deletions examples/mnist-pytorch-DPSGD/client_settings.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Constants
max_physical_batch_size: 32
epochs: 1
epsilon: 100.0
delta: 1e-5
max_grad_norm: 1.2
global_rounds: 10
hardlimit: true
Loading