diff --git a/.gitignore b/.gitignore index e4fded8..a19a31d 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,4 @@ dmypy.json datasets/ wandb/ weights/ +data/ diff --git a/README.md b/README.md index 78b3998..1191232 100644 --- a/README.md +++ b/README.md @@ -74,8 +74,9 @@ cd kan-gpt git pull # Download Dataset -./scripts/download_webtext.sh -./scripts/download_tinyshakespeare.sh +python3 -m kan_gpt.download_dataset --dataset tinyshakespeare +python3 -m kan_gpt.download_dataset --dataset mnist +python3 -m kan_gpt.download_dataset --dataset webtext # Install dependencies for development pip install -r requirements.txt diff --git a/kan_gpt/VERSION b/kan_gpt/VERSION index 9084fa2..26aaba0 100644 --- a/kan_gpt/VERSION +++ b/kan_gpt/VERSION @@ -1 +1 @@ -1.1.0 +1.2.0 diff --git a/kan_gpt/dataset.py b/kan_gpt/dataset.py index 40cb3ee..30e5b1c 100644 --- a/kan_gpt/dataset.py +++ b/kan_gpt/dataset.py @@ -4,6 +4,7 @@ import pandas as pd import torch from torch.utils.data import Dataset +from torchvision import datasets, transforms from tqdm import tqdm from transformers import GPT2Tokenizer @@ -188,3 +189,64 @@ def __getitem__(self, idx): # y = y.unsqueeze(0) return x, y + + +class MNISTDataset(Dataset): + """ + MNIST Dataset for Transformer (GPT-style) processing + """ + + def __init__(self, split, model_type, block_size=784): # 784 + 1 for label + assert split in {"train", "test"} + + self.split = split + self.block_size = block_size + self.model_type = model_type + + # Load MNIST dataset + dataset = datasets.MNIST( + root="./data", + train=(split == "train"), + download=True, + transform=transforms.Compose( + [ + transforms.ToTensor(), + # transforms.Normalize((0.1307,), (0.3081,)), + ] + ), + ) + + self.data = [] + + for img, label in dataset: + # Flatten the image + flattened_img = img.view(-1) + # Convert to integer values (0-255) + flattened_img = (flattened_img * 255).long() + + # Append label to the end of flattened image + sample = torch.cat([flattened_img, torch.tensor([label])]) + + # Pad with zeros to reach block_size + if len(sample) < self.block_size: + padding = torch.zeros( + self.block_size - len(sample), dtype=torch.long + ) + sample = torch.cat([sample, padding]) + + self.data.append(sample) + + def __len__(self): + return len(self.data) + + def get_vocab_size(self): + return 256 # 0-255 pixel values + 10 classes + 1 padding token + + def get_block_size(self): + return self.block_size + + def __getitem__(self, idx): + x = self.data[idx][:-1] # Input: all but last token + y = self.data[idx][1:] # Target: all but first token + + return x, y diff --git a/kan_gpt/download_dataset.py b/kan_gpt/download_dataset.py index 8ebfe93..34bf682 100644 --- a/kan_gpt/download_dataset.py +++ b/kan_gpt/download_dataset.py @@ -2,6 +2,7 @@ from typing import List import requests +from tqdm import tqdm CHUNK_SIZE = 8192 @@ -13,15 +14,55 @@ def download_webtext( ): os.makedirs(download_path, exist_ok=True) for split in splits: - response = requests.get( - f"{base_url}/webtext.{split}.jsonl", stream=True - ) + file_path = f"{download_path}/webtext.{split}.jsonl" + file_url = f"{base_url}/webtext.{split}.jsonl" + + # Check if file exists and get its size + initial_pos = 0 + if os.path.exists(file_path): + initial_pos = os.path.getsize(file_path) + print( + f"Resuming download of webtext.{split}.jsonl from {initial_pos} bytes" # noqa + ) + + # Set up headers for resuming download + headers = {"Range": f"bytes={initial_pos}-"} + + response = requests.get(file_url, stream=True, headers=headers) + + # If the server doesn't support range requests, start over + if response.status_code == 416: + print( + f"Cannot resume download for webtext.{split}.jsonl. Starting from beginning." # noqa + ) + initial_pos = 0 + headers = {} + response = requests.get(file_url, stream=True) + response.raise_for_status() # Raise HTTP errors + total_size = ( + int(response.headers.get("content-length", 0)) + initial_pos + ) - # Open a local file for writing in binary mode - with open(f"{download_path}/webtext.{split}.jsonl", "wb") as file: + # Open the local file for writing in binary mode, appending if resuming + mode = "ab" if initial_pos > 0 else "wb" + with open(file_path, mode) as file, tqdm( + desc=f"Downloading webtext.{split}.jsonl", + initial=initial_pos, + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: for chunk in response.iter_content(chunk_size=CHUNK_SIZE): - file.write(chunk) + size = file.write(chunk) + progress_bar.update(size) + + # Verify file size after download + if os.path.getsize(file_path) != total_size: + print( + f"Warning: Downloaded file size does not match expected size for webtext.{split}.jsonl" # noqa + ) def download_tinyshakespeare( @@ -32,11 +73,26 @@ def download_tinyshakespeare( response = requests.get(f"{base_url}/input.txt", stream=True) response.raise_for_status() # Raise HTTP errors + total_size = int(response.headers.get("content-length", 0)) # Open a local file for writing in binary mode - with open(f"{download_path}/input.txt", "wb") as file: + with open(f"{download_path}/input.txt", "wb") as file, tqdm( + desc=f"Downloading {download_path}/input.txt", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: for chunk in response.iter_content(chunk_size=CHUNK_SIZE): - file.write(chunk) + size = file.write(chunk) + progress_bar.update(size) + + +def download_mnist(): + from .dataset import MNISTDataset + + MNISTDataset("train", "gpt2") + MNISTDataset("test", "gpt2") def main(args): @@ -44,6 +100,8 @@ def main(args): download_webtext() elif args.dataset == "tinyshakespeare": download_tinyshakespeare() + elif args.dataset == "mnist": + download_mnist() else: raise NotImplementedError @@ -54,7 +112,7 @@ def main(args): parser = argparse.ArgumentParser("KAN-GPT Trainer") parser.add_argument( "--dataset", - choices=["webtext", "tinyshakespeare"], + choices=["webtext", "tinyshakespeare", "mnist"], default="tinyshakespeare", ) diff --git a/kan_gpt/sweep.py b/kan_gpt/sweep.py index 2a43766..bc2ef1e 100644 --- a/kan_gpt/sweep.py +++ b/kan_gpt/sweep.py @@ -1,6 +1,6 @@ import torch -import wandb +import wandb from kan_gpt.train import main diff --git a/kan_gpt/train.py b/kan_gpt/train.py index adbd96a..0f985d0 100644 --- a/kan_gpt/train.py +++ b/kan_gpt/train.py @@ -4,13 +4,17 @@ import numpy as np import torch -import wandb from torch.nn import functional as F from torch.utils.data.dataloader import DataLoader from wandb.sdk.lib import RunDisabled from wandb.sdk.wandb_run import Run -from kan_gpt.dataset import TinyShakespeareDataset, WebTextDataset +import wandb +from kan_gpt.dataset import ( + MNISTDataset, + TinyShakespeareDataset, + WebTextDataset, +) from kan_gpt.mingpt.model import GPT as MLP_GPT from kan_gpt.mingpt.trainer import Trainer from kan_gpt.model import GPT as KAN_GPT @@ -83,13 +87,11 @@ def eval_split( x = x.to(trainer.device) y = y.to(trainer.device) - block_size = y.shape[1] - logits, loss = model(x, y) probs = F.softmax(logits, dim=-1) - _, y_pred = torch.topk(probs, k=block_size, dim=-1) + # _, y_pred = torch.topk(probs, k=block_size, dim=-1) perplexity, f1, precision, recall, cross_entropy = metrics( y=y.cpu().numpy(), y_pred=probs.cpu().numpy() @@ -148,6 +150,8 @@ def main(args, run=None): Dataset = WebTextDataset elif args.dataset == "tinyshakespeare": Dataset = TinyShakespeareDataset + elif args.dataset == "mnist": + Dataset = MNISTDataset # print an example instance of the dataset if args.dummy_dataset: @@ -287,13 +291,13 @@ def batch_end_callback(trainer): parser.add_argument("--model_type", default="gpt-mini") parser.add_argument("--dummy_dataset", action="store_true") parser.add_argument("--learning_rate", default=5e-3) - parser.add_argument("--max_iters", default=2000) + parser.add_argument("--max_iters", default=32000) parser.add_argument("--num_workers", default=0) - parser.add_argument("--batch_size", default=64) + parser.add_argument("--batch_size", default=2) parser.add_argument( "--dataset", - choices=["webtext", "tinyshakespeare"], + choices=["webtext", "tinyshakespeare", "mnist"], default="tinyshakespeare", ) parser.add_argument( diff --git a/requirements.txt b/requirements.txt index 369c0e5..51f2981 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,3 +9,4 @@ pandas>=2.0.3 requests>=2.31.0 transformers>=4.40.1 wandb>=0.16.6 +torchvision diff --git a/scripts/train.sh b/scripts/train.sh new file mode 100755 index 0000000..0981bfb --- /dev/null +++ b/scripts/train.sh @@ -0,0 +1,16 @@ +#!/bin/bash + +# Download Dataset +python3 -m kan_gpt.download_dataset --dataset tinyshakespeare +python3 -m kan_gpt.download_dataset --dataset mnist +python3 -m kan_gpt.download_dataset --dataset webtext + +# Train +python3 -m kan_gpt.train --dataset mnist --architecture MLP +python3 -m kan_gpt.train --dataset mnist --architecture KAN + +python3 -m kan_gpt.train --dataset tinyshakespeare --architecture MLP +python3 -m kan_gpt.train --dataset tinyshakespeare --architecture KAN + +python3 -m kan_gpt.train --dataset webtext --architecture MLP +python3 -m kan_gpt.train --dataset webtext --architecture KAN