diff --git a/06_gpu_and_ml/tgi_mixtral.py b/06_gpu_and_ml/tgi_mixtral.py index fd0d42784..7d61767e3 100644 --- a/06_gpu_and_ml/tgi_mixtral.py +++ b/06_gpu_and_ml/tgi_mixtral.py @@ -25,10 +25,13 @@ GPU_CONFIG = gpu.A100(memory=40, count=4) MODEL_ID = "mistralai/Mixtral-8x7B-Instruct-v0.1" +MODEL_REVISION = "f1ca00645f0b1565c7f9a1c863d2be6ebf896b04" # Add `["--quantize", "gptq"]` for TheBloke GPTQ models. LAUNCH_FLAGS = [ "--model-id", MODEL_ID, + "--revision", + MODEL_REVISION, "--port", "8000", ] @@ -52,6 +55,8 @@ def download_model(): "text-generation-server", "download-weights", MODEL_ID, + "--revision", + MODEL_REVISION, ] )