-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
in terms of a Python reference implementation
- Loading branch information
1 parent
fe6a927
commit e667ba4
Showing
5 changed files
with
319 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# A reference implementation of D-LSTM, so that we can test our | ||
# Knossos implementation. | ||
# | ||
# See also | ||
# | ||
# https://github.com/awf/ADBench/blob/e5f72ab5dcb453b1bb72dd000e5add6b90502ec4/src/python/modules/Tensorflow/TensorflowLSTM.py | ||
|
||
import numpy as np | ||
|
||
def sigmoid(x): | ||
return 1.0 / (1.0 + np.exp(-x)) | ||
|
||
def lstm_model(weight, bias, hidden, cell, inp): | ||
gates = np.concatenate((inp, hidden, inp, hidden), 0) * weight + bias | ||
hidden_size = hidden.shape[0] | ||
|
||
forget = sigmoid(gates[0:hidden_size]) | ||
ingate = sigmoid(gates[hidden_size:2*hidden_size]) | ||
outgate = sigmoid(gates[2*hidden_size:3*hidden_size]) | ||
change = np.tanh(gates[3*hidden_size:]) | ||
|
||
cell = cell * forget + ingate * change | ||
hidden = outgate * np.tanh(cell) | ||
|
||
return (hidden, cell) | ||
|
||
def lstm_predict(w, w2, s, x): | ||
s2 = s.copy() | ||
# NOTE not sure if this should be element-wise or matrix multiplication | ||
x = x * w2[0] | ||
for i in range(0, len(s), 2): | ||
(s2[i], s2[i + 1]) = lstm_model(w[i], w[i + 1], s[i], s[i + 1], x) | ||
x = s2[i] | ||
return (x * w2[1] + w2[2], s2) | ||
|
||
def lstm_objective(main_params, extra_params, state, sequence, _range=None): | ||
if _range is None: | ||
_range = range(0, len(sequence) - 1) | ||
|
||
total = 0.0 | ||
count = 0 | ||
_input = sequence[_range[0]] | ||
all_states = [state] | ||
for t in _range: | ||
ypred, new_state = lstm_predict(main_params, extra_params, all_states[t], _input) | ||
all_states.append(new_state) | ||
ynorm = ypred - np.log(sum(np.exp(ypred), 2)) | ||
ygold = sequence[t + 1] | ||
total += sum(ygold * ynorm) | ||
count += ygold.shape[0] | ||
_input = ygold | ||
return -total / count |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
# Remarkably, adbench-lstm.ks agrees exactly with our reference | ||
# implementation, so we set the almost equal check to check to a very | ||
# large number of decimal places. | ||
|
||
import adbench_lstm as a | ||
import random | ||
import numpy as np | ||
|
||
from ksc.adbench_lstm.lstm import ( | ||
lstm_model, lstm_predict, lstm_objective, sigmoid) | ||
|
||
ten = np.ndarray | ||
d = a.vec_double | ||
|
||
def r(): | ||
return random.random() * 2 - 1 | ||
|
||
def rv(n): | ||
return [r() for _ in range(n)] | ||
|
||
def rvv(n, m): | ||
return [rv(m) for _ in range(n)] | ||
|
||
def concat(l): | ||
return sum(l, []) | ||
|
||
# The ks::vec __iter__ method that is automatically generated by | ||
# pybind11 is one that keeps going off the end of the vec and never | ||
# stops. Until I get time to dig into how to make it generate a | ||
# better one, here's a handy utility function. | ||
def to_list(x): | ||
return [x[i] for i in range(len(x))] | ||
|
||
def main(): | ||
assert_equal_model() | ||
assert_equal_predict_and_objective() | ||
print("The assertions didn't throw any errors, so " | ||
"everything must be good!") | ||
|
||
def assert_equal_model(): | ||
h = 2 | ||
|
||
w1 = rv(h) | ||
w2 = rv(h) | ||
w3 = rv(h) | ||
w4 = rv(h) | ||
|
||
b1 = rv(h) | ||
b2 = rv(h) | ||
b3 = rv(h) | ||
b4 = rv(h) | ||
|
||
hidden = rv(h) | ||
cell = rv(h) | ||
input_ = rv(h) | ||
|
||
weight = concat([w1, w2, w3, w4]) | ||
bias = concat([b1, b2, b3, b4]) | ||
|
||
(ao0, ao1) = a.lstm_model(d(w1), | ||
d(b1), | ||
d(w2), | ||
d(b2), | ||
d(w3), | ||
d(b3), | ||
d(w4), | ||
d(b4), | ||
d(hidden), | ||
d(cell), | ||
d(input_)) | ||
|
||
ao0l = to_list(ao0) | ||
ao1l = to_list(ao1) | ||
|
||
nd_weight = np.array(weight) | ||
|
||
(mo0, mo1) = lstm_model(nd_weight, | ||
np.array(bias), | ||
np.array(hidden), | ||
np.array(cell), | ||
np.array(input_)) | ||
|
||
print(mo0) | ||
print(ao0l) | ||
print(mo1) | ||
print(ao1l) | ||
|
||
np.testing.assert_almost_equal(ao0l, mo0, decimal=12, err_msg="Model 1") | ||
np.testing.assert_almost_equal(ao1l, mo1, decimal=12, err_msg="Model 2") | ||
|
||
def assert_equal_predict_and_objective(): | ||
l = 2 | ||
h = 10 | ||
|
||
w1 = rvv(l, h) | ||
w2 = rvv(l, h) | ||
w3 = rvv(l, h) | ||
w4 = rvv(l, h) | ||
|
||
b1 = rvv(l, h) | ||
b2 = rvv(l, h) | ||
b3 = rvv(l, h) | ||
b4 = rvv(l, h) | ||
|
||
hidden = rvv(l, h) | ||
cell = rvv(l, h) | ||
|
||
input_ = rv(h) | ||
|
||
input_weight = rv(h) | ||
output_weight = rv(h) | ||
output_bias = rv(h) | ||
|
||
tww = np.array(concat([concat([w1i, w2i, w3i, w4i]), | ||
concat([b1i, b2i, b3i, b4i])] | ||
for (w1i, w2i, w3i, w4i, b1i, b2i, b3i, b4i) | ||
in zip(w1, w2, w3, w4, b1, b2, b3, b4))) | ||
|
||
ts = np.array(concat(([hiddeni, celli] | ||
for (hiddeni, celli) | ||
in zip(hidden, cell)))) | ||
|
||
tww2 = np.array([input_weight, output_weight, output_bias]) | ||
|
||
tinput_ = np.array(input_) | ||
|
||
print(tww.shape) | ||
print(tww2.shape) | ||
print(ts.shape) | ||
print(tinput_.shape) | ||
|
||
(tp0, tp1) = lstm_predict(tww, tww2, ts, tinput_) | ||
|
||
|
||
tp0l = tp0.tolist() | ||
tp1l = tp1.tolist() | ||
|
||
wf_etc = [tuple(d(i) for i in tu) | ||
for tu in zip(w1, b1, w2, b2, w3, b3, w4, b4, hidden, cell)] | ||
|
||
(v, vtvv) = a.lstm_predict(a.vec_tuple_vec10(wf_etc), | ||
d(input_weight), | ||
d(output_weight), | ||
d(output_bias), | ||
d(input_)) | ||
|
||
vl = to_list(v) | ||
vtvvl = concat([to_list(v1), to_list(v2)] for (v1, v2) in to_list(vtvv)) | ||
|
||
to = lstm_objective(tww, tww2, ts, [tinput_, tinput_]) | ||
tol = to.tolist() | ||
|
||
print(tol) | ||
|
||
aol = a.lstm_objective(a.vec_tuple_vec10(wf_etc), | ||
d(input_weight), | ||
d(output_weight), | ||
d(output_bias), | ||
a.vec_tuple_vec2([(d(input_), d(input_))])) | ||
|
||
print(tp0l) | ||
print(vl) | ||
print(tp1l) | ||
print(vtvvl) | ||
print(tol) | ||
print(aol) | ||
|
||
np.testing.assert_almost_equal(tp0l, vl, decimal=12, err_msg="Predict 1") | ||
np.testing.assert_almost_equal(tp1l, vtvvl, decimal=12, err_msg="Predict 2") | ||
np.testing.assert_almost_equal(tol, aol, decimal=12, err_msg="Objective") | ||
|
||
if __name__ == '__main__': main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# There's a lot of duplication between this and | ||
# build_and_test_mnistcnn.sh, but we will follow the Rule of Three | ||
# | ||
# https://en.wikipedia.org/wiki/Rule_of_three_(computer_programming) | ||
|
||
set -e | ||
|
||
KNOSSOS=$1 | ||
PYBIND11=$2 | ||
|
||
RUNTIME=$KNOSSOS/src/runtime | ||
OBJ=$KNOSSOS/obj/test/ksc | ||
PYBIND11_INCLUDE=$PYBIND11/include | ||
|
||
PYTHON3_CONFIG_EXTENSION_SUFFIX=$(python3-config --extension-suffix) | ||
|
||
MODULE_NAME=adbench_lstm | ||
MODULE_FILE="$OBJ/$MODULE_NAME$PYTHON3_CONFIG_EXTENSION_SUFFIX" | ||
|
||
echo Compiling... | ||
|
||
g++-7 -fmax-errors=5 \ | ||
-fdiagnostics-color=always \ | ||
-Wall \ | ||
-Wno-unused \ | ||
-Wno-maybe-uninitialized \ | ||
-I$RUNTIME \ | ||
-I$OBJ \ | ||
-I$PYBIND11_INCLUDE \ | ||
$(PYTHONPATH=$PYBIND11 python3 -m pybind11 --includes) \ | ||
-O3 \ | ||
-std=c++17 \ | ||
-shared \ | ||
-fPIC \ | ||
-o $MODULE_FILE \ | ||
-DMNISTCNNCPP_MODULE_NAME=$MODULE_NAME \ | ||
$KNOSSOS/test/ksc/adbench-lstmpy.cpp | ||
|
||
KSCPY=$KNOSSOS/src/python | ||
PYTHONPATH=$OBJ:$KSCPY python3 -m ksc.adbench_lstm.test |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* There's a lot of duplication between this and mnistcnnpy.cpp, but | ||
* we will follow the Rule of Three | ||
* | ||
* https://en.wikipedia.org/wiki/Rule_of_three_(computer_programming) | ||
*/ | ||
|
||
#include <pybind11/pybind11.h> | ||
#include <pybind11/stl.h> | ||
#include <pybind11/operators.h> | ||
|
||
namespace py = pybind11; | ||
|
||
#include "adbench-lstm.cpp" | ||
|
||
int ks::main() { return 0; }; | ||
|
||
template<typename T> | ||
void declare_vec(py::module &m, std::string typestr) { | ||
using Class = ks::vec<T>; | ||
std::string pyclass_name = std::string("vec_") + typestr; | ||
py::class_<Class>(m, pyclass_name.c_str()) | ||
.def(py::init<>()) | ||
.def(py::init<std::vector<T> const&>()) | ||
.def("is_zero", &Class::is_zero) | ||
.def("__getitem__", [](const ks::vec<T> &a, const int &b) { | ||
return a[b]; | ||
}) | ||
.def("__len__", [](const ks::vec<T> &a) { return a.size(); }); | ||
} | ||
|
||
// In the future it might make more sense to move the vec type | ||
// definitions to a general Knossos CPP types Python module. | ||
// | ||
// I don't know how to make a single Python type that works for vecs | ||
// of many different sorts of contents. It seems like it must be | ||
// possible because Python tuples map to std::tuples regardless of | ||
// their contents. I'll look into it later. For now I'll just have a | ||
// bunch of verbose replication. | ||
PYBIND11_MODULE(MNISTCNNCPP_MODULE_NAME, m) { | ||
declare_vec<double>(m, std::string("double")); | ||
declare_vec<std::tuple<ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>, ks::vec<double>>>(m, std::string("tuple_vec10")); | ||
declare_vec<std::tuple<ks::vec<double>, ks::vec<double>>>(m, std::string("tuple_vec2")); | ||
declare_vec<ks::vec<double> >(m, std::string("vec_double")); | ||
declare_vec<ks::vec<ks::vec<double> > >(m, std::string("vec_vec_double")); | ||
declare_vec<ks::vec<ks::vec<ks::vec<double> > > >(m, std::string("vec_vec_vec_double")); | ||
declare_vec<ks::vec<ks::vec<ks::vec<ks::vec<double> > > > >(m, std::string("vec_vec_vec_vec_double")); | ||
m.def("sigmoid", &ks::sigmoid); | ||
m.def("logsumexp", &ks::logsumexp); | ||
m.def("lstm_model", &ks::lstm_model); | ||
m.def("lstm_predict", &ks::lstm_predict); | ||
m.def("lstm_objective", &ks::lstm_objective); | ||
} |