Skip to content

Commit

Permalink
Merge pull request apache#1170 from dcslin/feature/ms_model_mlp
Browse files Browse the repository at this point in the history
update msmlp
  • Loading branch information
chrishkchris authored May 20, 2024
2 parents e16a203 + da18ab9 commit f9027d2
Showing 1 changed file with 14 additions and 4 deletions.
18 changes: 14 additions & 4 deletions examples/ms_model_mlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,21 +83,30 @@ def forward(self, x):

class MSMLP(model.Model):

def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
def __init__(self, data_size=10, perceptron_size=100, num_classes=10, layer_hidden_list=[10,10,10,10]):
super(MSMLP, self).__init__()
self.num_classes = num_classes
self.dimension = 2

self.relu = layer.ReLU()
self.linear1 = layer.Linear(perceptron_size)
self.linear2 = layer.Linear(num_classes)
self.linear1 = layer.Linear(layer_hidden_list[0])
self.linear2 = layer.Linear(layer_hidden_list[1])
self.linear3 = layer.Linear(layer_hidden_list[2])
self.linear4 = layer.Linear(layer_hidden_list[3])
self.linear5 = layer.Linear(num_classes)
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
self.sum_error = SumErrorLayer()

def forward(self, inputs):
y = self.linear1(inputs)
y = self.relu(y)
y = self.linear2(y)
y = self.relu(y)
y = self.linear3(y)
y = self.relu(y)
y = self.linear4(y)
y = self.relu(y)
y = self.linear5(y)
return y

def train_one_batch(self, x, y, dist_option, spars, synflow_flag):
Expand Down Expand Up @@ -144,6 +153,7 @@ def set_optimizer(self, optimizer):

def create_model(pretrained=False, **kwargs):
"""Constructs a CNN model.
Args:
pretrained (bool): If True, returns a pre-trained model.
Expand Down

0 comments on commit f9027d2

Please sign in to comment.