Skip to content

Commit

Permalink
fix epoch callback
Browse files Browse the repository at this point in the history
  • Loading branch information
charlie-becker committed Feb 29, 2024
1 parent 9497cad commit 68546fa
Showing 1 changed file with 10 additions and 17 deletions.
27 changes: 10 additions & 17 deletions mlguess/keras/callbacks.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.callbacks import (
from keras import backend as K
from keras.callbacks import (
Callback,
ModelCheckpoint,
CSVLogger,
EarlyStopping,
)
import tensorflow as tf
from tensorflow.python.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
from keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
from typing import List, Dict
import logging
from functools import partial
import math
import os
import keras
import keras.ops as ops

logger = logging.getLogger(__name__)


def get_callbacks(config: Dict[str, str], path_extend=False) -> List[Callback]:
callbacks = []

Expand Down Expand Up @@ -73,17 +73,10 @@ def on_epoch_end(self, epoch: int, logs: Dict[str, float] = None) -> None:
logs["lr"] = K.get_value(self.model.optimizer.lr)


class ReportEpoch(tf.keras.callbacks.Callback):
def __init__(self, annealing_coef, this_epoch_num):
super(ReportEpoch, self).__init__()
self.this_epoch = 0
self.annealing_coef = annealing_coef
self.this_epoch_num = this_epoch_num
class ReportEpoch(keras.callbacks.Callback):
def __init__(self, epoch_var):
self.epoch_var = epoch_var

def on_epoch_begin(self, epoch, logs=None):
if logs is None:
logs = {}
self.this_epoch += 1
K.set_value(
self.this_epoch_num, self.this_epoch
)
self.epoch_var.assign_add(1)

0 comments on commit 68546fa

Please sign in to comment.