-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added Gradient Descent [Python] #348
Closed
Closed
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
a64e36d
- Added Gradient Descent
84c1a6d
- Fixed typo
e7246c4
Merge branch 'master' into master
9ddb8b9
- Changed gradient_descent.py according to PEP8 guidelines.
f3cb691
Fixed typos in comments
3f7c2d1
Update gradient_descent.py
29ce52d
Update gradient_descent.py
d88fe72
Merge branch 'master' into master
3556dc3
Update gradient_descent.py
a2d3d46
Merge branch 'master' into master
c4f79bf
Merge branch 'master' into master
f0a6367
Merge branch 'master' into master
77c58df
- Modified gradient_descent.py
1e0a530
- Updated gradient_descent.py
6da9dab
- Updated gradient_descent.py
c3d6150
Update gradient_descent.py
c74af4b
Merge branch 'master' into master
3802026
- Graph is shown when sys argument is passed.
7a0b0c9
Update gradient_descent.py
5eea0f9
Update pip2-requirements.txt
601d0fb
Merge branch 'master' into master
d592d32
Merge branch 'master' into master
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
""" | ||
Implementation of gradient descent algorithm for minimizing cost of a linear hypothesis function. | ||
""" | ||
import numpy | ||
|
||
# List of input, output pairs | ||
train_data = (((5, 2, 3), 15), ((6, 5, 9), 25), | ||
((11, 12, 13), 41), ((1, 1, 1), 8), ((11, 12, 13), 41)) | ||
test_data = (((515, 22, 13), 555), ((61, 35, 49), 150)) | ||
parameter_vector = [2, 4, 1, 5] | ||
m = len(train_data) | ||
LEARNING_RATE = 0.009 | ||
|
||
|
||
def _error(example_no, data_set='train'): | ||
""" | ||
:param data_set: train data or test data | ||
:param example_no: example number whose error has to be checked | ||
:return: error in example pointed by example number. | ||
""" | ||
return calculate_hypothesis_value(example_no, data_set) - output(example_no, data_set) | ||
|
||
|
||
def _hypothesis_value(data_input_tuple): | ||
""" | ||
Calculates hypothesis function value for a given input | ||
:param data_input_tuple: Input tuple of a particular example | ||
:return: Value of hypothesis function at that point. | ||
Note that there is an 'biased input' whose value is fixed as 1. | ||
It is not explicitly mentioned in input data.. But, ML hypothesis functions use it. | ||
So, we have to take care of it separately. Line 36 takes care of it. | ||
""" | ||
hyp_val = 0 | ||
for i in range(len(parameter_vector) - 1): | ||
hyp_val += data_input_tuple[i]*parameter_vector[i+1] | ||
hyp_val += parameter_vector[0] | ||
return hyp_val | ||
|
||
|
||
def output(example_no, data_set): | ||
""" | ||
:param data_set: test data or train data | ||
:param example_no: example whose output is to be fetched | ||
:return: output for that example | ||
""" | ||
if data_set == 'train': | ||
return train_data[example_no][1] | ||
elif data_set == 'test': | ||
return test_data[example_no][1] | ||
|
||
|
||
def calculate_hypothesis_value(example_no, data_set): | ||
""" | ||
Calculates hypothesis value for a given example | ||
:param data_set: test data or train_data | ||
:param example_no: example whose hypothesis value is to be calculated | ||
:return: hypothesis value for that example | ||
""" | ||
if data_set == "train": | ||
return _hypothesis_value(train_data[example_no][0]) | ||
elif data_set == "test": | ||
return _hypothesis_value(test_data[example_no][0]) | ||
|
||
|
||
def summation_of_cost_derivative(index, end=m): | ||
""" | ||
Calculates the sum of cost function derivative | ||
:param index: index wrt derivative is being calculated | ||
:param end: value where summation ends, default is m, number of examples | ||
:return: Returns the summation of cost derivative | ||
Note: If index is -1, this means we are calculating summation wrt to biased parameter. | ||
""" | ||
summation_value = 0 | ||
for i in range(end): | ||
if index == -1: | ||
summation_value += _error(i) | ||
else: | ||
summation_value += _error(i)*train_data[i][0][index] | ||
return summation_value | ||
|
||
|
||
def get_cost_derivative(index): | ||
""" | ||
:param index: index of the parameter vector wrt to derivative is to be calculated | ||
:return: derivative wrt to that index | ||
Note: If index is -1, this means we are calculating summation wrt to biased parameter. | ||
""" | ||
cost_derivative_value = summation_of_cost_derivative(index, m)/m | ||
return cost_derivative_value | ||
|
||
|
||
def run_gradient_descent(): | ||
global parameter_vector | ||
# Tune these values to set a tolerance value for predicted output | ||
absolute_error_limit = 0.000002 | ||
relative_error_limit = 0 | ||
j = 0 | ||
while True: | ||
j += 1 | ||
temp_parameter_vector = [0, 0, 0, 0] | ||
for i in range(0, len(parameter_vector)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
cost_derivative = get_cost_derivative(i-1) | ||
temp_parameter_vector[i] = parameter_vector[i] - \ | ||
LEARNING_RATE*cost_derivative | ||
if numpy.allclose(parameter_vector, temp_parameter_vector, | ||
atol=absolute_error_limit, rtol=relative_error_limit): | ||
break | ||
parameter_vector = temp_parameter_vector | ||
print("Number of iterations:", j) | ||
|
||
|
||
def test_gradient_descent(): | ||
for i in range(len(test_data)): | ||
print("Actual output value:", output(i, 'test')) | ||
print("Hypothesis output:", calculate_hypothesis_value(i, 'test')) | ||
|
||
|
||
if __name__ == '__main__': | ||
run_gradient_descent() | ||
print("\nTesting gradient descent for a linear hypothesis function.\n") | ||
test_gradient_descent() |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove Trailing Whitespaces
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Where is the issue exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
At the end of Line 31
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got it, thanks.