diff --git a/experimental/torch_xla2/examples/basic_training_jax.py b/experimental/torch_xla2/examples/basic_training_jax.py index ae6efdf4856..5ca14398fd2 100644 --- a/experimental/torch_xla2/examples/basic_training_jax.py +++ b/experimental/torch_xla2/examples/basic_training_jax.py @@ -3,6 +3,8 @@ https://pytorch.org/tutorials/beginner/introyt/trainingyt.html """ +import functools +from torch_xla2 import train, interop import torch from torch.utils import _pytree as pytree import torchvision @@ -17,6 +19,8 @@ from torch.utils.tensorboard import SummaryWriter from datetime import datetime +env = torch_xla2.enable_globally() + transform = transforms.Compose( [transforms.ToTensor(), @@ -38,29 +42,7 @@ print('Training set has {} instances'.format(len(training_set))) print('Validation set has {} instances'.format(len(validation_set))) -import matplotlib.pyplot as plt import numpy as np - -# Helper function for inline image display -def matplotlib_imshow(img, one_channel=False): - if one_channel: - img = img.mean(dim=0) - img = img / 2 + 0.5 # unnormalize - npimg = img.numpy() - if one_channel: - plt.imshow(npimg, cmap="Greys") - else: - plt.imshow(np.transpose(npimg, (1, 2, 0))) - -dataiter = iter(training_loader) -images, labels = next(dataiter) - -# Create a grid from the images and show them -img_grid = torchvision.utils.make_grid(images) -matplotlib_imshow(img_grid, one_channel=True) -print(' '.join(classes[labels[j]] for j in range(4))) - - import torch.nn as nn import torch.nn.functional as F @@ -83,35 +65,30 @@ def forward(self, x): model = GarmentClassifier() loss_fn = torch.nn.CrossEntropyLoss() -jax_weights, jax_func = torch_xla2.extract_jax(model) -jax_func = jax.jit(jax_func, inline=True) jax_optimizer = optax.adam(0.01) -opt_state = jax_optimizer.init(jax_weights) +model.to('jax') # move the model to jax device +model_jittable = interop.JittableModule(model) +weights = model_jittable.params # these are trainable parameters +buffers = model_jittable.buffers # these are non-trainable parameters -def jax_loss(weights, data, label): - pred = jax_func(weights, data) - loss = torch_xla2.interop.call_torch(loss_fn, pred, label) - return loss +opt_state = interop.call_jax(jax_optimizer.init, weights) +model_fn = functools.partial(model_jittable.functional_call, 'forward') -grad_fn = jax.jit(jax.value_and_grad(jax_loss)) +train_step = train.make_train_step(model_fn, loss_fn, jax_optimizer) +train_step = interop.jax_jit(train_step, kwargs_for_jax_jit={'donate_argnums': (0, 2)}) # NB: Loss functions expect data in batches, so we're creating batches of 4 # Represents the model's confidence in each of the 10 classes for a given input -dummy_outputs = torch.rand(4, 10) +dummy_inputs = torch.rand(4, 28, 28).to('jax') +dummy_outputs = torch.rand(4, 10).to('jax') # Represents the correct class among the 10 being tested -dummy_labels = torch.tensor([1, 5, 3, 7]) - -print(dummy_outputs) -print(dummy_labels) - -loss = loss_fn(dummy_outputs, dummy_labels) -print('Total loss for this batch: {}'.format(loss.item())) - +dummy_labels = torch.tensor([1, 5, 3, 7]).to('jax') -def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): +# test train_step +def train_one_epoch(weights, buffers, opt_state, epoch_index, tb_writer): running_loss = 0. last_loss = 0. @@ -119,18 +96,16 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): # iter(training_loader) so that we can track the batch # index and do some intra-epoch reporting for i, data in enumerate(training_loader): - # Every data instance is an input + label pair - # NEW: Move model to XLA device - data = pytree.tree_map_only(torch.Tensor, - torch_xla2.tensor.t2j, data) inputs, labels = data - val, grads = grad_fn(jax_weights, (inputs, ), labels) - updates, opt_state = jax_optimizer.update(grads, opt_state) - jax_weights = optax.apply_updates(jax_weights, updates) + inputs = inputs.to('jax') + labels = labels.to('jax') + + loss, weights, opt_state = train_step( + weights, buffers, opt_state, inputs, labels) # Gather data and report - running_loss += val.item() + running_loss += loss.item() if i % 1000 == 999: last_loss = running_loss / 1000 # loss per batch print(' batch {} loss: {}'.format(i + 1, last_loss)) @@ -138,7 +113,7 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): tb_writer.add_scalar('Loss/train', last_loss, tb_x) running_loss = 0. - return last_loss, opt_state + return last_loss, weights, opt_state @@ -152,39 +127,5 @@ def train_one_epoch(jax_weights, opt_state, epoch_index, tb_writer): for epoch in range(EPOCHS): print('EPOCH {}:'.format(epoch_number + 1)) - # Make sure gradient tracking is on, and do a pass over the data - model.train(True) - - avg_loss, opt_state = train_one_epoch(jax_weights, opt_state, epoch_number, writer) - - running_vloss = 0.0 - # Set the model to evaluation mode, disabling dropout and using population - # statistics for batch normalization. - model.eval() - - # Disable gradient computation and reduce memory consumption. - with torch.no_grad(): - for i, vdata in enumerate(validation_loader): - - vinputs, vlabels = pytree.tree_map_only(torch.Tensor, torch_xla2.tensor.t2j, vdata) - voutputs = jax_func(jax_weights, (vinputs, )) # call model's forward - vloss = torch_xla2.interop.call_torch(loss_fn, voutputs, vlabels) - running_vloss += vloss - - avg_vloss = running_vloss / (i + 1) - print('LOSS train {} valid {}'.format(avg_loss, avg_vloss)) - - # Log the running loss averaged per batch - # for both training and validation - writer.add_scalars('Training vs. Validation Loss', - { 'Training' : np.asarray(avg_loss), 'Validation' : np.asarray(avg_vloss) }, - epoch_number + 1) - writer.flush() - - # Track best performance, and save the model's state - if avg_vloss < best_vloss: - best_vloss = avg_vloss - model_path = 'model_{}_{}'.format(timestamp, epoch_number) - torch.save(model.state_dict(), model_path) - - epoch_number += 1 \ No newline at end of file + avg_loss, weights, opt_state = train_one_epoch(weights, buffers, opt_state, epoch_number, writer) + print(avg_loss) diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile b/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile new file mode 100644 index 00000000000..dd7e74024f4 --- /dev/null +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/Dockerfile @@ -0,0 +1,35 @@ +# syntax=docker/dockerfile:experimental +# Use Python 3.10 as the base image +FROM python:3.10-slim-bullseye + +# Install system dependencies +RUN apt-get update && apt-get upgrade -y +RUN apt-get update && apt-get install -y curl gnupg + +# Add the Google Cloud SDK package repository +RUN echo "deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main" | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list +RUN curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add - + +# Install the Google Cloud SDK +RUN apt-get update && apt-get install -y google-cloud-sdk git + +# Set the default Python version to 3.10 +RUN update-alternatives --install /usr/bin/python3 python3 /usr/local/bin/python3.10 1 +RUN pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html +RUN pip install optax fire tensorflow tensorboard-plugin-profile +RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + +WORKDIR / +RUN git clone https://github.com/pytorch/torchtitan.git +WORKDIR /torchtitan +RUN pip install -r requirements.txt +RUN pip install . + +WORKDIR / +RUN git clone https://github.com/pytorch/xla.git +WORKDIR xla/experimental/torch_xla2 +RUN git checkout hanq_hybrid_mesh +RUN pip install -e . + +ENTRYPOINT ["python", "examples/train_llama_torchtitan/train_llama.py"] +CMD ["--batch_size=8", "--seqlen=2048"] \ No newline at end of file diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/README.md b/experimental/torch_xla2/examples/train_llama_torchtitan/README.md new file mode 100644 index 00000000000..9519eaa9dba --- /dev/null +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/README.md @@ -0,0 +1,15 @@ +Training based on torchtitan llama model +==================================== + +```bash +python train_llama.py +``` + + + +## Detailed numbers + +### v5p-8 + +seqlen = 8192 +bs = 8 diff --git a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py b/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py index b7bdfcc7615..37462c7dd83 100644 --- a/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py +++ b/experimental/torch_xla2/examples/train_llama_torchtitan/train_llama.py @@ -5,8 +5,8 @@ from collections import defaultdict import functools + def _setup_default_env(): - os.environ.setdefault('TF_CPP_MIN_LOG_LEVEL', '1') os.environ.setdefault('GRPC_VERBOSITY', 'ERROR') os.environ.setdefault('ALLOW_MULTIPLE_LIBTPU_LOAD', '1') # only need for tpu v4 @@ -22,6 +22,7 @@ def _setup_default_env(): import torch_xla2 import torch_xla2.interop +import torch_xla2.train from torch_xla2.interop import jax_view, torch_view, JittableModule import jax import jax.numpy as jnp @@ -34,10 +35,6 @@ def _setup_default_env(): P = jax.sharding.PartitionSpec - - -SEQLEN = 8192 -BATCH = 8 global_axis: Tuple[str, str] = ('fsdp', ) num_global_devices = jax.device_count() num_local_devices = jax.local_device_count() @@ -56,170 +53,82 @@ def sharded_device_put(tensor, sharding): return jax.make_array_from_single_device_arrays(shape, sharding, x_split) -class FSDPv2(torch.nn.Module): - - def __init__(self, mod): - super().__init__() - self.mod = mod - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) - self.sharding = jax.sharding.NamedSharding(self.mesh, P(*global_axis)) - - def forward(self, *args): - args = list(args) - args[0] = self.shard(args[0]) - res = self.mod(*args) - return self.shard(res) - - def shard(self, x): - return torch_xla2.interop.call_jax( - jax.lax.with_sharding_constraint, - x, - self.sharding, - ) - -def print_shapes(pyt): - for p in pytree.tree_flatten(pyt)[0]: - if hasattr(p, 'shape'): - print(p.shape, p.dtype) - - -class Module(torch.nn.Module): - - def __init__(self, inner): - super().__init__() - self.inner = FSDPv2(inner) - - def training_step(self, data, batch_id): - x, y = data - logits = self.inner(x) - num_tokens = logits.shape[-1] - logits = logits.reshape(-1, num_tokens) - y = y.reshape(-1) - return torch.nn.functional.cross_entropy( - logits, y) - - class Trainer: - def __init__(self): - self.mesh = jax.sharding.Mesh( - mesh_utils.create_device_mesh(num_partitions), - axis_names=global_axis, - ) + def __init__(self, mesh): + self.mesh = mesh self.x_sharding = jax.sharding.NamedSharding(self.mesh, P(global_axis)) self.replicated = jax.sharding.NamedSharding(self.mesh, P()) - def _shard_fsdp_style(self, state_dict, sharding=None): - if sharding is None: - sharding = self.x_sharding - def move_one_tensor(x): - jval = torch_xla2.tensor.t2j(x) - return sharded_device_put(jval, sharding) - - if isinstance(state_dict, torch.Tensor): - return move_one_tensor(state_dict) - res = {} - for k, v in sorted(state_dict.items()): - res[k] = move_one_tensor(v) - return res - - def fit(self, lightning_mod, data_loader): + def fit(self, model, loss_fn, data_loader): xla_env = torch_xla2.default_env() jax.config.update('jax_enable_x64', False) xla_env._mesh = self.mesh xla_env.use_flash_attention = True - jittable_mod = JittableModule(lightning_mod) - jax_params = self._shard_fsdp_style(jittable_mod.params) - jax_buffers = self._shard_fsdp_style(jittable_mod.buffers) - - @jax.checkpoint - def lightning_mod_loss( - weights: jax.Array, buffers: jax.Array, data: jax.Array, batch_id): - """returns loss""" - with jax.named_scope("Computing_loss"): - weights, buffers, data = torch_view((weights, buffers, data)) - # NOTE: these is needed because the original model - # did not register those as persistent buffer - with xla_env: - loss = jittable_mod.functional_call( - 'training_step', - weights, buffers, data, batch_id) - return jax_view(loss) - - jax_optimizer = optax.adamw(0.001) - - opt_state = jax_optimizer.init(jax_params) - grad_fn = jax.value_and_grad(lightning_mod_loss) + + model.to('jax') + jittable_mod = JittableModule(model) - opt_state_sharding = jax.tree_util.tree_map(lambda p : p.sharding, opt_state) + # split the params to the n devices - print('Begining training') + def model_fn(weights, buffers, args): + return jittable_mod.functional_call('forward', weights, buffers, args) - @functools.partial( - jax.jit, - donate_argnums=(0, 2), - ) - def step(jax_weights, jax_buffers, optimizer_state, xla_data, bid): - print('Tracing inside of step') - with jax.named_scope("Computing_loss_and_grad"): - loss, grads = grad_fn(jax_weights, jax_buffers, xla_data, bid) - with jax.named_scope("optimizer_updates"): - updates, opt_state = jax_optimizer.update( - grads, optimizer_state, jax_weights) - jax_weights = optax.apply_updates(jax_weights, updates) - return loss, jax_weights, opt_state - - total_param_size = 0 - for k, v in jax_params.items(): - total_param_size += v.size - - print('Total number of params: ', total_param_size) - - print('Start compiling') - start = time.perf_counter() - lowered = step.lower( - jax_params, jax_buffers, opt_state, - (jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding), - jax.ShapeDtypeStruct((BATCH, SEQLEN), jnp.dtype('int32'), sharding=self.x_sharding)), - 0 - ) - # print(lowered.as_text()) - print('program size:', len(lowered.as_text()) / 1e6, 'm chars') - step_compiled = lowered.compile() - end = time.perf_counter() - compile_time = end - start - print('End compiling', compile_time) + jax_optimizer = optax.adamw(0.001) + opt_state = torch_xla2.interop.call_jax(jax_optimizer.init, jittable_mod.params) - for co in step_compiled.cost_analysis(): - print('flops counter:', co['flops']) + train_step = torch_xla2.train.make_train_step( + model_fn, loss_fn, jax_optimizer, + remat_policy=jax.checkpoint_policies.nothing_saveable, + mark_fsdp_sharding_axis='fsdp') + print('Begining training') s = time.perf_counter() jax.profiler.start_trace('/tmp/tensorboard') print('start training') min_loop_time = 10000 for i, item in enumerate(data_loader): - inputs, labels = sharded_device_put(jax_view(xla_env.to_xla(item)), - self.x_sharding) - print('INPUT shape', inputs.shape) + inputs, labels = item + # Move them to jax device + inputs = inputs.to('jax') + labels = labels.to('jax') + + # Shard them on batch dim for fsdp + inputs.apply_(sharded_device_put, self.x_sharding) + labels.apply_(sharded_device_put, self.x_sharding) + print('INPUT shape', inputs.shape) step_start = time.perf_counter() - loss, jax_params, opt_state = step_compiled( - jax_params, jax_buffers, opt_state, (inputs, labels), 0) - jax.block_until_ready((loss, jax_params)) + loss, jittable_mod.params, opt_state = train_step( + jittable_mod.params, jittable_mod.buffers, opt_state, inputs, labels) + # wait for iteration to finish to measure time + jax.block_until_ready((loss, jittable_mod.params)) step_end = time.perf_counter() print(i, 'loss', loss, 'step latency: ', step_end - step_start) loop_time = step_end - step_start min_loop_time = min(min_loop_time, loop_time) print('======') - if i >= 2: + if i >= 3: break jax.profiler.stop_trace() - return min_loop_time, compile_time + return min_loop_time + +def create_sharded_weights(state_dict, sharding): + res = {} + env = torch_xla2.default_env() + for name, weight_meta in state_dict.items(): + with jax.default_device(jax.devices('cpu')[0]): + weight_torch = torch.randn( + weight_meta.shape, + dtype=weight_meta.dtype) + # weight_jax is jax array + weight_jax = env.to_xla(weight_torch).jax() + res[name] = env.j2t_iso(jax.make_array_from_callback( + weight_jax.shape, sharding, lambda a: weight_jax[a] + )) + return res def fake_dataloader(size, seqlen, batch_size): @@ -232,33 +141,51 @@ def main( model_type='8B', batch_size=8, seqlen=2048, - mode='regular', + override_num_layers=-1, ): - logging.getLogger("jax").setLevel(logging.DEBUG) + torch_xla2.enable_globally() + #logging.getLogger("jax").setLevel(logging.DEBUG) print(f"Running with parameters {locals()}") - global SEQLEN - global BATCH - SEQLEN = seqlen - BATCH = batch_size mesh = jax.make_mesh((len(jax.local_devices()), ), ('fsdp', )) + sharding = jax.sharding.NamedSharding(mesh, P('fsdp')) + env = torch_xla2.default_env() - env.config.use_tpu_flash_attention = use_flash_attention - env.config.shmap_flash_attention = use_flash_attention + env.config.use_tpu_flash_attention = True + env.config.shmap_flash_attention = True + env._mesh = mesh # this is the mesh used by flash attention pallas kernel args = llama3_configs[model_type] - #with torch.device('meta'): - gpt = titan.Transformer(args) - - light_mod = Module(gpt) - light_mod.to(torch.bfloat16) - + # Note: torchtitan's upstream config did not specify this value + args.vocab_size = 128256 + if override_num_layers > 0: + args.n_layers = override_num_layers + + # Note: because a single device don't have enough HBM memory + # nor enough CPU memory to hold the parameters. We instantiate + # the model on meta then manually initialize then shard each param + torch.set_default_dtype(torch.bfloat16) + with torch.device('meta'): + gpt = titan.Transformer(args) + gpt.to(torch.bfloat16) + + state_dict = create_sharded_weights(gpt.state_dict(), sharding) + gpt.load_state_dict(state_dict, assign=True) + train_loader = fake_dataloader(10, seqlen, batch_size) + def loss_fn(logits, y): + num_tokens = logits.shape[-1] + logits = logits.reshape(-1, num_tokens) + y = y.reshape(-1) + return torch.nn.functional.cross_entropy( + logits, y) + with mesh: trainer = Trainer() return trainer.fit( - light_mod, + gpt, + loss_fn, train_loader ) diff --git a/experimental/torch_xla2/torch_xla2/interop.py b/experimental/torch_xla2/torch_xla2/interop.py index d75c450d0ed..42aafaef776 100644 --- a/experimental/torch_xla2/torch_xla2/interop.py +++ b/experimental/torch_xla2/torch_xla2/interop.py @@ -198,3 +198,8 @@ def jax_jit(torch_function, kwargs_for_jax_jit=None): def jax_shard_map(torch_function, kwargs_for_jax_shard_map=None): return wrap_jax_jit(torch_function, jax_jit_func=shard_map, kwargs_for_jax=kwargs_for_jax_shard_map) + + +def jax_value_and_grad(torch_function, kwargs_for_value_and_grad=None): + return wrap_jax_jit(torch_function, jax_jit_func=jax.value_and_grad, + kwargs_for_jax=kwargs_for_value_and_grad) diff --git a/experimental/torch_xla2/torch_xla2/tensor.py b/experimental/torch_xla2/torch_xla2/tensor.py index 35d69eb7326..ec609d93ac7 100644 --- a/experimental/torch_xla2/torch_xla2/tensor.py +++ b/experimental/torch_xla2/torch_xla2/tensor.py @@ -155,6 +155,14 @@ def device(self): def jax_device(self): return self._elem.device + def apply(self, jax_function, *args, **kwargs): + # Call a jax function on _elem + res = jax_function(self._elem, *args, **kwargs) + return self._env.j2t_iso(res) + + def apply_(self, jax_function, *args, **kwargs): + self._elem = jax_function(self._elem, *args, **kwargs) + def tolist(self): return self._elem.tolist() @@ -294,7 +302,7 @@ def get_as_jax_device(self, device: Any): if not self.config.treat_cuda_as_jax_device and device.startswith('cuda'): return None - if device in ('jax_cpu', 'cpu'): + if device == 'cpu': return jax.devices('cpu')[0] return jax.devices()[0] diff --git a/experimental/torch_xla2/torch_xla2/train.py b/experimental/torch_xla2/torch_xla2/train.py new file mode 100644 index 00000000000..7dd378795a2 --- /dev/null +++ b/experimental/torch_xla2/torch_xla2/train.py @@ -0,0 +1,67 @@ +import functools +import torch +import jax +import torch_xla2 +from torch_xla2 import interop +from torch_xla2.interop import torch_view, jax_view +import optax + + +remat = torch_view(jax.remat) +mark_sharding = torch_view(jax.lax.with_sharding_constraint) + + +def make_train_step(model_fn, + loss_fn, optax_optimizer, + remat_policy=None, + mark_fsdp_sharding_axis=None): + """Make a function that do one train step given model and loss. + + model_fn: a function representing the model's forward: + i.e. has signature Callable[weights, buffers, args] -> result. Where, + weights is a pytree of trainable parameters + buffers is a pytree of non-trainable parameters / constants + args is the input data loaded from the data set + result is the return value of the model + loss_fn: a function to compute loss. + i.e. it has signature of Callable[result, label] -> loss + where, result is what model_fn returned + loss is loaded from the dataloader. + optax_optimizer: the optimizer from optax library. for example, optax.adam + remat_policy: One of jax.ad_checkpoint.checkpoint_policies, specifies how + to do gradient checkpointing. If None, then it means no checkpointing. + mark_fsdp_sharding_axis: str. A string name for marking sharding for + fsdp. It must be an axis that exists in the current mesh. + if None, then no sharding is specified (i.e. for single device) + """ + env = torch_xla2.default_env() + @functools.partial( + remat, + policy=remat_policy) + def loss(weights, buffers, args, label): # inputs are XLATensor + with env, jax.named_scope('compute_loss'): + if mark_fsdp_sharding_axis is not None: + args = mark_sharding( + args, + jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) + res = model_fn(weights, buffers, args) + if mark_fsdp_sharding_axis is not None: + res = mark_sharding(res, jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) + label = mark_sharding(label, jax.sharding.PartitionSpec(mark_fsdp_sharding_axis)) + l = loss_fn(res, label) + return l + + grad_fn = interop.jax_value_and_grad(loss) + + def step(weights, buffers, opt_state, args, label): #inputs are array + with jax.named_scope('compute_gradient'): + loss, gradient = grad_fn(weights, buffers, args, label) + + with jax.named_scope("optimizer_updates"): + updates, opt_state = interop.call_jax( + optax_optimizer.update, + gradient, opt_state, weights) + weights = interop.call_jax(optax.apply_updates, weights, updates) + return loss, weights, opt_state + + return interop.jax_jit(step, {'donate_argnums': (0, 2)}) \ No newline at end of file