diff --git a/llmc/__main__.py b/llmc/__main__.py index 79b44a22..1c43ca12 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -9,6 +9,7 @@ import yaml from easydict import EasyDict from loguru import logger +from torch.distributed import destroy_process_group, init_process_group from llmc.compression.quantization import * from llmc.compression.sparsification import * @@ -111,12 +112,19 @@ def main(config): llmc_start_time = time.time() parser = argparse.ArgumentParser() parser.add_argument('--config', type=str, required=True) + parser.add_argument('--task_id', type=str, required=True) args = parser.parse_args() with open(args.config, 'r') as file: config = yaml.safe_load(file) config = EasyDict(config) + init_process_group(backend='nccl') + torch.cuda.set_device(int(os.environ['LOCAL_RANK'])) + + if int(os.environ['RANK']) != 0: + logger.remove() + check_config(config) logger.info(f'args: {args}') @@ -124,7 +132,9 @@ def main(config): print_important_package_version() - seed_all(config.base.seed) + logger.info(f'WORLD_SIZE : {int(os.environ["WORLD_SIZE"])}') + + seed_all(config.base.seed + int(os.environ['RANK'])) # mkdirs if 'save' in config: @@ -149,6 +159,8 @@ def main(config): main(config) + destroy_process_group() + llmc_end_time = time.time() llmc_duration_time = llmc_end_time - llmc_start_time logger.info(f'llmc_duration_time: {llmc_duration_time} s') diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index 8a2b291e..6800135e 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -1,6 +1,8 @@ import gc +import os import torch +import torch.distributed as dist import torch.nn as nn from loguru import logger @@ -136,6 +138,8 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs) best_error = loss_mean best_scales = scales_mean best_scales = best_scales.view(-1) + dist.all_reduce(best_scales, op=dist.ReduceOp.SUM) + best_scales /= int(os.environ['WORLD_SIZE']) del org_out_dict gc.collect() torch.cuda.empty_cache() diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 8d0e025a..d5f6b1f1 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -1,10 +1,12 @@ import functools import gc import json +import os from collections import defaultdict from functools import partial import torch +import torch.distributed as dist import torch.nn as nn from loguru import logger @@ -487,6 +489,12 @@ def auto_clip(self, block, input_feat, n_sample_token): n_sample_token=n_sample_token, ) + dist.all_reduce(max_val, op=dist.ReduceOp.SUM) + max_val /= int(os.environ['WORLD_SIZE']) + + dist.all_reduce(min_val, op=dist.ReduceOp.SUM) + min_val /= int(os.environ['WORLD_SIZE']) + self.apply_clip(m, min_val, max_val, n) @torch.no_grad() @@ -802,6 +810,8 @@ def contiguous_params(self): @torch.no_grad() def save_model(self, path): + if int(os.environ['RANK']) != 0: + return if self.online_rotate: self.contiguous_params() if self.config.model.type == 'Llava': diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 8cfb36a1..3dd2e7a0 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -1,3 +1,4 @@ +import os from abc import ABCMeta import torch @@ -84,6 +85,10 @@ def get_calib_samples(self): def get_calib_dataset(self): samples = self.get_calib_samples() + logger.info(f'len(samples) all : {len(samples)}') + assert len(samples) % int(os.environ['WORLD_SIZE']) == 0 + samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])] + logger.info(f'len(samples) rank : {len(samples)}') calib_samples = [] if self.calib_bs < 0: batch = torch.cat(samples, dim=0) diff --git a/scripts/export_rtn_llama.sh b/scripts/export_rtn_llama.sh deleted file mode 100644 index 7d332604..00000000 --- a/scripts/export_rtn_llama.sh +++ /dev/null @@ -1,9 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -python -m llmc --config ../configs/quantization/RTN/rtn_w4a16.yml diff --git a/scripts/run_adadim_llama.sh b/scripts/run_adadim_llama.sh deleted file mode 100644 index 28e2a4ba..00000000 --- a/scripts/run_adadim_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/AdaDim/adadim_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_awq_llama.sh b/scripts/run_awq_llama.sh deleted file mode 100644 index 3d638583..00000000 --- a/scripts/run_awq_llama.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/Awq/awq_w4a16_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid - diff --git a/scripts/run_dgq_llama.sh b/scripts/run_dgq_llama.sh deleted file mode 100644 index aa3c109b..00000000 --- a/scripts/run_dgq_llama.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/DGQ/dgq_w4a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid - diff --git a/scripts/run_gptq_llama.sh b/scripts/run_gptq_llama.sh deleted file mode 100644 index 3d9e7cbe..00000000 --- a/scripts/run_gptq_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/GPTQ/gptq_w4a16_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_gptq_owq_llama.sh b/scripts/run_gptq_owq_llama.sh deleted file mode 100644 index 7e0f6d22..00000000 --- a/scripts/run_gptq_owq_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/GPTQ/gptq_owq_w4a16_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_hqq_llama.sh b/scripts/run_hqq_llama.sh deleted file mode 100644 index 7f995c9a..00000000 --- a/scripts/run_hqq_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/HQQ/hqq_w4a16_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_in_tmux_sequence.sh b/scripts/run_in_tmux_sequence.sh deleted file mode 100644 index 6534e1ae..00000000 --- a/scripts/run_in_tmux_sequence.sh +++ /dev/null @@ -1,25 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - - -task_name=rtn_w8a8_fakequant_eval -echo "${task_name} running..." -python -m llmc --config ../configs/quantization/RTN/rtn_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 - - -task_name=smoothquant_llama_w8a8_fakequant_eval_general -echo "${task_name} running..." -python -m llmc --config ../configs/quantization/SmoothQuant/smoothquant_llama_w8a8_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 - - -task_name=osplus_llama_w8a8_fakequant_eval_general -echo "${task_name} running..." -python -m llmc --config ../configs/quantization/OsPlus/osplus_llama_w8a8_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 diff --git a/scripts/run_llmc.sh b/scripts/run_llmc.sh new file mode 100644 index 00000000..c2a1f792 --- /dev/null +++ b/scripts/run_llmc.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +# export CUDA_VISIBLE_DEVICES=0,1 + +llmc=llmc_path +export PYTHONPATH=$llmc:$PYTHONPATH + +task_name=awq_w4a16_fakequant_eval +config=${llmc}/configs/quantization/Awq/awq_w4a16_fakequant_eval.yml + +nnodes=1 +nproc_per_node=1 + + +MASTER_ADDR=127.0.0.1 +MASTER_PORT=$((10000 + RANDOM % 20000)) + +RANDOM=$(python -c 'import uuid; print(uuid.uuid4())') +task_id=$RANDOM + +nohup \ +torchrun \ +--nnodes $nnodes \ +--nproc_per_node $nproc_per_node \ +--rdzv_id $task_id \ +--rdzv_backend c10d \ +--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \ +${llmc}/llmc/__main__.py --config $config --task_id $task_id \ +> ${task_name}.log 2>&1 & + +sleep 2 +ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid + +# You can kill this program by +# xargs kill -9 < xxx.pid +# xxx.pid is ${task_name}.pid file diff --git a/scripts/run_llmint8_llama.sh b/scripts/run_llmint8_llama.sh deleted file mode 100644 index a4261cb6..00000000 --- a/scripts/run_llmint8_llama.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/LlmInt8/llmint8_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid - diff --git a/scripts/run_ntweak_llama.sh b/scripts/run_ntweak_llama.sh deleted file mode 100644 index b94e260a..00000000 --- a/scripts/run_ntweak_llama.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/NormTweaking/ntweak_llama_w4a16_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid - diff --git a/scripts/run_omniq_llama.sh b/scripts/run_omniq_llama.sh deleted file mode 100644 index 5f7241a6..00000000 --- a/scripts/run_omniq_llama.sh +++ /dev/null @@ -1,16 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/OmniQuant/omniq_llama_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid - diff --git a/scripts/run_omniq_mistral.sh b/scripts/run_omniq_mistral.sh deleted file mode 100644 index 0164521a..00000000 --- a/scripts/run_omniq_mistral.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/OmniQuant/omniq_mistral_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid diff --git a/scripts/run_omniq_opt.sh b/scripts/run_omniq_opt.sh deleted file mode 100644 index 2e0da4b4..00000000 --- a/scripts/run_omniq_opt.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/OmniQuant/omniq_opt_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_osplus_llama.sh b/scripts/run_osplus_llama.sh deleted file mode 100644 index 98336462..00000000 --- a/scripts/run_osplus_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/OsPlus/osplus_llama_w8a8_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid diff --git a/scripts/run_osplus_opt.sh b/scripts/run_osplus_opt.sh deleted file mode 100644 index 37f66615..00000000 --- a/scripts/run_osplus_opt.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/OsPlus/osplus_opt_w8a8_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid diff --git a/scripts/run_quarot_llama.sh b/scripts/run_quarot_llama.sh deleted file mode 100644 index 5b00ede6..00000000 --- a/scripts/run_quarot_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/QuaRot/quarot_w4a4.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_quik_llama.sh b/scripts/run_quik_llama.sh deleted file mode 100644 index 818069d8..00000000 --- a/scripts/run_quik_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/QUIK/quik_w4a4_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_rtn_llama.sh b/scripts/run_rtn_llama.sh deleted file mode 100644 index 8d328a7f..00000000 --- a/scripts/run_rtn_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/RTN/rtn_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_rtn_llama_static.sh b/scripts/run_rtn_llama_static.sh deleted file mode 100644 index cc7e62da..00000000 --- a/scripts/run_rtn_llama_static.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/RTN/rtn_w8a8_pertensor_static.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_shortgpt_llama.sh b/scripts/run_shortgpt_llama.sh deleted file mode 100644 index f56c090a..00000000 --- a/scripts/run_shortgpt_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/sparsification/ShortGPT/shortgpt.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_smoothquant_llama.sh b/scripts/run_smoothquant_llama.sh deleted file mode 100644 index 6715d68e..00000000 --- a/scripts/run_smoothquant_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/SmoothQuant/smoothquant_llama_w8a8_fakequant_eval_general.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid diff --git a/scripts/run_smoothquant_opt.sh b/scripts/run_smoothquant_opt.sh deleted file mode 100644 index 38f7b616..00000000 --- a/scripts/run_smoothquant_opt.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/SmoothQuant/smoothquant_opt_w8a8_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid diff --git a/scripts/run_spqr_llama.sh b/scripts/run_spqr_llama.sh deleted file mode 100644 index 270c6161..00000000 --- a/scripts/run_spqr_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/quantization/SpQR/spqr_w4a16_fakequant_eval.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file diff --git a/scripts/run_wanda_llama.sh b/scripts/run_wanda_llama.sh deleted file mode 100644 index 96b31c51..00000000 --- a/scripts/run_wanda_llama.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/bash - -gpu_id=0 -export CUDA_VISIBLE_DEVICES=$gpu_id - -llmc=llmc_path -export PYTHONPATH=$llmc:$PYTHONPATH - -task_name=llm_quant_exp - -nohup \ -python -m llmc --config ../configs/sparsification/Wand/wanda.yml \ -> ${task_name}.log 2>&1 & - -echo $! > ${task_name}.pid \ No newline at end of file