diff --git a/pipeline/alignments/align.py b/pipeline/alignments/align.py index ea631915f..bfc5c2ac1 100755 --- a/pipeline/alignments/align.py +++ b/pipeline/alignments/align.py @@ -33,6 +33,7 @@ from tqdm import tqdm from pipeline.alignments.tokenizer import tokenize_moses +from pipeline.common.datasets import decompress from pipeline.common.logging import get_logger logger = get_logger("alignments") @@ -60,8 +61,8 @@ def run( tmp_dir = os.path.join(os.path.dirname(output_path), "tmp") os.makedirs(tmp_dir, exist_ok=True) - corpus_src = decompress(corpus_src) - corpus_trg = decompress(corpus_trg) + corpus_src = maybe_decompress(corpus_src) + corpus_trg = maybe_decompress(corpus_trg) if tokenization == Tokenization.moses: tokenized_src = ( @@ -121,11 +122,9 @@ def run( shutil.move(output_aln, output_path) -def decompress(file_path: str): +def maybe_decompress(file_path: str): if file_path.endswith(".zst"): - logger.info(f"Decompressing file {file_path}") - subprocess.check_call(["zstdmt", "-d", "-f", "--rm", file_path]) - return file_path[:-4] + return str(decompress(file_path, remove=True, logger=logger)) return file_path diff --git a/pipeline/common/command_runner.py b/pipeline/common/command_runner.py index b5e7a13de..38acd46b0 100644 --- a/pipeline/common/command_runner.py +++ b/pipeline/common/command_runner.py @@ -1,3 +1,4 @@ +import os import re from shlex import join import shlex @@ -62,7 +63,7 @@ def run_command_pipeline( it will log the constructed pipeline commands. Defaults to None. Example: - python_scripts = run_pipeline( + python_scripts = run_command_pipeline( [ ["ls", "-l"], ["grep", ".py"], @@ -94,3 +95,38 @@ def run_command_pipeline( return subprocess.check_output(command_string, shell=True).decode("utf-8") subprocess.check_call(command_string, shell=True) + + +def run_command( + command: list[str], capture=False, shell=False, logger=None, env=None +) -> str | None: + """ + Runs a command and outputs a nice representation of the command to a logger, if supplied. + + Args: + command: The command arguments provided to subprocess.check_call + capture: If True, captures and returns the output of the final command in the + pipeline. If False, output is printed to stdout. + logger: A logger instance used for logging the command execution. If provided, + it will log the pipeline commands. + env: The environment object. + + Example: + directory_listing = run_command( + ["ls", "-l"], + capture=True + ) + """ + # Expand any environment variables. + command = [os.path.expandvars(part) for part in command] + + if logger: + # Log out a nice representation of this command. + logger.info("Running:") + for line in _get_indented_command_string(command).split("\n"): + logger.info(line) + + if capture: + return subprocess.check_output(command).decode("utf-8") + + subprocess.check_call(command, env=env) diff --git a/pipeline/common/datasets.py b/pipeline/common/datasets.py index 4f4758935..09221c2fb 100644 --- a/pipeline/common/datasets.py +++ b/pipeline/common/datasets.py @@ -1,13 +1,15 @@ from collections.abc import Iterable import hashlib import json +from logging import Logger import os +import subprocess import tempfile from dataclasses import dataclass from io import TextIOWrapper from pathlib import Path from random import Random -from typing import Callable, Iterator, Optional, Set, Union +from typing import Callable, Iterator, Literal, Optional, Set, Union from urllib.parse import urlparse import unicodedata @@ -444,3 +446,110 @@ def _hash_string(string: str) -> int: """ cleaned_line = unicodedata.normalize("NFC", string.strip()) return hash(cleaned_line) + + +def decompress( + source: Union[str, Path], + destination: Optional[Union[Path, str]] = None, + remove: bool = False, + logger: Optional[Logger] = None, +) -> Path: + """ + Decompresses a file using the appropriate command based on its file extension. + + Args: + file_path: The path to the file to be decompressed + remove: If set to `True`, the original compressed file will be removed after decompression. + destination: Be default the file will be decompressed next to the original. This arguments + allows for overriding the destination. + logger: Log information about the decompression + """ + if isinstance(source, str): + source = Path(source) + if not destination: + destination = source.parent / source.stem + + if logger: + logger.info(f"[decompress] From: {source}") + logger.info(f"[decompress] To: {destination}") + + if source.suffix == ".zst": + command = ["zstdmt", "--decompress", "--force", "-o", destination, source] + if remove: + command.append("--rm") + + subprocess.check_call(command) + elif source.suffix == ".gz": + command = ["gzip", "-c", "-d", source] + with open(destination, "wb") as out_file: + subprocess.check_call(command, stdout=out_file) + if remove: + source.unlink() + else: + raise Exception(f"Unknown file type to decompress: {source}") + + if remove: + logger.info(f"[decompress] Removed: {source}") + + return destination + + +def compress( + source: Union[str, Path], + destination: Optional[Union[Path, str]] = None, + remove: bool = False, + compression_type: Union[Literal["zst"], Literal["gz"]] = None, + logger: Optional[Logger] = None, +) -> Path: + """ + Compress a file using the appropriate command based on its file extension. + + Args: + source: The path to the file to be compressed + destination: Be default the file will be compressed next to the original. This arguments + allows for overriding the destination. + remove: If set to `True`, the original decompressed file will be removed. + type: The type defaults to "zst", and is implied by the destination, however it can + be explicitly set. + logger: Log information about the compression + """ + if isinstance(source, str): + source = Path(source) + if isinstance(destination, str): + destination = Path(destination) + + # Ensure the compression type is valid and present + if compression_type and destination: + assert f".{type}" == destination.suffix, "The compression type and destination must match." + + if not compression_type: + if destination: + compression_type = destination.suffix[1:] + else: + compression_type = "zst" + + # Set default destination if not provided + if not destination: + destination = source.with_suffix(f"{source.suffix}.{compression_type}") + + if logger: + logger.info(f"Compressing: {source}") + logger.info(f"Destination: {destination}") + + if compression_type == "zst": + command = ["zstdmt", "--compress", "--force", "--quiet", source, "-o", destination] + if remove: + command.append("--rm") + subprocess.check_call(command) + elif compression_type == "gz": + with open(destination, "wb") as out_file: + subprocess.check_call(["gzip", "-c", "--force", source], stdout=out_file) + if remove: + source.unlink() + else: + raise ValueError(f"Unsupported compression type: {compression_type}") + + if remove: + logger.info(f"Removed {source}") + + return destination diff --git a/pipeline/common/downloads.py b/pipeline/common/downloads.py index 94c39064b..79f303e9b 100644 --- a/pipeline/common/downloads.py +++ b/pipeline/common/downloads.py @@ -520,6 +520,19 @@ def count_lines(path: Path | str) -> int: return sum(1 for _ in lines) +def is_file_empty(path: Path | str) -> bool: + """ + Attempts to read a line to determine if a file is empty or not. Works on local or remote files + as well as compressed or uncompressed files. + """ + with read_lines(path) as lines: + try: + next(lines) + return False + except StopIteration: + return True + + def get_file_size(location: Union[Path, str]) -> int: """Get the size of a file, whether it is remote or local.""" if str(location).startswith("http://") or str(location).startswith("https://"): diff --git a/pipeline/common/logging.py b/pipeline/common/logging.py index f07bf0025..026ef0685 100644 --- a/pipeline/common/logging.py +++ b/pipeline/common/logging.py @@ -1,8 +1,14 @@ import logging from pathlib import Path +import subprocess +import threading +import time logging.basicConfig(level=logging.INFO, format="[%(name)s] %(message)s") +STOP_BYTE_COUNT_LOGGER = False +STOP_GPU_LOGGER = False + def get_logger(name: str): """ @@ -21,3 +27,101 @@ def get_logger(name: str): logger = logging.getLogger(Path(name).stem) logger.setLevel(logging.INFO) return logger + + +def _log_gpu_stats(logger: logging.Logger, interval_seconds: int): + # Only load gpustat when it's needed. + import gpustat + + global STOP_GPU_LOGGER + while True: + time.sleep(interval_seconds) + if STOP_GPU_LOGGER: + STOP_GPU_LOGGER = False + return + try: + logger.info("[gpu] Current GPU stats:") + gpustat.print_gpustat() + except subprocess.CalledProcessError as e: + logger.error(f"Failed to retrieve GPU stats: {e}") + + +def stop_gpu_logging(): + global STOP_GPU_LOGGER + STOP_GPU_LOGGER = True + + +def start_gpu_logging(logger: logging.Logger, interval_seconds: int): + """Logs GPU stats on an interval using gpustat in a background thread.""" + assert not STOP_GPU_LOGGER, "A gpu logger should not already be running" + + thread = threading.Thread( + target=_log_gpu_stats, + # Set as a daemon thread so it automatically is closed on shutdown. + daemon=True, + args=(logger, interval_seconds), + ) + thread.start() + + +def _log_byte_rate(logger: logging.Logger, interval_seconds: int, file_path: Path): + global STOP_BYTE_COUNT_LOGGER + previous_byte_count = 0 + previous_time = time.time() + is_zst = file_path.suffix == ".zst" + + while True: + time.sleep(interval_seconds) + if STOP_BYTE_COUNT_LOGGER: + STOP_BYTE_COUNT_LOGGER = False + return + + try: + if is_zst: + # This takes ~1 second to run on 5 million sentences. + current_byte_count = 0 + cmd = ["zstd", "-dc", str(file_path)] + with subprocess.Popen( + cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) as process: + for chunk in iter(lambda: process.stdout.read(8192), b""): + current_byte_count += len(chunk) + else: + # This is pretty much instantaneous. + result = subprocess.run( + ["wc", "-c", str(file_path)], capture_output=True, text=True, check=True + ) + current_byte_count = int(result.stdout.split()[0]) + + bytes_added = current_byte_count - previous_byte_count + + current_secs = time.time() + elapsed_secs = current_secs - previous_time + byte_rate = bytes_added / elapsed_secs if bytes_added > 0 else 0 + + logger.info(f"[bytes] Added: {bytes_added:,}") + logger.info(f"[bytes] Total: {current_byte_count:,}") + logger.info(f"[bytes] Rate: {byte_rate:,.2f} bytes/second") + + previous_byte_count = current_byte_count + previous_time = time.time() + except Exception as e: + logger.error(f"Failed to monitor byte count: {e}") + + +def stop_byte_count_logger(): + global STOP_BYTE_COUNT_LOGGER + STOP_BYTE_COUNT_LOGGER = True + + +def start_byte_count_logger(logger: logging.Logger, interval_seconds: int, file_path: Path): + """ + Monitors the rate of bytes being added to a file, logging the number of bytes + added per second over the interval. + """ + + assert not STOP_BYTE_COUNT_LOGGER, "A line count logger should not already be running" + thread = threading.Thread( + target=_log_byte_rate, args=(logger, interval_seconds, file_path), daemon=True + ) + thread.start() diff --git a/pipeline/common/marian.py b/pipeline/common/marian.py new file mode 100644 index 000000000..69febc561 --- /dev/null +++ b/pipeline/common/marian.py @@ -0,0 +1,53 @@ +""" +Common utilities related to working with Marian. +""" + +from pathlib import Path +from typing import Union + +import yaml + + +def get_combined_config(config_path: Path, extra_marian_args: list[str]) -> dict[str, any]: + """ + Frequently we combine a Marian yml config with extra marian args when running + training. To get the final value, add both here. + """ + return { + **yaml.safe_load(config_path.open()), + **marian_args_to_dict(extra_marian_args), + } + + +def marian_args_to_dict(extra_marian_args: list[str]) -> dict[str, Union[str, bool, list[str]]]: + """ + Converts marian args, to the dict format. This will combine a decoder.yml + and extra marian args. + + e.g. `--precision float16` becomes {"precision": "float16"} + """ + decoder_config = {} + if extra_marian_args and extra_marian_args[0] == "--": + extra_marian_args = extra_marian_args[1:] + + previous_key = None + for arg in extra_marian_args: + if arg.startswith("--"): + previous_key = arg[2:] + decoder_config[previous_key] = True + continue + + if not previous_key: + raise Exception( + f"Expected to have a previous key when converting marian args to a dict: {extra_marian_args}" + ) + + prev_value = decoder_config.get(previous_key) + if prev_value is True: + decoder_config[previous_key] = arg + elif isinstance(prev_value, list): + prev_value.append(arg) + else: + decoder_config[previous_key] = [prev_value, arg] + + return decoder_config diff --git a/pipeline/eval/eval.py b/pipeline/eval/eval.py index 1435060cd..8bf3dcf15 100755 --- a/pipeline/eval/eval.py +++ b/pipeline/eval/eval.py @@ -65,7 +65,7 @@ list_existing_group_logs_metrics, ) - WANDB_AVAILABLE = True + WANDB_AVAILABLE = "TASKCLUSTER_PROXY_URL" in os.environ except ImportError as e: print(f"Failed to import tracking module: {e}") WANDB_AVAILABLE = False diff --git a/pipeline/train/requirements/train.in b/pipeline/train/requirements/train.in index 833b10020..5f461a0c2 100644 --- a/pipeline/train/requirements/train.in +++ b/pipeline/train/requirements/train.in @@ -1,3 +1,3 @@ # use the latest main, switch to PyPi when released git+https://github.com/hplt-project/OpusTrainer.git@c966d7b353d6b3c6a09d9573f1ab6ba3221c1d21 - +gpustat==1.1.1 diff --git a/pipeline/train/requirements/train.txt b/pipeline/train/requirements/train.txt index d7d301010..34ad968ad 100644 --- a/pipeline/train/requirements/train.txt +++ b/pipeline/train/requirements/train.txt @@ -1,15 +1,23 @@ # -# This file is autogenerated by pip-compile with Python 3.11 +# This file is autogenerated by pip-compile with Python 3.10 # by the following command: # -# pip-compile pipeline/train/requirements/train.in +# pip-compile --allow-unsafe pipeline/train/requirements/train.in # +blessed==1.20.0 + # via gpustat click==8.1.7 # via sacremoses +gpustat==1.1.1 + # via -r pipeline/train/requirements/train.in joblib==1.3.2 # via sacremoses +nvidia-ml-py==12.560.30 + # via gpustat opustrainer @ git+https://github.com/hplt-project/OpusTrainer.git@c966d7b353d6b3c6a09d9573f1ab6ba3221c1d21 # via -r pipeline/train/requirements/train.in +psutil==6.1.0 + # via gpustat pyyaml==6.0.1 # via opustrainer regex==2023.10.3 @@ -18,7 +26,11 @@ sacremoses==0.1.1 # via opustrainer sentencepiece==0.1.99 # via opustrainer +six==1.17.0 + # via blessed tqdm==4.66.1 # via sacremoses typo==0.1.5 # via opustrainer +wcwidth==0.2.13 + # via blessed diff --git a/pipeline/translate/requirements/translate-ctranslate2.in b/pipeline/translate/requirements/translate-ctranslate2.in new file mode 100644 index 000000000..0d266b389 --- /dev/null +++ b/pipeline/translate/requirements/translate-ctranslate2.in @@ -0,0 +1,3 @@ +ctranslate2==4.3.1 +sentencepiece==0.2.0 +gpustat==1.1.1 diff --git a/pipeline/translate/requirements/translate-ctranslate2.txt b/pipeline/translate/requirements/translate-ctranslate2.txt new file mode 100644 index 000000000..49d6cb81c --- /dev/null +++ b/pipeline/translate/requirements/translate-ctranslate2.txt @@ -0,0 +1,232 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --allow-unsafe --generate-hashes pipeline/translate/requirements/translate-ctranslate2.in +# +blessed==1.20.0 \ + --hash=sha256:0c542922586a265e699188e52d5f5ac5ec0dd517e5a1041d90d2bbf23f906058 \ + --hash=sha256:2cdd67f8746e048f00df47a2880f4d6acbcdb399031b604e34ba8f71d5787680 + # via gpustat +ctranslate2==4.3.1 \ + --hash=sha256:08626f115d5a39c56a666680735d6eebfc4d8a215288896d4d8afc14cfcdcffe \ + --hash=sha256:30c02fcd5a7be93bf42a8adf81a9ac4f394e23bd639192907b2e11feae589971 \ + --hash=sha256:343b24fe3d8a5b6a7c8082332415767bef7ceaf15bb43d0cec7e83665108c51e \ + --hash=sha256:45c5b352783bd3806f0c9f5dcbfa49d89c0dde71cb7d1b1c527c525e85af3ded \ + --hash=sha256:49a0d9136d577b667c1bb450267248d9cf205b5eb28b89b3f70c296ec5285da8 \ + --hash=sha256:4bca2ce519c497bc2f79e567093609d7bdfaff3313220e0d831797288803f3aa \ + --hash=sha256:60bc176dd2e0ee6ddd33682401440f7626d115fed4f1e5e6816d9f7f213d1a62 \ + --hash=sha256:68301fbc5fb7daa609eb12ca6c2ed8aa29852c20f962532317762d1889e751d9 \ + --hash=sha256:6f49834b63848f17dfdc1b2b8c632c31932ad69e130ce0f7b1e2505aa3923e6c \ + --hash=sha256:7d394367fe472b6540489e3b081fc7e17cea2264075b074fb28eca30ff63463f \ + --hash=sha256:7d95ecb440e4985cad4623a1fe7bb91406bab4aa55b00aa89a0c16eb5939d640 \ + --hash=sha256:8c202011fa2ebb8129ba98a65df48df075f0ef53f905f2b13b8cd00f31c7ccff \ + --hash=sha256:9f1fd426d9019198d0fd8f37a18bf9c486241f711d597686956c58cd7676d564 \ + --hash=sha256:a06043910a7dee91ea03634be2cff2e1338a9f87bb51e062c03bae69e2c826b6 \ + --hash=sha256:a49dc5d339e2f4ed016553db0d0e6cbd369742697c87c6cc0cc15a47c7c72d00 \ + --hash=sha256:d8679354547260db999c2bcc6f11a31dad828c3d896d6120045bd0333940732f \ + --hash=sha256:de05e33790d72492a76101a0357c3d87d97ad53af84417c78f45e85df76d39e8 \ + --hash=sha256:def98f6f8900470b2cec9408e5b0402af75f40f771391ebacd2b60666b8d75b9 \ + --hash=sha256:e40d43c5f7d25f40d31cca0541cf21c2846f89509b99189d340fdee595391196 \ + --hash=sha256:e962c9dc3ddfacf60f2467bea5f91f75239c3d9c17656e4b0c569d956d662b99 \ + --hash=sha256:ef812a4129e877f64f8ca2438b6247060af0f053a56b438dbfa81dae9ca12675 \ + --hash=sha256:f352bcb802ab9ff1b94a25b4915c4f9f97cdd230993cf45ea290592d8997c2e2 \ + --hash=sha256:f63f779f1d4518acdc694b1938887d4f28613ac2dfe507ccc2c0d56dd8c95b40 \ + --hash=sha256:fcf649d976070ddd33cdda00a7a60fde6f1fbe27d65d2c6141dd95153f965f01 \ + --hash=sha256:febf7cf0fb641c76035cdece58e97d27f4e8950a5e32fc480f9afa1bcbbb856c + # via -r pipeline/translate/requirements/translate-ctranslate2.in +gpustat==1.1.1 \ + --hash=sha256:c18d3ed5518fc16300c42d694debc70aebb3be55cae91f1db64d63b5fa8af9d8 + # via -r pipeline/translate/requirements/translate-ctranslate2.in +numpy==2.0.1 \ + --hash=sha256:08458fbf403bff5e2b45f08eda195d4b0c9b35682311da5a5a0a0925b11b9bd8 \ + --hash=sha256:0fbb536eac80e27a2793ffd787895242b7f18ef792563d742c2d673bfcb75134 \ + --hash=sha256:12f5d865d60fb9734e60a60f1d5afa6d962d8d4467c120a1c0cda6eb2964437d \ + --hash=sha256:15eb4eca47d36ec3f78cde0a3a2ee24cf05ca7396ef808dda2c0ddad7c2bde67 \ + --hash=sha256:173a00b9995f73b79eb0191129f2455f1e34c203f559dd118636858cc452a1bf \ + --hash=sha256:1b902ce0e0a5bb7704556a217c4f63a7974f8f43e090aff03fcf262e0b135e02 \ + --hash=sha256:1f682ea61a88479d9498bf2091fdcd722b090724b08b31d63e022adc063bad59 \ + --hash=sha256:1f87fec1f9bc1efd23f4227becff04bd0e979e23ca50cc92ec88b38489db3b55 \ + --hash=sha256:24a0e1befbfa14615b49ba9659d3d8818a0f4d8a1c5822af8696706fbda7310c \ + --hash=sha256:2c3a346ae20cfd80b6cfd3e60dc179963ef2ea58da5ec074fd3d9e7a1e7ba97f \ + --hash=sha256:36d3a9405fd7c511804dc56fc32974fa5533bdeb3cd1604d6b8ff1d292b819c4 \ + --hash=sha256:3fdabe3e2a52bc4eff8dc7a5044342f8bd9f11ef0934fcd3289a788c0eb10018 \ + --hash=sha256:4127d4303b9ac9f94ca0441138acead39928938660ca58329fe156f84b9f3015 \ + --hash=sha256:4658c398d65d1b25e1760de3157011a80375da861709abd7cef3bad65d6543f9 \ + --hash=sha256:485b87235796410c3519a699cfe1faab097e509e90ebb05dcd098db2ae87e7b3 \ + --hash=sha256:529af13c5f4b7a932fb0e1911d3a75da204eff023ee5e0e79c1751564221a5c8 \ + --hash=sha256:5a3d94942c331dd4e0e1147f7a8699a4aa47dffc11bf8a1523c12af8b2e91bbe \ + --hash=sha256:5daab361be6ddeb299a918a7c0864fa8618af66019138263247af405018b04e1 \ + --hash=sha256:61728fba1e464f789b11deb78a57805c70b2ed02343560456190d0501ba37b0f \ + --hash=sha256:6790654cb13eab303d8402354fabd47472b24635700f631f041bd0b65e37298a \ + --hash=sha256:69ff563d43c69b1baba77af455dd0a839df8d25e8590e79c90fcbe1499ebde42 \ + --hash=sha256:6bf4e6f4a2a2e26655717a1983ef6324f2664d7011f6ef7482e8c0b3d51e82ac \ + --hash=sha256:6e4eeb6eb2fced786e32e6d8df9e755ce5be920d17f7ce00bc38fcde8ccdbf9e \ + --hash=sha256:72dc22e9ec8f6eaa206deb1b1355eb2e253899d7347f5e2fae5f0af613741d06 \ + --hash=sha256:75b4e316c5902d8163ef9d423b1c3f2f6252226d1aa5cd8a0a03a7d01ffc6268 \ + --hash=sha256:7b9853803278db3bdcc6cd5beca37815b133e9e77ff3d4733c247414e78eb8d1 \ + --hash=sha256:7d6fddc5fe258d3328cd8e3d7d3e02234c5d70e01ebe377a6ab92adb14039cb4 \ + --hash=sha256:81b0893a39bc5b865b8bf89e9ad7807e16717f19868e9d234bdaf9b1f1393868 \ + --hash=sha256:8efc84f01c1cd7e34b3fb310183e72fcdf55293ee736d679b6d35b35d80bba26 \ + --hash=sha256:8fae4ebbf95a179c1156fab0b142b74e4ba4204c87bde8d3d8b6f9c34c5825ef \ + --hash=sha256:99d0d92a5e3613c33a5f01db206a33f8fdf3d71f2912b0de1739894668b7a93b \ + --hash=sha256:9adbd9bb520c866e1bfd7e10e1880a1f7749f1f6e5017686a5fbb9b72cf69f82 \ + --hash=sha256:a1e01dcaab205fbece13c1410253a9eea1b1c9b61d237b6fa59bcc46e8e89343 \ + --hash=sha256:a8fc2de81ad835d999113ddf87d1ea2b0f4704cbd947c948d2f5513deafe5a7b \ + --hash=sha256:b83e16a5511d1b1f8a88cbabb1a6f6a499f82c062a4251892d9ad5d609863fb7 \ + --hash=sha256:bb2124fdc6e62baae159ebcfa368708867eb56806804d005860b6007388df171 \ + --hash=sha256:bfc085b28d62ff4009364e7ca34b80a9a080cbd97c2c0630bb5f7f770dae9414 \ + --hash=sha256:cbab9fc9c391700e3e1287666dfd82d8666d10e69a6c4a09ab97574c0b7ee0a7 \ + --hash=sha256:e5eeca8067ad04bc8a2a8731183d51d7cbaac66d86085d5f4766ee6bf19c7f87 \ + --hash=sha256:e9e81fa9017eaa416c056e5d9e71be93d05e2c3c2ab308d23307a8bc4443c368 \ + --hash=sha256:ea2326a4dca88e4a274ba3a4405eb6c6467d3ffbd8c7d38632502eaae3820587 \ + --hash=sha256:eacf3291e263d5a67d8c1a581a8ebbcfd6447204ef58828caf69a5e3e8c75990 \ + --hash=sha256:ec87f5f8aca726117a1c9b7083e7656a9d0d606eec7299cc067bb83d26f16e0c \ + --hash=sha256:f1659887361a7151f89e79b276ed8dff3d75877df906328f14d8bb40bb4f5101 \ + --hash=sha256:f9cf5ea551aec449206954b075db819f52adc1638d46a6738253a712d553c7b4 + # via ctranslate2 +nvidia-ml-py==12.560.30 \ + --hash=sha256:f0254dc7400647680a072ee02509bfd46102b60bdfeca321576d4d4817e7fe97 \ + --hash=sha256:fea371c94d63e38a611c17bbb85fe400e9c8ddb9e8684a9cd0e47786a4bc3c73 + # via gpustat +psutil==6.1.0 \ + --hash=sha256:000d1d1ebd634b4efb383f4034437384e44a6d455260aaee2eca1e9c1b55f047 \ + --hash=sha256:045f00a43c737f960d273a83973b2511430d61f283a44c96bf13a6e829ba8fdc \ + --hash=sha256:0895b8414afafc526712c498bd9de2b063deaac4021a3b3c34566283464aff8e \ + --hash=sha256:1209036fbd0421afde505a4879dee3b2fd7b1e14fee81c0069807adcbbcca747 \ + --hash=sha256:1ad45a1f5d0b608253b11508f80940985d1d0c8f6111b5cb637533a0e6ddc13e \ + --hash=sha256:353815f59a7f64cdaca1c0307ee13558a0512f6db064e92fe833784f08539c7a \ + --hash=sha256:498c6979f9c6637ebc3a73b3f87f9eb1ec24e1ce53a7c5173b8508981614a90b \ + --hash=sha256:5cd2bcdc75b452ba2e10f0e8ecc0b57b827dd5d7aaffbc6821b2a9a242823a76 \ + --hash=sha256:6d3fbbc8d23fcdcb500d2c9f94e07b1342df8ed71b948a2649b5cb060a7c94ca \ + --hash=sha256:6e2dcd475ce8b80522e51d923d10c7871e45f20918e027ab682f94f1c6351688 \ + --hash=sha256:9118f27452b70bb1d9ab3198c1f626c2499384935aaf55388211ad982611407e \ + --hash=sha256:9dcbfce5d89f1d1f2546a2090f4fcf87c7f669d1d90aacb7d7582addece9fb38 \ + --hash=sha256:a8506f6119cff7015678e2bce904a4da21025cc70ad283a53b099e7620061d85 \ + --hash=sha256:a8fb3752b491d246034fa4d279ff076501588ce8cbcdbb62c32fd7a377d996be \ + --hash=sha256:c0e0c00aa18ca2d3b2b991643b799a15fc8f0563d2ebb6040f64ce8dc027b942 \ + --hash=sha256:d905186d647b16755a800e7263d43df08b790d709d575105d419f8b6ef65423a \ + --hash=sha256:ff34df86226c0227c52f38b919213157588a678d049688eded74c76c8ba4a5d0 + # via gpustat +pyyaml==6.0.1 \ + --hash=sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5 \ + --hash=sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc \ + --hash=sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df \ + --hash=sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741 \ + --hash=sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206 \ + --hash=sha256:18aeb1bf9a78867dc38b259769503436b7c72f7a1f1f4c93ff9a17de54319b27 \ + --hash=sha256:1d4c7e777c441b20e32f52bd377e0c409713e8bb1386e1099c2415f26e479595 \ + --hash=sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62 \ + --hash=sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98 \ + --hash=sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696 \ + --hash=sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290 \ + --hash=sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9 \ + --hash=sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d \ + --hash=sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6 \ + --hash=sha256:4fb147e7a67ef577a588a0e2c17b6db51dda102c71de36f8549b6816a96e1867 \ + --hash=sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47 \ + --hash=sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486 \ + --hash=sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6 \ + --hash=sha256:596106435fa6ad000c2991a98fa58eeb8656ef2325d7e158344fb33864ed87e3 \ + --hash=sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007 \ + --hash=sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938 \ + --hash=sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0 \ + --hash=sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c \ + --hash=sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735 \ + --hash=sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d \ + --hash=sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28 \ + --hash=sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4 \ + --hash=sha256:9046c58c4395dff28dd494285c82ba00b546adfc7ef001486fbf0324bc174fba \ + --hash=sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8 \ + --hash=sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef \ + --hash=sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5 \ + --hash=sha256:afd7e57eddb1a54f0f1a974bc4391af8bcce0b444685d936840f125cf046d5bd \ + --hash=sha256:b1275ad35a5d18c62a7220633c913e1b42d44b46ee12554e5fd39c70a243d6a3 \ + --hash=sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0 \ + --hash=sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515 \ + --hash=sha256:baa90d3f661d43131ca170712d903e6295d1f7a0f595074f151c0aed377c9b9c \ + --hash=sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c \ + --hash=sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924 \ + --hash=sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34 \ + --hash=sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43 \ + --hash=sha256:c8098ddcc2a85b61647b2590f825f3db38891662cfc2fc776415143f599bb859 \ + --hash=sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673 \ + --hash=sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54 \ + --hash=sha256:d858aa552c999bc8a8d57426ed01e40bef403cd8ccdd0fc5f6f04a00414cac2a \ + --hash=sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b \ + --hash=sha256:f003ed9ad21d6a4713f0a9b5a7a0a79e08dd0f221aff4525a2be4c346ee60aab \ + --hash=sha256:f22ac1c3cac4dbc50079e965eba2c1058622631e526bd9afd45fedd49ba781fa \ + --hash=sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c \ + --hash=sha256:fca0e3a251908a499833aa292323f32437106001d436eca0e6e7833256674585 \ + --hash=sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d \ + --hash=sha256:fd66fc5d0da6d9815ba2cebeb4205f95818ff4b79c3ebe268e75d961704af52f + # via ctranslate2 +sentencepiece==0.2.0 \ + --hash=sha256:0461324897735512a32d222e3d886e24ad6a499761952b6bda2a9ee6e4313ea5 \ + --hash=sha256:0993dbc665f4113017892f1b87c3904a44d0640eda510abcacdfb07f74286d36 \ + --hash=sha256:0a91aaa3c769b52440df56fafda683b3aa48e3f2169cf7ee5b8c8454a7f3ae9b \ + --hash=sha256:0f67eae0dbe6f2d7d6ba50a354623d787c99965f068b81e145d53240198021b0 \ + --hash=sha256:1380ce6540a368de2ef6d7e6ba14ba8f3258df650d39ba7d833b79ee68a52040 \ + --hash=sha256:17982700c4f6dbb55fa3594f3d7e5dd1c8659a274af3738e33c987d2a27c9d5c \ + --hash=sha256:188779e1298a1c8b8253c7d3ad729cb0a9891e5cef5e5d07ce4592c54869e227 \ + --hash=sha256:1e0f9c4d0a6b0af59b613175f019916e28ade076e21242fd5be24340d8a2f64a \ + --hash=sha256:20813a68d4c221b1849c62c30e1281ea81687894d894b8d4a0f4677d9311e0f5 \ + --hash=sha256:22e37bac44dd6603388cb598c64ff7a76e41ca774646f21c23aadfbf5a2228ab \ + --hash=sha256:27f90c55a65013cbb8f4d7aab0599bf925cde4adc67ae43a0d323677b5a1c6cb \ + --hash=sha256:298f21cc1366eb60311aedba3169d30f885c363ddbf44214b0a587d2908141ad \ + --hash=sha256:2a3149e3066c2a75e0d68a43eb632d7ae728c7925b517f4c05c40f6f7280ce08 \ + --hash=sha256:2fde4b08cfe237be4484c6c7c2e2c75fb862cfeab6bd5449ce4caeafd97b767a \ + --hash=sha256:3212121805afc58d8b00ab4e7dd1f8f76c203ddb9dc94aa4079618a31cf5da0f \ + --hash=sha256:38aed822fb76435fa1f12185f10465a94ab9e51d5e8a9159e9a540ce926f0ffd \ + --hash=sha256:3f1ec95aa1e5dab11f37ac7eff190493fd87770f7a8b81ebc9dd768d1a3c8704 \ + --hash=sha256:4547683f330289ec4f093027bfeb87f9ef023b2eb6f879fdc4a8187c7e0ffb90 \ + --hash=sha256:4c378492056202d1c48a4979650981635fd97875a00eabb1f00c6a236b013b5e \ + --hash=sha256:536b934e244829e3fe6c4f198652cd82da48adb9aa145c9f00889542726dee3d \ + --hash=sha256:632f3594d3e7ac8b367bca204cb3fd05a01d5b21455acd097ea4c0e30e2f63d7 \ + --hash=sha256:6cf333625234f247ab357b0bd9836638405ea9082e1543d5b8408f014979dcbf \ + --hash=sha256:7140d9e5a74a0908493bb4a13f1f16a401297bd755ada4c707e842fbf6f0f5bf \ + --hash=sha256:787e480ca4c1d08c9985a7eb1eae4345c107729c99e9b5a9a00f2575fc7d4b4b \ + --hash=sha256:7a673a72aab81fef5ebe755c6e0cc60087d1f3a4700835d40537183c1703a45f \ + --hash=sha256:7b06b70af54daa4b4904cbb90b4eb6d35c9f3252fdc86c9c32d5afd4d30118d8 \ + --hash=sha256:7c867012c0e8bcd5bdad0f791609101cb5c66acb303ab3270218d6debc68a65e \ + --hash=sha256:7cd6175f7eaec7142d2bf6f6597ce7db4c9ac89acf93fcdb17410c3a8b781eeb \ + --hash=sha256:7fd6071249c74f779c5b27183295b9202f8dedb68034e716784364443879eaa6 \ + --hash=sha256:859ba1acde782609a0910a26a60e16c191a82bf39b5621107552c0cd79fad00f \ + --hash=sha256:89f65f69636b7e9c015b79dff9c9985a9bc7d19ded6f79ef9f1ec920fdd73ecf \ + --hash=sha256:926ef920ae2e8182db31d3f5d081ada57804e3e1d3a8c4ef8b117f9d9fb5a945 \ + --hash=sha256:98501e075f35dd1a1d5a20f65be26839fcb1938752ec61539af008a5aa6f510b \ + --hash=sha256:a1151d6a6dd4b43e552394aed0edfe9292820272f0194bd56c7c1660a0c06c3d \ + --hash=sha256:a52c19171daaf2e697dc6cbe67684e0fa341b1248966f6aebb541de654d15843 \ + --hash=sha256:b293734059ef656dcd65be62ff771507bea8fed0a711b6733976e1ed3add4553 \ + --hash=sha256:b99a308a2e5e569031ab164b74e6fab0b6f37dfb493c32f7816225f4d411a6dd \ + --hash=sha256:bcbbef6cc277f8f18f36959e305f10b1c620442d75addc79c21d7073ae581b50 \ + --hash=sha256:bed9cf85b296fa2b76fc2547b9cbb691a523864cebaee86304c43a7b4cb1b452 \ + --hash=sha256:c581258cf346b327c62c4f1cebd32691826306f6a41d8c4bec43b010dee08e75 \ + --hash=sha256:cdb701eec783d3ec86b7cd4c763adad8eaf6b46db37ee1c36e5e6c44b3fe1b5f \ + --hash=sha256:d0cb51f53b6aae3c36bafe41e86167c71af8370a039f542c43b0cce5ef24a68c \ + --hash=sha256:d1e5ca43013e8935f25457a4fca47e315780172c3e821b4b13a890668911c792 \ + --hash=sha256:d490142b0521ef22bc1085f061d922a2a6666175bb6b42e588ff95c0db6819b2 \ + --hash=sha256:d7b67e724bead13f18db6e1d10b6bbdc454af574d70efbb36f27d90387be1ca3 \ + --hash=sha256:d8cf876516548b5a1d6ac4745d8b554f5c07891d55da557925e5c13ff0b4e6ad \ + --hash=sha256:e3d1d2cc4882e8d6a1adf9d5927d7716f80617fc693385661caff21888972269 \ + --hash=sha256:e58b47f933aca74c6a60a79dcb21d5b9e47416256c795c2d58d55cec27f9551d \ + --hash=sha256:ea5f536e32ea8ec96086ee00d7a4a131ce583a1b18d130711707c10e69601cb2 \ + --hash=sha256:f295105c6bdbb05bd5e1b0cafbd78ff95036f5d3641e7949455a3f4e5e7c3109 \ + --hash=sha256:f4d158189eb2ecffea3a51edf6d25e110b3678ec47f1a40f2d541eafbd8f6250 \ + --hash=sha256:fb89f811e5efd18bab141afc3fea3de141c3f69f3fe9e898f710ae7fe3aab251 \ + --hash=sha256:ff88712338b01031910e8e61e7239aff3ce8869ee31a47df63cb38aadd591bea + # via -r pipeline/translate/requirements/translate-ctranslate2.in +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via blessed +wcwidth==0.2.13 \ + --hash=sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859 \ + --hash=sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5 + # via blessed + +# The following packages are considered to be unsafe in a requirements file: +setuptools==72.1.0 \ + --hash=sha256:5a03e1860cf56bb6ef48ce186b0e557fdba433237481a9a625176c2831be15d1 \ + --hash=sha256:8d243eff56d095e5817f796ede6ae32941278f542e0f941867cc05ae52b162ec + # via ctranslate2 diff --git a/pipeline/translate/translate-nbest.sh b/pipeline/translate/translate-nbest.sh deleted file mode 100755 index c9c38d73c..000000000 --- a/pipeline/translate/translate-nbest.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -## -# Translates files generating n-best lists as output -# - -set -x -set -euo pipefail - -test -v GPUS -test -v MARIAN -test -v WORKSPACE - -input=$1 -vocab=$2 -models=( "${@:3}" ) - -output="${input}.nbest" - -cd "$(dirname "${0}")" - -"${MARIAN}/marian-decoder" \ - --config decoder.yml \ - --models "${models[@]}" \ - --vocabs "${vocab}" "${vocab}" \ - --input "${input}" \ - --output "${output}" \ - --log "${input}.log" \ - --n-best \ - --devices ${GPUS} \ - --workspace "${WORKSPACE}" - -# Test that the input and output have the same number of sentences. -test "$(wc -l <"${output}")" -eq "$(( $(wc -l <"${input}") * 8 ))" diff --git a/pipeline/translate/translate.py b/pipeline/translate/translate.py new file mode 100644 index 000000000..e12e56d78 --- /dev/null +++ b/pipeline/translate/translate.py @@ -0,0 +1,256 @@ +""" +Translate a corpus using either Marian or CTranslate2. +""" + +import argparse +from enum import Enum +from glob import glob +import os +from pathlib import Path +import tempfile + +from pipeline.common.command_runner import apply_command_args, run_command +from pipeline.common.datasets import compress, decompress +from pipeline.common.downloads import count_lines, is_file_empty, write_lines +from pipeline.common.logging import ( + get_logger, + start_gpu_logging, + start_byte_count_logger, + stop_gpu_logging, + stop_byte_count_logger, +) +from pipeline.common.marian import get_combined_config +from pipeline.translate.translate_ctranslate2 import translate_with_ctranslate2 + +logger = get_logger(__file__) + +DECODER_CONFIG_PATH = Path(__file__).parent / "decoder.yml" + + +class Decoder(Enum): + marian = "marian" + ctranslate2 = "ctranslate2" + + +class Device(Enum): + cpu = "cpu" + gpu = "gpu" + + +def get_beam_size(extra_marian_args: list[str]): + return get_combined_config(DECODER_CONFIG_PATH, extra_marian_args)["beam-size"] + + +def run_marian( + marian_dir: Path, + models: list[Path], + vocab: str, + input: Path, + output: Path, + gpus: list[str], + workspace: int, + is_nbest: bool, + extra_args: list[str], +): + config = Path(__file__).parent / "decoder.yml" + marian_bin = str(marian_dir / "marian-decoder") + log = input.parent / f"{input.name}.log" + if is_nbest: + extra_args = ["--n-best", *extra_args] + + logger.info("Starting Marian to translate") + + run_command( + [ + marian_bin, + *apply_command_args( + { + "config": config, + "models": models, + "vocabs": [vocab, vocab], + "input": input, + "output": output, + "log": log, + "devices": gpus, + "workspace": workspace, + } + ), + *extra_args, + ], + logger=logger, + env={**os.environ}, + ) + + +def main() -> None: + parser = argparse.ArgumentParser( + description=__doc__, + # Preserves whitespace in the help text. + formatter_class=argparse.RawTextHelpFormatter, + ) + + parser.add_argument( + "--input", type=Path, required=True, help="The path to the text to translate." + ) + parser.add_argument( + "--models_glob", + type=str, + required=True, + nargs="+", + help="A glob pattern to the Marian model(s)", + ) + parser.add_argument( + "--artifacts", type=Path, required=True, help="Output path to the artifacts." + ) + parser.add_argument("--nbest", action="store_true", help="Whether to use the nbest") + parser.add_argument( + "--marian_dir", type=Path, required=True, help="The path the Marian binaries" + ) + parser.add_argument("--vocab", type=Path, help="Path to vocab file") + parser.add_argument( + "--gpus", + type=str, + required=True, + help='The indexes of the GPUs to use on a system, e.g. --gpus "0 1 2 3"', + ) + parser.add_argument( + "--workspace", + type=str, + required=True, + help="The amount of Marian memory (in MB) to preallocate", + ) + parser.add_argument( + "--decoder", + type=Decoder, + default=Decoder.marian, + help="Either use the normal marian decoder, or opt for CTranslate2.", + ) + parser.add_argument( + "--device", + type=Device, + default=Device.gpu, + help="Either use the normal marian decoder, or opt for CTranslate2.", + ) + parser.add_argument( + "extra_marian_args", + nargs=argparse.REMAINDER, + help="Additional parameters for the training script", + ) + + args = parser.parse_args() + + # Provide the types for the arguments. + marian_dir: Path = args.marian_dir + input_zst: Path = args.input + artifacts: Path = args.artifacts + models_globs: list[str] = args.models_glob + models: list[Path] = [] + for models_glob in models_globs: + for path in glob(models_glob): + models.append(Path(path)) + postfix = "nbest" if args.nbest else "out" + output_zst = artifacts / f"{input_zst.stem}.{postfix}.zst" + vocab: Path = args.vocab + gpus: list[str] = args.gpus.split(" ") + extra_marian_args: list[str] = args.extra_marian_args + decoder: Decoder = args.decoder + is_nbest: bool = args.nbest + device: Device = args.device + + # Do some light validation of the arguments. + assert input_zst.exists(), f"The input file exists: {input_zst}" + assert vocab.exists(), f"The vocab file exists: {vocab}" + if not artifacts.exists(): + artifacts.mkdir() + for gpu_index in gpus: + assert gpu_index.isdigit(), f'GPUs must be list of numbers: "{gpu_index}"' + assert models, "There must be at least one model" + for model in models: + assert model.exists(), f"The model file exists {model}" + if extra_marian_args and extra_marian_args[0] != "--": + logger.error(" ".join(extra_marian_args)) + raise Exception("Expected the extra marian args to be after a --") + + logger.info(f"Input file: {input_zst}") + logger.info(f"Output file: {output_zst}") + + # Taskcluster can produce empty input files when chunking out translation for + # parallelization. In this case skip translating, and write out an empty file. + if is_file_empty(input_zst): + logger.info(f"The input is empty, create a blank output: {output_zst}") + with write_lines(output_zst) as _outfile: + # Nothing to write, just create the file. + pass + return + + if decoder == Decoder.ctranslate2: + translate_with_ctranslate2( + input_zst=input_zst, + artifacts=artifacts, + extra_marian_args=extra_marian_args, + models_globs=models_globs, + is_nbest=is_nbest, + vocab=[str(vocab)], + device=device.value, + device_index=[int(n) for n in gpus], + ) + return + + # The device flag is for use with CTranslate, but add some assertions here so that + # we can be consistent in usage. + if device == Device.cpu: + assert ( + "--cpu-threads" in extra_marian_args + ), "Marian's cpu should be controlled with the flag --cpu-threads" + else: + assert ( + "--cpu-threads" not in extra_marian_args + ), "Requested a GPU device, but --cpu-threads was provided" + + # Run the training. + with tempfile.TemporaryDirectory() as temp_dir_str: + temp_dir = Path(temp_dir_str) + input_txt = temp_dir / input_zst.stem + output_txt = temp_dir / output_zst.stem + + decompress(input_zst, destination=input_txt, remove=True, logger=logger) + + five_minutes = 300 + if device == Device.gpu: + start_gpu_logging(logger, five_minutes) + start_byte_count_logger(logger, five_minutes, output_txt) + + run_marian( + marian_dir=marian_dir, + models=models, + vocab=vocab, + input=input_txt, + output=output_txt, + gpus=gpus, + workspace=args.workspace, + is_nbest=is_nbest, + # Take off the initial "--" + extra_args=extra_marian_args[1:], + ) + + stop_gpu_logging() + stop_byte_count_logger() + + compress(output_txt, destination=output_zst, remove=True, logger=logger) + + input_count = count_lines(input_txt) + output_count = count_lines(output_zst) + if is_nbest: + beam_size = get_beam_size(extra_marian_args) + expected_output = input_count * beam_size + assert ( + expected_output == output_count + ), f"The nbest output had {beam_size}x as many lines ({expected_output} vs {output_count})" + else: + assert ( + input_count == output_count + ), f"The input ({input_count} and output ({output_count}) had the same number of lines" + + +if __name__ == "__main__": + main() diff --git a/pipeline/translate/translate.sh b/pipeline/translate/translate.sh deleted file mode 100755 index 46ca3b103..000000000 --- a/pipeline/translate/translate.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/bin/bash -## -# Translates input dataset -# - -set -x -set -euo pipefail - -test -v GPUS -test -v MARIAN -test -v WORKSPACE - -input=$1 -vocab=$2 -models=( "${@:3}" ) -output="${input}.out" - -cd "$(dirname "${0}")" - -"${MARIAN}/marian-decoder" \ - --config decoder.yml \ - --models "${models[@]}" \ - --vocabs "${vocab}" "${vocab}" \ - --input "${input}" \ - --output "${output}" \ - --log "${input}.log" \ - --devices ${GPUS} \ - --workspace "${WORKSPACE}" - -# Test that the input and output have the same number of sentences. -test "$(wc -l <"${input}")" == "$(wc -l <"${output}")" diff --git a/pipeline/translate/translate_ctranslate2.py b/pipeline/translate/translate_ctranslate2.py new file mode 100644 index 000000000..9f83af37b --- /dev/null +++ b/pipeline/translate/translate_ctranslate2.py @@ -0,0 +1,198 @@ +""" +Translate a corpus with a teacher model (transformer-based) using CTranslate2. This is useful +to quickly synthesize training data for student distillation as CTranslate2 is ~2 times faster +than Marian. For a more detailed analysis see: https://github.com/mozilla/translations/issues/931 + +https://github.com/OpenNMT/CTranslate2/ +""" + +from typing import Any, TextIO +from enum import Enum +from glob import glob +from pathlib import Path + +import ctranslate2 +import sentencepiece as spm +from ctranslate2.converters.marian import MarianConverter + +from pipeline.common.downloads import read_lines, write_lines +from pipeline.common.logging import ( + get_logger, + start_gpu_logging, + start_byte_count_logger, + stop_gpu_logging, + stop_byte_count_logger, +) +from pipeline.common.marian import get_combined_config + + +def load_vocab(path: str): + logger.info("Loading vocab:") + logger.info(path) + sp = spm.SentencePieceProcessor(path) + + return [sp.id_to_piece(i) for i in range(sp.vocab_size())] + + +# The vocab expects a .yml file. Instead directly load the vocab .spm file via a monkey patch. +if not ctranslate2.converters.marian.load_vocab: + raise Exception("Expected to be able to monkey patch the load_vocab function") +ctranslate2.converters.marian.load_vocab = load_vocab + +logger = get_logger(__file__) + + +class Device(Enum): + gpu = "gpu" + cpu = "cpu" + + +class MaxiBatchSort(Enum): + src = "src" + none = "none" + + +def get_model(models_globs: list[str]) -> Path: + models: list[Path] = [] + for models_glob in models_globs: + for path in glob(models_glob): + models.append(Path(path)) + if not models: + raise ValueError(f'No model was found with the glob "{models_glob}"') + if len(models) != 1: + logger.info(f"Found models {models}") + raise ValueError("Ensemble training is not supported in CTranslate2") + return Path(models[0]) + + +class DecoderConfig: + def __init__(self, extra_marian_args: list[str]) -> None: + super().__init__() + # Combine the two configs. + self.config = get_combined_config(Path(__file__).parent / "decoder.yml", extra_marian_args) + + self.mini_batch_words: int = self.get_from_config("mini-batch-words", int) + self.beam_size: int = self.get_from_config("beam-size", int) + self.precision = self.get_from_config("precision", str) + + def get_from_config(self, key: str, type: any): + value = self.config.get(key, None) + if value is None: + raise ValueError(f'"{key}" could not be found in the decoder.yml config') + if isinstance(value, type): + return value + if type == int and isinstance(value, str): + return int(value) + raise ValueError(f'Expected "{key}" to be of a type "{type}" in the decoder.yml config') + + +def write_single_translation( + _index: int, tokenizer_trg: spm.SentencePieceProcessor, result: Any, outfile: TextIO +): + """ + Just write each single translation to a new line. If beam search was used all the other + beam results are discarded. + """ + line = tokenizer_trg.decode(result.hypotheses[0]) + outfile.write(line) + outfile.write("\n") + + +def write_nbest_translations( + index: int, tokenizer_trg: spm.SentencePieceProcessor, result: Any, outfile: TextIO +): + """ + Match Marian's way of writing out nbest translations. For example, with a beam-size of 2 and + collection nbest translations: + + 0 ||| Translation attempt + 0 ||| An attempt at translation + 1 ||| The quick brown fox jumped + 1 ||| The brown fox quickly jumped + ... + """ + for hypothesis in result.hypotheses: + line = tokenizer_trg.decode(hypothesis) + outfile.write(f"{index} ||| {line}\n") + + +def translate_with_ctranslate2( + input_zst: Path, + artifacts: Path, + extra_marian_args: list[str], + models_globs: list[str], + is_nbest: bool, + vocab: list[str], + device: str, + device_index: list[int], +) -> None: + model = get_model(models_globs) + postfix = "nbest" if is_nbest else "out" + + tokenizer_src = spm.SentencePieceProcessor(vocab[0]) + if len(vocab) == 1: + tokenizer_trg = tokenizer_src + else: + tokenizer_trg = spm.SentencePieceProcessor(vocab[1]) + + if extra_marian_args and extra_marian_args[0] != "--": + logger.error(" ".join(extra_marian_args)) + raise Exception("Expected the extra marian args to be after a --") + + decoder_config = DecoderConfig(extra_marian_args[1:]) + + ctranslate2_model_dir = model.parent / f"{Path(model).stem}" + logger.info("Converting the Marian model to Ctranslate2:") + logger.info(model) + logger.info("Outputing model to:") + logger.info(ctranslate2_model_dir) + + converter = MarianConverter(model, vocab) + converter.convert(ctranslate2_model_dir, quantization=decoder_config.precision) + + if device == "gpu": + translator = ctranslate2.Translator( + str(ctranslate2_model_dir), device="cuda", device_index=device_index + ) + else: + translator = ctranslate2.Translator(str(ctranslate2_model_dir), device="cpu") + + logger.info("Loading model") + translator.load_model() + logger.info("Model loaded") + + output_zst = artifacts / f"{input_zst.stem}.{postfix}.zst" + + num_hypotheses = 1 + write_translation = write_single_translation + if is_nbest: + num_hypotheses = decoder_config.beam_size + write_translation = write_nbest_translations + + def tokenize(line): + return tokenizer_src.Encode(line.strip(), out_type=str) + + five_minutes = 300 + if device == "gpu": + start_gpu_logging(logger, five_minutes) + start_byte_count_logger(logger, five_minutes, output_zst) + + index = 0 + with write_lines(output_zst) as outfile, read_lines(input_zst) as lines: + for result in translator.translate_iterable( + # Options for "translate_iterable": + # https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_iterable + map(tokenize, lines), + max_batch_size=decoder_config.mini_batch_words, + batch_type="tokens", + # Options for "translate_batch": + # https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch + beam_size=decoder_config.beam_size, + return_scores=False, + num_hypotheses=num_hypotheses, + ): + write_translation(index, tokenizer_trg, result, outfile) + index += 1 + + stop_gpu_logging() + stop_byte_count_logger() diff --git a/taskcluster/configs/config.ci.yml b/taskcluster/configs/config.ci.yml index 25610dbdd..759e0529a 100644 --- a/taskcluster/configs/config.ci.yml +++ b/taskcluster/configs/config.ci.yml @@ -23,6 +23,7 @@ experiment: use-opuscleaner: "true" opuscleaner-mode: "custom" teacher-mode: "two-stage" + teacher-decoder: marian corpus-max-sentences: 1000 student-model: "tiny" diff --git a/taskcluster/configs/config.prod.yml b/taskcluster/configs/config.prod.yml index 09eae49ea..b4e32d79a 100644 --- a/taskcluster/configs/config.prod.yml +++ b/taskcluster/configs/config.prod.yml @@ -74,6 +74,8 @@ experiment: # Switch to "one-stage" training if back-translations are produced by a high quality model or # the model stops too early on the fine-tuning stage teacher-mode: "two-stage" + # Translate with either Marian, or CTranslate2. + teacher-decoder: marian # Two student training configurations from Bergamot are supported: "tiny" and "base" # "base" model is twice slower and larger but adds ~2 COMET points in quality (see https://github.com/mozilla/translations/issues/174) student-model: "tiny" diff --git a/taskcluster/kinds/translate-corpus/kind.yml b/taskcluster/kinds/translate-corpus/kind.yml index 5e488f139..452264a75 100644 --- a/taskcluster/kinds/translate-corpus/kind.yml +++ b/taskcluster/kinds/translate-corpus/kind.yml @@ -33,11 +33,11 @@ tasks: cache: type: translate-corpus resources: - - pipeline/translate/translate-nbest.sh - - taskcluster/scripts/pipeline/translate-taskcluster.sh + - pipeline/translate/translate.py from-parameters: split_chunks: training_config.taskcluster.split-chunks marian_args: training_config.marian-args.decoding-teacher + teacher_decoder: training_config.experiment.teacher-decoder # This job is split into `split-chunks` chunk: @@ -76,6 +76,7 @@ tasks: trg_locale: training_config.experiment.trg best_model: training_config.experiment.best-model split_chunks: training_config.taskcluster.split-chunks + teacher_decoder: training_config.experiment.teacher-decoder substitution-fields: - description - worker.env @@ -100,6 +101,7 @@ tasks: env: CUDA_DIR: fetches/cuda-toolkit CUDNN_DIR: fetches/cuda-toolkit + MARIAN: $MOZ_FETCHES_DIR # 128 happens when cloning this repository fails retry-exit-status: [128] @@ -111,13 +113,19 @@ tasks: # double curly braces are used for the chunk substitutions because # this must first be formatted by task-context to get src and trg locale - >- - export MARIAN=$MOZ_FETCHES_DIR && - $VCS_PATH/taskcluster/scripts/pipeline/translate-taskcluster.sh - $MOZ_FETCHES_DIR/file.{{this_chunk}}.zst - artifacts - nbest - $MOZ_FETCHES_DIR/vocab.spm - $MOZ_FETCHES_DIR/model*/*.npz + pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt && + export PYTHONPATH=$PYTHONPATH:$VCS_PATH && + python3 $VCS_PATH/pipeline/translate/translate.py + --input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst" + --models_glob "$MOZ_FETCHES_DIR/*.npz" "$MOZ_FETCHES_DIR/model*/*.npz" + --artifacts "$TASK_WORKDIR/artifacts" + --vocab "$MOZ_FETCHES_DIR/vocab.spm" + --marian_dir "$MARIAN" + --gpus "$GPUS" + --workspace "$WORKSPACE" + --decoder "{teacher_decoder}" + --nbest + -- {marian_args} fetches: diff --git a/taskcluster/kinds/translate-mono-src/kind.yml b/taskcluster/kinds/translate-mono-src/kind.yml index 9f28aec1a..7480e4986 100644 --- a/taskcluster/kinds/translate-mono-src/kind.yml +++ b/taskcluster/kinds/translate-mono-src/kind.yml @@ -5,6 +5,7 @@ loader: taskgraph.loader.transform:loader transforms: + - translations_taskgraph.transforms.marian_args:transforms - translations_taskgraph.transforms.worker_selection - taskgraph.transforms.task_context - translations_taskgraph.transforms.cast_to @@ -37,10 +38,13 @@ tasks: cache: type: translate-mono-src resources: - - pipeline/translate/translate.sh - - taskcluster/scripts/pipeline/translate-taskcluster.sh + - pipeline/translate/translate.py + - pipeline/translate/translate_ctranslate2.py + - pipeline/translate/requirements/translate-ctranslate2.txt from-parameters: split_chunks: training_config.taskcluster.split-chunks + marian_args: training_config.marian-args.decoding-teacher + teacher_decoder: training_config.experiment.teacher-decoder task-context: from-parameters: @@ -49,12 +53,14 @@ tasks: best_model: training_config.experiment.best-model locale: training_config.experiment.src split_chunks: training_config.taskcluster.split-chunks + teacher_decoder: training_config.experiment.teacher-decoder substitution-fields: - chunk.total-chunks - description - label - worker.env - attributes + - run.command cast-to: int: @@ -95,9 +101,13 @@ tasks: env: CUDA_DIR: fetches/cuda-toolkit CUDNN_DIR: fetches/cuda-toolkit + MARIAN: $MOZ_FETCHES_DIR # 128 happens when cloning this repository fails retry-exit-status: [128] + marian-args: + from-parameters: training_config.marian-args.decoding-teacher + # Don't run unless explicitly scheduled run-on-tasks-for: [] @@ -107,13 +117,19 @@ tasks: - bash - -xc - >- - export MARIAN=$MOZ_FETCHES_DIR && - $VCS_PATH/taskcluster/scripts/pipeline/translate-taskcluster.sh - $MOZ_FETCHES_DIR/file.{this_chunk}.zst - artifacts - plain - $MOZ_FETCHES_DIR/vocab.spm - $MOZ_FETCHES_DIR/model*/*.npz + pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt && + export PYTHONPATH=$PYTHONPATH:$VCS_PATH && + python3 $VCS_PATH/pipeline/translate/translate.py + --input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst" + --models_glob "$MOZ_FETCHES_DIR/*.npz" "$MOZ_FETCHES_DIR/model*/*.npz" + --artifacts "$TASK_WORKDIR/artifacts" + --vocab "$MOZ_FETCHES_DIR/vocab.spm" + --marian_dir "$MARIAN" + --gpus "$GPUS" + --workspace "$WORKSPACE" + --decoder "{teacher_decoder}" + -- + {marian_args} fetches: toolchain: diff --git a/taskcluster/kinds/translate-mono-trg/kind.yml b/taskcluster/kinds/translate-mono-trg/kind.yml index 041ab0136..ca6e7c9d3 100644 --- a/taskcluster/kinds/translate-mono-trg/kind.yml +++ b/taskcluster/kinds/translate-mono-trg/kind.yml @@ -36,11 +36,13 @@ tasks: cache: type: translate-mono-trg resources: - - pipeline/translate/translate.sh - - taskcluster/scripts/pipeline/translate-taskcluster.sh + - pipeline/translate/translate.py + - pipeline/translate/translate_ctranslate2.py + - pipeline/translate/requirements/translate-ctranslate2.txt from-parameters: split_chunks: training_config.taskcluster.split-chunks marian_args: training_config.marian-args.decoding-backward + teacher_decoder: training_config.experiment.teacher-decoder task-context: from-parameters: @@ -49,15 +51,16 @@ tasks: best_model: training_config.experiment.best-model locale: training_config.experiment.trg split_chunks: training_config.taskcluster.split-chunks + teacher_decoder: training_config.experiment.teacher-decoder substitution-fields: + - chunk.total-chunks - description - - fetches.train-backwards - - dependencies + - label - worker.env - attributes - - label - run.command - - chunk.total-chunks + - fetches.train-backwards + - dependencies cast-to: int: @@ -89,7 +92,6 @@ tasks: marian-args: from-parameters: training_config.marian-args.decoding-backward - worker-type: b-largegpu worker: max-run-time: 2592000 @@ -100,6 +102,7 @@ tasks: env: CUDA_DIR: fetches/cuda-toolkit CUDNN_DIR: fetches/cuda-toolkit + MARIAN: $MOZ_FETCHES_DIR # 128 happens when cloning this repository fails retry-exit-status: [128] @@ -114,11 +117,16 @@ tasks: # double curly braces are used for the chunk substitutions because # this must first be formatted by task-context to get src and trg locale - >- - export MARIAN=$MOZ_FETCHES_DIR && - $VCS_PATH/taskcluster/scripts/pipeline/translate-taskcluster.sh - $MOZ_FETCHES_DIR/file.{{this_chunk}}.zst - artifacts - plain - $MOZ_FETCHES_DIR/vocab.spm - $MOZ_FETCHES_DIR/*.npz + pip3 install -r $VCS_PATH/pipeline/translate/requirements/translate-ctranslate2.txt && + export PYTHONPATH=$PYTHONPATH:$VCS_PATH && + python3 $VCS_PATH/pipeline/translate/translate.py + --input "$MOZ_FETCHES_DIR/file.{{this_chunk}}.zst" + --models_glob "$MOZ_FETCHES_DIR/*.npz" "$MOZ_FETCHES_DIR/model*/*.npz" + --artifacts "$TASK_WORKDIR/artifacts" + --vocab "$MOZ_FETCHES_DIR/vocab.spm" + --marian_dir "$MARIAN" + --gpus "$GPUS" + --workspace "$WORKSPACE" + --decoder "marian" + -- {marian_args} diff --git a/taskcluster/scripts/pipeline/translate-taskcluster.sh b/taskcluster/scripts/pipeline/translate-taskcluster.sh deleted file mode 100755 index 79133ea4e..000000000 --- a/taskcluster/scripts/pipeline/translate-taskcluster.sh +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash - -set -x -set -euo pipefail - -input=$1 -output_dir=$2 -type=$3 -other_args=( "${@:4}" ) - -pushd `dirname $0`/../../.. &>/dev/null -VCS_ROOT=$(pwd) -popd &>/dev/null - -mkdir -p "${output_dir}" - -zstd -d --rm "${input}" -input="${input%.zst}" - -outfile="${input}.out" -if [ "${type}" = "nbest" ]; then - outfile="${input}.nbest" -fi - -# In Taskcluster, we always parallelize this step N ways. In rare cases, there -# may not be enough input files to feed all of these jobs. If we received an -# empty input file we have nothing to do other than copying the empty file -# to the output file, simulating successfully completion. -if [ -s "${input}" ]; then - if [ "${type}" = "plain" ]; then - ${VCS_ROOT}/pipeline/translate/translate.sh "${input}" "${other_args[@]}" - elif [ "${type}" = "nbest" ]; then - ${VCS_ROOT}/pipeline/translate/translate-nbest.sh "${input}" "${other_args[@]}" - fi -else - cp "${input}" "${outfile}" -fi - -zstd --rm "${outfile}" -cp "${outfile}.zst" "${output_dir}" diff --git a/taskcluster/test/params/large-lt-en.yml b/taskcluster/test/params/large-lt-en.yml index 7dca94c8f..5af8cddc2 100644 --- a/taskcluster/test/params/large-lt-en.yml +++ b/taskcluster/test/params/large-lt-en.yml @@ -150,6 +150,7 @@ training_config: src: lt teacher-ensemble: 2 teacher-mode: 'two-stage' + teacher-decoder: marian student-model: 'base' trg: en use-opuscleaner: 'false' diff --git a/taskcluster/test/params/small-ru-en.yml b/taskcluster/test/params/small-ru-en.yml index 9916c1ad3..930ad51b1 100644 --- a/taskcluster/test/params/small-ru-en.yml +++ b/taskcluster/test/params/small-ru-en.yml @@ -62,6 +62,7 @@ training_config: src: ru teacher-ensemble: 1 teacher-mode: 'two-stage' + teacher-decoder: marian student-model: 'tiny' trg: en use-opuscleaner: 'true' diff --git a/taskcluster/translations_taskgraph/actions/train.py b/taskcluster/translations_taskgraph/actions/train.py index 5a246a294..9b5bf862d 100644 --- a/taskcluster/translations_taskgraph/actions/train.py +++ b/taskcluster/translations_taskgraph/actions/train.py @@ -131,6 +131,12 @@ def validate_pretrained_models(params): "enum": ["one-stage", "two-stage"], "default": "two-stage", }, + "teacher-decoder": { + "type": "string", + "description": "Translate with either Marian or CTranslate2", + "enum": ["marian", "ctranslate2"], + "default": "marian", + }, "student-model": { "type": "string", "description": "Student model configuration", diff --git a/taskcluster/translations_taskgraph/parameters.py b/taskcluster/translations_taskgraph/parameters.py index fea1f1f11..2f2168ea1 100644 --- a/taskcluster/translations_taskgraph/parameters.py +++ b/taskcluster/translations_taskgraph/parameters.py @@ -37,6 +37,7 @@ def get_ci_training_config(_=None) -> dict: Required("trg"): str, Required("teacher-ensemble"): int, Required("teacher-mode"): str, + Required("teacher-decoder"): str, Required("student-model"): str, Optional("corpus-max-sentences"): int, Required("mono-max-sentences-trg"): { diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py index 3b219f5cd..dfc0effeb 100644 --- a/tests/fixtures/__init__.py +++ b/tests/fixtures/__init__.py @@ -13,6 +13,7 @@ import zstandard as zstd +from pipeline.common.downloads import read_lines from utils.preflight_check import get_taskgraph_parameters, run_taskgraph FIXTURES_PATH = os.path.dirname(os.path.abspath(__file__)) @@ -78,10 +79,13 @@ def join(self, *paths: str): """Create a folder or file name by joining it to the test directory.""" return os.path.join(self.path, *paths) - def load(self, name: str): - """Load a text file""" - with open(self.join(name), "r") as file: - return file.read() + def read_text(self, name: str): + """Load the text from a file. It can be a txt file or a compressed file.""" + text = "" + with read_lines(self.join(name)) as lines: + for line in lines: + text += line + return text def create_zst(self, name: str, contents: str) -> str: """ @@ -133,8 +137,9 @@ def run_task( work_dir: Optional[str] = None, fetches_dir: Optional[str] = None, env: dict[str, str] = {}, + extra_flags: List[str] = None, extra_args: List[str] = None, - replace_args: List[str] = None, + replace_args: List[Tuple[str, str]] = None, config: Optional[str] = None, ): """ @@ -154,10 +159,14 @@ def run_task( env - Any environment variable overrides. - extra_args - Extra Marian arguments + extra_flags - Place extra flags in the command before the "--" that is used to apply + marian args. - config - A path to a Taskcluster config file + extra_args - Place extra arguments at the end of the command. + + replace_args - A list of Tuples where an argument is replaced by word match. + config - A path to a Taskcluster config file """ command_parts, requirements, task_env = get_task_command_and_env(task_name, config=config) @@ -177,6 +186,13 @@ def run_task( fetches_dir = self.path for command_parts_split in split_on_ampersands_operator(command_parts): + if extra_flags: + index = command_parts_split.index("--") + command_parts_split = [ # noqa: PLW2901 + *command_parts_split[:index], + *extra_flags, + *command_parts_split[index:], + ] if extra_args: command_parts_split.extend(extra_args) @@ -503,7 +519,7 @@ def get_task_command_and_env( ] # The python binary will be picked by the run_task abstraction. - if command_parts[0] == "python" or command_parts[0] == "python3": + if requirements and (command_parts[0] == "python" or command_parts[0] == "python3"): command_parts = command_parts[1:] # Return the full command. diff --git a/tests/fixtures/config.pytest.yml b/tests/fixtures/config.pytest.yml index ff1eb0f6d..93c544b01 100644 --- a/tests/fixtures/config.pytest.yml +++ b/tests/fixtures/config.pytest.yml @@ -21,6 +21,7 @@ experiment: spm-sample-size: 10_000_000 teacher-ensemble: 2 teacher-mode: "two-stage" + teacher-decoder: marian student-model: "tiny" backward-model: NOT-YET-SUPPORTED vocab: NOT-YET-SUPPORTED diff --git a/tests/fixtures/marian-decoder b/tests/fixtures/marian-decoder index f2e2dc647..4a74cad92 100755 --- a/tests/fixtures/marian-decoder +++ b/tests/fixtures/marian-decoder @@ -2,27 +2,103 @@ """ marian-decoder test fixture -Do not rely on marian-decoder in tests. This mocks marian-decoder by uppercasing the -source sentences, and saving the arguments to marian-decoder.args.txt. +Use this mock for tests that do not need the real marian-decoder in tests. This mocks +marian-decoder by uppercasing the source sentences, and saving the arguments to +marian-decoder.args.txt. + +It supports the behavior of --n-best with multiple sentences as well. """ import json import os +from pathlib import Path import sys -artifacts_dir = os.environ.get("TEST_ARTIFACTS") +# Relative imports require the source directory to be on the PYTHONPATH. +src_dir = Path(__file__).parent / "../.." +if src_dir not in sys.path: + sys.path.append(src_dir) + +from pipeline.common.marian import get_combined_config + + +def output_to_file(is_nbest: bool, beam_size: int): + """ + Output to file when the input was a file. + """ + input_path = Path(sys.argv[sys.argv.index("--input") + 1]) + output_path = Path(sys.argv[sys.argv.index("--output") + 1]) + assert input_path.exists(), "The input file exists" + + print(f"[marian-decoder] open {input_path}") + print(f"[marian-decoder] write out uppercase lines to {output_path}") + if is_nbest: + print(f"[marian-decoder] outputing nbest with a beam size of: {beam_size}") + + with input_path.open("rt") as input: + with output_path.open("wt") as outfile: + for line_index, line in enumerate(input): + if is_nbest: + for beam_index in range(beam_size): + outfile.write(f"{line_index} ||| {line.upper().strip()} {beam_index}\n") + else: + outfile.write(line.upper()) + + +def output_to_stdout(is_nbest: bool, beam_size: int): + """ + Output to stdout when no input file was provided. + """ + # The input is being provided as stdin. + for line_index, line in enumerate(sys.stdin): + if is_nbest: + for beam_index in range(beam_size): + print(f"{line_index} ||| {line.upper().strip()} {beam_index}") + else: + print(line.upper(), end="") + + +def write_arguments_to_disk(): + """ + This allows tests make assertions against the arguments provided. + """ + artifacts_dir = os.environ.get("TEST_ARTIFACTS") + + if not artifacts_dir: + raise Exception("TEST_ARTIFACTS was not set.") + + if not os.path.exists(artifacts_dir): + raise Exception("The TEST_ARTIFACTS directory did not exist") + + # Write the arguments to disk + with open(os.path.join(artifacts_dir, "marian-decoder.args.txt"), "w") as input_path: + json.dump(sys.argv[1:], input_path) + + +def determine_marian_config(): + """ + If --n-best is set, the lines are written out differently. Determine the n-best and beam_size + configuration. + """ + for config_index, arg in enumerate(sys.argv): + if arg in ("-c", "--config"): + config_path = Path(sys.argv[config_index + 1]) + break + config_dict = get_combined_config(config_path, sys.argv[1:]) + is_nbest = "--n-best" in sys.argv + beam_size = int(config_dict.get("beam-size", 0)) + + return is_nbest, beam_size -if not artifacts_dir: - raise Exception("TEST_ARTIFACTS was not set.") -if not os.path.exists(artifacts_dir): - raise Exception("The TEST_ARTIFACTS directory did not exist") +def main(): + write_arguments_to_disk() + is_nbest, beam_size = determine_marian_config() + try: + output_to_file(is_nbest, beam_size) + except ValueError: + output_to_stdout(is_nbest, beam_size) -# Write the arguments to disk -arguments = sys.argv[1:] -with open(os.path.join(artifacts_dir, "marian-decoder.args.txt"), "w") as file: - json.dump(arguments, file) -# Output the input but uppercase. -for line in sys.stdin: - print(line.upper(), end="") +if __name__ == "__main__": + main() diff --git a/tests/test_bicleaner.py b/tests/test_bicleaner.py index ef92929fa..2560c4068 100644 --- a/tests/test_bicleaner.py +++ b/tests/test_bicleaner.py @@ -3,13 +3,13 @@ from subprocess import CompletedProcess import pytest -import sh import yaml from fixtures import DataDir from pytest import fixture from pipeline.bicleaner import download_pack from pipeline.bicleaner.download_pack import main as download_model +from pipeline.common.datasets import decompress @pytest.fixture(scope="function") @@ -35,9 +35,9 @@ def _fake_download(src, trg, dir): download_pack._run_download = _fake_download -def decompress(path): - sh.zstd("-d", path) - with tarfile.open(path[:-4]) as tar: +def decompress_tar(path): + tar_path = decompress(path) + with tarfile.open(tar_path) as tar: tar.extractall(os.path.dirname(path)) @@ -58,7 +58,7 @@ def test_model_download(src, trg, model_src, model_trg, init, data_dir): download_model([f"--src={src}", f"--trg={trg}", target_path]) assert os.path.isfile(target_path) - decompress(target_path) + decompress_tar(target_path) assert os.path.isdir(decompressed_path) with open(meta_path) as f: metadata = yaml.safe_load(f) diff --git a/tests/test_common_datasets.py b/tests/test_common_datasets.py index 6a2e6e5d6..2431ee46c 100644 --- a/tests/test_common_datasets.py +++ b/tests/test_common_datasets.py @@ -1,16 +1,42 @@ import io +import logging +from pathlib import Path from typing import Iterator import pytest from fixtures import DataDir -from pipeline.common.datasets import WeakStringSet, shuffle_in_temp_files, shuffle_with_max_lines +from pipeline.common.logging import get_logger +from pipeline.common.datasets import ( + WeakStringSet, + compress, + decompress, + shuffle_in_temp_files, + shuffle_with_max_lines, +) +from pipeline.common.downloads import read_lines, write_lines ITEMS = 100_000 # ITEMS = 1_000 PERCENTAGE = 0.2 MAX_LINES = int(ITEMS * PERCENTAGE) +line_fixtures = [ + "line 1\n", + "line 2\n", + "line 3\n", + "line 4\n", + "line 5\n", +] +line_fixtures_bytes = "".join(line_fixtures).encode("utf-8") + + +def write_test_content(output_path: str) -> str: + with write_lines(output_path) as outfile: + for line in line_fixtures: + outfile.write(line) + return output_path + def get_total_byte_size(lines: list[str]) -> int: total_byte_size = 0 @@ -190,3 +216,61 @@ def test_weak_string_set(): assert "string b" in unique_strings2 assert "string c" not in unique_strings2 assert len(unique_strings2) == 2 + + +@pytest.mark.parametrize("suffix", ["zst", "gz"]) +@pytest.mark.parametrize("remove_or_keep", ["remove", "keep"]) +def test_compress(suffix: str, remove_or_keep: str): + data_dir = DataDir("test_common_datasets") + source = Path(data_dir.join("lines.txt")) + destination = Path(data_dir.join(f"lines.txt.{suffix}")) + logger = get_logger(__file__) + logger.setLevel(logging.INFO) + + write_test_content(source) + assert source.exists() + assert not destination.exists() + + with read_lines(source) as lines: + assert list(lines) == line_fixtures + + remove = remove_or_keep == "remove" + compress(source, destination, logger=logger, remove=remove) + + if remove: + assert not source.exists(), "The source file was removed." + else: + assert source.exists(), "The source file was kept." + + with read_lines(destination) as lines: + assert list(lines) == line_fixtures + + +@pytest.mark.parametrize("suffix", ["zst", "gz"]) +@pytest.mark.parametrize("remove_or_keep", ["remove", "keep"]) +def test_decompress(suffix: str, remove_or_keep: str): + data_dir = DataDir("test_common_datasets") + source = Path(data_dir.join(f"lines.txt.{suffix}")) + destination = Path(data_dir.join("lines.txt")) + logger = get_logger(__file__) + logger.setLevel(logging.INFO) + + write_test_content(source) + assert source.exists() + assert not destination.exists() + + with read_lines(source) as lines: + assert list(lines) == line_fixtures + + remove = remove_or_keep == "remove" + decompress(source, destination, remove=remove, logger=logger) + + if remove: + assert not source.exists(), "The source file was removed." + else: + assert source.exists(), "The source file was kept." + + assert destination.exists() + + with read_lines(destination) as lines: + assert list(lines) == line_fixtures diff --git a/tests/test_common_marian.py b/tests/test_common_marian.py new file mode 100644 index 000000000..5ac7d245f --- /dev/null +++ b/tests/test_common_marian.py @@ -0,0 +1,18 @@ +import pytest +from pipeline.common.marian import marian_args_to_dict + + +@pytest.mark.parametrize( + "marian_args,dict_value", + [ + # + (["--input", "file.txt"], {"input": "file.txt"}), + (["--vocab", "en.spm", "fr.spm"], {"vocab": ["en.spm", "fr.spm"]}), + ( + ["--", "--input", "file.in", "--output", "file.out"], + {"input": "file.in", "output": "file.out"}, + ), + ], +) +def test_marian_args_to_dict(marian_args: list[str], dict_value: dict): + assert marian_args_to_dict(marian_args) == dict_value diff --git a/tests/test_ctranslate2.py b/tests/test_ctranslate2.py new file mode 100644 index 000000000..dbf8e054e --- /dev/null +++ b/tests/test_ctranslate2.py @@ -0,0 +1,93 @@ +import shutil +import pytest +from pathlib import Path +from fixtures import DataDir +from pipeline.common.downloads import stream_download_to_file + + +text = """La màfia no va recuperar el seu poder fins al cap de la rendició d'Itàlia en la Segona Guerra Mundial. +En els vuitanta i noranta, una sèrie de disputes internes van portar a la mort a molts membres destacats de la màfia. +Després del final de la Segona Guerra Mundial, la màfia es va convertir en un Estat dins de l'Estat. +Els seus tentacles ja no abastaven només a Sicília, sinó gairebé a tota l'estructura econòmica d'Itàlia, i d'usar escopetes de canons retallats, va passar a disposar d'armament més expeditiu: revòlvers del calibre .357 Magnum, fusells llança-granades, bazookas, i explosius. +La màfia i altres societats secretes del crim organitzat van formar un sistema de vasos comunicants. +En la lògia maçònica P-2, representada pel gran maestre Lici Gelli, hi havia ministres, parlamentaris, generals, jutges, policies, banquers, aristòcrates i fins i tot mafiosos. +En 1992, la màfia siciliana va assassinar al jutge italià Giovanni Falcone fent esclatar mil quilograms d'explosius col·locats sota l'autopista que uneix Palerm amb l'aeroport ara anomenat Giovanni Falcone. +Van morir ell, la seva esposa Francesca Morvilio i tres escortes. +En 1993, cinc ex-presidents de Govern, moltíssims ministres i més de 3000 polítics i empresaris van ser acusats, processaments o condemnats per corrupció i associació amb la màfia. +Es tractava d'un missatge de la màfia al vell Andreotti, ex-president del Govern, per no aturar l'enpresonament masiu dels seus membres. +La màfia no perdona mai, com ja no podran testificar els banquers Michele Sindona i Roberto Calvi, dos mags de les finances del Vaticà, la màfia i altres institucions d'Itàlia. +Van ser assassinats per un rampell de cobdícia, ja que van voler apropiar-se dels diners de la màfia. +El capo di tutti capi és el major rang que pot haver-hi en la Cosa Nostra. +Es tracta del cap d'una família que, en ser més poderós o per haver assassinat als altres caps de les altres famílies, s'ha convertit en el més poderós membre de la màfia. +Un exemple d'això va ser Salvatore Maranzano , qui va ser traït per Lucky Luciano, qui finalment li va cedir el lloc ―en ser extradit per problemes amb la justícia nord-americana― a la seva mà dreta i conseller, Frank Costello. +El don és el cap d'una família. +""" + + +@pytest.fixture +def data_dir(): + data_dir = DataDir("test_translate") + shutil.copyfile("tests/data/vocab.spm", data_dir.join("vocab.spm")) + return data_dir + + +def download_and_cache(data_dir: DataDir, url: str, cached_filename: str, data_dir_name: str): + """ + Download remote language model resources and cache them in the data directory. + """ + src_dir = Path(__file__).parent.parent + cached_file = src_dir / "data/tests" / cached_filename + cached_file.parent.mkdir(parents=True, exist_ok=True) + if not cached_file.exists(): + stream_download_to_file(url, cached_file) + shutil.copy(cached_file, data_dir.join(data_dir_name)) + + +def test_ctranslate2(): + data_dir = DataDir("test_ctranslate2") + data_dir.mkdir("model1") + data_dir.create_zst("file.1.zst", text) + + # Download the teacher models. + download_and_cache( + data_dir, + "https://storage.googleapis.com/releng-translations-dev/models/ca-en/dev/teacher-finetuned1/final.model.npz.best-chrf.npz", + cached_filename="en-ca-teacher-1.npz", + data_dir_name="model1/final.model.npz.best-chrf.npz", + ) + + # Download the vocab. + download_and_cache( + data_dir, + "https://storage.googleapis.com/releng-translations-dev/models/ca-en/dev/vocab/vocab.spm", + cached_filename="en-ca-vocab.spm", + data_dir_name="vocab.spm", + ) + + data_dir.run_task( + "translate-mono-src-en-ru-1/10", + env={"USE_CPU": "true"}, + # Applied before the "--" + extra_flags=["--decoder", "ctranslate2", "--device", "cpu"], + ) + data_dir.print_tree() + + out_lines = data_dir.read_text("artifacts/file.1.out.zst").strip().split("\n") + assert out_lines == [ + "The Mafia did not regain its power until the end of World War II.", + "In the 1990s, a series of internal scandals led to the death of many prominent members of the Mafia.", + "After World War II, the Mafia became a state.", + "The Italians, however, did not concentrate more heavily on the use of steel, but only in the case of the most expensive and costly weaponry: the Babylonian cartridges, with more than 3,500 rifles, were eroded, all of them eroded, including the cryptanalysts, the cylindrical rifles, the cylindrical rifles and the cylindrical rifles.", + "The Mafia and other secret societies formed a system of organized crime.", + "In the Ptolemy II, the Giulio Giulio Giulio Giulio, which included a number of magistrates, magistrates, magistrates, magistrates, magistrates, magistrates, magistrates, magistrates, ministers, even the police, even the police, the police, the police, the police, the police, the police, the police and the police.", + "In 1992, the Italian dictator Giovanni Falcone shot down the 600-year-old paratrooper Giovanni Falcone with a parachute carrying a paratrooper named Giovanni Falcone.", + "He was succeeded by his wife, Francesco Morgas, and three sisters.", + "In 1993, more than three hundred ministers, lawyers and government officials were accused and accused of fraud and corruption.", + "It was a message from Mr Andreotti, the former prime minister, to stop the government's membership.", + "The bankers will not be able to imagine, as Michel Salmond and Pablo Guerrero have said, the bankers of the two banks of the Vatican and the Vatican.", + "They were murdered by a mafia because they wanted to steal money from the mafia.", + "The Capricorn is the largest rank that can be in our heads.", + "It is the head of a family, or more powerful than the head of another family, who has become the most powerful Mafia leader.", + "This was a tragedy that Luciano Margo, who was later persuaded by Frank Prigogore, gave up for Lucca, who was incarcerated by Luca Cortino, who was later incarcerated by Margo.", + "The head is not the head of a family.", + ] diff --git a/tests/test_data_importer.py b/tests/test_data_importer.py index 232e9bc87..47f033f84 100644 --- a/tests/test_data_importer.py +++ b/tests/test_data_importer.py @@ -260,7 +260,7 @@ def test_mono_hplt(language, data_dir: DataDir): assert max_len <= max_characters assert max_len > max_characters - 50 assert ( - json.loads(data_dir.load(f"artifacts/{dataset}.{language}.stats.json")) + json.loads(data_dir.read_text(f"artifacts/{dataset}.{language}.stats.json")) == hplt_stats[language] ) assert [l[:-1] for l in lines[:10]] == hplt_expected[language].split("\n") diff --git a/tests/test_eval.py b/tests/test_eval.py index 8b53d3f02..bfd02ee80 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -76,6 +76,7 @@ def run_eval_test(params) -> None: data_dir = DataDir("test_eval") data_dir.create_zst("wmt09.en.zst", en_sample) data_dir.create_zst("wmt09.ru.zst", ru_sample) + data_dir.create_file("final.model.npz.best-chrf.npz.decoder.yml", "{}") model_path = os.path.join(root_path, "data/models") os.makedirs(model_path, exist_ok=True) @@ -116,20 +117,20 @@ def run_eval_test(params) -> None: # Test that the data files are properly written out. if "backward" in task_name: # Backwards evaluation. - assert data_dir.load("artifacts/wmt09.ru") == ru_sample - assert data_dir.load("artifacts/wmt09.en.ref") == en_sample - assert data_dir.load("artifacts/wmt09.en") == en_fake_translated + assert data_dir.read_text("artifacts/wmt09.ru") == ru_sample + assert data_dir.read_text("artifacts/wmt09.en.ref") == en_sample + assert data_dir.read_text("artifacts/wmt09.en") == en_fake_translated else: # Forwards evaluation. - assert data_dir.load("artifacts/wmt09.en") == en_sample - assert data_dir.load("artifacts/wmt09.ru.ref") == ru_sample - assert data_dir.load("artifacts/wmt09.ru") == ru_fake_translated + assert data_dir.read_text("artifacts/wmt09.en") == en_sample + assert data_dir.read_text("artifacts/wmt09.ru.ref") == ru_sample + assert data_dir.read_text("artifacts/wmt09.ru") == ru_fake_translated # Test that text metrics get properly generated. - assert f"{bleu}\n{chrf}\n{comet}\n" in data_dir.load("artifacts/wmt09.metrics") + assert f"{bleu}\n{chrf}\n{comet}\n" in data_dir.read_text("artifacts/wmt09.metrics") # Test that the JSON metrics get properly generated. - metrics_json = json.loads(data_dir.load("artifacts/wmt09.metrics.json")) + metrics_json = json.loads(data_dir.read_text("artifacts/wmt09.metrics.json")) assert metrics_json["bleu"]["details"]["name"] == "BLEU" assert metrics_json["bleu"]["details"]["score"] == bleu @@ -144,5 +145,5 @@ def run_eval_test(params) -> None: assert metrics_json["comet"]["score"] == comet # Test that marian is given the proper arguments. - marian_decoder_args = json.loads(data_dir.load("marian-decoder.args.txt")) + marian_decoder_args = json.loads(data_dir.read_text("marian-decoder.args.txt")) assert marian_decoder_args == expected_marian_args, "The marian arguments matched." diff --git a/tests/test_merge_corpus.py b/tests/test_merge_corpus.py index ce9d3b704..65408ad05 100644 --- a/tests/test_merge_corpus.py +++ b/tests/test_merge_corpus.py @@ -67,7 +67,7 @@ def assert_dataset(data_dir: DataDir, path: str, sorted_lines: list[str]): "name", ["corpus", "devset"], ) -def test_merge_corpus(data_dir, name): +def test_merge_corpus(data_dir: DataDir, name): data_dir.run_task( # Tasks merge-corpus-en-ru, and merge-devset-en-ru. f"merge-{name}-en-ru", @@ -121,7 +121,7 @@ def test_merge_corpus(data_dir, name): ], ) - assert json.loads(data_dir.load(f"artifacts/{name}.stats.json")) == { + assert json.loads(data_dir.read_text(f"artifacts/{name}.stats.json")) == { "parallel_corpus": { "description": "The parallel corpora are merged and deduplicated", "filtered": 4, @@ -151,7 +151,7 @@ def test_merge_corpus(data_dir, name): "name", ["corpus", "devset"], ) -def test_merge_devset_trimmed(data_dir, name): +def test_merge_devset_trimmed(data_dir: DataDir, name: str): data_dir.run_task( # Tasks merge-corpus-en-ru, and merge-devset-en-ru. f"merge-{name}-en-ru", @@ -193,7 +193,7 @@ def test_merge_devset_trimmed(data_dir, name): ], ) - assert json.loads(data_dir.load(f"artifacts/{name}.stats.json")) == { + assert json.loads(data_dir.read_text(f"artifacts/{name}.stats.json")) == { "parallel_corpus": { "description": "The parallel corpora are merged and deduplicated", "filtered": 4, diff --git a/tests/test_merge_mono.py b/tests/test_merge_mono.py index 4778d65f4..481f0c2f8 100644 --- a/tests/test_merge_mono.py +++ b/tests/test_merge_mono.py @@ -50,7 +50,7 @@ def test_merge_mono(task: str): data_dir.print_tree() - assert json.loads(data_dir.load(f"artifacts/mono.{locale}.stats.json")) == { + assert json.loads(data_dir.read_text(f"artifacts/mono.{locale}.stats.json")) == { "final_truncated_monolingual_lines": { "description": "After truncation via the config's `experiment.mono-max-sentences-src.total`, how many lines are left.", "value": 10, diff --git a/tests/test_split_collect.py b/tests/test_split_collect.py index b4e48debc..5b3f94ed8 100644 --- a/tests/test_split_collect.py +++ b/tests/test_split_collect.py @@ -8,6 +8,7 @@ import sh from fixtures import DataDir +from pipeline.common.datasets import decompress from pipeline.translate.splitter import main as split_file @@ -32,15 +33,11 @@ def generate_dataset(length, path): sh.zstdmt(path) -def decompress(path): - sh.zstdmt("-d", path) - - def imitate_translate(dir, suffix): for file in glob.glob(f"{dir}/file.?.zst") + glob.glob(f"{dir}/file.??.zst"): print(file) - decompress(file) - shutil.copy(file[:-4], file[:-4] + suffix) + uncompressed_file = decompress(file) + shutil.copy(uncompressed_file, str(uncompressed_file) + suffix) def read_file(path): diff --git a/tests/test_translate.py b/tests/test_translate.py new file mode 100644 index 000000000..2121d2fc8 --- /dev/null +++ b/tests/test_translate.py @@ -0,0 +1,150 @@ +import json +from pathlib import Path +import shutil + +import pytest +from fixtures import DataDir, en_sample +from pipeline.common.marian import marian_args_to_dict + +fixtures_path = Path(__file__).parent / "fixtures" + + +@pytest.fixture +def data_dir(): + data_dir = DataDir("test_translate") + shutil.copyfile("tests/data/vocab.spm", data_dir.join("vocab.spm")) + return data_dir + + +def sanitize_marian_args(args_list: list[str]): + """ + Marian args can have details that reflect the host machine or are unique per run. + Sanitize those here. + """ + base_dir = str((Path(__file__).parent / "..").resolve()) + args_dict = marian_args_to_dict(args_list) + for key, value in args_dict.items(): + if isinstance(value, list): + for index, value_inner in enumerate(value): + if isinstance(value_inner, str): + if value_inner.startswith("/tmp"): + value[index] = "/" + Path(value_inner).name + if value_inner.startswith(base_dir): + value[index] = value_inner.replace(base_dir, "") + elif isinstance(value, str): + if value.startswith("/tmp"): + args_dict[key] = "/" + Path(value).name + if value.startswith(base_dir): + args_dict[key] = value.replace(base_dir, "") + + return args_dict + + +def test_translate_corpus(data_dir: DataDir): + data_dir.create_zst("file.1.zst", en_sample) + data_dir.create_file("fake-model.npz", "") + data_dir.run_task( + "translate-corpus-en-ru-1/10", + env={ + "MARIAN": str(fixtures_path), + "TEST_ARTIFACTS": data_dir.path, + }, + ) + data_dir.print_tree() + + output = data_dir.read_text("artifacts/file.1.nbest.zst") + for pseudo_translated in en_sample.upper().split("\n"): + assert pseudo_translated in output + + args = json.loads(data_dir.read_text("marian-decoder.args.txt")) + assert sanitize_marian_args(args) == { + "config": "/pipeline/translate/decoder.yml", + "vocabs": [ + "/data/tests_data/test_translate/vocab.spm", + "/data/tests_data/test_translate/vocab.spm", + ], + "input": "/file.1", + "output": "/file.1.nbest", + "n-best": True, + "log": "/file.1.log", + "devices": ["0", "1", "2", "3"], + "workspace": "12000", + "mini-batch-words": "4000", + "precision": "float16", + "models": "/data/tests_data/test_translate/fake-model.npz", + } + + +def test_translate_corpus_empty(data_dir: DataDir): + """ + Test the case of an empty file. + """ + data_dir.create_zst("file.1.zst", "") + data_dir.create_file("fake-model.npz", "") + data_dir.run_task( + "translate-corpus-en-ru-1/10", + env={ + "MARIAN": str(fixtures_path), + "TEST_ARTIFACTS": data_dir.path, + }, + ) + + data_dir.print_tree() + + assert data_dir.read_text("artifacts/file.1.nbest.zst") == "", "The text is empty" + + +mono_args = { + "src": { + "config": "/pipeline/translate/decoder.yml", + "vocabs": [ + "/data/tests_data/test_translate/vocab.spm", + "/data/tests_data/test_translate/vocab.spm", + ], + "input": "/file.1", + "output": "/file.1.out", + "log": "/file.1.log", + "devices": ["0", "1", "2", "3"], + "workspace": "12000", + "mini-batch-words": "4000", + "precision": "float16", + "models": "/data/tests_data/test_translate/fake-model.npz", + }, + "trg": { + "beam-size": "12", + "config": "/pipeline/translate/decoder.yml", + "vocabs": [ + "/data/tests_data/test_translate/vocab.spm", + "/data/tests_data/test_translate/vocab.spm", + ], + "input": "/file.1", + "output": "/file.1.out", + "log": "/file.1.log", + "devices": ["0", "1", "2", "3"], + "workspace": "12000", + "mini-batch-words": "2000", + "models": "/data/tests_data/test_translate/fake-model.npz", + }, +} + + +@pytest.mark.parametrize("direction", ["src", "trg"]) +def test_translate_mono(direction: str, data_dir: DataDir): + data_dir.create_zst("file.1.zst", en_sample) + data_dir.create_file("fake-model.npz", "") + data_dir.print_tree() + data_dir.run_task( + f"translate-mono-{direction}-en-ru-1/10", + env={ + "MARIAN": str(fixtures_path), + "TEST_ARTIFACTS": data_dir.path, + }, + ) + data_dir.print_tree() + + assert ( + data_dir.read_text("artifacts/file.1.out.zst") == en_sample.upper() + ), "The text is pseudo-translated" + + args = json.loads(data_dir.read_text("marian-decoder.args.txt")) + assert sanitize_marian_args(args) == mono_args[direction]