Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ddp #26

Merged
merged 1 commit into from
Aug 11, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion llmc/__main__.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
import yaml
from easydict import EasyDict
from loguru import logger
from torch.distributed import destroy_process_group, init_process_group

from llmc.compression.quantization import *
from llmc.compression.sparsification import *
@@ -111,20 +112,29 @@ def main(config):
llmc_start_time = time.time()
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--task_id', type=str, required=True)
args = parser.parse_args()

with open(args.config, 'r') as file:
config = yaml.safe_load(file)
config = EasyDict(config)

init_process_group(backend='nccl')
torch.cuda.set_device(int(os.environ['LOCAL_RANK']))

if int(os.environ['RANK']) != 0:
logger.remove()

check_config(config)

logger.info(f'args: {args}')
logger.info(f'config:\n{json.dumps(config, ensure_ascii=False, indent=4)}')

print_important_package_version()

seed_all(config.base.seed)
logger.info(f'WORLD_SIZE : {int(os.environ["WORLD_SIZE"])}')

seed_all(config.base.seed + int(os.environ['RANK']))

# mkdirs
if 'save' in config:
@@ -149,6 +159,8 @@ def main(config):

main(config)

destroy_process_group()

llmc_end_time = time.time()
llmc_duration_time = llmc_end_time - llmc_start_time
logger.info(f'llmc_duration_time: {llmc_duration_time} s')
4 changes: 4 additions & 0 deletions llmc/compression/quantization/awq.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import gc
import os

import torch
import torch.distributed as dist
import torch.nn as nn
from loguru import logger

@@ -136,6 +138,8 @@ def search_scale_subset(self, layers_dict, input, inspect_module, subset_kwargs)
best_error = loss_mean
best_scales = scales_mean
best_scales = best_scales.view(-1)
dist.all_reduce(best_scales, op=dist.ReduceOp.SUM)
best_scales /= int(os.environ['WORLD_SIZE'])
del org_out_dict
gc.collect()
torch.cuda.empty_cache()
10 changes: 10 additions & 0 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import functools
import gc
import json
import os
from collections import defaultdict
from functools import partial

import torch
import torch.distributed as dist
import torch.nn as nn
from loguru import logger

@@ -487,6 +489,12 @@ def auto_clip(self, block, input_feat, n_sample_token):
n_sample_token=n_sample_token,
)

dist.all_reduce(max_val, op=dist.ReduceOp.SUM)
max_val /= int(os.environ['WORLD_SIZE'])

dist.all_reduce(min_val, op=dist.ReduceOp.SUM)
min_val /= int(os.environ['WORLD_SIZE'])

self.apply_clip(m, min_val, max_val, n)

@torch.no_grad()
@@ -802,6 +810,8 @@ def contiguous_params(self):

@torch.no_grad()
def save_model(self, path):
if int(os.environ['RANK']) != 0:
return
if self.online_rotate:
self.contiguous_params()
if self.config.model.type == 'Llava':
5 changes: 5 additions & 0 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
from abc import ABCMeta

import torch
@@ -84,6 +85,10 @@ def get_calib_samples(self):

def get_calib_dataset(self):
samples = self.get_calib_samples()
logger.info(f'len(samples) all : {len(samples)}')
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
logger.info(f'len(samples) rank : {len(samples)}')
calib_samples = []
if self.calib_bs < 0:
batch = torch.cat(samples, dim=0)
9 changes: 0 additions & 9 deletions scripts/export_rtn_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_adadim_llama.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_awq_llama.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_dgq_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_gptq_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_gptq_owq_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_hqq_llama.sh

This file was deleted.

25 changes: 0 additions & 25 deletions scripts/run_in_tmux_sequence.sh

This file was deleted.

36 changes: 36 additions & 0 deletions scripts/run_llmc.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

# export CUDA_VISIBLE_DEVICES=0,1

llmc=llmc_path
export PYTHONPATH=$llmc:$PYTHONPATH

task_name=awq_w4a16_fakequant_eval
config=${llmc}/configs/quantization/Awq/awq_w4a16_fakequant_eval.yml

nnodes=1
nproc_per_node=1


MASTER_ADDR=127.0.0.1
MASTER_PORT=$((10000 + RANDOM % 20000))

RANDOM=$(python -c 'import uuid; print(uuid.uuid4())')
task_id=$RANDOM

nohup \
torchrun \
--nnodes $nnodes \
--nproc_per_node $nproc_per_node \
--rdzv_id $task_id \
--rdzv_backend c10d \
--rdzv_endpoint $MASTER_ADDR:$MASTER_PORT \
${llmc}/llmc/__main__.py --config $config --task_id $task_id \
> ${task_name}.log 2>&1 &

sleep 2
ps aux | grep '__main__.py' | grep $task_id | awk '{print $2}' > ${task_name}.pid

# You can kill this program by
# xargs kill -9 < xxx.pid
# xxx.pid is ${task_name}.pid file
16 changes: 0 additions & 16 deletions scripts/run_llmint8_llama.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_ntweak_llama.sh

This file was deleted.

16 changes: 0 additions & 16 deletions scripts/run_omniq_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_omniq_mistral.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_omniq_opt.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_osplus_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_osplus_opt.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_quarot_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_quik_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_rtn_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_rtn_llama_static.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_shortgpt_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_smoothquant_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_smoothquant_opt.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_spqr_llama.sh

This file was deleted.

15 changes: 0 additions & 15 deletions scripts/run_wanda_llama.sh

This file was deleted.