Skip to content

Commit

Permalink
Completed project with cleaned code and readme
Browse files Browse the repository at this point in the history
  • Loading branch information
voetberg committed May 16, 2023
1 parent 02ab95a commit f8bba26
Show file tree
Hide file tree
Showing 29 changed files with 1,168 additions and 2,353 deletions.
Empty file added WavPool/__init__.py
Empty file.
3 changes: 3 additions & 0 deletions WavPool/data_generators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from WavPool.data_generators.cifar_generator import CIFARGenerator
from WavPool.data_generators.mnist_generator import MNISTGenerator
from WavPool.data_generators.fashion_mnist_generator import FashionMNISTGenerator
18 changes: 18 additions & 0 deletions WavPool/data_generators/cifar_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from WavPool.data_generators.data_generator import DataGenerator
from torchvision import transforms
from torchvision.datasets import CIFAR10


class CIFARGenerator(DataGenerator):
def __init__(self):
grayscale_transforms = transforms.Compose(
[transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]
)

dataset = CIFAR10(
root="wavNN/data/cifar10",
download=True,
train=True,
transform=grayscale_transforms,
)
super().__init__(dataset=dataset)
59 changes: 59 additions & 0 deletions WavPool/data_generators/data_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from torch.utils.data import DataLoader, Subset
import numpy as np


class DataGenerator:
def __init__(self, dataset):
self.dataset = dataset

def __call__(self, *args, **kwargs):
sample_size = kwargs["sample_size"]
split = False if not "split" in kwargs else kwargs["split"]
if type(sample_size) == list:
split = True

batch_size = 64 if "batch_size" not in kwargs else kwargs["batch_size"]
shuffle = False if "shuffle" not in kwargs else kwargs["shuffle"]

if split:
if type(sample_size) == int:
sample_size = [sample_size for _ in range(3)]

assert sum(sample_size) <= len(self.dataset), (
f""
f"Too many requested samples, "
f"decreases your sample size to less "
f"than {len(self.dataset)}"
)

assert len(sample_size) == 3, (
"The sample size of validation "
"and test must be individually specified"
)

# this is quick and dirty. Ideally I'd be shuffling when i load in.
# But I'm chronically lazy and this is what I'm doing
# I can change it later.

samples = np.cumsum(sample_size)
training_data = Subset(self.dataset, list(range(0, samples[0])))
val_data = Subset(self.dataset, list(range(samples[0], samples[1])))
test_data = Subset(self.dataset, list(range(samples[1], samples[2])))

training = DataLoader(training_data, batch_size=batch_size, shuffle=shuffle)
validation = DataLoader(val_data, batch_size=batch_size, shuffle=shuffle)
test = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle)

else:
assert sample_size <= len(self.dataset), (
f"Too many requested samples, decreases your"
f" sample size to less than {len(self.dataset)}"
)

training_data = Subset(self.dataset, list(range(0, sample_size)))

training = DataLoader(training_data, batch_size=batch_size, shuffle=shuffle)
validation = None
test = None

return {"training": training, "validation": validation, "test": test}
14 changes: 14 additions & 0 deletions WavPool/data_generators/fashion_mnist_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from WavPool.data_generators.data_generator import DataGenerator
from torchvision.transforms import ToTensor
from torchvision.datasets import FashionMNIST


class FashionMNISTGenerator(DataGenerator):
def __init__(self):
dataset = FashionMNIST(
root="wavNN/data/fashionmnist",
download=True,
train=True,
transform=ToTensor(),
)
super().__init__(dataset=dataset)
11 changes: 11 additions & 0 deletions WavPool/data_generators/mnist_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from WavPool.data_generators.data_generator import DataGenerator
from torchvision.transforms import ToTensor
from torchvision.datasets import MNIST


class MNISTGenerator(DataGenerator):
def __init__(self):
dataset = MNIST(
root="wavNN/data/mnist", download=True, train=True, transform=ToTensor()
)
super().__init__(dataset=dataset)
4 changes: 4 additions & 0 deletions WavPool/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from WavPool.models.wavMLP import WavMLP
from WavPool.models.wavpool import WavPool
from WavPool.models.vanillaMLP import VanillaMLP
from WavPool.models.vanillaCNN import VanillaCNN
52 changes: 52 additions & 0 deletions WavPool/models/vanillaCNN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import torch.nn as nn
import math


class VanillaCNN(nn.Module):
def __init__(
self,
in_channels: int,
kernel_size: int,
out_channels: int,
hidden_channels_1: int = 1,
hidden_channels_2: int = 1,
) -> None:
super().__init__()

self.conv1 = nn.Conv2d(
in_channels=1,
out_channels=int(hidden_channels_1),
kernel_size=int(kernel_size),
padding=1,
stride=1,
bias=False,
)
self.batch_norm_hidden = nn.BatchNorm2d(int(hidden_channels_1))
conv1_out = math.ceil(in_channels + 2 - int(kernel_size) + 1)

self.conv2 = nn.Conv2d(
in_channels=int(hidden_channels_1),
out_channels=int(hidden_channels_2),
kernel_size=int(kernel_size),
padding=1,
stride=1,
bias=False,
)
self.batch_norm_out = nn.BatchNorm2d(int(hidden_channels_2))
conv2_out = math.ceil(conv1_out + 2 - int(kernel_size) + 1)

self.dense_out = nn.Linear(
in_features=(conv2_out**2) * int(hidden_channels_2),
out_features=out_channels,
)

def forward(self, x):
x = self.conv1(x)
x = self.batch_norm_hidden(x)
x = nn.ReLU()(x)
x = self.conv2(x)
x = self.batch_norm_out(x)
x = nn.ReLU()(x)
x = nn.Flatten()(x)
x = self.dense_out(x)
return x
63 changes: 63 additions & 0 deletions WavPool/models/vanillaMLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.nn as nn


class VanillaMLP(nn.Module):
def __init__(
self, in_channels: int, hidden_size: int, out_channels: int, tail: bool = False
):
super().__init__()

self.flatten_input = nn.Flatten()
self.hidden_layer = nn.Linear(int(in_channels) ** 2, int(hidden_size))
self.output_layer = nn.Linear(int(hidden_size), out_channels)

if tail:
self.tail = nn.Softmax(dim=0)

def forward(self, x):
x = self.flatten_input(x)
x = self.hidden_layer(x)
x = self.output_layer(x)

if hasattr(self, "tail"):
x = self.tail(x)

return x


class BananaSplitMLP(nn.Module):
def __init__(self, in_channels, hidden_size, out_channels, tail=False):
super().__init__()

self.flatten_input = nn.Flatten()

self.hidden_layer_1 = nn.Linear(in_channels**2, hidden_size)
self.hidden_layer_2 = nn.Linear(in_channels**2, hidden_size)
self.hidden_layer_3 = nn.Linear(in_channels**2, hidden_size)

self.flatten = nn.Flatten(start_dim=1, end_dim=-1)

# Output of those tied 3 channel layers, and the flattened concat of those
self.output = nn.Linear(hidden_size * 3, out_channels)

if tail:
self.tail = nn.Softmax(dim=0)

def forward(self, x):
x = self.flatten_input(x)

channel_1 = self.hidden_layer_1(x)
channel_2 = self.hidden_layer_2(x)
channel_3 = self.hidden_layer_3(x)

concat = torch.stack([channel_1, channel_2, channel_3], dim=1)
# Flatten for the output dense

x = self.flatten(concat)
x = self.output(x)

if hasattr(self, "tail"):
x = self.tail(x)

return x
54 changes: 54 additions & 0 deletions WavPool/models/wavMLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import torch
import torch.nn as nn

from WavPool.utils.levels import Levels
from WavPool.models.wavelet_layer import MicroWav


class WavMLP(nn.Module):
def __init__(
self,
in_channels: int,
hidden_size: int,
out_channels: int,
level: int,
tail=False,
):
"""
Simplest version is an MLP that takes a single layer of the wavelet
And stacks the input channel-wise
in_channels: Input size, int; should be the x, y of the input image
out_channels: Out size of the mlp layer
level: Level of the wavelet used in the MLP
tail: Bool, If to add an activation at the end of the network to get a class output
"""
super().__init__()
assert level != 0, "Level 0 wavelet not supported"

# Wavelet transform of input x at a level as defined by the user
self.wav = MicroWav(
level=int(level), in_channels=in_channels, hidden_size=int(hidden_size)
)
# Flatten for when these are stacked
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)

# Output of those tied 3 channel layers, and the flattened concat of those
self.output = nn.Linear(int(hidden_size) * 3, out_channels)

# Activation for a classifier
if tail:
self.tail = nn.Softmax(dim=0)

def forward(self, x):
# forward pass through the network

x = self.wav(x)
# Flatten for the output dense
x = self.flatten(x)
x = self.output(x)

if hasattr(self, "tail"):
x = self.tail(x)

return x
55 changes: 55 additions & 0 deletions WavPool/models/wavelet_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Union
import numpy as np
import torch
import pywt
from kymatio.torch import Scattering2D
from WavPool.utils.levels import Levels


class WaveletLayer:
def __init__(
self, level: int, input_size: Union[int, None] = None, backend: str = "pywt"
) -> None:

layers = {
"pywt": lambda x: torch.Tensor(np.array(pywt.wavedec2(x, "db1")[level])),
"kymatio": lambda x: self.kymatio_layer(level=level, input_size=input_size)(
x
)[level],
}

assert backend in layers.keys()
self.layer = layers[backend]

def kymatio_layer(self, level, input_size):
# Ref: This code
# https://www.kymat.io/gallery_2d/plot_invert_scattering_torch.html#sphx-glr-gallery-2d-plot-invert-scattering-torch-py

scattering = Scattering2D(J=2, shape=(input_size, input_size), max_order=level)
return scattering

def __call__(self, x):
return self.layer(x)


class MicroWav(torch.nn.Module):
def __init__(self, level: int, in_channels: int, hidden_size: int) -> None:
super().__init__()
self.wavelet = WaveletLayer(level=level)
wav_in_channels = Levels.find_output_size(level, in_channels)
hidden_size = int(hidden_size)
self.flatten_wavelet = torch.nn.Flatten(start_dim=1, end_dim=-1)
# Channels for each of the 3 channels of the wavelet (Not including the downscaled original
self.channel_1_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size)
self.channel_2_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size)
self.channel_3_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size)

def forward(self, x):
x = self.wavelet(x)
# An MLP for each of the transformed levels
channel_1 = self.channel_1_mlp(self.flatten_wavelet(x[0]))
channel_2 = self.channel_2_mlp(self.flatten_wavelet(x[1]))
channel_3 = self.channel_3_mlp(self.flatten_wavelet(x[2]))
# stack the outputs
concat = torch.stack([channel_1, channel_2, channel_3], dim=-1)
return concat
Loading

0 comments on commit f8bba26

Please sign in to comment.