Skip to content

Commit

Permalink
[Example] Add comments to example codes (#36)
Browse files Browse the repository at this point in the history
In this PR, we add comments explaining VeScale APIs in the nanoGPT
example.
  • Loading branch information
lichen225 authored May 24, 2024
1 parent 2a072bf commit 55c7f8a
Showing 1 changed file with 21 additions and 2 deletions.
23 changes: 21 additions & 2 deletions examples/nanogpt_4D_finetune/finetune_4D.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,10 @@ def main():
device = f"cuda:{rank}"
torch.cuda.set_device(device)
init_process_group(backend=backend, world_size=world_size, rank=rank)

# + + + VeScale API below
VESCALE_DEVICE_MESH.init_device_mesh(device, (dp_size, tp_size), mesh_dim_names=["DP", "TP"])
mesh = VESCALE_DEVICE_MESH.get()
# + + + VeScale API above
ddp_rank = get_rank() // tp_size
else:
rank = 0
Expand All @@ -137,7 +138,9 @@ def main():
if master_process:
os.makedirs(out_dir, exist_ok=True)
torch.manual_seed(1337)
# + + + VeScale API below
manual_seed(1337, mesh)
# + + + VeScale API above
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
device_type = "cuda" if "cuda" in device else "cpu" # for later use in torch.autocast
Expand All @@ -147,7 +150,13 @@ def main():
# poor man's data loader
data_dir = os.path.join("data", dataset)

# + + + Support larger batch size when running evaluation and only the master process do the random sampling
"""
Deterministic data loader for loss match:
This data loader ensures that the mini-batch sampling has identical behavior no matter how many GPUs are used.
In particular, at each training iteration, each rank samples a batch of indices under the identical RNG state.
Then, each Data Parallelism (DP) rank takes the corresponding subset of indices and fetches the corresponding sequences from the dataset.
"""

def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
Expand All @@ -166,9 +175,11 @@ def get_batch(split, bsz=batch_size, lbsz=local_batch_size):
x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
else:
x, y = x.to(device), y.to(device)
# + + + VeScale API below
if ddp:
x = distribute_tensor(x, VESCALE_DEVICE_MESH["TP"], [Replicate()])
y = distribute_tensor(y, VESCALE_DEVICE_MESH["TP"], [Replicate()])
# + + + VeScale API above
return x, y

# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
Expand Down Expand Up @@ -335,9 +346,11 @@ def get_lr(it):
wandb.init(project=wandb_project, name=wandb_run_name, config=config)

# Load checkpoint
# + + + VeScale Load checkpoint
if load_checkpoint_path:
checkpoint_state = {"model": model, "optimizer": optimizer}
vescale.checkpoint.load(load_checkpoint_path, checkpoint_state)
# + + + VeScale API above
# training loop
X, Y = get_batch("train") # fetch the very first batch
t0 = time.time()
Expand Down Expand Up @@ -369,14 +382,18 @@ def get_lr(it):
if iter_num > 0:
# When iter_num == 0, the training does not start sotoptimizer state is empty,
# Don't save checkpoint
# + + + VeScale API below
checkpoint_state = {"model": model, "optimizer": optimizer}
vescale.checkpoint.save(os.path.join(save_checkpoint_path, f"iter_{iter_num}"), checkpoint_state)
# + + + VeScale API above
if iter_num == 0 and eval_only:
break

# forward backward update, with optional gradient accumulation to simulate larger batch size
# + + + VeScale API below
if ddp:
model.zero_grad_buffer()
# + + + VeScale API above
for micro_step in range(gradient_accumulation_steps):
# with ctx:
logits, loss = model(X, Y)
Expand All @@ -385,8 +402,10 @@ def get_lr(it):
X, Y = get_batch("train")
# backward pass
loss.backward()
# + + + VeScale API below
if ddp:
model.finish_grad_sync()
# + + + VeScale API above
optimizer.step()
# flush the gradients as soon as we can, no need for this memory anymore
optimizer.zero_grad(set_to_none=True)
Expand Down

0 comments on commit 55c7f8a

Please sign in to comment.