diff --git a/nn.py b/nn.py index 8fc11ff..6f8b7de 100644 --- a/nn.py +++ b/nn.py @@ -32,11 +32,13 @@ nr_correct += int(np.argmax(o) == np.argmax(l)) # Backpropagation output -> hidden (cost function derivative) - delta_o = o - l - w_h_o += -learn_rate * delta_o @ np.transpose(h) - b_h_o += -learn_rate * delta_o + delta_o = (o - l)*o*(1-o) + # Backpropagation hidden -> input (activation function derivative) delta_h = np.transpose(w_h_o) @ delta_o * (h * (1 - h)) + #update weights + w_h_o += -learn_rate * delta_o @ np.transpose(h) + b_h_o += -learn_rate * delta_o w_i_h += -learn_rate * delta_h @ np.transpose(img) b_i_h += -learn_rate * delta_h