Skip to content

Commit

Permalink
checkpoint on v6e
Browse files Browse the repository at this point in the history
  • Loading branch information
qihqi committed Dec 20, 2024
1 parent da24e94 commit 6891f29
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 27 deletions.
23 changes: 23 additions & 0 deletions experimental/torch_xla2/examples/train_llama_torchtitan/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,30 @@
Training based on torchtitan llama model
====================================

## Install dependencies

```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
pip install optax fire tensorflow tensorboard-plugin-profile
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu

cd ~
git clone https://github.com/pytorch/torchtitan.git
cd torchtitan
pip install -r requirements.txt
pip install .

cd ~
git clone https://github.com/pytorch/xla.git
cd xla/experimental/torch_xla2
pip install -e .
```

Run the train script

```bash
export LIBTPU_INIT_ARGS="--xla_tpu_use_minor_sharding_for_major_trivial_input=true --xla_tpu_relayout_group_size_threshold_for_reduce_scatter=1 --xla_tpu_scoped_vmem_limit_kib=98304 --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"

python train_llama.py
```

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,6 @@
from typing import Tuple
from collections import defaultdict
import functools


def _setup_default_env():
os.environ.setdefault('GRPC_VERBOSITY', 'ERROR')
os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1')
# only need for tpu v4
# os.environ.setdefault('TPU_MEGACORE', 'megacore_dense')
tpu_args = "--xla_tpu_enable_async_collective_fusion_fuse_all_gather=true --xla_tpu_megacore_fusion_allow_ags=false --xla_enable_async_collective_permute=true --xla_tpu_enable_ag_backward_pipelining=true --xla_tpu_enable_data_parallel_all_reduce_opt=true --xla_tpu_data_parallel_opt_different_sized_ops=true --xla_tpu_enable_async_collective_fusion=true --xla_tpu_enable_async_collective_fusion_multiple_steps=true --xla_tpu_overlap_compute_collective_tc=true --xla_enable_async_all_gather=true"
os.environ.setdefault('LIBTPU_INIT_ARGS', tpu_args)

_setup_default_env()

import torch
import torch.nn.functional
from torch.utils import _pytree as pytree
Expand All @@ -35,29 +23,31 @@ def _setup_default_env():

P = jax.sharding.PartitionSpec

global_axis: Tuple[str, str] = ('fsdp', )
num_global_devices = jax.device_count()
num_local_devices = jax.local_device_count()
num_partitions = (num_global_devices, )


def sharded_device_put(tensor, sharding):
def sharded_device_put(tensor: jax.Array, sharding) -> jax.Array:
if isinstance(tensor, tuple):
return tuple(sharded_device_put(t, sharding) for t in tensor)

if num_global_devices == num_local_devices:
return jax.device_put(tensor, sharding)

# NOTE: at here, num_global_devices != num_local_devices
# meaning we are in multi-host setup. Each host will run the same process
# and each process only need to handle the devices accessible to this host.
shape = tensor.shape
x_split = [jax.device_put(tensor[i], device) for device, i in sharding.addressable_devices_indices_map(shape).items()]
x_split = [jax.device_put(tensor[i], device)
for device, i in sharding.addressable_devices_indices_map(shape).items()]
return jax.make_array_from_single_device_arrays(shape, sharding, x_split)


class Trainer:

def __init__(self, mesh):
self.mesh = mesh
self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis))
self.x_sharding = jax.sharding.NamedSharding(self.mesh, P('fsdp'))
self.replicated = jax.sharding.NamedSharding(self.mesh, P())

def fit(self, model, loss_fn, data_loader):
Expand All @@ -78,7 +68,7 @@ def model_fn(weights, buffers, args):

train_step = torch_xla2.train.make_train_step(
model_fn, loss_fn, jax_optimizer,
remat_policy=jax.checkpoint_policies.nothing_saveable,
remat_policy=jax.checkpoint_policies.offload_dot_with_no_batch_dims('device', 'pinned_host'),
mark_fsdp_sharding_axis='fsdp')

print('Begining training')
Expand All @@ -93,8 +83,8 @@ def model_fn(weights, buffers, args):
labels = labels.to('jax')

# Shard them on batch dim for fsdp
inputs.apply_(sharded_device_put, self.x_sharding)
labels.apply_(sharded_device_put, self.x_sharding)
inputs.apply_jax_(sharded_device_put, self.x_sharding)
labels.apply_jax_(sharded_device_put, self.x_sharding)

print('INPUT shape', inputs.shape)
step_start = time.perf_counter()
Expand Down Expand Up @@ -147,10 +137,11 @@ def main(
use_scan = True,
):
torch_xla2.enable_globally()
torch_xla2.enable_performance_mode()
#logging.getLogger("jax").setLevel(logging.DEBUG)
print(f"Running with parameters {locals()}")

mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', ))
mesh = jax.make_mesh((num_global_devices, ), ('fsdp', ))
if use_scan:
# using scan the individial weights will have shape (num_layers, w, h)
sharding = jax.sharding.NamedSharding(mesh, P(None, 'fsdp'))
Expand All @@ -165,6 +156,7 @@ def main(
args = llama3_configs[model_type]
# Note: torchtitan's upstream config did not specify this value
args.vocab_size = 128256
args.max_seq_len = seqlen
if override_num_layers > 0:
args.n_layers = override_num_layers

Expand All @@ -174,11 +166,20 @@ def main(
torch.set_default_dtype(torch.bfloat16)
with torch.device('meta'):
gpt = titan.Transformer(args)

with torch.device('cpu'):
# need actual value for freqs_cis
freqs_cis = gpt._precompute_freqs_cis()

if use_scan:
gpt = TransfomerWithScan(gpt)

state_dict = dict(gpt.state_dict())
state_dict.pop('freqs_cis') # dont shard freqs_cis
state_dict = create_sharded_weights(gpt.state_dict(), sharding)
replicated = jax.sharding.NamedSharding(mesh, P())

state_dict['freqs_cis'] = freqs_cis.to('jax').apply_jax(jax.device_put, replicated)
gpt.load_state_dict(state_dict, assign=True)

train_loader = fake_dataloader(10, seqlen, batch_size)
Expand Down
4 changes: 4 additions & 0 deletions experimental/torch_xla2/torch_xla2/interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,7 @@ def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None):
def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None):
return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad,
kwargs_for_jax=kwargs_for_value_and_grad)

def gradient_checkpoint(torch_function, kwargs=None):
return wrap_jax_jit(torch_function, jax_jit_func=jax.checkpoint,
kwargs_for_jax=kwargs)
4 changes: 2 additions & 2 deletions experimental/torch_xla2/torch_xla2/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def device(self):
def jax_device(self):
return self._elem.device

def apply(self, jax_function, *args, **kwargs):
def apply_jax(self, jax_function, *args, **kwargs):
# Call a jax function on _elem
res = jax_function(self._elem, *args, **kwargs)
return self._env.j2t_iso(res)

def apply_(self, jax_function, *args, **kwargs):
def apply_jax_(self, jax_function, *args, **kwargs):
self._elem = jax_function(self._elem, *args, **kwargs)

def tolist(self):
Expand Down
6 changes: 2 additions & 4 deletions experimental/torch_xla2/torch_xla2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,12 @@ def make_train_step(model_fn,
loss is loaded from the dataloader.
optax_optimizer: the optimizer from optax library. for example, optax.adam
remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how
to do gradient checkpointing. If None, then it means no checkpointing.
to do gradient checkpointing. If None, then it means checkpoint everything.
mark_fsdp_sharding_axis: str. A string name for marking sharding for
fsdp. It must be an axis that exists in the current mesh.
if None, then no sharding is specified (i.e. for single device)
"""
env = torch_xla2.default_env()
@functools.partial(
remat,
policy=remat_policy)
def loss(weights, buffers, args, label): # inputs are XLATensor
with env, jax.named_scope('compute_loss'):
if mark_fsdp_sharding_axis is not None:
Expand All @@ -52,6 +49,7 @@ def loss(weights, buffers, args, label): # inputs are XLATensor
l = loss_fn(res, label)
return l

loss = interop.gradient_checkpoint(loss, kwargs={'policy': remat_policy})
grad_fn = interop.jax_value_and_grad(loss)

def step(weights, buffers, opt_state, args, label): #inputs are array
Expand Down

0 comments on commit 6891f29

Please sign in to comment.