Skip to content

Commit

Permalink
Merge pull request #690 from PINTO0309/fix_batchnorm
Browse files Browse the repository at this point in the history
Improved the conversion stability of `BatchNormalization`.
  • Loading branch information
PINTO0309 authored Sep 12, 2024
2 parents 5baf18d + 924464d commit 8d0541a
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 70 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -299,15 +299,15 @@ Video speed is adjusted approximately 50 times slower than actual speed.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
ghcr.io/pinto0309/onnx2tf:1.25.10
ghcr.io/pinto0309/onnx2tf:1.25.11

or

# Authentication is not required for pulls from Docker Hub.
docker run --rm -it \
-v `pwd`:/workdir \
-w /workdir \
docker.io/pinto0309/onnx2tf:1.25.10
docker.io/pinto0309/onnx2tf:1.25.11

or

Expand Down
2 changes: 1 addition & 1 deletion onnx2tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from onnx2tf.onnx2tf import convert, main

__version__ = '1.25.10'
__version__ = '1.25.11'
141 changes: 74 additions & 67 deletions onnx2tf/ops/BatchNormalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,17 @@ def make_node(

# Automatic correction of accuracy degradation
min_abs_err = sys.maxsize
min_abs_err_perm_1: List[int] = [idx for idx in range(len(mean.shape))]
min_abs_err_perm_1: List[int] = []
check_length = 0
if input_tensor.shape is not None and mean.shape is not None and len(input_tensor.shape) >= len(mean.shape):
check_length = len(input_tensor.shape)
else:
check_length = len(mean.shape)
min_abs_err_perm_1: List[int] = [idx for idx in range(check_length)]

if not disable_strict_mode:
if onnx_tensor_infos is not None and validation_data is not None:
tensor_1_candidate_for_transpositions = list(itertools.permutations(range(len(mean.shape))))
tensor_1_candidate_for_transpositions = list(itertools.permutations(range(check_length)))
# Search for the axis with the smallest error
for tensor_1_candidate_for_transposition in tensor_1_candidate_for_transpositions:
try:
Expand Down Expand Up @@ -470,71 +476,72 @@ def make_node(
except Exception as ex:
pass

tf_layers_dict[Y.name]['tf_node'] = \
tf.nn.batch_normalization(
x=input_tensor,
mean=\
transpose_with_flexing_deterrence(
input_tensor=mean,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(mean, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(mean),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
variance=\
transpose_with_flexing_deterrence(
input_tensor=var,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(var, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(var),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
offset=\
transpose_with_flexing_deterrence(
input_tensor=offset,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(offset, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(offset),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
scale=\
transpose_with_flexing_deterrence(
input_tensor=scale,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(scale, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(scale),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
variance_epsilon=epsilon,
)
if min_abs_err_perm_1 != [idx for idx in range(check_length)]:
tf_layers_dict[Y.name]['tf_node'] = \
tf.nn.batch_normalization(
x=input_tensor,
mean=\
transpose_with_flexing_deterrence(
input_tensor=mean,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(mean, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(mean),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
variance=\
transpose_with_flexing_deterrence(
input_tensor=var,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(var, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(var),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
offset=\
transpose_with_flexing_deterrence(
input_tensor=offset,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(offset, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(offset),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
scale=\
transpose_with_flexing_deterrence(
input_tensor=scale,
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
) if not isinstance(scale, np.ndarray) else \
transpose_with_flexing_deterrence(
input_tensor=tf.convert_to_tensor(scale),
perm=min_abs_err_perm_1,
output_shape=Y.shape \
if None not in Y.shape and Y.shape != [] else None,
**kwargs,
),
variance_epsilon=epsilon,
)
tf_type = tf.nn.batch_normalization

# Post-process transpose
Expand Down

0 comments on commit 8d0541a

Please sign in to comment.