Skip to content

Commit

Permalink
modify export with pir
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 committed Dec 23, 2024
1 parent cd18221 commit 3c0a956
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
repos:
- repo: https://github.com/PaddlePaddle/mirrors-yapf.git
- repo: https://github.com/cuicheng01/mirrors-yapf.git
rev: 0d79c0c469bab64f7229c9aca2b1186ef47f0e37
hooks:
- id: yapf
Expand Down
48 changes: 18 additions & 30 deletions ppcls/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,37 +606,25 @@ def export(self,
model.base_model.quanter.save_quantized_model(model,
save_path + "_int8")
else:
paddle_version = version.parse(paddle.__version__)
if (paddle_version >= version.parse('3.0.0b2') or paddle_version ==
version.parse('0.0.0')) and os.environ.get(
"FLAGS_enable_pir_api", None) not in ["0", "False"]:
save_path = os.path.dirname(save_path)
for enable_pir in [True, False]:
if not enable_pir:
save_path_no_pir = os.path.join(save_path, "inference")
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] +
self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(model, save_path_no_pir)
else:
save_path_pir = os.path.join(
os.path.dirname(save_path),
f"{os.path.basename(save_path)}_pir", "inference")
paddle.jit.save(model, save_path_pir)
shutil.copy(
dst_path,
os.path.join(
os.path.dirname(save_path_pir),
os.path.basename(dst_path)), )
else:
if self.config["Global"].get("export_with_pir", False):
paddle_version = version.parse(paddle.__version__)
assert (paddle_version >= version.parse('3.0.0b2') or
paddle_version == version.parse('0.0.0')
) and os.environ.get("FLAGS_enable_pir_api",
None) not in ["0", "False"]
paddle.jit.save(model, save_path)
else:
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(
shape=[None] +
self.config["Global"]["image_shape"],
dtype='float32')
])
paddle.jit.save(model, save_path)
logger.info(
f"Export succeeded! The inference model exported has been saved in \"{save_path}\"."
)
Expand Down

0 comments on commit 3c0a956

Please sign in to comment.