Skip to content

Commit

Permalink
feat(kan_gpt/dataset.py): mnist suppport
Browse files Browse the repository at this point in the history
  • Loading branch information
AdityaNG committed Sep 9, 2024
1 parent 535665a commit 6a38e12
Show file tree
Hide file tree
Showing 9 changed files with 164 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ dmypy.json
datasets/
wandb/
weights/
data/
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ cd kan-gpt
git pull

# Download Dataset
./scripts/download_webtext.sh
./scripts/download_tinyshakespeare.sh
python3 -m kan_gpt.download_dataset --dataset tinyshakespeare
python3 -m kan_gpt.download_dataset --dataset mnist
python3 -m kan_gpt.download_dataset --dataset webtext

# Install dependencies for development
pip install -r requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion kan_gpt/VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.1.0
1.2.0
62 changes: 62 additions & 0 deletions kan_gpt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
from tqdm import tqdm
from transformers import GPT2Tokenizer

Expand Down Expand Up @@ -188,3 +189,64 @@ def __getitem__(self, idx):
# y = y.unsqueeze(0)

return x, y


class MNISTDataset(Dataset):
"""
MNIST Dataset for Transformer (GPT-style) processing
"""

def __init__(self, split, model_type, block_size=784): # 784 + 1 for label
assert split in {"train", "test"}

self.split = split
self.block_size = block_size
self.model_type = model_type

# Load MNIST dataset
dataset = datasets.MNIST(
root="./data",
train=(split == "train"),
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
# transforms.Normalize((0.1307,), (0.3081,)),
]
),
)

self.data = []

for img, label in dataset:
# Flatten the image
flattened_img = img.view(-1)
# Convert to integer values (0-255)
flattened_img = (flattened_img * 255).long()

# Append label to the end of flattened image
sample = torch.cat([flattened_img, torch.tensor([label])])

# Pad with zeros to reach block_size
if len(sample) < self.block_size:
padding = torch.zeros(
self.block_size - len(sample), dtype=torch.long
)
sample = torch.cat([sample, padding])

self.data.append(sample)

def __len__(self):
return len(self.data)

def get_vocab_size(self):
return 256 # 0-255 pixel values + 10 classes + 1 padding token

def get_block_size(self):
return self.block_size

def __getitem__(self, idx):
x = self.data[idx][:-1] # Input: all but last token
y = self.data[idx][1:] # Target: all but first token

return x, y
76 changes: 67 additions & 9 deletions kan_gpt/download_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import List

import requests
from tqdm import tqdm

CHUNK_SIZE = 8192

Expand All @@ -13,15 +14,55 @@ def download_webtext(
):
os.makedirs(download_path, exist_ok=True)
for split in splits:
response = requests.get(
f"{base_url}/webtext.{split}.jsonl", stream=True
)
file_path = f"{download_path}/webtext.{split}.jsonl"
file_url = f"{base_url}/webtext.{split}.jsonl"

# Check if file exists and get its size
initial_pos = 0
if os.path.exists(file_path):
initial_pos = os.path.getsize(file_path)
print(
f"Resuming download of webtext.{split}.jsonl from {initial_pos} bytes" # noqa
)

# Set up headers for resuming download
headers = {"Range": f"bytes={initial_pos}-"}

response = requests.get(file_url, stream=True, headers=headers)

# If the server doesn't support range requests, start over
if response.status_code == 416:
print(
f"Cannot resume download for webtext.{split}.jsonl. Starting from beginning." # noqa
)
initial_pos = 0
headers = {}
response = requests.get(file_url, stream=True)

response.raise_for_status() # Raise HTTP errors
total_size = (
int(response.headers.get("content-length", 0)) + initial_pos
)

# Open a local file for writing in binary mode
with open(f"{download_path}/webtext.{split}.jsonl", "wb") as file:
# Open the local file for writing in binary mode, appending if resuming
mode = "ab" if initial_pos > 0 else "wb"
with open(file_path, mode) as file, tqdm(
desc=f"Downloading webtext.{split}.jsonl",
initial=initial_pos,
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
file.write(chunk)
size = file.write(chunk)
progress_bar.update(size)

# Verify file size after download
if os.path.getsize(file_path) != total_size:
print(
f"Warning: Downloaded file size does not match expected size for webtext.{split}.jsonl" # noqa
)


def download_tinyshakespeare(
Expand All @@ -32,18 +73,35 @@ def download_tinyshakespeare(

response = requests.get(f"{base_url}/input.txt", stream=True)
response.raise_for_status() # Raise HTTP errors
total_size = int(response.headers.get("content-length", 0))

# Open a local file for writing in binary mode
with open(f"{download_path}/input.txt", "wb") as file:
with open(f"{download_path}/input.txt", "wb") as file, tqdm(
desc=f"Downloading {download_path}/input.txt",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
file.write(chunk)
size = file.write(chunk)
progress_bar.update(size)


def download_mnist():
from .dataset import MNISTDataset

MNISTDataset("train", "gpt2")
MNISTDataset("test", "gpt2")


def main(args):
if args.dataset == "webtext":
download_webtext()
elif args.dataset == "tinyshakespeare":
download_tinyshakespeare()
elif args.dataset == "mnist":
download_mnist()
else:
raise NotImplementedError

Expand All @@ -54,7 +112,7 @@ def main(args):
parser = argparse.ArgumentParser("KAN-GPT Trainer")
parser.add_argument(
"--dataset",
choices=["webtext", "tinyshakespeare"],
choices=["webtext", "tinyshakespeare", "mnist"],
default="tinyshakespeare",
)

Expand Down
2 changes: 1 addition & 1 deletion kan_gpt/sweep.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
import wandb

import wandb
from kan_gpt.train import main


Expand Down
20 changes: 12 additions & 8 deletions kan_gpt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@

import numpy as np
import torch
import wandb
from torch.nn import functional as F
from torch.utils.data.dataloader import DataLoader
from wandb.sdk.lib import RunDisabled
from wandb.sdk.wandb_run import Run

from kan_gpt.dataset import TinyShakespeareDataset, WebTextDataset
import wandb
from kan_gpt.dataset import (
MNISTDataset,
TinyShakespeareDataset,
WebTextDataset,
)
from kan_gpt.mingpt.model import GPT as MLP_GPT
from kan_gpt.mingpt.trainer import Trainer
from kan_gpt.model import GPT as KAN_GPT
Expand Down Expand Up @@ -83,13 +87,11 @@ def eval_split(
x = x.to(trainer.device)
y = y.to(trainer.device)

block_size = y.shape[1]

logits, loss = model(x, y)

probs = F.softmax(logits, dim=-1)

_, y_pred = torch.topk(probs, k=block_size, dim=-1)
# _, y_pred = torch.topk(probs, k=block_size, dim=-1)

perplexity, f1, precision, recall, cross_entropy = metrics(
y=y.cpu().numpy(), y_pred=probs.cpu().numpy()
Expand Down Expand Up @@ -148,6 +150,8 @@ def main(args, run=None):
Dataset = WebTextDataset
elif args.dataset == "tinyshakespeare":
Dataset = TinyShakespeareDataset
elif args.dataset == "mnist":
Dataset = MNISTDataset

# print an example instance of the dataset
if args.dummy_dataset:
Expand Down Expand Up @@ -287,13 +291,13 @@ def batch_end_callback(trainer):
parser.add_argument("--model_type", default="gpt-mini")
parser.add_argument("--dummy_dataset", action="store_true")
parser.add_argument("--learning_rate", default=5e-3)
parser.add_argument("--max_iters", default=2000)
parser.add_argument("--max_iters", default=32000)
parser.add_argument("--num_workers", default=0)
parser.add_argument("--batch_size", default=64)
parser.add_argument("--batch_size", default=2)

parser.add_argument(
"--dataset",
choices=["webtext", "tinyshakespeare"],
choices=["webtext", "tinyshakespeare", "mnist"],
default="tinyshakespeare",
)
parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ pandas>=2.0.3
requests>=2.31.0
transformers>=4.40.1
wandb>=0.16.6
torchvision
16 changes: 16 additions & 0 deletions scripts/train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/bin/bash

# Download Dataset
python3 -m kan_gpt.download_dataset --dataset tinyshakespeare
python3 -m kan_gpt.download_dataset --dataset mnist
python3 -m kan_gpt.download_dataset --dataset webtext

# Train
python3 -m kan_gpt.train --dataset mnist --architecture MLP
python3 -m kan_gpt.train --dataset mnist --architecture KAN

python3 -m kan_gpt.train --dataset tinyshakespeare --architecture MLP
python3 -m kan_gpt.train --dataset tinyshakespeare --architecture KAN

python3 -m kan_gpt.train --dataset webtext --architecture MLP
python3 -m kan_gpt.train --dataset webtext --architecture KAN

0 comments on commit 6a38e12

Please sign in to comment.