Skip to content

Commit

Permalink
Merge pull request apache#1194 from GY-GitCode/24-8-6-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
lzjpaul authored Aug 7, 2024
2 parents 8735f49 + c01551e commit e9d1cc9
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions examples/cnn_ms/msmlp/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,35 @@

singa_dtype = {"float16": tensor.float16, "float32": tensor.float32}

#### self-defined loss begin

### from autograd.py
class SumError(Operator):

def __init__(self):
super(SumError, self).__init__()
# self.t = t.data

def forward(self, x):
# self.err = singa.__sub__(x, self.t)
self.data_x = x
# sqr = singa.Square(self.err)
# loss = singa.SumAll(sqr)
loss = singa.SumAll(x)
# self.n = 1
# for s in x.shape():
# self.n *= s
# loss /= self.n
return loss

def backward(self, dy=1.0):
# dx = self.err
dev = device.get_default_device()
dx = tensor.Tensor(self.data_x.shape, dev, singa_dtype['float32'])
dx.copy_from_numpy(np.ones(self.data_x.shape))
# dx *= float(2 / self.n)
dx *= dy
return dx

#### self-defined loss end

Expand Down

0 comments on commit e9d1cc9

Please sign in to comment.