diff --git a/KD_Lib/KD/vision/attention/loss_metric.py b/KD_Lib/KD/vision/attention/loss_metric.py index fc0c84d..44a3bf4 100755 --- a/KD_Lib/KD/vision/attention/loss_metric.py +++ b/KD_Lib/KD/vision/attention/loss_metric.py @@ -21,17 +21,20 @@ def forward(self, teacher_output, student_output): :param student_output (torch.FloatTensor): Prediction made by the student model """ - A_t = teacher_output[1:] - A_s = student_output[1:] + A_t = teacher_output # [1:] + A_s = student_output # [1:] + loss = 0.0 for (layerT, layerS) in zip(A_t, A_s): + xT = self.single_at_loss(layerT) xS = self.single_at_loss(layerS) loss += (xS - xT).pow(self.p).mean() + return loss def single_at_loss(self, activation): """ Function for calculating single attention loss """ - return F.normalize(activation.pow(self.p).mean(1).view(activation.size(0), -1)) + return F.normalize(activation.pow(self.p).view(activation.size(0), -1)) diff --git a/KD_Lib/__init__.py b/KD_Lib/__init__.py index 8e31082..53e59db 100755 --- a/KD_Lib/__init__.py +++ b/KD_Lib/__init__.py @@ -4,4 +4,4 @@ __author__ = """Het Shah""" __email__ = "divhet163@gmail.com" -__version__ = "__version__ = '0.0.31'" +__version__ = "__version__ = '0.0.32'" diff --git a/setup.cfg b/setup.cfg index 02d803c..d204f9c 100755 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.0.31 +current_version = 0.0.32 commit = True tag = True diff --git a/setup.py b/setup.py index 7a61bb1..e8efc9d 100644 --- a/setup.py +++ b/setup.py @@ -82,6 +82,6 @@ test_suite="tests", tests_require=test_requirements, url="https://github.com/SforAiDL/KD_Lib", - version="0.0.31", + version="0.0.32", zip_safe=False, ) diff --git a/tests/test_kd.py b/tests/test_kd.py index 7b297c3..bfc1e6b 100644 --- a/tests/test_kd.py +++ b/tests/test_kd.py @@ -10,7 +10,6 @@ RCO, TAKD, Attention, - BaseClass, LabelSmoothReg, MeanTeacher, MessyCollab, @@ -29,7 +28,7 @@ img_size = (32, 32) img_channels = 3 n_classes = 10 -len_dataset = 4 +len_dataset = 8 batch_size = 2 train_loader = test_loader = DataLoader( @@ -86,7 +85,7 @@ def test_TAKD(): assistant_train_order = [[-1], [-1, 0]] - distil = TAKD( + distiller = TAKD( teacher, assistants, student, @@ -98,32 +97,32 @@ def test_TAKD(): student_optimizer, ) - distil.train_teacher(epochs=1, plot_losses=False, save_model=False) - distil.train_assistants(epochs=1, plot_losses=False, save_model=False) - distil.train_student(epochs=1, plot_losses=False, save_model=False) - distil.get_parameters() + distiller.train_teacher(epochs=1, plot_losses=False, save_model=False) + distiller.train_assistants(epochs=1, plot_losses=False, save_model=False) + distiller.train_student(epochs=1, plot_losses=False, save_model=False) + distiller.get_parameters() -# def test_attention(): +def test_Attention(): -# att = Attention( -# teacher, -# student, -# train_loader, -# test_loader, -# t_optimizer, -# s_optimizer, -# ) + distiller = Attention( + teacher, + student, + train_loader, + test_loader, + t_optimizer, + s_optimizer, + ) -# att.train_teacher(epochs=1, plot_losses=False, save_model=False) -# att.train_student(epochs=1, plot_losses=False, save_model=False) -# att.evaluate(teacher=False) -# att.get_parameters() + distiller.train_teacher(epochs=1, plot_losses=False, save_model=False) + distiller.train_student(epochs=1, plot_losses=False, save_model=False) + distiller.evaluate(teacher=False) + distiller.get_parameters() def test_NoisyTeacher(): - experiment = NoisyTeacher( + distiller = NoisyTeacher( teacher, student, train_loader, @@ -135,10 +134,10 @@ def test_NoisyTeacher(): device="cpu", ) - experiment.train_teacher(epochs=1, plot_losses=False, save_model=False) - experiment.train_student(epochs=1, plot_losses=False, save_model=False) - experiment.evaluate(teacher=False) - experiment.get_parameters() + distiller.train_teacher(epochs=1, plot_losses=False, save_model=False) + distiller.train_student(epochs=1, plot_losses=False, save_model=False) + distiller.evaluate(teacher=False) + distiller.get_parameters() def test_VirtualTeacher(): @@ -158,21 +157,21 @@ def test_SelfTraining(): distiller.get_parameters() -# def test_mean_teacher(): +# def test_MeanTeacher(): -# mt = MeanTeacher( -# teacher_model, -# student_model, +# distiller = MeanTeacher( +# teacher, +# student, # train_loader, # test_loader, # t_optimizer, # s_optimizer, # ) -# mt.train_teacher(epochs=1, plot_losses=False, save_model=False) -# mt.train_student(epochs=1, plot_losses=False, save_model=False) -# mt.evaluate() -# mt.get_parameters() +# distiller.train_teacher(epochs=1, plot_losses=False, save_model=False) +# distiller.train_student(epochs=1, plot_losses=False, save_model=False) +# distiller.evaluate() +# distiller.get_parameters() def test_RCO(): @@ -192,15 +191,15 @@ def test_RCO(): distiller.get_parameters() -# def test_BANN(): +def test_BANN(): -# model = deepcopy(mock_vision_model) -# optimizer = optim.SGD(model.parameters(), 0.01) + model = deepcopy(mock_vision_model) + optimizer = optim.SGD(model.parameters(), 0.01) -# distiller = BANN(model, train_loader, test_loader, optimizer, num_gen=2) + distiller = BANN(model, train_loader, test_loader, optimizer, num_gen=2) -# distiller.train_student(epochs=1, plot_losses=False, save_model=False) -# distiller.evaluate() + distiller.train_student(epochs=1, plot_losses=False, save_model=False) + # distiller.evaluate() def test_PS(): @@ -237,7 +236,7 @@ def test_LSR(): distiller.get_parameters() -def test_soft_random(): +def test_SoftRandom(): distiller = SoftRandom( teacher, @@ -254,7 +253,7 @@ def test_soft_random(): distiller.get_parameters() -def test_messy_collab(): +def test_MessyCollab(): distiller = MessyCollab( teacher, @@ -271,21 +270,6 @@ def test_messy_collab(): distiller.get_parameters() -# def test_bert2lstm(): -# student_model = LSTMNet( -# input_dim=len(text_field.vocab), num_classes=2, dropout_prob=0.5 -# ) -# optimizer = optim.Adam(student_model.parameters()) -# -# experiment = BERT2LSTM( -# student_model, bert2lstm_train_loader, bert2lstm_train_loader, optimizer, train_df, val_df -# ) -# # experiment.train_teacher(epochs=1, plot_losses=False, save_model=False) -# experiment.train_student(epochs=1, plot_losses=False, save_model=False) -# experiment.evaluate_student() -# experiment.evaluate_teacher() - - def test_DML(): student_1 = deepcopy(mock_vision_model) @@ -312,3 +296,18 @@ def test_DML(): ) distiller.evaluate() distiller.get_parameters() + + +# def test_BERT2LSTM(): +# student_model = LSTMNet( +# input_dim=len(text_field.vocab), num_classes=2, dropout_prob=0.5 +# ) +# optimizer = optim.Adam(student_model.parameters()) +# +# distiller = BERT2LSTM( +# student_model, bert2lstm_train_loader, bert2lstm_train_loader, optimizer, train_df, val_df +# ) +# # distiller.train_teacher(epochs=1, plot_losses=False, save_model=False) +# distiller.train_student(epochs=1, plot_losses=False, save_model=False) +# distiller.evaluate_student() +# distiller.evaluate_teacher()