Skip to content

Commit

Permalink
initial splitlearning implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankJonasmoelle committed Dec 18, 2024
1 parent c1b8215 commit 3ad9402
Show file tree
Hide file tree
Showing 28 changed files with 4,177 additions and 267 deletions.
167 changes: 167 additions & 0 deletions examples/splitlearning_titanic/api.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 133,
"metadata": {},
"outputs": [],
"source": [
"from fedn import APIClient"
]
},
{
"cell_type": "code",
"execution_count": 134,
"metadata": {},
"outputs": [],
"source": [
"DISCOVER_HOST = '127.0.0.1'\n",
"DISCOVER_PORT = 8092\n",
"client = APIClient(DISCOVER_HOST, DISCOVER_PORT)"
]
},
{
"cell_type": "code",
"execution_count": 135,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<fedn.network.api.client.APIClient at 0x107cd93d0>"
]
},
"execution_count": 135,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client"
]
},
{
"cell_type": "code",
"execution_count": 136,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'count': 2,\n",
" 'result': [{'client_id': 'eb2bcf55-17c5-4e34-9b6f-d53b41e38cd4',\n",
" 'combiner': 'combinerd3486ca7',\n",
" 'id': '6763339936d0caaa05834846',\n",
" 'ip': '127.0.0.1',\n",
" 'last_seen': 'Wed, 18 Dec 2024 21:42:03 GMT',\n",
" 'name': 'clientb0abc7ea',\n",
" 'package': 'local',\n",
" 'status': 'online',\n",
" 'updated_at': '2024-12-18 21:42:01.464088'},\n",
" {'client_id': '8b4a21f9-c1e6-4632-a196-6aa29b268f61',\n",
" 'combiner': 'combinerd3486ca7',\n",
" 'id': '676332f636d0caaa058346f8',\n",
" 'ip': '127.0.0.1',\n",
" 'last_seen': 'Wed, 18 Dec 2024 21:41:07 GMT',\n",
" 'name': 'client1aafa234',\n",
" 'package': 'local',\n",
" 'status': 'online',\n",
" 'updated_at': '2024-12-18 21:39:18.504766'}]}"
]
},
"execution_count": 136,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.get_active_clients()"
]
},
{
"cell_type": "code",
"execution_count": 137,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'message': 'Compute package set.', 'success': True}"
]
},
"execution_count": 137,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"client.set_active_package('package.tgz', 'splitlearninghelper')"
]
},
{
"cell_type": "code",
"execution_count": 138,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'config': {'aggregator': 'splitlearningagg',\n",
" 'buffer_size': -1,\n",
" 'clients_requested': 8,\n",
" 'clients_required': 1,\n",
" 'delete_models_storage': True,\n",
" 'helper_type': 'splitlearninghelper',\n",
" 'round_timeout': 60,\n",
" 'rounds': 200,\n",
" 'server_functions': None,\n",
" 'session_id': '1',\n",
" 'task': '',\n",
" 'validate': True},\n",
" 'message': 'Split Learning Session started successfully.',\n",
" 'success': True}"
]
},
"execution_count": 138,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"session_id = \"1\"\n",
"\n",
"session_config = {\n",
" \"helper\": \"splitlearninghelper\",\n",
" \"id\": session_id,\n",
" \"aggregator\": \"splitlearningagg\",\n",
" \"rounds\": 200,\n",
" \"round_timeout\": 60\n",
" }\n",
"\n",
"\n",
"client.start_splitlearning_session(**session_config)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 4 additions & 0 deletions examples/splitlearning_titanic/client.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
network_id: fedn-network
api_url: http://api-server:8092
discover_host: api-server
discover_port: 8092
59 changes: 59 additions & 0 deletions examples/splitlearning_titanic/client/backward.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import os
import sys

import numpy as np
import torch
from model import load_client_model, save_client_model
from torch import optim

from fedn.common.log_config import logger
from fedn.utils.helpers.helpers import get_helper, save_metadata

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)

HELPER_MODULE = "splitlearninghelper"
helper = get_helper(HELPER_MODULE)


def backward_pass(gradient_path, client_id):
"""Load gradients from in_gradients_path, load the embeddings, and perform a backward pass to update
the parameters of the client model. Save the updated model to out_model_path.
"""
# load client model with parameters
client_model = load_client_model(client_id)
logger.info(f"Client model loaded from {client_id}")

# instantiate optimizer
client_optimizer = optim.Adam(client_model.parameters(), lr=0.01)
client_optimizer.zero_grad()

# load local embedding from previous forward pass
logger.info(f"Loading embedding from {client_id}")
try:
npz_file = np.load(f"{abs_path}/embeddings/embeddings_{client_id}.npz")
embedding = next(iter(npz_file.values()))
except FileNotFoundError:
raise FileNotFoundError(f"Embedding file {client_id} not found")

# transform to tensor
embedding = torch.tensor(embedding, dtype=torch.float32, requires_grad=True)

# load gradients
gradients = helper.load(gradient_path)
logger.info(f"Gradients loaded from {gradient_path}")

local_gradients = gradients[client_id]
local_gradients = torch.tensor(local_gradients, dtype=torch.float32, requires_grad=True)

# perform backward pass
embedding.backward(local_gradients)
client_optimizer.step()

# save updated client model locally
save_client_model(client_model, client_id)

logger.info(f"Updated client model saved to {abs_path}/local_models/{client_id}.pth")

if __name__ == "__main__":
backward_pass(sys.argv[1], sys.argv[2])
114 changes: 114 additions & 0 deletions examples/splitlearning_titanic/client/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import os

import pandas as pd
import torch
from sklearn.preprocessing import StandardScaler

dir_path = os.path.dirname(os.path.realpath(__file__))
abs_path = os.path.abspath(dir_path)


def load_data(data_path=None, is_train=True):
"""Load data from data_path. If data_path is None, load data from default path."""
if data_path is None:
data_path = abs_path + "/data/clients/1/titanic.pt"
data = torch.load(data_path, weights_only=True)
if is_train:
return data["X_train"]
else:
return data["X_test"]

def load_labels(data_path=None, is_train=True):
"""Load labels from data_path. If data_path is None, load labels from default path."""
if data_path is None:
data_path = abs_path + "/data/clients/labels.pt"
data = torch.load(data_path, weights_only=True)
if is_train:
return data["y_train"]
else:
return data["y_test"]

def preprocess_data(df: pd.DataFrame, scaler=None, is_train=True):
"""Preprocess data. If scaler is None, fit scaler on data. If is_train is False, remove labels from features."""
if is_train:
prep_df = df[["PassengerId", "Survived", "Pclass", "Sex", "Age", "Fare"]].copy() # select relevant features
else:
prep_df = df[["PassengerId", "Pclass", "Sex", "Age", "Fare"]].copy() # Survived should not be in test set
# fill nas
prep_df["Age"] = prep_df["Age"].fillna(prep_df["Age"].median())
prep_df["Fare"] = prep_df["Fare"].fillna(prep_df["Fare"].median())

# scale data
if is_train:
scaler = StandardScaler()
prep_df[["Age", "Fare"]] = scaler.fit_transform(prep_df[["Age", "Fare"]])
else:
prep_df[["Age", "Fare"]] = scaler.transform(prep_df[["Age", "Fare"]])

# categorization
prep_df["Sex"] = prep_df["Sex"].astype("category").cat.codes
prep_df["Pclass"] = prep_df["Pclass"].astype("category").cat.codes
return prep_df, scaler


def vertical_split(out_dir="data"):
"""Generate vertical splits for titanic dataset for 2 clients. Hardcoded for now."""
n_splits = 2

# Make dir
if not os.path.exists(f"{out_dir}/clients"):
os.makedirs(f"{out_dir}/clients")

train_df = pd.read_csv("../data/train.csv")
test_df = pd.read_csv("../data/test.csv")

train_df, scaler = preprocess_data(train_df, is_train=True)
test_df, _ = preprocess_data(test_df, scaler=scaler, is_train=False)

# vertical train data split (for 2 clients, hardcoded)
client_1_data_tensor = torch.tensor(train_df[["Sex", "Age"]].values, dtype=torch.float32)
client_2_data_tensor = torch.tensor(train_df[["Pclass", "Fare"]].values, dtype=torch.float32)
# labels, will only be accessed by server
train_label_tensor = torch.tensor(train_df[["Survived"]].values, dtype=torch.float32)

# vertical test data split (for 2 clients, hardcoded)
test_client_1_tensor = torch.tensor(test_df[["Sex", "Age"]].values, dtype=torch.float32)
test_client_2_tensor = torch.tensor(test_df[["Pclass", "Fare"]].values, dtype=torch.float32)
# test labels, need to be loaded separately
test_label_df = pd.read_csv("../data/labels.csv")
test_label_tensor = torch.tensor(test_label_df.values, dtype=torch.float32)

data = {
"train_features": [client_1_data_tensor, client_2_data_tensor],
"train_labels": train_label_tensor,
"test_features": [test_client_1_tensor, test_client_2_tensor],
"test_labels": test_label_tensor,
}

# Make 2 vertical splits
for i in range(n_splits):
subdir = f"{out_dir}/clients/{str(i+1)}"
if not os.path.exists(subdir):
os.mkdir(subdir)
# save features
torch.save(
{
"X_train": data["train_features"][i],
"X_test": data["test_features"][i],
},
f"{subdir}/titanic.pt",
)
# save labels
subdir = f"{out_dir}/clients"
torch.save(
{
"y_train": data["train_labels"],
"y_test": data["test_labels"],
},
f"{subdir}/labels.pt"
)

if __name__ == "__main__":
# Prepare data if not already done
if not os.path.exists(abs_path + "/data/clients/1"):
vertical_split()
8 changes: 8 additions & 0 deletions examples/splitlearning_titanic/client/fedn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# python_env: python_env.yaml
entry_points:
startup:
command: python data.py
forward:
command: python forward.py
backward:
command: python backward.py
Loading

0 comments on commit 3ad9402

Please sign in to comment.