Skip to content

Commit

Permalink
Merge pull request #46 from marcpinet/feat-add-lstm-rnn
Browse files Browse the repository at this point in the history
Feat add lstm rnn
  • Loading branch information
marcpinet authored Nov 6, 2024
2 parents 85b8602 + a785241 commit 4106a3d
Show file tree
Hide file tree
Showing 7 changed files with 966 additions and 131 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ I intend to improve the neural networks and add more features in the future.

## 📦 Features

- Many layers (input, activation, dense, dropout, conv1d/2d, maxpooling1d/2d, flatten, embedding, batchnormalization, and more) 🧠
- Many layers (wrappers, dense, dropout, conv1d/2d, pooling1d/2d, flatten, embedding, batchnormalization, lstm, attention and more) 🧠
- Many activation functions (sigmoid, tanh, relu, leaky relu, softmax, linear, elu, selu) 📈
- Many loss functions (mean squared error, mean absolute error, categorical crossentropy, binary crossentropy, huber loss) 📉
- Many optimizers (sgd, momentum, rmsprop, adam) 📊
Expand All @@ -32,8 +32,9 @@ pip install neuralnetlib

## 💡 How to use

See [this file](examples/classification-regression/simple_mnist_multiclass.py) for a simple example of how to use the library.
For a more advanced example, see [this file](examples/cnn-classification/simple_cnn_classification_mnist.py).
See [this file](examples/classification-regression/mnist_multiclass.ipynb) for a simple example of how to use the library.<br>
For a more advanced example, see [this file](examples/cnn-classification/cnn_classification_mnist.ipynb) for using CNN.<br>
You can also check [this file](examples/classification-regression/sentiment_analysis.ipynb) for text classification using RNN.<br>

More examples in [this folder](examples).

Expand Down
72 changes: 40 additions & 32 deletions examples/classification-regression/mnist_multiclass.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:17.470315300Z",
"start_time": "2024-09-22T21:23:15.274765600Z"
"end_time": "2024-11-06T21:20:11.860716600Z",
"start_time": "2024-11-06T21:20:03.030565100Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -52,8 +52,8 @@
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:17.612787400Z",
"start_time": "2024-09-22T21:23:17.472315400Z"
"end_time": "2024-11-06T21:20:12.002523Z",
"start_time": "2024-11-06T21:20:11.862717900Z"
}
},
"outputs": [],
Expand All @@ -73,8 +73,8 @@
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:17.702612600Z",
"start_time": "2024-09-22T21:23:17.609786900Z"
"end_time": "2024-11-06T21:20:12.091137200Z",
"start_time": "2024-11-06T21:20:11.999925Z"
}
},
"outputs": [],
Expand All @@ -97,8 +97,8 @@
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:17.718270700Z",
"start_time": "2024-09-22T21:23:17.704611500Z"
"end_time": "2024-11-06T21:20:12.107204400Z",
"start_time": "2024-11-06T21:20:12.092135900Z"
}
},
"outputs": [],
Expand Down Expand Up @@ -134,8 +134,8 @@
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:17.763653100Z",
"start_time": "2024-09-22T21:23:17.719270900Z"
"end_time": "2024-11-06T21:20:12.152371800Z",
"start_time": "2024-11-06T21:20:12.108612300Z"
}
},
"outputs": [
Expand Down Expand Up @@ -177,26 +177,34 @@
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.493706600Z",
"start_time": "2024-09-22T21:23:17.734301400Z"
"end_time": "2024-11-06T21:21:10.172232400Z",
"start_time": "2024-11-06T21:20:12.124120500Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[==============================] 100% Epoch 1/10 - loss: 0.5703 - accuracy_score: 0.8109 - 1.10s\n",
"[==============================] 100% Epoch 2/10 - loss: 0.2287 - accuracy_score: 0.9336 - 1.05s\n",
"[==============================] 100% Epoch 3/10 - loss: 0.1950 - accuracy_score: 0.9437 - 1.13s\n",
"[==============================] 100% Epoch 4/10 - loss: 0.1791 - accuracy_score: 0.9468 - 1.02s\n",
"[==============================] 100% Epoch 5/10 - loss: 0.1600 - accuracy_score: 0.9525 - 1.12s\n",
"[==============================] 100% Epoch 6/10 - loss: 0.1469 - accuracy_score: 0.9567 - 1.01s\n",
"[==============================] 100% Epoch 7/10 - loss: 0.1398 - accuracy_score: 0.9582 - 1.10s\n",
"[==============================] 100% Epoch 8/10 - loss: 0.1337 - accuracy_score: 0.9601 - 1.03s\n",
"[==============================] 100% Epoch 9/10 - loss: 0.1292 - accuracy_score: 0.9620 - 1.12s\n",
"[==============================] 100% Epoch 10/10 - loss: 0.1243 - accuracy_score: 0.9631 - 1.02s\n"
"[==============================] 100% Epoch 1/10 - loss: 0.5703 - accuracy: 0.8109 - 5.33s\n",
"[==============================] 100% Epoch 2/10 - loss: 0.2287 - accuracy: 0.9336 - 5.37s\n",
"[==============================] 100% Epoch 3/10 - loss: 0.1950 - accuracy: 0.9437 - 5.41s\n",
"[==============================] 100% Epoch 4/10 - loss: 0.1791 - accuracy: 0.9468 - 5.75s\n",
"[==============================] 100% Epoch 5/10 - loss: 0.1600 - accuracy: 0.9525 - 5.87s\n",
"[==============================] 100% Epoch 6/10 - loss: 0.1469 - accuracy: 0.9567 - 6.02s\n",
"[==============================] 100% Epoch 7/10 - loss: 0.1398 - accuracy: 0.9582 - 6.17s\n",
"[==============================] 100% Epoch 8/10 - loss: 0.1337 - accuracy: 0.9601 - 6.02s\n",
"[==============================] 100% Epoch 9/10 - loss: 0.1292 - accuracy: 0.9620 - 5.99s\n",
"[==============================] 100% Epoch 10/10 - loss: 0.1243 - accuracy: 0.9631 - 6.00s\n"
]
},
{
"data": {
"text/plain": ""
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -215,8 +223,8 @@
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.536750900Z",
"start_time": "2024-09-22T21:23:28.490707700Z"
"end_time": "2024-11-06T21:21:10.188691300Z",
"start_time": "2024-11-06T21:21:10.145550200Z"
}
},
"outputs": [
Expand Down Expand Up @@ -245,8 +253,8 @@
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.568699500Z",
"start_time": "2024-09-22T21:23:28.537750700Z"
"end_time": "2024-11-06T21:21:10.223168Z",
"start_time": "2024-11-06T21:21:10.189691600Z"
}
},
"outputs": [],
Expand All @@ -266,8 +274,8 @@
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.582991Z",
"start_time": "2024-09-22T21:23:28.567699400Z"
"end_time": "2024-11-06T21:21:10.235337900Z",
"start_time": "2024-11-06T21:21:10.221169700Z"
}
},
"outputs": [
Expand Down Expand Up @@ -299,8 +307,8 @@
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.814879800Z",
"start_time": "2024-09-22T21:23:28.583991Z"
"end_time": "2024-11-06T21:21:10.404184900Z",
"start_time": "2024-11-06T21:21:10.236337600Z"
}
},
"outputs": [
Expand Down Expand Up @@ -334,8 +342,8 @@
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T21:23:28.867661200Z",
"start_time": "2024-09-22T21:23:28.815905Z"
"end_time": "2024-11-06T21:21:10.456973200Z",
"start_time": "2024-11-06T21:21:10.406688900Z"
}
},
"outputs": [],
Expand Down
89 changes: 44 additions & 45 deletions examples/classification-regression/sentiment_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:10:57.538645900Z",
"start_time": "2024-09-22T23:10:55.233016Z"
"end_time": "2024-11-06T21:51:28.948615200Z",
"start_time": "2024-11-06T21:51:19.721136Z"
}
},
"outputs": [],
Expand All @@ -31,8 +31,8 @@
"import pandas as pd\n",
"\n",
"from neuralnetlib.model import Model\n",
"from neuralnetlib.layers import Input, Dense, Embedding, Flatten\n",
"from neuralnetlib.preprocessing import Tokenizer, pad_sequences, CountVectorizer\n",
"from neuralnetlib.layers import Input, Dense, Embedding, LSTM, Bidirectional, Attention, GlobalAveragePooling1D\n",
"from neuralnetlib.preprocessing import Tokenizer, pad_sequences\n",
"from neuralnetlib.metrics import accuracy_score\n",
"from neuralnetlib.utils import train_test_split\n",
"\n",
Expand All @@ -48,11 +48,11 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:13:42.739941500Z",
"start_time": "2024-09-22T23:13:41.184859600Z"
"end_time": "2024-11-06T21:51:30.589179800Z",
"start_time": "2024-11-06T21:51:28.950619500Z"
}
},
"outputs": [],
Expand All @@ -69,11 +69,11 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:13:43.449172100Z",
"start_time": "2024-09-22T23:13:43.200238700Z"
"end_time": "2024-11-06T21:51:30.871205900Z",
"start_time": "2024-11-06T21:51:30.590182500Z"
}
},
"outputs": [
Expand Down Expand Up @@ -122,7 +122,6 @@
"max_words = 10000\n",
"max_len = 200\n",
"\n",
"tokenizer = Tokenizer(num_words=max_words)\n",
"x_train = pad_sequences(x_train, max_length=max_len)\n",
"x_test = pad_sequences(x_test, max_length=max_len)\n",
"\n",
Expand All @@ -148,20 +147,20 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:13:48.701766500Z",
"start_time": "2024-09-22T23:13:48.692765600Z"
"end_time": "2024-11-06T21:51:30.899961500Z",
"start_time": "2024-11-06T21:51:30.871205900Z"
}
},
"outputs": [],
"source": [
"model = Model()\n",
"model.add(Input(input_shape=(max_len,)))\n",
"model.add(Embedding(max_words, 50, input_length=max_len))\n",
"model.add(Flatten())\n",
"model.add(Dense(10, activation='relu'))\n",
"model.add(Input(max_len))\n",
"model.add(Embedding(max_words, 100, weights_init='xavier'))\n",
"model.add(Bidirectional(LSTM(32, return_sequences=True)))\n",
"model.add(Attention())\n",
"model.add(Dense(1, activation='sigmoid'))"
]
},
Expand All @@ -174,11 +173,11 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:13:50.151043500Z",
"start_time": "2024-09-22T23:13:50.140043900Z"
"end_time": "2024-11-06T21:51:30.904961800Z",
"start_time": "2024-11-06T21:51:30.886456800Z"
}
},
"outputs": [
Expand All @@ -189,12 +188,11 @@
"Model\n",
"-------------------------------------------------\n",
"Layer 1: Input(input_shape=(200,))\n",
"Layer 2: Embedding(input_dim=10000, output_dim=50, input_length=200)\n",
"Layer 3: Flatten\n",
"Layer 4: Dense(units=10)\n",
"Layer 5: Activation(ReLU)\n",
"Layer 6: Dense(units=1)\n",
"Layer 7: Activation(Sigmoid)\n",
"Layer 2: Embedding(input_dim=10000, output_dim=100)\n",
"Layer 3: Bidirectional(layer=LSTM(units=32, return_sequences=True, return_state=False, random_state=None))\n",
"Layer 4: Attention(use_scale=True, score_mode=dot)\n",
"Layer 5: Dense(units=1)\n",
"Layer 6: Activation(Sigmoid)\n",
"-------------------------------------------------\n",
"Loss function: BinaryCrossentropy\n",
"Optimizer: Adam(learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-08)\n",
Expand All @@ -217,28 +215,29 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:15:06.090102500Z",
"start_time": "2024-09-22T23:13:52.475373900Z"
"end_time": "2024-11-06T22:17:05.632380200Z",
"start_time": "2024-11-06T22:17:05.625379900Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[==============================] 100% Epoch 1/10 - loss: 0.6922 - accuracy: 0.5208 - 7.01s - val_accuracy: 0.5466\n",
"[==============================] 100% Epoch 2/10 - loss: 0.6494 - accuracy: 0.6512 - 7.02s - val_accuracy: 0.5763\n",
"[==============================] 100% Epoch 3/10 - loss: 0.5619 - accuracy: 0.7295 - 6.99s - val_accuracy: 0.5831\n",
"[==============================] 100% Epoch 4/10 - loss: 0.4977 - accuracy: 0.7723 - 6.97s - val_accuracy: 0.5838\n",
"[==============================] 100% Epoch 5/10 - loss: 0.4506 - accuracy: 0.7991 - 7.05s - val_accuracy: 0.5842\n",
"[==============================] 100% Epoch 6/10 - loss: 0.4123 - accuracy: 0.8224 - 6.98s - val_accuracy: 0.5840\n",
"[==============================] 100% Epoch 7/10 - loss: 0.3792 - accuracy: 0.8418 - 7.01s - val_accuracy: 0.5838\n",
"[==============================] 100% Epoch 8/10 - loss: 0.3495 - accuracy: 0.8586 - 7.06s - val_accuracy: 0.5818\n",
"[==============================] 100% Epoch 9/10 - loss: 0.3219 - accuracy: 0.8752 - 6.99s - val_accuracy: 0.5793\n",
"[==============================] 100% Epoch 10/10 - loss: 0.2963 - accuracy: 0.8907 - 6.98s - val_accuracy: 0.5761\n"
"\n",
"[==============================] 100% Epoch 1/10 - loss: 0.6193 - accuracy: 0.7079 - 248.72s - val_accuracy: 0.8013\n",
"[==============================] 100% Epoch 2/10 - loss: 0.4215 - accuracy: 0.8477 - 264.70s - val_accuracy: 0.8504\n",
"[==============================] 100% Epoch 3/10 - loss: 0.3301 - accuracy: 0.8799 - 266.74s - val_accuracy: 0.8624\n",
"[==============================] 100% Epoch 4/10 - loss: 0.2835 - accuracy: 0.8954 - 255.44s - val_accuracy: 0.8677\n",
"[==============================] 100% Epoch 5/10 - loss: 0.2519 - accuracy: 0.9093 - 239.53s - val_accuracy: 0.8710\n",
"[==============================] 100% Epoch 6/10 - loss: 0.2283 - accuracy: 0.9183 - 239.53s - val_accuracy: 0.8728\n",
"[==============================] 100% Epoch 7/10 - loss: 0.2090 - accuracy: 0.9260 - 239.53s - val_accuracy: 0.8802\n",
"[==============================] 100% Epoch 8/10 - loss: 0.1926 - accuracy: 0.9320 - 239.53s - val_accuracy: 0.8884\n",
"[==============================] 100% Epoch 9/10 - loss: 0.1784 - accuracy: 0.9376 - 239.53s - val_accuracy: 0.8902\n",
"[==============================] 100% Epoch 10/10 - loss: 0.1660 - accuracy: 0.9423 - 239.53s - val_accuracy: 0.9000\n"
]
}
],
Expand All @@ -255,20 +254,20 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2024-09-22T23:15:08.910264900Z",
"start_time": "2024-09-22T23:15:08.577100100Z"
"end_time": "2024-11-06T22:17:25.754433600Z",
"start_time": "2024-11-06T22:17:14.398517800Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loss: 1.1821619656925941\n",
"Accuracy: 0.5884\n"
"Loss: 1.4010948021794365\n",
"Accuracy: 0.881\n"
]
}
],
Expand Down
Loading

0 comments on commit 4106a3d

Please sign in to comment.