Skip to content

Commit

Permalink
Clean up script
Browse files Browse the repository at this point in the history
  • Loading branch information
jerrychen109 committed Nov 5, 2023
1 parent 0c9c4b8 commit 85c29e9
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 235 deletions.
55 changes: 28 additions & 27 deletions llmfoundry/utils/model_download_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
SAFE_WEIGHTS_INDEX_NAME,
)

PYTORCH_WEIGHTS_PATTERN = "pytorch_model*.bin*"
SAFE_WEIGHTS_PATTERN = "model*.safetensors*"
PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'

log = logging.getLogger(__name__)

def download_from_hf_hub(
repo_id: str,
Expand All @@ -48,34 +49,34 @@ def download_from_hf_hub(

# Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer.
ignore_patterns = [
"*.ckpt",
"*.h5",
"*.msgpack",
'*.ckpt',
'*.h5',
'*.msgpack',
]

if (
SAFE_WEIGHTS_NAME in repo_files or SAFE_WEIGHTS_INDEX_NAME in repo_files
) and prefer_safetensors:
logging.info("Safetensors found and preferred. Excluding pytorch files")
log.info('Safetensors found and preferred. Excluding pytorch files')
ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
elif PYTORCH_WEIGHTS_NAME in repo_files or PYTORCH_WEIGHTS_INDEX_NAME in repo_files:
logging.info(
"Safetensors not found or prefer_safetensors is False. Excluding safetensors files"
log.info(
'Safetensors not found or prefer_safetensors is False. Excluding safetensors files'
)
ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
else:
raise ValueError(
f"No supported model weights found in repo {repo_id}."
+ " Please make sure the repo contains either safetensors or pytorch weights."
f'No supported model weights found in repo {repo_id}.'
+ ' Please make sure the repo contains either safetensors or pytorch weights.'
)

download_start = time.time()
hf_hub.snapshot_download(
repo_id, cache_dir=save_dir, ignore_patterns=ignore_patterns, token=token
)
download_duration = time.time() - download_start
logging.info(
f"Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds"
log.info(
f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds'
)


Expand All @@ -88,8 +89,8 @@ def _extract_links_from_html(html: str):
Returns:
list[str]: A list of links to download.
"""
soup = BeautifulSoup(html, "html.parser")
links = [a["href"] for a in soup.find_all("a")]
soup = BeautifulSoup(html, 'html.parser')
links = [a['href'] for a in soup.find_all('a')]
return links


Expand All @@ -115,37 +116,37 @@ def _recursive_download(
PermissionError: If the remote server returns a 401 Unauthorized status code.
RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
"""
url = f"{base_url}/{path}"
url = f'{base_url}/{path}'
response = session.get(url, verify=(not ignore_cert))

if response.status_code == HTTPStatus.UNAUTHORIZED:
raise PermissionError(
f"Not authorized to download file from {url}. Received status code {response.status_code}. "
f'Not authorized to download file from {url}. Received status code {response.status_code}. '
)
elif response.status_code != HTTPStatus.OK:
raise RuntimeError(
f"Could not download file from {url}. Received unexpected status code {response.status_code}"
f'Could not download file from {url}. Received unexpected status code {response.status_code}'
)

# Assume that the URL points to a file if it does not end with a slash.
if not path.endswith("/"):
save_path = f"{save_dir}/{path}"
if not path.endswith('/'):
save_path = f'{save_dir}/{path}'
parent_dir = os.path.dirname(save_path)
if not os.path.exists(parent_dir):
os.makedirs(parent_dir)

with open(save_path, "wb") as f:
with open(save_path, 'wb') as f:
f.write(response.content)

logging.info(f"Downloaded file {save_path}")
log.info(f'Downloaded file {save_path}')
return

# If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links
# to download.
child_links = _extract_links_from_html(response.content.decode())
for child_link in child_links:
_recursive_download(
session, base_url, f"{path}/{child_link}", save_dir, ignore_cert=ignore_cert
session, base_url, f'{path}/{child_link}', save_dir, ignore_cert=ignore_cert
)


Expand Down Expand Up @@ -177,19 +178,19 @@ def download_from_cache_server(
False.
WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
"""
formatted_model_name = f"models/{model_name}".replace("/", "--")
formatted_model_name = f'models/{model_name}'.replace('/', '--')
with requests.Session() as session:
session.headers.update({"Authorization": f"Bearer {token}"})
session.headers.update({'Authorization': f'Bearer {token}'})

download_start = time.time()
_recursive_download(
session,
cache_base_url,
f"{formatted_model_name}/blobs/", # Trailing slash to indicate directory
f'{formatted_model_name}/blobs/', # Trailing slash to indicate directory
save_dir,
ignore_cert=ignore_cert,
)
download_duration = time.time() - download_start
logging.info(
f"Downloaded model {model_name} from cache server in {download_duration} seconds"
log.info(
f'Downloaded model {model_name} from cache server in {download_duration} seconds'
)
Loading

0 comments on commit 85c29e9

Please sign in to comment.