Skip to content

Commit

Permalink
remove cupy require for cumprod
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 23, 2022
1 parent 9e58fac commit 5e559cf
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.3.4'
__version__ = '1.3.3.5'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
11 changes: 4 additions & 7 deletions python/jittor/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,22 +661,19 @@ def _prod(x,dim=0):

def numpy_cumsum(x, dim=None):
def cumsum_forward(np, data):
dim = data['inputs'][1].item()
a = data['inputs'][0]
b = data['outputs'][0]
np.cumsum(a, axis=dim, out=b)

def cumsum_backward(np, data):
dim = data['inputs'][1].item()
dout = data['dout']
out = data['outputs'][0]
np.cumsum(dout[..., ::-1], axis=dim, out=out)
np.copyto(out, out[..., ::-1])
np.cumsum(np.flip(dout, dim), axis=dim, out=out)
np.copyto(out, np.flip(out, dim))
if (dim == None):
dim = -1
assert(dim >= -1 and dim < len(x.shape))
dim_var = jt.array([dim],dtype=int)
return jt.numpy_code(x.shape, x.dtype, [x, dim_var.detach()], cumsum_forward, [cumsum_backward])
return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])

def cub_cumsum(x, dim=None):
if (dim == None):
Expand Down Expand Up @@ -1040,7 +1037,7 @@ def auto_parallel(n, src, **kw):
return new_src


def cumprod(a, dim):
def numpy_cumprod(a, dim):
class CumprodFunc(jt.Function):
def forward_code(self, np, data):
a = data["inputs"][0]
Expand Down
3 changes: 2 additions & 1 deletion python/jittor/test/test_cumprod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class TestCumprod(unittest.TestCase):
def test_cumprod_cpu(self):
for i in range(1,6):
for j in range(i):
print("test", i, j)
x = np.random.rand(*((10,)*i))
x_jt = jt.array(x)
y_jt = jt.cumprod(x_jt, j).sqr()
Expand All @@ -34,7 +35,7 @@ def test_cumprod_cpu(self):
y_tc.sum().backward()
g_tc = x_tc.grad
assert np.allclose(y_jt.numpy(), y_tc.data)
assert np.allclose(g_jt.numpy(), g_tc.data)
np.testing.assert_allclose(g_jt.numpy(), g_tc.data, atol=1e-5)

@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
Expand Down

0 comments on commit 5e559cf

Please sign in to comment.