-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate_dataset.py
159 lines (134 loc) · 4.62 KB
/
generate_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
## @package generate_dataset
# Pulls album data from the Spotify API and randomly partitions it into test and training datasets.
# Takes a playlist URI as an argument to get a list of artists to pull albums from.
# Usage: % python3 generate_dataset.py \<Spotify Playlist URI\>
from __future__ import print_function # (at top of module)
from spotipy.oauth2 import SpotifyClientCredentials
import pickle
import spotipy
import sys
import numpy as np
## Returns up to 100 artists in the given playlist and related artists if there is room
# @param uri A Spotify playlist URI
# @returns A list of artist ids
def get_artists(uri):
username = uri.split(':')[2]
playlist_id = uri.split(':')[4]
results = sp.user_playlist_tracks(username, playlist_id)['items']
unique = set()
full = False
ids = []
for track in results:
artists = track['track']['artists']
for artist in artists:
id = artist['id']
# Skip duplicates
if not id in unique:
ids.append(artist['id'])
unique.add(id)
if len(unique) >= 200:
full = True
else:
related_artists = sp.artist_related_artists(id)['artists']
artists.extend(related_artists)
if full: break
if full: break
return ids
## Returns track ids for a given album
# @param album A Spotify album id
# @returns A list of track ids
def get_album_tracks(album):
tracks = []
ids = []
results = sp.album_tracks(album)
tracks.extend(results['items'])
while results['next']:
results = sp.next(results)
tracks.extend(results['items'])
for track in tracks:
ids.append(track['id'])
return ids
## Returns a list of album ids produced by a given artist
# @param artist A Spotify artist id
# @returns A list of album ids
def get_artist_albums(artist):
albums = []
ids = []
results = sp.artist_albums(artist, album_type='album')
albums.extend(results['items'])
while results['next']:
results = sp.next(results)
albums.extend(results['items'])
unique = set() # skip duplicate albums
for album in albums:
name = album['name'].lower()
if not name in unique:
ids.append(album['id'])
unique.add(name)
return ids
###########################################################################
## Gets Spotify credentials from the source environment
client_credentials_manager = SpotifyClientCredentials()
## Initializes a Spotify object with the credentials
sp = spotipy.Spotify(client_credentials_manager=client_credentials_manager)
if len(sys.argv) > 1:
## A Spotify playlist URI
uri = sys.argv[1]
else:
uri = 'spotify:user:1230457813:playlist:74oqUs80qzfsdCr0Ek9tZV'
## A list of artist ids from the given playlist
artists = get_artists(uri)
## An output file object
fout = open('data', 'wb')
for artist in artists:
## A list of album ids
albums = get_artist_albums(artist)
for album in albums:
## Indicates whether the album had any track with missing information
bad_track = False
## A list of track ids
results = get_album_tracks(album)
if len(results) >= 50: continue
## A list of json objects with song features
tracks = sp.audio_features(results)
for track in tracks:
if track:
track.pop('type', None)
track.pop('track_href', None)
track.pop('analysis_url', None)
track.pop('uri', None)
if len(track) != 14: bad_track = True
else:
bad_track = True
# toss the album if there is a bad track
if bad_track: continue
pickle.dump(tracks, fout)
fout.close()
## An input file object
fin = open('data', 'rb')
## An output file object for the test dataset
test = open('test', 'wb')
## An output file object for the training dataset
train = open('train', 'wb')
## Counts the number of albums processed
n = 0
## Counts the number of albums in the training dataset
train_n = 0
## Counts the number of albums in the test dataset
test_n = 0
while True:
try:
if np.random.binomial(1,0.1) == 0:
pickle.dump(pickle.load(fin), train)
train_n += 1
else:
pickle.dump(pickle.load(fin), test)
test_n += 1
n += 1
except EOFError:
break
print('{} albums processed'.format(n))
print('Train - {}, Test - {}'.format(train_n, test_n))
fin.close()
test.close()
train.close()