diff --git a/gptq/gptj.py b/gptq/gptj.py index 5ebf7b3e..23f4f96f 100644 --- a/gptq/gptj.py +++ b/gptq/gptj.py @@ -535,6 +535,16 @@ def sync(): quantizers = gptj_sequential(model, dataloader, DEV) print(time.time() - tick) + if args.benchmark: + gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + gptj_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, : args.benchmark] + benchmark(model, input_ids, check=args.check) + if args.eval: datasets = ["wikitext2", "ptb", "c4"] if args.new_eval: @@ -550,6 +560,9 @@ def sync(): print(dataset) gptj_eval(model, testloader, DEV) + if args.load: + exit() + if args.save: gptj_pack(model, quantizers, args.wbits, args.groupsize) torch.save(model.state_dict(), args.save) @@ -559,13 +572,3 @@ def sync(): from safetensors.torch import save_file as safe_save safe_save(model.state_dict(), args.save_safetensors) - - if args.benchmark: - gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - gptj_multigpu(model, gpus) - else: - model = model.to(DEV) - if args.benchmark: - input_ids = next(iter(dataloader))[0][:, : args.benchmark] - benchmark(model, input_ids, check=args.check) diff --git a/gptq/gptneox.py b/gptq/gptneox.py index 69192b19..ada4c731 100644 --- a/gptq/gptneox.py +++ b/gptq/gptneox.py @@ -545,9 +545,6 @@ def sync(): input_ids = next(iter(dataloader))[0][:, : args.benchmark] benchmark(model, input_ids, check=args.check) - if args.load: - exit() - if args.eval: datasets = ["wikitext2", "ptb", "c4"] if args.new_eval: @@ -559,6 +556,9 @@ def sync(): print(dataset) gptneox_eval(model, testloader, DEV) + if args.load: + exit() + if args.save: gptneox_pack(model, quantizers, args.wbits, args.groupsize) torch.save(model.state_dict(), args.save) diff --git a/gptq/llama.py b/gptq/llama.py index 30a7dc50..a2c64fd5 100644 --- a/gptq/llama.py +++ b/gptq/llama.py @@ -481,9 +481,6 @@ def sync(): if args.benchmark: input_ids = next(iter(dataloader))[0][:, :args.benchmark] benchmark(model, input_ids, check=args.check) - - if args.load: - exit() if args.eval: datasets = ['wikitext2', 'ptb', 'c4'] @@ -496,6 +493,9 @@ def sync(): print(dataset) llama_eval(model, testloader, DEV) + if args.load: + exit() + if args.save: llama_pack(model, quantizers, args.wbits, args.groupsize) torch.save(model.state_dict(), args.save) diff --git a/gptq/mpt.py b/gptq/mpt.py index b67514df..ec7eae2a 100644 --- a/gptq/mpt.py +++ b/gptq/mpt.py @@ -519,6 +519,16 @@ def sync(): quantizers = mpt_sequential(model, dataloader, DEV) print(time.time() - tick) + if args.benchmark: + gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + mpt_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, : args.benchmark] + benchmark(model, input_ids, check=args.check) + if args.eval: datasets = ["wikitext2", "ptb", "c4"] if args.new_eval: @@ -530,6 +540,9 @@ def sync(): print(dataset) mpt_eval(model, testloader, DEV) + if args.load: + exit() + if args.save: mpt_pack(model, quantizers, args.wbits, args.groupsize) torch.save(model.state_dict(), args.save) @@ -539,15 +552,3 @@ def sync(): from safetensors.torch import save_file as safe_save safe_save(model.state_dict(), args.save_safetensors) - - if args.benchmark: - gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] - if len(gpus) > 1: - mpt_multigpu(model, gpus) - else: - model = model.to(DEV) - if args.benchmark: - input_ids = next(iter(dataloader))[0][:, : args.benchmark] - benchmark(model, input_ids, check=args.check) - if args.load: - exit() diff --git a/gptq/opt.py b/gptq/opt.py index 5751ae7d..ae9948e5 100644 --- a/gptq/opt.py +++ b/gptq/opt.py @@ -484,9 +484,6 @@ def sync(): input_ids = next(iter(dataloader))[0][:, :args.benchmark] benchmark(model, input_ids, check=args.check) - if args.load: - exit() - if args.eval: datasets = ['wikitext2', 'ptb', 'c4'] if args.new_eval: @@ -498,6 +495,9 @@ def sync(): print(dataset) opt_eval(model, testloader, DEV) + if args.load: + exit() + if args.save: opt_pack(model, quantizers, args.wbits, args.groupsize) torch.save(model.state_dict(), args.save)