Skip to content

Commit

Permalink
Rewrite alpaca-lora
Browse files Browse the repository at this point in the history
  • Loading branch information
erikbern committed Dec 21, 2023
1 parent 98a802c commit f444aa9
Showing 1 changed file with 18 additions and 23 deletions.
41 changes: 18 additions & 23 deletions 06_gpu_and_ml/alpaca/alpaca_lora.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,13 @@
import sys

from modal import Image, Stub, method
from modal import Image, Stub, build, enter, method

# Define a function for downloading the models, that will run once on image build.
# This allows the weights to be present inside the image for faster startup.

base_model = "luodian/llama-7b-hf"
lora_weights = "tloen/alpaca-lora-7b"


def download_models():
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer

model = LlamaForCausalLM.from_pretrained(
base_model,
)
PeftModel.from_pretrained(model, lora_weights)
LlamaTokenizer.from_pretrained(base_model)


# Alpaca-LoRA is distributed as a public Github repository and the repository is not
# installable by `pip`, so instead we install the repository by cloning it into our Modal
# image.
Expand Down Expand Up @@ -58,8 +46,14 @@ def download_models():
"torchvision~=0.16",
"sentencepiece==0.1.99",
)
.run_function(download_models)
)

with image.imports():
import torch
from generate import generate_prompt
from peft import PeftModel
from transformers import GenerationConfig, LlamaForCausalLM, LlamaTokenizer

stub = Stub(name="example-alpaca-lora", image=image)

# The Alpaca-LoRA model is integrated into model as a Python class with an __enter__
Expand All @@ -73,16 +67,21 @@ def download_models():

@stub.cls(gpu="A10G")
class AlpacaLoRAModel:
def __enter__(self):
@build()
def download_models(self):
model = LlamaForCausalLM.from_pretrained(
base_model,
)
PeftModel.from_pretrained(model, lora_weights)
LlamaTokenizer.from_pretrained(base_model)

@enter()
def enter(self):
"""
Container-lifeycle method for model setup. Code is taken from
https://github.com/tloen/alpaca-lora/blob/main/generate.py and minor
modifications are made to support usage in a Python class.
"""
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizer

load_8bit = False
device = "cuda" if torch.cuda.is_available() else "cpu"

Expand Down Expand Up @@ -146,10 +145,6 @@ def evaluate(
max_new_tokens=128,
**kwargs,
):
import torch
from generate import generate_prompt
from transformers import GenerationConfig

prompt = generate_prompt(instruction, input)
inputs = self.tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(self.device)
Expand Down

0 comments on commit f444aa9

Please sign in to comment.