Skip to content

Commit

Permalink
Merge pull request #371 from DrHazemAli/main
Browse files Browse the repository at this point in the history
Enhance LoraManager to Support .tar Files
  • Loading branch information
mrhan1993 authored Jul 1, 2024
2 parents 9668537 + 224e848 commit ca56962
Showing 1 changed file with 27 additions and 15 deletions.
42 changes: 27 additions & 15 deletions fooocusapi/utils/lora_manager.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,12 @@
"""
Manager loras from url
@author: TechnikMax
@github: https://github.com/TechnikMax
"""
import hashlib
import os
import requests

import tarfile

def _hash_url(url):
"""Generates a hash value for a given URL."""
return hashlib.md5(url.encode('utf-8')).hexdigest()


class LoraManager:
"""
Manager loras from url
Expand All @@ -26,29 +19,48 @@ def __init__(self):

def _download_lora(self, url):
"""
Downloads a LoRa from a URL and saves it in the cache.
Downloads a LoRa from a URL, saves it in the cache, and if it's a .tar file, extracts it and returns the .safetensors file.
"""
url_hash = _hash_url(url)
filepath = os.path.join(self.cache_dir, f"{url_hash}.safetensors")
file_name = f"{url_hash}.safetensors"
file_ext = url.split('.')[-1]
filepath = os.path.join(self.cache_dir, f"{url_hash}.{file_ext}")

if not os.path.exists(filepath):
print(f"start download for: {url}")
print(f"Start download for: {url}")

try:
response = requests.get(url, timeout=10, stream=True)
response.raise_for_status()
with open(filepath, 'wb') as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print(f"Download successfully, saved as {file_name}")

if file_ext == "tar":
print("Extracting the tar file...")
with tarfile.open(filepath, 'r:*') as tar:
tar.extractall(path=self.cache_dir)
print("Extraction completed.")
return self._find_safetensors_file(self.cache_dir)

print(f"Download successfully, saved as {filepath}")
except Exception as e:
raise Exception(f"error downloading {url}: {e}") from e
raise Exception(f"Error downloading {url}: {e}") from e

else:
print(f"LoRa already downloaded {url}")
return file_name

return filepath

def _find_safetensors_file(self, directory):
"""
Finds the first .safetensors file in the specified directory.
"""
print("Searching for .safetensors file.")
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith('.safetensors'):
return os.path.join(root, file)
raise FileNotFoundError("No .safetensors file found in the extracted files.")

def check(self, urls):
"""Manages the specified LoRAs: downloads missing ones and returns their file names."""
Expand Down

0 comments on commit ca56962

Please sign in to comment.