forked from databricks/databricks-ml-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path08_load_from_marketplace.py
113 lines (83 loc) · 3.79 KB
/
08_load_from_marketplace.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Databricks notebook source
# MAGIC %md
# MAGIC # Loading Llama 2 13B Chat model from Marketplace
# MAGIC
# MAGIC This example notebook demonstrates how to load the Llama 2 13B Chat model from a Databricks Marketplace's Catalog ([see announcement blog](https://www.databricks.com/blog/llama-2-foundation-models-available-databricks-lakehouse-ai)).
# MAGIC
# MAGIC Environment:
# MAGIC - MLR: 13.3 ML
# MAGIC - Instance: `g5.12xlarge` on AWS, `Standard_NV72ads_A10_v5` or `Standard_NC24ads_A100_v4` on Azure
# COMMAND ----------
# To access models in Unity Catalog, ensure that MLflow is up to date
%pip install --upgrade "mlflow-skinny[databricks]>=2.4.1"
dbutils.library.restartPython()
# COMMAND ----------
import mlflow
mlflow.set_registry_uri("databricks-uc")
catalog_name = "databricks_llama_2_models" # Default catalog name when installing the model from Databricks Marketplace
version = 1
# Create a Spark UDF to generate the response to a prompt
generate = mlflow.pyfunc.spark_udf(
spark, f"models:/{catalog_name}.models.llama_2_13b_chat_hf/{version}", "string"
)
# COMMAND ----------
# MAGIC %md
# MAGIC The Spark UDF `generate` could inference on Spark DataFrames.
# COMMAND ----------
import pandas as pd
# To have more than 1 input sequences in the same batch for inference, more GPU memory would be needed; swap to more powerful GPUs, or use Databricks Model Serving
df = spark.createDataFrame(
pd.DataFrame(
{
"text": [
"What is a large language model?",
# "Write a short announcement of Llama 2 models in Databricks Marketplace.",
]
}
)
)
display(df)
generated_df = df.select(generate(df.text).alias("generated_text"))
display(generated_df)
# COMMAND ----------
# MAGIC %md
# MAGIC We could also wrap the Spark UDF into a function that takes system prompts, and takes lists of text strings as input/output.
# COMMAND ----------
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
If a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information."""
INTRO_BLURB = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
PROMPT_FOR_GENERATION_FORMAT = """
<s>[INST]<<SYS>>
{system_prompt}
<</SYS>>
{instruction}
[/INST]
""".format(
system_prompt=DEFAULT_SYSTEM_PROMPT,
instruction="{instruction}"
)
# COMMAND ----------
from typing import List
import pandas as pd
def gen_text(instructions: List[str]):
prompts = [
PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
for instruction in instructions
]
# `generate` is a Spark UDF that takes a string column as input
df = spark.createDataFrame(pd.DataFrame({"text": pd.Series(prompts)}))
generated_df = df.select(generate(df.text).alias("generated_text"))
# Get the rows of the 'generated_text' column in the dataframe 'generated_df' as a list, and truncate the instruction
generated_text_list = [
str(row.generated_text).split("[/INST]\n")[1] for row in generated_df.collect()
]
return generated_text_list
# COMMAND ----------
# To have more than 1 input sequences in the same batch for inference, more GPU memory would be needed; swap to more powerful GPUs if needed, or use Databricks Model Serving
gen_text(
[
"What is a large language model?",
# "Write a short announcement of Llama 2 models in Databricks Marketplace.",
]
)