diff --git a/test/spmd/test_train_spmd_linear_model.py b/test/spmd/test_train_spmd_linear_model.py index 686178292ea..f9d0a006175 100644 --- a/test/spmd/test_train_spmd_linear_model.py +++ b/test/spmd/test_train_spmd_linear_model.py @@ -45,10 +45,10 @@ def __init__(self): super(SimpleLinear, self).__init__() self.fc1 = nn.Linear(FLAGS.input_dim, FLAGS.input_dim // 2) self.relu = nn.ReLU() - self.fc2 = nn.Linear(FLAGS.input_dim // 2, 1) + self.fc2 = nn.Linear(FLAGS.input_dim // 2, 3) # Add an additional 1x1 layer at the end to ensure the final layer # is not sharded. - self.fc3 = nn.Linear(1, 1) + self.fc3 = nn.Linear(3, 3) def forward(self, x): y = self.relu(self.fc1(x))