-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_data_partitions.py
47 lines (37 loc) · 1.35 KB
/
create_data_partitions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
import sys
import pandas as pd
from math import floor
def import_data(test_ratio):
""" Download data. """
data = pd.read_csv('https://query.data.world/s/culciexydc2njqbyaqxayl7rleyhwf')
data = data.sample(frac=1).reset_index(drop=True)
data.to_csv("data.csv", index=False)
num_test = int(test_ratio*data.shape[0])
testset = data[:num_test]
trainset = data[num_test:]
return trainset, testset
def splitset(dataset, parts):
n = dataset.shape[0]
local_n = floor(n/parts)
result = []
for i in range(parts):
result.append(dataset[i*local_n: (i+1)*local_n])
return result
if __name__ == '__main__':
if len(sys.argv) < 2:
nr_of_datasets = 10
else:
nr_of_datasets = sys.argv[1]
trainset, testset = import_data(0.1)
trainsets = splitset(trainset, nr_of_datasets)
testsets = splitset(testset, nr_of_datasets)
if not os.path.exists('data'):
os.mkdir('data')
if not os.path.exists('data/clients'):
os.mkdir('data/clients')
for i in range(nr_of_datasets):
if not os.path.exists('data/clients/{}'.format(str(i))):
os.mkdir('data/clients/{}'.format(str(i)))
trainsets[i].to_csv('data/clients/{}'.format(str(i)) + '/train.csv', index=False)
testsets[i].to_csv('data/clients/{}'.format(str(i)) + '/test.csv', index=False)