diff --git a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py index 6e847d528..0cd5fc0ef 100644 --- a/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py +++ b/06_gpu_and_ml/stable_diffusion/stable_diffusion_cli.py @@ -122,7 +122,9 @@ def __enter__(self): __build__ = __enter__ @method() - def run_inference(self, prompt: str, steps: int = 20, batch_size: int = 4): + def run_inference( + self, prompt: str, steps: int = 20, batch_size: int = 4 + ) -> list[bytes]: with torch.inference_mode(): with torch.autocast("cuda"): images = self.pipe(