Skip to content

Commit

Permalink
Switch decompression zstd python package
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Jul 1, 2024
1 parent dcaf79b commit 1dd8036
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 9 deletions.
4 changes: 2 additions & 2 deletions petagraph/configs/config_petagraph_multi_node.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
checkpoints:
checkpoint_interval: 500
checkpoints_path: /users/burgerm/petagraph/logs/wgs_fungi/multi/checkpoints
checkpoints_path: /users/burgerm/petagraph/logs/wgs_fungi/multi-2/checkpoints
checkpoints_path_is_shared_file_system: true
resume_checkpoint_path: null
save_initial_state: false
Expand Down Expand Up @@ -73,7 +73,7 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 16 # 32
dp: 32
expert_parallel_size: 1
pp: 1
pp_engine: 1f1b
Expand Down
2 changes: 1 addition & 1 deletion petagraph/scripts/run_multi_node.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#!/bin/bash
#SBATCH --job-name=petagraph # create a short name for your job
#SBATCH --nodes=4 # total number of nodes
#SBATCH --nodes=8 # total number of nodes
#SBATCH --ntasks-per-node=1 # total number of tasks per node
#SBATCH --gpus-per-task=4
#SBATCH --time=12:00:00
Expand Down
26 changes: 21 additions & 5 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,12 @@
import torch
import random
from tqdm import tqdm
import zstd
import numpy as np
from typing import Dict, Optional, Tuple

# import zstd
import zstandard

from pathlib import Path
from Bio import SeqIO
from io import StringIO
Expand Down Expand Up @@ -84,13 +86,13 @@ def __init__(self,
# are required. For detail, please check our tutorial in:
# https://pytorch.org/data/main/tutorial.html#working-with-dataloader
dp_s3_urls = IterableWrapper(url_list) # .list_files_by_s3()
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter()
sharded_s3_urls = dp_s3_urls.shuffle().sharding_filter().cycle()

# opened_files = S3FileLoader(sharded_s3_urls)
opened_files = FSSpecFileOpener(sharded_s3_urls, mode="rb")

else:
files_names = IterableWrapper(url_list).shuffle().sharding_filter()
files_names = IterableWrapper(url_list).shuffle().sharding_filter().cycle()
opened_files = FileOpener(files_names, mode="rb")

decoded_files = StreamReader(opened_files)
Expand Down Expand Up @@ -151,7 +153,15 @@ def decompression_func(self, input_data):
# if self.debug:
# self.logging_func(f"[{self.__class__.__name__}] Decompressing {path}")

decompressed_data = zstd.decompress(data)
# decompressed_data = zstd.decompress(data)

try:
dctx = zstandard.ZstdDecompressor()
decompressed_data = dctx.decompress(data)
except Exception as e:
self.logger.warning(f"[PetaGraphStreamDataset] Error decompressing {path}: {e}")
return path, None

# if self.debug:
# num_mb_compressed = len(data) / 1024 / 1024
# num_mb_decompressed = len(decompressed_data) / 1024 / 1024
Expand All @@ -162,6 +172,9 @@ def decompression_func(self, input_data):
def fasta_parsing_func(self, input_data):
path, data = input_data

if data is None:
return [[]]

# if self.debug:
# self.logging_func(f"[{self.__class__.__name__}] Parsing {path}")

Expand Down Expand Up @@ -224,6 +237,9 @@ def generate(self):

try:
source_path, text_raw = next(self.iterable_dataset)
if text_raw is None or len(text_raw) == 0:
continue

if self.log_directory is not None:
if source_path not in self.consumed_files:
out_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"
Expand All @@ -232,7 +248,7 @@ def generate(self):
self.consumed_files.add(source_path)

except StopIteration:
self.logger.warning(f"Reached end of dataset, restarting from the beginning")
self.logger.warning(f"Reached end of dataset")

if not self.packed:

Expand Down
3 changes: 2 additions & 1 deletion src/nanotron/scaling/parametrization.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(self, config: ModelArgs):
self.std = config.init_method.std
self.num_layers = config.model_config.num_hidden_layers

if hasattr(config.init_method, "truncated_normal_bound"):
if hasattr(config.init_method, "truncated_normal_bound") and config.init_method.truncated_normal_bound is not None:
self.truncated_normal = True
self.trunc_bound = config.init_method.truncated_normal_bound
log_rank(
Expand All @@ -60,6 +60,7 @@ def __init__(self, config: ModelArgs):
)
else:
self.truncated_normal = False
self.trunc_bound = None

def _parametrize_column_linear(self, param_name: str, module: nn.Module):
assert param_name in ["weight", "bias"]
Expand Down
5 changes: 5 additions & 0 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,11 +599,16 @@ def train_step_logs(
current_dataset = dataloaders[list(dataloaders.keys())[0]].dataset
else:
current_dataset = dataloaders.dataset

if hasattr(current_dataset, "consumed_seq_len_queue"):
consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64)
log_entries.append(LogItem("consumed_seq_lens_median", np.median(consumed_seq_lens), "human_format"))
log_entries.append(LogItem("consumed_seq_lens_max", np.max(consumed_seq_lens), "human_format"))

if hasattr(current_dataset, "consumed_files"):
num_consumed_files = len(current_dataset.consumed_files)
log_entries.append(LogItem("num_consumed_files", num_consumed_files, "human_format"))

if self.config.optimizer.clip_grad is not None:
log_entries.append(LogItem("grad_norm", self.grad_norm_unclipped.item(), "human_format")) # , ".3f"))

Expand Down

0 comments on commit 1dd8036

Please sign in to comment.