-
Notifications
You must be signed in to change notification settings - Fork 0
/
Mean Shift.py
90 lines (71 loc) · 2.39 KB
/
Mean Shift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#================================================================================================================
#----------------------------------------------------------------------------------------------------------------
# MEAN SHIFT
#----------------------------------------------------------------------------------------------------------------
#================================================================================================================
import math
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import style
import pandas
import datetime
from sklearn import preprocessing, cross_validation
#for plotting
plt.style.use('ggplot')
class CustomMS:
def __init__(self):
pass
def main():
'''
Pclass Passenger Class (1 = 1st; 2 = 2nd; 3 = 3rd)
survival Survival (0 = No; 1 = Yes)
name Name
sex Sex
age Age
sibsp Number of Siblings/Spouses Aboard
parch Number of Parents/Children Aboard
ticket Ticket Number
fare Passenger Fare (British pound)
cabin Cabin
embarked Port of Embarkation (C = Cherbourg; Q = Queenstown; S = Southampton)
boat Lifeboat
body Body Identification Number
home.dest Home/Destination
'''
df = pd.read_excel('data/titanic.xls')
original_df = pd.DataFrame.copy(df)
df.drop(['body','name'], 1, inplace=True)
df.fillna(0,inplace=True)
def handle_non_numerical_data(df):
# handling non-numerical data: must convert.
columns = df.columns.values
for column in columns:
text_digit_vals = {}
def convert_to_int(val):
return text_digit_vals[val]
#print(column,df[column].dtype)
if df[column].dtype != np.int64 and df[column].dtype != np.float64:
column_contents = df[column].values.tolist()
#finding just the uniques
unique_elements = set(column_contents)
# great, found them.
x = 0
for unique in unique_elements:
if unique not in text_digit_vals:
# creating dict that contains new
# id per unique string
text_digit_vals[unique] = x
x+=1
# now we map the new "id" vlaue
# to replace the string.
df[column] = list(map(convert_to_int,df[column]))
return df
df = handle_non_numerical_data(df)
df.drop(['ticket','home.dest'], 1, inplace=True)
X = np.array(df.drop(['survived'], 1).astype(float))
X = preprocessing.scale(X)
y = np.array(df['survived'])
ms = CustomMS()
ms.fit(dataset = dataset)
if __name__ == "__main__":
main()