Skip to content

Commit

Permalink
Minor cleanups (#858)
Browse files Browse the repository at this point in the history
* nits

* logger

* add log

* lint
  • Loading branch information
mvpatel2000 authored Jan 11, 2024
1 parent c694121 commit a7c36bc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 14 deletions.
4 changes: 2 additions & 2 deletions llmfoundry/models/mpt/modeling_mpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,12 @@ def __init__(self, config: MPTConfig):
for module in self.modules():
if hasattr(module, 'bias') and isinstance(
module.bias, nn.Parameter):
log.info(f'Removing bias ({module.bias}) from {module}.')
log.info(f'Removing bias from {module=}.')
module.register_parameter('bias', None)

# For transformer engine
if hasattr(module, 'use_bias'):
log.info(f'Setting use_bias=False for {module}.')
log.info(f'Setting use_bias=False for {module=}.')
module.use_bias = False

log.debug(self)
Expand Down
14 changes: 3 additions & 11 deletions llmfoundry/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,18 +120,10 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
# Set defaults for mixed initialization
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)
# Always set `sync_module_states` to True when using hybrid sharding
if fsdp_config is not None and \
fsdp_config.get('sharding_strategy', 'FULL_SHARD') in ['HYBRID_SHARD', '_HYBRID_SHARD_ZERO2'] \
and not fsdp_config.get('sync_module_states', False):
warnings.warn(
('Setting `sync_module_states = True` for FSDP. This is required '
'when using hybrid sharding.'))
fsdp_config['sync_module_states'] = True

# no mixed precision needed for weights when they're already 16 bits

# No mixed precision needed for weights when they're already 16 bits
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'f16', 'float16', 'bfloat16', 'amp_fp16',
small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16',
'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
Expand Down
6 changes: 5 additions & 1 deletion scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,13 +438,17 @@ def main(cfg: DictConfig) -> Trainer:
format=
f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s'
)
logging.getLogger('llmfoundry').setLevel(python_log_level.upper())
logging.getLogger('llmfoundry').setLevel(
python_log_level.upper()) # Foundry module
logging.getLogger(__name__).setLevel(
python_log_level.upper()) # Train script

# Initialize context
init_context = process_init_device(model_config, fsdp_config)
logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)

# Build tokenizer
log.info('Building tokenizer...')
tokenizer_name = tokenizer_config['name']
tokenizer_kwargs = tokenizer_config.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
Expand Down

0 comments on commit a7c36bc

Please sign in to comment.