-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
initial splitlearning implementation
- Loading branch information
1 parent
c1b8215
commit 3ad9402
Showing
28 changed files
with
4,177 additions
and
267 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,167 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 133, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from fedn import APIClient" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 134, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"DISCOVER_HOST = '127.0.0.1'\n", | ||
"DISCOVER_PORT = 8092\n", | ||
"client = APIClient(DISCOVER_HOST, DISCOVER_PORT)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 135, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"<fedn.network.api.client.APIClient at 0x107cd93d0>" | ||
] | ||
}, | ||
"execution_count": 135, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"client" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 136, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'count': 2,\n", | ||
" 'result': [{'client_id': 'eb2bcf55-17c5-4e34-9b6f-d53b41e38cd4',\n", | ||
" 'combiner': 'combinerd3486ca7',\n", | ||
" 'id': '6763339936d0caaa05834846',\n", | ||
" 'ip': '127.0.0.1',\n", | ||
" 'last_seen': 'Wed, 18 Dec 2024 21:42:03 GMT',\n", | ||
" 'name': 'clientb0abc7ea',\n", | ||
" 'package': 'local',\n", | ||
" 'status': 'online',\n", | ||
" 'updated_at': '2024-12-18 21:42:01.464088'},\n", | ||
" {'client_id': '8b4a21f9-c1e6-4632-a196-6aa29b268f61',\n", | ||
" 'combiner': 'combinerd3486ca7',\n", | ||
" 'id': '676332f636d0caaa058346f8',\n", | ||
" 'ip': '127.0.0.1',\n", | ||
" 'last_seen': 'Wed, 18 Dec 2024 21:41:07 GMT',\n", | ||
" 'name': 'client1aafa234',\n", | ||
" 'package': 'local',\n", | ||
" 'status': 'online',\n", | ||
" 'updated_at': '2024-12-18 21:39:18.504766'}]}" | ||
] | ||
}, | ||
"execution_count": 136, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"client.get_active_clients()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 137, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'message': 'Compute package set.', 'success': True}" | ||
] | ||
}, | ||
"execution_count": 137, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"client.set_active_package('package.tgz', 'splitlearninghelper')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 138, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"text/plain": [ | ||
"{'config': {'aggregator': 'splitlearningagg',\n", | ||
" 'buffer_size': -1,\n", | ||
" 'clients_requested': 8,\n", | ||
" 'clients_required': 1,\n", | ||
" 'delete_models_storage': True,\n", | ||
" 'helper_type': 'splitlearninghelper',\n", | ||
" 'round_timeout': 60,\n", | ||
" 'rounds': 200,\n", | ||
" 'server_functions': None,\n", | ||
" 'session_id': '1',\n", | ||
" 'task': '',\n", | ||
" 'validate': True},\n", | ||
" 'message': 'Split Learning Session started successfully.',\n", | ||
" 'success': True}" | ||
] | ||
}, | ||
"execution_count": 138, | ||
"metadata": {}, | ||
"output_type": "execute_result" | ||
} | ||
], | ||
"source": [ | ||
"session_id = \"1\"\n", | ||
"\n", | ||
"session_config = {\n", | ||
" \"helper\": \"splitlearninghelper\",\n", | ||
" \"id\": session_id,\n", | ||
" \"aggregator\": \"splitlearningagg\",\n", | ||
" \"rounds\": 200,\n", | ||
" \"round_timeout\": 60\n", | ||
" }\n", | ||
"\n", | ||
"\n", | ||
"client.start_splitlearning_session(**session_config)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venv", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.9.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
network_id: fedn-network | ||
api_url: http://api-server:8092 | ||
discover_host: api-server | ||
discover_port: 8092 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import os | ||
import sys | ||
|
||
import numpy as np | ||
import torch | ||
from model import load_client_model, save_client_model | ||
from torch import optim | ||
|
||
from fedn.common.log_config import logger | ||
from fedn.utils.helpers.helpers import get_helper, save_metadata | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
abs_path = os.path.abspath(dir_path) | ||
|
||
HELPER_MODULE = "splitlearninghelper" | ||
helper = get_helper(HELPER_MODULE) | ||
|
||
|
||
def backward_pass(gradient_path, client_id): | ||
"""Load gradients from in_gradients_path, load the embeddings, and perform a backward pass to update | ||
the parameters of the client model. Save the updated model to out_model_path. | ||
""" | ||
# load client model with parameters | ||
client_model = load_client_model(client_id) | ||
logger.info(f"Client model loaded from {client_id}") | ||
|
||
# instantiate optimizer | ||
client_optimizer = optim.Adam(client_model.parameters(), lr=0.01) | ||
client_optimizer.zero_grad() | ||
|
||
# load local embedding from previous forward pass | ||
logger.info(f"Loading embedding from {client_id}") | ||
try: | ||
npz_file = np.load(f"{abs_path}/embeddings/embeddings_{client_id}.npz") | ||
embedding = next(iter(npz_file.values())) | ||
except FileNotFoundError: | ||
raise FileNotFoundError(f"Embedding file {client_id} not found") | ||
|
||
# transform to tensor | ||
embedding = torch.tensor(embedding, dtype=torch.float32, requires_grad=True) | ||
|
||
# load gradients | ||
gradients = helper.load(gradient_path) | ||
logger.info(f"Gradients loaded from {gradient_path}") | ||
|
||
local_gradients = gradients[client_id] | ||
local_gradients = torch.tensor(local_gradients, dtype=torch.float32, requires_grad=True) | ||
|
||
# perform backward pass | ||
embedding.backward(local_gradients) | ||
client_optimizer.step() | ||
|
||
# save updated client model locally | ||
save_client_model(client_model, client_id) | ||
|
||
logger.info(f"Updated client model saved to {abs_path}/local_models/{client_id}.pth") | ||
|
||
if __name__ == "__main__": | ||
backward_pass(sys.argv[1], sys.argv[2]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
|
||
import pandas as pd | ||
import torch | ||
from sklearn.preprocessing import StandardScaler | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
abs_path = os.path.abspath(dir_path) | ||
|
||
|
||
def load_data(data_path=None, is_train=True): | ||
"""Load data from data_path. If data_path is None, load data from default path.""" | ||
if data_path is None: | ||
data_path = abs_path + "/data/clients/1/titanic.pt" | ||
data = torch.load(data_path, weights_only=True) | ||
if is_train: | ||
return data["X_train"] | ||
else: | ||
return data["X_test"] | ||
|
||
def load_labels(data_path=None, is_train=True): | ||
"""Load labels from data_path. If data_path is None, load labels from default path.""" | ||
if data_path is None: | ||
data_path = abs_path + "/data/clients/labels.pt" | ||
data = torch.load(data_path, weights_only=True) | ||
if is_train: | ||
return data["y_train"] | ||
else: | ||
return data["y_test"] | ||
|
||
def preprocess_data(df: pd.DataFrame, scaler=None, is_train=True): | ||
"""Preprocess data. If scaler is None, fit scaler on data. If is_train is False, remove labels from features.""" | ||
if is_train: | ||
prep_df = df[["PassengerId", "Survived", "Pclass", "Sex", "Age", "Fare"]].copy() # select relevant features | ||
else: | ||
prep_df = df[["PassengerId", "Pclass", "Sex", "Age", "Fare"]].copy() # Survived should not be in test set | ||
# fill nas | ||
prep_df["Age"] = prep_df["Age"].fillna(prep_df["Age"].median()) | ||
prep_df["Fare"] = prep_df["Fare"].fillna(prep_df["Fare"].median()) | ||
|
||
# scale data | ||
if is_train: | ||
scaler = StandardScaler() | ||
prep_df[["Age", "Fare"]] = scaler.fit_transform(prep_df[["Age", "Fare"]]) | ||
else: | ||
prep_df[["Age", "Fare"]] = scaler.transform(prep_df[["Age", "Fare"]]) | ||
|
||
# categorization | ||
prep_df["Sex"] = prep_df["Sex"].astype("category").cat.codes | ||
prep_df["Pclass"] = prep_df["Pclass"].astype("category").cat.codes | ||
return prep_df, scaler | ||
|
||
|
||
def vertical_split(out_dir="data"): | ||
"""Generate vertical splits for titanic dataset for 2 clients. Hardcoded for now.""" | ||
n_splits = 2 | ||
|
||
# Make dir | ||
if not os.path.exists(f"{out_dir}/clients"): | ||
os.makedirs(f"{out_dir}/clients") | ||
|
||
train_df = pd.read_csv("../data/train.csv") | ||
test_df = pd.read_csv("../data/test.csv") | ||
|
||
train_df, scaler = preprocess_data(train_df, is_train=True) | ||
test_df, _ = preprocess_data(test_df, scaler=scaler, is_train=False) | ||
|
||
# vertical train data split (for 2 clients, hardcoded) | ||
client_1_data_tensor = torch.tensor(train_df[["Sex", "Age"]].values, dtype=torch.float32) | ||
client_2_data_tensor = torch.tensor(train_df[["Pclass", "Fare"]].values, dtype=torch.float32) | ||
# labels, will only be accessed by server | ||
train_label_tensor = torch.tensor(train_df[["Survived"]].values, dtype=torch.float32) | ||
|
||
# vertical test data split (for 2 clients, hardcoded) | ||
test_client_1_tensor = torch.tensor(test_df[["Sex", "Age"]].values, dtype=torch.float32) | ||
test_client_2_tensor = torch.tensor(test_df[["Pclass", "Fare"]].values, dtype=torch.float32) | ||
# test labels, need to be loaded separately | ||
test_label_df = pd.read_csv("../data/labels.csv") | ||
test_label_tensor = torch.tensor(test_label_df.values, dtype=torch.float32) | ||
|
||
data = { | ||
"train_features": [client_1_data_tensor, client_2_data_tensor], | ||
"train_labels": train_label_tensor, | ||
"test_features": [test_client_1_tensor, test_client_2_tensor], | ||
"test_labels": test_label_tensor, | ||
} | ||
|
||
# Make 2 vertical splits | ||
for i in range(n_splits): | ||
subdir = f"{out_dir}/clients/{str(i+1)}" | ||
if not os.path.exists(subdir): | ||
os.mkdir(subdir) | ||
# save features | ||
torch.save( | ||
{ | ||
"X_train": data["train_features"][i], | ||
"X_test": data["test_features"][i], | ||
}, | ||
f"{subdir}/titanic.pt", | ||
) | ||
# save labels | ||
subdir = f"{out_dir}/clients" | ||
torch.save( | ||
{ | ||
"y_train": data["train_labels"], | ||
"y_test": data["test_labels"], | ||
}, | ||
f"{subdir}/labels.pt" | ||
) | ||
|
||
if __name__ == "__main__": | ||
# Prepare data if not already done | ||
if not os.path.exists(abs_path + "/data/clients/1"): | ||
vertical_split() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
# python_env: python_env.yaml | ||
entry_points: | ||
startup: | ||
command: python data.py | ||
forward: | ||
command: python forward.py | ||
backward: | ||
command: python backward.py |
Oops, something went wrong.