-
Notifications
You must be signed in to change notification settings - Fork 0
/
perceptron_iris.py
73 lines (59 loc) · 2.24 KB
/
perceptron_iris.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
class Perceptron:
def __init__(self, learning_rate=0.1, n_iters=10000):
self.lr = learning_rate
self.n_iters = n_iters
self.weights = None
self.bias = None
def fit(self, X, y):
n_samples, n_features = X.shape
self.weights = np.zeros(n_features)
self.bias = 0
# gradient descent
for _ in range(self.n_iters):
for idx, x_i in enumerate(X):
linear_output = np.dot(x_i, self.weights) + self.bias
y_predicted = self.step_function(linear_output)
# update weights and bias
update = self.lr * (y[idx] - y_predicted)
self.weights += update * x_i
self.bias += update
def predict(self, X):
linear_output = np.dot(X, self.weights) + self.bias
y_predicted = self.step_function(linear_output)
return y_predicted
def step_function(self, x):
return np.where(x >= 0, 1, 0)
def main():
# Load iris dataset
iris = load_iris()
X = iris.data
y = iris.target
# Split data into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Feature scaling
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# Initialize and train perceptron
perceptron = Perceptron(learning_rate=0.2, n_iters=20000)
perceptron.fit(X_train, y_train)
# Make predictions
predictions = perceptron.predict(X_test)
# Evaluate accuracy
accuracy = np.mean(predictions == y_test)
print("Accuracy:", accuracy)
# Plot the data
plt.figure(figsize=(10, 6))
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap='viridis', label='Training Data')
plt.xlabel('Sepal Length (cm)')
plt.ylabel('Sepal Width (cm)')
plt.title('Iris Dataset - Sepal Length vs. Sepal Width')
plt.legend()
plt.show()
if __name__ == "__main__":
main()