Skip to content

Commit

Permalink
update the outdated SyncBatchNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
lruizcalico committed Sep 27, 2023
1 parent 66b1c24 commit b00b21f
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ def conv_block(

# normalize
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(
Expand Down Expand Up @@ -221,8 +221,8 @@ def conv_dna(
else:
# normalize
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current)
Expand Down Expand Up @@ -303,8 +303,8 @@ def conv_nac(

# normalize
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current)
Expand Down Expand Up @@ -479,11 +479,11 @@ def fpn_unet(

# normalize
if norm_type == "batch-sync":
current1 = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current1 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current1)
current2 = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current2 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current2)
elif norm_type == "batch":
current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1)
Expand Down Expand Up @@ -570,8 +570,8 @@ def fpn1_unet(

# normalize
if norm_type == "batch-sync":
current1 = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current1 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current1)
elif norm_type == "batch":
current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1)
Expand Down Expand Up @@ -648,11 +648,11 @@ def upsample_unet(

# normalize
if norm_type == "batch-sync":
current1 = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current1 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current1)
current2 = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current2 = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current2)
elif norm_type == "batch":
current1 = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current1)
Expand Down Expand Up @@ -745,8 +745,8 @@ def tconv_nac(

# normalize
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(momentum=bn_momentum)(current)
Expand Down Expand Up @@ -824,8 +824,8 @@ def conv_block_2d(

# normalize
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(
Expand Down Expand Up @@ -1870,8 +1870,8 @@ def dense_block(
if norm_gamma is None:
norm_gamma = "zeros" if residual else "ones"
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(
Expand Down Expand Up @@ -1940,8 +1940,8 @@ def dense_nac(
if norm_gamma is None:
norm_gamma = "zeros" if residual else "ones"
if norm_type == "batch-sync":
current = tf.keras.layers.experimental.SyncBatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma
current = tf.keras.layers.BatchNormalization(
momentum=bn_momentum, gamma_initializer=norm_gamma, synchronized=True
)(current)
elif norm_type == "batch":
current = tf.keras.layers.BatchNormalization(
Expand Down

0 comments on commit b00b21f

Please sign in to comment.