diff --git a/src/axolotl/utils/data/__init__.py b/src/axolotl/utils/data/__init__.py new file mode 100644 index 0000000000..1015c8370a --- /dev/null +++ b/src/axolotl/utils/data/__init__.py @@ -0,0 +1,15 @@ +""" +Data processing modules +""" +from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401 +from axolotl.utils.data.pretraining import ( # noqa: F401 + encode_pretraining, + wrap_pretraining_dataset, +) +from axolotl.utils.data.sft import ( # noqa: F401 + get_dataset_wrapper, + load_prepare_datasets, + load_tokenized_prepared_datasets, + prepare_dataset, +) +from axolotl.utils.data.utils import md5 # noqa: F401 diff --git a/src/axolotl/utils/data/dpo.py b/src/axolotl/utils/data/dpo.py new file mode 100644 index 0000000000..765a3fc374 --- /dev/null +++ b/src/axolotl/utils/data/dpo.py @@ -0,0 +1,114 @@ +"""data handling specific to DPO""" + +import logging +from pathlib import Path +from typing import Any, List + +import yaml +from datasets import concatenate_datasets, load_dataset, load_from_disk + +from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH +from axolotl.prompt_strategies.dpo import load as load_dpo +from axolotl.utils.data.utils import md5 +from axolotl.utils.dict import DictDefault +from axolotl.utils.distributed import is_main_process, zero_first + +LOG = logging.getLogger("axolotl") + + +def _get_path(ds_hash, cfg): + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + ) + + return prepared_ds_path + + +def _load_preprocessed_ds(cfg, sub_cfg): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + dataset = None + + # pylint: disable=duplicate-code + if ( + cfg.dataset_prepared_path + and any(prepared_ds_path.glob("*")) + and not cfg.is_preprocess + ): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset = load_from_disk(str(prepared_ds_path)) + + return dataset + + +def _save_preprocessed_ds(cfg, sub_cfg, dataset): + ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) + prepared_ds_path = _get_path(ds_hash, cfg) + + if cfg.is_preprocess and is_main_process(): + LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") + dataset.save_to_disk(str(prepared_ds_path)) + + +def load_prepare_dpo_datasets(cfg): + def load_split(dataset_cfgs, _cfg): + split_datasets: List[Any] = [] + for i, ds_cfg in enumerate(dataset_cfgs): + if ds_cfg["ds_type"] == "json": + for data_file in ds_cfg["data_files"]: + data_files = {ds_cfg["split"]: data_file} + ds = load_dataset( # pylint: disable=invalid-name + "json", + data_files=data_files, + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + else: + ds = load_dataset( # pylint: disable=invalid-name + ds_cfg["path"], + split=ds_cfg["split"], + ) + split_datasets.insert(i, ds) + + for i, data_set in enumerate(split_datasets): + _type = dataset_cfgs[i]["type"] + if _type: + if isinstance(_type, DictDefault): + _type = "user_defined.default" + ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) + split_datasets[i] = data_set.map( + ds_transform_fn, + desc="Mapping RL Dataset", + ) + else: + # If no `type` is provided, assume the dataset is already in the expected format with + # "prompt", "chosen" and "rejected" already preprocessed + split_datasets[i] = data_set + + return concatenate_datasets(split_datasets) + + with zero_first(is_main_process()): + train_is_preprocessed = False + eval_is_preprocessed = False + if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): + train_is_preprocessed = True + else: + train_dataset = load_split(cfg.datasets, cfg) + + eval_dataset = None + if cfg.test_datasets: + if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): + eval_is_preprocessed = True + else: + eval_dataset = load_split(cfg.test_datasets, cfg) + if not eval_dataset: + eval_dataset = None + + if not train_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) + if eval_dataset and not eval_is_preprocessed: + _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) + + return train_dataset, eval_dataset diff --git a/src/axolotl/utils/data/pretraining.py b/src/axolotl/utils/data/pretraining.py new file mode 100644 index 0000000000..544ed13162 --- /dev/null +++ b/src/axolotl/utils/data/pretraining.py @@ -0,0 +1,232 @@ +"""data handling specific to pretraining""" + +import functools +import logging +from collections import defaultdict +from typing import Callable, Dict, List, Optional + +import torch +from datasets import Dataset +from torch.utils.data import RandomSampler +from transformers import PreTrainedTokenizerBase + +from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.trainer import process_pretraining_datasets_for_packing + +LOG = logging.getLogger("axolotl") + + +def encode_pretraining( + tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] +) -> Dict[str, List]: + res = tokenizer( + examples, + truncation=True, + max_length=max_tokens - 2, + add_special_tokens=True, + ) + # Convert to PyTorch tensors + input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + new_input_ids = [] + new_attention_mask = [] + # Append EOS and PAD tokens to input_ids, and correct attention_mask + for i, _ in enumerate(input_ids): + input_ids[i] = torch.cat( + ( + input_ids[i], + torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + ), + dim=0, + ) + attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) + + # Concatenate tokens so that their lengths are less than max_tokens + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + for ids, mask in zip(input_ids, attention_mask): + if buffer_input_ids.numel() == max_tokens: + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + elif buffer_input_ids.numel() + ids.numel() <= max_tokens: + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + else: + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + + if buffer_input_ids.numel() > 0: # for any leftover tokens + while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + + ret = { + "input_ids": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_input_ids], + "attention_mask": [seq.tolist() for seq in new_attention_mask], + } + + LOG.debug(len(ret["input_ids"])) + return ret + + +def wrap_pretraining_dataset( + dataset, + tokenizer, + cfg, + ds_wrapper_fn, + max_tokens=2048, + batch_size=1, + seed=42, + buffer_size=10_000, +): + if cfg.sample_packing: + collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( + tokenizer, + return_tensors="pt", + padding=True, + pad_to_multiple_of=max_tokens * batch_size, + multipack_attn=cfg.pretrain_multipack_attn, + ) + encode = functools.partial( + encode_packed_pretraining, + collate_fn, + ds_wrapper_fn, + max_seq_length=max_tokens, + batch_size=batch_size, + multipack_attn=cfg.pretrain_multipack_attn, + ) + # set this to 1 so downstream data_loader doesn't try to increase the batch again + cfg.micro_batch_size = 1 + else: + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + + if cfg.shuffle_merged_datasets: + dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) + else: + LOG.debug("NOT shuffling merged pretraining datasets") + + # remove all the existing columns after mapping since they end up having + # a different length than the encoded/tokenized column + # this is empty during streaming/pretraining + remove_columns = [] + if dataset.features is None: + for first_row in dataset: + remove_columns = first_row.keys() + break + else: + remove_columns = dataset.features.keys() + + dataset = dataset.map( + encode, + batched=True, + batch_size=buffer_size, + # input_columns="text", + remove_columns=remove_columns, + ) + return dataset + + +def encode_packed_pretraining( + collate_fn, + ds_wrapper: Callable, + examples: Dict[str, List], + max_seq_length: int = 2048, + batch_size: int = 4, + multipack_attn: Optional[bool] = False, +) -> Dict[str, List]: + # pylint: disable=duplicate-code + # tokenize all the examples + # rows get split with stride (overlap) + train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] + + train_dataset = process_pretraining_datasets_for_packing( + train_dataset, + max_seq_length, + skip_position_ids=not multipack_attn, + ) + + sampler = MultipackBatchSampler( + RandomSampler(train_dataset), + batch_size=1, + drop_last=True, + batch_max_len=batch_size * max_seq_length, + lengths=get_dataset_lengths(train_dataset), + ) + + chunked_data = defaultdict(list) + + for batch in sampler: + for data in batch: + features = train_dataset[data] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "num_truncated_tokens" in features: + del features["num_truncated_tokens"] + if "overflow_to_sample_mapping" in features: + del features["overflow_to_sample_mapping"] + if "labels" not in features: + features["labels"] = features["input_ids"].copy() + collated_features = collate_fn(features) + + for feature in features.keys(): + if feature == "length": + continue + chunked_data[feature].append(collated_features[feature].squeeze(0)) + + return chunked_data diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data/sft.py similarity index 67% rename from src/axolotl/utils/data.py rename to src/axolotl/utils/data/sft.py index 5f13a2a63f..9423b9a6f3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data/sft.py @@ -1,14 +1,10 @@ -"""Module containing data utilities""" +"""data handling specific to SFT""" import functools -import hashlib import logging -from collections import defaultdict from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union -import torch -import yaml from datasets import ( Dataset, DatasetDict, @@ -18,13 +14,11 @@ ) from huggingface_hub import hf_hub_download from huggingface_hub.utils import HFValidationError -from torch.utils.data import RandomSampler from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.datasets import TokenizedPromptDataset from axolotl.prompt_strategies import load -from axolotl.prompt_strategies.dpo import load as load_dpo from axolotl.prompt_tokenizers import ( AlpacaMultipleChoicePromptTokenizingStrategy, AlpacaPromptTokenizingStrategy, @@ -45,26 +39,18 @@ SummarizeTLDRPrompter, UnsupportedPrompter, ) -from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq +from axolotl.utils.data.pretraining import wrap_pretraining_dataset +from axolotl.utils.data.utils import md5 from axolotl.utils.dict import DictDefault from axolotl.utils.distributed import is_main_process, zero_first -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths from axolotl.utils.trainer import ( calculate_total_num_steps, process_datasets_for_packing, - process_pretraining_datasets_for_packing, ) LOG = logging.getLogger("axolotl") -def md5(to_hash: str, encoding: str = "utf-8") -> str: - try: - return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() - except TypeError: - return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec - - def prepare_dataset(cfg, tokenizer): prompters = [] if not cfg.pretraining_dataset: @@ -182,6 +168,7 @@ def load_tokenized_prepared_datasets( except Exception: # pylint: disable=broad-except # nosec pass + # pylint: disable=duplicate-code if dataset: ... elif ( @@ -691,315 +678,3 @@ def get_dataset_wrapper( ) return dataset_wrapper, dataset_prompter - - -def encode_pretraining( - tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str] -) -> Dict[str, List]: - res = tokenizer( - examples, - truncation=True, - max_length=max_tokens - 2, - add_special_tokens=True, - ) - # Convert to PyTorch tensors - input_ids = [torch.tensor(seq) for seq in res["input_ids"]] - attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] - new_input_ids = [] - new_attention_mask = [] - # Append EOS and PAD tokens to input_ids, and correct attention_mask - for i, _ in enumerate(input_ids): - input_ids[i] = torch.cat( - ( - input_ids[i], - torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), - ), - dim=0, - ) - attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) - - # Concatenate tokens so that their lengths are less than max_tokens - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - - for ids, mask in zip(input_ids, attention_mask): - if buffer_input_ids.numel() == max_tokens: - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - elif buffer_input_ids.numel() + ids.numel() <= max_tokens: - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - else: - buffer_input_ids = torch.cat( - ( - buffer_input_ids, - torch.full( - (max_tokens - buffer_input_ids.numel(),), - tokenizer.pad_token_id, - dtype=torch.long, - ), - ), - dim=0, - ) - buffer_attention_mask = torch.cat( - ( - buffer_attention_mask, - torch.full( - (max_tokens - buffer_attention_mask.numel(),), - 0, - dtype=torch.long, - ), - ), - dim=0, - ) - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - buffer_input_ids = torch.tensor([], dtype=torch.long) - buffer_attention_mask = torch.tensor([], dtype=torch.long) - - buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) - buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) - - if buffer_input_ids.numel() > 0: # for any leftover tokens - while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size - buffer_input_ids = torch.cat( - ( - buffer_input_ids, - torch.full( - (max_tokens - buffer_input_ids.numel(),), - tokenizer.pad_token_id, - dtype=torch.long, - ), - ), - dim=0, - ) - buffer_attention_mask = torch.cat( - ( - buffer_attention_mask, - torch.full( - (max_tokens - buffer_attention_mask.numel(),), - 0, - dtype=torch.long, - ), - ), - dim=0, - ) - new_input_ids.append(buffer_input_ids) - new_attention_mask.append(buffer_attention_mask) - - ret = { - "input_ids": [seq.tolist() for seq in new_input_ids], - "labels": [seq.tolist() for seq in new_input_ids], - "attention_mask": [seq.tolist() for seq in new_attention_mask], - } - - LOG.debug(len(ret["input_ids"])) - return ret - - -def wrap_pretraining_dataset( - dataset, - tokenizer, - cfg, - ds_wrapper_fn, - max_tokens=2048, - batch_size=1, - seed=42, - buffer_size=10_000, -): - if cfg.sample_packing: - collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq( - tokenizer, - return_tensors="pt", - padding=True, - pad_to_multiple_of=max_tokens * batch_size, - multipack_attn=cfg.pretrain_multipack_attn, - ) - encode = functools.partial( - encode_packed_pretraining, - collate_fn, - ds_wrapper_fn, - max_seq_length=max_tokens, - batch_size=batch_size, - multipack_attn=cfg.pretrain_multipack_attn, - ) - # set this to 1 so downstream data_loader doesn't try to increase the batch again - cfg.micro_batch_size = 1 - else: - encode = functools.partial(encode_pretraining, tokenizer, max_tokens) - - if cfg.shuffle_merged_datasets: - dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) - else: - LOG.debug("NOT shuffling merged pretraining datasets") - - # remove all the existing columns after mapping since they end up having - # a different length than the encoded/tokenized column - # this is empty during streaming/pretraining - remove_columns = [] - if dataset.features is None: - for first_row in dataset: - remove_columns = first_row.keys() - break - else: - remove_columns = dataset.features.keys() - - dataset = dataset.map( - encode, - batched=True, - batch_size=buffer_size, - # input_columns="text", - remove_columns=remove_columns, - ) - return dataset - - -def encode_packed_pretraining( - collate_fn, - ds_wrapper: Callable, - examples: Dict[str, List], - max_seq_length: int = 2048, - batch_size: int = 4, - multipack_attn: Optional[bool] = False, -) -> Dict[str, List]: - # pylint: disable=duplicate-code - # tokenize all the examples - # rows get split with stride (overlap) - train_dataset = ds_wrapper(Dataset.from_dict(examples))[0] - - train_dataset = process_pretraining_datasets_for_packing( - train_dataset, - max_seq_length, - skip_position_ids=not multipack_attn, - ) - - sampler = MultipackBatchSampler( - RandomSampler(train_dataset), - batch_size=1, - drop_last=True, - batch_max_len=batch_size * max_seq_length, - lengths=get_dataset_lengths(train_dataset), - ) - - chunked_data = defaultdict(list) - - for batch in sampler: - for data in batch: - features = train_dataset[data] - if "num_truncated_tokens" in features: - del features["num_truncated_tokens"] - if "num_truncated_tokens" in features: - del features["num_truncated_tokens"] - if "overflow_to_sample_mapping" in features: - del features["overflow_to_sample_mapping"] - if "labels" not in features: - features["labels"] = features["input_ids"].copy() - collated_features = collate_fn(features) - - for feature in features.keys(): - if feature == "length": - continue - chunked_data[feature].append(collated_features[feature].squeeze(0)) - - return chunked_data - - -def _get_path(ds_hash, cfg): - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash - ) - - return prepared_ds_path - - -def _load_preprocessed_ds(cfg, sub_cfg): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - dataset = None - - if ( - cfg.dataset_prepared_path - and any(prepared_ds_path.glob("*")) - and not cfg.is_preprocess - ): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset = load_from_disk(str(prepared_ds_path)) - - return dataset - - -def _save_preprocessed_ds(cfg, sub_cfg, dataset): - ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper)) - prepared_ds_path = _get_path(ds_hash, cfg) - - if cfg.is_preprocess and is_main_process(): - LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") - dataset.save_to_disk(str(prepared_ds_path)) - - -def load_prepare_dpo_datasets(cfg): - def load_split(dataset_cfgs, _cfg): - split_datasets: List[Any] = [] - for i, ds_cfg in enumerate(dataset_cfgs): - if ds_cfg["ds_type"] == "json": - for data_file in ds_cfg["data_files"]: - data_files = {ds_cfg["split"]: data_file} - ds = load_dataset( # pylint: disable=invalid-name - "json", - data_files=data_files, - split=ds_cfg["split"], - ) - split_datasets.insert(i, ds) - else: - ds = load_dataset( # pylint: disable=invalid-name - ds_cfg["path"], - split=ds_cfg["split"], - ) - split_datasets.insert(i, ds) - - for i, data_set in enumerate(split_datasets): - _type = dataset_cfgs[i]["type"] - if _type: - if isinstance(_type, DictDefault): - _type = "user_defined.default" - ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i) - split_datasets[i] = data_set.map( - ds_transform_fn, - desc="Mapping RL Dataset", - ) - else: - # If no `type` is provided, assume the dataset is already in the expected format with - # "prompt", "chosen" and "rejected" already preprocessed - split_datasets[i] = data_set - - return concatenate_datasets(split_datasets) - - with zero_first(is_main_process()): - train_is_preprocessed = False - eval_is_preprocessed = False - if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets): - train_is_preprocessed = True - else: - train_dataset = load_split(cfg.datasets, cfg) - - eval_dataset = None - if cfg.test_datasets: - if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets): - eval_is_preprocessed = True - else: - eval_dataset = load_split(cfg.test_datasets, cfg) - if not eval_dataset: - eval_dataset = None - - if not train_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.datasets, train_dataset) - if eval_dataset and not eval_is_preprocessed: - _save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset) - - return train_dataset, eval_dataset diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py new file mode 100644 index 0000000000..e05701e7b0 --- /dev/null +++ b/src/axolotl/utils/data/utils.py @@ -0,0 +1,10 @@ +"""data handling helpers""" + +import hashlib + + +def md5(to_hash: str, encoding: str = "utf-8") -> str: + try: + return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest() + except TypeError: + return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec