Skip to content

Commit

Permalink
Update llm_diffusion_serving_app, fix linter issues
Browse files Browse the repository at this point in the history
  • Loading branch information
ravi9 committed Nov 23, 2024
1 parent ba4a2fe commit 28e1b53
Show file tree
Hide file tree
Showing 10 changed files with 302 additions and 206 deletions.
18 changes: 9 additions & 9 deletions examples/usecases/llm_diffusion_serving_app/docker/client_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,10 @@ def sd_response_postprocess(response):

def preprocess_llm_input(user_prompt, num_images=2):
template = """ Below is an instruction that describes a task. Write a response that appropriately completes the request.
Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact.
Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact.
Give the output in square brackets seperated by semicolon.
Do not generate text beyond the specified output format. Do not explain your response.
### Response:
### Response:
"""

prompt_template_with_user_input = template.format(num_images, user_prompt)
Expand Down Expand Up @@ -242,8 +242,8 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
st.markdown(
"""
### Multi-Image Generation App with TorchServe and OpenVINO
Welcome to the Multi-Image Generation Client App. This app allows you to generate multiple images
from a single text prompt. Simply input your prompt, and the app will enhance it using a LLM (Llama) and
Welcome to the Multi-Image Generation Client App. This app allows you to generate multiple images
from a single text prompt. Simply input your prompt, and the app will enhance it using a LLM (Llama) and
generate images in parallel using the Stable Diffusion with latent-consistency/lcm-sdxl model.
See [GitHub](https://github.com/pytorch/serve/tree/master/examples/usecases/llm_diffusion_serving_app) for details.
""",
Expand All @@ -252,7 +252,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt):
st.image("./img/workflow-2.png")

st.markdown(
"""<div style='background-color: #232628; font-size: 14px; padding: 10px;
"""<div style='background-color: #232628; font-size: 14px; padding: 10px;
border: 1px solid #ddd; border-radius: 5px;'>
NOTE: Initial image generation may take longer due to model warm-up. Subsequent generations will be faster !
</div>""",
Expand All @@ -274,7 +274,7 @@ def display_images_in_grid(images, captions):


def display_prompts():
prompt_container.write(f"Generated Prompts:")
prompt_container.write("Generated Prompts:")
prompt_list = ""
for i, pr in enumerate(st.session_state.llm_prompts, 1):
prompt_list += f"{i}. {pr}\n"
Expand Down Expand Up @@ -304,18 +304,18 @@ def display_prompts():

if not st.session_state.llm_prompts:
prompt_container.write(
f"Enter Image Generation Prompt and Click Generate Prompts !"
"Enter Image Generation Prompt and Click Generate Prompts !"
)
elif len(st.session_state.llm_prompts) < num_images:
prompt_container.warning(
f"""Insufficient prompts. Regenerate prompts !
f"""Insufficient prompts. Regenerate prompts !
Num Images Requested: {num_images}, Prompts Generated: {len(st.session_state.llm_prompts)}
{f"Consider increasing the max_new_tokens parameter !" if num_images > 4 else ""}""",
icon="⚠️",
)
else:
st.success(
f"""{len(st.session_state.llm_prompts)} Prompts ready.
f"""{len(st.session_state.llm_prompts)} Prompts ready.
Proceed with image generation or regenerate if needed.""",
icon="⬇️",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def dir_path(path_str):
if not os.path.isdir(path_str):
os.makedirs(path_str)
print(f"{path_str} did not exist, created the directory.")
print(f"\nDownload might take a moment to start.. ")
print("\nDownload will take few moments to start.. ")
return path_str
except Exception as e:
raise NotADirectoryError(f"Failed to create directory {path_str}: {e}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
import logging
import time
import torch
import openvino.torch
import openvino.torch # noqa: F401 # Import to enable optimizations from OpenVINO
from transformers import AutoModelForCausalLM, AutoTokenizer

from pathlib import Path
from ts.handler_utils.timer import timed
from ts.torch_handler.base_handler import BaseHandler

Expand All @@ -31,7 +30,6 @@ def __init__(self):
def initialize(self, ctx):
self.context = ctx
self.manifest = ctx.manifest
properties = ctx.system_properties

model_store_dir = ctx.model_yaml_config["handler"]["model_store_dir"]
model_name_llm = os.environ["MODEL_NAME_LLM"].replace("/", "---")
Expand All @@ -50,11 +48,16 @@ def initialize(self, ctx):
self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
self.model = AutoModelForCausalLM.from_pretrained(model_dir)

# Get backend for model-confil.yaml. Defaults to "inductor"
backend = ctx.model_yaml_config.get("pt2", {}).get("backend", "inductor")
# Get backend for model-confil.yaml. Defaults to "openvino"
compile_options = {}
pt2_config = ctx.model_yaml_config.get("pt2", {})
compile_options = {
"backend": pt2_config.get("backend", "openvino"),
"options": pt2_config.get("options", {}),
}
logger.info(f"Loading LLM model with PT2 compiler options: {compile_options}")

logger.info(f"Compiling model with {backend} backend.")
self.model = torch.compile(self.model, backend=backend)
self.model = torch.compile(self.model, **compile_options)

self.model.to(self.device)
self.model.eval()
Expand All @@ -67,7 +70,6 @@ def preprocess(self, requests):
assert len(requests) == 1, "Llama currently only supported with batch_size=1"

req_data = requests[0]

input_data = req_data.get("data") or req_data.get("body")

if isinstance(input_data, (bytes, bytearray)):
Expand All @@ -82,7 +84,6 @@ def preprocess(self, requests):
self.device
)

# self.prompt_length = encoded_prompt.size(0)
input_data["encoded_prompt"] = encoded_prompt

return input_data
Expand Down Expand Up @@ -119,7 +120,7 @@ def postprocess(self, generated_text):
# Initialize with user prompt
prompt_list = [self.user_prompt]
try:
logger.info(f"Parsing LLM Generated Output to extract prompts within []...")
logger.info("Parsing LLM Generated Output to extract prompts within []...")
response_match = re.search(r"\[(.*?)\]", generated_text)
# Extract the result if match is found
if response_match:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,3 @@ pt2:
handler:
profile: true
model_store_dir: "/home/model-server/model-store/"
max_new_tokens: 40
compile: true
fx_graph_cache: true
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
--extra-index-url https://download.pytorch.org/whl/cpu
transformers
streamlit>=1.26.0
requests_futures
asyncio
aiohttp
accelerate
tabulate
torch>=2.5.1
Loading

0 comments on commit 28e1b53

Please sign in to comment.