Skip to content

Commit

Permalink
test llama example with 8-t4
Browse files Browse the repository at this point in the history
  • Loading branch information
zzhhjjj committed May 2, 2024
1 parent da7cf7a commit a3efab8
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 6 deletions.
3 changes: 0 additions & 3 deletions .github/workflows/fa2_unit_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,3 @@ jobs:
# NOTE: -m fa2 will only run the unit tests that have the mark
# "fa2" (these are FA2-related tests)
run: pytest -m fa2 --color=yes --durations=0 --ignore tests/fp8 --verbose tests/

- name: Run tiny Llama example
run: ./examples/train_tiny_llama.sh
7 changes: 5 additions & 2 deletions .github/workflows/llama_tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,8 @@ jobs:
- name: Show installed libraries and their versions
run: pip freeze | tee installed.txt

- name: Run Llama loss tests
run: pytest -sv tests/test_train_llama.py
- name: Run Llama example
run: pytest --verbose tests/test_llama.py::test_tiny_llama

- name: Run Llama loss test
run: pytest --verbose tests/test_llama.py::test_train_llama
29 changes: 28 additions & 1 deletion tests/test_train_llama.py → tests/test_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
TRAIN_SCRIPT = "run_train.py"
NUM_GPUS = 8

TINY_LLLAMA_CONFIG_FILE = "examples/config_tiny_llama.yaml"
TINY_LLLAMA_CREATE_CONFIG_FILE = "examples/config_tiny_llama.py"

## Experiment results:
## 100 steps: 3.28
## 160 steps: 2.83
Expand Down Expand Up @@ -42,7 +45,7 @@ def extract_loss(line):
raise ValueError(f"Could not extract loss value from line: {line}")


def test_tiny_llama():
def test_train_llama():
# create CONFIG_FILE
cmd = f"python {CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
Expand Down Expand Up @@ -76,6 +79,30 @@ def test_tiny_llama():
assert process.returncode == 0


# also run the tiny llama example. Only want to assert it can be ran.
def test_tiny_llama():
# create CONFIG_FILE
cmd = f"python {TINY_LLLAMA_CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)

# run training
# set DISABLE_FLASH_ATTENTION=1 to replace flash attention implementations
cmd = f'DISABLE_FLASH_ATTENTION=1 FI_PROVIDER="efa" CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={NUM_GPUS} --rdzv_endpoint=localhost:29800 {TRAIN_SCRIPT} --config-file {TINY_LLLAMA_CONFIG_FILE}'
os.setpgrp() # create new process group, become its leader
atexit.register(exit_with_children) # kill all children processes when this process exits

process = subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
while True:
line = process.stdout.readline()
if process.poll() is not None and line == b"":
break
if line:
print(line.decode("utf-8"), end="")

process.wait() # Wait for the process to finish
assert process.returncode == 0


if __name__ == "__main__":
cmd = f"python {CREATE_CONFIG_FILE}"
subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
Expand Down

0 comments on commit a3efab8

Please sign in to comment.