Skip to content

Commit

Permalink
fix imports
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Emelyanov committed Feb 11, 2021
1 parent 636905f commit 47fbb4f
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
import time

import torch
# from transformers.tokenization_gpt2 import GPT2Tokenizer
from transformers.tokenization_gpt2 import GPT2Tokenizer

from src import mpu
from src.arguments import get_args
from src.fp16 import FP16_Module
from src.model import DistributedDataParallel as DDP
from src.model import GPT3Model
from src.pretrain_gpt3 import generate
from src.pretrain_gpt3 import initialize_distributed
from src.pretrain_gpt3 import set_random_seed
from .pretrain_gpt3 import generate
from .pretrain_gpt3 import initialize_distributed
from .pretrain_gpt3 import set_random_seed
from src.utils import Timers
from src.utils import export_to_huggingface_model
from src.utils import print_rank_0
Expand Down Expand Up @@ -105,11 +105,11 @@ def generate_samples(model, tokenizer, args):
context_length = len(context_tokens)

if context_length >= args.seq_length // 2:
print("\nContext length", context_length, \
print("\nContext length", context_length,
"\nPlease give smaller context (half of the sequence length)!")
continue
else:
context_tokens = tokenizer("EMPTY TEXT")['input_ids']
_ = tokenizer("EMPTY TEXT")['input_ids']

terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
Expand Down Expand Up @@ -140,7 +140,7 @@ def generate_samples(model, tokenizer, args):


def prepare_tokenizer(args):
tokenizer = GPT3Tokenizer.from_pretrained(args.tokenizer_path)
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_path)
eod_token = tokenizer.encoder['<pad>']
num_tokens = len(tokenizer)

Expand All @@ -166,7 +166,7 @@ def main():
torch.backends.cudnn.enabled = False

# Timer.
timers = Timers()
_ = Timers()

# Arguments.
args = get_args()
Expand Down

0 comments on commit 47fbb4f

Please sign in to comment.