Skip to content

Commit

Permalink
Rewrite the train scripts and add config support for ctranslate2 (#922)
Browse files Browse the repository at this point in the history
* WANDB Test failure

* Rename DataDir.load to DataDir.read_text and allow for reading compressed files

* Add compress and decompress common utilities

* Use decompression utilities everywhere

* Re-work the marian-decoder fixture to correctly output nbest

* Rewrite translate.sh to python

* Add a requirements file for ctranslate2

* Add support for ctranslate2

* Add gpustats to the train requirements

* Add logging for translations

* Remove old translate scripts

* Handle review feedback
  • Loading branch information
gregtatum authored Dec 20, 2024
1 parent cfbaf72 commit 8977fbf
Show file tree
Hide file tree
Showing 38 changed files with 1,588 additions and 197 deletions.
11 changes: 5 additions & 6 deletions pipeline/alignments/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from tqdm import tqdm

from pipeline.alignments.tokenizer import tokenize_moses
from pipeline.common.datasets import decompress
from pipeline.common.logging import get_logger

logger = get_logger("alignments")
Expand Down Expand Up @@ -60,8 +61,8 @@ def run(
tmp_dir = os.path.join(os.path.dirname(output_path), "tmp")
os.makedirs(tmp_dir, exist_ok=True)

corpus_src = decompress(corpus_src)
corpus_trg = decompress(corpus_trg)
corpus_src = maybe_decompress(corpus_src)
corpus_trg = maybe_decompress(corpus_trg)

if tokenization == Tokenization.moses:
tokenized_src = (
Expand Down Expand Up @@ -121,11 +122,9 @@ def run(
shutil.move(output_aln, output_path)


def decompress(file_path: str):
def maybe_decompress(file_path: str):
if file_path.endswith(".zst"):
logger.info(f"Decompressing file {file_path}")
subprocess.check_call(["zstdmt", "-d", "-f", "--rm", file_path])
return file_path[:-4]
return str(decompress(file_path, remove=True, logger=logger))
return file_path


Expand Down
38 changes: 37 additions & 1 deletion pipeline/common/command_runner.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import re
from shlex import join
import shlex
Expand Down Expand Up @@ -62,7 +63,7 @@ def run_command_pipeline(
it will log the constructed pipeline commands. Defaults to None.
Example:
python_scripts = run_pipeline(
python_scripts = run_command_pipeline(
[
["ls", "-l"],
["grep", ".py"],
Expand Down Expand Up @@ -94,3 +95,38 @@ def run_command_pipeline(
return subprocess.check_output(command_string, shell=True).decode("utf-8")

subprocess.check_call(command_string, shell=True)


def run_command(
command: list[str], capture=False, shell=False, logger=None, env=None
) -> str | None:
"""
Runs a command and outputs a nice representation of the command to a logger, if supplied.
Args:
command: The command arguments provided to subprocess.check_call
capture: If True, captures and returns the output of the final command in the
pipeline. If False, output is printed to stdout.
logger: A logger instance used for logging the command execution. If provided,
it will log the pipeline commands.
env: The environment object.
Example:
directory_listing = run_command(
["ls", "-l"],
capture=True
)
"""
# Expand any environment variables.
command = [os.path.expandvars(part) for part in command]

if logger:
# Log out a nice representation of this command.
logger.info("Running:")
for line in _get_indented_command_string(command).split("\n"):
logger.info(line)

if capture:
return subprocess.check_output(command).decode("utf-8")

subprocess.check_call(command, env=env)
111 changes: 110 additions & 1 deletion pipeline/common/datasets.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from collections.abc import Iterable
import hashlib
import json
from logging import Logger
import os
import subprocess
import tempfile
from dataclasses import dataclass
from io import TextIOWrapper
from pathlib import Path
from random import Random
from typing import Callable, Iterator, Optional, Set, Union
from typing import Callable, Iterator, Literal, Optional, Set, Union
from urllib.parse import urlparse
import unicodedata

Expand Down Expand Up @@ -444,3 +446,110 @@ def _hash_string(string: str) -> int:
"""
cleaned_line = unicodedata.normalize("NFC", string.strip())
return hash(cleaned_line)


def decompress(
source: Union[str, Path],
destination: Optional[Union[Path, str]] = None,
remove: bool = False,
logger: Optional[Logger] = None,
) -> Path:
"""
Decompresses a file using the appropriate command based on its file extension.
Args:
file_path: The path to the file to be decompressed
remove: If set to `True`, the original compressed file will be removed after decompression.
destination: Be default the file will be decompressed next to the original. This arguments
allows for overriding the destination.
logger: Log information about the decompression
"""
if isinstance(source, str):
source = Path(source)
if not destination:
destination = source.parent / source.stem

if logger:
logger.info(f"[decompress] From: {source}")
logger.info(f"[decompress] To: {destination}")

if source.suffix == ".zst":
command = ["zstdmt", "--decompress", "--force", "-o", destination, source]
if remove:
command.append("--rm")

subprocess.check_call(command)
elif source.suffix == ".gz":
command = ["gzip", "-c", "-d", source]
with open(destination, "wb") as out_file:
subprocess.check_call(command, stdout=out_file)
if remove:
source.unlink()
else:
raise Exception(f"Unknown file type to decompress: {source}")

if remove:
logger.info(f"[decompress] Removed: {source}")

return destination


def compress(
source: Union[str, Path],
destination: Optional[Union[Path, str]] = None,
remove: bool = False,
compression_type: Union[Literal["zst"], Literal["gz"]] = None,
logger: Optional[Logger] = None,
) -> Path:
"""
Compress a file using the appropriate command based on its file extension.
Args:
source: The path to the file to be compressed
destination: Be default the file will be compressed next to the original. This arguments
allows for overriding the destination.
remove: If set to `True`, the original decompressed file will be removed.
type: The type defaults to "zst", and is implied by the destination, however it can
be explicitly set.
logger: Log information about the compression
"""
if isinstance(source, str):
source = Path(source)
if isinstance(destination, str):
destination = Path(destination)

# Ensure the compression type is valid and present
if compression_type and destination:
assert f".{type}" == destination.suffix, "The compression type and destination must match."

if not compression_type:
if destination:
compression_type = destination.suffix[1:]
else:
compression_type = "zst"

# Set default destination if not provided
if not destination:
destination = source.with_suffix(f"{source.suffix}.{compression_type}")

if logger:
logger.info(f"Compressing: {source}")
logger.info(f"Destination: {destination}")

if compression_type == "zst":
command = ["zstdmt", "--compress", "--force", "--quiet", source, "-o", destination]
if remove:
command.append("--rm")
subprocess.check_call(command)
elif compression_type == "gz":
with open(destination, "wb") as out_file:
subprocess.check_call(["gzip", "-c", "--force", source], stdout=out_file)
if remove:
source.unlink()
else:
raise ValueError(f"Unsupported compression type: {compression_type}")

if remove:
logger.info(f"Removed {source}")

return destination
13 changes: 13 additions & 0 deletions pipeline/common/downloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,19 @@ def count_lines(path: Path | str) -> int:
return sum(1 for _ in lines)


def is_file_empty(path: Path | str) -> bool:
"""
Attempts to read a line to determine if a file is empty or not. Works on local or remote files
as well as compressed or uncompressed files.
"""
with read_lines(path) as lines:
try:
next(lines)
return False
except StopIteration:
return True


def get_file_size(location: Union[Path, str]) -> int:
"""Get the size of a file, whether it is remote or local."""
if str(location).startswith("http://") or str(location).startswith("https://"):
Expand Down
104 changes: 104 additions & 0 deletions pipeline/common/logging.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
import logging
from pathlib import Path
import subprocess
import threading
import time

logging.basicConfig(level=logging.INFO, format="[%(name)s] %(message)s")

STOP_BYTE_COUNT_LOGGER = False
STOP_GPU_LOGGER = False


def get_logger(name: str):
"""
Expand All @@ -21,3 +27,101 @@ def get_logger(name: str):
logger = logging.getLogger(Path(name).stem)
logger.setLevel(logging.INFO)
return logger


def _log_gpu_stats(logger: logging.Logger, interval_seconds: int):
# Only load gpustat when it's needed.
import gpustat

global STOP_GPU_LOGGER
while True:
time.sleep(interval_seconds)
if STOP_GPU_LOGGER:
STOP_GPU_LOGGER = False
return
try:
logger.info("[gpu] Current GPU stats:")
gpustat.print_gpustat()
except subprocess.CalledProcessError as e:
logger.error(f"Failed to retrieve GPU stats: {e}")


def stop_gpu_logging():
global STOP_GPU_LOGGER
STOP_GPU_LOGGER = True


def start_gpu_logging(logger: logging.Logger, interval_seconds: int):
"""Logs GPU stats on an interval using gpustat in a background thread."""
assert not STOP_GPU_LOGGER, "A gpu logger should not already be running"

thread = threading.Thread(
target=_log_gpu_stats,
# Set as a daemon thread so it automatically is closed on shutdown.
daemon=True,
args=(logger, interval_seconds),
)
thread.start()


def _log_byte_rate(logger: logging.Logger, interval_seconds: int, file_path: Path):
global STOP_BYTE_COUNT_LOGGER
previous_byte_count = 0
previous_time = time.time()
is_zst = file_path.suffix == ".zst"

while True:
time.sleep(interval_seconds)
if STOP_BYTE_COUNT_LOGGER:
STOP_BYTE_COUNT_LOGGER = False
return

try:
if is_zst:
# This takes ~1 second to run on 5 million sentences.
current_byte_count = 0
cmd = ["zstd", "-dc", str(file_path)]
with subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE
) as process:
for chunk in iter(lambda: process.stdout.read(8192), b""):
current_byte_count += len(chunk)
else:
# This is pretty much instantaneous.
result = subprocess.run(
["wc", "-c", str(file_path)], capture_output=True, text=True, check=True
)
current_byte_count = int(result.stdout.split()[0])

bytes_added = current_byte_count - previous_byte_count

current_secs = time.time()
elapsed_secs = current_secs - previous_time
byte_rate = bytes_added / elapsed_secs if bytes_added > 0 else 0

logger.info(f"[bytes] Added: {bytes_added:,}")
logger.info(f"[bytes] Total: {current_byte_count:,}")
logger.info(f"[bytes] Rate: {byte_rate:,.2f} bytes/second")

previous_byte_count = current_byte_count
previous_time = time.time()
except Exception as e:
logger.error(f"Failed to monitor byte count: {e}")


def stop_byte_count_logger():
global STOP_BYTE_COUNT_LOGGER
STOP_BYTE_COUNT_LOGGER = True


def start_byte_count_logger(logger: logging.Logger, interval_seconds: int, file_path: Path):
"""
Monitors the rate of bytes being added to a file, logging the number of bytes
added per second over the interval.
"""

assert not STOP_BYTE_COUNT_LOGGER, "A line count logger should not already be running"
thread = threading.Thread(
target=_log_byte_rate, args=(logger, interval_seconds, file_path), daemon=True
)
thread.start()
Loading

0 comments on commit 8977fbf

Please sign in to comment.