Skip to content

Commit

Permalink
add gemma (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
jyj0w0 authored Nov 18, 2024
1 parent d30999d commit 12a6b80
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions dags/inference/configs/trt_llm_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_trt_llm_gpu_config(
"gsutil -m cp -r gs://tohaowu/llama_3_8B_Instruct_HF_model .",
"gsutil -m cp -r gs://tohaowu/llama_3.1_70B_Instruct_HF_model .",
"gsutil -m cp -r gs://tohaowu/Mixtral-8x22B-Instruct-v0.1 .",
"gsutil -m cp -r gs://yijiaj/gemma/gemma-2-27b-it .",
"sudo apt-get update",
"sudo apt-get -y install git git-lfs",
"git clone https://github.com/NVIDIA/TensorRT-LLM.git",
Expand Down Expand Up @@ -103,10 +104,14 @@ def get_trt_llm_gpu_config(
"trtllm-build --checkpoint_dir /scratch/tllm_checkpoint_8gpu_tp8 --output_dir /scratch/llama/70B/trt_engines/fp16/8-gpu/ --gemm_plugin auto",
"python ../llama/convert_checkpoint.py --model_dir /scratch/Mixtral-8x22B-Instruct-v0.1 --output_dir /scratch/tllm_checkpoint_mixtral_8gpu --dtype float16 --tp_size 8 --moe_tp_size 2 --moe_ep_size 4",
"trtllm-build --checkpoint_dir /scratch/tllm_checkpoint_mixtral_8gpu --output_dir /scratch/trt_engines/mixtral/tp2ep4",
"cd ../gemma",
"python3 convert_checkpoint.py --ckpt-type hf --model-dir /scratch/gemma-2-27b-it/ --dtype bfloat16 --world-size 1 --output-model-dir /scratch/checkpoints/tmp_27b_it_tensorrt_llm/bf16/tp1/",
"trtllm-build --checkpoint_dir /scratch/checkpoints/tmp_27b_it_tensorrt_llm/bf16/tp1/ --gemm_plugin auto --max_batch_size 8 --max_input_len 3000 --max_seq_len 3100 --output_dir /scratch/gemma2/27b/bf16/1-gpu/",
"cd ../../benchmarks/python",
"python benchmark.py -m dec --engine_dir /scratch/llama/8B/trt_engines/fp16/1-gpu/ --csv",
"OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 mpirun -n 8 python benchmark.py -m dec --engine_dir /scratch/llama/70B/trt_engines/fp16/8-gpu/ --csv",
"OMPI_ALLOW_RUN_AS_ROOT=1 OMPI_ALLOW_RUN_AS_ROOT_CONFIRM=1 mpirun -n 8 python benchmark.py -m dec --engine_dir /scratch/trt_engines/mixtral/tp2ep4 --csv",
"python benchmark.py -m dec --engine_dir /scratch/gemma2/27b/bf16/1-gpu/ --dtype bfloat16 --csv",
make_jsonl_convert_cmd,
f"python jsonl_converter.py {jsonl_output_path}",
)
Expand Down

0 comments on commit 12a6b80

Please sign in to comment.