From 07d1b9b27763ab68147b96125d7be6c4c761b1d0 Mon Sep 17 00:00:00 2001 From: Luan Pham Date: Tue, 1 Dec 2020 22:59:12 +0700 Subject: [PATCH] Add torch no grad, reduce GPU mem significantly --- u2net_portrait_test.py | 5 +++-- u2net_test.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/u2net_portrait_test.py b/u2net_portrait_test.py index 7e103dd3..f6fddb51 100644 --- a/u2net_portrait_test.py +++ b/u2net_portrait_test.py @@ -101,8 +101,9 @@ def main(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) - - d1,d2,d3,d4,d5,d6,d7= net(inputs_test) + + with torch.no_grad(): + d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = 1.0 - d1[:,0,:,:] diff --git a/u2net_test.py b/u2net_test.py index 02f03001..cd092c2e 100644 --- a/u2net_test.py +++ b/u2net_test.py @@ -101,8 +101,9 @@ def main(): inputs_test = Variable(inputs_test.cuda()) else: inputs_test = Variable(inputs_test) - - d1,d2,d3,d4,d5,d6,d7= net(inputs_test) + + with torch.no_grad(): + d1,d2,d3,d4,d5,d6,d7= net(inputs_test) # normalization pred = d1[:,0,:,:]