Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Nov 7, 2023
1 parent 7ed660e commit 1c65fa6
Showing 1 changed file with 16 additions and 8 deletions.
24 changes: 16 additions & 8 deletions mlc_llm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,17 @@ def validate_config(model_path: str):
), f"Model type {config['model_type']} not supported."


def get_cuda_sm_version():
major, minor = parse_compute_version(tvm.cuda(0).compute_version)

if major == 8:
sm = 80
else:
sm = 10 * major + minor

return sm


def mod_transform_before_build(
mod: tvm.IRModule,
param_manager: param_manager.ParamManager,
Expand Down Expand Up @@ -550,13 +561,7 @@ def mod_transform_before_build(
if len(patterns) > 0:
os.makedirs("./tmp", exist_ok=True)

major, minor = parse_compute_version(tvm.cuda(0).compute_version)

if major == 8:
sm = 80
else:
sm = 10 * major + minor

sm = get_cuda_sm_version()
options = {"cutlass": {"sm": sm, "find_first_valid": False}}

if hasattr(config, "rms_norm_eps"):
Expand Down Expand Up @@ -802,9 +807,12 @@ def build_model_from_args(args: argparse.Namespace):
if args.num_shards > 1 and use_ft_quant:
preprocessed = []
weight_preprocess_func = tvm.get_global_func("cutlass.ft_preprocess_weight")
is_int4 = args.quantization.name == "q4f16_ft"
sm = get_cuda_sm_version()

for p in params:
if p.dtype == "int8":
preprocessed.append(weight_preprocess_func(p, 80, True))
preprocessed.append(weight_preprocess_func(p, sm, is_int4))
else:
preprocessed.append(p)

Expand Down

0 comments on commit 1c65fa6

Please sign in to comment.