diff --git a/u2net_portrait_test.py b/u2net_portrait_test.py index 7e103dd3..f6fddb51 100644 --- a/u2net_portrait_test.py +++ b/u2net_portrait_test.py @@ -101,8 +101,9 @@ def main(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) - - d1,d2,d3,d4,d5,d6,d7= net(inputs_test) + + with torch.no_grad(): + d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = 1.0 - d1[:,0,:,:] diff --git a/u2net_test.py b/u2net_test.py index 02f03001..cd092c2e 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -101,8 +101,9 @@ def main(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) - - d1,d2,d3,d4,d5,d6,d7= net(inputs_test) + + with torch.no_grad(): + d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = d1[:,0,:,:]