Skip to content

Commit

Permalink
Allow HF token also from an environment variable
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Nov 26, 2024
1 parent d26b511 commit ef91ba9
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 3 deletions.
3 changes: 2 additions & 1 deletion docs/src/getting-started/checkpoints.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ a file path.
mtt export https://my.url.com/model.ckpt --output model.pt
Downloading private HuggingFace models is also supported, by specifying the
corresponding API token with the ``--huggingface_api_token`` flag.
corresponding API token with the ``--huggingface_api_token`` flag or the
``HUGGINGFACE_METATRAIN_TOKEN`` environment variable.

Keep in mind that a checkpoint (``.ckpt``) is only a temporary file, which can have
several dependencies and may become unusable if the corresponding architecture is
Expand Down
19 changes: 17 additions & 2 deletions src/metatrain/utils/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,22 @@ def load_model(
)

# Download from HuggingFace with a private token
if kwargs.get("huggingface_api_token"):
if (
kwargs.get("huggingface_api_token") # token from CLI
or os.environ.get("HUGGINGFACE_METATRAIN_TOKEN") # token from env variable
) and "huggingface.co" in str(path):
cli_token = kwargs.get("huggingface_api_token")
env_token = os.environ.get("HUGGINGFACE_METATRAIN_TOKEN")
if cli_token and env_token:
logging.info(
"Both CLI and environment variable tokens are set for "
"HuggingFace. Using the CLI token."
)
hf_token = cli_token
if cli_token:
hf_token = cli_token
if env_token:
hf_token = env_token
try:
from huggingface_hub import hf_hub_download
except ImportError:
Expand Down Expand Up @@ -135,7 +150,7 @@ def load_model(
"'main' branch."
)
filename = filename[10:]
path = hf_hub_download(repo_id, filename, token=kwargs["huggingface_api_token"])
path = hf_hub_download(repo_id, filename, token=hf_token)
# make sure to copy the checkpoint to the current directory
basename = os.path.basename(path)
shutil.copy(path, Path.cwd() / basename)
Expand Down
18 changes: 18 additions & 0 deletions tests/cli/test_export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,21 @@ def test_private_huggingface(monkeypatch, tmp_path):

# Test that the model can be loaded
load_model(output, extensions_directory="extensions/")

# also test with the token in the environment variable
os.environ["HUGGINGFACE_METATRAIN_TOKEN"] = HF_TOKEN

# remove output file and extensions
os.remove(output)
os.rmdir("extensions/")

command = command[:-1] # remove the token from the command line
subprocess.check_call(command)
assert Path(output).is_file()

# Test if extensions are saved
extensions_glob = glob.glob("extensions/")
assert len(extensions_glob) == 1

# Test that the model can be loaded
load_model(output, extensions_directory="extensions/")

0 comments on commit ef91ba9

Please sign in to comment.