-
Notifications
You must be signed in to change notification settings - Fork 38
/
split_whole_dataset_to_indices.py
51 lines (38 loc) · 1.98 KB
/
split_whole_dataset_to_indices.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
'''
This script allows to split the dataset of songs into training/validation/test
splits. It is done at the song granularity in order to prevent leaking
information within each song (compared to splitting at block level).
Also this approach is invariant to the block/hop size of features like
chromagram. This allows to compare various feature types.
The output is a TSV file containing information on which song is in which split
and its relative order within the split.
'''
import numpy as np
import pandas as pd
from sklearn.cross_validation import train_test_split
def split_songs(song_file, song_index_file, split_index_file, random_state):
df = pd.read_csv(song_file, sep='\t', header=None, names=['path'])
songs = np.array([p.split('/') for p in df['path']])
df['artist'] = songs[:, 0]
df['album'] = songs[:, 1]
df['song'] = songs[:, 2]
def split_dataset(index, random_state):
index = list(index)
ix_train, ix_test = train_test_split(index, test_size=0.2, random_state=random_state)
ix_train, ix_valid = train_test_split(ix_train, test_size=0.2 / (1 - 0.2), random_state=random_state)
return {'train': ix_train, 'valid': ix_valid, 'test': ix_test}
split_incides = split_dataset(df.index, random_state)
df['split'] = ''
for name in split_incides:
df['split'].ix[split_incides[name]] = name
df['order'] = np.hstack([split_incides['train'], split_incides['valid'], split_incides['test']])
df.to_csv(song_index_file, sep='\t', index=None)
with open(split_index_file, 'w') as file:
for name in split_incides:
print(name + '\t' + ','.join([str(i) for i in split_incides[name]]), file=file)
if __name__ == '__main__':
data_dir = '../data/beatles'
song_file = data_dir + '/isophonic-songs.txt'
song_index_file = data_dir + '/songs-dataset-split.tsv'
split_index_file = data_dir + '/dataset-split-indexes.tsv'
split_songs(song_file, song_index_file, split_index_file, random_state=42)