Skip to content

Commit

Permalink
Modify batch type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
yourplatanus committed Apr 25, 2022
1 parent 1a20458 commit 4dfb52c
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/textbrewer/distiller_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def move_to_device(batch, device):
return batch

def get_outputs_from_batch(batch, device, model_T, model_S, args, no_teacher_forward=False):
if type(batch) is dict:
if isinstance(batch, abc.Mapping):
if 'teacher' in batch and 'student' in batch:
teacher_batch = batch['teacher']
student_batch = batch['student']
Expand All @@ -252,7 +252,7 @@ def get_outputs_from_batch(batch, device, model_T, model_S, args, no_teacher_for
results_T = auto_forward(model_T,teacher_batch,args)
#student outputs
student_batch = move_to_device(student_batch, device)
if type(student_batch) is dict:
if isinstance(student_batch, abc.Mapping):
results_S = model_S(**student_batch, **args)
else:
results_S = model_S(*student_batch, **args)
Expand All @@ -278,7 +278,7 @@ def get_outputs_from_batch(batch, device, model_T, model_S, args, no_teacher_for
return (teacher_batch,results_T), (student_batch,results_S)

def auto_forward(model,batch,args):
if type(batch) is dict:
if isinstance(batch, abc.Mapping):
if isinstance(model,(list,tuple)):
results = [v(**batch, **args) for v in model]
elif isinstance(model,dict):
Expand All @@ -292,4 +292,4 @@ def auto_forward(model,batch,args):
results = {k:v(*batch, **args) for k,v in model.items()}
else:
results = model(*batch, **args)
return results
return results

0 comments on commit 4dfb52c

Please sign in to comment.