Skip to content

Commit

Permalink
Fix the vocab training script for Taskcluster (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
gregtatum authored Dec 22, 2023
1 parent e7cfab9 commit e48440f
Show file tree
Hide file tree
Showing 8 changed files with 300 additions and 58 deletions.
12 changes: 10 additions & 2 deletions pipeline/train/spm-vocab.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
set -x
set -euo pipefail

test -v MARIAN
if [[ -z "${MARIAN}" ]]; then
echo "Error: The MARIAN environment variable was not provided. This is required as"
echo "the path to the spm_train binary."
exit 1
fi

# The name of the source corpus, e.g. "fetches/corpus.en.zst".
merged_corpus_src=$1
Expand All @@ -38,7 +42,11 @@ sample_size=$4
# The thread count, either "auto" or an int.
num_threads=$5
# The size of the final vocab. Defaults to 32000.
vocab_size="${6:-32000}"
vocab_size=${6:-None}

if [ "$vocab_size" == "None" ]; then
vocab_size=32000
fi

if (( vocab_size % 8 != 0 )); then
echo "Error: vocab_size must be a multiple of 8 (https://github.com/mozilla/firefox-translations-training/issues/249)"
Expand Down
87 changes: 69 additions & 18 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ opustrainer = {git = "https://github.com/hplt-project/OpusTrainer.git", rev="913
pytest-clarity = "^1.0.1"
requests-mock = "^1.11.0"
sh = "^2.0.6"
zstandard = "^0.22.0"

[tool.black]
extend-exclude= "/3rd_party"
Expand Down
73 changes: 73 additions & 0 deletions tests/fixtures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import os
import shutil

import zstandard as zstd

FIXTURES_PATH = os.path.dirname(os.path.abspath(__file__))
DATA_PATH = os.path.abspath(os.path.join(FIXTURES_PATH, "../../data"))
TESTS_DATA = os.path.join(DATA_PATH, "tests_data")


en_sample = """The little girl, seeing she had lost one of her pretty shoes, grew angry, and said to the Witch, “Give me back my shoe!”
“I will not,” retorted the Witch, “for it is now my shoe, and not yours.”
“You are a wicked creature!” cried Dorothy. “You have no right to take my shoe from me.”
“I shall keep it, just the same,” said the Witch, laughing at her, “and someday I shall get the other one from you, too.”
This made Dorothy so very angry that she picked up the bucket of water that stood near and dashed it over the Witch, wetting her from head to foot.
Instantly the wicked woman gave a loud cry of fear, and then, as Dorothy looked at her in wonder, the Witch began to shrink and fall away.
“See what you have done!” she screamed. “In a minute I shall melt away.”
“I’m very sorry, indeed,” said Dorothy, who was truly frightened to see the Witch actually melting away like brown sugar before her very eyes.
"""

ca_sample = """La nena, en veure que havia perdut una de les seves boniques sabates, es va enfadar i va dir a la bruixa: "Torna'm la sabata!"
"No ho faré", va replicar la Bruixa, "perquè ara és la meva sabata, i no la teva".
"Ets una criatura dolenta!" va cridar la Dorothy. "No tens dret a treure'm la sabata".
"Me'l guardaré, igualment", va dir la Bruixa, rient-se d'ella, "i algun dia t'agafaré l'altre també".
Això va fer enfadar tant la Dorothy que va agafar la galleda d'aigua que hi havia a prop i la va llançar sobre la Bruixa, mullant-la de cap a peus.
A l'instant, la malvada dona va fer un fort crit de por, i aleshores, mentre la Dorothy la mirava meravellada, la Bruixa va començar a encongir-se i a caure.
"Mira què has fet!" ella va cridar. "D'aquí a un minut em fondreré".
"Ho sento molt, de veritat", va dir la Dorothy, que es va espantar veritablement de veure que la Bruixa es va desfer com el sucre moreno davant els seus mateixos ulls.
"""


class DataDir:
"""
Creates a persistent data directory in data/tests_data/{dir_name} that will be
cleaned out before a test run. This should help in persisting artifacts between test
runs to manually verify the results.
"""

def __init__(self, dir_name: str) -> None:
self.path = os.path.join(TESTS_DATA, dir_name)

# Ensure the base /data directory exists.
os.makedirs(TESTS_DATA, exist_ok=True)

# Clean up a previous run if this exists.
if os.path.exists(self.path):
shutil.rmtree(self.path)

os.makedirs(self.path)
print("Tests are using the subdirectory:", self.path)

def join(self, name: str):
return os.path.join(self.path, name)

def create_zst(self, name: str, contents: str) -> str:
"""
Creates a compressed zst file and returns the path to it.
"""
zst_path = os.path.join(self.path, name)
if not os.path.exists(self.path):
raise Exception(f"Directory for the compressed file does not exist: {self.path}")
if os.path.exists(zst_path):
raise Exception(f"A file already exists and would be overwritten: {zst_path}")

# Create the compressed file.
cctx = zstd.ZstdCompressor()
compressed_data = cctx.compress(contents.encode("utf-8"))

print("Writing a compressed file to: ", zst_path)
with open(zst_path, "wb") as file:
file.write(compressed_data)

return zst_path
26 changes: 26 additions & 0 deletions tests/fixtures/spm_train
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#!/usr/bin/env python3
import os
import sys

"""
spm_train test fixture
Do not rely on spm_train in tests, instead just capture what arguments are passed to it
and save that as a vocab.model.
"""
arguments = sys.argv[1:]

model_prefix_arg = next(arg for arg in arguments if "--model_prefix=" in arg)
model_prefix = model_prefix_arg.split("=")[1]

if not model_prefix:
raise Exception("Could not find the model prefix argument")

vocab_path = model_prefix = model_prefix + ".model"
data_directory = os.path.dirname(vocab_path)

if not os.path.exists(data_directory):
raise Exception("The data directory could not be found.")

with open(vocab_path, "w") as file:
file.write("\n".join(arguments))
22 changes: 9 additions & 13 deletions tests/test_data_importer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import gzip
import os
import shutil

import pytest
from fixtures import DataDir

SRC = "ru"
TRG = "en"
Expand All @@ -15,7 +15,6 @@

from pipeline.data.dataset_importer import run_import

OUTPUT_DIR = "data/tests_data"
# the augmentation is probabilistic, here is a range for 0.1 probability
AUG_MAX_RATE = 0.3
AUG_MIN_RATE = 0.01
Expand Down Expand Up @@ -43,11 +42,8 @@ def get_aug_rate(file, check_func):


@pytest.fixture(scope="function")
def output_dir():
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR)
return os.path.abspath(OUTPUT_DIR)
def data_dir():
return DataDir("test_data_importer")


@pytest.mark.parametrize(
Expand All @@ -59,8 +55,8 @@ def output_dir():
"sacrebleu_wmt19",
],
)
def test_basic_corpus_import(dataset, output_dir):
prefix = os.path.join(output_dir, dataset)
def test_basic_corpus_import(dataset, data_dir):
prefix = data_dir.join(dataset)
output_src = f"{prefix}.{SRC}.{ARTIFACT_EXT}"
output_trg = f"{prefix}.{TRG}.{ARTIFACT_EXT}"

Expand All @@ -81,9 +77,9 @@ def test_basic_corpus_import(dataset, output_dir):
("sacrebleu_aug-title-strict_wmt19", is_title_case, 1.0, 1.0),
],
)
def test_specific_augmentation(params, output_dir):
def test_specific_augmentation(params, data_dir):
dataset, check_func, min_rate, max_rate = params
prefix = os.path.join(output_dir, dataset)
prefix = data_dir.join(dataset)
output_src = f"{prefix}.{SRC}.{ARTIFACT_EXT}"
output_trg = f"{prefix}.{TRG}.{ARTIFACT_EXT}"

Expand All @@ -98,9 +94,9 @@ def test_specific_augmentation(params, output_dir):
assert rate <= max_rate


def test_augmentation_mix(output_dir):
def test_augmentation_mix(data_dir):
dataset = "sacrebleu_aug-mix_wmt19"
prefix = os.path.join(output_dir, dataset)
prefix = data_dir.join(dataset)
output_src = f"{prefix}.{SRC}.{ARTIFACT_EXT}"
output_trg = f"{prefix}.{TRG}.{ARTIFACT_EXT}"

Expand Down
48 changes: 23 additions & 25 deletions tests/test_split_collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,16 @@

import pytest
import sh
from fixtures import DataDir

from pipeline.translate.splitter import main as split_file

COMPRESSION_CMD = "zstdmt"

OUTPUT_DIR = "data/tests_data"


@pytest.fixture(scope="function")
def clean():
if os.path.exists(OUTPUT_DIR):
shutil.rmtree(OUTPUT_DIR)
os.makedirs(OUTPUT_DIR, exist_ok=True)
def data_dir():
return DataDir("test_split_collect")


def generate_dataset(length, path):
Expand Down Expand Up @@ -54,57 +51,58 @@ def read_file(path):
return f.read()


def test_split_collect_mono(clean):
def test_split_collect_mono(data_dir):
os.environ["COMPRESSION_CMD"] = COMPRESSION_CMD
length = 1234
path = os.path.join(OUTPUT_DIR, "mono.in")
output = os.path.join(OUTPUT_DIR, "mono.output")
path = data_dir.join("mono.in")
output = data_dir.join("mono.output")
output_compressed = f"{output}.zst"
generate_dataset(length, path)

split_file(
[
f"--output_dir={OUTPUT_DIR}",
f"--output_dir={data_dir.path}",
"--num_parts=10",
f"--compression_cmd={COMPRESSION_CMD}",
f"{path}.zst",
]
)

# file.1.zst, file.2.zst ... file.10.zst
expected_files = set([f"{OUTPUT_DIR}/file.{i}.zst" for i in range(1, 11)])
assert set(glob.glob(f"{OUTPUT_DIR}/file.*.zst")) == expected_files
expected_files = set([data_dir.join(f"file.{i}.zst") for i in range(1, 11)])
assert set(glob.glob(data_dir.join("file.*.zst"))) == expected_files

imitate_translate(OUTPUT_DIR, suffix=".out")
imitate_translate(data_dir.path, suffix=".out")
subprocess.run(
["pipeline/translate/collect.sh", OUTPUT_DIR, output_compressed, f"{path}.zst"], check=True
["pipeline/translate/collect.sh", data_dir.path, output_compressed, f"{path}.zst"],
check=True,
)

decompress(output_compressed)
assert read_file(path) == read_file(output)


def test_split_collect_corpus(clean):
def test_split_collect_corpus(data_dir):
os.environ["COMPRESSION_CMD"] = COMPRESSION_CMD
length = 1234
path_src = os.path.join(OUTPUT_DIR, "corpus.src.in")
path_trg = os.path.join(OUTPUT_DIR, "corpus.trg.in")
output = os.path.join(OUTPUT_DIR, "corpus.src.output")
path_src = data_dir.join("corpus.src.in")
path_trg = data_dir.join("corpus.trg.in")
output = data_dir.join("corpus.src.output")
output_compressed = f"{output}.zst"
generate_dataset(length, path_src)
generate_dataset(length, path_trg)

split_file(
[
f"--output_dir={OUTPUT_DIR}",
f"--output_dir={data_dir.path}",
"--num_parts=10",
f"--compression_cmd={COMPRESSION_CMD}",
f"{path_src}.zst",
]
)
split_file(
[
f"--output_dir={OUTPUT_DIR}",
f"--output_dir={data_dir.path}",
"--num_parts=10",
f"--compression_cmd={COMPRESSION_CMD}",
"--output_suffix=.ref",
Expand All @@ -114,14 +112,14 @@ def test_split_collect_corpus(clean):

# file.1.zst, file.2.zst ... file.10.zst
# file.1.ref.zst, file.2.ref.zst ... file.10.ref.zst
expected_files = set([f"{OUTPUT_DIR}/file.{i}.zst" for i in range(1, 11)]) | set(
[f"{OUTPUT_DIR}/file.{i}.ref.zst" for i in range(1, 11)]
expected_files = set([data_dir.join(f"file.{i}.zst") for i in range(1, 11)]) | set(
[data_dir.join(f"file.{i}.ref.zst") for i in range(1, 11)]
)
assert set(glob.glob(f"{OUTPUT_DIR}/file.*.zst")) == expected_files
assert set(glob.glob(data_dir.join("file.*.zst"))) == expected_files

imitate_translate(OUTPUT_DIR, suffix=".nbest.out")
imitate_translate(data_dir.path, suffix=".nbest.out")
subprocess.run(
["pipeline/translate/collect.sh", OUTPUT_DIR, output_compressed, f"{path_src}.zst"],
["pipeline/translate/collect.sh", data_dir.path, output_compressed, f"{path_src}.zst"],
check=True,
)

Expand Down
Loading

0 comments on commit e48440f

Please sign in to comment.