From 712855a646cc035501aadb836fed49326b599889 Mon Sep 17 00:00:00 2001 From: Manfei <41607353+ManfeiBai@users.noreply.github.com> Date: Wed, 11 Dec 2024 09:45:35 -0800 Subject: [PATCH] Update test_train_spmd_linear_model.py --- test/spmd/test_train_spmd_linear_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index f9d0a006175..9e8d54daf30 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -46,7 +46,7 @@ def __init__(self): self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2) self.relu = nn.ReLU() self.fc2 = nn.Linear(FLAGS.input_dim // 2, 3) - # Add an additional 1x1 layer at the end to ensure the final layer + # Add an additional 3x3 layer at the end to ensure the final layer # is not sharded. self.fc3 = nn.Linear(3, 3)