Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ADBench LSTM test #152

Merged
merged 2 commits into from
Nov 14, 2019
Merged

ADBench LSTM test #152

merged 2 commits into from
Nov 14, 2019

Conversation

toelli-msft
Copy link
Contributor

No description provided.

@toelli-msft
Copy link
Contributor Author

This collection of code contains a C++ wrapper for adbench-lstm.ks so that it can be tested against a Python reference implementation. It would be good to get at least one pair of eyes on it before merging.

@pashminacameron
Copy link

@toelli-msft It appears you have a simple LSTM implementation. I added a sequence LSTM implementation. As I am not sure what the purpose here is, lstm2.py may not be what you wanted, but having it in there will help discussion. I think we should test the sort of LSTM I have pushed as a minimum to make any speed claims.

gates = np.concatenate((inp, hidden, inp, hidden), 0) * weight + bias
hidden_size = hidden.shape[0]

forget = sigmoid(gates[0:hidden_size])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert something here about size/shape of hidden vs size of inp ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to change the original source code as little as possible. See the new top-level comment.

ypred, new_state = lstm_predict(main_params, extra_params, all_states[t], _input)
all_states.append(new_state)
ynorm = ypred - np.log(sum(np.exp(ypred), 2))
ygold = sequence[t + 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super-nit: is that yg_old or y_gold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latter, but: I'm trying to change the original source code as little as possible. See the new top-level comment.

@@ -0,0 +1,40 @@
# There's a lot of duplication between this and
# build_and_test_mnistcnn.sh, but we will follow the Rule of Three
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add similar comment to build_and_test_mnistcnn.sh? (I think the Wikipedia reference is probably unnecessary/OTT but it is humorous to include it ;) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, good idea


"""
Ther are many formulations of LSTMs. This code follows the formulation from
https://cs224d.stanford.edu/lecture_notes/LectureNotes4.pdf with some simplifications
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe a comment what this is for? Do I understand correctly that this is not what you are testing against, it's for eyeball comparison only or something?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alan - here's the reference for what this is for.

from ksc.adbench_lstm.lstm import (
lstm_model, lstm_predict, lstm_objective, sigmoid)

ten = np.ndarray
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this ten used anywhere? I don't see it. (And d ?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, it's not used. (d is used in several places though)

import random
import numpy as np

from ksc.adbench_lstm.lstm import (
Copy link
Contributor

@acl33 acl33 Nov 14, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Personal preference but I would probably import ksc.adbench_lstm.lstm as k and then I'm testing a.lstm_model against k.lstm_model etc. Up to you...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like that idea, thanks!

Copy link
Contributor

@acl33 acl33 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, a few nits/suggestions.

in terms of a Python reference implementation
@toelli-msft
Copy link
Contributor Author

Thanks Alan and Pashmina.

For those who are wondering why this is not a standard LSTM implementation, we have to copy exactly whatever ADBench does so that we are comparing like-for-like. ADBench explicitly titles its graph with "D-LSTM" (Diagonal LSTM) to make it clear that it is not a standard one. We definitely want ADBench to have one, it is an open ADBench issue to implement a standard one, but as yet it is not one.

Hopefully the new comment on top of lstm.py makes this clearer too.

Pashmina, thanks for the reference implementation of a standard LSTM. Perhaps ADBench can use it as a reference.

Copy link

@pashminacameron pashminacameron left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the context, Tom.

@toelli-msft toelli-msft merged commit c20e98c into master Nov 14, 2019
@toelli-msft toelli-msft deleted the toelli/adbench-lstm-test branch November 14, 2019 16:14
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants