-
Notifications
You must be signed in to change notification settings - Fork 0
/
helpers.py
38 lines (30 loc) · 982 Bytes
/
helpers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding: utf-8 -*-
"""some helper functions for project 1."""
import csv
import numpy as np
def load_csv_data(data_path, sub_sample=False):
"""Loads data.
return
y(class labels), tX (features) and ids (event ids).
"""
y = np.genfromtxt(
data_path, delimiter=",", skip_header=1, dtype=str, usecols=1)
x = np.genfromtxt(
data_path, delimiter=",", skip_header=1)
ids = x[:, 0].astype(np.int)
input_data = x[:, 2:]
# convert class labels from strings to binary (-1,1)
yb = np.ones(len(y))
yb[np.where(y == 'b')] = -1
# sub-sample
if sub_sample:
yb = yb[::50]
input_data = input_data[::50]
ids = ids[::50]
return yb, input_data, ids
def predict_labels(weights, data):
"""Generates class predictions given weights, and a test data matrix."""
y_pred = np.dot(data, weights)
y_pred[np.where(y_pred <= 0)] = -1
y_pred[np.where(y_pred > 0)] = 1
return y_pred