diff --git a/fooocusapi/utils/lora_manager.py b/fooocusapi/utils/lora_manager.py index fcff29a..1053bd7 100644 --- a/fooocusapi/utils/lora_manager.py +++ b/fooocusapi/utils/lora_manager.py @@ -1,19 +1,12 @@ -""" -Manager loras from url - -@author: TechnikMax -@github: https://github.com/TechnikMax -""" import hashlib import os import requests - +import tarfile def _hash_url(url): """Generates a hash value for a given URL.""" return hashlib.md5(url.encode('utf-8')).hexdigest() - class LoraManager: """ Manager loras from url @@ -26,14 +19,14 @@ def __init__(self): def _download_lora(self, url): """ - Downloads a LoRa from a URL and saves it in the cache. + Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file. """ url_hash = _hash_url(url) - filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors") - file_name = f"{url_hash}.safetensors" + file_ext = url.split('.')[-1] + filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}") if not os.path.exists(filepath): - print(f"start download for: {url}") + print(f"Start download for: {url}") try: response = requests.get(url, timeout=10, stream=True) @@ -41,14 +34,33 @@ def _download_lora(self, url): with open(filepath, 'wb') as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) - print(f"Download successfully, saved as {file_name}") + + if file_ext == "tar": + print("Extracting the tar file...") + with tarfile.open(filepath, 'r:*') as tar: + tar.extractall(path=self.cache_dir) + print("Extraction completed.") + return self._find_safetensors_file(self.cache_dir) + print(f"Download successfully, saved as {filepath}") except Exception as e: - raise Exception(f"error downloading {url}: {e}") from e + raise Exception(f"Error downloading {url}: {e}") from e else: print(f"LoRa already downloaded {url}") - return file_name + + return filepath + + def _find_safetensors_file(self, directory): + """ + Finds the first .safetensors file in the specified directory. + """ + print("Searching for .safetensors file.") + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith('.safetensors'): + return os.path.join(root, file) + raise FileNotFoundError("No .safetensors file found in the extracted files.") def check(self, urls): """Manages the specified LoRAs: downloads missing ones and returns their file names.""" diff --git a/fooocusapi/worker.py b/fooocusapi/worker.py index 991c0be..4bab45c 100644 --- a/fooocusapi/worker.py +++ b/fooocusapi/worker.py @@ -1013,6 +1013,7 @@ def callback(step, x0, x, total_steps, y): results += imgs except model_management.InterruptProcessingException as e: logger.std_warn("[Fooocus] User stopped") + results = [] results.append(ImageGenerationResult( im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.user_cancel)) async_task.set_result(results, True, str(e)) @@ -1020,6 +1021,7 @@ def callback(step, x0, x, total_steps, y): except Exception as e: logger.std_error(f'[Fooocus] Process error: {e}') logging.exception(e) + results = [] results.append(ImageGenerationResult( im=None, seed=task['task_seed'], finish_reason=GenerationFinishReason.error)) async_task.set_result(results, True, str(e))