Skip to content

Commit

Permalink
[Examples] Update with HFApi (huggingface#5393)
Browse files Browse the repository at this point in the history
* update training examples to use HFAPI.

* update training example.

* reflect the changes in the korean version too.

* Empty-Commit
  • Loading branch information
sayakpaul authored Oct 16, 2023
1 parent 0ea78f9 commit cc12f3e
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 142 deletions.
29 changes: 12 additions & 17 deletions docs/source/en/tutorials/basic_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -284,22 +284,11 @@ Now you can wrap all these components together in a training loop with 🤗 Acce

```py
>>> from accelerate import Accelerator
>>> from huggingface_hub import HfFolder, Repository, whoami
>>> from huggingface_hub import create_repo, upload_folder
>>> from tqdm.auto import tqdm
>>> from pathlib import Path
>>> import os


>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
... if token is None:
... token = HfFolder.get_token()
... if organization is None:
... username = whoami(token)["name"]
... return f"{username}/{model_id}"
... else:
... return f"{organization}/{model_id}"


>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
... # Initialize accelerator and tensorboard logging
... accelerator = Accelerator(
Expand All @@ -309,11 +298,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce
... project_dir=os.path.join(config.output_dir, "logs"),
... )
... if accelerator.is_main_process:
... if config.push_to_hub:
... repo_name = get_full_repo_name(Path(config.output_dir).name)
... repo = Repository(config.output_dir, clone_from=repo_name)
... elif config.output_dir is not None:
... if config.output_dir is not None:
... os.makedirs(config.output_dir, exist_ok=True)
... if config.push_to_hub:
... repo_id = create_repo(
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
... )
... accelerator.init_trackers("train_example")

... # Prepare everything
Expand Down Expand Up @@ -371,7 +361,12 @@ Now you can wrap all these components together in a training loop with 🤗 Acce

... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
... if config.push_to_hub:
... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
... upload_folder(
... repo_id=repo_id,
... folder_path=config.output_dir,
... commit_message=f"Epoch {epoch}",
... ignore_patterns=["step_*", "epoch_*"],
... )
... else:
... pipeline.save_pretrained(config.output_dir)
```
Expand Down
32 changes: 14 additions & 18 deletions docs/source/ko/tutorials/basic_training.md
Original file line number Diff line number Diff line change
Expand Up @@ -283,36 +283,27 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽

```py
>>> from accelerate import Accelerator
>>> from huggingface_hub import HfFolder, Repository, whoami
>>> from huggingface_hub import create_repo, upload_folder
>>> from tqdm.auto import tqdm
>>> from pathlib import Path
>>> import os


>>> def get_full_repo_name(model_id: str, organization: str = None, token: str = None):
... if token is None:
... token = HfFolder.get_token()
... if organization is None:
... username = whoami(token)["name"]
... return f"{username}/{model_id}"
... else:
... return f"{organization}/{model_id}"


>>> def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
... # accelerator와 tensorboard 로깅 초기화
... # Initialize accelerator and tensorboard logging
... accelerator = Accelerator(
... mixed_precision=config.mixed_precision,
... gradient_accumulation_steps=config.gradient_accumulation_steps,
... log_with="tensorboard",
... logging_dir=os.path.join(config.output_dir, "logs"),
... project_dir=os.path.join(config.output_dir, "logs"),
... )
... if accelerator.is_main_process:
... if config.push_to_hub:
... repo_name = get_full_repo_name(Path(config.output_dir).name)
... repo = Repository(config.output_dir, clone_from=repo_name)
... elif config.output_dir is not None:
... if config.output_dir is not None:
... os.makedirs(config.output_dir, exist_ok=True)
... if config.push_to_hub:
... repo_id = create_repo(
... repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True
... )
... accelerator.init_trackers("train_example")

... # 모든 것이 준비되었습니다.
Expand Down Expand Up @@ -369,7 +360,12 @@ TensorBoard에 로깅, 그래디언트 누적 및 혼합 정밀도 학습을 쉽

... if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:
... if config.push_to_hub:
... repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=True)
... upload_folder(
... repo_id=repo_id,
... folder_path=config.output_dir,
... commit_message=f"Epoch {epoch}",
... ignore_patterns=["step_*", "epoch_*"],
... )
... else:
... pipeline.save_pretrained(config.output_dir)
```
Expand Down
40 changes: 13 additions & 27 deletions examples/dreambooth/train_dreambooth_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
import os
from pathlib import Path
from typing import Optional

import jax
import jax.numpy as jnp
Expand All @@ -16,7 +15,7 @@
from flax import jax_utils
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from jax.experimental.compilation_cache import compilation_cache as cc
from PIL import Image
from torch.utils.data import Dataset
Expand Down Expand Up @@ -318,16 +317,6 @@ def __getitem__(self, index):
return example


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def get_params_to_save(params):
return jax.device_get(jax.tree_util.tree_map(lambda x: x[0], params))

Expand Down Expand Up @@ -392,22 +381,14 @@ def main():

# Handle the repository creation
if jax.process_index() == 0:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
Expand Down Expand Up @@ -668,7 +649,12 @@ def checkpoint(step=None):

if args.push_to_hub:
message = f"checkpoint-{step}" if step is not None else "End of training"
repo.push_to_hub(commit_message=message, blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message=message,
ignore_patterns=["step_*", "epoch_*"],
)

global_step = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import os
import random
from pathlib import Path
from typing import Iterable, Optional
from typing import Iterable

import numpy as np
import PIL
Expand All @@ -13,7 +13,7 @@
import torch.utils.checkpoint
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration, set_seed
from huggingface_hub import HfFolder, Repository, whoami
from huggingface_hub import create_repo, upload_folder
from neural_compressor.utils import logger
from packaging import version
from PIL import Image
Expand Down Expand Up @@ -413,16 +413,6 @@ def __getitem__(self, i):
return example


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def freeze_params(params):
for param in params:
param.requires_grad = False
Expand Down Expand Up @@ -461,21 +451,14 @@ def main():

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo = Repository(args.output_dir, clone_from=repo_name)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Load the tokenizer and add the placeholder token as a additional special token
if args.tokenizer_name:
tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
Expand Down Expand Up @@ -982,7 +965,12 @@ def attention_fetcher(x):
)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message="End of training",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import math
import os
from pathlib import Path
from typing import Optional

import accelerate
import datasets
Expand All @@ -14,7 +13,7 @@
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration
from datasets import load_dataset
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from huggingface_hub import create_repo, upload_folder
from onnxruntime.training.optim.fp16_optimizer import FP16_Optimizer as ORT_FP16_Optimizer
from onnxruntime.training.ortmodule import ORTModule
from packaging import version
Expand Down Expand Up @@ -277,16 +276,6 @@ def parse_args():
return args


def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"


def main(args):
logging_dir = os.path.join(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(
Expand Down Expand Up @@ -360,22 +349,14 @@ def load_model_hook(models, input_dir):

# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
create_repo(repo_name, exist_ok=True, token=args.hub_token)
repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)

with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

if args.push_to_hub:
repo_id = create_repo(
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
).repo_id

# Initialize the model
if args.model_config_name_or_path is None:
model = UNet2DModel(
Expand Down Expand Up @@ -691,7 +672,12 @@ def transform_images(examples):
ema_model.restore(unet.parameters())

if args.push_to_hub:
repo.push_to_hub(commit_message=f"Epoch {epoch}", blocking=False)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
commit_message=f"Epoch {epoch}",
ignore_patterns=["step_*", "epoch_*"],
)

accelerator.end_training()

Expand Down
Loading

0 comments on commit cc12f3e

Please sign in to comment.