Skip to content

Commit

Permalink
fix dtype for load, use PartialState instead of accelerator to init p…
Browse files Browse the repository at this point in the history
…rocess group, remove redundant wandb callback
  • Loading branch information
winglian committed Mar 30, 2024
1 parent 947b010 commit ae109c8
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 18 deletions.
4 changes: 0 additions & 4 deletions src/axolotl/core/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,10 +898,6 @@ def get_callbacks(self):
):
callbacks.append(SaveBetterTransformerModelCallback())

if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
if self.cfg.use_mlflow and is_mlflow_available():
from axolotl.utils.callbacks.mlflow_ import (
SaveAxolotlConfigtoMlflowCallback,
Expand Down
20 changes: 9 additions & 11 deletions src/axolotl/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,25 @@
import os
import pickle # nosec
from contextlib import contextmanager
from datetime import timedelta

import torch
import torch.distributed as dist
from accelerate import Accelerator
from accelerate import PartialState

accelerate = None # pylint: disable=invalid-name


def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
distributed_state = None # pylint: disable=invalid-name


def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
global distributed_state # pylint: disable=global-statement
if not distributed_state:
timeout = int(os.environ.get("AXOLOTL_NCCL_TIMEOUT", 1800))
distributed_state = PartialState(timeout=timedelta(seconds=timeout))

return distributed_state.use_distributed and distributed_state.initialized


def barrier():
Expand Down
4 changes: 1 addition & 3 deletions src/axolotl/utils/model_shard_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def load_sharded_model(
_attn_implementation=model_config._attn_implementation, # pylint: disable=protected-access
trust_remote_code=cfg.trust_remote_code,
)
dtype = torch.bfloat16 if cfg.bf16 else None
dtype = torch_dtype if not cfg.float32 else None
model.to(dtype=dtype, device="cpu" if low_memory else cfg.local_rank)
else:
with init_empty_weights():
Expand All @@ -161,8 +161,6 @@ def load_sharded_model(
torch_dtype=torch_dtype,
trust_remote_code=cfg.trust_remote_code,
)
if cfg.bf16:
model.to(torch.bfloat16)
return model


Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@ def load_model(
base_model,
model_config,
cfg,
torch_dtype=cfg.torch_dtype,
)
skip_move_to_device = True
elif qlora_fsdp and cfg.fsdp_config.fsdp_cpu_ram_efficient_loading:
Expand Down

0 comments on commit ae109c8

Please sign in to comment.