diff --git a/Dockerfile b/Dockerfile
index 477caff..1834db2 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -19,6 +19,7 @@ ENV PIP_DEFAULT_TIMEOUT=100 \
 RUN pip install "poetry==$POETRY_VERSION"
 
 COPY pyproject.toml poetry.lock README.md .
+RUN pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
 RUN poetry install --only main --no-root
 
 COPY india_forecast_app ./india_forecast_app
diff --git a/poetry.lock b/poetry.lock
index d8c3aba..29ead05 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -3726,146 +3726,6 @@ files = [
     {file = "numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010"},
 ]
 
-[[package]]
-name = "nvidia-cublas-cu12"
-version = "12.1.3.1"
-description = "CUBLAS native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl", hash = "sha256:ee53ccca76a6fc08fb9701aa95b6ceb242cdaab118c3bb152af4e579af792728"},
-    {file = "nvidia_cublas_cu12-12.1.3.1-py3-none-win_amd64.whl", hash = "sha256:2b964d60e8cf11b5e1073d179d85fa340c120e99b3067558f3cf98dd69d02906"},
-]
-
-[[package]]
-name = "nvidia-cuda-cupti-cu12"
-version = "12.1.105"
-description = "CUDA profiling tools runtime libs."
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:e54fde3983165c624cb79254ae9818a456eb6e87a7fd4d56a2352c24ee542d7e"},
-    {file = "nvidia_cuda_cupti_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:bea8236d13a0ac7190bd2919c3e8e6ce1e402104276e6f9694479e48bb0eb2a4"},
-]
-
-[[package]]
-name = "nvidia-cuda-nvrtc-cu12"
-version = "12.1.105"
-description = "NVRTC native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:339b385f50c309763ca65456ec75e17bbefcbbf2893f462cb8b90584cd27a1c2"},
-    {file = "nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:0a98a522d9ff138b96c010a65e145dc1b4850e9ecb75a0172371793752fd46ed"},
-]
-
-[[package]]
-name = "nvidia-cuda-runtime-cu12"
-version = "12.1.105"
-description = "CUDA Runtime native Libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:6e258468ddf5796e25f1dc591a31029fa317d97a0a94ed93468fc86301d61e40"},
-    {file = "nvidia_cuda_runtime_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:dfb46ef84d73fababab44cf03e3b83f80700d27ca300e537f85f636fac474344"},
-]
-
-[[package]]
-name = "nvidia-cudnn-cu12"
-version = "8.9.2.26"
-description = "cuDNN runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl", hash = "sha256:5ccb288774fdfb07a7e7025ffec286971c06d8d7b4fb162525334616d7629ff9"},
-]
-
-[package.dependencies]
-nvidia-cublas-cu12 = "*"
-
-[[package]]
-name = "nvidia-cufft-cu12"
-version = "11.0.2.54"
-description = "CUFFT native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl", hash = "sha256:794e3948a1aa71fd817c3775866943936774d1c14e7628c74f6f7417224cdf56"},
-    {file = "nvidia_cufft_cu12-11.0.2.54-py3-none-win_amd64.whl", hash = "sha256:d9ac353f78ff89951da4af698f80870b1534ed69993f10a4cf1d96f21357e253"},
-]
-
-[[package]]
-name = "nvidia-curand-cu12"
-version = "10.3.2.106"
-description = "CURAND native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:9d264c5036dde4e64f1de8c50ae753237c12e0b1348738169cd0f8a536c0e1e0"},
-    {file = "nvidia_curand_cu12-10.3.2.106-py3-none-win_amd64.whl", hash = "sha256:75b6b0c574c0037839121317e17fd01f8a69fd2ef8e25853d826fec30bdba74a"},
-]
-
-[[package]]
-name = "nvidia-cusolver-cu12"
-version = "11.4.5.107"
-description = "CUDA solver native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd"},
-    {file = "nvidia_cusolver_cu12-11.4.5.107-py3-none-win_amd64.whl", hash = "sha256:74e0c3a24c78612192a74fcd90dd117f1cf21dea4822e66d89e8ea80e3cd2da5"},
-]
-
-[package.dependencies]
-nvidia-cublas-cu12 = "*"
-nvidia-cusparse-cu12 = "*"
-nvidia-nvjitlink-cu12 = "*"
-
-[[package]]
-name = "nvidia-cusparse-cu12"
-version = "12.1.0.106"
-description = "CUSPARSE native runtime libraries"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c"},
-    {file = "nvidia_cusparse_cu12-12.1.0.106-py3-none-win_amd64.whl", hash = "sha256:b798237e81b9719373e8fae8d4f091b70a0cf09d9d85c95a557e11df2d8e9a5a"},
-]
-
-[package.dependencies]
-nvidia-nvjitlink-cu12 = "*"
-
-[[package]]
-name = "nvidia-nccl-cu12"
-version = "2.19.3"
-description = "NVIDIA Collective Communication Library (NCCL) Runtime"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl", hash = "sha256:a9734707a2c96443331c1e48c717024aa6678a0e2a4cb66b2c364d18cee6b48d"},
-]
-
-[[package]]
-name = "nvidia-nvjitlink-cu12"
-version = "12.3.101"
-description = "Nvidia JIT LTO Library"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-manylinux1_x86_64.whl", hash = "sha256:64335a8088e2b9d196ae8665430bc6a2b7e6ef2eb877a9c735c804bd4ff6467c"},
-    {file = "nvidia_nvjitlink_cu12-12.3.101-py3-none-win_amd64.whl", hash = "sha256:1b2e317e437433753530792f13eece58f0aec21a2b05903be7bffe58a606cbd1"},
-]
-
-[[package]]
-name = "nvidia-nvtx-cu12"
-version = "12.1.105"
-description = "NVIDIA Tools Extension"
-optional = false
-python-versions = ">=3"
-files = [
-    {file = "nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl", hash = "sha256:dc21cf308ca5691e7c04d962e213f8a4aa9bbfa23d95412f452254c2caeb09e5"},
-    {file = "nvidia_nvtx_cu12-12.1.105-py3-none-win_amd64.whl", hash = "sha256:65f4d98982b31b60026e0e6de73fbdfc09d08a96f4656dd3665ca616a11e1e82"},
-]
 
 [[package]]
 name = "oauthlib"
@@ -6326,17 +6186,6 @@ filelock = "*"
 fsspec = "*"
 jinja2 = "*"
 networkx = "*"
-nvidia-cublas-cu12 = {version = "12.1.3.1", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-cupti-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-nvrtc-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cuda-runtime-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cudnn-cu12 = {version = "8.9.2.26", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cufft-cu12 = {version = "11.0.2.54", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-curand-cu12 = {version = "10.3.2.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cusolver-cu12 = {version = "11.4.5.107", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-cusparse-cu12 = {version = "12.1.0.106", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-nccl-cu12 = {version = "2.19.3", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
-nvidia-nvtx-cu12 = {version = "12.1.105", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\""}
 sympy = "*"
 triton = {version = "2.2.0", markers = "platform_system == \"Linux\" and platform_machine == \"x86_64\" and python_version < \"3.12\""}
 typing-extensions = ">=4.8.0"
@@ -6930,4 +6779,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
 [metadata]
 lock-version = "2.0"
 python-versions = "^3.11"
-content-hash = "a89bd5f80669e9b7bcef65f97728d18e13213f02fc22ca575c8dd70c1d0326b9"
+content-hash = "0887be0e1b025bd69ba66727a955d979dd111fac45cdab34656c95305c84ef65"
diff --git a/pyproject.toml b/pyproject.toml
index cdff813..ba47f2e 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -13,8 +13,9 @@ pandas = "1.5.3"
 pvnet = "3.0.5"
 pytz = "^2024.1"
 numpy = "^1.26.4"
-ocf-datapipes = "^3.2.9"
 huggingface-hub = "0.20.3"
+torch = "2.2.1"
+ocf-datapipes = "^3.2.9"
 
 [tool.poetry.group.dev.dependencies]
 pytest = "^7.4.4"