diff --git a/test/benchmarks/test_benchmark_experiment.py b/test/benchmarks/test_benchmark_experiment.py index 06c004bcf1e..461df56687f 100644 --- a/test/benchmarks/test_benchmark_experiment.py +++ b/test/benchmarks/test_benchmark_experiment.py @@ -6,15 +6,16 @@ class BenchmarkExperimentTest(unittest.TestCase): def test_to_dict(self): - be = BenchmarkExperiment("cpu", "PJRT", "some xla_flags", "openxla", + be = BenchmarkExperiment("cpu", "PJRT", "some xla_flags", "openxla", None, "train", "123") actual = be.to_dict() - self.assertEqual(7, len(actual)) + self.assertEqual(8, len(actual)) self.assertEqual("cpu", actual["accelerator"]) self.assertTrue("accelerator_model" in actual) self.assertEqual("PJRT", actual["xla"]) self.assertEqual("some xla_flags", actual["xla_flags"]) self.assertEqual("openxla", actual["dynamo"]) + self.assertEqual(None, actual["torch_xla2"]) self.assertEqual("train", actual["test"]) self.assertEqual("123", actual["batch_size"])