From ce2a59dc4021c6e21a57e949979798165f5938cb Mon Sep 17 00:00:00 2001 From: fcogidi <41602287+fcogidi@users.noreply.github.com> Date: Tue, 24 Sep 2024 12:23:02 -0400 Subject: [PATCH] Add support for tf32 and set precision to bf16-mixed if available --- mmlearn/cli/run.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/mmlearn/cli/run.py b/mmlearn/cli/run.py index e1501e0..f342458 100644 --- a/mmlearn/cli/run.py +++ b/mmlearn/cli/run.py @@ -12,6 +12,7 @@ from omegaconf import OmegaConf from pytorch_lightning.utilities import rank_zero_only from torch.utils.data import DataLoader +from transformers.utils.import_utils import is_torch_tf32_available from mmlearn.cli._instantiators import ( instantiate_callbacks, @@ -41,7 +42,11 @@ def main(cfg: MMLearnConf) -> None: # noqa: PLR0912 cfg_copy = copy.deepcopy(cfg) # copy of the config for logging L.seed_everything(cfg.seed, workers=True) - torch.set_float32_matmul_precision("high") + + if is_torch_tf32_available(): + torch.backends.cuda.matmul.allow_tf32 = True + if "16-mixed" in cfg.trainer.precision: + cfg.trainer.precision = "bf16-mixed" # setup trainer first so that we can get some variables for distributed training callbacks = instantiate_callbacks(cfg.trainer.get("callbacks"))