Skip to content

Commit

Permalink
Merge pull request apache#1193 from moazreyad/dev-postgresql
Browse files Browse the repository at this point in the history
Add the Sum Error Loss for Synfolw
  • Loading branch information
chrishkchris authored Aug 5, 2024
2 parents d917f09 + 20ffa82 commit 8735f49
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions examples/cnn_ms/msmlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
from singa import tensor
from singa import opt
from singa import device
from singa.autograd import Operator
from singa.layer import Layer
from singa import singa_wrap as singa
import argparse
import numpy as np

Expand All @@ -30,10 +33,12 @@
singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}


class MLP(model.Model):
#### self-defined loss end

class MSMLP(model.Model):

def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
super(MLP, self).__init__()
super(MSMLP, self).__init__()
self.num_classes = num_classes
self.dimension = 2

Expand All @@ -42,18 +47,20 @@ def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
self.linear2 = layer.Linear(num_classes)
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()

self.sum_error = SumErrorLayer()

def forward(self, inputs):
y = self.linear1(inputs)
y = self.relu(y)
y = self.linear2(y)
return y

def train_one_batch(self, x, y, dist_option, spars):
def train_one_batch(self, x, y, synflow_flag, dist_option, spars):
out = self.forward(x)
loss = self.softmax_cross_entropy(out, y)

if dist_option == 'plain':
self.optimizer(loss)
pn_p_g_list = self.optimizer(loss)
elif dist_option == 'half':
self.optimizer.backward_and_update_half(loss)
elif dist_option == 'partialUpdate':
Expand All @@ -66,7 +73,7 @@ def train_one_batch(self, x, y, dist_option, spars):
self.optimizer.backward_and_sparse_update(loss,
topK=False,
spars=spars)
return out, loss
return pn_p_g_list, out, loss

def set_optimizer(self, optimizer):
self.optimizer = optimizer
Expand All @@ -80,7 +87,7 @@ def create_model(pretrained=False, **kwargs):
Returns:
The created CNN model.
"""
model = MLP(**kwargs)
model = MSMLP(**kwargs)

return model

Expand Down

0 comments on commit 8735f49

Please sign in to comment.