Skip to content

Commit

Permalink
polish fuser
Browse files Browse the repository at this point in the history
  • Loading branch information
cjld committed Jan 11, 2022
1 parent 5b4576c commit 6d1b5e4
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/jittor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************

__version__ = '1.3.1.34'
__version__ = '1.3.1.35'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
Expand Down
13 changes: 13 additions & 0 deletions python/jittor/src/opt/pass/loop_var_analyze_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,19 @@ void LoopVarAnalyzePass::run() {
ir->replace({{"op"+S(i)+"_outputshape0", "1"}});
}
}

// fix index op stride not found
replace_vars.clear();
for (int i=0; i<this->op->ops.size(); i++) {
auto op = this->op->ops[i];
if (op->type() == OpType::element &&
op->name() == string("index")) {
for (int j=1; j<op->outputs().size(); i++)
replace_vars.push_back({"op"+S(i)+"_x"+S(j)+"stride", "op"+S(i)+"_x0stride"});
}
}
if (replace_vars.size())
ir->replace(replace_vars);
LOGvvvv << "KernelIR after replace\n" >> ir->to_string(0, true);
// move define
ir->move_loop_back();
Expand Down
7 changes: 7 additions & 0 deletions python/jittor/test/test_index_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,12 @@ def test_vary_shape_dep2(self):
def test_doc(self):
assert "Index Operator" in jt.index.__doc__

def test_wrong_fuse(self):
a,b = jt.index([10,10])
c = jt.zeros([10,10])
c = c.reindex([b+1,a])
x = b.clone()
jt.sync([c, x])

if __name__ == "__main__":
unittest.main()

0 comments on commit 6d1b5e4

Please sign in to comment.