From 851a07bd69360bebdfa99da336f7853471b81ade Mon Sep 17 00:00:00 2001 From: drisspg Date: Mon, 24 Jun 2024 15:55:15 -0700 Subject: [PATCH] make testing better on amd --- float8_experimental/config.py | 13 ++++++++++++- test/test_base.py | 2 +- test/test_everything.sh | 10 ++++++++++ 3 files changed, 23 insertions(+), 2 deletions(-) diff --git a/float8_experimental/config.py b/float8_experimental/config.py index 99574c0..f389e92 100644 --- a/float8_experimental/config.py +++ b/float8_experimental/config.py @@ -3,6 +3,17 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import os + + +def get_bool_env(var_name: str, default) -> bool: + value = os.environ.get(var_name, "").lower() + if value in ("true", "1", "yes"): + return True + elif value in ("false", "0", "no"): + return False + return default + # If True, on the first iteration of Float8Linear the amaxes will be # initialized with the incoming data. As of 2023-12-30, this doesn't work @@ -22,7 +33,7 @@ # If True, use 'fnuz' float8 types for calculations. # Currently, ROCm only supports fnuz variants. -use_fnuz_dtype = False +use_fnuz_dtype = get_bool_env("USE_FNUZ_DTYPE", False) # If True, then prior to performing the fp8 scaled mamtmul we will pad the # inner dimension of a (dim 1) and b (dim 2) with 0s. This is needed for matmuls diff --git a/test/test_base.py b/test/test_base.py index b688ccb..f48b613 100644 --- a/test/test_base.py +++ b/test/test_base.py @@ -398,7 +398,7 @@ def test_merge_configs(self): @pytest.mark.parametrize("use_fast_accum", [True, False]) def test_pad_inner_dim(self, base_dtype, use_fast_accum): torch.manual_seed(42) - input_dtype = torch.float8_e4m3fn + input_dtype = e4m3_dtype compare_type = torch.float32 a = torch.randn(16, 41, device="cuda", dtype=base_dtype) diff --git a/test/test_everything.sh b/test/test_everything.sh index b989393..6258409 100755 --- a/test/test_everything.sh +++ b/test/test_everything.sh @@ -4,6 +4,16 @@ set -e IS_ROCM=$(rocm-smi --version || true) + +# Set USE_FNUZ_DTYPE environment variable if IS_ROCM is not empty +if [ -n "$IS_ROCM" ]; then + export USE_FNUZ_DTYPE=true + echo "ROCm detected. Set USE_FNUZ_DTYPE=true" +else + export USE_FNUZ_DTYPE=false + echo "ROCm not detected. Set USE_FNUZ_DTYPE=false" +fi + pytest test/test_base.py pytest test/test_sam.py pytest test/test_compile.py