Skip to content

Commit

Permalink
added ensemble with one shift
Browse files Browse the repository at this point in the history
  • Loading branch information
Anya Korsakova committed Oct 17, 2023
1 parent e84e220 commit dc6802a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/baskerville/seqnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,12 +219,12 @@ def build_embed(self, conv_layer_i: int, batch_norm: bool = True):

def build_ensemble(self, ensemble_rc: bool = False, ensemble_shifts=[0]):
"""Build ensemble of models computing on augmented input sequences."""
if ensemble_rc or len(ensemble_shifts) > 1:
if ensemble_rc or len(ensemble_shifts) > 1 or int(ensemble_shifts[0]) != 0:
# sequence input
sequence = tf.keras.Input(shape=(self.seq_length, 4), name="sequence")
sequences = [sequence]

if len(ensemble_shifts) > 1:
if len(ensemble_shifts) > 1 or int(ensemble_shifts[0]) != 0:
# generate shifted sequences
sequences = layers.EnsembleShift(ensemble_shifts)(sequences)

Expand Down

0 comments on commit dc6802a

Please sign in to comment.