Skip to content

Commit

Permalink
workaround for md5 variations (axolotl-ai-cloud#533)
Browse files Browse the repository at this point in the history
* workaround for md5 variations

* refactor the prepared hash too
  • Loading branch information
winglian authored Sep 8, 2023
1 parent 4b6cd09 commit b4dca17
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 13 deletions.
28 changes: 15 additions & 13 deletions src/axolotl/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import functools
import hashlib
import logging
from hashlib import md5
from pathlib import Path
from typing import Tuple, Union

Expand Down Expand Up @@ -52,6 +51,13 @@
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"


def md5(to_hash: str, encoding: str = "utf-8") -> str:
try:
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
except TypeError:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec


def prepare_dataset(cfg, tokenizer):
if not cfg.pretraining_dataset:
with zero_first(is_main_process()):
Expand Down Expand Up @@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5( # nosec
md5(
(
str(cfg.sequence_len)
+ "@"
Expand All @@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
)
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
Expand Down Expand Up @@ -374,7 +380,7 @@ def load_prepare_datasets(
# see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
ds_hash = str(
md5( # nosec
md5(
(
str(cfg.sequence_len)
+ "@"
Expand All @@ -385,8 +391,8 @@ def load_prepare_datasets(
)
+ "|"
+ tokenizer_name
).encode("utf-8")
).hexdigest()
)
)
)
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
Expand Down Expand Up @@ -500,12 +506,8 @@ def load_prepare_datasets(
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
train_fingerprint = md5(to_hash_train)
test_fingerprint = md5(to_hash_test)

with zero_first(is_main_process()):
dataset = dataset.train_test_split(
Expand Down
64 changes: 64 additions & 0 deletions tests/test_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
test module for the axolotl.utis.data module
"""
import unittest

from transformers import LlamaTokenizer

from axolotl.utils.data import encode_pretraining, md5


class TestEncodePretraining(unittest.TestCase):
"""
test class for encode pretraining and md5 helper
"""

def setUp(self):
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"eos_token": "</s>",
"bos_token": "<s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
self.max_tokens = 15 # set a small number for easy inspection

def test_encode_pretraining(self):
examples = {
"text": [
"Hello, world!",
"Nice to meet you.",
"lorem ipsum dolor sit amet.",
"Nice to meet you again!.",
"hello, hello",
]
}
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)

self.assertEqual(len(result["input_ids"]), 3)

# Assert the length of input_ids and attention_mask is correct
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)

# Assert EOS and PAD tokens are correctly added
# hello world! is 4 tokens
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
# second part, 5 tokens
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)

def test_md5(self):
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
self.assertEqual(
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
)


if __name__ == "__main__":
unittest.main()

0 comments on commit b4dca17

Please sign in to comment.