diff --git a/torchsparsegradutils/tests/spd_forward_solve_notebook.ipynb b/torchsparsegradutils/tests/spd_forward_solve_notebook.ipynb
new file mode 100644
index 0000000..6b4a747
--- /dev/null
+++ b/torchsparsegradutils/tests/spd_forward_solve_notebook.ipynb
@@ -0,0 +1,389 @@
+{
+ "nbformat": 4,
+ "nbformat_minor": 0,
+ "metadata": {
+ "colab": {
+ "provenance": [],
+ "authorship_tag": "ABX9TyN0JFpiKDeXK9tTd/a7xn9o",
+ "include_colab_link": true
+ },
+ "kernelspec": {
+ "name": "python3",
+ "display_name": "Python 3"
+ },
+ "language_info": {
+ "name": "python"
+ },
+ "accelerator": "GPU",
+ "gpuClass": "standard"
+ },
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "view-in-github",
+ "colab_type": "text"
+ },
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "# A basic notebook to compare different solvers for a SPD matrix\n",
+ "\n",
+ "First import all necessary modules."
+ ],
+ "metadata": {
+ "id": "xvuQc5HyjX4t"
+ }
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "3hHTa-xCjIQv",
+ "outputId": "cf255d05-b0ab-4a79-9765-1ed23fbff784"
+ },
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Running PyTorch version: 1.13.0+cu116\n",
+ "Default GPU is Tesla T4\n",
+ "Running on cuda\n",
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
+ "Collecting git+https://github.com/cai4cai/torchsparsegradutils\n",
+ " Cloning https://github.com/cai4cai/torchsparsegradutils to /tmp/pip-req-build-plwawj9f\n",
+ " Running command git clone -q https://github.com/cai4cai/torchsparsegradutils /tmp/pip-req-build-plwawj9f\n",
+ " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n",
+ " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n",
+ " Preparing wheel metadata ... \u001b[?25l\u001b[?25hdone\n",
+ "Requirement already satisfied: torch>=1.13 in /usr/local/lib/python3.8/dist-packages (from torchsparsegradutils==0.0.1) (1.13.0+cu116)\n",
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.8/dist-packages (from torch>=1.13->torchsparsegradutils==0.0.1) (4.4.0)\n"
+ ]
+ }
+ ],
+ "source": [
+ "import torch\n",
+ "print(f'Running PyTorch version: {torch.__version__}')\n",
+ "\n",
+ "torchdevice = torch.device('cpu')\n",
+ "if torch.cuda.is_available():\n",
+ " torchdevice = torch.device('cuda')\n",
+ " print('Default GPU is ' + torch.cuda.get_device_name(torch.device('cuda')))\n",
+ "print('Running on ' + str(torchdevice))\n",
+ "\n",
+ "import numpy as np\n",
+ "import scipy.io\n",
+ "import scipy.sparse.linalg\n",
+ "\n",
+ "import cupyx\n",
+ "import cupyx.scipy.sparse as csp\n",
+ "\n",
+ "import jax\n",
+ "from jax.config import config\n",
+ "config.update(\"jax_enable_x64\", True)\n",
+ "\n",
+ "!pip install git+https://github.com/cai4cai/torchsparsegradutils\n",
+ "import torchsparsegradutils as tsgu\n",
+ "import torchsparsegradutils.utils\n",
+ "import torchsparsegradutils.cupy as tsgucupy\n",
+ "import torchsparsegradutils.jax as tsgujax\n",
+ "\n",
+ "import time\n",
+ "import urllib\n",
+ "import os.path\n",
+ "import tarfile"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Now load an example SPD matrix and create a random RHS vector"
+ ],
+ "metadata": {
+ "id": "Vp-qCHEDkQbm"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def load_mat_from_suitesparse_collection(dirname,matname):\n",
+ " base_url = 'https://suitesparse-collection-website.herokuapp.com/MM/'\n",
+ " url = base_url + dirname + '/' + matname + '.tar.gz'\n",
+ " compressedlocalfile = matname + '.tar.gz'\n",
+ " if not os.path.exists(compressedlocalfile):\n",
+ " print(f'Downloading {url}')\n",
+ " urllib.request.urlretrieve(url, filename=compressedlocalfile)\n",
+ "\n",
+ " localfile = './' + matname + '/' + matname + '.mtx'\n",
+ " if not os.path.exists(localfile):\n",
+ " print(f'untarring {compressedlocalfile}')\n",
+ " srctarfile = tarfile.open(compressedlocalfile)\n",
+ " srctarfile.extractall('./')\n",
+ " srctarfile.close()\n",
+ "\n",
+ " A_np_coo = scipy.io.mmread(localfile)\n",
+ " print(f'Loaded suitesparse matrix {dirname}/{matname}: type={type(A_np_coo)}, shape={A_np_coo.shape}')\n",
+ " return A_np_coo\n",
+ "\n",
+ "A_np_coo = load_mat_from_suitesparse_collection('Rothberg','cfd2')\n",
+ "A_np_csr = scipy.sparse.csr_matrix(A_np_coo)\n",
+ "\n",
+ "b_np = np.random.randn(A_np_coo.shape[1])\n",
+ "print(f'Created random RHS with shape={b_np.shape}')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "0CzeubR_kUVA",
+ "outputId": "dc3ff1db-1425-4a82-cf54-0ad60f81963c"
+ },
+ "execution_count": 2,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "Loaded suitesparse matrix Rothberg/cfd2: type=, shape=(123440, 123440)\n",
+ "Created random RHS with shape=(123440,)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Define some helper functions to run the tests"
+ ],
+ "metadata": {
+ "id": "MUacEhsqoYxZ"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def scipy_test(A, b, tested_solver, print_string):\n",
+ " t = time.time()\n",
+ " x = tested_solver(A, b)\n",
+ " elapsed = time.time() - t\n",
+ " resnorm = scipy.linalg.norm(A @ x - b)\n",
+ " print(f'{print_string} took {elapsed:.2f} seconds - resnorm={resnorm:.2e}')\n",
+ " return elapsed, resnorm"
+ ],
+ "metadata": {
+ "id": "1jYXGRDlnmYA"
+ },
+ "execution_count": 3,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Run the tests with the scipy routines (it can take a while)"
+ ],
+ "metadata": {
+ "id": "Th8gF8MnoeGR"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "run_scipy = True\n",
+ "if run_scipy:\n",
+ " t, n = scipy_test(A_np_coo, b_np, scipy.sparse.linalg.spsolve, 'scipy.spsolve COO')\n",
+ " t, n = scipy_test(A_np_csr, b_np, scipy.sparse.linalg.spsolve, 'scipy.spsolve CSR')\n",
+ " t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.cg(A,b)[0], 'scipy.cg COO')\n",
+ " t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.cg(A,b)[0], 'scipy.cg CSR')\n",
+ " t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.bicgstab(A,b)[0], 'scipy.bicgstab COO')\n",
+ " t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.bicgstab(A,b)[0], 'scipy.bicgstab CSR')\n",
+ " t, n = scipy_test(A_np_coo, b_np, lambda A, b: scipy.sparse.linalg.minres(A,b)[0], 'scipy.minres COO')\n",
+ " t, n = scipy_test(A_np_csr, b_np, lambda A, b: scipy.sparse.linalg.minres(A,b)[0], 'scipy.minres CSR')\n",
+ "else:\n",
+ " print('Skipping scipy tests')"
+ ],
+ "metadata": {
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "id": "4KaNyZpDogS4",
+ "outputId": "34e0dc0b-392b-4485-ac55-1233dde26aaf"
+ },
+ "execution_count": 4,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/scipy/sparse/linalg/dsolve/linsolve.py:144: SparseEfficiencyWarning: spsolve requires A be CSC or CSR matrix format\n",
+ " warn('spsolve requires A be CSC or CSR matrix format',\n"
+ ]
+ },
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "scipy.spsolve COO took 108.62 seconds - resnorm=2.15e-10\n",
+ "scipy.spsolve CSR took 98.29 seconds - resnorm=2.49e-10\n",
+ "scipy.cg COO took 72.09 seconds - resnorm=3.50e-03\n",
+ "scipy.cg CSR took 55.59 seconds - resnorm=3.50e-03\n",
+ "scipy.bicgstab COO took 112.65 seconds - resnorm=2.31e-03\n",
+ "scipy.bicgstab CSR took 88.79 seconds - resnorm=2.31e-03\n",
+ "scipy.minres COO took 0.54 seconds - resnorm=3.05e+01\n",
+ "scipy.minres CSR took 0.38 seconds - resnorm=3.05e+01\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Create corresponding PyTorch tensors"
+ ],
+ "metadata": {
+ "id": "BqOzcOXYrMY0"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "A_tcpu_csr = tsgucupy.c2t_csr(A_np_csr)\n",
+ "b_tcpu = torch.from_numpy(b_np)\n",
+ "\n",
+ "if torch.cuda.is_available():\n",
+ " A_tgpu_csr = A_tcpu_csr.to(torchdevice)\n",
+ " b_tgpu = b_tcpu.to(torchdevice)"
+ ],
+ "metadata": {
+ "id": "DHqwjtVOrRyH",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "307faf94-ee4a-4943-d483-ad084246ee0d"
+ },
+ "execution_count": 5,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stderr",
+ "text": [
+ "/usr/local/lib/python3.8/dist-packages/torchsparsegradutils/cupy/cupy_bindings.py:52: UserWarning: Sparse CSR tensor support is in beta state. If you miss a functionality in the sparse tensor support, please submit a feature request to https://github.com/pytorch/pytorch/issues. (Triggered internally at ../aten/src/ATen/SparseCsrTensorImpl.cpp:54.)\n",
+ " x_torch = torch.sparse_csr_tensor(ind_ptr_t, idices_t, data_t, x_cupy.shape)\n"
+ ]
+ }
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Define some helper function to run the tests with pytorch"
+ ],
+ "metadata": {
+ "id": "zRZ7yXiZrjJ3"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "def torch_test(A, b, tested_solver, print_string):\n",
+ " t = time.time()\n",
+ " x = tested_solver(A, b)\n",
+ " elapsed = time.time() - t\n",
+ " resnorm = torch.norm(A @ x - b).cpu().numpy()\n",
+ " print(f'{print_string} took {elapsed:.2f} seconds - resnorm={resnorm:.2e}')\n",
+ " print(f'GPU memory allocated: {torch.cuda.memory_allocated(device=torchdevice)/10**9:.2f}Gb'\n",
+ " f' - max allocated: {torch.cuda.max_memory_allocated(device=torchdevice)/10**9:.2f}Gb')\n",
+ " #print(torch.cuda.memory_summary(abbreviated=True))\n",
+ " return elapsed, resnorm"
+ ],
+ "metadata": {
+ "id": "jE0EFZiJrw71"
+ },
+ "execution_count": 6,
+ "outputs": []
+ },
+ {
+ "cell_type": "markdown",
+ "source": [
+ "Run the tests with our pytorch routines (this can take a while)"
+ ],
+ "metadata": {
+ "id": "-nSPnGeosFGr"
+ }
+ },
+ {
+ "cell_type": "code",
+ "source": [
+ "print(f'GPU memory allocated: {torch.cuda.memory_allocated(device=torchdevice)/10**9:.2f}Gb'\n",
+ " f' - max allocated: {torch.cuda.max_memory_allocated(device=torchdevice)/10**9:.2f}Gb')\n",
+ "#print(torch.cuda.memory_summary(abbreviated=True))\n",
+ "\n",
+ "#t, n = torch_test(A_tcpu_csr, b_tcpu, tsgu.sparse_generic_solve, 'tsgu.sparse_generic_solve CPU CSR')\n",
+ "if torch.cuda.is_available():\n",
+ " t, n = torch_test(A_tgpu_csr, b_tgpu, tsgu.sparse_generic_solve, 'tsgu.sparse_generic_solve GPU CSR')\n",
+ "\n",
+ "#t, n = torch_test(A_tcpu_csr, b_tcpu, tsgu.sparse_generic_lstsq, 'tsgu.sparse_generic_lstsq CPU CSR')\n",
+ "#if torch.cuda.is_available():\n",
+ "# t, n = torch_test(A_tgpu_csr, b_tgpu, tsgu.sparse_generic_lstsq, 'tsgu.sparse_generic_lstsq GPU CSR')\n",
+ "\n",
+ "mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.minres)\n",
+ "if torch.cuda.is_available():\n",
+ " t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve minres GPU CSR')\n",
+ "\n",
+ "mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.linear_cg)\n",
+ "if torch.cuda.is_available():\n",
+ " t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve cg GPU CSR')\n",
+ "\n",
+ "mysolver = lambda A, b: tsgu.sparse_generic_solve(A,b,solve=tsgu.utils.bicgstab)\n",
+ "if torch.cuda.is_available():\n",
+ " t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_generic_solve bicgstab GPU CSR')\n",
+ "\n",
+ "jaxsolver = None\n",
+ "mysolver = lambda A, b: tsgujax.sparse_solve_j4t(A,b)\n",
+ "if torch.cuda.is_available():\n",
+ " t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_solve_j4t cg GPU CSR')\n",
+ "\n",
+ "#cpsolver = lambda AA, BB: csp.linalg.cg(AA,BB)[0]\n",
+ "#mysolver = lambda A, b: tsgucupy.sparse_solve_c4t(A,b,solve=cpsolver)\n",
+ "#mysolver = lambda A, b: tsgucupy.sparse_solve_c4t(A,b)\n",
+ "#t, n = torch_test(A_tcpu_csr, b_tcpu, mysolver, 'tsgu.sparse_solve_c4t cg CPU CSR')\n",
+ "#if torch.cuda.is_available():\n",
+ "# t, n = torch_test(A_tgpu_csr, b_tgpu, mysolver, 'tsgu.sparse_solve_c4t cg GPU CSR')"
+ ],
+ "metadata": {
+ "id": "3HH-nGDSsIXy",
+ "colab": {
+ "base_uri": "https://localhost:8080/"
+ },
+ "outputId": "7a8dacff-efaa-416a-f6d5-9a87a2b84157"
+ },
+ "execution_count": 7,
+ "outputs": [
+ {
+ "output_type": "stream",
+ "name": "stdout",
+ "text": [
+ "GPU memory allocated: 0.04Gb - max allocated: 0.04Gb\n",
+ "tsgu.sparse_generic_solve GPU CSR took 0.62 seconds - resnorm=1.93e+00\n",
+ "GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n",
+ "tsgu.sparse_generic_solve minres GPU CSR took 0.46 seconds - resnorm=1.93e+00\n",
+ "GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n",
+ "tsgu.sparse_generic_solve cg GPU CSR took 0.06 seconds - resnorm=2.43e+02\n",
+ "GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n",
+ "tsgu.sparse_generic_solve bicgstab GPU CSR took 9.35 seconds - resnorm=2.60e-04\n",
+ "GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n",
+ "tsgu.sparse_solve_j4t cg GPU CSR took 6.39 seconds - resnorm=2.74e-03\n",
+ "GPU memory allocated: 0.04Gb - max allocated: 0.05Gb\n"
+ ]
+ }
+ ]
+ }
+ ]
+}
\ No newline at end of file