Skip to content

Commit

Permalink
Merge pull request #84 from furiosa-ai/bert_qlv4_save_compact
Browse files Browse the repository at this point in the history
qlv4_save 수정: handling compact_causal_mask and change default values
  • Loading branch information
BeomGeunCho authored Jul 5, 2024
2 parents 5c89e8c + 6bb724a commit 42cd3ce
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
8 changes: 4 additions & 4 deletions language/bert/qlv4_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def get_args():
choices=[
"huggingface_rngd_gelu",
"mlperf_submission",
"experimental_huggingface_unsplit_packed",
"compact_causal_mask",
],
help="choose model source",
)
Expand Down Expand Up @@ -151,12 +151,12 @@ def qlv4_save():
sut = None
args.backend = "pytorch"
args.max_examples = 1
args.recalibrate = True
args.recalibrate = False
args.use_mcp = True
args.accuracy = True
args.torch_optim = "none"
args.model_script_path = (
"./quantization/model_script/Qlevel4_RGDA0-W8A8KV8-PTQ.yaml"
"./quantization/model_script/Qlevel4_RGDA0-W8A8KV8-PTQ_submission.yaml"
)

from pytorch_SUT import get_pytorch_sut
Expand All @@ -174,7 +174,7 @@ def qlv4_save():
output_path=args.output_path,
)

if args.model_source =="mlperf_submission":
if args.model_source =="mlperf_submission" or args.model_source =="compact_causal_mask" :
model = sut.model.model
else:
model= sut.model
Expand Down
10 changes: 6 additions & 4 deletions language/gpt-j/qlv4_save.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,18 @@
import joblib
import argparse

version='v3.12.1'
model_source="mlperf_submission"
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", default="./model/", help="")
parser.add_argument("--model_config", default="./ci_test_file/config.json", help="")
parser.add_argument("--model_script_path", default="./quantization/model_script/Qlevel4_RGDA0-W8A8KV8-PTQ-SMQ-rope_lm-headint8.yaml", help="")
parser.add_argument("--model_source", type = str, default = "mlperf_submission", help="the type of GPTJForCausalLM to use")
parser.add_argument('--qformat_path', type = str, default="./quantization/output/qformat_Qlevel4_RGDA0-W8A8KV8-PTQ-SMQ-mlperf_submission.yaml", help="")
parser.add_argument('--qparam_path', type = str, default="./quantization/output/qparam_Qlevel4_RGDA0-W8A8KV8-PTQ-SMQ-mlperf_submission.npy", help="")
parser.add_argument('--qlv4_prefill_out_path', type = str, default='./quantization/model_script/prefill.bin', help="")
parser.add_argument('--qlv4_decode_out_path', type = str, default='./quantization/model_script/decode.bin', help="")
parser.add_argument('--qformat_path', type = str, default=f'./quantization/output/{version}/{model_source}/qformat.yaml', help="")
parser.add_argument('--qparam_path', type = str, default=f'./quantization/output/{version}/{model_source}/qparam.npy', help="")
parser.add_argument('--qlv4_prefill_out_path', type = str, default=f'./quantization/output/{version}/{model_source}/prefill.bin', help="")
parser.add_argument('--qlv4_decode_out_path', type = str, default=f'./quantization/output/{version}/{model_source}/decode.bin', help="")
args = parser.parse_args()
return args

Expand Down

0 comments on commit 42cd3ce

Please sign in to comment.