From f1362fab195a458290081967fbea2bbcba9114e0 Mon Sep 17 00:00:00 2001 From: shiyang-chen Date: Wed, 20 Nov 2024 10:57:33 -0800 Subject: [PATCH] add a new version called V3_TIKTOKEN. other edits based on suggestions. --- ...fuji-1B-v3-tiktoken-flash-single-host.txt} | 0 ...1B-v3-tiktoken-flash-single-host_init.txt} | 0 ...iktoken-flash-single-host_regularizer.txt} | 0 ...oken.txt => fuji-1B-v3-tiktoken-flash.txt} | 0 ...txt => fuji-1B-v3-tiktoken-flash_init.txt} | 0 ...fuji-1B-v3-tiktoken-flash_regularizer.txt} | 0 ...fuji-3B-v3-tiktoken-flash-single-host.txt} | 0 ...3B-v3-tiktoken-flash-single-host_init.txt} | 0 ...iktoken-flash-single-host_regularizer.txt} | 0 ...oken.txt => fuji-3B-v3-tiktoken-flash.txt} | 0 ...txt => fuji-3B-v3-tiktoken-flash_init.txt} | 0 ...fuji-3B-v3-tiktoken-flash_regularizer.txt} | 0 .../fuji-70B-v1-flash-single-host.txt | 311 --------------- .../fuji-70B-v1-flash-single-host_init.txt | 10 - .../fuji-70B-v1-single-host.txt | 276 -------------- .../fuji-70B-v1-single-host_init.txt | 10 - .../fuji-70B-v2-flash-single-host.txt | 312 --------------- .../fuji-70B-v2-flash-single-host_init.txt | 10 - .../fuji-70B-v2-single-host.txt | 277 -------------- .../fuji-70B-v2-single-host_init.txt | 10 - .../fuji-70B-v3-flash-single-host.txt | 312 --------------- ...i-70B-v3-flash-single-host_regularizer.txt | 11 - .../fuji-70B-v3-flash-tiktoken_init.txt | 10 - ...fuji-70B-v3-flash-tiktoken_regularizer.txt | 11 - .../fuji-70B-v3-single-host_init.txt | 10 - .../fuji-70B-v3-single-host_regularizer.txt | 11 - ...ken.txt => fuji-70B-v3-tiktoken-flash.txt} | 0 ...xt => fuji-70B-v3-tiktoken-flash_init.txt} | 0 ...uji-70B-v3-tiktoken-flash_regularizer.txt} | 0 .../fuji-7B-v3-flash-single-host.txt | 1 + .../fuji-7B-v3-flash.txt | 1 + .../fuji-7B-v3-single-host.txt | 1 + .../fuji-7B-v3.txt | 1 + .../fuji-8B-v3-flash-single-host.txt | 356 ------------------ .../fuji-8B-v3-flash-single-host_init.txt | 10 - ...ji-8B-v3-flash-single-host_regularizer.txt | 11 - ...flash-tiktoken-single-host_regularizer.txt | 11 - .../fuji-8B-v3-flash-tiktoken_regularizer.txt | 11 - .../fuji-8B-v3-flash.txt | 356 ------------------ .../fuji-8B-v3-flash_init.txt | 10 - .../fuji-8B-v3-flash_regularizer.txt | 11 - .../fuji-8B-v3-single-host.txt | 321 ---------------- .../fuji-8B-v3-single-host_init.txt | 10 - .../fuji-8B-v3-single-host_regularizer.txt | 11 - ...fuji-8B-v3-tiktoken-flash-single-host.txt} | 0 ...8B-v3-tiktoken-flash-single-host_init.txt} | 0 ...iktoken-flash-single-host_regularizer.txt} | 0 ...oken.txt => fuji-8B-v3-tiktoken-flash.txt} | 0 ...txt => fuji-8B-v3-tiktoken-flash_init.txt} | 0 ...fuji-8B-v3-tiktoken-flash_regularizer.txt} | 0 .../fuji-8B-v3.txt | 321 ---------------- .../fuji-8B-v3_init.txt | 10 - .../fuji-8B-v3_regularizer.txt | 11 - ...t => fuji-golden-run-test-v3-tiktoken.txt} | 97 ++--- ...fuji-golden-run-test-v3-tiktoken_init.txt} | 0 ...lden-run-test-v3-tiktoken_regularizer.txt} | 0 ...en.txt => fuji-test-v3-tiktoken-flash.txt} | 0 .../fuji-test-v3-tiktoken-flash_init.txt | 9 + ...ji-test-v3-tiktoken-flash_regularizer.txt} | 1 - axlearn/experiments/text/gpt/c4_trainer.py | 10 +- axlearn/experiments/text/gpt/common.py | 4 +- axlearn/experiments/text/gpt/fuji.py | 37 +- .../text/gpt/param_converter_test.py | 4 + .../text/gpt/vocabulary_fuji_v3.py | 11 +- 64 files changed, 87 insertions(+), 3131 deletions(-) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken-single-host.txt => fuji-1B-v3-tiktoken-flash-single-host.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken-single-host_init.txt => fuji-1B-v3-tiktoken-flash-single-host_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt => fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken.txt => fuji-1B-v3-tiktoken-flash.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken_init.txt => fuji-1B-v3-tiktoken-flash_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-1B-v3-flash-tiktoken_regularizer.txt => fuji-1B-v3-tiktoken-flash_regularizer.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken-single-host.txt => fuji-3B-v3-tiktoken-flash-single-host.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken-single-host_init.txt => fuji-3B-v3-tiktoken-flash-single-host_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt => fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken.txt => fuji-3B-v3-tiktoken-flash.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken_init.txt => fuji-3B-v3-tiktoken-flash_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-3B-v3-flash-tiktoken_regularizer.txt => fuji-3B-v3-tiktoken-flash_regularizer.txt} (100%) delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v3-flash-tiktoken.txt => fuji-70B-v3-tiktoken-flash.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v3-flash-single-host_init.txt => fuji-70B-v3-tiktoken-flash_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v1-flash-single-host_regularizer.txt => fuji-70B-v3-tiktoken-flash_regularizer.txt} (100%) delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-8B-v3-flash-tiktoken-single-host.txt => fuji-8B-v3-tiktoken-flash-single-host.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-8B-v3-flash-tiktoken-single-host_init.txt => fuji-8B-v3-tiktoken-flash-single-host_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v1-single-host_regularizer.txt => fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-8B-v3-flash-tiktoken.txt => fuji-8B-v3-tiktoken-flash.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-8B-v3-flash-tiktoken_init.txt => fuji-8B-v3-tiktoken-flash_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v2-flash-single-host_regularizer.txt => fuji-8B-v3-tiktoken-flash_regularizer.txt} (100%) delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt delete mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v3-single-host.txt => fuji-golden-run-test-v3-tiktoken.txt} (81%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-test-v3-flash-tiktoken_init.txt => fuji-golden-run-test-v3-tiktoken_init.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-test-v3-flash-tiktoken_regularizer.txt => fuji-golden-run-test-v3-tiktoken_regularizer.txt} (100%) rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-test-v3-flash-tiktoken.txt => fuji-test-v3-tiktoken-flash.txt} (100%) create mode 100644 axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt rename axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/{fuji-70B-v2-single-host_regularizer.txt => fuji-test-v3-tiktoken-flash_regularizer.txt} (95%) diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash-single-host_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-flash-tiktoken_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-1B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash-single-host_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-flash-tiktoken_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-3B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt deleted file mode 100644 index 4f3e7862..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host.txt +++ /dev/null @@ -1,311 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 367001 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 367001 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 32 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 2048 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 367001 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 32 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 2048 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 32 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 2048 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 367001 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 367001 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: 1 -mesh_shape[2]: 1 -mesh_shape[3]: -1 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.dim: 8192 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 32768 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt deleted file mode 100644 index ab71a133..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt deleted file mode 100644 index 074855a4..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host.txt +++ /dev/null @@ -1,276 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 367001 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 367001 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 32 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 2048 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 367001 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 32 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 2048 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 32 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 2048 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 367001 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 367001 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: 1 -mesh_shape[2]: 1 -mesh_shape[3]: -1 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.dim: 8192 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'MultiheadAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'MultiheadAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'MultiheadAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'MultiheadAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'MultiheadAttention.o_proj' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.MultiheadAttention' -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 32768 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt deleted file mode 100644 index ab71a133..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt deleted file mode 100644 index 857d879f..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host.txt +++ /dev/null @@ -1,312 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 524288 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 524288 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 16 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 4096 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 524288 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 16 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 4096 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 16 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 4096 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 524288 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 524288 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: 1 -mesh_shape[2]: 1 -mesh_shape[3]: -1 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.dim: 8192 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 32768 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt deleted file mode 100644 index 2f13215e..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt deleted file mode 100644 index b410021d..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host.txt +++ /dev/null @@ -1,277 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 524288 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 524288 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 16 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 4096 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 524288 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 16 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 4096 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 16 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 4096 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_32k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 524288 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 524288 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: 1 -mesh_shape[2]: 1 -mesh_shape[3]: -1 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.dim: 8192 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 10000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 32768 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt deleted file mode 100644 index 2f13215e..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32768, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(32768, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt deleted file mode 100644 index 03eb4153..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host.txt +++ /dev/null @@ -1,312 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 8 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 8 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 8 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: 1 -mesh_shape[2]: 1 -mesh_shape[3]: -1 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.dim: 8192 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 128001 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 128004 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt deleted file mode 100644 index f0e1c9fe..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt deleted file mode 100644 index f0e1c9fe..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[128256, 8192], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 80, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8192, 64, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8192, 28672), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(28672, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(128256, 8192), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-tiktoken.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-flash-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-flash-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt index 32982ee4..999b3b91 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash-single-host.txt @@ -188,6 +188,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt index 74483207..d37b97d3 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-flash.txt @@ -188,6 +188,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt index 99be5fdb..b2a5fc5d 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3-single-host.txt @@ -188,6 +188,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt index 4d32be7c..5a0737a4 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-7B-v3.txt @@ -188,6 +188,7 @@ model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' model.decoder.dim: 4096 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt deleted file mode 100644 index 69666d00..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host.txt +++ /dev/null @@ -1,356 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 16 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 16 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 16 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.0003 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v4-(1024|2048)' -mesh_rules[0][1][0]: 1 -mesh_rules[0][1][1]: -1 -mesh_rules[0][1][2]: 1 -mesh_rules[0][1][3]: 16 -mesh_rules[0][1][4]: 1 -mesh_rules[0][1][5]: 1 -mesh_rules[1][0]: 'tpu-v5litepod-256' -mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 -mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' -mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[2][0]: 'tpu-v5litepod-256-2' -mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[3][0]: 'tpu-v5litepod-256-4' -mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' -mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' -mesh_rules[5][1][0]: 1 -mesh_rules[5][1][1]: -1 -mesh_rules[5][1][2]: 1 -mesh_rules[5][1][3]: 8 -mesh_rules[5][1][4]: 1 -mesh_rules[5][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: -1 -mesh_shape[2]: 1 -mesh_shape[3]: 8 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' -model.decoder.dim: 4096 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.num_heads: 32 -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 32 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 131072 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt deleted file mode 100644 index 311e12ed..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-single-host_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt deleted file mode 100644 index d60b1007..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash.txt +++ /dev/null @@ -1,356 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.0003 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v4-(1024|2048)' -mesh_rules[0][1][0]: 1 -mesh_rules[0][1][1]: -1 -mesh_rules[0][1][2]: 1 -mesh_rules[0][1][3]: 16 -mesh_rules[0][1][4]: 1 -mesh_rules[0][1][5]: 1 -mesh_rules[1][0]: 'tpu-v5litepod-256' -mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 -mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' -mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[2][0]: 'tpu-v5litepod-256-2' -mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[3][0]: 'tpu-v5litepod-256-4' -mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' -mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' -mesh_rules[5][1][0]: 1 -mesh_rules[5][1][1]: -1 -mesh_rules[5][1][2]: 1 -mesh_rules[5][1][3]: 8 -mesh_rules[5][1][4]: 1 -mesh_rules[5][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: -1 -mesh_shape[2]: 1 -mesh_shape[3]: 8 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' -model.decoder.dim: 4096 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'FlashAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'FlashAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'FlashAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'FlashAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'FlashAttention.o_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[5]: 'TransformerFeedForwardLayer.activation' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[6]: 'TransformerFeedForwardLayer.linear2' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.flash_attention.layer.FlashAttention' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][0]: 'seq' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][2][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bsnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][1]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][2]: None -model.decoder.transformer.layer.self_attention.attention.mha_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.num_heads: 32 -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][1]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][2]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['btnh'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][0]: 'data' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][1]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][0][2]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_dim_to_partition_spec['bnts'][3]: None -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.attention.tpu_block_size: 512 -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 32 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 131072 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt deleted file mode 100644 index 311e12ed..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt deleted file mode 100644 index 056640ca..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host.txt +++ /dev/null @@ -1,321 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 16 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 16 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 16 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.0003 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v4-(1024|2048)' -mesh_rules[0][1][0]: 1 -mesh_rules[0][1][1]: -1 -mesh_rules[0][1][2]: 1 -mesh_rules[0][1][3]: 16 -mesh_rules[0][1][4]: 1 -mesh_rules[0][1][5]: 1 -mesh_rules[1][0]: 'tpu-v5litepod-256' -mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 -mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' -mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[2][0]: 'tpu-v5litepod-256-2' -mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[3][0]: 'tpu-v5litepod-256-4' -mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' -mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' -mesh_rules[5][1][0]: 1 -mesh_rules[5][1][1]: -1 -mesh_rules[5][1][2]: 1 -mesh_rules[5][1][3]: 8 -mesh_rules[5][1][4]: 1 -mesh_rules[5][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: -1 -mesh_shape[2]: 1 -mesh_shape[3]: 8 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' -model.decoder.dim: 4096 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' -model.decoder.transformer.layer.self_attention.attention.num_heads: 32 -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 32 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 131072 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt deleted file mode 100644 index 311e12ed..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-single-host_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken-single-host_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v1-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash-single-host_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-flash-tiktoken_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-flash-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3-tiktoken-flash_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt deleted file mode 100644 index 7d394aee..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3.txt +++ /dev/null @@ -1,321 +0,0 @@ -batch_axis_names[0]: 'data' -batch_axis_names[1]: 'expert' -batch_axis_names[2]: 'fsdp' -batch_axis_names[3]: 'seq' -checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 -checkpointer.keep_last_n: 3 -checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' -checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 -checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 -checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' -checkpointer.storage.timeout_secs: 3600 -evalers['train'].eval_dtype: 'jax.numpy.bfloat16' -evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 -evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 -evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 512 -evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['train'].input.batcher.prefetch_buffer_size: -1 -evalers['train'].input.is_training: False -evalers['train'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 -evalers['train'].input.source.replace_newlines_with: '\n' -evalers['train'].input.source.split: 'train[:8192]' -evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['train'].metric_calculator.model_method: 'forward' -evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['train'].summary_writer.write_every_n_steps: 1 -evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' -evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 -evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 -evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 512 -evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -evalers['validation'].input.batcher.prefetch_buffer_size: -1 -evalers['validation'].input.is_training: False -evalers['validation'].input.klass: 'axlearn.common.input_tf_data.Input' -evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity' -evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' -evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' -evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 -evalers['validation'].input.source.replace_newlines_with: '\n' -evalers['validation'].input.source.split: 'validation' -evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' -evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' -evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -evalers['validation'].metric_calculator.model_method: 'forward' -evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -evalers['validation'].summary_writer.write_every_n_steps: 1 -input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 512 -input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' -input.batcher.prefetch_buffer_size: -1 -input.is_training: True -input.klass: 'axlearn.common.input_tf_data.Input' -input.processor.fn: 'axlearn.common.input_tf_data.identity' -input.source.data_mixture_components[0]['name']: 'c4/en:3.0.1' -input.source.data_mixture_components[0]['weight']: 1.0 -input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 -input.source.data_mixture_components[0]['split']: 'train' -input.source.data_mixture_components[0]['info']: '' -input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 -input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' -input.source.preprocessor.max_padding_fraction: 0.5 -input.source.preprocessor.shuffle_buffer_size: 8192 -input.source.preprocessor.window_size: 128 -input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'bpe_128k_c4.model' -klass: 'axlearn.common.trainer.SpmdTrainer' -learner.ema.fn: 'axlearn.common.optimizers.param_ema' -learner.enable_per_variable_summaries: False -learner.klass: 'axlearn.common.learner.Learner' -learner.optimizer.args[0].eps: 1e-08 -learner.optimizer.args[0].fn: 'axlearn.common.optimizers.clip_by_global_norm' -learner.optimizer.args[0].max_norm: 1 -learner.optimizer.args[1].b1: 0.9 -learner.optimizer.args[1].b2: 0.95 -learner.optimizer.args[1].eps: 1e-08 -learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.0003 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 -learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 -mesh_axis_names[0]: 'pipeline' -mesh_axis_names[1]: 'data' -mesh_axis_names[2]: 'expert' -mesh_axis_names[3]: 'fsdp' -mesh_axis_names[4]: 'seq' -mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v4-(1024|2048)' -mesh_rules[0][1][0]: 1 -mesh_rules[0][1][1]: -1 -mesh_rules[0][1][2]: 1 -mesh_rules[0][1][3]: 16 -mesh_rules[0][1][4]: 1 -mesh_rules[0][1][5]: 1 -mesh_rules[1][0]: 'tpu-v5litepod-256' -mesh_rules[1][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[1][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[1][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[1][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[1][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[1][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[1][1].config_modifiers[2].grad_acc_steps: 4 -mesh_rules[1][1].config_modifiers[2].klass: 'axlearn.common.trainer_config_modifier.GradientAccumulationModifier' -mesh_rules[1][1].config_modifiers[2].metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' -mesh_rules[1][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[2][0]: 'tpu-v5litepod-256-2' -mesh_rules[2][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[2][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[2][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[2][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[2][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[2][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[2][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[3][0]: 'tpu-v5litepod-256-4' -mesh_rules[3][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[3][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[3][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[3][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[3][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[3][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy']: 'jax._src.ad_checkpoint.dots_saveable' -mesh_rules[3][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[4][0]: 'tpu-v5p-.*' -mesh_rules[4][1][0]: 1 -mesh_rules[4][1][1]: -1 -mesh_rules[4][1][2]: 1 -mesh_rules[4][1][3]: 8 -mesh_rules[4][1][4]: 1 -mesh_rules[4][1][5]: 1 -mesh_rules[5][0]: 'gpu-(p5.48xlarge|p4de.24xlarge|a3-highgpu-8g)-(256|512|1024)' -mesh_rules[5][1][0]: 1 -mesh_rules[5][1][1]: -1 -mesh_rules[5][1][2]: 1 -mesh_rules[5][1][3]: 8 -mesh_rules[5][1][4]: 1 -mesh_rules[5][1][5]: 1 -mesh_shape[0]: 1 -mesh_shape[1]: -1 -mesh_shape[2]: 1 -mesh_shape[3]: 8 -mesh_shape[4]: 1 -mesh_shape[5]: 1 -model.batch_axis_names[0]: 'data' -model.batch_axis_names[1]: 'expert' -model.batch_axis_names[2]: 'fsdp' -model.decoder.attention_mask: None -model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' -model.decoder.dim: 4096 -model.decoder.dropout_rate: 0.0 -model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' -model.decoder.emb.token_emb.klass: 'axlearn.common.layers.Embedding' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].fan: 'fan_out' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.decoder.emb.token_emb.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.decoder.emb.token_emb.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.decoder.emb.token_emb.param_partition_spec[0]: None -model.decoder.emb.token_emb.param_partition_spec[1]: 'model' -model.decoder.eos_token_id: 1 -model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' -model.decoder.logits_partition_spec[0][0]: 'data' -model.decoder.logits_partition_spec[0][1]: 'expert' -model.decoder.logits_partition_spec[0][2]: 'fsdp' -model.decoder.logits_partition_spec[1]: 'seq' -model.decoder.logits_partition_spec[2]: 'model' -model.decoder.output_dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.output_norm.eps: 1e-05 -model.decoder.output_norm.forward_dtype: None -model.decoder.output_norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.pad_token_id: 0 -model.decoder.transformer.klass: 'axlearn.common.attention.RepeatedTransformerLayer' -model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' -model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' -model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 -model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' -model.decoder.transformer.layer.feed_forward.linear1.bias: False -model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.feed_forward.linear1.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.bias: False -model.decoder.transformer.layer.feed_forward.linear2.klass: 'axlearn.common.layers.Linear' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][0]: 'data' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][1]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[0][2]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[1]: 'seq' -model.decoder.transformer.layer.feed_forward.linear2.output_partition_spec[2]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[0]: 'model' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][0]: 'expert' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][1]: 'fsdp' -model.decoder.transformer.layer.feed_forward.linear2.param_partition_spec[1][2]: 'seq' -model.decoder.transformer.layer.feed_forward.norm.eps: 1e-05 -model.decoder.transformer.layer.feed_forward.norm.forward_dtype: None -model.decoder.transformer.layer.feed_forward.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.feed_forward.residual_weight: 1.0 -model.decoder.transformer.layer.feed_forward.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.feed_forward.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.feed_forward.structure: 'prenorm' -model.decoder.transformer.layer.klass: 'axlearn.common.attention.TransformerLayer' -model.decoder.transformer.layer.remat_spec['prevent_cse']: False -model.decoder.transformer.layer.remat_spec['policy'].fn: 'jax._src.ad_checkpoint.save_only_these_names' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[0]: 'GroupedQueryAttention.q_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[1]: 'GroupedQueryAttention.k_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[2]: 'GroupedQueryAttention.v_proj' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[3]: 'GroupedQueryAttention.context' -model.decoder.transformer.layer.remat_spec['policy'].names_which_can_be_saved[4]: 'GroupedQueryAttention.o_proj' -model.decoder.transformer.layer.self_attention.attention.causal: True -model.decoder.transformer.layer.self_attention.attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.attention.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.cache_dtype: 'jax.numpy.bfloat16' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.klass: 'axlearn.common.attention.FusedGroupedQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.bias: False -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.klass: 'axlearn.common.attention.MultiheadInputLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 -model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' -model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 -model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False -model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' -model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' -model.decoder.transformer.layer.self_attention.attention.num_heads: 32 -model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False -model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][1]: 'fsdp' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][2]: 'seq' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[1]: 'model' -model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.query_scale.klass: 'axlearn.common.attention.ScaleQuery' -model.decoder.transformer.layer.self_attention.dropout.klass: 'axlearn.common.layers.Dropout' -model.decoder.transformer.layer.self_attention.klass: 'axlearn.common.attention.TransformerAttentionLayer' -model.decoder.transformer.layer.self_attention.norm.eps: 1e-05 -model.decoder.transformer.layer.self_attention.norm.forward_dtype: None -model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layers.RMSNorm' -model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' -model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' -model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 32 -model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' -model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' -model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 131072 -model.dtype: 'jax.numpy.float32' -model.klass: 'axlearn.common.causal_lm.Model' -model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' -model.param_init.init_by_param_name['.*weight$'].fan: 'fan_in' -model.param_init.init_by_param_name['.*weight$'].klass: 'axlearn.common.param_init.WeightInitializer' -model.param_init.init_by_param_name['.*weight$'].scale: 1.0 -model.param_init.klass: 'axlearn.common.param_init.DefaultInitializer' -model.seq_axis_names[0]: 'seq' -model.z_loss_scale: 0.0 -name: 'gpt_trainer' -prune_empty_state_updates: True -save_input_iterator: False -start_trace_process_indices[0]: 0 -summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' -summary_writer.max_queue: 1000 -summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt deleted file mode 100644 index 311e12ed..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_init.txt +++ /dev/null @@ -1,10 +0,0 @@ -decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[131072, 4096], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 48, 128), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(4096, 32, 128), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(4096, 14336), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(14336, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) -decoder/output_norm/scale: constant(1.0) -decoder/lm_head/weight: normal(0, 1.0 / fan_in), shape=(131072, 4096), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt deleted file mode 100644 index 65733fb7..00000000 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-8B-v3_regularizer.txt +++ /dev/null @@ -1,11 +0,0 @@ -====================weight_decay_scale root.optimizer==================== -decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 -decoder/output_norm/scale: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 -decoder/transformer/repeat/layer/feed_forward/linear2/weight: 1 -decoder/transformer/repeat/layer/feed_forward/norm/scale: 1 -decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: 1 -decoder/transformer/repeat/layer/self_attention/norm/scale: 1 diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt similarity index 81% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt index 11d87ce4..d05e6283 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v3-single-host.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken.txt @@ -3,22 +3,22 @@ batch_axis_names[1]: 'expert' batch_axis_names[2]: 'fsdp' batch_axis_names[3]: 'seq' checkpointer.gc_loop_interval_seconds: 60 -checkpointer.keep_every_n_steps: 50000 +checkpointer.keep_every_n_steps: 3000 checkpointer.keep_last_n: 3 checkpointer.klass: 'axlearn.common.checkpointer.Checkpointer' checkpointer.save_policy.fn: 'axlearn.common.checkpointer.every_n_steps_and_last_policy' -checkpointer.save_policy.max_step: 3932160 +checkpointer.save_policy.max_step: 3000 checkpointer.save_policy.min_step: 1 -checkpointer.save_policy.n: 5000 +checkpointer.save_policy.n: 500 checkpointer.storage.klass: 'axlearn.common.checkpointer.TensorStoreStateStorage' checkpointer.storage.timeout_secs: 3600 evalers['train'].eval_dtype: 'jax.numpy.bfloat16' evalers['train'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['train'].eval_policy.max_step: 3932160 +evalers['train'].eval_policy.max_step: 3000 evalers['train'].eval_policy.min_step: 1 -evalers['train'].eval_policy.n: 5000 +evalers['train'].eval_policy.n: 1500 evalers['train'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['train'].input.batcher.global_batch_size: 8 +evalers['train'].input.batcher.global_batch_size: 32 evalers['train'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['train'].input.batcher.prefetch_buffer_size: -1 evalers['train'].input.is_training: False @@ -27,12 +27,12 @@ evalers['train'].input.processor.fn: 'axlearn.common.input_tf_data.identity' evalers['train'].input.source.dataset_name: 'c4/en:3.0.1' evalers['train'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' evalers['train'].input.source.is_training: False -evalers['train'].input.source.max_sequence_length: 8192 +evalers['train'].input.source.max_sequence_length: 64 evalers['train'].input.source.replace_newlines_with: '\n' evalers['train'].input.source.split: 'train[:8192]' evalers['train'].input.source.train_shuffle_buffer_size: 16384 -evalers['train'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['train'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['train'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['train'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['train'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['train'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -41,11 +41,11 @@ evalers['train'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWri evalers['train'].summary_writer.write_every_n_steps: 1 evalers['validation'].eval_dtype: 'jax.numpy.bfloat16' evalers['validation'].eval_policy.fn: 'axlearn.common.evaler.every_n_steps_policy' -evalers['validation'].eval_policy.max_step: 3932160 +evalers['validation'].eval_policy.max_step: 3000 evalers['validation'].eval_policy.min_step: 1 -evalers['validation'].eval_policy.n: 5000 +evalers['validation'].eval_policy.n: 1500 evalers['validation'].input.batcher.fn: 'axlearn.common.input_tf_data.batch' -evalers['validation'].input.batcher.global_batch_size: 8 +evalers['validation'].input.batcher.global_batch_size: 32 evalers['validation'].input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' evalers['validation'].input.batcher.prefetch_buffer_size: -1 evalers['validation'].input.is_training: False @@ -54,12 +54,12 @@ evalers['validation'].input.processor.fn: 'axlearn.common.input_tf_data.identity evalers['validation'].input.source.dataset_name: 'c4/en:3.0.1' evalers['validation'].input.source.fn: 'axlearn.experiments.text.gpt.common.tfds_input' evalers['validation'].input.source.is_training: False -evalers['validation'].input.source.max_sequence_length: 8192 +evalers['validation'].input.source.max_sequence_length: 64 evalers['validation'].input.source.replace_newlines_with: '\n' evalers['validation'].input.source.split: 'validation' evalers['validation'].input.source.train_shuffle_buffer_size: 16384 -evalers['validation'].input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -evalers['validation'].input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +evalers['validation'].input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' evalers['validation'].klass: 'axlearn.common.evaler.SpmdEvaler' evalers['validation'].metric_calculator.klass: 'axlearn.common.evaler.ModelSummaryAccumulator' evalers['validation'].metric_calculator.metric_accumulator.klass: 'axlearn.common.metrics.MetricAccumulator' @@ -67,7 +67,7 @@ evalers['validation'].metric_calculator.model_method: 'forward' evalers['validation'].summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' evalers['validation'].summary_writer.write_every_n_steps: 1 input.batcher.fn: 'axlearn.common.input_tf_data.batch' -input.batcher.global_batch_size: 8 +input.batcher.global_batch_size: 32 input.batcher.pad_example_fn: 'axlearn.common.input_tf_data.default_pad_example_fn' input.batcher.prefetch_buffer_size: -1 input.is_training: True @@ -79,14 +79,14 @@ input.source.data_mixture_components[0]['shuffle_buffer_size']: 8192 input.source.data_mixture_components[0]['split']: 'train' input.source.data_mixture_components[0]['info']: '' input.source.fn: 'axlearn.experiments.text.gpt.common.mixture_train_input_source' -input.source.max_sequence_length: 8192 +input.source.max_sequence_length: 64 input.source.preprocessor.fn: 'axlearn.common.input_lm.lm_text_preprocessor' input.source.preprocessor.max_padding_fraction: 0.5 input.source.preprocessor.shuffle_buffer_size: 8192 input.source.preprocessor.window_size: 128 input.source.replace_newlines_with: '' -input.source.vocab_cfg.fn: 'axlearn.experiments.text.common.vocab' -input.source.vocab_cfg.sentencepiece_model_name: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.filename: 'Llama-3-tokenizer.json' +input.source.vocab_cfg.klass: 'axlearn.experiments.text.gpt.vocabulary_fuji_v3.FujiV3Vocabulary' klass: 'axlearn.common.trainer.SpmdTrainer' learner.ema.fn: 'axlearn.common.optimizers.param_ema' learner.enable_per_variable_summaries: False @@ -98,54 +98,29 @@ learner.optimizer.args[1].b1: 0.9 learner.optimizer.args[1].b2: 0.95 learner.optimizer.args[1].eps: 1e-08 learner.optimizer.args[1].fn: 'axlearn.common.optimizers.adamw_decoupled_optimizer' -learner.optimizer.args[1].learning_rate: 0.00015 -learner.optimizer.args[1].update_schedule.alpha: 0.1 -learner.optimizer.args[1].update_schedule.begin_value: 0.0 -learner.optimizer.args[1].update_schedule.fn: 'axlearn.common.schedule.cosine_with_linear_warmup' -learner.optimizer.args[1].update_schedule.max_step: 3932160 -learner.optimizer.args[1].update_schedule.peak_lr: 1.0 -learner.optimizer.args[1].update_schedule.warmup_steps: 2000 -learner.optimizer.args[1].weight_decay: 0.1 +learner.optimizer.args[1].learning_rate: 0.3 +learner.optimizer.args[1].update_schedule: 1 +learner.optimizer.args[1].weight_decay: 0.01 learner.optimizer.fn: 'axlearn.common.optimizers.chain' -max_step: 3932160 +max_step: 5 mesh_axis_names[0]: 'pipeline' mesh_axis_names[1]: 'data' mesh_axis_names[2]: 'expert' mesh_axis_names[3]: 'fsdp' mesh_axis_names[4]: 'seq' mesh_axis_names[5]: 'model' -mesh_rules[0][0]: 'tpu-v5litepod-256-4' -mesh_rules[0][1].config_modifiers[0].klass: 'axlearn.common.trainer_config_modifier.MeshShapeModifier' -mesh_rules[0][1].config_modifiers[0].mesh_shape[0]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[1]: -1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[2]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[3]: 256 -mesh_rules[0][1].config_modifiers[0].mesh_shape[4]: 1 -mesh_rules[0][1].config_modifiers[0].mesh_shape[5]: 1 -mesh_rules[0][1].config_modifiers[1].klass: 'axlearn.common.trainer_config_modifier.RematSpecModifier' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['prevent_cse']: True -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].fn: 'axlearn.common.utils.offload_dots_saveable' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_dst: 'pinned_host' -mesh_rules[0][1].config_modifiers[1].remat_policies['model.decoder.transformer.layer']['policy'].offload_src: 'device' -mesh_rules[0][1].klass: 'axlearn.common.trainer_config_modifier.ChainConfigModifier' -mesh_rules[1][0]: 'gpu-(p5.48xlarge|p4de.24xlarge)-(512|1024)' -mesh_rules[1][1][0]: 1 -mesh_rules[1][1][1]: -1 -mesh_rules[1][1][2]: 1 -mesh_rules[1][1][3]: 128 -mesh_rules[1][1][4]: 1 -mesh_rules[1][1][5]: 1 mesh_shape[0]: 1 -mesh_shape[1]: 1 +mesh_shape[1]: -1 mesh_shape[2]: 1 -mesh_shape[3]: -1 +mesh_shape[3]: 1 mesh_shape[4]: 1 mesh_shape[5]: 1 model.batch_axis_names[0]: 'data' model.batch_axis_names[1]: 'expert' model.batch_axis_names[2]: 'fsdp' model.decoder.attention_mask: None -model.decoder.dim: 8192 +model.decoder.decoding.klass: 'axlearn.common.decoder.DecodingLayer' +model.decoder.dim: 8 model.decoder.dropout_rate: 0.0 model.decoder.emb.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.emb.klass: 'axlearn.common.embedding.TransformerTextEmbeddings' @@ -159,9 +134,6 @@ model.decoder.emb.token_emb.param_partition_spec[0]: None model.decoder.emb.token_emb.param_partition_spec[1]: 'model' model.decoder.eos_token_id: 128001 model.decoder.klass: 'axlearn.common.decoder.Decoder' -model.decoder.lm_head.klass: 'axlearn.common.decoder.LmHead' -model.decoder.lm_head.param_partition_spec[0]: None -model.decoder.lm_head.param_partition_spec[1]: 'model' model.decoder.logits_partition_spec[0][0]: 'data' model.decoder.logits_partition_spec[0][1]: 'expert' model.decoder.logits_partition_spec[0][2]: 'fsdp' @@ -177,8 +149,8 @@ model.decoder.transformer.layer.feed_forward.activation[0]: 'nn.silu' model.decoder.transformer.layer.feed_forward.activation[1]: 'linear' model.decoder.transformer.layer.feed_forward.dropout.klass: 'axlearn.common.layers.Dropout' model.decoder.transformer.layer.feed_forward.hidden_dim.fn: 'axlearn.experiments.text.gpt.common.scale_fn' -model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 256 -model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 3.5 +model.decoder.transformer.layer.feed_forward.hidden_dim.round_up_to_multiples_of: 16 +model.decoder.transformer.layer.feed_forward.hidden_dim.scale: 2.6666666666666665 model.decoder.transformer.layer.feed_forward.klass: 'axlearn.common.attention.TransformerFeedForwardLayer' model.decoder.transformer.layer.feed_forward.linear1.bias: False model.decoder.transformer.layer.feed_forward.linear1.klass: 'axlearn.common.layers.Linear' @@ -229,14 +201,14 @@ model.decoder.transformer.layer.self_attention.attention.input_linear.input_line model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[0][2]: 'seq' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[1]: 'model' model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.layer.param_partition_spec[2]: None -model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 8 +model.decoder.transformer.layer.self_attention.attention.input_linear.input_linear.num_kv_heads: 2 model.decoder.transformer.layer.self_attention.attention.input_linear.klass: 'axlearn.common.attention.RoFormerQKVLinear' model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.klass: 'axlearn.common.attention.RoFormerSinusoidalPositionalEmbedding' model.decoder.transformer.layer.self_attention.attention.input_linear.rope_pos_emb_layer.theta: 500000.0 model.decoder.transformer.layer.self_attention.attention.input_linear.rotary_value: False model.decoder.transformer.layer.self_attention.attention.key_scale.klass: 'axlearn.common.attention.ScaleKey' model.decoder.transformer.layer.self_attention.attention.klass: 'axlearn.common.attention.GroupedQueryAttention' -model.decoder.transformer.layer.self_attention.attention.num_heads: 64 +model.decoder.transformer.layer.self_attention.attention.num_heads: 4 model.decoder.transformer.layer.self_attention.attention.output_linear.bias: False model.decoder.transformer.layer.self_attention.attention.output_linear.klass: 'axlearn.common.attention.MultiheadOutputLinear' model.decoder.transformer.layer.self_attention.attention.output_linear.param_partition_spec[0][0]: 'expert' @@ -253,11 +225,11 @@ model.decoder.transformer.layer.self_attention.norm.klass: 'axlearn.common.layer model.decoder.transformer.layer.self_attention.stochastic_depth.klass: 'axlearn.common.layers.StochasticDepth' model.decoder.transformer.layer.self_attention.stochastic_depth.mode: 'row' model.decoder.transformer.layer.self_attention.structure: 'prenorm' -model.decoder.transformer.num_layers: 80 +model.decoder.transformer.num_layers: 4 model.decoder.transformer.repeat.drop_output.fn: 'axlearn.common.repeat._drop_by_regex' model.decoder.transformer.repeat.drop_output.rules[0]: 'module_outputs.*' model.decoder.transformer.repeat.klass: 'axlearn.common.attention._TransformerRepeat' -model.decoder.vocab_size: 128256 +model.decoder.vocab_size: 32 model.dtype: 'jax.numpy.float32' model.klass: 'axlearn.common.causal_lm.Model' model.param_init.init_by_param_name['.*weight$'].distribution: 'normal' @@ -274,4 +246,5 @@ start_trace_process_indices[0]: 0 summary_writer.klass: 'axlearn.common.summary_writer.SummaryWriter' summary_writer.max_queue: 1000 summary_writer.write_every_n_steps: 100 -train_dtype: 'jax.numpy.bfloat16' \ No newline at end of file +train_dtype: 'jax.numpy.bfloat16' +vlog: 1 \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_init.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_init.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-golden-run-test-v3-tiktoken_regularizer.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash.txt similarity index 100% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-flash-tiktoken.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash.txt diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt new file mode 100644 index 00000000..61615aa5 --- /dev/null +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_init.txt @@ -0,0 +1,9 @@ +decoder/emb/token_emb/weight: normal(0, 1.0 / fan_out), shape=[32, 8], axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/self_attention/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/self_attention/attention/i_proj/i_proj/qkv_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 8, 2), axes=FanAxes(in_axis=0, out_axis=(1, 2), batch_axis=()) +decoder/transformer/repeat/layer/self_attention/attention/o_proj/weight: normal(0, 1.0 / fan_in), shape=(8, 4, 2), axes=FanAxes(in_axis=(1, 2), out_axis=0, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/norm/scale: constant(1.0) +decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: normal(0, 1.0 / fan_in), shape=(8, 32), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/transformer/repeat/layer/feed_forward/linear2/weight: normal(0, 1.0 / fan_in), shape=(32, 8), axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()) +decoder/output_norm/scale: constant(1.0) \ No newline at end of file diff --git a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt similarity index 95% rename from axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt rename to axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt index 65733fb7..03fb7437 100644 --- a/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-70B-v2-single-host_regularizer.txt +++ b/axlearn/experiments/testdata/axlearn.experiments.text.gpt.c4_trainer/fuji-test-v3-tiktoken-flash_regularizer.txt @@ -1,6 +1,5 @@ ====================weight_decay_scale root.optimizer==================== decoder/emb/token_emb/weight: 1 -decoder/lm_head/weight: 1 decoder/output_norm/scale: 1 decoder/transformer/repeat/layer/feed_forward/linear1_0/weight: 1 decoder/transformer/repeat/layer/feed_forward/linear1_1/weight: 1 diff --git a/axlearn/experiments/text/gpt/c4_trainer.py b/axlearn/experiments/text/gpt/c4_trainer.py index a9d403b7..f48af19f 100644 --- a/axlearn/experiments/text/gpt/c4_trainer.py +++ b/axlearn/experiments/text/gpt/c4_trainer.py @@ -51,17 +51,17 @@ from axlearn.experiments.trainer_config_utils import TrainerConfigFn -def _vocab_cfg(size: int): - if size == 32 * 1024: +def _vocab_cfg(vocab_size: int): + if vocab_size == 32 * 1024: # Sentencepiece vocabs generated from c4/en:3.0.1. # See bpe_{32k,128k}.json for the sentencepiece settings. return config_for_function(vocab).set(sentencepiece_model_name="bpe_32k_c4.model") - if size == 128 * 1024: + if vocab_size == 128 * 1024: return config_for_function(vocab).set(sentencepiece_model_name="bpe_128k_c4.model") - if size == 128256: + if vocab_size == 128256: # TikToken. return config_for_class(FujiV3Vocabulary).set(filename="Llama-3-tokenizer.json") - raise ValueError(f"size {size} tokenizer does not exist.") + raise ValueError(f"Tokenizer with vocab size {vocab_size} does not exist.") _train_data_mixture_components = [ diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index e91fd1c9..09b159d8 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -305,9 +305,9 @@ def model_config( lm_head=lm_head_cfg, dropout_rate=dropout_rate, ) - if pad_token_id: + if pad_token_id is not None: decoder_cfg.set(pad_token_id=pad_token_id) - if eos_token_id: + if eos_token_id is not None: decoder_cfg.set(eos_token_id=eos_token_id) # Model. model_param_init = DefaultInitializer.default_config().set( diff --git a/axlearn/experiments/text/gpt/fuji.py b/axlearn/experiments/text/gpt/fuji.py index 9e145e53..786bc618 100644 --- a/axlearn/experiments/text/gpt/fuji.py +++ b/axlearn/experiments/text/gpt/fuji.py @@ -61,6 +61,7 @@ class Version(enum.Enum): V1 = 1 V2 = 2 V3 = 3 + V3_TIKTOKEN = "3-tiktoken" # Mapping from Fuji versions to vocab sizes. @@ -68,6 +69,7 @@ class Version(enum.Enum): Version.V1: 32 * 1024, Version.V2: 32 * 1024, Version.V3: 128 * 1024, + Version.V3_TIKTOKEN: 128256, } @@ -76,6 +78,7 @@ class Version(enum.Enum): Version.V1: 2048, Version.V2: 4096, Version.V3: 8192, + Version.V3_TIKTOKEN: 8192, } @@ -83,6 +86,7 @@ class Version(enum.Enum): Version.V1: 1e4, Version.V2: 1e4, Version.V3: 5e5, + Version.V3_TIKTOKEN: 5e5, } @@ -99,6 +103,13 @@ class Version(enum.Enum): "70B": 2 * (1024**4), # 2T tokens }, Version.V3: { + "test": 15 * (1024**4), # 15T tokens + "1B": 15 * (1024**4), # 15T tokens + "3B": 15 * (1024**4), # 15T tokens + "7B": 15 * (1024**4), # 15T tokens + "70B": 15 * (1024**4), # 15T tokens + }, + Version.V3_TIKTOKEN: { "test": 15 * (1024**4), # 15T tokens "1B": 15 * (1024**4), # 15T tokens "3B": 15 * (1024**4), # 15T tokens @@ -123,7 +134,7 @@ def get_trainer_kwargs( # Whether to use grouped query attention. num_kv_heads = None - if version == Version.V3: + if version in (Version.V3, Version.V3_TIKTOKEN): num_kv_heads = 8 rope_theta = ROPE_THETA[version] @@ -421,7 +432,7 @@ def get_trainer_kwargs( raise NotImplementedError(f"Unknown model size {model_size}.") model_kwargs = trainer_kwargs.pop("model_kwargs") model_kwargs.setdefault("vocab_size", vocab_size) - if version == Version.V3 and vocab_size == 128256: # tiktoken tokenizer + if version == Version.V3_TIKTOKEN: # tiktoken tokenizer model_kwargs["pad_token_id"] = 128004 model_kwargs["eos_token_id"] = 128001 trainer_kwargs["model_cfg"] = model_config(**model_kwargs) @@ -466,6 +477,8 @@ def model_config( flash_attention: Whether to enable flash attention. stack_cfg: The transformer stack config. If None, defaults to a RepeatedTransformerLayer. + pad_token_id: Int ID of the inputs to be masked for self-attention. + eos_token_id: Int ID of the end of sequence token id. Returns: A causal LM config. @@ -520,20 +533,17 @@ def trainer_configs( """ arch = "fuji" config_map = {} - for version, model_size, flash_attention, tiktoken in itertools.product( - Version, MODEL_SIZES, [True, False], [True, False] + for version, model_size, flash_attention in itertools.product( + Version, MODEL_SIZES, [True, False] ): if model_size not in TOTAL_TOKENS[version]: # This combination does not exist. continue - if version != Version.V3 and tiktoken: # Only V3 has TikToken option. - continue - suffix = "-flash" if flash_attention else "" vocab_size = VOCAB_SIZE[version] - if tiktoken: - suffix += "-tiktoken" - vocab_size = 128256 config_name = make_config_name( - arch=arch, model_size=model_size, version=f"v{version.value}", suffix=suffix + arch=arch, + model_size=model_size, + version=f"v{version.value}", + suffix="-flash" if flash_attention else "", ) kwargs = get_trainer_kwargs( model_size, vocab_size=vocab_size, version=version, flash_attention=flash_attention @@ -546,10 +556,7 @@ def trainer_configs( max_sequence_length=max_sequence_length, ), evalers=evaler_config_dict( - eval_input_sources( - vocab_size=vocab_size, - max_sequence_length=max_sequence_length, - ), + eval_input_sources(vocab_size=vocab_size, max_sequence_length=max_sequence_length), ), **kwargs, ) diff --git a/axlearn/experiments/text/gpt/param_converter_test.py b/axlearn/experiments/text/gpt/param_converter_test.py index 9d67d0fc..3d8bd2c2 100644 --- a/axlearn/experiments/text/gpt/param_converter_test.py +++ b/axlearn/experiments/text/gpt/param_converter_test.py @@ -27,6 +27,7 @@ # Use cpu for the test. jax.config.update("jax_platform_name", "cpu") +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.2-1B/blob/main/config.json config_dict_1b = { "vocab_size": 128256, "hidden_size": 2048, @@ -57,7 +58,9 @@ "torch_dtype": "bfloat16", "architectures": ["LlamaForCausalLM"], } +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.2-3B/blob/main/config.json config_dict_3b = {"hidden_size": 3072, "num_attention_heads": 24, "num_hidden_layers": 28} +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.1-8B/blob/main/config.json config_dict_8b = { "hidden_size": 4096, "intermediate_size": 14336, @@ -71,6 +74,7 @@ }, "tie_word_embeddings": False, } +# Parameters are based on https://huggingface.co/meta-llama/Llama-3.1-70B/blob/main/config.json config_dict_70b = { "hidden_size": 8192, "intermediate_size": 28672, diff --git a/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py index f60e6f70..7d097fa1 100644 --- a/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py +++ b/axlearn/experiments/text/gpt/vocabulary_fuji_v3.py @@ -102,11 +102,12 @@ def __init__(self, filename: str): filename = os.path.join(data_dir, "tokenizers", "hf", filename) if filename.startswith("gs:") or filename.startswith("s3:"): # Create a different file for each usage. - tmp = tempfile.mkdtemp() - path = os.path.join(tmp, "tokenizer.json") - fs.copy(filename, path) - filename = path - self._tokenizer = Tokenizer.from_file(filename) + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, "tokenizer.json") + fs.copy(filename, path) + self._tokenizer = Tokenizer.from_file(path) + else: + self._tokenizer = Tokenizer.from_file(filename) self.vocab = self._tokenizer.get_vocab() self.tokenizer = FujiInnerTokenizer(self._tokenizer)