-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Completed project with cleaned code and readme
- Loading branch information
Showing
29 changed files
with
1,168 additions
and
2,353 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from WavPool.data_generators.cifar_generator import CIFARGenerator | ||
from WavPool.data_generators.mnist_generator import MNISTGenerator | ||
from WavPool.data_generators.fashion_mnist_generator import FashionMNISTGenerator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
from WavPool.data_generators.data_generator import DataGenerator | ||
from torchvision import transforms | ||
from torchvision.datasets import CIFAR10 | ||
|
||
|
||
class CIFARGenerator(DataGenerator): | ||
def __init__(self): | ||
grayscale_transforms = transforms.Compose( | ||
[transforms.Grayscale(num_output_channels=1), transforms.ToTensor()] | ||
) | ||
|
||
dataset = CIFAR10( | ||
root="wavNN/data/cifar10", | ||
download=True, | ||
train=True, | ||
transform=grayscale_transforms, | ||
) | ||
super().__init__(dataset=dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from torch.utils.data import DataLoader, Subset | ||
import numpy as np | ||
|
||
|
||
class DataGenerator: | ||
def __init__(self, dataset): | ||
self.dataset = dataset | ||
|
||
def __call__(self, *args, **kwargs): | ||
sample_size = kwargs["sample_size"] | ||
split = False if not "split" in kwargs else kwargs["split"] | ||
if type(sample_size) == list: | ||
split = True | ||
|
||
batch_size = 64 if "batch_size" not in kwargs else kwargs["batch_size"] | ||
shuffle = False if "shuffle" not in kwargs else kwargs["shuffle"] | ||
|
||
if split: | ||
if type(sample_size) == int: | ||
sample_size = [sample_size for _ in range(3)] | ||
|
||
assert sum(sample_size) <= len(self.dataset), ( | ||
f"" | ||
f"Too many requested samples, " | ||
f"decreases your sample size to less " | ||
f"than {len(self.dataset)}" | ||
) | ||
|
||
assert len(sample_size) == 3, ( | ||
"The sample size of validation " | ||
"and test must be individually specified" | ||
) | ||
|
||
# this is quick and dirty. Ideally I'd be shuffling when i load in. | ||
# But I'm chronically lazy and this is what I'm doing | ||
# I can change it later. | ||
|
||
samples = np.cumsum(sample_size) | ||
training_data = Subset(self.dataset, list(range(0, samples[0]))) | ||
val_data = Subset(self.dataset, list(range(samples[0], samples[1]))) | ||
test_data = Subset(self.dataset, list(range(samples[1], samples[2]))) | ||
|
||
training = DataLoader(training_data, batch_size=batch_size, shuffle=shuffle) | ||
validation = DataLoader(val_data, batch_size=batch_size, shuffle=shuffle) | ||
test = DataLoader(test_data, batch_size=batch_size, shuffle=shuffle) | ||
|
||
else: | ||
assert sample_size <= len(self.dataset), ( | ||
f"Too many requested samples, decreases your" | ||
f" sample size to less than {len(self.dataset)}" | ||
) | ||
|
||
training_data = Subset(self.dataset, list(range(0, sample_size))) | ||
|
||
training = DataLoader(training_data, batch_size=batch_size, shuffle=shuffle) | ||
validation = None | ||
test = None | ||
|
||
return {"training": training, "validation": validation, "test": test} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from WavPool.data_generators.data_generator import DataGenerator | ||
from torchvision.transforms import ToTensor | ||
from torchvision.datasets import FashionMNIST | ||
|
||
|
||
class FashionMNISTGenerator(DataGenerator): | ||
def __init__(self): | ||
dataset = FashionMNIST( | ||
root="wavNN/data/fashionmnist", | ||
download=True, | ||
train=True, | ||
transform=ToTensor(), | ||
) | ||
super().__init__(dataset=dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
from WavPool.data_generators.data_generator import DataGenerator | ||
from torchvision.transforms import ToTensor | ||
from torchvision.datasets import MNIST | ||
|
||
|
||
class MNISTGenerator(DataGenerator): | ||
def __init__(self): | ||
dataset = MNIST( | ||
root="wavNN/data/mnist", download=True, train=True, transform=ToTensor() | ||
) | ||
super().__init__(dataset=dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from WavPool.models.wavMLP import WavMLP | ||
from WavPool.models.wavpool import WavPool | ||
from WavPool.models.vanillaMLP import VanillaMLP | ||
from WavPool.models.vanillaCNN import VanillaCNN |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import torch.nn as nn | ||
import math | ||
|
||
|
||
class VanillaCNN(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
kernel_size: int, | ||
out_channels: int, | ||
hidden_channels_1: int = 1, | ||
hidden_channels_2: int = 1, | ||
) -> None: | ||
super().__init__() | ||
|
||
self.conv1 = nn.Conv2d( | ||
in_channels=1, | ||
out_channels=int(hidden_channels_1), | ||
kernel_size=int(kernel_size), | ||
padding=1, | ||
stride=1, | ||
bias=False, | ||
) | ||
self.batch_norm_hidden = nn.BatchNorm2d(int(hidden_channels_1)) | ||
conv1_out = math.ceil(in_channels + 2 - int(kernel_size) + 1) | ||
|
||
self.conv2 = nn.Conv2d( | ||
in_channels=int(hidden_channels_1), | ||
out_channels=int(hidden_channels_2), | ||
kernel_size=int(kernel_size), | ||
padding=1, | ||
stride=1, | ||
bias=False, | ||
) | ||
self.batch_norm_out = nn.BatchNorm2d(int(hidden_channels_2)) | ||
conv2_out = math.ceil(conv1_out + 2 - int(kernel_size) + 1) | ||
|
||
self.dense_out = nn.Linear( | ||
in_features=(conv2_out**2) * int(hidden_channels_2), | ||
out_features=out_channels, | ||
) | ||
|
||
def forward(self, x): | ||
x = self.conv1(x) | ||
x = self.batch_norm_hidden(x) | ||
x = nn.ReLU()(x) | ||
x = self.conv2(x) | ||
x = self.batch_norm_out(x) | ||
x = nn.ReLU()(x) | ||
x = nn.Flatten()(x) | ||
x = self.dense_out(x) | ||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
class VanillaMLP(nn.Module): | ||
def __init__( | ||
self, in_channels: int, hidden_size: int, out_channels: int, tail: bool = False | ||
): | ||
super().__init__() | ||
|
||
self.flatten_input = nn.Flatten() | ||
self.hidden_layer = nn.Linear(int(in_channels) ** 2, int(hidden_size)) | ||
self.output_layer = nn.Linear(int(hidden_size), out_channels) | ||
|
||
if tail: | ||
self.tail = nn.Softmax(dim=0) | ||
|
||
def forward(self, x): | ||
x = self.flatten_input(x) | ||
x = self.hidden_layer(x) | ||
x = self.output_layer(x) | ||
|
||
if hasattr(self, "tail"): | ||
x = self.tail(x) | ||
|
||
return x | ||
|
||
|
||
class BananaSplitMLP(nn.Module): | ||
def __init__(self, in_channels, hidden_size, out_channels, tail=False): | ||
super().__init__() | ||
|
||
self.flatten_input = nn.Flatten() | ||
|
||
self.hidden_layer_1 = nn.Linear(in_channels**2, hidden_size) | ||
self.hidden_layer_2 = nn.Linear(in_channels**2, hidden_size) | ||
self.hidden_layer_3 = nn.Linear(in_channels**2, hidden_size) | ||
|
||
self.flatten = nn.Flatten(start_dim=1, end_dim=-1) | ||
|
||
# Output of those tied 3 channel layers, and the flattened concat of those | ||
self.output = nn.Linear(hidden_size * 3, out_channels) | ||
|
||
if tail: | ||
self.tail = nn.Softmax(dim=0) | ||
|
||
def forward(self, x): | ||
x = self.flatten_input(x) | ||
|
||
channel_1 = self.hidden_layer_1(x) | ||
channel_2 = self.hidden_layer_2(x) | ||
channel_3 = self.hidden_layer_3(x) | ||
|
||
concat = torch.stack([channel_1, channel_2, channel_3], dim=1) | ||
# Flatten for the output dense | ||
|
||
x = self.flatten(concat) | ||
x = self.output(x) | ||
|
||
if hasattr(self, "tail"): | ||
x = self.tail(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import torch | ||
import torch.nn as nn | ||
|
||
from WavPool.utils.levels import Levels | ||
from WavPool.models.wavelet_layer import MicroWav | ||
|
||
|
||
class WavMLP(nn.Module): | ||
def __init__( | ||
self, | ||
in_channels: int, | ||
hidden_size: int, | ||
out_channels: int, | ||
level: int, | ||
tail=False, | ||
): | ||
""" | ||
Simplest version is an MLP that takes a single layer of the wavelet | ||
And stacks the input channel-wise | ||
in_channels: Input size, int; should be the x, y of the input image | ||
out_channels: Out size of the mlp layer | ||
level: Level of the wavelet used in the MLP | ||
tail: Bool, If to add an activation at the end of the network to get a class output | ||
""" | ||
super().__init__() | ||
assert level != 0, "Level 0 wavelet not supported" | ||
|
||
# Wavelet transform of input x at a level as defined by the user | ||
self.wav = MicroWav( | ||
level=int(level), in_channels=in_channels, hidden_size=int(hidden_size) | ||
) | ||
# Flatten for when these are stacked | ||
self.flatten = nn.Flatten(start_dim=1, end_dim=-1) | ||
|
||
# Output of those tied 3 channel layers, and the flattened concat of those | ||
self.output = nn.Linear(int(hidden_size) * 3, out_channels) | ||
|
||
# Activation for a classifier | ||
if tail: | ||
self.tail = nn.Softmax(dim=0) | ||
|
||
def forward(self, x): | ||
# forward pass through the network | ||
|
||
x = self.wav(x) | ||
# Flatten for the output dense | ||
x = self.flatten(x) | ||
x = self.output(x) | ||
|
||
if hasattr(self, "tail"): | ||
x = self.tail(x) | ||
|
||
return x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Union | ||
import numpy as np | ||
import torch | ||
import pywt | ||
from kymatio.torch import Scattering2D | ||
from WavPool.utils.levels import Levels | ||
|
||
|
||
class WaveletLayer: | ||
def __init__( | ||
self, level: int, input_size: Union[int, None] = None, backend: str = "pywt" | ||
) -> None: | ||
|
||
layers = { | ||
"pywt": lambda x: torch.Tensor(np.array(pywt.wavedec2(x, "db1")[level])), | ||
"kymatio": lambda x: self.kymatio_layer(level=level, input_size=input_size)( | ||
x | ||
)[level], | ||
} | ||
|
||
assert backend in layers.keys() | ||
self.layer = layers[backend] | ||
|
||
def kymatio_layer(self, level, input_size): | ||
# Ref: This code | ||
# https://www.kymat.io/gallery_2d/plot_invert_scattering_torch.html#sphx-glr-gallery-2d-plot-invert-scattering-torch-py | ||
|
||
scattering = Scattering2D(J=2, shape=(input_size, input_size), max_order=level) | ||
return scattering | ||
|
||
def __call__(self, x): | ||
return self.layer(x) | ||
|
||
|
||
class MicroWav(torch.nn.Module): | ||
def __init__(self, level: int, in_channels: int, hidden_size: int) -> None: | ||
super().__init__() | ||
self.wavelet = WaveletLayer(level=level) | ||
wav_in_channels = Levels.find_output_size(level, in_channels) | ||
hidden_size = int(hidden_size) | ||
self.flatten_wavelet = torch.nn.Flatten(start_dim=1, end_dim=-1) | ||
# Channels for each of the 3 channels of the wavelet (Not including the downscaled original | ||
self.channel_1_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size) | ||
self.channel_2_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size) | ||
self.channel_3_mlp = torch.nn.Linear(wav_in_channels**2, hidden_size) | ||
|
||
def forward(self, x): | ||
x = self.wavelet(x) | ||
# An MLP for each of the transformed levels | ||
channel_1 = self.channel_1_mlp(self.flatten_wavelet(x[0])) | ||
channel_2 = self.channel_2_mlp(self.flatten_wavelet(x[1])) | ||
channel_3 = self.channel_3_mlp(self.flatten_wavelet(x[2])) | ||
# stack the outputs | ||
concat = torch.stack([channel_1, channel_2, channel_3], dim=-1) | ||
return concat |
Oops, something went wrong.