Skip to content

Commit

Permalink
pre-commit & and still need to be re-format into current code-base
Browse files Browse the repository at this point in the history
Signed-off-by: lawrence-cj <[email protected]>
  • Loading branch information
lawrence-cj committed Nov 27, 2024
1 parent 9ff9f66 commit 16b7ad5
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 96 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ bash train_scripts/train.sh \
```

Local training with bucketing and VAE embedding caching:

```bash
# Prepare buckets and cache VAE embeds
python train_scripts/make_buckets.py \
Expand All @@ -235,11 +236,10 @@ bash train_scripts/train_local.sh \
--data.buckets_file=buckets.json \
--train.train_batch_size=30
```

Using the AdamW optimizer, training with a batch size of 30 on 1024x1024 resolution consumes ~48GB VRAM on an NVIDIA A6000 GPU.
Each training iteration takes ~7.5 seconds.



# 💻 4. Metric toolkit

Refer to [Toolkit Manual](asset/docs/metrics_toolkit.md).
Expand Down
116 changes: 63 additions & 53 deletions train_scripts/make_buckets.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import torch
from diffusion.model.builder import get_vae, vae_encode
from diffusion.utils.config import SanaConfig
import pyrallis
from PIL import Image
import torchvision.transforms as T
import json
import math
import os
import os.path as osp
from torchvision.transforms import InterpolationMode
import json
from itertools import chain

import pyrallis
import torch
import torchvision.transforms as T
from PIL import Image
from torch.utils.data import DataLoader
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
import math
from itertools import chain

from diffusion.model.builder import get_vae, vae_encode
from diffusion.utils.config import SanaConfig


@pyrallis.wrap()
def main(config: SanaConfig) -> None:
Expand All @@ -22,16 +25,16 @@ def main(config: SanaConfig) -> None:
step = 32

ratios_array = []
while(min_size != max_size):
while min_size != max_size:
width = int(preferred_pixel_count / min_size)
if(width % step != 0):
mod = width % step
if(mod < step//2):
if width % step != 0:
mod = width % step
if mod < step // 2:
width -= mod
else:
width += step - mod

ratio = min_size / width
ratio = min_size / width

ratios_array.append((ratio, (int(min_size), width)))
min_size += step
Expand All @@ -43,25 +46,31 @@ def get_closest_ratio(height: float, width: float):

def get_preffered_size(height: float, width: float):
pixel_count = height * width

scale = math.sqrt(pixel_count / preferred_pixel_count)
return height / scale, width / scale

class BucketsDataset(torch.utils.data.Dataset):
def __init__(self, data_dir, skip_files):
valid_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"}
self.files = ([
osp.join(data_dir, f) for f in os.listdir(data_dir)
if osp.isfile(osp.join(data_dir, f)) and osp.splitext(f)[1].lower() in valid_extensions and osp.join(data_dir, f) not in skip_files ])

self.transform = T.Compose([
T.ToTensor(),
T.Normalize([0.5], [0.5]),
])

self.files = [
osp.join(data_dir, f)
for f in os.listdir(data_dir)
if osp.isfile(osp.join(data_dir, f))
and osp.splitext(f)[1].lower() in valid_extensions
and osp.join(data_dir, f) not in skip_files
]

self.transform = T.Compose(
[
T.ToTensor(),
T.Normalize([0.5], [0.5]),
]
)

def __len__(self):
return len(self.files)

def __getitem__(self, idx):
path = self.files[idx]
img = Image.open(path).convert("RGB")
Expand All @@ -70,26 +79,28 @@ def __getitem__(self, idx):

crop = T.Resize(ratio[1], interpolation=InterpolationMode.BICUBIC)
return {
'img': self.transform(crop(img)),
'size': torch.tensor([ratio[1][0], ratio[1][1]]),
'prefsize': torch.tensor([prefsize[0], prefsize[1]]),
'ratio': ratio[0],
'path': path
"img": self.transform(crop(img)),
"size": torch.tensor([ratio[1][0], ratio[1][1]]),
"prefsize": torch.tensor([prefsize[0], prefsize[1]]),
"ratio": ratio[0],
"path": path,
}

vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, "cuda").to(torch.float16)

def encode_images(batch, vae):
with torch.no_grad():
z = vae_encode(
config.vae.vae_type, vae, batch,
config.vae.vae_type,
vae,
batch,
sample_posterior=config.vae.sample_posterior, # Adjust as necessary
device="cuda"
device="cuda",
)
return z

if os.path.exists(config.data.buckets_file):
with open(config.data.buckets_file, 'r') as json_file:
with open(config.data.buckets_file) as json_file:
buckets = json.load(json_file)
existings_images = set(chain.from_iterable(buckets.values()))
else:
Expand All @@ -101,36 +112,35 @@ def add_to_list(key, item):
buckets[key].append(item)
else:
buckets[key] = [item]

for path in config.data.data_dir:
print(f'Processing {path}')
print(f"Processing {path}")
dataset = BucketsDataset(path, existings_images)
dataloader = DataLoader(dataset, batch_size=1)
for batch in tqdm(dataloader):
img = batch['img']
size = batch['size']
ratio = batch['ratio']
image_path = batch['path']
prefsize = batch['prefsize']
img = batch["img"]
size = batch["size"]
ratio = batch["ratio"]
image_path = batch["path"]
prefsize = batch["prefsize"]

encoded = encode_images(img.to(torch.half), vae)

for i in range(0, len(encoded)):
filename_wo_ext = os.path.splitext(os.path.basename(image_path[i]))[0]
add_to_list(str(ratio[i].item()), image_path[i])

torch.save({
'img': encoded[i].detach().clone(),
'size': size[i],
'prefsize': prefsize[i],
'ratio': ratio[i]
}, f"{path}/{filename_wo_ext}_img.npz")

with open(config.data.buckets_file, 'w') as json_file:

torch.save(
{"img": encoded[i].detach().clone(), "size": size[i], "prefsize": prefsize[i], "ratio": ratio[i]},
f"{path}/{filename_wo_ext}_img.npz",
)

with open(config.data.buckets_file, "w") as json_file:
json.dump(buckets, json_file, indent=4)

for ratio in buckets.keys():
print(f'{float(ratio):.2f}: {len(buckets[ratio])}')
print(f"{float(ratio):.2f}: {len(buckets[ratio])}")


if __name__ == "__main__":
main()
main()
Loading

0 comments on commit 16b7ad5

Please sign in to comment.