diff --git a/src/axolotl/core/tokenizer_utils.py b/src/axolotl/core/tokenizer_utils.py
deleted file mode 100644
index 1b86b9c49..000000000
--- a/src/axolotl/core/tokenizer_utils.py
+++ /dev/null
@@ -1,272 +0,0 @@
-"""
-helper functions for fixing the embeddings/tokenizer
-"""
-
-# Copyright 2023-present Daniel Han-Chen & the Unsloth team. All rights reserved.
-# GNU LESSER GENERAL PUBLIC LICENSE
-# Version 3, 29 June 2007
-#
-# Copyright (C) 2007 Free Software Foundation, Inc.
-# Everyone is permitted to copy and distribute verbatim copies
-# of this license document, but changing it is not allowed.
-
-import gc
-import itertools
-import logging
-from collections import Counter
-
-import datasets
-import numpy as np
-import torch
-
-LOG = logging.getLogger("axolotl.core.tokenizer_utils")
-
-
-@torch.inference_mode()
-def fix_untrained_tokens( # pylint: disable=too-many-return-statements
- model, tokenizer, train_dataset, ignored_tokenizer_names=None, eps=1e-16
-):
- """
- Llama-3 for eg has untrained vectors in the base model.
- These include <|eot_id|>, <|start_header_id|>, <|end_header_id|>
- We reset them to the mean of the rest of the tokens
- """
- # Code licensed under LGPL
- embedding_matrix = model.get_input_embeddings().weight
- lm_head_matrix = model.get_output_embeddings().weight
- chat_template = getattr(tokenizer, "chat_template", None)
- tokenizer = tokenizer.tokenizer if hasattr(tokenizer, "tokenizer") else tokenizer
-
- # Ignore some model checks for now
- if not ignored_tokenizer_names:
- ignored_tokenizer_names = []
- if (
- model.config._name_or_path # pylint: disable=protected-access
- in ignored_tokenizer_names
- ):
- return
-
- # Sometimes the sizes can be different like in vision models
- # Ie is in input, but not in output
- min_size = min(embedding_matrix.shape[1], lm_head_matrix.shape[1])
- embedding_matrix = embedding_matrix[:, :min_size]
- lm_head_matrix = lm_head_matrix[:, :min_size]
-
- # Get untrained tokens
- indicator_untrained1 = torch.amax(embedding_matrix, axis=1) <= eps
- # Check lm_head as well
-
- # Does NOT work for Llama 3.1!!
- indicator_untrained2 = torch.amax(lm_head_matrix, axis=1) <= eps
-
- # We instead check for repeated vectors
- lm_head_where = torch.where(indicator_untrained1)[0]
- lm_head_bad = lm_head_matrix[lm_head_where]
- lm_head_bad = lm_head_bad.cpu().float().numpy().round(3)
- counter = Counter()
- for row in lm_head_bad:
- counter[hash(row.data.tobytes())] += 1
- counter = Counter({k: c for k, c in counter.items() if c >= 2})
-
- lm_head_where = lm_head_where.cpu().numpy()
- final_bad_lm_head = []
- for j, row in enumerate(lm_head_bad):
- if hash(row.data.tobytes()) in counter:
- final_bad_lm_head.append(lm_head_where[j])
- indicator_untrained2 = indicator_untrained2 | torch.zeros_like(indicator_untrained2)
- indicator_untrained2[final_bad_lm_head] = True
-
- # Combine both checks
- indicator_untrained = indicator_untrained1 & indicator_untrained2
-
- # Remove pad token possibility
- if hasattr(tokenizer, "pad_token_id"):
- pad_token_id = tokenizer.pad_token_id
- if pad_token_id is not None and pad_token_id < indicator_untrained.shape[0]:
- indicator_untrained[pad_token_id] = False
-
- where_untrained = torch.where(indicator_untrained)[0]
- n_untrained = where_untrained.shape[0]
- n_trained = embedding_matrix.shape[0] - n_untrained
-
- # Get set and actual tokens
- where_untrained = where_untrained.tolist()
- if len(where_untrained) == 0:
- return
-
- # Remove untrained indices where it's longer
- where_untrained_set = frozenset(where_untrained)
- actual_bad_tokens = tokenizer.convert_ids_to_tokens(where_untrained)
- # Remove None items in actual_bad_tokens
- actual_bad_tokens = [x for x in actual_bad_tokens if x is not None]
-
- # Check if tokenizer and training datasets have bad tokens
- if_bad_first = False
- if_bad_second = False
- # Check tokenizer's chat template for any untrained tokens
- if chat_template is not None:
- if_bad_first = any(x in chat_template for x in actual_bad_tokens)
-
- if isinstance(train_dataset, datasets.IterableDataset):
- # Skip the check, since the code below assumes
- # an indexable dataset
- return
-
- # Check the first 250, last 250 input_ids
- size_dataset = len(train_dataset)
- size = min(size_dataset, 250)
- for j in range(size):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- if_bad = any(item in where_untrained_set for item in input_ids)
- if if_bad:
- if_bad_second = True
- break
-
- # Check last 250
- if not if_bad_second:
- left = max(size_dataset - 250, 0)
- for j in range(left, size_dataset):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- if_bad = any(item in where_untrained_set for item in input_ids)
- if if_bad:
- if_bad_second = True
- break
-
- # Check if bad tokens exists!
- if not if_bad_first and not if_bad_second:
- return
-
- # Check if lm_head / embed_token are trainable!
- bad_not_trainable = False
- if not embedding_matrix.requires_grad:
- bad_not_trainable = True
- if not lm_head_matrix.requires_grad:
- bad_not_trainable = True
-
- if bad_not_trainable: # pylint: disable=too-many-nested-blocks
- final_bad_items = []
-
- # Re-check the first 250, last 250 input_ids
- size_dataset = len(train_dataset)
- size = min(size_dataset, 250)
- for j in range(size):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- for item in input_ids:
- if item in where_untrained_set:
- final_bad_items.append(item)
-
- # Re-check last 250
- left = max(size_dataset - 250, 0)
- for j in range(left, size_dataset):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- for item in input_ids:
- if item in where_untrained_set:
- final_bad_items.append(item)
-
- # If no bad tokens, possibly chat template itself has issues?
- if len(final_bad_items) == 0:
- # Recheck 2000 and last 2000 items
- size_dataset = len(train_dataset)
- size = min(size_dataset, 2000)
- for j in range(size):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- for item in input_ids:
- if item in where_untrained_set:
- final_bad_items.append(item)
-
- # Re-check last 2000
- left = max(size_dataset - 2000, 0)
- for j in range(left, size_dataset):
- input_ids = train_dataset[j]
- if "input_ids" in input_ids:
- input_ids = input_ids["input_ids"]
- for item in input_ids:
- if item in where_untrained_set:
- final_bad_items.append(item)
-
- # Most likely false signal!
- if len(final_bad_items) == 0:
- return
-
- raise ValueError(
- f"Untrained tokens of [{list(set(final_bad_items))}] found, but embed_tokens & lm_head not trainable, causing NaNs. "
- )
-
- # Count all the possible bad tokens
- final_counts = np.zeros(
- max(len(tokenizer), embedding_matrix.shape[0]), dtype=np.int64
- )
-
- def mapping(examples):
- input_ids = examples["input_ids"]
- counter = np.fromiter(itertools.chain.from_iterable(input_ids), dtype=np.int32)
- np.add.at(final_counts, counter, 1)
-
- train_dataset.map(mapping, batched=True, desc="Counting untrained tokens")
-
- # Get counts for untrained tokens
- counts_untrained = final_counts[where_untrained]
- # Identify untrained tokens seen in train_dataset
- indices_seen_in_train = np.where(counts_untrained > 0)[0]
- tokens_to_update = [where_untrained[i] for i in indices_seen_in_train]
-
- if len(tokens_to_update) == 0:
- LOG.info(
- "No untrained tokens found in train_dataset. No embeddings were modified."
- )
- return
-
- # Log the token IDs that are being rescaled
- LOG.info(
- f"Rescaling embeddings for tokens seen in train_dataset: {tokens_to_update}"
- )
-
- # Get sum of all items
- sum_embedding = torch.sum(embedding_matrix, dtype=torch.float32, axis=0)
- sum_lm_head = torch.sum(lm_head_matrix, dtype=torch.float32, axis=0)
-
- # Remove bad tokens
- sum_embedding -= torch.sum(
- embedding_matrix[where_untrained], dtype=torch.float32, axis=0
- )
- sum_lm_head -= torch.sum(
- lm_head_matrix[where_untrained], dtype=torch.float32, axis=0
- )
-
- # Find correct average by dividing by sum of trained tokens
- mean_embedding = sum_embedding / n_trained
- mean_lm_head = sum_lm_head / n_trained
-
- # Compute scaling for tokens to update
- scaling = counts_untrained[indices_seen_in_train] / max(final_counts.max(), 1)
- scaling = torch.tensor(scaling, device=mean_embedding.device).unsqueeze(1)
-
- # Prepare mean embeddings for tokens to update
- mean_embedding_repeated = (
- mean_embedding.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
- )
- mean_lm_head_repeated = (
- mean_lm_head.unsqueeze(0).repeat(len(tokens_to_update), 1) * scaling
- )
-
- # Update embeddings only for tokens seen in train_dataset
- embedding_matrix[tokens_to_update] = mean_embedding_repeated.to(
- embedding_matrix.dtype
- )
- lm_head_matrix[tokens_to_update] = mean_lm_head_repeated.to(lm_head_matrix.dtype)
-
- # Clean up
- for _ in range(3):
- gc.collect()
- torch.cuda.empty_cache()
- return