Skip to content

Commit

Permalink
add where(cond,x,y) alias
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Apr 5, 2022
1 parent 9c74699 commit 2901e57
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.2.2'
__version__ = '1.3.2.3'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
6 changes: 6 additions & 0 deletions python/jittor/src/ops/where_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ WhereOp::WhereOp(Var* cond, NanoString dtype) : cond(cond) {
for (uint i=0; i<ndim; i++)
outs[i] = create_output(nullptr, dtype);
}
static auto make_ternary = get_op_info("ternary")
.get_constructor<VarPtr, Var*, Var*, Var*>();
WhereOp::WhereOp(Var* cond, Var* x, Var* y) {
forward(make_ternary(cond, x, y));
return;
}

void WhereOp::infer_shape() {
auto ndim = cond->shape.size();
Expand Down
4 changes: 4 additions & 0 deletions python/jittor/src/ops/where_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ struct WhereOp : Op {
*/
// @attrs(multiple_outputs)
WhereOp(Var* cond, NanoString dtype=ns_int32);
/**
* Condition operator, perform cond ? x : y
* */
WhereOp(Var* cond, Var* x, Var* y);
void infer_shape() override;

const char* name() const override { return "where"; }
Expand Down
12 changes: 12 additions & 0 deletions python/jittor/test/test_ternary_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ def test_with_np(self):
assert (jda.data==(a>b)*1).all()
assert (jdb.data==1-(a>b)).all()

def test_where(self):
np.random.seed(0)
a = np.random.rand(5,10).astype("float32")
b = np.random.rand(5,10).astype("float32")
ja = jt.array(a)
jb = jt.array(b)
jc = jt.where(ja>jb, ja, jb)
assert (jc.data==np.maximum(a,b)).all(), f"\n{jc.data}\n{np.maximum(a,b)}\n{a}\n{b}"
jda, jdb = jt.grad(jc, [ja, jb])
assert (jda.data==(a>b)*1).all()
assert (jdb.data==1-(a>b)).all()

def test_min(self):
np.random.seed(1)
a = np.random.rand(5,10).astype("float32")
Expand Down

0 comments on commit 2901e57

Please sign in to comment.