From 28e1b53ac9159e42b07f39041cd9506427b1f855 Mon Sep 17 00:00:00 2001 From: Ravi Panchumarthy Date: Sat, 23 Nov 2024 06:27:01 +0000 Subject: [PATCH] Update llm_diffusion_serving_app, fix linter issues --- .../docker/client_app.py | 18 +- .../docker/llm/download_model_llm.py | 2 +- .../docker/llm/llm_handler.py | 21 +- .../docker/llm/model-config.yaml | 3 - .../docker/requirements.txt | 3 +- .../docker/sd-benchmark.py | 366 ++++++++++++------ .../docker/sd/download_model_sd.py | 2 +- .../docker/sd/model-config.yaml | 12 +- .../docker/sd/stable_diffusion_handler.py | 49 +-- .../docker/server_app.py | 32 +- 10 files changed, 302 insertions(+), 206 deletions(-) diff --git a/examples/usecases/llm_diffusion_serving_app/docker/client_app.py b/examples/usecases/llm_diffusion_serving_app/docker/client_app.py index 29613fa723..616bc40522 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/client_app.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/client_app.py @@ -160,10 +160,10 @@ def sd_response_postprocess(response): def preprocess_llm_input(user_prompt, num_images=2): template = """ Below is an instruction that describes a task. Write a response that appropriately completes the request. - Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact. + Generate {} unique prompts similar to '{}' by changing the context, keeping the core theme intact. Give the output in square brackets seperated by semicolon. Do not generate text beyond the specified output format. Do not explain your response. - ### Response: + ### Response: """ prompt_template_with_user_input = template.format(num_images, user_prompt) @@ -242,8 +242,8 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt): st.markdown( """ ### Multi-Image Generation App with TorchServe and OpenVINO - Welcome to the Multi-Image Generation Client App. This app allows you to generate multiple images - from a single text prompt. Simply input your prompt, and the app will enhance it using a LLM (Llama) and + Welcome to the Multi-Image Generation Client App. This app allows you to generate multiple images + from a single text prompt. Simply input your prompt, and the app will enhance it using a LLM (Llama) and generate images in parallel using the Stable Diffusion with latent-consistency/lcm-sdxl model. See [GitHub](https://github.com/pytorch/serve/tree/master/examples/usecases/llm_diffusion_serving_app) for details. """, @@ -252,7 +252,7 @@ def generate_llm_model_response(prompt_template_with_user_input, user_prompt): st.image("./img/workflow-2.png") st.markdown( - """
NOTE: Initial image generation may take longer due to model warm-up. Subsequent generations will be faster !
""", @@ -274,7 +274,7 @@ def display_images_in_grid(images, captions): def display_prompts(): - prompt_container.write(f"Generated Prompts:") + prompt_container.write("Generated Prompts:") prompt_list = "" for i, pr in enumerate(st.session_state.llm_prompts, 1): prompt_list += f"{i}. {pr}\n" @@ -304,18 +304,18 @@ def display_prompts(): if not st.session_state.llm_prompts: prompt_container.write( - f"Enter Image Generation Prompt and Click Generate Prompts !" + "Enter Image Generation Prompt and Click Generate Prompts !" ) elif len(st.session_state.llm_prompts) < num_images: prompt_container.warning( - f"""Insufficient prompts. Regenerate prompts ! + f"""Insufficient prompts. Regenerate prompts ! Num Images Requested: {num_images}, Prompts Generated: {len(st.session_state.llm_prompts)} {f"Consider increasing the max_new_tokens parameter !" if num_images > 4 else ""}""", icon="⚠️", ) else: st.success( - f"""{len(st.session_state.llm_prompts)} Prompts ready. + f"""{len(st.session_state.llm_prompts)} Prompts ready. Proceed with image generation or regenerate if needed.""", icon="⬇️", ) diff --git a/examples/usecases/llm_diffusion_serving_app/docker/llm/download_model_llm.py b/examples/usecases/llm_diffusion_serving_app/docker/llm/download_model_llm.py index 42406a9e6c..5148c8df66 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/llm/download_model_llm.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/llm/download_model_llm.py @@ -10,7 +10,7 @@ def dir_path(path_str): if not os.path.isdir(path_str): os.makedirs(path_str) print(f"{path_str} did not exist, created the directory.") - print(f"\nDownload might take a moment to start.. ") + print("\nDownload will take few moments to start.. ") return path_str except Exception as e: raise NotADirectoryError(f"Failed to create directory {path_str}: {e}") diff --git a/examples/usecases/llm_diffusion_serving_app/docker/llm/llm_handler.py b/examples/usecases/llm_diffusion_serving_app/docker/llm/llm_handler.py index 8edb288782..c0033101a9 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/llm/llm_handler.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/llm/llm_handler.py @@ -4,10 +4,9 @@ import logging import time import torch -import openvino.torch +import openvino.torch # noqa: F401 # Import to enable optimizations from OpenVINO from transformers import AutoModelForCausalLM, AutoTokenizer -from pathlib import Path from ts.handler_utils.timer import timed from ts.torch_handler.base_handler import BaseHandler @@ -31,7 +30,6 @@ def __init__(self): def initialize(self, ctx): self.context = ctx self.manifest = ctx.manifest - properties = ctx.system_properties model_store_dir = ctx.model_yaml_config["handler"]["model_store_dir"] model_name_llm = os.environ["MODEL_NAME_LLM"].replace("/", "---") @@ -50,11 +48,16 @@ def initialize(self, ctx): self.tokenizer = AutoTokenizer.from_pretrained(model_dir) self.model = AutoModelForCausalLM.from_pretrained(model_dir) - # Get backend for model-confil.yaml. Defaults to "inductor" - backend = ctx.model_yaml_config.get("pt2", {}).get("backend", "inductor") + # Get backend for model-confil.yaml. Defaults to "openvino" + compile_options = {} + pt2_config = ctx.model_yaml_config.get("pt2", {}) + compile_options = { + "backend": pt2_config.get("backend", "openvino"), + "options": pt2_config.get("options", {}), + } + logger.info(f"Loading LLM model with PT2 compiler options: {compile_options}") - logger.info(f"Compiling model with {backend} backend.") - self.model = torch.compile(self.model, backend=backend) + self.model = torch.compile(self.model, **compile_options) self.model.to(self.device) self.model.eval() @@ -67,7 +70,6 @@ def preprocess(self, requests): assert len(requests) == 1, "Llama currently only supported with batch_size=1" req_data = requests[0] - input_data = req_data.get("data") or req_data.get("body") if isinstance(input_data, (bytes, bytearray)): @@ -82,7 +84,6 @@ def preprocess(self, requests): self.device ) - # self.prompt_length = encoded_prompt.size(0) input_data["encoded_prompt"] = encoded_prompt return input_data @@ -119,7 +120,7 @@ def postprocess(self, generated_text): # Initialize with user prompt prompt_list = [self.user_prompt] try: - logger.info(f"Parsing LLM Generated Output to extract prompts within []...") + logger.info("Parsing LLM Generated Output to extract prompts within []...") response_match = re.search(r"\[(.*?)\]", generated_text) # Extract the result if match is found if response_match: diff --git a/examples/usecases/llm_diffusion_serving_app/docker/llm/model-config.yaml b/examples/usecases/llm_diffusion_serving_app/docker/llm/model-config.yaml index 2e13179a56..88fbc10b58 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/llm/model-config.yaml +++ b/examples/usecases/llm_diffusion_serving_app/docker/llm/model-config.yaml @@ -9,6 +9,3 @@ pt2: handler: profile: true model_store_dir: "/home/model-server/model-store/" - max_new_tokens: 40 - compile: true - fx_graph_cache: true diff --git a/examples/usecases/llm_diffusion_serving_app/docker/requirements.txt b/examples/usecases/llm_diffusion_serving_app/docker/requirements.txt index 90ab5f271c..1bbea6c403 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/requirements.txt +++ b/examples/usecases/llm_diffusion_serving_app/docker/requirements.txt @@ -1,7 +1,8 @@ +--extra-index-url https://download.pytorch.org/whl/cpu transformers streamlit>=1.26.0 requests_futures -asyncio aiohttp accelerate tabulate +torch>=2.5.1 \ No newline at end of file diff --git a/examples/usecases/llm_diffusion_serving_app/docker/sd-benchmark.py b/examples/usecases/llm_diffusion_serving_app/docker/sd-benchmark.py index 2aad51e6b8..34ea23272c 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/sd-benchmark.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/sd-benchmark.py @@ -1,3 +1,25 @@ +""" +Stable Diffusion Benchmark Script. + +Prerequisites: +- See https://github.com/pytorch/serve/tree/master/examples/usecases/llm_diffusion_serving_app/README.md +- This script assumes models are available in the mounted volume at + /home/model-server/model-store/stabilityai---stable-diffusion-xl-base-1.0/model + +This script benchmarks Stable Diffusion model inference across different execution modes: +- Eager mode (standard PyTorch) +- Torch.compile with Inductor backend +- Torch.compile with OpenVINO backend + +Results are saved in a timestamped directory including: +- JSON file with complete benchmark data +- Generated images for each mode +- Profiling data when enabled + +Usage: + python sd-benchmark.py [--num_iter N] [--run_profiling] +""" + import argparse import importlib.metadata import json @@ -12,123 +34,155 @@ from typing import Dict, Tuple, List import torch -import openvino.torch +from torch.profiler import profile, record_function, ProfilerActivity +import openvino.torch # noqa: F401 # Import to enable optimizations from OpenVINO from PIL import Image from diffusers import UNet2DConditionModel, DiffusionPipeline, LCMScheduler + class RunMode(Enum): EAGER = "eager" TC_INDUCTOR = "tc_inductor" TC_OPENVINO = "tc_openvino" + def setup_pipeline(run_mode: str, ckpt: str, dtype=torch.float16) -> DiffusionPipeline: - """ - Setup the diffusion pipeline based on run mode configuration - - Args: - run_mode: One of 'eager', 'tc_inductor', or 'tc_openvino' - ckpt: Path to the model checkpoint - dtype: Model dtype - """ + """Setup function remains unchanged""" print(f"\nInitializing pipeline with mode: {run_mode}") - - # Set compile options based on run mode + if run_mode == RunMode.TC_OPENVINO.value: compile_options = { - 'backend': 'openvino', - 'options': {'device': 'CPU', 'config': {'PERFORMANCE_HINT': 'LATENCY'}} + "backend": "openvino", + "options": {"device": "CPU", "config": {"PERFORMANCE_HINT": "LATENCY"}}, } print(f"Using OpenVINO backend with options: {compile_options}") elif run_mode == RunMode.TC_INDUCTOR.value: - compile_options = {'backend': 'inductor', 'options': {}} + compile_options = {"backend": "inductor", "options": {}} print(f"Using Inductor backend with options: {compile_options}") else: # eager mode compile_options = {} print("Using eager mode (no compilation)") - - # Initialize models + unet = UNet2DConditionModel.from_pretrained(f"{ckpt}/lcm/", torch_dtype=dtype) pipe = DiffusionPipeline.from_pretrained(ckpt, unet=unet, torch_dtype=dtype) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) - - # Apply compilation if using torch.compile + if run_mode != RunMode.EAGER.value: - print("Compiling models...") + print("Torch.Compiling models...") pipe.text_encoder = torch.compile(pipe.text_encoder, **compile_options) pipe.unet = torch.compile(pipe.unet, **compile_options) pipe.vae.decode = torch.compile(pipe.vae.decode, **compile_options) - + pipe.to("cpu") return pipe -def run_inference(pipe: DiffusionPipeline, params: Dict, iteration: int = 0) -> Tuple[Image.Image, float]: - """Run inference and measure time""" - start_time = time.time() - image = pipe( - params["prompt"], - num_inference_steps=params["num_inference_steps"], - guidance_scale=params["guidance_scale"], - height=params["height"], - width=params["width"], - ).images[0] - end_time = time.time() - + +def run_inference( + pipe: DiffusionPipeline, + params: Dict, + iteration: int = 0, + enable_profiling: bool = False, +) -> Tuple[Image.Image, float, Dict]: + """Run inference and measure time, with optional profiling""" + profiler_output = None + + if enable_profiling: + print(f"\nRunning inference with profiling for iteration: {iteration}") + with profile( + activities=[ProfilerActivity.CPU], profile_memory=True, record_shapes=True + ) as prof: + with record_function("model_inference"): + start_time = time.time() + image = pipe( + params["prompt"], + num_inference_steps=params["num_inference_steps"], + guidance_scale=params["guidance_scale"], + height=params["height"], + width=params["width"], + ).images[0] + end_time = time.time() + + # Process profiler data + profiler_output = { + "cpu_time": prof.key_averages().table( + sort_by="cpu_time_total", row_limit=20 + ), + "memory": prof.key_averages().table( + sort_by="cpu_memory_usage", row_limit=20 + ), + } + else: + print(f"\nRunning inference for iteration: {iteration}") + start_time = time.time() + image = pipe( + params["prompt"], + num_inference_steps=params["num_inference_steps"], + guidance_scale=params["guidance_scale"], + height=params["height"], + width=params["width"], + ).images[0] + end_time = time.time() + execution_time = end_time - start_time print(f"Iteration {iteration} execution time: {execution_time:.2f} seconds") - - return image, execution_time -def run_benchmark(run_mode: str, params: Dict, num_iter: int) -> Dict: + return image, execution_time, profiler_output + + +def run_benchmark( + run_mode: str, params: Dict, num_iter: int, enable_profiling: bool = False +) -> Dict: """Run a single benchmark configuration with multiple iterations""" - out_dir = "/home/model-server/model-store/" - + try: - pipe = setup_pipeline( - run_mode, - params["ckpt"], - params["dtype"] - ) - + pipe = setup_pipeline(run_mode, params["ckpt"]) + # Warm-up run print("\nPerforming warm-up run...") - warmup_image, warmup_time = run_inference(pipe, params, iteration=0) - + warmup_image, warmup_time, _ = run_inference(pipe, params, iteration=0) + # Benchmark iterations print(f"\nRunning {num_iter} benchmark iterations...") iteration_times = [] final_image = None - - for i in range(num_iter): - image, exec_time = run_inference(pipe, params, iteration=i+1) + profiler_data = None + + for i in range(1, num_iter + 1): + image, exec_time, profiler_data = run_inference( + pipe, + params, + iteration=i, + enable_profiling=( + enable_profiling and i == 1 + ), # if profile is enabled, run for 1 iteration only + ) + iteration_times.append(exec_time) - if i == num_iter - 1: + + if i == num_iter: # Save final image from the last iteration final_image = image - + # Calculate statistics stats = { "mean": float(np.mean(iteration_times)), "std": float(np.std(iteration_times)), - "all_iterations": iteration_times + "all_iterations": iteration_times, } - - # Save images - final_image_filename = f"image-{run_mode}-final.png" - final_image.save(os.path.join(out_dir, final_image_filename)) - - return { + + benchmark_results = { "run_mode": run_mode, "warmup_time": warmup_time, "statistics": stats, - "final_image": final_image_filename, - "status": "success" + "final_image": final_image, + "profiler_data": profiler_data if profiler_data else None, + "status": "success", } + + return benchmark_results except Exception as e: print(f"Error during benchmark: {str(e)}") - return { - "run_mode": run_mode, - "status": "failed", - "error": str(e) - } + return {"run_mode": run_mode, "status": "failed", "error": str(e)} + def get_hw_config(): output = subprocess.check_output(["lscpu"]).decode("utf-8") @@ -147,19 +201,22 @@ def get_hw_config(): elif line.startswith("Socket(s):"): socket_count = line.split("Socket(s):")[1].strip() - output = subprocess.check_output(["head", "-n", "1", "/proc/meminfo"]).decode("utf-8") - total_memory = int(output.split()[1]) / (1024.0 ** 2) + output = subprocess.check_output(["head", "-n", "1", "/proc/meminfo"]).decode( + "utf-8" + ) + total_memory = int(output.split()[1]) / (1024.0**2) total_memory_str = f"{total_memory:.2f} GB" - + return { "cpu_model": cpu_model, "cpu_count": cpu_count, "threads_per_core": threads_per_core, "cores_per_socket": cores_per_socket, "socket_count": socket_count, - "total_memory": total_memory_str + "total_memory": total_memory_str, } - + + def get_sw_versions(): sw_versions = {} packages = [ @@ -167,112 +224,177 @@ def get_sw_versions(): ("OpenVINO", "openvino"), ("PyTorch", "torch"), ("Transformers", "transformers"), - ("Diffusers", "diffusers") + ("Diffusers", "diffusers"), ] sw_versions["Python"] = sys.version.split()[0] - + for name, package in packages: try: version = importlib.metadata.version(package) sw_versions[name] = version except Exception as e: sw_versions[name] = "Not installed" + print(f"Exception trying to get {package} version. Error: {e}") return sw_versions -def save_results(results: List[Dict], hw_config: List[Dict], sw_versions: List[Dict]): - """Save benchmark results to a JSON file""" + +def save_results_1(results: List[Dict], hw_config: Dict, sw_versions: Dict): + """ + Save benchmark results to a timestamped directory + + Args: + results: List of benchmark results for different run modes + hw_config: Dictionary containing hardware configuration details + sw_versions: Dictionary containing software version information + """ out_dir = "/home/model-server/model-store/" - filename = f"sd_benchmark_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + + # Create results directory with timestamp + results_dir = os.path.join(out_dir, f"benchmark_results_{timestamp}") + os.makedirs(results_dir, exist_ok=True) + + # Save main results JSON + results_file = os.path.join(results_dir, "benchmark_results.json") + benchmark_data = [ + {k: result.get(k) for k in ["run_mode", "warmup_time", "statistics"]} + for result in results + ] + + full_results = { + "timestamp": datetime.now().isoformat(), + "hardware_config": hw_config, + "software_versions": sw_versions, + "benchmark_results": benchmark_data, + } + + with open(results_file, "w") as f: + json.dump(full_results, f, indent=2) + + # Copy images and profiler data for each run mode + for result in results: + if result["status"] == "success": + run_mode = result["run_mode"] + + # Save the final image + if result.get("final_image"): + image_filename = f"image-{run_mode}-final.png" + result.get("final_image").save( + os.path.join(results_dir, image_filename) + ) + + # Save profiler data + if result.get("profiler_data"): + profiler_data = result.get("profiler_data") + profiler_filename = f"profile-{run_mode}.txt" + with open(os.path.join(results_dir, profiler_filename), "w") as f: + f.write( + "CPU Time Profile (sort_by='cpu_time_total', row_limit=20):\n" + ) + f.write(profiler_data["cpu_time"]) + f.write( + "\n\nMemory Usage Profile (sort_by='cpu_memory_usage', row_limit=20):\n" + ) + f.write(profiler_data["memory"]) + + print(f"\nResults saved in directory: {results_dir}") + print(f"Files in the {results_dir} directory:") + for file in sorted(os.listdir(results_dir)): + print(file) - data = {"hw_config": hw_config, "sw_versions": sw_versions, "results": results} - with open(os.path.join(out_dir, filename), 'w') as f: - json.dump(data, f, indent=2) - - print(f"\nResults saved to {os.path.join(out_dir, filename)}") - def main(): # Parse command-line args - parser = argparse.ArgumentParser(description='Stable Diffusion Benchmark script') - parser.add_argument('-ni', '--num_iter', type=int, default=3, help='Number of benchmark iterations') + parser = argparse.ArgumentParser(description="Stable Diffusion Benchmark script") + parser.add_argument( + "-ni", "--num_iter", type=int, default=3, help="Number of benchmark iterations" + ) + parser.add_argument( + "-rp", + "--run_profiling", + action="store_true", + help="Run benchmark with profiling", + ) args = parser.parse_args() - + # Number of benchmark iterations - num_iter = args.num_iter + num_iter = 1 if args.run_profiling else args.num_iter + out_dir = "/home/model-server/model-store/" - - # Run modes to test + run_modes = [ RunMode.EAGER.value, RunMode.TC_INDUCTOR.value, - RunMode.TC_OPENVINO.value + RunMode.TC_OPENVINO.value, ] - - # Parameters + params = { "ckpt": "/home/model-server/model-store/stabilityai---stable-diffusion-xl-base-1.0/model", "guidance_scale": 5.0, "num_inference_steps": 4, "height": 768, "width": 768, - "prompt": "a close-up picture of an old man standing in the rain", - "dtype": torch.float16 + "prompt": "A close-up HD shot of a vibrant macaw parrot perched on a branch in a lush jungle ", + "dtype": torch.float16, } - - # Run benchmarks + # params["prompt"] = "A close-up of a blooming cherry blossom tree in full bloom" + results = [] for mode in run_modes: - print("\n" + "="*50) - print(f"Running benchmark with run mode: {mode}") - print(f"Number of iterations: {num_iter}") - print("="*50) - - result = run_benchmark(mode, params, num_iter) + print("\n" + "=" * 80) + print( + f"Running benchmark with run mode: {mode}, num_iter: {num_iter}, run_profiling: {args.run_profiling}" + ) + print("=" * 80) + result = run_benchmark(mode, params, num_iter, args.run_profiling) results.append(result) - - print("-"*50) + print("-" * 80) # Hardware and Software Info print("\nHardware Info:") - print("-"*50) + print("-" * 80) hw_config = get_hw_config() for key, value in hw_config.items(): print(f"{key}: {value}") - + print("\nSoftware Versions:") sw_versions = get_sw_versions() - print("-"*50) + print("-" * 80) for name, version in sw_versions.items(): print(f"{name}: {version}") - + # Print summary print("\nBenchmark Summary:") - print("-"*50) + print("-" * 80) table_data = [] for result in results: if result["status"] == "success": - table_data.append([ - result['run_mode'], + row = [ + result["run_mode"], f"{result['warmup_time']:.2f} seconds", f"{result['statistics']['mean']:.2f} +/- {result['statistics']['std']:.2f} seconds", - result['final_image'] - ]) + ] else: - table_data.append([ - result['run_mode'], + row = [ + result["run_mode"], "Failed", - result['error'], - "-" - ]) - - headers = ["Run Mode", "Warm-up Time", f"Average Time for {num_iter} iter", "Image Saved as "] + result["error"], + ] + table_data.append(row) + + headers = ["Run Mode", "Warm-up Time", f"Average Time for {num_iter} iter"] print(tabulate(table_data, headers=headers, tablefmt="grid")) - + # Save results - save_results(results, hw_config, sw_versions) - print(f"\nResults and Images saved at {out_dir} which is a Docker container mount, corresponds to 'serve/model-store-local/' on the host machine.\n") - + save_results_1(results, hw_config, sw_versions) + if args.run_profiling: + print("\nnum_iter is set to 1 as run_profiling flag is enabled !") + print( + f"\nResults saved at {out_dir} which is a Docker container mount, corresponds to 'serve/model-store-local/' on the host machine.\n" + ) + + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/examples/usecases/llm_diffusion_serving_app/docker/sd/download_model_sd.py b/examples/usecases/llm_diffusion_serving_app/docker/sd/download_model_sd.py index 8123f168cf..5374caa267 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/sd/download_model_sd.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/sd/download_model_sd.py @@ -10,7 +10,7 @@ def dir_path(path_str): if not os.path.isdir(path_str): os.makedirs(path_str) print(f"{path_str} did not exist, created the directory.") - print(f"\nDownload might take a moment to start.. ") + print("\nDownload will take few moments to start.. ") return path_str except Exception as e: raise NotADirectoryError(f"Failed to create directory {path_str}: {e}") diff --git a/examples/usecases/llm_diffusion_serving_app/docker/sd/model-config.yaml b/examples/usecases/llm_diffusion_serving_app/docker/sd/model-config.yaml index 50cdea6d1d..aa35bc28dc 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/sd/model-config.yaml +++ b/examples/usecases/llm_diffusion_serving_app/docker/sd/model-config.yaml @@ -13,15 +13,5 @@ pt2: handler: profile: true model_path: "model" - num_inference_steps: 5 - compile_unet: true - compile_mode: "max-autotune" - compile_vae: true - enable_fused_projections: true - do_quant: false - change_comp_config: false - no_bf16: true - no_sdpa: false - upcast_vae: false - is_xl: true + is_xl: false is_lcm: true diff --git a/examples/usecases/llm_diffusion_serving_app/docker/sd/stable_diffusion_handler.py b/examples/usecases/llm_diffusion_serving_app/docker/sd/stable_diffusion_handler.py index d480a1128a..f12f2d7ddf 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/sd/stable_diffusion_handler.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/sd/stable_diffusion_handler.py @@ -1,10 +1,11 @@ import logging +import time import os from pathlib import Path import numpy as np import json import torch -import openvino.torch +import openvino.torch # noqa: F401 # Import to enable optimizations from OpenVINO from diffusers import ( DiffusionPipeline, StableDiffusionXLPipeline, @@ -40,19 +41,12 @@ def initialize(self, ctx): model_dir = properties.get("model_dir") self.device = ctx.model_yaml_config["deviceType"] - self.num_inference_steps = ctx.model_yaml_config["handler"][ - "num_inference_steps" - ] logger.info(f"SD ctx.model_yaml_config is {ctx.model_yaml_config}") logger.info(f"SD ctx.system_properties is {ctx.system_properties}") logger.info(f"SD device={self.device}") # Parameters for the model - compile_unet = ctx.model_yaml_config["handler"]["compile_unet"] - compile_vae = ctx.model_yaml_config["handler"]["compile_vae"] - compile_mode = ctx.model_yaml_config["handler"]["compile_mode"] - change_comp_config = ctx.model_yaml_config["handler"]["change_comp_config"] is_xl = ctx.model_yaml_config["handler"]["is_xl"] is_lcm = ctx.model_yaml_config["handler"]["is_lcm"] @@ -69,10 +63,9 @@ def initialize(self, ctx): ckpt = os.path.join(model_dir, model_path) """Loads the SDXL LCM pipeline.""" - dtype = torch.float16 logger.info(f"Loading the SDXL LCM pipeline using dtype: {dtype}") - + t0 = time.time() if is_lcm: unet = UNet2DConditionModel.from_pretrained( f"{ckpt}/lcm/", torch_dtype=dtype @@ -80,6 +73,8 @@ def initialize(self, ctx): pipe = DiffusionPipeline.from_pretrained(ckpt, unet=unet, torch_dtype=dtype) pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config) pipe.text_encoder = torch.compile(pipe.text_encoder, **compile_options) + pipe.unet = torch.compile(pipe.unet, **compile_options) + pipe.vae.decode = torch.compile(pipe.vae.decode, **compile_options) elif is_xl: pipe = StableDiffusionXLPipeline.from_pretrained( @@ -90,33 +85,15 @@ def initialize(self, ctx): ckpt, torch_dtype=dtype, use_safetensors=True, safety_checker=None ) - if compile_unet: - logger.info("Compiling UNet.") - if compile_mode == "max-autotune" and change_comp_config: - pipe.unet.to(memory_format=torch.channels_last) - torch._inductor.config.conv_1x1_as_mm = True - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.epilogue_fusion = False - torch._inductor.config.coordinate_descent_check_all_directions = True - - pipe.unet = torch.compile(pipe.unet, **compile_options) - - if compile_vae: - logger.info("Compiling VAE.") - if compile_mode == "max-autotune" and change_comp_config: - pipe.vae.to(memory_format=torch.channels_last) - torch._inductor.config.conv_1x1_as_mm = True - torch._inductor.config.coordinate_descent_tuning = True - torch._inductor.config.epilogue_fusion = False - torch._inductor.config.coordinate_descent_check_all_directions = True - - pipe.vae.decode = torch.compile(pipe.vae.decode, **compile_options) - - logger.info(f"Compiled {ckpt} model with {compile_options}") + logger.info( + f"Compiled {ckpt} model with PT2 compiler options: {compile_options}" + ) pipe.set_progress_bar_config(disable=True) self.pipeline = pipe - logger.info(f"Stable Diffusion model loaded successfully: {ckpt}") + logger.info( + f"Time to load Stable Diffusion model: {ckpt}: {time.time() - t0:.02f} seconds" + ) self.initialized = True return pipe @@ -155,8 +132,8 @@ def inference(self, model_inputs): list : It returns a list of the generate images for the input text """ # Handling inference for sequence_classification. - guidance_scale = model_inputs.get("guidance_scale") or 5.0 - num_inference_steps = model_inputs.get("num_inference_steps") or 5 + guidance_scale = model_inputs.get("guidance_scale") or 4.0 + num_inference_steps = model_inputs.get("num_inference_steps") or 4 height = model_inputs.get("height") or 768 width = model_inputs.get("width") or 768 inferences = self.pipeline( diff --git a/examples/usecases/llm_diffusion_serving_app/docker/server_app.py b/examples/usecases/llm_diffusion_serving_app/docker/server_app.py index 11b36d6457..7422ef2eb7 100644 --- a/examples/usecases/llm_diffusion_serving_app/docker/server_app.py +++ b/examples/usecases/llm_diffusion_serving_app/docker/server_app.py @@ -24,7 +24,7 @@ # Init Session State variables st.session_state.started = st.session_state.get("started", False) -st.session_state.stopped = st.session_state.get("stopped", True) +st.session_state.stopped = st.session_state.get("stopped", True) st.session_state.registered = st.session_state.get( "registered", { @@ -33,6 +33,7 @@ }, ) + def is_server_running(): """Check if the TorchServe server is running.""" try: @@ -41,6 +42,7 @@ def is_server_running(): except requests.exceptions.ConnectionError: return False + def init_model_registrations(): for model_name in [MODEL_NAME_LLM, MODEL_NAME_SD]: url = f"http://localhost:8081/models/{model_name}" @@ -54,15 +56,17 @@ def init_model_registrations(): logger.info(f"Error checking model registration: {e}") st.session_state.registered[model_name] = False + # Update Session State variables if is_server_running(): st.session_state.started = True - st.session_state.stopped = False + st.session_state.stopped = False init_model_registrations() def start_torchserve_server(): """Starts the TorchServe server if it's not already running.""" + def launch_server(): """Launch the TorchServe server with the specified configurations.""" subprocess.run( @@ -240,19 +244,22 @@ def get_hw_config_output(): elif line.startswith("Socket(s):"): socket_count = line.split("Socket(s):")[1].strip() - output = subprocess.check_output(["head", "-n", "1", "/proc/meminfo"]).decode("utf-8") - total_memory = int(output.split()[1]) / (1024.0 ** 2) + output = subprocess.check_output(["head", "-n", "1", "/proc/meminfo"]).decode( + "utf-8" + ) + total_memory = int(output.split()[1]) / (1024.0**2) total_memory_str = f"{total_memory:.2f} GB" - + return { "cpu_model": cpu_model, "cpu_count": cpu_count, "threads_per_core": threads_per_core, "cores_per_socket": cores_per_socket, "socket_count": socket_count, - "total_memory": total_memory_str + "total_memory": total_memory_str, } + def get_sw_versions(): sw_versions = {} packages = [ @@ -260,17 +267,18 @@ def get_sw_versions(): ("OpenVINO", "openvino"), ("PyTorch", "torch"), ("Transformers", "transformers"), - ("Diffusers", "diffusers") + ("Diffusers", "diffusers"), ] sw_versions["Python"] = sys.version.split()[0] - + for name, package in packages: try: version = importlib.metadata.version(package) sw_versions[name] = version except Exception as e: sw_versions[name] = "Not installed" + print(f"Exception trying to get {package} version. Error: {e}") return sw_versions @@ -282,7 +290,7 @@ def get_sw_versions(): st.button("Start TorchServe Server", on_click=start_torchserve_server) st.button("Stop TorchServe Server", on_click=stop_server) st.button( - f"Register Models", + "Register Models", on_click=register_models, args=([MODEL_NAME_LLM, MODEL_NAME_SD],), ) @@ -323,15 +331,15 @@ def get_sw_versions(): st.markdown( """ ### Multi-Image Generation App Control Center - Manage the Multi-Image Generation App workflow with this administrative interface. - Use this app to Start/stop TorchServe, load/register models, scale up/down workers, + Manage the Multi-Image Generation App workflow with this administrative interface. + Use this app to Start/stop TorchServe, load/register models, scale up/down workers, and review TorchServe Server and Model info. See [GitHub](https://github.com/pytorch/serve/tree/master/examples/usecases/llm_diffusion_serving_app) for details. """, unsafe_allow_html=True, ) st.markdown( - """
NOTE: After Starting TorchServe and Registering models, proceed to Client App running at port 8085.
""",