Skip to content

Commit

Permalink
Format code style
Browse files Browse the repository at this point in the history
  • Loading branch information
XuehaiPan committed Dec 6, 2024
1 parent 8fcfeb0 commit b4f99e8
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 19 deletions.
10 changes: 8 additions & 2 deletions tests/models/encoder_decoder/test_modeling_encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ def check_encoder_decoder_model_from_pretrained_using_model_paths(
**kwargs,
):
encoder_model, decoder_model = self.get_encoder_decoder_model(config, decoder_config)
with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
with (
tempfile.TemporaryDirectory() as encoder_tmp_dirname,
tempfile.TemporaryDirectory() as decoder_tmp_dirname,
):
encoder_model.save_pretrained(encoder_tmp_dirname)
decoder_model.save_pretrained(decoder_tmp_dirname)
model_kwargs = {"encoder_hidden_dropout_prob": 0.0}
Expand Down Expand Up @@ -306,7 +309,10 @@ def check_save_and_load_encoder_decoder_model(
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0

with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
with (
tempfile.TemporaryDirectory() as encoder_tmp_dirname,
tempfile.TemporaryDirectory() as decoder_tmp_dirname,
):
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
enc_dec_model = EncoderDecoderModel.from_encoder_decoder_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,10 @@ def check_save_and_load_encoder_decoder_model(
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0

with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
with (
tempfile.TemporaryDirectory() as encoder_tmp_dirname,
tempfile.TemporaryDirectory() as decoder_tmp_dirname,
):
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
SpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,10 @@ def check_save_and_load_encoder_decoder_model(
out_2 = outputs[0].cpu().numpy()
out_2[np.isnan(out_2)] = 0

with tempfile.TemporaryDirectory() as encoder_tmp_dirname, tempfile.TemporaryDirectory() as decoder_tmp_dirname:
with (
tempfile.TemporaryDirectory() as encoder_tmp_dirname,
tempfile.TemporaryDirectory() as decoder_tmp_dirname,
):
enc_dec_model.encoder.save_pretrained(encoder_tmp_dirname)
enc_dec_model.decoder.save_pretrained(decoder_tmp_dirname)
VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
Expand Down
7 changes: 4 additions & 3 deletions tests/models/wav2vec2/test_modeling_flax_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,9 +623,10 @@ def test_wav2vec2_with_lm_pool(self):
self.assertEqual(transcription[0], "bien y qué regalo vas a abrir primero")

# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
with (
CaptureLogger(processing_wav2vec2_with_lm.logger) as cl,
multiprocessing.get_context("fork").Pool(2) as pool,
):
transcription = processor.batch_decode(np.array(logits), pool, num_processes=2).text

self.assertIn("num_process", cl.out)
Expand Down
7 changes: 4 additions & 3 deletions tests/models/wav2vec2/test_modeling_tf_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,9 +827,10 @@ def test_wav2vec2_with_lm_pool(self):
self.assertEqual(transcription[0], "el libro ha sido escrito por cervantes")

# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
with (
CaptureLogger(processing_wav2vec2_with_lm.logger) as cl,
multiprocessing.get_context("fork").Pool(2) as pool,
):
transcription = processor.batch_decode(logits.numpy(), pool, num_processes=2).text

self.assertIn("num_process", cl.out)
Expand Down
7 changes: 4 additions & 3 deletions tests/models/wav2vec2/test_modeling_wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1889,9 +1889,10 @@ def test_wav2vec2_with_lm_pool(self):
self.assertEqual(transcription[0], "habitan aguas poco profundas y rocosas")

# user-managed pool + num_processes should trigger a warning
with CaptureLogger(processing_wav2vec2_with_lm.logger) as cl, multiprocessing.get_context("fork").Pool(
2
) as pool:
with (
CaptureLogger(processing_wav2vec2_with_lm.logger) as cl,
multiprocessing.get_context("fork").Pool(2) as pool,
):
transcription = processor.batch_decode(logits.cpu().numpy(), pool, num_processes=2).text

self.assertIn("num_process", cl.out)
Expand Down
15 changes: 9 additions & 6 deletions utils/download_glue_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,11 @@ def format_mrpc(data_dir, path_to_data):
for row in ids_fh:
dev_ids.append(row.strip().split("\t"))

with open(mrpc_train_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8"
) as train_fh, open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh:
with (
open(mrpc_train_file, encoding="utf8") as data_fh,
open(os.path.join(mrpc_dir, "train.tsv"), "w", encoding="utf8") as train_fh,
open(os.path.join(mrpc_dir, "dev.tsv"), "w", encoding="utf8") as dev_fh,
):
header = data_fh.readline()
train_fh.write(header)
dev_fh.write(header)
Expand All @@ -92,9 +94,10 @@ def format_mrpc(data_dir, path_to_data):
else:
train_fh.write("%s\t%s\t%s\t%s\t%s\n" % (label, id1, id2, s1, s2))

with open(mrpc_test_file, encoding="utf8") as data_fh, open(
os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8"
) as test_fh:
with (
open(mrpc_test_file, encoding="utf8") as data_fh,
open(os.path.join(mrpc_dir, "test.tsv"), "w", encoding="utf8") as test_fh,
):
header = data_fh.readline()
test_fh.write("index\t#1 ID\t#2 ID\t#1 String\t#2 String\n")
for idx, row in enumerate(data_fh):
Expand Down

0 comments on commit b4f99e8

Please sign in to comment.