diff --git a/.gitignore b/.gitignore
deleted file mode 100644
index 02c381e..0000000
--- a/.gitignore
+++ /dev/null
@@ -1,145 +0,0 @@
-### Data ###
-data/
-erc/
-
-### Python ###
-# Byte-compiled / optimized / DLL files
-__pycache__/
-*.py[cod]
-*$py.class
-
-# C extensions
-*.so
-
-# Distribution / packaging
-.Python
-build/
-develop-eggs/
-dist/
-downloads/
-eggs/
-.eggs/
-lib/
-lib64/
-parts/
-sdist/
-var/
-wheels/
-share/python-wheels/
-*.egg-info/
-.installed.cfg
-*.egg
-MANIFEST
-
-# PyInstaller
-# Usually these files are written by a python script from a template
-# before PyInstaller builds the exe, so as to inject date/other infos into it.
-*.manifest
-*.spec
-
-# Installer logs
-pip-log.txt
-pip-delete-this-directory.txt
-
-# Unit test / coverage reports
-htmlcov/
-.tox/
-.nox/
-.coverage
-.coverage.*
-.cache
-nosetests.xml
-coverage.xml
-*.cover
-*.py,cover
-.hypothesis/
-.pytest_cache/
-cover/
-
-# Translations
-*.mo
-*.pot
-
-# Django stuff:
-*.log
-local_settings.py
-db.sqlite3
-db.sqlite3-journal
-
-# Flask stuff:
-instance/
-.webassets-cache
-
-# Scrapy stuff:
-.scrapy
-
-# Sphinx documentation
-docs/_build/
-
-# PyBuilder
-.pybuilder/
-target/
-
-# Jupyter Notebook
-.ipynb_checkpoints
-
-# IPython
-profile_default/
-ipython_config.py
-
-# pyenv
-# For a library or package, you might want to ignore these files since the code is
-# intended to run in multiple environments; otherwise, check them in:
-# .python-version
-
-# pipenv
-# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
-# However, in case of collaboration, if having platform-specific dependencies or dependencies
-# having no cross-platform support, pipenv may install dependencies that don't work, or not
-# install all needed dependencies.
-#Pipfile.lock
-
-# PEP 582; used by e.g. github.com/David-OConnor/pyflow
-__pypackages__/
-
-# Celery stuff
-celerybeat-schedule
-celerybeat.pid
-
-# SageMath parsed files
-*.sage.py
-
-# Environments
-.env
-.venv
-env/
-venv/
-ENV/
-env.bak/
-venv.bak/
-
-# Spyder project settings
-.spyderproject
-.spyproject
-
-# Rope project settings
-.ropeproject
-
-# mkdocs documentation
-/site
-
-# mypy
-.mypy_cache/
-.dmypy.json
-dmypy.json
-
-# Pyre type checker
-.pyre/
-
-# pytype static type analyzer
-.pytype/
-
-# Cython debug symbols
-cython_debug/
-
-# End of https://www.toptal.com/developers/gitignore/api/python
\ No newline at end of file
diff --git a/LICENSE b/LICENSE
deleted file mode 100644
index f37218e..0000000
--- a/LICENSE
+++ /dev/null
@@ -1,21 +0,0 @@
-MIT License
-
-Copyright (c) 2021 Chung_es
-
-Permission is hereby granted, free of charge, to any person obtaining a copy
-of this software and associated documentation files (the "Software"), to deal
-in the Software without restriction, including without limitation the rights
-to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
-copies of the Software, and to permit persons to whom the Software is
-furnished to do so, subject to the following conditions:
-
-The above copyright notice and this permission notice shall be included in all
-copies or substantial portions of the Software.
-
-THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
-IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
-FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
-AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
-LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
-OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
-SOFTWARE.
diff --git a/README.md b/README.md
deleted file mode 100644
index 89ecc5a..0000000
--- a/README.md
+++ /dev/null
@@ -1 +0,0 @@
-# Time Series Generation
\ No newline at end of file
diff --git a/Timeseries_clustering.ipynb b/Timeseries_clustering.ipynb
deleted file mode 100644
index 7f59b46..0000000
--- a/Timeseries_clustering.ipynb
+++ /dev/null
@@ -1,1635 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "## Timeseries clustering\n",
- "\n",
- "Time series clustering is to partition time series data into groups based on similarity or distance, so that time series in the same cluster are similar.\n",
- "\n",
- "Methodology followed:\n",
- "* Use Variational Recurrent AutoEncoder (VRAE) for dimensionality reduction of the timeseries\n",
- "* To visualize the clusters, PCA and t-sne are used\n",
- "\n",
- "Paper:\n",
- "https://arxiv.org/pdf/1412.6581.pdf"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Contents\n",
- "\n",
- "0. [Load data and preprocess](#Load-data-and-preprocess)\n",
- "1. [Initialize VRAE object](#Initialize-VRAE-object)\n",
- "2. [Fit the model onto dataset](#Fit-the-model-onto-dataset)\n",
- "3. [Transform the input timeseries to encoded latent vectors](#Transform-the-input-timeseries-to-encoded-latent-vectors)\n",
- "4. [Save the model to be fetched later](#Save-the-model-to-be-fetched-later)\n",
- "5. [Visualize using PCA and tSNE](#Visualize-using-PCA-and-tSNE)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Import required modules"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/html": [
- " \n",
- " "
- ]
- },
- "metadata": {},
- "output_type": "display_data"
- }
- ],
- "source": [
- "from model.org_vrae import VRAE\n",
- "from model.utils import *\n",
- "import numpy as np\n",
- "import torch\n",
- "\n",
- "import plotly\n",
- "from torch.utils.data import DataLoader, TensorDataset\n",
- "plotly.offline.init_notebook_mode()"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Input parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "dload = './model_dir' #download directory"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Hyper parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "hidden_size = 90\n",
- "hidden_layer_depth = 1\n",
- "latent_length = 20\n",
- "batch_size = 32\n",
- "learning_rate = 0.0005\n",
- "n_epochs = 1\n",
- "dropout_rate = 0.2\n",
- "optimizer = 'Adam' # options: ADAM, SGD\n",
- "cuda = True # options: True, False\n",
- "print_every=30\n",
- "clip = True # options: True, False\n",
- "max_grad_norm=5\n",
- "loss = 'MSELoss' # options: SmoothL1Loss, MSELoss\n",
- "block = 'LSTM' # options: LSTM, GRU"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Load data and preprocess"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "X_train, X_val, y_train, y_val = open_data('data', ratio_train=0.9)\n",
- "\n",
- "num_classes = len(np.unique(y_train))\n",
- "base = np.min(y_train) # Check if data is 0-based\n",
- "if base != 0:\n",
- " y_train -= base\n",
- "y_val -= base"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(8549, 140, 1)\n"
- ]
- }
- ],
- "source": [
- "print(X_train.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_dataset = TensorDataset(torch.from_numpy(X_train))\n",
- "test_dataset = TensorDataset(torch.from_numpy(X_val))"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Fetch `sequence_length` from dataset**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "140"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sequence_length = X_train.shape[1]\n",
- "sequence_length"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Fetch `number_of_features` from dataset**\n",
- "\n",
- "This config corresponds to number of input features"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "1"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "number_of_features = X_train.shape[2]\n",
- "number_of_features"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Initialize VRAE object\n",
- "\n",
- "VRAE inherits from `sklearn.base.BaseEstimator` and overrides `fit`, `transform` and `fit_transform` functions, similar to sklearn modules"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
- "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py:65: UserWarning:\n",
- "\n",
- "dropout option adds dropout after all but last recurrent layer, so non-zero dropout expects num_layers greater than 1, but got dropout=0.2 and num_layers=1\n",
- "\n",
- "/usr/local/lib/python3.6/dist-packages/torch/nn/_reduction.py:42: UserWarning:\n",
- "\n",
- "size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
- "\n"
- ]
- }
- ],
- "source": [
- "vrae = VRAE(sequence_length=sequence_length,\n",
- " number_of_features = number_of_features,\n",
- " hidden_size = hidden_size, \n",
- " hidden_layer_depth = hidden_layer_depth,\n",
- " latent_length = latent_length,\n",
- " batch_size = batch_size,\n",
- " learning_rate = learning_rate,\n",
- " n_epochs = n_epochs,\n",
- " dropout_rate = dropout_rate,\n",
- " optimizer = optimizer, \n",
- " cuda = cuda,\n",
- " print_every=print_every, \n",
- " clip=clip, \n",
- " max_grad_norm=max_grad_norm,\n",
- " loss = loss,\n",
- " block = block,\n",
- " dload = dload)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Fit the model onto dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "fit result\n",
- "torch.Size([32, 140, 1])\n",
- "Epoch: 0\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 30, loss = 4126.4561, recon_loss = 4126.3960, kl_loss = 0.0602\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 60, loss = 3002.3113, recon_loss = 2999.1250, kl_loss = 3.1863\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 90, loss = 2697.9019, recon_loss = 2694.5222, kl_loss = 3.3796\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 120, loss = 2783.7349, recon_loss = 2780.4021, kl_loss = 3.3327\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 150, loss = 2799.6096, recon_loss = 2796.6382, kl_loss = 2.9714\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 180, loss = 2885.5581, recon_loss = 2882.9241, kl_loss = 2.6341\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 210, loss = 2366.8254, recon_loss = 2364.6108, kl_loss = 2.2146\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Batch 240, loss = 2797.3643, recon_loss = 2795.4651, kl_loss = 1.8991\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "Average loss: 2893.1905\n"
- ]
- }
- ],
- "source": [
- "vrae.fit(train_dataset)\n",
- "\n",
- "#If the model has to be saved, with the learnt parameters use:\n",
- "# vrae.fit(dataset, save = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Transform the input timeseries to encoded latent vectors"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([140, 32, 1])\n",
- "--------------------------\n"
- ]
- }
- ],
- "source": [
- "z_run = vrae.transform(test_dataset)\n",
- "\n",
- "#If the latent vectors have to be saved, pass the parameter `save`\n",
- "# z_run = vrae.transform(dataset, save = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Save the model to be fetched later"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {},
- "outputs": [],
- "source": [
- "vrae.save('vrae.pth')\n",
- "\n",
- "# To load a presaved model, execute:\n",
- "# vrae.load('vrae.pth')"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git "a/VRAE_\354\230\210\354\213\234.ipynb" "b/VRAE_\354\230\210\354\213\234.ipynb"
deleted file mode 100644
index c88052c..0000000
--- "a/VRAE_\354\230\210\354\213\234.ipynb"
+++ /dev/null
@@ -1,656 +0,0 @@
-{
- "cells": [
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "#### Contents\n",
- "\n",
- "0. [Load data and preprocess](#Load-data-and-preprocess)\n",
- "1. [Initialize VRAE object](#Initialize-VRAE-object)\n",
- "2. [Fit the model onto dataset](#Fit-the-model-onto-dataset)\n",
- "3. [Transform the input timeseries to encoded latent vectors](#Transform-the-input-timeseries-to-encoded-latent-vectors)\n",
- "4. [Save the model to be fetched later](#Save-the-model-to-be-fetched-later)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Import required modules"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 22,
- "metadata": {},
- "outputs": [],
- "source": [
- "%load_ext autoreload\n",
- "%autoreload 2\n",
- "\n",
- "from model.vrae import VRAE\n",
- "from model.utils import *\n",
- "import numpy as np\n",
- "import torch\n",
- "from torch.utils.data import DataLoader, Dataset\n",
- "from tqdm.notebook import trange\n",
- "import tqdm"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Input parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "dload = './saved_model' #download directory"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Hyper parameters"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Load data and preprocess\n",
- "- `folder` : data location\n",
- "- `cols_to_remove` : generation 수행하지 않을 column 제거"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**TODO : 해당 변수에 대한 처리를 어떻게 해줘야하는가 확인 작업이 필요함**\n",
- "\n",
- "~~~\n",
- "YYYYMMDD : 년월일\n",
- "HHMMSS : 시분초\n",
- "MNG_NO : 장비번호\n",
- "IF_IDX : 회선 index\n",
- "~~~\n",
- "\n",
- "- 현재는 분석의 편의를 위해 ['YYYYMMDD', 'HHMMSS']만 제거해줌"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "(23195128, 56)\n"
- ]
- }
- ],
- "source": [
- "# params\n",
- "folder = 'data'\n",
- "cols_to_remove = ['YYYYMMDD', 'HHMMSS']\n",
- "\n",
- "# load data\n",
- "df_total = load_data(folder, cols_to_remove)\n",
- "\n",
- "# shape\n",
- "print(df_total.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "class HamonDataset(Dataset):\n",
- " def __init__(self, data, window, stride):\n",
- " self.data = torch.Tensor(data)\n",
- " self.window = window\n",
- " \n",
- " def __len__(self):\n",
- " return len(self.data) - self.window \n",
- " \n",
- " def __getitem__(self, index):\n",
- " x_index = index*self.window\n",
- " x = self.data[x_index:x_index+self.window]\n",
- " return x"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "metadata": {},
- "outputs": [],
- "source": [
- "data = df_total\n",
- "stride = 10\n",
- "window = 100"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 6,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "<__main__.HamonDataset at 0x7f7cbaa3f940>"
- ]
- },
- "execution_count": 6,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_dataset = HamonDataset(data, window, stride)\n",
- "train_dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 7,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([100, 56])"
- ]
- },
- "execution_count": 7,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_dataset[0].shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Fetch `sequence_length` from dataset**"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 8,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "100"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "sequence_length = train_dataset[0].shape[0]\n",
- "sequence_length"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "**Fetch `number_of_features` from dataset**\n",
- "\n",
- "This config corresponds to number of input features"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "56"
- ]
- },
- "execution_count": 9,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "number_of_features = train_dataset[0].shape[1]\n",
- "number_of_features"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Parameters"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "metadata": {},
- "outputs": [],
- "source": [
- "n_epochs = 1\n",
- "hidden_size = 90\n",
- "hidden_layer_depth = 1\n",
- "latent_length = 20\n",
- "batch_size = 32\n",
- "learning_rate = 0.0005\n",
- "dropout_rate = 0.2\n",
- "optimizer = 'Adam' # options: ADAM, SGD\n",
- "cuda = True # options: True, False\n",
- "print_every=30\n",
- "clip = True # options: True, False\n",
- "max_grad_norm=5\n",
- "loss = 'MSELoss' # options: SmoothL1Loss, MSELoss\n",
- "block = 'LSTM' # options: LSTM, GRU"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 11,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_loader = DataLoader(dataset = train_dataset,\n",
- " batch_size = batch_size,\n",
- " shuffle = False,\n",
- " drop_last=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 12,
- "metadata": {
- "scrolled": true
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[[2.7220e+03, 1.2400e+02, 1.8431e+05, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 3.8349e+05, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 2.3519e+05, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [2.8500e+03, 1.2400e+02, 2.3200e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]],\n",
- "\n",
- " [[2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [2.8630e+03, 1.2400e+02, 1.8664e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8630e+03, 1.2400e+02, 1.9056e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8630e+03, 1.2400e+02, 1.8104e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]],\n",
- "\n",
- " [[2.8630e+03, 1.2400e+02, 1.8096e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8630e+03, 1.2400e+02, 1.8640e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8630e+03, 1.2400e+02, 1.9448e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [2.8730e+03, 1.2400e+02, 3.3920e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8730e+03, 1.2400e+02, 3.4480e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [2.8730e+03, 1.2400e+02, 3.3840e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]],\n",
- "\n",
- " ...,\n",
- "\n",
- " [[3.7730e+03, 1.2400e+02, 2.0880e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7730e+03, 1.2400e+02, 1.9360e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7730e+03, 1.2400e+02, 1.9840e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [3.7810e+03, 1.2400e+02, 1.7760e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7810e+03, 1.2400e+02, 1.6800e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7810e+03, 1.2400e+02, 1.7600e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]],\n",
- "\n",
- " [[3.7820e+03, 1.2400e+02, 5.1096e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7820e+03, 1.2400e+02, 1.2566e+06, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7820e+03, 1.2400e+02, 5.2016e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [3.7900e+03, 1.2400e+02, 1.6496e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7900e+03, 1.2400e+02, 1.6416e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7900e+03, 1.2400e+02, 1.6776e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]],\n",
- "\n",
- " [[3.7900e+03, 1.2400e+02, 1.6032e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7900e+03, 1.2400e+02, 1.6528e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7900e+03, 1.2400e+02, 1.7032e+04, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " ...,\n",
- " [3.7980e+03, 1.2400e+02, 6.1760e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7980e+03, 1.2400e+02, 6.1920e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00],\n",
- " [3.7980e+03, 1.2400e+02, 5.9840e+03, ..., 0.0000e+00,\n",
- " 0.0000e+00, 0.0000e+00]]])"
- ]
- },
- "execution_count": 12,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X = iter(train_loader).next()\n",
- "X"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 13,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "torch.Size([32, 100, 56])"
- ]
- },
- "execution_count": 13,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "X.shape"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Initialize VRAE object\n",
- "\n",
- "VRAE inherits from `sklearn.base.BaseEstimator` and overrides `fit`, `transform` and `fit_transform` functions, similar to sklearn modules"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 23,
- "metadata": {},
- "outputs": [],
- "source": [
- "vrae = VRAE(sequence_length=sequence_length,\n",
- " number_of_features = number_of_features,\n",
- " hidden_size = hidden_size, \n",
- " hidden_layer_depth = hidden_layer_depth,\n",
- " latent_length = latent_length,\n",
- " batch_size = batch_size,\n",
- " learning_rate = learning_rate,\n",
- " n_epochs = n_epochs,\n",
- " dropout_rate = dropout_rate,\n",
- " optimizer = optimizer, \n",
- " cuda = cuda,\n",
- " print_every=print_every, \n",
- " clip=clip, \n",
- " max_grad_norm=max_grad_norm,\n",
- " loss = loss,\n",
- " block = block,\n",
- " dload = dload)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Fit the model onto dataset"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 24,
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "tensor([[2.7220e+03, 1.2400e+02, 1.8431e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 3.8349e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 2.3519e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " ...,\n",
- " [2.8500e+03, 1.2400e+02, 2.3200e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00]])"
- ]
- },
- "execution_count": 24,
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "train_dataset[0]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 28,
- "metadata": {},
- "outputs": [],
- "source": [
- "train_loader_test = DataLoader(dataset = train_dataset[0],\n",
- " batch_size = 32,\n",
- " shuffle = False,\n",
- " drop_last=True)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 29,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "torch.Size([32, 56])\n"
- ]
- }
- ],
- "source": [
- "tmp = iter(train_loader_test).next()\n",
- "print(tmp.shape)"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 30,
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "fit result\n",
- "<__main__.HamonDataset object at 0x7f7cbaa3f940>\n",
- "tensor([[2.7220e+03, 1.2400e+02, 1.8431e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 3.8349e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.7220e+03, 1.2400e+02, 2.3519e+05, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " ...,\n",
- " [2.8500e+03, 1.2400e+02, 2.3200e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00],\n",
- " [2.8500e+03, 1.2400e+02, 2.4000e+02, ..., 0.0000e+00, 0.0000e+00,\n",
- " 0.0000e+00]])\n",
- "torch.Size([32, 100, 56])\n",
- "Epoch: 0\n",
- "--------------------------\n",
- "DEBUGGING\n",
- "torch.Size([32, 100, 56])\n",
- "--------------------------\n"
- ]
- },
- {
- "ename": "RuntimeError",
- "evalue": "Expected hidden[0] size (1, 32, 90), got [1, 100, 90]",
- "output_type": "error",
- "traceback": [
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
- "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
- "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mvrae\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;31m#If the model has to be saved, with the learnt parameters use:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;31m# vrae.fit(dataset, save = True)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/repo/projects/timeseries-generation/model/vrae.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, dataset, save)\u001b[0m\n\u001b[1;32m 352\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'Epoch: %s'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 353\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 354\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_train\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 355\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 356\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_fitted\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/repo/projects/timeseries-generation/model/vrae.py\u001b[0m in \u001b[0;36m_train\u001b[0;34m(self, train_loader)\u001b[0m\n\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 310\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecon_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkl_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompute_loss\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 311\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/repo/projects/timeseries-generation/model/vrae.py\u001b[0m in \u001b[0;36mcompute_loss\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 287\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mVariable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdtype\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrequires_grad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 289\u001b[0;31m \u001b[0mx_decoded\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 290\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrecon_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkl_loss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_rec\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_decoded\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdetach\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloss_fn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 291\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/repo/projects/timeseries-generation/model/vrae.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m 257\u001b[0m \u001b[0mcell_output\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mencoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 258\u001b[0m \u001b[0mlatent\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlmbd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcell_output\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 259\u001b[0;31m \u001b[0mx_decoded\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlatent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 260\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 261\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx_decoded\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlatent\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/repo/projects/timeseries-generation/model/vrae.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, latent)\u001b[0m\n\u001b[1;32m 139\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mLSTM\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 140\u001b[0m \u001b[0mh_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh_state\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhidden_layer_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 141\u001b[0;31m \u001b[0mdecoder_output\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdecoder_inputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mh_0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 142\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mGRU\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 143\u001b[0m \u001b[0mh_0\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mh_state\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhidden_layer_depth\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1100\u001b[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks\n\u001b[1;32m 1101\u001b[0m or _global_forward_hooks or _global_forward_pre_hooks):\n\u001b[0;32m-> 1102\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mforward_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1103\u001b[0m \u001b[0;31m# Do not call functions when jit is used\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1104\u001b[0m \u001b[0mfull_backward_hooks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_full_backward_hooks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input, hx)\u001b[0m\n\u001b[1;32m 687\u001b[0m \u001b[0mhx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpermute_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msorted_indices\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 688\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 689\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 690\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mbatch_sizes\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 691\u001b[0m result = _VF.lstm(input, hx, self._flat_weights, self.bias, self.num_layers,\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_forward_args\u001b[0;34m(self, input, hidden, batch_sizes)\u001b[0m\n\u001b[1;32m 632\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcheck_input\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 633\u001b[0m self.check_hidden_size(hidden[0], self.get_expected_hidden_size(input, batch_sizes),\n\u001b[0;32m--> 634\u001b[0;31m 'Expected hidden[0] size {}, got {}')\n\u001b[0m\u001b[1;32m 635\u001b[0m self.check_hidden_size(hidden[1], self.get_expected_cell_size(input, batch_sizes),\n\u001b[1;32m 636\u001b[0m 'Expected hidden[1] size {}, got {}')\n",
- "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/torch/nn/modules/rnn.py\u001b[0m in \u001b[0;36mcheck_hidden_size\u001b[0;34m(self, hx, expected_hidden_size, msg)\u001b[0m\n\u001b[1;32m 224\u001b[0m msg: str = 'Expected hidden size {}, got {}') -> None:\n\u001b[1;32m 225\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 226\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mRuntimeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mexpected_hidden_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 227\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 228\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mcheck_forward_args\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhidden\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_sizes\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mOptional\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mTensor\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
- "\u001b[0;31mRuntimeError\u001b[0m: Expected hidden[0] size (1, 32, 90), got [1, 100, 90]"
- ]
- }
- ],
- "source": [
- "vrae.fit(train_dataset)\n",
- "\n",
- "#If the model has to be saved, with the learnt parameters use:\n",
- "# vrae.fit(dataset, save = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Transform the input timeseries to encoded latent vectors"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "z_run = vrae.transform(test_dataset)\n",
- "\n",
- "#If the latent vectors have to be saved, pass the parameter `save`\n",
- "# z_run = vrae.transform(dataset, save = True)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Save the model to be fetched later"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "vrae.save('vrae.pth')\n",
- "\n",
- "# To load a presaved model, execute:\n",
- "# vrae.load('vrae.pth')"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "### Visualize using PCA and tSNE"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {},
- "outputs": [],
- "source": [
- "plot_clustering(z_run, y_val, engine='matplotlib', download = False)\n",
- "\n",
- "# If plotly to be used as rendering engine, uncomment below line\n",
- "#plot_clustering(z_run, y_val, engine='plotly', download = False)"
- ]
- }
- ],
- "metadata": {
- "kernelspec": {
- "display_name": "Python 3",
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.6.9"
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-}
diff --git a/params.txt b/params.txt
deleted file mode 100644
index 6272edd..0000000
--- a/params.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-hidden_size = 90
-hidden_layer_depth = 1
-latent_length = 20
-batch_size = 100
-learning_rate = 0.0005
-n_epochs = 50
-dropout_rate = 0.2
-optimizer = 'Adam'
-cuda = True
-print_every=30
-clip = True
-max_grad_norm=5
-loss = 'MSELoss'
-block = 'LSTM'
\ No newline at end of file
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index fe1979c..0000000
--- a/requirements.txt
+++ /dev/null
@@ -1,7 +0,0 @@
-matplotlib==3.3.4
-numpy==1.19.5
-plotly==5.4.0
-scikit-learn==0.24.2
-scipy==1.5.4
-torch==1.10.0
-torchvision==0.11.1
\ No newline at end of file