Skip to content

Commit

Permalink
Support RedPajama-INCITE-Chat-3B-v1 (octoml#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored May 18, 2023
1 parent 6808bcf commit 5184eb8
Show file tree
Hide file tree
Showing 32 changed files with 7,328 additions and 327 deletions.
35 changes: 19 additions & 16 deletions build.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import argparse
import json
import os
import pickle
import json
from typing import List
import json
from typing import Any, Dict, List

import tvm
import tvm.testing
from tvm import meta_schedule as ms
from tvm import relax

import mlc_llm
Expand Down Expand Up @@ -53,7 +53,6 @@ def _parse_args():
)
args.add_argument("--debug-dump", action="store_true", default=False)
args.add_argument("--debug-load-script", action="store_true", default=False)

args.add_argument(
"--llvm-mingw",
type=str,
Expand All @@ -71,7 +70,12 @@ def _parse_args():
parsed = _setup_model_path(parsed)

parsed.db_path = parsed.db_path or os.path.join("log_db", parsed.model)

if os.path.exists(parsed.db_path):
ms.database.create(work_dir=parsed.db_path)
else:
print(
f"WARNING: --db-path does not point to a valid database: {parsed.db_path}"
)
utils.parse_target(parsed)
utils.argparse_postproc_common(parsed)

Expand Down Expand Up @@ -129,7 +133,7 @@ def _setup_model_path(args):
if args.model == "auto":
raise ValueError(f"Please specify either the model_path or the hf_path.")

print(f"Using model path {args.model_path}")
print(f'Using path "{args.model_path}" for model "{args.model}"')
return args


Expand All @@ -139,9 +143,7 @@ def validate_config(model_path: str):
), "Model path must contain valid config file."
with open(os.path.join(model_path, "config.json")) as f:
config = json.load(f)
assert ("model_type" in config) and (
"_name_or_path" in config
), "Invalid config format."
assert "model_type" in config, "Invalid config format."
assert (
config["model_type"] in utils.supported_model_types
), f"Model type {config['model_type']} not supported."
Expand Down Expand Up @@ -225,6 +227,7 @@ def mod_transform_before_build(
def dump_default_mlc_chat_config(args):
params_path = os.path.join(args.artifact_path, "params")
config = dict()
config: Dict[str, Any] = {}
config["model_lib"] = f"{args.model}-{args.quantization.name}"
config["local_id"] = f"{args.model}-{args.quantization.name}"
config["conv_template"] = args.conv_template
Expand All @@ -244,8 +247,6 @@ def build(mod_deploy: tvm.IRModule, args: argparse.Namespace) -> None:
target_kind = args.target_kind
debug_dump_script(mod_deploy, "mod_before_build.py", args)
if target_kind != "cpu":
from tvm import meta_schedule as ms

if os.path.exists(args.db_path):
db = ms.database.create(work_dir=args.db_path)
else:
Expand Down Expand Up @@ -297,8 +298,7 @@ def dump_split_tir(mod: tvm.IRModule):
o_f.write(template.format(content=mod_dynamic.script()))


if __name__ == "__main__":
ARGS = _parse_args()
def main():
os.makedirs(ARGS.artifact_path, exist_ok=True)
os.makedirs(os.path.join(ARGS.artifact_path, "debug"), exist_ok=True)
cache_path = os.path.join(
Expand All @@ -311,9 +311,7 @@ def dump_split_tir(mod: tvm.IRModule):
if ARGS.model_category == "llama":
mod, params = llama.get_model(ARGS, config)
elif ARGS.model_category == "gpt_neox":
mod, params = gpt_neox.get_model(
ARGS.model, ARGS.model_path, ARGS.quantization.model_dtype, config
)
mod, params = gpt_neox.get_model(ARGS, config)
elif ARGS.model_category == "moss":
mod, params = moss.get_model(ARGS, config)
else:
Expand All @@ -332,3 +330,8 @@ def dump_split_tir(mod: tvm.IRModule):
dump_split_tir(mod)
build(mod, ARGS)
dump_default_mlc_chat_config(ARGS)


if __name__ == "__main__":
ARGS = _parse_args()
main()
Loading

0 comments on commit 5184eb8

Please sign in to comment.