-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathsample.py
149 lines (136 loc) · 4.71 KB
/
sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
from dataclasses import dataclass
import time
import dm_pix as pix
import einops
import jax
import jax.numpy as jnp
import numpy as np
from orbax.checkpoint import PyTreeCheckpointer
from PIL import Image, ImageDraw
import tyro
from genie import Genie
from utils.dataloader import get_dataloader
@dataclass
class Args:
# Experiment
seed: int = 0
seq_len: int = 16
image_channels: int = 3
image_resolution: int = 64
data_dir: str = "data/coinrun_episodes"
checkpoint: str = ""
# Sampling
batch_size: int = 1
maskgit_steps: int = 25
temperature: float = 1.0
sample_argmax: bool = True
start_frame: int = 0
# Tokenizer checkpoint
tokenizer_dim: int = 512
latent_patch_dim: int = 32
num_patch_latents: int = 1024
patch_size: int = 4
tokenizer_num_blocks: int = 8
tokenizer_num_heads: int = 8
# LAM checkpoint
lam_dim: int = 512
latent_action_dim: int = 32
num_latent_actions: int = 6
lam_patch_size: int = 16
lam_num_blocks: int = 8
lam_num_heads: int = 8
# Dynamics checkpoint
dyna_dim: int = 512
dyna_num_blocks: int = 12
dyna_num_heads: int = 8
args = tyro.cli(Args)
rng = jax.random.PRNGKey(args.seed)
# --- Load Genie checkpoint ---
genie = Genie(
# Tokenizer
in_dim=args.image_channels,
tokenizer_dim=args.tokenizer_dim,
latent_patch_dim=args.latent_patch_dim,
num_patch_latents=args.num_patch_latents,
patch_size=args.patch_size,
tokenizer_num_blocks=args.tokenizer_num_blocks,
tokenizer_num_heads=args.tokenizer_num_heads,
# LAM
lam_dim=args.lam_dim,
latent_action_dim=args.latent_action_dim,
num_latent_actions=args.num_latent_actions,
lam_patch_size=args.lam_patch_size,
lam_num_blocks=args.lam_num_blocks,
lam_num_heads=args.lam_num_heads,
# Dynamics
dyna_dim=args.dyna_dim,
dyna_num_blocks=args.dyna_num_blocks,
dyna_num_heads=args.dyna_num_heads,
)
rng, _rng = jax.random.split(rng)
image_shape = (args.image_resolution, args.image_resolution, args.image_channels)
dummy_inputs = dict(
videos=jnp.zeros((args.batch_size, args.seq_len, *image_shape), dtype=jnp.float32),
mask_rng=_rng,
)
rng, _rng = jax.random.split(rng)
params = genie.init(_rng, dummy_inputs)
ckpt = PyTreeCheckpointer().restore(args.checkpoint)["model"]["params"]["params"]
params["params"].update(ckpt)
# --- Define autoregressive sampling loop ---
def _autoreg_sample(rng, video_batch, action_batch):
vid = video_batch[:, : args.start_frame + 1]
for frame_idx in range(args.start_frame + 1, args.seq_len):
# --- Sample next frame ---
print("Frame", frame_idx)
rng, _rng = jax.random.split(rng)
batch = dict(videos=vid, latent_actions=action_batch[:, :frame_idx], rng=_rng)
new_frame = genie.apply(
params,
batch,
args.maskgit_steps,
args.temperature,
args.sample_argmax,
method=Genie.sample,
)
vid = jnp.concatenate([vid, new_frame], axis=1)
return vid
# --- Get video + latent actions ---
dataloader = get_dataloader(args.data_dir, args.seq_len, args.batch_size)
video_batch = next(iter(dataloader))
# Get latent actions from first video only
first_video = video_batch[:1]
batch = dict(videos=first_video)
action_batch = genie.apply(params, batch, False, method=Genie.vq_encode)
action_batch = action_batch.reshape(1, args.seq_len - 1, 1)
# Use actions from first video for all videos
action_batch = jnp.repeat(action_batch, video_batch.shape[0], axis=0)
# --- Sample + evaluate video ---
vid = _autoreg_sample(rng, video_batch, action_batch)
gt = video_batch[:, : vid.shape[1]].clip(0, 1).reshape(-1, *video_batch.shape[2:])
recon = vid.clip(0, 1).reshape(-1, *vid.shape[2:])
ssim = pix.ssim(gt[:, args.start_frame + 1 :], recon[:, args.start_frame + 1 :]).mean()
print(f"SSIM: {ssim}")
# --- Construct video ---
first_true = (video_batch[0:1] * 255).astype(np.uint8)
first_pred = (vid[0:1] * 255).astype(np.uint8)
first_video_comparison = np.zeros((2, *vid.shape[1:5]), dtype=np.uint8)
first_video_comparison[0] = first_true[:, : vid.shape[1]]
first_video_comparison[1] = first_pred
# For other videos, only show generated video
other_preds = (vid[1:] * 255).astype(np.uint8)
all_frames = np.concatenate([first_video_comparison, other_preds], axis=0)
flat_vid = einops.rearrange(all_frames, "n t h w c -> t h (n w) c")
# --- Save video ---
imgs = [Image.fromarray(img) for img in flat_vid]
# Write actions on each frame
for img, action in zip(imgs[1:], action_batch[0, :, 0]):
d = ImageDraw.Draw(img)
d.text((2, 2), f"{action}", fill=255)
imgs[0].save(
f"generation_{time.time()}.gif",
save_all=True,
append_images=imgs[1:],
duration=250,
loop=0,
)