diff --git a/llmfoundry/callbacks/__init__.py b/llmfoundry/callbacks/__init__.py
index d9bb3c24a7..dc3ee707ac 100644
--- a/llmfoundry/callbacks/__init__.py
+++ b/llmfoundry/callbacks/__init__.py
@@ -11,6 +11,8 @@
 from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
 from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
 from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
+from llmfoundry.callbacks.log_mbmoe_tok_per_expert_callback import \
+    MegaBlocksMoE_TokPerExpert
 from llmfoundry.callbacks.monolithic_ckpt_callback import \
     MonolithicCheckpointSaver
 from llmfoundry.callbacks.resumption_callbacks import (GlobalLRScaling,
@@ -34,6 +36,7 @@
 callbacks.register('scheduled_gc', func=ScheduledGarbageCollector)
 callbacks.register('oom_observer', func=OOMObserver)
 callbacks.register('eval_output_logging', func=EvalOutputLogging)
+callbacks.register('mbmoe_tok_per_expert', func=MegaBlocksMoE_TokPerExpert)
 
 callbacks_with_config.register('async_eval', func=AsyncEval)
 callbacks_with_config.register('curriculum_learning', func=CurriculumLearning)
@@ -46,6 +49,7 @@
     'ScheduledGarbageCollector',
     'EvalGauntlet',
     'HuggingFaceCheckpointer',
+    'MegaBlocksMoE_TokPerExpert',
     'AsyncEval',
     'CurriculumLearning',
 ]
diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py
index b7d80bd5f8..baa72a7f66 100644
--- a/llmfoundry/callbacks/hf_checkpointer.py
+++ b/llmfoundry/callbacks/hf_checkpointer.py
@@ -15,6 +15,7 @@
 from typing import Any, Dict, List, Optional, Sequence, Union
 
 import torch
+import torch.nn as nn
 from composer.core import Callback, Event, State, Time, TimeUnit
 from composer.core.state import fsdp_state_dict_type_context
 from composer.loggers import Logger, MLFlowLogger
@@ -24,6 +25,7 @@
                             parse_uri)
 from composer.utils.misc import create_interval_scheduler
 from mlflow.transformers import _fetch_model_card, _write_license_information
+from packaging import version
 from transformers import PreTrainedModel, PreTrainedTokenizerBase
 
 from llmfoundry.models.mpt import MPTConfig, MPTForCausalLM
@@ -312,28 +314,72 @@ def _save_checkpoint(self, state: State, logger: Logger):
             state_dict_model = state.model.model
             original_tokenizer = state.model.tokenizer
 
-        state_dict_context = fsdp_state_dict_type_context(
-            original_model,
-            state_dict_type='full') if ((not state.is_model_ddp) and isinstance(
-                state_dict_model, FSDP)) else contextlib.nullcontext()
-
-        with state_dict_context:
-            state_dict = state_dict_model.state_dict()
-
-            # convert the state dict to the requested precision
-            for k, v in state_dict.items():
-                if isinstance(v, torch.Tensor):
-                    state_dict[k] = v.to(dtype=self.dtype)
+        if version.parse(torch.__version__) > version.parse('2.2.9'):
+            from torch.distributed._tensor import DTensor
+            from torch.distributed.checkpoint.state_dict import (
+                StateDictOptions, get_model_state_dict)
+            cpu_offload = True
+
+            # Add a dtensor->cpu tensor hook to avoid CUDA OOM
+            def dtensor_to_tensor_hook(
+                module: nn.Module,
+                state_dict: Dict[str, Any],
+                prefix: str,
+                *args: Any,
+            ) -> Dict[str, Any]:
+                dtensor_fqns = []
+                for fqn in state_dict.keys():
+                    tensor = state_dict[fqn]
+                    if isinstance(tensor, DTensor):
+                        dtensor_fqns.append(fqn)
+                        tensor = tensor.full_tensor()  # type: ignore
+                        if dist.get_global_rank() == 0:
+                            if cpu_offload:
+                                tensor = tensor.cpu()
+                            state_dict[fqn] = tensor
+                if dist.get_global_rank() != 0:
+                    for fqn in dtensor_fqns:
+                        del state_dict[fqn]
+                return state_dict
+
+            hooks = []
+            for _, module in state_dict_model.named_modules():
+                if isinstance(module, FSDP):
+                    hooks.append(
+                        module._register_state_dict_hook(
+                            dtensor_to_tensor_hook))
+
+            state_dict = get_model_state_dict(state_dict_model,
+                                              options=StateDictOptions(
+                                                  full_state_dict=True,
+                                                  cpu_offload=cpu_offload))
+            for hook in hooks:
+                hook.remove()
+        else:
+            state_dict_context = fsdp_state_dict_type_context(
+                original_model, state_dict_type='full') if (
+                    (not state.is_model_ddp) and isinstance(
+                        state_dict_model, FSDP)) else contextlib.nullcontext()
+            with state_dict_context:
+                state_dict = state_dict_model.state_dict()
+
+        # Convert the state dict to the requested precis
+        for k, v in state_dict.items():
+            if isinstance(v, torch.Tensor):
+                state_dict[k] = v.to(dtype=self.dtype)
 
         new_model_instance = None  # Need this for pyright because variable could be unbound
 
         if dist.get_global_rank() == 0:
             log.debug('Saving Hugging Face checkpoint in global rank 0')
 
+            # Edit HF config before building 2nd model copy
             copied_config = copy.deepcopy(original_model.config)
             if copied_config.model_type == 'mpt':
                 copied_config.attn_config['attn_impl'] = 'torch'
                 copied_config.init_device = 'cpu'
+                if 'moe_world_size' in getattr(copied_config, 'ffn_config', {}):
+                    copied_config.ffn_config['moe_world_size'] = 1
 
             log.debug(f'Creating new model instance')
 
diff --git a/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py
new file mode 100644
index 0000000000..fc906e0d87
--- /dev/null
+++ b/llmfoundry/callbacks/log_mbmoe_tok_per_expert_callback.py
@@ -0,0 +1,140 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Log tokens per expert for MegaBlocks MoE."""
+from __future__ import annotations
+
+import torch
+from composer.core import Callback, State
+from composer.loggers import Logger
+from composer.utils import dist
+
+
+class MegaBlocksMoE_TokPerExpert(Callback):
+    """Log tokens per expert for MegaBlocks MoE.
+
+    To compute the load balancing loss, MegaBlocks caches information including `tokens_per_expert`
+    (tpe). At the :attr:`.Event.BATCH_END` event this callback gets load_balancing_loss from
+    MegaBlocks to get `tokens_per_expert` then logs statistics (<STAT>) of the number of tokens
+    assigned to experts for each layer index (l_idx) under ``mb_moe/layer<l_idx>_<STAT>_tpe``.
+
+
+    The tokens_per_expert statistics are logged by the :class:`.Logger` to the following keys as
+    described below.
+
+    +----------------------------------+-----------------------------------------------------------+
+    | Key                              | Logged data                                               |
+    +==================================+===========================================================+
+    | `mb_moe/alllayer_min_tpe`        | Minimum tokens per expert across all layers               |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/alllayer_max_tpe`        | Maximum tokens per expert across all layers               |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/alllayer_median_tpe`     | Median tokens per expert across all layers                |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/alllayer_std_tpe`        | Standard deviation of tokens per expert across all layers |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/layer<l_idx>_min_tpe`    | Minimum tokens per expert at l_idx layer                  |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/layer<l_idx>_max_tpe`    | Maximum tokens per expert at l_idx layer                  |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/layer<l_idx>_median_tpe` | Median tokens per expert at l_idx layer                   |
+    +----------------------------------+-----------------------------------------------------------+
+    | `mb_moe/layer<l_idx>_std_tpe`    | Standard deviation of tokens per expert at l_idx layer    |
+    +----------------------------------+-----------------------------------------------------------+
+
+    Args:
+        log_interval (int, optional): The interval on which to log (Default: 10).
+        log_every_layer (bool, optional): Enable logging ever layer's statisictics (True) or log
+            only aggregate statistics (Default: False).
+        all_reduce_stats (bool, optional): Enable aggregating statistics across gpus (True) or log
+            statistics for GPU 0 (Default: False).
+        normalize (bool, optional): Normalize token counts by total tokens (Default: True) or output
+            raw token count (False). When normalize is True, the callback displays the fraction of
+            unique tokens routed to each expert. When normalize is False, the callback displays the
+            total number of tokens routed to each expert.
+    """
+
+    def __init__(
+        self,
+        log_interval: int = 10,
+        log_every_layer: bool = False,
+        all_reduce_stats: bool = False,
+        normalize: bool = True,
+    ):
+        self.log_interval = log_interval
+        self.log_every_layer = log_every_layer
+        self.all_reduce_stats = all_reduce_stats
+        self.normalize = normalize
+
+        self.topk = None
+
+    def fit_start(self, state: State, logger: Logger) -> None:
+        if self.topk is None and self.normalize:
+            try:
+                from megablocks.layers.dmoe import dMoE
+                from megablocks.layers.moe import MoE
+            except:
+                raise RuntimeError(
+                    'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
+                )
+            for module in state.model.modules():
+                if isinstance(module, (MoE, dMoE)):
+                    self.topk = module.experts.args.moe_top_k
+                    return
+
+            raise RuntimeError(
+                f'Callback not initialized correctly; self.topk not instantiated.'
+            )
+
+    def batch_end(self, state: State, logger: Logger) -> None:
+        if state.timestamp.batch.value % self.log_interval == 0:
+            try:
+                from megablocks.layers.moe import get_load_balancing_loss
+            except:
+                raise RuntimeError(
+                    'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
+                )
+            tokens_per_expert, _ = zip(*get_load_balancing_loss())
+
+            tokens_per_expert = [
+                tpe.clone().detach() for tpe in tokens_per_expert
+            ]
+            if self.all_reduce_stats:
+                for tpe in tokens_per_expert:
+                    dist.all_reduce(tpe)
+
+            if self.normalize:
+                tokens_per_expert = [
+                    tpe / (tpe.sum() / self.topk) for tpe in tokens_per_expert
+                ]
+
+            all_tokens_per_expert = torch.concat(tokens_per_expert)
+
+            min_tpe = all_tokens_per_expert.min().item()
+            max_tpe = all_tokens_per_expert.max().item()
+            median_tpe = all_tokens_per_expert.median().item()
+            std_tpe = all_tokens_per_expert.float().std().item()
+
+            log_info = {
+                f'mb_moe/all_layers_min_tpe': min_tpe,
+                f'mb_moe/all_layers_max_tpe': max_tpe,
+                f'mb_moe/all_layers_median_tpe': median_tpe,
+                f'mb_moe/all_layers_std_tpe': std_tpe,
+            }
+
+            if self.log_every_layer:
+                for l_idx, tpe_layer in enumerate(tokens_per_expert):
+
+                    min_tpe = tpe_layer.min().item()
+                    max_tpe = tpe_layer.max().item()
+                    median_tpe = tpe_layer.median().item()
+                    std_tpe = tpe_layer.float().std().item()
+
+                    log_info.update({
+                        f'mb_moe/layer{l_idx}_min_tpe': min_tpe,
+                        f'mb_moe/layer{l_idx}_max_tpe': max_tpe,
+                        f'mb_moe/layer{l_idx}_median_tpe': median_tpe,
+                        f'mb_moe/layer{l_idx}_std_tpe': std_tpe,
+                    })
+
+            logger.log_metrics(log_info)
diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py
index 38c9673a14..1d8711d280 100644
--- a/llmfoundry/data/finetuning/dataloader.py
+++ b/llmfoundry/data/finetuning/dataloader.py
@@ -170,6 +170,8 @@ def build_finetuning_dataloader(cfg: DictConfig,
             sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
             batching_method=cfg.dataset.get('batching_method', 'random'),
             max_seq_len=cfg.dataset.max_seq_len,
+            allow_unsafe_types=cfg.dataset.get('allow_unsafe_types', False),
+            replication=cfg.dataset.get('replication', None),
         )
 
     else:
diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py
index 4ca15e8d1f..4906cea151 100644
--- a/llmfoundry/data/finetuning/tasks.py
+++ b/llmfoundry/data/finetuning/tasks.py
@@ -77,7 +77,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
 _ALLOWED_MESSAGES_KEYS = {'messages'}
 _ALLOWED_ROLE_KEYS = {'role'}
 _ALLOWED_CONTENT_KEYS = {'content'}
-_ALLOWED_ROLES = {'user', 'assistant', 'system'}
+_ALLOWED_ROLES = {'user', 'assistant', 'system', 'tool'}
 _ALLOWED_LAST_MESSAGE_ROLES = {'assistant'}
 DOWNLOADED_FT_DATASETS_DIRPATH = os.path.abspath(
     os.path.join(os.path.realpath(__file__), os.pardir, os.pardir, os.pardir,
@@ -217,7 +217,7 @@ def slice_out_last_turn(
         if conversation_through_previous_turn != prompt_with_history[:len(
                 conversation_through_previous_turn)]:
             raise ValueError(
-                f'The prompt_with_histry must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
+                f'The prompt_with_history must start with the conversation through the previous turn. {conversation_through_previous_turn=}, {prompt_with_history=}'
             )
         if prompt_with_history != full_conversation[:len(prompt_with_history)]:
             raise ValueError(
@@ -490,6 +490,12 @@ class StreamingFinetuningDataset(StreamingDataset):
             Defaults to ``1``.
         batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
             ``per_stream``. Defaults to ``random``.
+        allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
+            execution during deserialization, whether to keep going if ``True`` or raise an error
+            if ``False``. Defaults to ``False``.
+        replication (int, optional): Determines how many consecutive devices will receive the same
+            samples. Useful for training with tensor or sequence parallelism, where multiple
+            devices need to see the same partition of the dataset. Defaults to ``None``.
     """
 
     def __init__(self,
@@ -516,6 +522,8 @@ def __init__(self,
                  sampling_granularity: int = 1,
                  batching_method: str = 'random',
                  max_seq_len: int = 2048,
+                 allow_unsafe_types: bool = False,
+                 replication: Optional[int] = None,
                  **kwargs: Any):
 
         if len(kwargs) > 0:
@@ -552,6 +560,8 @@ def __init__(self,
             sampling_method=sampling_method,
             sampling_granularity=sampling_granularity,
             batching_method=batching_method,
+            allow_unsafe_types=allow_unsafe_types,
+            replication=replication,
         )
 
         self.tokenizer = tokenizer
@@ -614,9 +624,8 @@ def print_registered_tasks(self) -> None:
         log.info('\n'.join(tasks))
 
     def get_preprocessing_fn_from_dict(
-            self,
-            mapping: Dict[str,
-                          str]) -> Callable[[Dict[str, Any]], Dict[str, str]]:
+            self, mapping: Dict[str,
+                                str]) -> Callable[[Dict[str, Any]], Example]:
         """Get a preprocessing function from a dictionary.
 
         The dictionary maps column names in the dataset to "prompt" and "response".
@@ -652,7 +661,7 @@ def get_preprocessing_fn_from_str(
         self,
         preprocessor: Optional[str],
         dataset_name: Optional[str] = None
-    ) -> Optional[Callable[[Dict[str, Any]], Dict[str, str]]]:
+    ) -> Optional[Callable[[Dict[str, Any]], Example]]:
         """Get a preprocessing function from a string.
 
         String can be either a registered function or an import path.
@@ -700,7 +709,7 @@ def get_preprocessing_fn_from_str(
 
     def build_from_hf(
         self, dataset_name: str, split: str, safe_load: bool, max_seq_len: int,
-        preprocessing_fn: Optional[Callable[[dict[str, Any]], dict[str, str]]],
+        preprocessing_fn: Optional[Callable[[dict[str, Any]], Example]],
         tokenizer: PreTrainedTokenizerBase, target_prompts: str,
         target_responses: str, decoder_only_format: bool, hf_kwargs: Dict[str,
                                                                           Any]
@@ -783,7 +792,8 @@ def build_from_hf(
 
             def dataset_mapper(example: Dict):
                 if preprocessing_fn is not None:
-                    example = preprocessing_fn(example)
+                    return tokenize_formatted_example(preprocessing_fn(example),
+                                                      tokenizer)
                 return tokenize_formatted_example(example, tokenizer)
 
             detected_cpu_count = os.cpu_count() or 1
@@ -847,7 +857,7 @@ def build_from_streaming(self, *args: Any,
 
 
 @dataset_constructor.register('tatsu-lab/alpaca')
-def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
+def alpaca_preprocessing_function(inp: Dict) -> PromptResponseDict:
     """Split out prompt/response from text."""
     try:
         prompt, response = inp['text'].split('### Response:')
@@ -859,7 +869,7 @@ def alpaca_preprocessing_function(inp: Dict) -> Dict[str, str]:
 
 
 @dataset_constructor.register('HuggingFaceH4/databricks_dolly_15k')
-def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
+def dolly_preprocessing_function(inp: Dict) -> PromptResponseDict:
     """Format the text string."""
     PROMPT_FORMAT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n'
     try:
@@ -875,7 +885,7 @@ def dolly_preprocessing_function(inp: Dict) -> Dict[str, str]:
 
 
 @dataset_constructor.register('bigscience/P3')
-def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
+def p3_preprocessing_function(inp: Dict) -> PromptResponseDict:
     """Format the already-split example."""
     return {
         'prompt': inp['inputs'] + ':',
@@ -885,7 +895,7 @@ def p3_preprocessing_function(inp: Dict) -> Dict[str, str]:
 
 # Muennighoff's P3 and flan datasets share a similar convention
 @dataset_constructor.register('Muennighoff/P3', 'Muennighoff/flan')
-def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
+def muennighoff_tokenize_function(inp: Dict) -> PromptResponseDict:
     """Format the already-split example."""
     try:
         prompt: str = inp['inputs']
@@ -898,3 +908,22 @@ def muennighoff_tokenize_function(inp: Dict) -> Dict[str, str]:
     except Exception as e:
         raise UnableToProcessPromptResponseError(inp) from e
     return {'prompt': prompt, 'response': response}
+
+
+@dataset_constructor.register('teknium/OpenHermes-2.5')
+def shareGPT_format_preprocessor(inp: Dict) -> ChatFormattedDict:
+    """Convert from ShareGPT format to our chat format."""
+    role_map = {
+        'human': 'user',
+        'gpt': 'assistant',
+    }
+    try:
+        conversation = inp['conversations']
+        messages: List[Dict[str, str]] = []
+        for message in conversation:
+            role: str = role_map.get(message['from'], message['from'])
+            content: str = message['value']
+            messages.append({'role': role, 'content': content})
+    except Exception as e:
+        raise UnableToProcessPromptResponseError(inp) from e
+    return {'messages': messages}
diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py
index e85968543c..fc31b890b0 100644
--- a/llmfoundry/data/text_data.py
+++ b/llmfoundry/data/text_data.py
@@ -83,6 +83,12 @@ class StreamingTextDataset(StreamingDataset):
             Defaults to ``1``.
         batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
             ``per_stream``. Defaults to ``random``.
+        allow_unsafe_types (bool): If a shard contains Pickle, which allows arbitrary code
+            execution during deserialization, whether to keep going if ``True`` or raise an error
+            if ``False``. Defaults to ``False``.
+        replication (int, optional): Determines how many consecutive devices will receive the same
+            samples. Useful for training with tensor or sequence parallelism, where multiple
+            devices need to see the same partition of the dataset. Defaults to ``None``.
     """
 
     def __init__(self,
@@ -109,6 +115,8 @@ def __init__(self,
                  sampling_method: str = 'balanced',
                  sampling_granularity: int = 1,
                  batching_method: str = 'random',
+                 allow_unsafe_types: bool = False,
+                 replication: Optional[int] = None,
                  **kwargs: Any):
 
         if len(kwargs) > 0:
@@ -151,6 +159,8 @@ def __init__(self,
             sampling_method=sampling_method,
             sampling_granularity=sampling_granularity,
             batching_method=batching_method,
+            allow_unsafe_types=allow_unsafe_types,
+            replication=replication,
         )
         self.tokenizer = tokenizer
         self.max_seq_len = max_seq_len
diff --git a/llmfoundry/eval/datasets/in_context_learning_evaluation.py b/llmfoundry/eval/datasets/in_context_learning_evaluation.py
index bd5c7dc30c..8f317f60b8 100644
--- a/llmfoundry/eval/datasets/in_context_learning_evaluation.py
+++ b/llmfoundry/eval/datasets/in_context_learning_evaluation.py
@@ -478,14 +478,15 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
         batch['attention_mask'] = ~(batch['input_ids'] == self.pad_tok_id)
         return batch
 
-    def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
+    def split_batch(self, batch: Any,
+                    microbatch_size: Union[int, float]) -> Sequence[Any]:
         """Handling for certain specialty columns that must be split into.
 
         batches in different formats.
 
         Args:
             batch (Dict): Batch of data
-            microbatch_size (int): Size of microbatches
+            microbatch_size (int | float): Size of microbatches
 
         Returns:
             List: List of chunked batches
@@ -493,6 +494,9 @@ def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
         # Don't split kwargs that don't change
         # Normally split torch tensors
         # List split lists of strings
+        if isinstance(microbatch_size, float):
+            raise ValueError(
+                'split_batch does not support floating point microbatch_size.')
         chunked = {}
         for k, v in batch.items():
             if k in self.static_keys:
@@ -905,7 +909,8 @@ def collate_fn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
     def get_num_samples_in_batch(self, batch: Dict[str, torch.Tensor]) -> int:
         return batch['input_ids'].shape[0] // self.num_choices
 
-    def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
+    def split_batch(self, batch: Any,
+                    microbatch_size: Union[int, float]) -> Sequence[Any]:
         """Split batch while ensuring all continuations are in the same.
 
         microbatch.
@@ -917,11 +922,14 @@ def split_batch(self, batch: Any, microbatch_size: int) -> Sequence:
         microbatch_size and real attributes by microbatch_size * num_choices.
         Args:
             batch (Dict): Batch of data
-            microbatch_size (int): Size of microbatches
+            microbatch_size (int | float): Size of microbatches
 
         Returns:
             list: List of chunked batches
         """
+        if isinstance(microbatch_size, float):
+            raise ValueError(
+                'split_batch does not support floating point microbatch_size.')
         chunked = {}
         for k, v in batch.items():
             if k in self.static_keys:
diff --git a/llmfoundry/layers_registry.py b/llmfoundry/layers_registry.py
index 9c7dabe128..3c0a7ebd6e 100644
--- a/llmfoundry/layers_registry.py
+++ b/llmfoundry/layers_registry.py
@@ -1,20 +1,53 @@
 # Copyright 2024 MosaicML LLM Foundry authors
 # SPDX-License-Identifier: Apache-2.0
 
-from typing import Type
+from typing import Callable, Type
 
 import torch
 
 from llmfoundry.utils.registry_utils import create_registry
 
-# Layers
-_norm_description = """The norms registry is used to register classes that implement normalization layers."""
+_norm_description = (
+    'The norms registry is used to register classes that implement normalization layers.'
+)
 norms = create_registry('llmfoundry',
                         'norms',
                         generic_type=Type[torch.nn.Module],
                         entry_points=True,
                         description=_norm_description)
+_fc_description = (
+    'The fully connected layers registry is used to register classes that implement fully connected layers (i.e. torch.nn.Linear).'
+    +
+    'These classes should take in_features and out_features in as args, at a minimum.'
+)
+fcs = create_registry('llmfoundry',
+                      'fcs',
+                      generic_type=Type[torch.nn.Module],
+                      entry_points=True,
+                      description=_fc_description)
+
+_attention_classes_description = (
+    'The attention_classes registry is used to register classes that implement attention layers. See '
+    + 'attention.py for expected constructor signature.')
+attention_classes = create_registry('llmfoundry',
+                                    'attention_classes',
+                                    generic_type=Type[torch.nn.Module],
+                                    entry_points=True,
+                                    description=_attention_classes_description)
+
+_attention_implementations_description = (
+    'The attention_implementations registry is used to register functions that implement the attention operation.'
+    + 'See attention.py for expected function signature.')
+attention_implementations = create_registry(
+    'llmfoundry',
+    'attention_implementations',
+    generic_type=Callable,
+    entry_points=True,
+    description=_attention_implementations_description)
 
 __all__ = [
     'norms',
+    'attention_classes',
+    'attention_implementations',
+    'fcs',
 ]
diff --git a/llmfoundry/models/layers/__init__.py b/llmfoundry/models/layers/__init__.py
index 262f190b47..5784fcd7e9 100644
--- a/llmfoundry/models/layers/__init__.py
+++ b/llmfoundry/models/layers/__init__.py
@@ -2,12 +2,12 @@
 # SPDX-License-Identifier: Apache-2.0
 
 from llmfoundry.models.layers.attention import (
-    ATTN_CLASS_REGISTRY, GroupedQueryAttention, MultiheadAttention,
-    MultiQueryAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
-    flash_attn_fn, scaled_multihead_dot_product_attention)
+    GroupedQueryAttention, MultiheadAttention, MultiQueryAttention,
+    attn_bias_shape, build_alibi_bias, build_attn_bias, flash_attn_fn,
+    scaled_multihead_dot_product_attention)
 from llmfoundry.models.layers.blocks import MPTBlock
 from llmfoundry.models.layers.custom_embedding import SharedEmbedding
-from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
+from llmfoundry.models.layers.fc import *
 from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
 from llmfoundry.models.layers.norm import LPLayerNorm
 
@@ -20,11 +20,9 @@
     'attn_bias_shape',
     'build_attn_bias',
     'build_alibi_bias',
-    'ATTN_CLASS_REGISTRY',
     'MPTMLP',
     'MPTBlock',
     'LPLayerNorm',
-    'FC_CLASS_REGISTRY',
     'SharedEmbedding',
     'FFN_CLASS_REGISTRY',
     'build_ffn',
diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py
index c24b3d4afa..6614d5d161 100644
--- a/llmfoundry/models/layers/attention.py
+++ b/llmfoundry/models/layers/attention.py
@@ -14,8 +14,9 @@
 from packaging import version
 from torch import nn
 
-from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
-from llmfoundry.models.layers.layer_builders import build_norm
+from llmfoundry.layers_registry import (attention_classes,
+                                        attention_implementations)
+from llmfoundry.models.layers.layer_builders import build_fc, build_norm
 
 
 def is_flash_v2_installed(v2_version: str = '2.0.0'):
@@ -341,6 +342,7 @@ def flash_attn_fn(
     return output, None, past_key_value
 
 
+@attention_classes.register_class('grouped_query_attention')
 class GroupedQueryAttention(nn.Module):
     """Grouped Query Attention (GQA) is a generalization of Multi-head (MHA).
 
@@ -406,10 +408,11 @@ def __init__(
             'bias': bias,
         }
         fc_kwargs['device'] = device
-        self.Wqkv = FC_CLASS_REGISTRY[fc_type](
-            self.d_model,
-            self.d_model + 2 * self.kv_n_heads * self.head_dim,
-            **fc_kwargs,
+        self.Wqkv = build_fc(
+            name=fc_type,
+            in_features=self.d_model,
+            out_features=self.d_model + 2 * self.kv_n_heads * self.head_dim,
+            fc_kwargs=fc_kwargs,
         )
         # for param init fn; enables shape based init of fused layers
         fuse_splits = [
@@ -433,17 +436,13 @@ def __init__(
                 device=device,
             )
 
-        if self.attn_impl == 'flash':
-            self.attn_fn = flash_attn_fn
-        elif self.attn_impl == 'torch':
-            self.attn_fn = scaled_multihead_dot_product_attention
-        else:
-            raise ValueError(f'{attn_impl=} is an invalid setting.')
+        self.attn_fn = attention_implementations.get(self.attn_impl)
 
-        self.out_proj = FC_CLASS_REGISTRY[fc_type](
-            self.d_model,
-            self.d_model,
-            **fc_kwargs,
+        self.out_proj = build_fc(
+            name=fc_type,
+            in_features=self.d_model,
+            out_features=self.d_model,
+            fc_kwargs=fc_kwargs,
         )
         self.out_proj._is_residual = True
 
@@ -572,6 +571,7 @@ def forward(
         return self.out_proj(context), attn_weights, past_key_value
 
 
+@attention_classes.register_class('multihead_attention')
 class MultiheadAttention(GroupedQueryAttention):
     """Multi-head self attention.
 
@@ -612,6 +612,7 @@ def __init__(
         )
 
 
+@attention_classes.register_class('multiquery_attention')
 class MultiQueryAttention(GroupedQueryAttention):
     """Multi-Query self attention.
 
@@ -740,8 +741,6 @@ def build_alibi_bias(
     return alibi_bias.to(dtype=dtype)
 
 
-ATTN_CLASS_REGISTRY = {
-    'multihead_attention': MultiheadAttention,
-    'multiquery_attention': MultiQueryAttention,
-    'grouped_query_attention': GroupedQueryAttention
-}
+attention_implementations.register('flash', func=flash_attn_fn)
+attention_implementations.register('torch',
+                                   func=scaled_multihead_dot_product_attention)
diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py
index 42feb983d4..1ad9ec954f 100644
--- a/llmfoundry/models/layers/blocks.py
+++ b/llmfoundry/models/layers/blocks.py
@@ -8,9 +8,9 @@
 import torch
 import torch.nn as nn
 
-from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
 from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
-from llmfoundry.models.layers.layer_builders import build_norm
+from llmfoundry.models.layers.layer_builders import (build_attention_layer,
+                                                     build_norm)
 
 try:
     from flash_attn.bert_padding import unpad_input, pad_input  # type: ignore # yapf: disable # isort: skip
@@ -68,12 +68,148 @@ def __init__(
             ffn_config = {
                 'ffn_type': 'mptmlp',
             }
+        self.fuse_norm_attn_norm = kwargs.get('fuse_norm_attn_norm', False)
 
         del kwargs  # unused, just to capture any extra args from the config
         super().__init__()
 
+        if self.fuse_norm_attn_norm:
+            self.norm_attn_norm = FusedNormAttentionNorm(
+                d_model=d_model,
+                n_heads=n_heads,
+                attn_config=attn_config,
+                ffn_config=ffn_config,
+                fc_type=fc_type,
+                resid_pdrop=resid_pdrop,
+                norm_type=norm_type,
+                device=device,
+                no_bias=no_bias,
+            )
+        else:
+            assert isinstance(attn_config['attn_type'], str)
+            # Necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
+            args_to_exclude_in_attn_class = {
+                'attn_type', 'alibi', 'attn_uses_sequence_id', 'alibi_bias_max',
+                'rope', 'rope_theta', 'rope_impl', 'rope_dail_config',
+                'rope_hf_config'
+            }
+            attn_config_subset_for_attn_class = {
+                k: v
+                for k, v in attn_config.items()
+                if k not in args_to_exclude_in_attn_class
+            }
+
+            self.norm_1 = build_norm(
+                name=norm_type.lower(),
+                normalized_shape=d_model,
+                device=device,
+            )
+            self.attn = build_attention_layer(
+                name=attn_config['attn_type'],
+                attn_kwargs={
+                    'd_model': d_model,
+                    'n_heads': n_heads,
+                    'fc_type': fc_type,
+                    'device': device,
+                    'bias': not no_bias,
+                    **attn_config_subset_for_attn_class
+                },
+            )
+            self.norm_2 = None
+            if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']],
+                           '_has_norm', False):
+                self.norm_2 = build_norm(
+                    name=norm_type.lower(),
+                    normalized_shape=d_model,
+                    device=device,
+                )
+
+        self.ffn = build_ffn(
+            d_model=d_model,
+            expansion_ratio=expansion_ratio,
+            device=device,
+            bias=not no_bias,
+            **ffn_config,
+        )
+        self.resid_attn_dropout = nn.Dropout(resid_pdrop)
+        self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
+        self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+        attn_bias: Optional[torch.Tensor] = None,
+        rotary_emb_w_meta_info: Optional[Dict] = None,
+        attention_mask: Optional[torch.ByteTensor] = None,
+        is_causal: bool = True,
+        output_attentions: bool = False,
+        alibi_slopes: Optional[torch.Tensor] = None,
+        flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
+            torch.Tensor, torch.Tensor]]]:
+        if self.fuse_norm_attn_norm:
+            x, m, attn_weights, past_key_value = self.norm_attn_norm(
+                x,
+                past_key_value=past_key_value,
+                attn_bias=attn_bias,
+                rotary_emb_w_meta_info=rotary_emb_w_meta_info,
+                attention_mask=attention_mask,
+                is_causal=is_causal,
+                output_attentions=output_attentions,
+                alibi_slopes=alibi_slopes,
+                flash_attn_padding_info=flash_attn_padding_info,
+            )
+        else:
+            a = self.norm_1(x)
+            b, attn_weights, past_key_value = self.attn(
+                a,
+                past_key_value=past_key_value,
+                attn_bias=attn_bias,
+                rotary_emb_w_meta_info=rotary_emb_w_meta_info,
+                attention_mask=attention_mask,
+                is_causal=is_causal,
+                needs_weights=output_attentions,
+                alibi_slopes=alibi_slopes,
+                flash_attn_padding_info=flash_attn_padding_info,
+            )
+            x = x + self.resid_attn_dropout(b)
+            m = x
+            if self.norm_2 is not None:
+                m = self.norm_2(x)
+
+        batch_size, seq_len = m.size()[:2]
+        indices = None
+        if not self.use_pad_tok_in_ffn:
+            assert unpad_input is not None
+            m, indices, _, _ = unpad_input(m, attention_mask)
+        n = self.ffn(m)
+        if not self.use_pad_tok_in_ffn:
+            assert pad_input is not None
+            n = pad_input(n, indices, batch_size, seq_len)
+        x = x + self.resid_ffn_dropout(n)
+        return x, attn_weights, past_key_value
+
+
+class FusedNormAttentionNorm(nn.Module):
+
+    def __init__(
+        self,
+        d_model: int,
+        n_heads: int,
+        attn_config: Optional[Dict] = None,
+        ffn_config: Optional[Dict] = None,
+        fc_type: str = 'torch',
+        resid_pdrop: float = 0.0,
+        norm_type: str = 'low_precision_layernorm',
+        device: Optional[str] = None,
+        no_bias: bool = False,
+        **kwargs: Any,
+    ):
+        super().__init__()
+        assert attn_config is not None
+        assert ffn_config is not None
         assert isinstance(attn_config['attn_type'], str)
-        attn_class = ATTN_CLASS_REGISTRY[attn_config['attn_type']]
 
         # necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
         args_to_exclude_in_attn_class = {
@@ -86,19 +222,21 @@ def __init__(
             for k, v in attn_config.items()
             if k not in args_to_exclude_in_attn_class
         }
-
         self.norm_1 = build_norm(
             name=norm_type.lower(),
             normalized_shape=d_model,
             device=device,
         )
-        self.attn = attn_class(
-            d_model=d_model,
-            n_heads=n_heads,
-            fc_type=fc_type,
-            device=device,
-            **attn_config_subset_for_attn_class,
-            bias=not no_bias,
+        self.attn = build_attention_layer(
+            name=attn_config['attn_type'],
+            attn_kwargs={
+                'd_model': d_model,
+                'n_heads': n_heads,
+                'fc_type': fc_type,
+                'device': device,
+                'bias': not no_bias,
+                **attn_config_subset_for_attn_class
+            },
         )
         self.norm_2 = None
         if not getattr(FFN_CLASS_REGISTRY[ffn_config['ffn_type']], '_has_norm',
@@ -108,17 +246,7 @@ def __init__(
                 normalized_shape=d_model,
                 device=device,
             )
-        self.ffn = build_ffn(
-            d_model=d_model,
-            expansion_ratio=expansion_ratio,
-            device=device,
-            bias=not no_bias,
-            **ffn_config,
-        )
         self.resid_attn_dropout = nn.Dropout(resid_pdrop)
-        self.resid_ffn_dropout = nn.Dropout(resid_pdrop)
-
-        self.use_pad_tok_in_ffn = use_pad_tok_in_ffn
 
     def forward(
         self,
@@ -131,8 +259,8 @@ def forward(
         output_attentions: bool = False,
         alibi_slopes: Optional[torch.Tensor] = None,
         flash_attn_padding_info: Optional[dict[str, torch.Tensor]] = None,
-    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
-            torch.Tensor, torch.Tensor]]]:
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
+               Optional[Tuple[torch.Tensor, torch.Tensor]]]:
         a = self.norm_1(x)
         b, attn_weights, past_key_value = self.attn(
             a,
@@ -149,14 +277,5 @@ def forward(
         m = x
         if self.norm_2 is not None:
             m = self.norm_2(x)
-        batch_size, seq_len = m.size()[:2]
-        indices = None
-        if not self.use_pad_tok_in_ffn:
-            assert unpad_input is not None
-            m, indices, _, _ = unpad_input(m, attention_mask)
-        n = self.ffn(m)
-        if not self.use_pad_tok_in_ffn:
-            assert pad_input is not None
-            n = pad_input(n, indices, batch_size, seq_len)
-        x = x + self.resid_ffn_dropout(n)
-        return x, attn_weights, past_key_value
+
+        return x, m, attn_weights, past_key_value
diff --git a/llmfoundry/models/layers/dmoe.py b/llmfoundry/models/layers/dmoe.py
new file mode 100644
index 0000000000..1a981b61c5
--- /dev/null
+++ b/llmfoundry/models/layers/dmoe.py
@@ -0,0 +1,246 @@
+# Copyright 2024 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from typing import Callable
+
+import torch
+
+
+# Add option to route tokens uniformly across experts. We use
+# a custom autograd op router backwards is still run for benchmarking.
+class _UniformExpertAssignment(torch.autograd.Function):
+
+    @staticmethod
+    def forward(
+            ctx,  # pyright: ignore[reportMissingParameterType]
+            x: torch.Tensor,
+            num_experts: int):
+        out = torch.arange(x.numel(), dtype=x.dtype, device=x.device)
+        out = torch.remainder(out, num_experts)
+        return out.view(x.shape)
+
+
+class LearnedRouter(torch.nn.Module):
+
+    def __init__(self, hidden_size: int, moe_num_experts: int, moe_top_k: int,
+                 moe_jitter_eps: float, moe_normalize_expert_weights: bool,
+                 uniform_expert_assignment: bool, device: torch.device) -> None:
+        super().__init__()
+        self.hidden_size: int = hidden_size
+        self.moe_num_experts: int = moe_num_experts
+        self.moe_top_k: int = moe_top_k
+        self.moe_jitter_eps: float = moe_jitter_eps
+        self.moe_normalize_expert_weights: bool = moe_normalize_expert_weights
+        self.uniform_expert_assignment: bool = uniform_expert_assignment
+
+        self.layer: torch.nn.Module = torch.nn.Linear(
+            hidden_size,
+            moe_num_experts,
+            bias=False,
+            device=device,
+        )
+
+    def jitter(self, x: torch.Tensor) -> torch.Tensor:
+        low: float = 1.0 - self.moe_jitter_eps
+        high: float = 1.0 + self.moe_jitter_eps
+        noise: torch.Tensor = torch.rand(x.size(),
+                                         dtype=x.dtype,
+                                         device=x.device)
+        return low + noise * (high - low)
+
+    def _top_k(self, scores: torch.Tensor) -> torch.Tensor:
+        if self.moe_top_k == 1:
+            return scores.max(
+                dim=-1)  # pyright: ignore[reportGeneralTypeIssues]
+        return torch.topk(scores, self.moe_top_k,
+                          dim=-1)  # pyright: ignore[reportGeneralTypeIssues]
+
+    def forward(self, x: torch.Tensor):
+        if self.training and self.moe_jitter_eps is not None:
+            x = x * self.jitter(x)
+
+        scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1)
+        expert_weights, top_experts = self._top_k(scores)
+        if self.moe_normalize_expert_weights:
+            expert_weights = expert_weights / torch.norm(
+                expert_weights,
+                p=self.moe_normalize_expert_weights,
+                dim=-1,
+                keepdim=True)
+
+        top_experts = (_UniformExpertAssignment.apply(top_experts,
+                                                      self.moe_num_experts)
+                       if self.uniform_expert_assignment else top_experts)
+        scores = scores.to(x.dtype)
+        expert_weights = expert_weights.to(x.dtype)
+        return scores, expert_weights, top_experts
+
+
+class MLP(torch.nn.Module):
+
+    def __init__(
+        self,
+        hidden_size: int,
+        ffn_hidden_size: int,
+        moe_num_experts: int,
+        activation_fn: Callable,
+        device: torch.device,
+    ) -> None:
+        super().__init__()
+
+        self.moe_num_experts: int = moe_num_experts
+        self.ffn_hidden_size: int = ffn_hidden_size
+        self.hidden_size: int = hidden_size
+        self.activation_fn: Callable = activation_fn
+
+        self.w1 = torch.nn.Parameter(
+            torch.rand(moe_num_experts * ffn_hidden_size,
+                       hidden_size,
+                       device=device))
+        self.w2 = torch.nn.Parameter(
+            torch.rand(moe_num_experts * ffn_hidden_size,
+                       hidden_size,
+                       device=device))
+        self.activation_fn = activation_fn
+
+    def forward(self, x: torch.Tensor, expert_idx: int) -> torch.Tensor:
+        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
+                                 self.hidden_size)[expert_idx]
+        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
+                                 self.hidden_size)[expert_idx]
+
+        before_activation = x @ expert_w1.t()
+        layer_1_output = self.activation_fn(before_activation)
+        output = layer_1_output @ expert_w2
+        return output
+
+
+class GLU(torch.nn.Module):
+
+    def __init__(self, hidden_size: int, ffn_hidden_size: int,
+                 moe_num_experts: int, activation_fn: Callable,
+                 device: torch.device):
+        super().__init__()
+        self.hidden_size = hidden_size
+        self.ffn_hidden_size = ffn_hidden_size
+        self.moe_num_experts = moe_num_experts
+
+        self.w1 = torch.nn.Parameter(
+            torch.rand(moe_num_experts * ffn_hidden_size,
+                       hidden_size,
+                       device=device))
+        self.v1 = torch.nn.Parameter(
+            torch.rand(moe_num_experts * ffn_hidden_size,
+                       hidden_size,
+                       device=device))
+        self.w2 = torch.nn.Parameter(
+            torch.rand(moe_num_experts * ffn_hidden_size,
+                       hidden_size,
+                       device=device))
+        self.activation_fn = activation_fn
+
+    def forward(self, x: torch.Tensor, expert_idx: torch.Tensor):
+        expert_w1 = self.w1.view(self.moe_num_experts, self.ffn_hidden_size,
+                                 self.hidden_size)[expert_idx]
+        expert_v1 = self.v1.view(self.moe_num_experts, self.ffn_hidden_size,
+                                 self.hidden_size)[expert_idx]
+        expert_w2 = self.w2.view(self.moe_num_experts, self.ffn_hidden_size,
+                                 self.hidden_size)[expert_idx]
+
+        x1 = x.matmul(expert_w1.t())
+        x2 = x.matmul(expert_v1.t())
+        x1 = self.activation_fn(x1)
+        x1 = x1 * x2
+        x1 = x1.matmul(expert_w2)
+        return x1
+
+
+class DroplessMLP(torch.nn.Module):
+
+    def __init__(self, hidden_size: int, ffn_hidden_size: int, mlp_type: str,
+                 moe_num_experts: int, activation_fn: Callable, bias: bool,
+                 device: torch.device):
+        super().__init__()
+        self.moe_num_experts = moe_num_experts
+
+        if mlp_type == 'mlp':
+            self.mlp = MLP(hidden_size=hidden_size,
+                           ffn_hidden_size=ffn_hidden_size,
+                           moe_num_experts=moe_num_experts,
+                           activation_fn=activation_fn,
+                           device=device)
+        elif mlp_type == 'glu':
+            self.mlp = GLU(hidden_size=hidden_size,
+                           ffn_hidden_size=ffn_hidden_size,
+                           moe_num_experts=moe_num_experts,
+                           activation_fn=activation_fn,
+                           device=device)
+        else:
+            raise ValueError(f'Received unknown {mlp_type=}')
+
+    def forward(self, x: torch.Tensor, scores: torch.Tensor,
+                expert_weights: torch.Tensor, top_experts: torch.Tensor):
+        in_shape = x.shape
+        hidden_size = in_shape[-1]
+
+        x = x.view(-1, hidden_size)
+        out = torch.zeros_like(x)
+
+        expert_mask = torch.nn.functional.one_hot(
+            top_experts, num_classes=self.moe_num_experts).permute(2, 1, 0)
+        for expert_idx in range(0, self.moe_num_experts):
+            topk_idx, token_idx = torch.where(expert_mask[expert_idx])
+            if token_idx.shape[0] == 0:
+                continue
+            # In torch it is faster to index using lists than torch tensors
+            token_list = token_idx.tolist()
+            topk_list = topk_idx.tolist()
+
+            expert_tokens = x[None, token_list].reshape(-1, hidden_size)
+            mlp_output = self.mlp(expert_tokens, expert_idx)
+            expert_out = mlp_output * expert_weights[token_list, topk_list,
+                                                     None]
+
+            out.index_add_(0, token_idx, expert_out)
+
+        out = out.view(in_shape)
+        return out
+
+
+class dMoE(torch.nn.Module):
+
+    def __init__(self, hidden_size: int, ffn_hidden_size: int,
+                 moe_num_experts: int, moe_top_k: int, mlp_type: str,
+                 activation_fn: Callable, moe_jitter_eps: float,
+                 moe_normalize_expert_weights: bool,
+                 uniform_expert_assignment: bool, bias: bool,
+                 device: torch.device):
+        super().__init__()
+
+        # Token router.
+        self.router = LearnedRouter(
+            hidden_size,
+            moe_num_experts=moe_num_experts,
+            moe_top_k=moe_top_k,
+            moe_jitter_eps=moe_jitter_eps,
+            moe_normalize_expert_weights=moe_normalize_expert_weights,
+            uniform_expert_assignment=uniform_expert_assignment,
+            device=device,
+        )
+
+        # Expert computation helper.
+        self.experts = DroplessMLP(
+            hidden_size=hidden_size,
+            ffn_hidden_size=ffn_hidden_size,
+            mlp_type=mlp_type,
+            moe_num_experts=moe_num_experts,
+            activation_fn=activation_fn,
+            bias=bias,
+            device=device,
+        )
+
+    def forward(self, x: torch.Tensor):
+        # Compute the expert scores and assignments.
+        scores, expert_weights, top_experts = self.router(x)
+        # Compute the experts.
+        return self.experts(x, scores, expert_weights, top_experts)
diff --git a/llmfoundry/models/layers/fc.py b/llmfoundry/models/layers/fc.py
index b85bc133bd..8650e4966f 100644
--- a/llmfoundry/models/layers/fc.py
+++ b/llmfoundry/models/layers/fc.py
@@ -3,12 +3,12 @@
 
 from torch import nn
 
-FC_CLASS_REGISTRY = {
-    'torch': nn.Linear,
-}
+from llmfoundry.layers_registry import fcs
+
+fcs.register('torch', func=nn.Linear)
 
 try:
     import transformer_engine.pytorch as te
-    FC_CLASS_REGISTRY['te'] = te.Linear
+    fcs.register('te', func=te.Linear)
 except:
     pass
diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py
index 9389cf385f..f0b499875a 100644
--- a/llmfoundry/models/layers/ffn.py
+++ b/llmfoundry/models/layers/ffn.py
@@ -6,17 +6,26 @@
 import logging
 from copy import deepcopy
 from functools import partial
-from typing import Any, Callable, Optional, Union
+from typing import Any, Callable, List, Optional, Union
 
 import torch
 import torch.nn as nn
+from torch.distributed._tensor import DeviceMesh, DTensor, Placement, Shard
 
-from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
+from llmfoundry.models.layers.dmoe import dMoE
+from llmfoundry.models.layers.layer_builders import build_fc
 
 try:
     import transformer_engine.pytorch as te
-except:
-    te = None
+    is_te_imported = True
+except ModuleNotFoundError:
+    is_te_imported = False
+
+try:
+    import megablocks
+    is_megablocks_imported = True
+except ModuleNotFoundError:
+    is_megablocks_imported = False
 
 log = logging.getLogger(__name__)
 
@@ -43,7 +52,7 @@ def resolve_ffn_act_fn(
     config = deepcopy(config)
     name = config.pop('name')
     if not hasattr(torch.nn.functional, name):
-        raise ValueError(f'Unrecognised activation function name ({name}).')
+        raise ValueError(f'Unrecognized activation function name ({name}).')
     act = getattr(torch.nn.functional, name)
     return partial(act, **config)
 
@@ -79,6 +88,18 @@ def resolve_ffn_hidden_size(
     return ffn_hidden_size
 
 
+def dtensorify_param(param: nn.Parameter, mesh: DeviceMesh,
+                     placements: List[Placement]):
+    """Construct a DTensor from an already sharded local parameter."""
+    param_dtensor = DTensor.from_local(
+        param.data,
+        device_mesh=mesh,
+        placements=placements,
+        run_check=False,
+    )
+    return nn.Parameter(param_dtensor)
+
+
 class MPTMLP(nn.Module):
 
     def __init__(
@@ -100,16 +121,18 @@ def __init__(
 
         self.fc_kwargs['device'] = device
 
-        self.up_proj = FC_CLASS_REGISTRY[fc_type](
-            d_model,
-            ffn_hidden_size,
-            **self.fc_kwargs,
+        self.up_proj = build_fc(
+            name=fc_type,
+            in_features=d_model,
+            out_features=ffn_hidden_size,
+            fc_kwargs=self.fc_kwargs,
         )
         self.act = act_fn
-        self.down_proj = FC_CLASS_REGISTRY[fc_type](
-            ffn_hidden_size,
-            d_model,
-            **self.fc_kwargs,
+        self.down_proj = build_fc(
+            name=fc_type,
+            in_features=ffn_hidden_size,
+            out_features=d_model,
+            fc_kwargs=self.fc_kwargs,
         )
         self.down_proj._is_residual = True
 
@@ -138,13 +161,13 @@ def __init__(
             device=device,
             bias=bias,
         )
-        self.gate_proj = FC_CLASS_REGISTRY[fc_type](
-            d_model,
-            self.up_proj.out_features,
-            **self.fc_kwargs,
+        self.gate_proj = build_fc(
+            name=fc_type,
+            in_features=d_model,
+            out_features=self.up_proj.out_features,
+            fc_kwargs=self.fc_kwargs,
         )
 
-    @torch.compile
     def forward(self, x: torch.Tensor) -> torch.Tensor:
         return self.down_proj(self.act(self.gate_proj(x)) * self.up_proj(x))
 
@@ -152,12 +175,20 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
 FFN_CLASS_REGISTRY = {
     'mptmlp': MPTMLP,
     'mptglu': MPTGLU,
+    'torch_dmoe': dMoE,
 }
 
-if te is not None:
+if is_te_imported:
+    import transformer_engine.pytorch as te
     te.LayerNormMLP._has_norm = True
     FFN_CLASS_REGISTRY['te_ln_mlp'] = te.LayerNormMLP
 
+if is_megablocks_imported:
+    import megablocks
+
+    FFN_CLASS_REGISTRY['mb_moe'] = megablocks.layers.moe.MoE
+    FFN_CLASS_REGISTRY['mb_dmoe'] = megablocks.layers.dmoe.dMoE
+
 
 def build_ffn(
     d_model: int,
@@ -185,7 +216,10 @@ def build_ffn(
             bias=bias,
         )
     elif ffn_type == 'te_ln_mlp':
-        assert te is not None
+        if te is None:
+            raise RuntimeError(
+                'Requirements for TransformerEngine not installed; see install instructions in `README.md`.'
+            )
         ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
                                                   ffn_hidden_size)
         if ffn_act_fn is not None:
@@ -198,5 +232,99 @@ def build_ffn(
             bias=bias,
             **kwargs,
         )
+    elif ffn_type in ('mb_moe', 'mb_dmoe'):
+        if megablocks is None:
+            raise RuntimeError(
+                'Requirements for megablocks not installed; see install instructions in `README.md`.'
+            )
+        args = kwargs['args']
+        args.bias = bias
+        args.hidden_size = d_model
+        args.device = device
+
+        ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio,
+                                                  ffn_hidden_size)
+        args.ffn_hidden_size = ffn_hidden_size
+
+        if ffn_act_fn is not None:
+            args.activation_fn = resolve_ffn_act_fn(ffn_act_fn)
+
+        moe_world_size = 1
+        expert_parallel_group = args.expert_parallel_group
+        if expert_parallel_group is not None:
+            moe_world_size = expert_parallel_group.size()
+        if kwargs.get('moe_world_size') != moe_world_size:
+            raise RuntimeError(
+                f'MoE expert_parallel_group configured with incorrect world size.'
+            )
+
+        if ffn_type == 'mb_moe':
+            ffn = megablocks.layers.moe.MoE(args)
+
+            # Fused initialization setup
+            # For param_init_fn, enables shape based init of stacked layers
+            ffn.experts.mlp._stack_dim = 0
+        elif ffn_type == 'mb_dmoe':
+            ffn = megablocks.layers.dmoe.dMoE(args)
+
+            # Fused initialization setup
+            # For param_init_fn, enables shape based init of fused layers
+            n_exp = min(1, args.moe_num_experts // moe_world_size)
+            ffn.experts.mlp._fused = (0, [
+                (n + 1) * args.ffn_hidden_size for n in range(n_exp - 1)
+            ])
+        else:
+            raise RuntimeError(f'Invalid ffn_type option: {ffn_type}.')
+
+        # Attach args to MLP directly for use in param_init_fn
+        ffn.experts.mlp.hidden_size = args.ffn_hidden_size
+        ffn.experts.mlp.expert_parallel_group = expert_parallel_group
+        ffn.experts.mlp.weight_parallel_group = args.weight_parallel_group
+
+        if moe_world_size > 1:
+            device_mesh = kwargs['device_mesh']
+
+            expert_mesh = device_mesh['expert_parallel']
+            expert_placements: List[Placement] = [Shard(0)]
+            # Register in two loops as you cannot overwrite parameters while iterating over named_parameters()
+            dtensorified_params = [
+                (name,
+                 dtensorify_param(param=parameter,
+                                  mesh=expert_mesh,
+                                  placements=expert_placements))
+                for name, parameter in ffn.experts.mlp.named_parameters()
+            ]
+            for name, dtensorified_param in dtensorified_params:
+                ffn.experts.mlp.register_parameter(name, dtensorified_param)
+
+            device_mesh = kwargs['device_mesh']
+            if device_mesh.mesh.ndim == 2:
+                submesh = device_mesh['weight_parallel']
+            elif device_mesh.mesh.ndim == 3:
+                raise RuntimeError(f'HSDP + MoE is not supported.')
+            else:
+                raise ValueError(
+                    f'{device_mesh.mesh.ndim=} not supported for MoE.')
+
+            ffn.experts._fsdp_kwargs_dict = {
+                'device_mesh': submesh,
+            }
+        return ffn
+    elif ffn_type == 'torch_dmoe':
+        return dMoE(
+            hidden_size=d_model,
+            ffn_hidden_size=resolve_ffn_hidden_size(d_model, expansion_ratio,
+                                                    ffn_hidden_size),
+            moe_num_experts=kwargs.pop('moe_num_experts'),
+            moe_top_k=kwargs.pop('moe_top_k'),
+            mlp_type=kwargs.pop('mlp_type'),
+            bias=bias,
+            moe_jitter_eps=kwargs.pop('moe_jitter_eps'),
+            activation_fn=resolve_ffn_act_fn(ffn_act_fn),
+            moe_normalize_expert_weights=kwargs.pop(
+                'moe_normalize_expert_weights'),
+            uniform_expert_assignment=kwargs.pop('uniform_expert_assignment'),
+            device=device,  # pyright: ignore[reportGeneralTypeIssues]
+        )
 
     raise ValueError(f'{ffn_type=} not recognized.')
diff --git a/llmfoundry/models/layers/layer_builders.py b/llmfoundry/models/layers/layer_builders.py
index 23f5b89668..6a725d469a 100644
--- a/llmfoundry/models/layers/layer_builders.py
+++ b/llmfoundry/models/layers/layer_builders.py
@@ -1,11 +1,11 @@
 # Copyright 2024 MosaicML LLM Foundry authors
 # SPDX-License-Identifier: Apache-2.0
 
-from typing import List, Optional, Union
+from typing import Any, Dict, List, Optional, Union
 
 import torch
 
-from llmfoundry.layers_registry import norms
+from llmfoundry.layers_registry import attention_classes, fcs, norms
 from llmfoundry.utils.registry_utils import construct_from_registry
 
 
@@ -23,3 +23,31 @@ def build_norm(
                                    registry=norms,
                                    pre_validation_function=torch.nn.Module,
                                    kwargs=kwargs)
+
+
+def build_attention_layer(
+    name: str,
+    attn_kwargs: Dict[str, Any],
+):
+    return construct_from_registry(name=name,
+                                   registry=attention_classes,
+                                   pre_validation_function=torch.nn.Module,
+                                   kwargs=attn_kwargs)
+
+
+def build_fc(
+    name: str,
+    in_features: int,
+    out_features: int,
+    fc_kwargs: Dict[str, Any],
+):
+    kwargs = {
+        'in_features': in_features,
+        'out_features': out_features,
+        **fc_kwargs,
+    }
+
+    return construct_from_registry(name=name,
+                                   registry=fcs,
+                                   pre_validation_function=torch.nn.Module,
+                                   kwargs=kwargs)
diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py
index 2f58ea312e..4b98fa611d 100644
--- a/llmfoundry/models/mpt/configuration_mpt.py
+++ b/llmfoundry/models/mpt/configuration_mpt.py
@@ -16,10 +16,10 @@
 # HuggingFace can detect all the needed files to copy into its modules folder.
 # Otherwise, certain modules are missing.
 # isort: off
-from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY  # type: ignore (see note)
 from llmfoundry.models.layers.norm import LPLayerNorm  # type: ignore (see note)
 from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY  # type: ignore (see note)
-from llmfoundry.models.layers.layer_builders import build_norm  # type: ignore (see note)
+from llmfoundry.models.layers.layer_builders import build_norm, build_fc  # type: ignore (see note)
+from llmfoundry.models.layers.dmoe import dMoE  # type: ignore (see note)
 from llmfoundry.layers_registry import norms  # type: ignore (see note)
 from llmfoundry.utils.registry_utils import construct_from_registry  # type: ignore (see note)
 
@@ -290,6 +290,8 @@ def _validate_config(self) -> None:
             )
         elif self.ffn_config['ffn_type'] in ['mptmlp', 'mptglu']:
             self.ffn_config['fc_type'] = self.fc_type
+        elif self.ffn_config['ffn_type'] in ['mb_moe', 'mb_dmoe']:
+            self.ffn_config['return_bias'] = False
         elif self.ffn_config['ffn_type'] == 'te_ln_mlp':
             self.ffn_config['bias'] = not self.no_bias
             if 'ffn_act_fn' in self.ffn_config.keys():
diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py
index d54b797269..4a8f3943af 100644
--- a/llmfoundry/models/mpt/modeling_mpt.py
+++ b/llmfoundry/models/mpt/modeling_mpt.py
@@ -10,6 +10,7 @@
 
 import math
 import warnings
+from functools import cached_property
 from typing import (Any, Dict, List, Mapping, MutableMapping, Optional, Tuple,
                     Union)
 
@@ -49,6 +50,9 @@
 from llmfoundry.models.layers.ffn import build_ffn as build_ffn
 from llmfoundry.models.layers.layer_builders import build_norm
 from llmfoundry.models.mpt.configuration_mpt import MPTConfig
+from llmfoundry.models.utils.config_moe_args import config_moe_args
+from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params,
+                                                     mpt_get_total_params)
 
 # NOTE: All utils are imported directly even if unused so that
 # HuggingFace can detect all the needed files to copy into its modules folder.
@@ -276,6 +280,8 @@ def _fsdp_wrap_fn(
     module: nn.Module,
 ) -> bool:
     # FSDP Wrap function for MPT Models
+    if hasattr(module, '_fsdp_kwargs_dict'):
+        return module._fsdp_kwargs_dict
     return isinstance(module, MPTBlock)
 
 
@@ -316,10 +322,20 @@ def __init__(self, config: MPTConfig):
                                           config.d_model,
                                           device=config.init_device)
         self.emb_drop = nn.Dropout(config.emb_pdrop)
+        self.mb_args = None
+        block_args = config.to_dict()
+        if block_args['ffn_config']['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+            block_args['ffn_config'] = config_moe_args(
+                block_args['ffn_config'],
+                config.d_model,
+                config.expansion_ratio,
+                config.n_layers,
+            )
+            self.mb_args = block_args['ffn_config'].get('args')
         self.blocks = nn.ModuleList([
             MPTBlock(
                 device=config.init_device,
-                **config.to_dict(),
+                **block_args,
             ) for _ in range(config.n_layers)
         ])
 
@@ -980,8 +996,6 @@ def __init__(
             allow_embedding_resizing=True,
         )
 
-        self.n_active_params = sum(p.numel() for p in self.parameters())
-
         loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy')
         if loss_fn_config == 'fused_crossentropy':
             try:
@@ -1012,6 +1026,15 @@ def get_targets(self, batch: Mapping) -> torch.Tensor:
         return targets
 
     def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
+        if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+            # Clear MegaBlocks MoE load balancing loss cache
+            try:  # Add try/catch to avoid transformers complaining and raising errors
+                from megablocks.layers.moe import clear_load_balancing_loss
+            except:
+                raise RuntimeError(
+                    'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
+                )
+            clear_load_balancing_loss()
         return self.model(
             input_ids=batch.get('input_ids', None),
             attention_mask=batch.get('attention_mask', None),
@@ -1020,7 +1043,7 @@ def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast:
         )
 
     def loss(self, outputs: CausalLMOutputWithPast,
-             batch: Mapping) -> torch.Tensor:
+             batch: Mapping) -> Union[dict, torch.Tensor]:
         targets = self.get_targets(batch)
         losses = self.loss_fn(outputs.logits.view(-1, outputs.logits.size(-1)),
                               targets.view(-1))
@@ -1030,18 +1053,40 @@ def loss(self, outputs: CausalLMOutputWithPast,
         else:
             loss = losses.sum() / (targets != self.loss_fn.ignore_index).sum()
 
+        if self.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+            # MegaBlocks MoE load balancing loss
+            try:  # Add try/catch to avoid transformers complaining and raising errors
+                from megablocks.layers.moe import batched_load_balancing_loss
+            except:
+                raise RuntimeError(
+                    'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
+                )
+            lbl = batched_load_balancing_loss(self.model.transformer.mb_args)
+            return {
+                'total': loss + lbl,
+                'loss': loss,
+                'lbl': lbl,
+            }
+
         return loss
 
-    def flops_per_batch(self, batch: Mapping) -> int:
+    @cached_property
+    def n_total_params(self):
+        """Gets the total number of parameters in the model."""
+        return mpt_get_total_params(self)
+
+    @cached_property
+    def n_active_params(self):
+        """Gets the total number of active parameters in the model."""
+        return mpt_get_active_params(self)
+
+    def flops_per_batch(self, batch: Mapping):
         # Note: this computation does not take into account padding, and assumes
         # that the dataset has been constructed without padding. Additionally, we
         # assume the backward pass is approximately 2x the forward pass
 
         bs, msl = batch['input_ids'].shape[0:2]
         params = self.n_active_params
-        if not self.model.transformer.config.tie_word_embeddings:
-            # embedding layers are lookup tables, therefore are not counted in the FLOP computation
-            params -= self.model.transformer.wte.weight.numel()
         params_flops_per_token = 2 * params
         params_flops_per_seq = params_flops_per_token * msl
         attn_flops_per_seq = (self.model.config.n_layers * 2 * 2 *
diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py
index 7c808ff449..41313b8729 100644
--- a/llmfoundry/models/utils/__init__.py
+++ b/llmfoundry/models/utils/__init__.py
@@ -1,8 +1,11 @@
 # Copyright 2022 MosaicML LLM Foundry authors
 # SPDX-License-Identifier: Apache-2.0
 
+from llmfoundry.models.utils.config_moe_args import config_moe_args
 from llmfoundry.models.utils.meta_init_context import (init_empty_weights,
                                                        init_on_device)
+from llmfoundry.models.utils.mpt_param_count import (mpt_get_active_params,
+                                                     mpt_get_total_params)
 from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY,
                                                     generic_param_init_fn_)
 
@@ -11,4 +14,7 @@
     'init_on_device',
     'generic_param_init_fn_',
     'MODEL_INIT_REGISTRY',
+    'config_moe_args',
+    'mpt_get_active_params',
+    'mpt_get_total_params',
 ]
diff --git a/llmfoundry/models/utils/act_ckpt.py b/llmfoundry/models/utils/act_ckpt.py
index bde7c92bd7..fea68492c1 100644
--- a/llmfoundry/models/utils/act_ckpt.py
+++ b/llmfoundry/models/utils/act_ckpt.py
@@ -5,9 +5,8 @@
 
 import torch
 
-from llmfoundry.layers_registry import norms
-from llmfoundry.models.layers.attention import ATTN_CLASS_REGISTRY
-from llmfoundry.models.layers.blocks import MPTBlock
+from llmfoundry.layers_registry import attention_classes, norms
+from llmfoundry.models.layers.blocks import FusedNormAttentionNorm, MPTBlock
 from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY
 
 
@@ -25,16 +24,19 @@ def get_act_ckpt_module(mod_name: str) -> Any:
     """Get the module type from the module name."""
     if mod_name.lower() == 'mptblock':
         mod_type = MPTBlock
-    elif mod_name in ATTN_CLASS_REGISTRY:
-        mod_type = ATTN_CLASS_REGISTRY[mod_name]
+    elif mod_name in attention_classes:
+        mod_type = attention_classes.get(mod_name)
+    elif mod_name.lower() == 'norm_attn_norm':
+        mod_type = FusedNormAttentionNorm
     elif mod_name in FFN_CLASS_REGISTRY:
         mod_type = FFN_CLASS_REGISTRY[mod_name]
     elif mod_name in norms:
         mod_type = norms.get(mod_name)
     else:
         msg = ', '.join(
-            list(ATTN_CLASS_REGISTRY.keys()) + list(FFN_CLASS_REGISTRY.keys()) +
-            list(norms.get_all()) + ['MPTBlock'])
+            list(attention_classes.get_all()) +
+            list(FFN_CLASS_REGISTRY.keys()) + list(norms.get_all()) +
+            ['MPTBlock'])
         raise ValueError(
             f'{mod_name} (specified in activation_checkpointing_target) is not a recognized option out of available options {msg}.'
         )
diff --git a/llmfoundry/models/utils/config_moe_args.py b/llmfoundry/models/utils/config_moe_args.py
new file mode 100644
index 0000000000..1f7132c281
--- /dev/null
+++ b/llmfoundry/models/utils/config_moe_args.py
@@ -0,0 +1,188 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Helper function to configure MPT with MoEs."""
+
+from typing import Union
+
+import torch
+from packaging import version
+from torch import distributed
+
+from llmfoundry.models.layers.ffn import resolve_ffn_hidden_size
+
+
+def create_process_group_ranks(ranks: tuple[int]):
+    """Creates a new distributed group.
+
+    Used in create_set_process_group and create_mod_process_group methods below.
+
+    This function is an alternative to `distributed.new_group(ranks)`.
+
+    Args:
+        ranks (tuple[int]): Tuple of ranks of group members.
+
+    Returns:
+        A handle of distributed group that can be given to collective calls.
+    """
+    ranks_gather_list = [None for _ in range(distributed.get_world_size())]
+    distributed.all_gather_object(ranks_gather_list, ranks)
+    ranks_per_subgroup = list(set(ranks_gather_list))
+    group, _ = distributed.distributed_c10d.new_subgroups_by_enumeration(
+        ranks_per_subgroup)
+    return group
+
+
+def create_set_process_group(k: int):
+    """Creates a new distributed group using sets of k GPUs.
+
+    For example, if you have 16 GPUs and input k=4, the resulting process groups
+    will have ranks:
+        process group 0 ranks: [ 0,  1,  2,  3]
+        process group 1 ranks: [ 4,  5,  6,  7]
+        process group 2 ranks: [ 8,  9, 10, 11]
+        process group 3 ranks: [12, 13, 14, 15]
+
+    Args:
+        k (int): Number of GPUs to use in set size.
+
+    Returns:
+        A handle of distributed group that can be given to collective calls.
+    """
+    world_size = distributed.get_world_size()
+    if world_size % k != 0:
+        raise RuntimeError(f'{world_size=} must be divisible by {k=}.')
+    start = distributed.get_rank() // k * k
+    ranks = tuple(range(start, start + k))
+    return create_process_group_ranks(ranks)
+
+
+def config_megablocks_moe_args(
+    ffn_config: dict,
+    d_model: int,
+    expansion_ratio: Union[int, float],
+    n_layers: int,
+) -> dict:
+    """Configures `ffn_config` for MegaBlocks MoE.
+
+    We prepare all necessary arguments for `megablocks.layers.arguments.Arguments` so that process
+    groups can be initialized and shared across all blocks in the network.
+
+    Args:
+        ffn_config (dict): FFN configuation before the MegaBlocks MoE is configured.
+        d_model (int): Hidden size of the network.
+        expansion_ratio (Union[int, float]): Expansion ratio in FFN.
+        n_layers (int): Number of blocks used in the network.
+
+    Returns:
+        ffn_config (dict): FFN configuration with MegaBlocks MoE configured.
+    """
+    try:
+        import megablocks
+    except:
+        raise RuntimeError(
+            'Requirements for MegaBlocks not installed; see install instructions in `README.md`.'
+        )
+
+    ffn_config.setdefault('fp16', False)
+    ffn_config.setdefault('bf16', False)
+    ffn_config['num_layers'] = n_layers
+
+    ffn_type = ffn_config.pop('ffn_type')
+    fc_type = ffn_config.pop('fc_type')
+    ffn_act_fn = ffn_config.pop('ffn_act_fn', None)
+
+    # Config for MegaBlocks MoE world size and device mesh
+    world_size = 1  # default
+    moe_world_size = ffn_config.pop('moe_world_size')
+    device_mesh = None
+    device_mesh_cfg = ffn_config.pop('device_mesh', None)
+    if moe_world_size > 1:
+        if version.parse(torch.__version__.split('.dev')[0]) < version.parse(
+                '2.2.0'):  # type: ignore
+            raise RuntimeError(
+                'MoE world size > 1 is not supported in torch version {torch.__version__}<2.2.'
+            )
+
+        from torch.distributed._tensor.device_mesh import init_device_mesh
+
+        world_size = distributed.get_world_size()
+        if world_size < moe_world_size or world_size % moe_world_size:
+            raise ValueError(
+                f'Invalid world size configuration: {world_size=} and {moe_world_size=}'
+            )
+
+        # FSDP
+        if device_mesh_cfg is None or len(device_mesh_cfg) == 1:
+            if device_mesh_cfg is not None:
+                world_size = device_mesh_cfg[0]
+            sharding_group_dim = world_size // moe_world_size
+            device_mesh = init_device_mesh(
+                'cuda',
+                (sharding_group_dim, moe_world_size),
+                mesh_dim_names=('weight_parallel', 'expert_parallel'),
+            )
+        else:
+            raise ValueError(f'{device_mesh_cfg=} must be length 1')
+
+        ffn_config['moe_expert_model_parallelism'] = True
+        ffn_config['expert_parallel_group'] = device_mesh[
+            'expert_parallel'].get_group(0)  # type: ignore
+
+    lbl_process_group = ffn_config.get('lbl_process_group', None)
+    if lbl_process_group is not None:
+        if lbl_process_group == 'expert_group':
+            lbl_process_group = ffn_config['expert_parallel_group']
+        elif lbl_process_group == 'global_group':
+            lbl_process_group = distributed.group.WORLD
+        elif isinstance(lbl_process_group, int):
+            lbl_process_group = create_set_process_group(lbl_process_group)
+        elif lbl_process_group is not None:
+            raise ValueError(
+                f'Unknown {lbl_process_group=}. Options are: none | expert_group | global_group | <GROUP_SIZE>.'
+            )
+        ffn_config['lbl_process_group'] = lbl_process_group
+
+    ffn_hidden_size = resolve_ffn_hidden_size(d_model, expansion_ratio)
+    ffn_config.setdefault('ffn_hidden_size', ffn_hidden_size)
+
+    args = megablocks.layers.arguments.Arguments(
+        hidden_size=d_model,
+        **ffn_config,
+    )
+    ffn_config['args'] = args
+    ffn_config['device_mesh'] = device_mesh
+    ffn_config['moe_world_size'] = moe_world_size
+    ffn_config['ffn_type'] = ffn_type
+    ffn_config['fc_type'] = fc_type
+    ffn_config['ffn_act_fn'] = ffn_act_fn
+
+    return ffn_config
+
+
+def config_moe_args(
+    ffn_config: dict,
+    d_model: int,
+    expansion_ratio: Union[int, float],
+    n_layers: int,
+) -> dict:
+    """Configures `ffn_config` for MoE.
+
+    Args:
+        ffn_config (dict): FFN configuation before the MoE is configured.
+        d_model (int): Hidden size of the network.
+        expansion_ratio (int, float): Expansion ratio in FFN.
+        n_layers (int): Number of blocks used in the network.
+
+    Returns:
+        ffn_config (dict): FFN configuration with MoE configured.
+    """
+    if ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+        return config_megablocks_moe_args(
+            ffn_config=ffn_config,
+            d_model=d_model,
+            expansion_ratio=expansion_ratio,
+            n_layers=n_layers,
+        )
+    else:
+        raise ValueError(f'Invalid ffn_type ({ffn_config["ffn_type"]}).')
diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py
index c22c226c28..d72a289a73 100644
--- a/llmfoundry/models/utils/meta_init_context.py
+++ b/llmfoundry/models/utils/meta_init_context.py
@@ -21,6 +21,7 @@
 
 import torch
 import torch.nn as nn
+from torch.distributed._tensor import DTensor
 
 
 @contextmanager
@@ -86,11 +87,13 @@ def register_empty_parameter(self: torch.nn.Module, name: str,
         if param is not None:
             parameter = self._parameters[name]
             assert parameter is not None
-
-            param_cls = type(parameter)
-            kwargs = parameter.__dict__
-
-            self._parameters[name] = param_cls(parameter.to(device), **kwargs)
+            if isinstance(parameter, DTensor):
+                self._parameters[name] = parameter.to(device)  # type: ignore
+            else:
+                param_cls = type(parameter)
+                kwargs = parameter.__dict__
+                self._parameters[name] = param_cls(parameter.to(device),
+                                                   **kwargs)
 
     def register_empty_buffer(self: torch.nn.Module,
                               name: str,
diff --git a/llmfoundry/models/utils/mpt_param_count.py b/llmfoundry/models/utils/mpt_param_count.py
new file mode 100644
index 0000000000..d90929713b
--- /dev/null
+++ b/llmfoundry/models/utils/mpt_param_count.py
@@ -0,0 +1,167 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Helper functions for computing parameter counts for MPT model.
+
+Use if generic `sum(p.numel() for p in self.parameters())`
+style computation does not account for MoE parameter sharding.
+The helper functions in this file account for MoE parameter
+sharding in the parameter count calculation. The functions below
+calculate the total parameter count and the active parameter count.
+Note: MPT has both n_total_params and n_active_params methods.
+"""
+
+from typing import Union
+
+from torch import Tensor, nn
+from torch.distributed._tensor import DTensor
+
+
+def module_n_params(module: nn.Module) -> int:
+    """Gets the number of parameters in this module excluding child modules.
+
+    Args:
+        module (nn.Module): Module of which we get the number of parameters.
+
+    Returns:
+        An int for the number of parameters in this module.
+    """
+    n_params = 0
+    for p in module.parameters(recurse=False):
+        n_params += p.numel()
+    return n_params
+
+
+def _dtensor_safe_check_numel(tensor: Union[Tensor, DTensor]) -> int:
+    if isinstance(tensor, DTensor):
+        tensor = tensor._local_tensor
+    return tensor.numel()
+
+
+def megablocks_n_total_params(mpt_model) -> int:  # type: ignore
+    """Calculates the number of parameters in a MegaBlocks enabled MPT model.
+
+    MoE experts are sharded across workers. This function scans for MegaBlocks
+    modules then multiplies expert params count by MoE world size.
+
+    Args:
+        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
+            parameters is calculated.
+
+    Returns:
+        An int for the total number of parameters in this MPT model.
+    """
+    import megablocks
+
+    moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
+
+    if mpt_model.config.ffn_config.get('moe_weight_parallelism', False):
+        # If MegaBlocks shards experts, the total sharding world size
+        # must be increased by the degree to which MegaBlocks shards the
+        # experts.
+        mb_args = mpt_model.model.transformer.mb_args
+        moe_world_size *= mb_args.weight_parallel_group.size()
+
+    n_total_params = 0
+    for module in mpt_model.modules():
+        if isinstance(
+                module,
+            (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
+            n_w1 = _dtensor_safe_check_numel(module.w1)
+            n_total_params += n_w1 * moe_world_size
+            n_w2 = _dtensor_safe_check_numel(module.w2)
+            n_total_params += n_w2 * moe_world_size
+
+            # GLU has an extra weight
+            if hasattr(module, 'v1'):
+                n_v1 = _dtensor_safe_check_numel(module.v1)
+                n_total_params += n_v1 * moe_world_size
+        else:
+            n_total_params += module_n_params(module)
+
+    return n_total_params
+
+
+def megablocks_n_active_params(mpt_model) -> int:  # type: ignore
+    """Calculates the number of active parameters in a MegaBlocks enabled MPT.
+
+    This requires we calculate the number of elements per expert and
+    multiply this by top k.
+
+    Args:
+        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
+            active parameters is calculated.
+
+    Returns:
+        An int for the active number of parameters in this MPT model.
+    """
+    import megablocks
+
+    moe_num_experts = mpt_model.config.ffn_config.get('moe_num_experts', 1)
+    moe_world_size = mpt_model.config.ffn_config.get('moe_world_size')
+
+    local_experts = moe_num_experts / moe_world_size  # if local_experts is < 1, then the expert is sharded
+    if mpt_model.config.ffn_config.get('moe_weight_parallelism', False):
+        mb_args = mpt_model.model.transformer.mb_args
+        local_experts /= mb_args.weight_parallel_group.size()
+
+    moe_top_k = mpt_model.config.ffn_config.get('moe_top_k', 1)
+    n_active_params = 0
+    for module in mpt_model.modules():
+        if isinstance(
+                module,
+            (megablocks.layers.mlp.SparseMLP, megablocks.layers.mlp.MLP)):
+            n_w1 = _dtensor_safe_check_numel(module.w1)
+            n_active_params += int(n_w1 / local_experts * moe_top_k)
+            n_w2 = _dtensor_safe_check_numel(module.w2)
+            n_active_params += int(n_w2 / local_experts * moe_top_k)
+
+            # GLU has an extra weight
+            if hasattr(module, 'v1'):
+                n_v1 = _dtensor_safe_check_numel(module.v1)
+                n_active_params += int(n_v1 / local_experts * moe_top_k)
+        else:
+            n_active_params += module_n_params(module)
+
+    return n_active_params
+
+
+def mpt_get_total_params(mpt_model) -> int:  # type: ignore
+    """Calculates the total paramter count of an MPT model.
+
+    Note: Must be called before model parameters are sharded by FSDP.
+
+    Args:
+        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
+            active parameters is calculated.
+
+    Returns:
+        An int for the total number of parameters in this MPT model.
+    """
+    if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+        return megablocks_n_total_params(mpt_model)
+    else:
+        return sum(p.numel() for p in mpt_model.parameters())
+
+
+def mpt_get_active_params(mpt_model) -> int:  # type: ignore
+    """Calculates the total paramter count of an MPT model.
+
+    Note: Must be called before model parameters are sharded by FSDP.
+
+    Args:
+        mpt_model (ComposerMPTCausalLM): MPT model of which the number of
+            active parameters is calculated.
+
+    Returns:
+        An int for the active number of parameters in this MPT model.
+    """
+    if mpt_model.config.ffn_config['ffn_type'] in ('mb_moe', 'mb_dmoe'):
+        params = megablocks_n_active_params(mpt_model)
+    else:
+        params = sum(p.numel() for p in mpt_model.parameters())
+    if not mpt_model.model.transformer.config.tie_word_embeddings:
+        # Embedding layers are lookup tables, therefore are not counted in the FLOP computation
+        params -= _dtensor_safe_check_numel(
+            mpt_model.model.transformer.wte.weight)
+    return params
diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py
index 35dc88a408..bd409dee36 100644
--- a/llmfoundry/models/utils/param_init_fns.py
+++ b/llmfoundry/models/utils/param_init_fns.py
@@ -4,20 +4,27 @@
 import math
 import warnings
 from collections.abc import Sequence
+from copy import deepcopy
 from functools import partial
 from typing import Any, Callable, Optional, Tuple, Union
 
 import torch
 from torch import nn
+from torch.distributed._tensor import DTensor
 
-from llmfoundry.layers_registry import norms
-from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY
+from llmfoundry.layers_registry import fcs, norms
+from llmfoundry.models.layers.dmoe import GLU, MLP
 
 try:
     import transformer_engine.pytorch as te
 except:
     te = None
 
+try:
+    import megablocks
+except:
+    megablocks = None
+
 
 def torch_default_param_init_fn_(
     module: nn.Module,
@@ -30,27 +37,114 @@ def torch_default_param_init_fn_(
         module.reset_parameters()
 
 
-def fused_init_helper_(module: nn.Module, init_fn_: Callable) -> None:
-    # parameter initialization is often based on the parameters shape.
-    # If a layer is fused, initialization should be based on the shapes
-    # of the original tensor instead of the shape of the fused tensor.
-    # Layers which are fused should have the _fused attribute defined.
-    # The first element of _fused is the dimension along which the tensor is fused.
-    # This is followed by an iterable of split indices."
-
+def fused_init_helper_(
+    module: nn.Module,
+    init_fn_: Callable,
+    name_param: str = 'weight',
+):
+    """Initializes parameters which have been fused for efficiency purposes.
+
+    Parameter initialization is often based on the parameters shape. If a layer is fused,
+    initialization should be based on the shapes of the original tensor instead of the
+    shape of the fused tensor. Layers which are fused should have the _fused
+    attribute. First element of _fused is the dimension along which the tensor is fused.
+    Second element is a an iterable of split indices.
+
+    Args:
+        module (nn.Module): The module to initialize.
+        init_fn_ (Callable): Initialization method.
+        name_param (str): Name of parameter to initalize within the module.
+    """
     _fused = getattr(module, '_fused', None)
-
     if _fused is None:
         raise RuntimeError(f'Internal logic error')
 
-    assert isinstance(module.weight, torch.Tensor)
+    fused_param_init_helper(getattr(module, name_param), init_fn_, _fused)
+
 
-    dim, splits = _fused
-    splits = (0, *splits, module.weight.size(dim))
+def fused_param_init_helper(
+    param: torch.Tensor,
+    init_fn_: Callable,
+    fused_parameters: tuple[int, list[int]],
+):
+    """Initializes parameters that are fused together.
+
+    Args:
+        param (torch.Tensor): Tensor to initialize.
+        init_fn_ (Callable): Initialization method.
+        fused_parameters (tuple[int, list[int]]): First element of _fused is the dimension
+            along which the tensor is fused. Second element is a an iterable of split indices.
+    """
+    p_ndims = param.ndim
+    dim, splits = fused_parameters
+    splits = (0, *splits, param.size(dim))  # type: ignore
     for s, e in zip(splits[:-1], splits[1:]):
-        slice_indices = [slice(None)] * module.weight.ndim
+        slice_indices = [slice(None)] * p_ndims  # type: ignore
         slice_indices[dim] = slice(s, e)
-        init_fn_(module.weight[slice_indices])
+        init_fn_(param[slice_indices])  # type: ignore
+
+
+def stacked_init_helper_(
+    module: nn.Module,
+    init_fn_: Callable,
+    name_param: str = 'weight',
+):
+    """Initializes parameters stacked along a new dimention.
+
+    Parameter initialization is often based on the parameters shape. If a layer is stacked,
+    initialization should be based on the shapes of the original tensor instead of the
+    shape of the stacked tensor. Layers which are fused should have the _stacked_dim
+    attribute defining the new dimension along which they are stacked.
+
+    Args:
+        module (nn.Module): The module to initialize.
+        init_fn_ (Callable): Initialization method.
+        name_param (str): Name of parameter to initalize within the module.
+    """
+    stack_dim = getattr(module, '_stack_dim', None)
+    if stack_dim is None:
+        raise RuntimeError(f'Internal logic error')
+
+    stacked_param_init_helper(getattr(module, name_param), init_fn_, stack_dim)
+
+
+def stacked_param_init_helper(
+    param: torch.Tensor,
+    init_fn_: Callable,
+    stack_dim: int,
+):
+    """Initialize parameters stacked along a new dimention.
+
+    Args:
+        param (torch.Tensor): Tensor to initialize.
+        init_fn_ (Callable): Initialization method.
+        stack_dim (int): Dimention along with parameters are stacked
+    """
+    p_ndims = param.ndim
+
+    for idx in range(param.size(stack_dim)):
+        slice_indices = [slice(None)] * p_ndims  # type: ignore
+        slice_indices[stack_dim] = idx  # type: ignore
+        init_fn_(param[slice_indices])  # type: ignore
+
+
+def _flip_fan_mode(init_fn_: Callable):
+    """Changes the mode of an init_fn_.
+
+    init_fn_'s "mode" is set to operate on standard torch modules eg torch.nn.Linear.
+    If a custom layer transposes its weights before they are allied such that it is
+    opposite pytorch's conventions, we must flip the fan mode, from fan_in to fan_out.
+
+    Args:
+        init_fn_ (Callable): Initialization method.
+    """
+    _init_fn_ = deepcopy(init_fn_)
+    if 'mode' in _init_fn_.keywords:
+        if _init_fn_.keywords['mode'] == 'fan_in':
+            _init_fn_.keywords['mode'] = 'fan_out'
+        elif _init_fn_.keywords['mode'] == 'fan_out':
+            _init_fn_.keywords['mode'] = 'fan_in'
+    return _init_fn_
 
 
 def generic_param_init_fn_(
@@ -87,7 +181,7 @@ def generic_param_init_fn_(
             f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}'
         )
 
-    if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))):
+    if isinstance(module, tuple(set([fcs.get(n) for n in fcs.get_all()]))):
         # Linear
         if hasattr(module, '_fused'):
             fused_init_helper_(module, init_fn_)
@@ -191,6 +285,35 @@ def generic_param_init_fn_(
         with torch.no_grad():
             module.fc2_weight.div_(div_is_residual)  # type: ignore
 
+    elif megablocks is not None and isinstance(module, (
+            megablocks.layers.moe.MoE,
+            megablocks.layers.dmoe.dMoE,
+            megablocks.layers.moe.ParallelMLP,
+            megablocks.layers.dmoe.ParallelDroplessMLP,
+    )):
+        if hasattr(module, 'bias') and module.bias is not None:
+            # Initialize bias to 0
+            torch.nn.init.zeros_(module.bias)  # type: ignore
+    elif megablocks is not None and isinstance(module,
+                                               megablocks.layers.glu.SparseGLU):
+        _megablocks_sparse_glu_generic_param_init_fn_(
+            module, init_fn_, bool(init_div_is_residual), div_is_residual)
+    elif megablocks is not None and isinstance(module,
+                                               megablocks.layers.mlp.SparseMLP):
+        _megablocks_sparse_mlp_generic_param_init_fn_(
+            module, init_fn_, bool(init_div_is_residual), div_is_residual)
+    elif megablocks is not None and isinstance(module,
+                                               megablocks.layers.mlp.MLP):
+        _megablocks_mlp_generic_param_init_fn_(module, init_fn_,
+                                               bool(init_div_is_residual),
+                                               div_is_residual)
+    elif isinstance(module, GLU):
+        init_fn_(module.w1)
+        init_fn_(module.v1)
+        init_fn_(module.w2)
+    elif isinstance(module, MLP):
+        init_fn_(module.w1)
+        init_fn_(module.w2)
     else:
         for _ in module.parameters(recurse=False):
             # raise error if uninitialized module has any parameters
@@ -199,7 +322,197 @@ def generic_param_init_fn_(
             )
 
 
-def _normal_init_(std: float, mean: float = 0.0) -> Callable:
+def _megablocks_sparse_mlp_generic_param_init_fn_(
+    module: nn.Module,
+    init_fn_: Callable,
+    init_div_is_residual: bool = False,
+    div_is_residual: float = 1.0,
+):
+    """Initializes MegaBlocks MLP.
+
+    To enable elastic deterministic initialization, this method creates the entire
+    weight matrix then slice into the weight tensors such that the sampled weights
+    should not vary between moe world size for the same random seed.
+
+    Args:
+        module (nn.Module): The module to initialize.
+        init_fn_ (Callable): Initialization method.
+        init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
+            flag to be divided by div_is_residual.
+        div_is_residual (float): The value by which parameter initialization is divided
+            if init_div_is_residual flag is enabled.
+    """
+    expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0
+    if module.expert_parallel_group is not None:
+        expert_process_group_size = int(
+            module.expert_parallel_group.size())  # type: ignore
+        rank = int(module.expert_parallel_group.rank())  # type: ignore
+    if module.weight_parallel_group is not None:
+        weight_parallel_group_size = int(
+            module.weight_parallel_group.size())  # type: ignore
+        weight_parallel_group_rank = int(
+            module.weight_parallel_group.rank())  # type: ignore
+
+    hidden_size = int(module.hidden_size)  # type: ignore
+
+    # Initialize w1
+    w1 = module.w1
+    if isinstance(w1, DTensor):
+        w1 = w1._local_tensor
+    w1_size = list(w1.shape)  # type: ignore
+    w1_size[
+        0] = w1_size[0] * expert_process_group_size * weight_parallel_group_size
+
+    n_exp = w1_size[0] // hidden_size
+    _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])
+
+    _w1 = w1.new_empty(w1_size)  # type: ignore
+    fused_param_init_helper(_w1, init_fn_, _fused)
+    _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
+    _w1_local_slice = _w1_local.chunk(weight_parallel_group_size,
+                                      dim=0)[weight_parallel_group_rank]
+    with torch.no_grad():
+        w1.copy_(_w1_local_slice)  # type: ignore
+
+    # Initialize w2
+    w2 = module.w2
+    if isinstance(w2, DTensor):
+        w2 = w2._local_tensor
+    w2_size = list(w2.shape)  # type: ignore
+    w2_size[
+        0] = w2_size[0] * expert_process_group_size * weight_parallel_group_size
+    _w2 = w2.new_empty(w2_size)  # type: ignore
+    # MegaBlocks operates on w2 as x @ w2, so needs flipped fan mode
+    fused_param_init_helper(_w2, _flip_fan_mode(init_fn_), _fused)
+    _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
+    _w2_local_slice = _w2_local.chunk(weight_parallel_group_size,
+                                      dim=0)[weight_parallel_group_rank]
+    with torch.no_grad():
+        w2.copy_(_w2_local_slice)  # type: ignore
+    if init_div_is_residual is not False:
+        with torch.no_grad():
+            w2.div_(div_is_residual)  # type: ignore
+
+
+def _megablocks_sparse_glu_generic_param_init_fn_(
+    module: nn.Module,
+    init_fn_: Callable,
+    init_div_is_residual: bool = False,
+    div_is_residual: float = 1.0,
+):
+    """Initializes MegaBlocks Sparse GLU.
+
+    Extends the Megablocks Sparse MLP case to an additional weight v1 for GLUs.
+    This additional weight v1 has the same initialization procedure as w1 for MLPs.
+
+    Args:
+        module (nn.Module): The module to initialize.
+        init_fn_ (Callable): Initialization method.
+        init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
+            flag to be divided by div_is_residual.
+        div_is_residual (float): The value by which parameter initialization is divided
+            if init_div_is_residual flag is enabled.
+    """
+    # Init for w1 and w2 matrices
+    _megablocks_sparse_mlp_generic_param_init_fn_(
+        module=module,
+        init_fn_=init_fn_,
+        init_div_is_residual=init_div_is_residual,
+        div_is_residual=div_is_residual)
+
+    # Init ported from _megablocks_sparse_mlp_generic_param_init_fn_ for v1
+    expert_process_group_size, rank, weight_parallel_group_size, weight_parallel_group_rank = 1, 0, 1, 0
+    if module.expert_parallel_group is not None:
+        expert_process_group_size = int(
+            module.expert_parallel_group.size())  # type: ignore
+        rank = int(module.expert_parallel_group.rank())  # type: ignore
+    if module.weight_parallel_group is not None:
+        weight_parallel_group_size = int(
+            module.weight_parallel_group.size())  # type: ignore
+        weight_parallel_group_rank = int(
+            module.weight_parallel_group.rank())  # type: ignore
+
+    hidden_size = int(module.hidden_size)  # type: ignore
+
+    # Separately initialize v1
+    v1 = module.v1
+    if isinstance(v1, DTensor):
+        v1 = v1._local_tensor
+    v1_size = list(v1.shape)  # type: ignore
+    v1_size[
+        0] = v1_size[0] * expert_process_group_size * weight_parallel_group_size
+
+    n_exp = v1_size[0] // hidden_size
+    _fused = (0, [(n + 1) * hidden_size for n in range(n_exp - 1)])
+
+    _v1 = v1.new_empty(v1_size)  # type: ignore
+    fused_param_init_helper(_v1, init_fn_, _fused)
+    _v1_local = _v1.chunk(expert_process_group_size, dim=0)[rank]
+    _v1_local_slice = _v1_local.chunk(weight_parallel_group_size,
+                                      dim=0)[weight_parallel_group_rank]
+    with torch.no_grad():
+        v1.copy_(_v1_local_slice)  # type: ignore
+
+
+def _megablocks_mlp_generic_param_init_fn_(
+    module: nn.Module,
+    init_fn_: Callable,
+    init_div_is_residual: bool = False,
+    div_is_residual: float = 1.0,
+):
+    """Initializes MegaBlocks' MLP.
+
+    To enable elastic deterministic initialization, this method creates the entire
+    weight matrix then slice into the weight tensors such that the sampled weights
+    should not vary between moe world size for the same random seed.
+
+    Args:
+        module (nn.Module): The module to initialize.
+        init_fn_ (Callable): Initialization method.
+        init_div_is_residual (bool): Flag enabling parameters tagged with _is_residual
+            flag to be divided by div_is_residual.
+        div_is_residual (float): The value by which parameter initialization is divided
+            if init_div_is_residual flag is enabled.
+    """
+    expert_process_group_size, rank, weight_parallel_group_size, w_rank = 1, 0, 1, 0
+    if module.expert_parallel_group is not None:
+        expert_process_group_size = int(
+            module.expert_parallel_group.size())  # type: ignore
+        rank = int(module.expert_parallel_group.rank())  # type: ignore
+    if module.weight_parallel_group is not None:
+        weight_parallel_group_size = int(
+            module.weight_parallel_group.size())  # type: ignore
+        w_rank = int(module.weight_parallel_group.rank())  # type: ignore
+
+    _init_fn_ = _flip_fan_mode(init_fn_)
+
+    # Initialize w1
+    w1_size = list(module.w1.shape)  # type: ignore
+    w1_size[0] = w1_size[0] * expert_process_group_size
+    w1_size[1] = w1_size[1] * weight_parallel_group_size
+    _w1 = module.w1.new_empty(w1_size)  # type: ignore
+    stacked_param_init_helper(_w1, _init_fn_, module._stack_dim)  # type: ignore
+    _w1_local = _w1.chunk(expert_process_group_size, dim=0)[rank]
+    _w1_local_slice = _w1_local.chunk(weight_parallel_group_size, dim=1)[w_rank]
+    with torch.no_grad():
+        module.w1.copy_(_w1_local_slice)  # type: ignore
+
+    # Initialize w2
+    w2_size = list(module.w2.shape)  # type: ignore
+    w2_size[0] = w2_size[0] * expert_process_group_size
+    w2_size[1] = w2_size[1] * weight_parallel_group_size
+    _w2 = module.w2.new_empty(w2_size)  # type: ignore
+    stacked_param_init_helper(_w2, _init_fn_, module._stack_dim)  # type: ignore
+    _w2_local = _w2.chunk(expert_process_group_size, dim=0)[rank]
+    _w2_local_slice = _w2_local.chunk(weight_parallel_group_size, dim=1)[w_rank]
+    with torch.no_grad():
+        module.w2.copy_(_w2_local_slice)  # type: ignore
+    if init_div_is_residual is not False:
+        with torch.no_grad():
+            module.w2.div_(div_is_residual)  # type: ignore
+
+
+def _normal_init_(std: float, mean: float = 0.0):
     return partial(torch.nn.init.normal_, mean=mean, std=std)
 
 
@@ -263,8 +576,8 @@ def small_param_init_fn_(
     **kwargs: Any,
 ) -> None:
     del kwargs  # unused, just to capture any extra args from the config
-    # very close to kaiming normal
-    # from Transformers without Tears (2019) - Nguyen & Salazar
+    # Very close to kaiming normal
+    # From Transformers without Tears (2019) - Nguyen & Salazar
     std = math.sqrt(2 / (5 * d_model))
     _normal_param_init_fn_(
         module=module,
diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py
index 424075da3b..ef6629be10 100644
--- a/llmfoundry/registry.py
+++ b/llmfoundry/registry.py
@@ -12,7 +12,8 @@
 from transformers import PreTrainedTokenizerBase
 
 from llmfoundry.interfaces import CallbackWithConfig
-from llmfoundry.layers_registry import norms
+from llmfoundry.layers_registry import (attention_classes,
+                                        attention_implementations, fcs, norms)
 from llmfoundry.utils.registry_utils import create_registry
 
 _loggers_description = (
@@ -85,17 +86,24 @@
                              entry_points=True,
                              description=_schedulers_description)
 
-_models_description = """The models registry is used to register classes that implement the ComposerModel interface. The model
-constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`.
-Note: This will soon be updated to take in named kwargs instead of a config directly."""
+_models_description = (
+    'The models registry is used to register classes that implement the ComposerModel interface. '
+    +
+    'The model constructor should accept two arguments: an omegaconf DictConfig named `om_model_config` and a PreTrainedTokenizerBase named `tokenizer`. '
+    +
+    'Note: This will soon be updated to take in named kwargs instead of a config directly.'
+)
 models = create_registry('llmfoundry',
                          'models',
                          generic_type=Type[ComposerModel],
                          entry_points=True,
                          description=_models_description)
 
-_dataloaders_description = """The dataloaders registry is used to register functions that create a DataSpec. The function should take
-a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec."""
+_dataloaders_description = (
+    'The dataloaders registry is used to register functions that create a DataSpec. The function should take '
+    +
+    'a DictConfig, a PreTrainedTokenizerBase, and an int as arguments, and return a DataSpec.'
+)
 dataloaders = create_registry(
     'llmfoundry',
     'dataloaders',
@@ -103,7 +111,9 @@
     entry_points=True,
     description=_dataloaders_description)
 
-_metrics_description = """The metrics registry is used to register classes that implement the torchmetrics.Metric interface."""
+_metrics_description = (
+    'The metrics registry is used to register classes that implement the torchmetrics.Metric interface.'
+)
 metrics = create_registry('llmfoundry',
                           'metrics',
                           generic_type=Type[Metric],
@@ -121,4 +131,7 @@
     'metrics',
     'dataloaders',
     'norms',
+    'attention_classes',
+    'attention_implementations',
+    'fcs',
 ]
diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py
index 1ef2a8cacf..298e1bc984 100644
--- a/llmfoundry/tokenizers/tiktoken.py
+++ b/llmfoundry/tokenizers/tiktoken.py
@@ -1,5 +1,6 @@
-# Copyright 2022 MosaicML LLM Foundry authors
+# Copyright 2024 MosaicML LLM Foundry authors
 # SPDX-License-Identifier: Apache-2.0
+
 from functools import lru_cache
 from typing import Any, Dict, List, Optional, Tuple
 
diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py
index 0edbae80a5..d2c3b733c0 100644
--- a/llmfoundry/utils/config_utils.py
+++ b/llmfoundry/utils/config_utils.py
@@ -129,6 +129,17 @@ def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
             fsdp_config.setdefault('use_orig_params', False)
             fsdp_config.setdefault('load_monolith_rank0_only', True)
 
+    # Set ffn_config.device_mesh to fsdp_config.device_mesh
+    if fsdp_config is not None and 'device_mesh' in fsdp_config and 'ffn_config' in model_cfg and model_cfg[
+            'ffn_config'].get('ffn_type', None) in {'mb_moe', 'mb_dmoe'}:
+        # Raise ValueError if not using device mesh with MoE expert parallelism
+        if fsdp_config['device_mesh'] is None and model_cfg['ffn_config'].get(
+                'moe_world_size', 1) > 1:
+            raise ValueError(
+                'device_mesh must be specified in fsdp_config when using MoE with moe_world_size > 1.'
+            )
+        model_cfg.ffn_config.device_mesh = fsdp_config['device_mesh']
+
     # No mixed precision needed for weights when they're already 16 bits
     master_dtype = model_cfg.get('master_weights_dtype')
     small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16',
diff --git a/llmfoundry/utils/registry_utils.py b/llmfoundry/utils/registry_utils.py
index 0901ea198a..d9c23e6f26 100644
--- a/llmfoundry/utils/registry_utils.py
+++ b/llmfoundry/utils/registry_utils.py
@@ -15,6 +15,7 @@
 
 T = TypeVar('T')
 TypeBoundT = TypeVar('TypeBoundT', bound=Type)
+CallableBoundT = TypeVar('CallableBoundT', bound=Callable[..., Any])
 
 
 class TypedRegistry(catalogue.Registry, Generic[T]):
diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py
index 594e4f778f..e78e76a912 100644
--- a/scripts/data_prep/convert_finetuning_dataset.py
+++ b/scripts/data_prep/convert_finetuning_dataset.py
@@ -269,6 +269,7 @@ def main(args: Namespace) -> None:
             examples_removed = 0
             for sample in tqdm(samples, desc=split_name):
                 formatted_sample = preprocessing_fn(sample)
+                assert isinstance(formatted_sample, dict)
 
                 # Use the _get_example_type utility to confirm that the formatted sample
                 # can be interpreted by the tokenization code
@@ -300,13 +301,15 @@ def main(args: Namespace) -> None:
                     out.write(sample_to_write)
                 else:
                     if example_type == 'prompt_response':
-                        encoded_sample = {
-                            key: formatted_sample[key].encode('utf-8')
-                            for key in ['prompt', 'response']
-                        }
+                        encoded_sample = {}
+                        for key in ['prompt', 'response']:
+                            value = formatted_sample[key]
+                            assert isinstance(value, str)
+                            encoded_sample[key] = value.encode('utf-8')
+                        out.write(encoded_sample)
                     else:
-                        encoded_sample = formatted_sample
-                    out.write(encoded_sample)
+                        out.write(formatted_sample)
+
         if tokenizer is not None and examples_removed > 0:
             warnings.warn(
                 f'Dropped {examples_removed} examples where the prompt was longer than {args.max_seq_len}, '
diff --git a/scripts/data_prep/convert_text_to_mds.py b/scripts/data_prep/convert_text_to_mds.py
index df39e38a90..be986fc24d 100644
--- a/scripts/data_prep/convert_text_to_mds.py
+++ b/scripts/data_prep/convert_text_to_mds.py
@@ -114,6 +114,13 @@ def parse_args() -> Namespace:
         help='If true, reprocess the input_folder to mds format. Otherwise, ' +
         'only reprocess upon changes to the input folder or dataset creation parameters.',
     )
+    parser.add_argument(
+        '--trust-remote-code',
+        type=bool,
+        required=False,
+        default=False,
+        help='If true, allows custom code to be executed to load the tokenizer',
+    )
 
     parsed = parser.parse_args()
 
@@ -124,7 +131,8 @@ def parse_args() -> Namespace:
             parser.error(
                 'Cannot set --eos_text with --use_tokenizer_eos. Please specify one.'
             )
-        tokenizer = AutoTokenizer.from_pretrained(parsed.tokenizer)
+        tokenizer = AutoTokenizer.from_pretrained(
+            parsed.tokenizer, trust_remote_code=parsed.trust_remote_code)
         parsed.eos_text = tokenizer.eos_token
 
     # now that we have validated them, change BOS/EOS to strings
@@ -171,6 +179,7 @@ def get_task_args(
     bos_text: str,
     no_wrap: bool,
     compression: str,
+    trust_remote_code: bool,
 ) -> Iterable:
     """Get download_and_convert arguments split across n_groups.
 
@@ -187,6 +196,7 @@ def get_task_args(
         bos_text (str): Text to prepend to each example to separate concatenated samples
         no_wrap: (bool): Whether to let text examples wrap across multiple training examples
         compression (str): The compression algorithm to use for MDS writing
+        trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
     """
     num_objects = len(object_names)
     objs_per_group = math.ceil(num_objects / n_groups)
@@ -202,6 +212,7 @@ def get_task_args(
             bos_text,
             no_wrap,
             compression,
+            trust_remote_code,
         )
 
 
@@ -223,6 +234,7 @@ def download_and_convert(
     bos_text: str,
     no_wrap: bool,
     compression: str,
+    trust_remote_code: bool,
 ):
     """Downloads and converts text fies to MDS format.
 
@@ -236,6 +248,7 @@ def download_and_convert(
         bos_text (str): Text to prepend to each example to separate concatenated samples
         no_wrap: (bool): Whether to let text examples wrap across multiple training examples
         compression (str): The compression algorithm to use for MDS writing
+        trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
     """
     object_store = maybe_create_object_store_from_uri(input_folder)
 
@@ -244,7 +257,8 @@ def download_and_convert(
         downloading_iter = DownloadingIterable(object_names=file_names,
                                                output_folder=tmp_dir,
                                                object_store=object_store)
-        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
+        tokenizer = AutoTokenizer.from_pretrained(
+            tokenizer_name, trust_remote_code=trust_remote_code)
         tokenizer.model_max_length = 5000000000  # Hack to prevent warnings from HuggingFace
 
         # Use the ConcatTokensDataset from LLM-foundry to concatenate sequences of tokens up
@@ -353,6 +367,7 @@ def convert_text_to_mds(
     processes: int,
     args_str: str,
     reprocess: bool,
+    trust_remote_code: bool,
 ):
     """Convert a folder of text files to MDS format.
 
@@ -368,6 +383,7 @@ def convert_text_to_mds(
         processes (int): The number of processes to use.
         args_str (str): String representation of the arguments
         reprocess (bool): Whether to always reprocess the given folder of text files
+        trust_remote_code (bool): If true, allows custom code to be executed to load the tokenizer
     """
     is_remote_output = is_remote_path(output_folder)
 
@@ -396,7 +412,7 @@ def convert_text_to_mds(
         # Download and convert the text files in parallel
         args = get_task_args(object_names, local_output_folder, input_folder,
                              processes, tokenizer_name, concat_tokens, eos_text,
-                             bos_text, no_wrap, compression)
+                             bos_text, no_wrap, compression, trust_remote_code)
         with ProcessPoolExecutor(max_workers=processes) as executor:
             list(executor.map(download_and_convert_starargs, args))
 
@@ -405,7 +421,7 @@ def convert_text_to_mds(
     else:
         download_and_convert(object_names, local_output_folder, input_folder,
                              tokenizer_name, concat_tokens, eos_text, bos_text,
-                             no_wrap, compression)
+                             no_wrap, compression, trust_remote_code)
 
     # Write a done file with the args and object names
     write_done_file(local_output_folder, args_str, object_names)
@@ -462,6 +478,7 @@ def _args_str(original_args: Namespace) -> str:
                             compression=args.compression,
                             processes=args.processes,
                             reprocess=args.reprocess,
+                            trust_remote_code=args.trust_remote_code,
                             args_str=_args_str(args))
     except Exception as e:
         if mosaicml_logger is not None:
diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py
index 6ac645e5b7..57193136ec 100644
--- a/scripts/inference/hf_generate.py
+++ b/scripts/inference/hf_generate.py
@@ -206,7 +206,6 @@ def main(args: Namespace) -> None:
         if device is not None:
             print(f'Placing model on {device=}...')
             model.to(device)
-        model.to(model_dtype)
     except Exception as e:
         raise RuntimeError(
             'Unable to load HF model. ' +
diff --git a/scripts/train/train.py b/scripts/train/train.py
index f0a6038dc8..76156d4577 100644
--- a/scripts/train/train.py
+++ b/scripts/train/train.py
@@ -80,16 +80,16 @@ def validate_config(cfg: DictConfig):
         fsdp_config = cfg.get('fsdp_config', None)
         act_ckpt = fsdp_config.get('activation_checkpointing', False)
         act_ckpt_reentrant = fsdp_config.get(
-            'activation_checkpointing_reentrant', True)
-        if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == False:
+            'activation_checkpointing_reentrant', False)
+        if fsdp_config is not None and act_ckpt == True and act_ckpt_reentrant == True:
             warnings.warn(
                 '`te.Linear` layers do not support activation_checkpointing with '
-                + '`activation_checkpointing_reentrant = False`. ' +
-                'Setting cfg.fsdp_config.activation_checkpointing_reentrant=True.'
+                + '`activation_checkpointing_reentrant = True`. ' +
+                'Setting cfg.fsdp_config.activation_checkpointing_reentrant=False.'
             )
-            cfg.fsdp_config.activation_checkpointing_reentrant = True
+            cfg.fsdp_config.activation_checkpointing_reentrant = False
 
-    if 'te' in cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp'):
+    if cfg.model.get('ffn_config', {}).get('ffn_type', 'mptmlp') == 'te_ln_mlp':
         warnings.warn(
             '`te.LayerNormMLP` requires has issues with torch._dynamo. ' +
             'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.'
@@ -101,6 +101,17 @@ def validate_config(cfg: DictConfig):
             '`load_in_8bit` is only supported for evaluation rather than training.'
         )
 
+    if cfg.model.get('ffn_config', {}).get('ffn_type',
+                                           'mptmlp') in ('mb_moe', 'mb_dmoe'):
+        moe_world_size = cfg.model.get('ffn_config',
+                                       {}).get('moe_world_size', 1)
+        use_orig_params = cfg.get('fsdp_config',
+                                  {}).get('use_orig_params', True)
+        if moe_world_size > 1 and not use_orig_params:
+            raise ValueError(
+                f'MoEs with expert parallelism (moe_world_size {moe_world_size} > 1) require `use_orig_params=True`.'
+            )
+
 
 def main(cfg: DictConfig) -> Trainer:
     # Run user provided code if specified
@@ -323,6 +334,10 @@ def main(cfg: DictConfig) -> Trainer:
                                                        'load_ignore_keys',
                                                        must_exist=False,
                                                        default_value=None)
+    save_ignore_keys: Optional[List[str]] = pop_config(cfg,
+                                                       'save_ignore_keys',
+                                                       must_exist=False,
+                                                       default_value=None)
     compile_config: Optional[Dict[str, Any]] = pop_config(cfg,
                                                           'compile_config',
                                                           must_exist=False,
@@ -520,11 +535,20 @@ def main(cfg: DictConfig) -> Trainer:
     )
 
     # Log number of parameters
-    n_params = sum(p.numel() for p in model.parameters())
-    n_trainable_params = sum(
-        p.numel() for p in model.parameters() if p.requires_grad)
+    if hasattr(model, 'n_total_params'):
+        n_params = model.n_total_params
+        n_trainable_params = n_params  # TODO: we currently assume all parameters are trainable.
+    else:
+        n_params = sum(p.numel() for p in model.parameters())
+        n_trainable_params = sum(
+            p.numel() for p in model.parameters() if p.requires_grad)
+    if hasattr(model, 'n_active_params'):
+        n_active_params = model.n_active_params
+    else:
+        n_active_params = n_params
     logged_cfg.update({
         'n_params': n_params,
+        'n_active_params': n_active_params,
         'n_trainable_params': n_trainable_params,
     })
 
@@ -580,6 +604,7 @@ def main(cfg: DictConfig) -> Trainer:
         load_weights_only=load_weights_only,
         load_strict_model_weights=load_strict_model_weights,
         load_ignore_keys=load_ignore_keys,
+        save_ignore_keys=save_ignore_keys,
         autoresume=autoresume,
         python_log_level=python_log_level,
         dist_timeout=dist_timeout,
diff --git a/scripts/train/yamls/finetune/dbrx-full-ft.yaml b/scripts/train/yamls/finetune/dbrx-full-ft.yaml
index a0e2504787..9cb53e40fd 100644
--- a/scripts/train/yamls/finetune/dbrx-full-ft.yaml
+++ b/scripts/train/yamls/finetune/dbrx-full-ft.yaml
@@ -86,7 +86,6 @@ seed: 17
 device_train_microbatch_size: 1
 device_eval_batch_size: 1
 precision: amp_bf16
-autoresume: true
 dist_timeout: 3600
 expandable_segments: true
 
diff --git a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml
index 7fb921ae16..06e8f1d6f0 100644
--- a/scripts/train/yamls/finetune/dbrx-lora-ft.yaml
+++ b/scripts/train/yamls/finetune/dbrx-lora-ft.yaml
@@ -94,7 +94,6 @@ seed: 17
 device_train_microbatch_size: 1
 device_eval_batch_size: 1
 precision: amp_bf16
-autoresume: true
 dist_timeout: 3600
 expandable_segments: true
 
diff --git a/scripts/train/yamls/pretrain/testing-moe.yaml b/scripts/train/yamls/pretrain/testing-moe.yaml
new file mode 100644
index 0000000000..eea2b999b7
--- /dev/null
+++ b/scripts/train/yamls/pretrain/testing-moe.yaml
@@ -0,0 +1,117 @@
+data_local: ./my-copy-c4
+data_remote:  # If blank, files must be present in data_local
+max_seq_len: 128
+global_seed: 17
+
+# Run Name
+run_name:  # If left blank, will be read from env var $RUN_NAME
+
+# Model
+model:
+  name: mpt_causal_lm
+  init_device: meta
+  d_model: 128
+  ffn_config:
+    ffn_type: mb_dmoe
+    memory_optimized_mlp: true
+    moe_lbl_in_fp32: false
+    moe_loss_weight: 0.01
+    moe_num_experts: 4
+    moe_top_k: 2
+    moe_world_size: 1
+    moe_weight_parallelism: false
+    uniform_expert_assignment: false
+  n_heads: 2
+  n_layers: 2
+  expansion_ratio: 1
+  max_seq_len: ${max_seq_len}
+  vocab_size: 50368
+  attn_config:
+    attn_impl: torch
+  loss_fn: torch_crossentropy
+
+# Tokenizer
+tokenizer:
+  name: EleutherAI/gpt-neox-20b
+  kwargs:
+    model_max_length: ${max_seq_len}
+
+# Dataloaders
+train_loader:
+  name: text
+  dataset:
+    local: ${data_local}
+    remote: ${data_remote}
+    split: train
+    shuffle: true
+    max_seq_len: ${max_seq_len}
+    shuffle_seed: ${global_seed}
+  drop_last: true
+  num_workers: 8
+
+eval_loader:
+  name: text
+  dataset:
+    local: ${data_local}
+    remote: ${data_remote}
+    split: val
+    shuffle: false
+    max_seq_len: ${max_seq_len}
+    shuffle_seed: ${global_seed}
+  drop_last: false
+  num_workers: 8
+
+# Optimization
+scheduler:
+  name: cosine_with_warmup
+  t_warmup: 100ba
+  alpha_f: 0.1
+
+optimizer:
+  name: decoupled_adamw
+  lr: 6.0e-4
+  betas:
+  - 0.9
+  - 0.95
+  eps: 1.0e-08
+  weight_decay: 0.0
+
+algorithms:
+  gradient_clipping:
+    clipping_type: norm
+    clipping_threshold: 1.0
+
+max_duration: 200ba
+eval_interval: 100ba
+eval_first: false
+eval_subset_num_batches: -1
+global_train_batch_size: 256
+
+# System
+seed: ${global_seed}
+device_eval_batch_size: 16
+device_train_microbatch_size: 16
+# device_train_microbatch_size: auto
+precision: amp_bf16
+
+# FSDP
+fsdp_config:
+  sharding_strategy: FULL_SHARD
+  mixed_precision: PURE
+  activation_checkpointing: false
+  activation_checkpointing_reentrant: false
+  activation_cpu_offload: false
+  limit_all_gathers: true
+  verbose: false
+
+# Logging
+progress_bar: false
+log_to_console: true
+console_log_interval: 1ba
+
+callbacks:
+  speed_monitor:
+    window_size: 10
+  lr_monitor: {}
+  memory_monitor: {}
+  runtime_estimator: {}
diff --git a/setup.py b/setup.py
index 79511eeca3..086e759384 100644
--- a/setup.py
+++ b/setup.py
@@ -55,7 +55,7 @@
     'mlflow>=2.10,<3',
     'accelerate>=0.25,<0.26',  # for HF inference `device_map`
     'transformers>=4.39.3,<4.40',
-    'mosaicml-streaming>=0.7.4,<0.8',
+    'mosaicml-streaming>=0.7.5,<0.8',
     'torch>=2.2.1,<2.3',
     'datasets>=2.16,<2.17',
     'fsspec==2023.6.0',  # newer version results in a bug in datasets that duplicates data
@@ -117,8 +117,15 @@
     'openai==1.3.8',
     'tiktoken==0.4.0',
 ]
-extra_deps['all-cpu'] = set(
-    dep for key, deps in extra_deps.items() for dep in deps if 'gpu' not in key)
+
+extra_deps['megablocks'] = [
+    'megablocks==0.5.1',
+    'grouped-gemm==0.1.4',
+]
+
+extra_deps['all-cpu'] = set(dep for key, deps in extra_deps.items()
+                            for dep in deps
+                            if 'gpu' not in key and 'megablocks' not in key)
 extra_deps['all'] = set(dep for key, deps in extra_deps.items() for dep in deps
                         if key not in {'gpu-flash2', 'all-cpu'})
 extra_deps['all-flash2'] = set(dep for key, deps in extra_deps.items()
diff --git a/tests/a_scripts/data_prep/test_convert_text_to_mds.py b/tests/a_scripts/data_prep/test_convert_text_to_mds.py
index e458cb1dfc..bd96de695c 100644
--- a/tests/a_scripts/data_prep/test_convert_text_to_mds.py
+++ b/tests/a_scripts/data_prep/test_convert_text_to_mds.py
@@ -106,6 +106,7 @@ def call_convert_text_to_mds() -> None:
             processes=processes,
             args_str='Namespace()',
             reprocess=False,
+            trust_remote_code=False,
         )
 
     call_convert_text_to_mds()
@@ -195,6 +196,7 @@ def call_convert_text_to_mds(reprocess: bool):
             processes=1,
             args_str='Namespace()',
             reprocess=reprocess,
+            trust_remote_code=False,
         )
 
     # Create input text data
@@ -234,6 +236,7 @@ def test_input_folder_not_exist(tmp_path: pathlib.Path):
             processes=1,
             args_str='Namespace()',
             reprocess=False,
+            trust_remote_code=False,
         )
 
 
diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py
index 7b4ef1e058..061227d8a4 100644
--- a/tests/a_scripts/inference/test_convert_composer_to_hf.py
+++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py
@@ -18,14 +18,17 @@
 from composer.utils import dist, get_device
 from omegaconf import DictConfig
 from omegaconf import OmegaConf as om
+from torch.distributed._tensor.api import DTensor
 from torch.utils.data import DataLoader
 from transformers import PreTrainedModel, PreTrainedTokenizerBase
 
 from llmfoundry.callbacks import HuggingFaceCheckpointer
 from llmfoundry.callbacks.hf_checkpointer import _maybe_get_license_filename
 from llmfoundry.data.finetuning import build_finetuning_dataloader
+from llmfoundry.models.mpt import MPTConfig
 from llmfoundry.utils.builders import (build_composer_model, build_optimizer,
                                        build_tokenizer)
+from llmfoundry.utils.config_utils import process_init_device
 from scripts.inference.convert_composer_to_hf import convert_composer_to_hf
 from tests.data_utils import make_tiny_ft_dataset
 
@@ -191,9 +194,18 @@ def check_hf_tokenizer_equivalence(tokenizer1: PreTrainedTokenizerBase,
     assert tokenizer1.__dict__ == tokenizer2.__dict__
 
 
+def remove_moe_world_size(config: MPTConfig):
+    if hasattr(config, 'ffn_config'):
+        if 'moe_world_size' in config.ffn_config:
+            config.ffn_config.pop('moe_world_size')
+
+
 def check_hf_model_equivalence(model1: PreTrainedModel,
                                model2: PreTrainedModel,
                                just_lora: bool = False):
+    remove_moe_world_size(model1.config)
+    remove_moe_world_size(model2.config)
+
     expected_model_config_dict = model1.config.to_dict()
     new_model_config_dict = model2.config.to_dict()
 
@@ -225,6 +237,7 @@ def check_hf_model_equivalence(model1: PreTrainedModel,
             assert torch.equal(p1.cpu(), p2.cpu())
 
 
+# TODO(GRT-2435): Change to fixture
 def delete_transformers_cache():
     # Only delete the files on local rank 0, otherwise race conditions are created
     if not dist.get_local_rank() == 0:
@@ -421,6 +434,35 @@ def _get_model_and_tokenizer(model: str, max_seq_len: int,
             'tie_word_embeddings': tie_word_embeddings,
         }
         tokenizer_name = 'EleutherAI/gpt-neox-20b'
+    elif model == 'mptmoe':
+        # Test export on moe_world_size 1
+        model_cfg = {
+            'name': 'mpt_causal_lm',
+            'init_device': 'cpu',
+            'd_model': 128,
+            'n_heads': 2,
+            'n_layers': 2,
+            'expansion_ratio': 1,
+            'ffn_config': {
+                'ffn_type': 'mb_dmoe',
+                'memory_optimized_mlp': True,
+                'moe_lbl_in_fp32': False,
+                'moe_loss_weight': 0.01,
+                'moe_num_experts': 4,
+                'moe_top_k': 2,
+                'moe_world_size': 1,
+                'moe_weight_parallelism': False,
+                'uniform_expert_assignment': False,
+            },
+            'max_seq_len': max_seq_len,
+            'vocab_size': 50368,
+            'attn_config': {
+                'attn_impl': 'torch',
+            },
+            'loss_fn': 'torch_crossentropy',
+            'no_bias': True,
+        }
+        tokenizer_name = 'EleutherAI/gpt-neox-20b'
     elif model == 'neo':
         assert tie_word_embeddings is None
         model_cfg = {
@@ -645,6 +687,7 @@ def _assert_checkpoint_equivalence(tmp_path: pathlib.Path,
     [
         ('mpt', True, None),
         ('mpt', False, None),
+        ('mptmoe', None, None),
         ('neo', None, None),
         ('llama2', None, None),
         ('llama2', None, {
@@ -680,6 +723,8 @@ def test_huggingface_conversion_callback(
     expected_normal_checkpoints: int,
     peft_config: Optional[dict],
 ):
+    if model == 'mptmoe' and fsdp_state_dict_type is None:
+        pytest.skip('mptmoe requires FSDP')
     delete_transformers_cache()
 
     dist.initialize_dist(get_device('gpu'))
@@ -697,7 +742,7 @@ def test_huggingface_conversion_callback(
         precision=precision_str,
         mlflow_registered_model_name='dummy-registered-name')
 
-    # get small version of each model
+    # Get small version of each model
     model_cfg, tokenizer_name = _get_model_and_tokenizer(
         model, max_seq_len, tie_word_embeddings)
     assert model_cfg is not None
@@ -781,9 +826,12 @@ def test_huggingface_conversion_callback(
     delete_transformers_cache()
 
 
+# TODO(GRT-2431): Refactor as enums
 @pytest.mark.parametrize(
     'model,tie_word_embeddings',
-    [('mpt', True), ('mpt', False), ('neo', None), ('llama2', None)],
+    [('mpt', True), ('mpt', False),
+     pytest.param('mptmoe', None, marks=pytest.mark.gpu), ('neo', None),
+     ('llama2', None)],
 )
 def test_convert_and_generate(model: str, tie_word_embeddings: bool,
                               tmp_path: pathlib.Path):
@@ -794,6 +842,9 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
         om_cfg = get_config(
             conf_path='scripts/train/yamls/pretrain/testing.yaml')
         om_cfg['tie_word_embeddings'] = tie_word_embeddings
+    elif model == 'mptmoe':
+        om_cfg = get_config(
+            conf_path='scripts/train/yamls/pretrain/testing-moe.yaml')
     elif model == 'neo':
         assert tie_word_embeddings is None
         om_cfg = get_config(
@@ -824,7 +875,8 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
         cfg=om_cfg['model'],
         tokenizer=tokenizer,
     )
-    trainer = Trainer(model=original_model, device='cpu')
+    trainer = Trainer(model=original_model,
+                      device='cpu' if not model == 'mptmoe' else 'gpu')
     trainer.save_checkpoint(os.path.join(tmp_path, 'checkpoint.pt'))
 
     args = Namespace(composer_path=os.path.join(tmp_path, 'checkpoint.pt'),
@@ -845,8 +897,15 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
     tokenizer = transformers.AutoTokenizer.from_pretrained(
         os.path.join(tmp_path, 'hf-output-folder'), trust_remote_code=True)
 
-    output = loaded_model.generate(tokenizer('hello',
-                                             return_tensors='pt')['input_ids'],
+    device = 'cuda' if model == 'mptmoe' else 'cpu'
+    precision = torch.bfloat16 if model == 'mptmoe' else torch.float32
+    original_model.to(device)
+    original_model.to(precision)
+    loaded_model.to(device)
+    loaded_model.to(precision)
+
+    output = loaded_model.generate(tokenizer(
+        'hello', return_tensors='pt')['input_ids'].to(device),
                                    max_new_tokens=1)
     assert output.shape == (1, 2 + (1 if model == 'llama2' else 0))
 
@@ -863,16 +922,21 @@ def test_convert_and_generate(model: str, tie_word_embeddings: bool,
     delete_transformers_cache()
 
 
+@pytest.mark.parametrize('conf_path', [
+    'scripts/train/yamls/pretrain/testing.yaml',
+    pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml',
+                 marks=pytest.mark.gpu),
+])
 @pytest.mark.parametrize('tie_word_embeddings', [True, False])
 def test_convert_and_generate_meta(tie_word_embeddings: str,
-                                   tmp_path: pathlib.Path):
+                                   tmp_path: pathlib.Path, conf_path: str):
     delete_transformers_cache()
 
     from composer.utils import dist
     gathered_paths = dist.all_gather_object(tmp_path)
     tmp_path_gathered = gathered_paths[0]
 
-    om_cfg = get_config(conf_path='scripts/train/yamls/pretrain/testing.yaml')
+    om_cfg = get_config(conf_path=conf_path)
 
     om_cfg['model']['init_device'] = 'cpu'
     om_cfg['tie_word_embeddings'] = tie_word_embeddings
@@ -883,7 +947,8 @@ def test_convert_and_generate_meta(tie_word_embeddings: str,
         cfg=om_cfg['model'],
         tokenizer=tokenizer,
     )
-    trainer = Trainer(model=original_model, device='cpu')
+    trainer = Trainer(model=original_model,
+                      device='cpu' if not 'moe' in conf_path else 'gpu')
     trainer.save_checkpoint(os.path.join(tmp_path_gathered, 'checkpoint.pt'))
 
     # patch in the meta device for testing
@@ -915,8 +980,15 @@ def test_convert_and_generate_meta(tie_word_embeddings: str,
         os.path.join(tmp_path_gathered, 'hf-output-folder'),
         trust_remote_code=True)
 
-    output = loaded_model.generate(tokenizer('hello',
-                                             return_tensors='pt')['input_ids'],
+    device = 'cuda' if 'moe' in conf_path else 'cpu'
+    precision = torch.bfloat16 if 'moe' in conf_path else torch.float32
+    original_model.to(device)
+    original_model.to(precision)
+    loaded_model.to(device)
+    loaded_model.to(precision)
+
+    output = loaded_model.generate(tokenizer(
+        'hello', return_tensors='pt')['input_ids'].to(device),
                                    max_new_tokens=1)
     assert output.shape == (1, 2)
 
@@ -933,6 +1005,253 @@ def test_convert_and_generate_meta(tie_word_embeddings: str,
     delete_transformers_cache()
 
 
+@pytest.mark.world_size(4)
+@pytest.mark.gpu
+@pytest.mark.parametrize('num_experts', [2, 4, 8])
+@pytest.mark.parametrize('sharding_strategy', ['FULL_SHARD', 'HYBRID_SHARD'])
+def test_mptmoe_huggingface_conversion_callback(
+    tmp_path: pathlib.Path,
+    num_experts: int,
+    sharding_strategy: str,
+    hf_save_interval: str = '1ba',
+    save_interval: str = '1ba',
+    max_duration: str = '1ba',
+    expected_hf_checkpoints: int = 1,
+    expected_normal_checkpoints: int = 1,
+):
+
+    delete_transformers_cache()
+
+    dist.initialize_dist(get_device('gpu'))
+    if dist.get_world_size() != 4:
+        pytest.skip('This test requires 4 GPUs')
+
+    max_seq_len = 16
+    device_batch_size = 1
+    dataset_size = 2
+    precision_str = 'float32'
+    precision = torch.float32
+    batches_per_epoch = math.ceil(dataset_size / (device_batch_size * 2))
+
+    checkpointer_callback = HuggingFaceCheckpointer(
+        save_folder=os.path.join(tmp_path, 'checkpoints'),
+        save_interval=hf_save_interval,
+        precision=precision_str,
+    )
+
+    # get small version of each model
+    model_cfg = None
+    tokenizer_name = None
+
+    # Test export on moe_world_size 1
+    model_cfg = {
+        'name': 'mpt_causal_lm',
+        'init_device': 'cpu',
+        'd_model': 128,
+        'n_heads': 2,
+        'n_layers': 2,
+        'expansion_ratio': 1,
+        'ffn_config': {
+            'ffn_type':
+                'mb_dmoe',
+            'memory_optimized_mlp':
+                True,
+            'moe_lbl_in_fp32':
+                False,
+            'moe_loss_weight':
+                0.01,
+            'moe_num_experts':
+                num_experts,
+            'moe_top_k':
+                2,
+            'moe_world_size':
+                2,
+            'moe_weight_parallelism':
+                False,
+            'uniform_expert_assignment':
+                True,
+            'mlp_impl':
+                'grouped',
+            'mlp_type':
+                'glu',
+            'device_mesh': [1, 2] if sharding_strategy == 'HYBRID_SHARD' else [
+                2,
+            ],
+        },
+        'precision': 'amp_bf16',
+        'max_seq_len': max_seq_len,
+        'vocab_size': 50368,
+        'attn_config': {
+            'attn_impl': 'torch',
+        },
+        'loss_fn': 'torch_crossentropy',
+        'no_bias': True,
+    }
+    tokenizer_name = 'EleutherAI/gpt-neox-20b'
+    assert model_cfg is not None
+    assert tokenizer_name is not None
+    model_cfg = om.create(model_cfg)
+
+    fsdp_config = {
+        'sharding_strategy': sharding_strategy,
+        'mixed_precision': 'PURE',
+        'activation_checkpointing': False,
+        'activation_checkpointing_reentrant': False,
+        'activation_cpu_offload': False,
+        'limit_all_gathers': True,
+        'device_mesh': [1, 4] if sharding_strategy == 'HYBRID_SHARD' else [
+            4,
+        ],
+        'use_orig_params': True,
+    }
+
+    tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
+    tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
+    if dist.get_global_rank() == 0:
+        make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size)
+
+    dataloader_cfg = {
+        'name': 'finetuning',
+        'dataset': {
+            'hf_name': tiny_dataset_folder_path,
+            'split': 'train',
+            'max_seq_len': max_seq_len,
+            'decoder_only_format': True,
+            'allow_pad_trimming': False,
+            'packing_ratio': None,
+            'shuffle': True,
+        },
+        'drop_last': False,
+        'num_workers': 0,
+        'pin_memory': False,
+        'prefetch_factor': None,
+        'persistent_workers': False,
+        'timeout': 0
+    }
+
+    dataloader_cfg = om.create(dataloader_cfg)
+
+    tokenizer = build_tokenizer(
+        tokenizer_name=tokenizer_name,
+        tokenizer_kwargs={'model_max_length': max_seq_len},
+    )
+
+    train_dataloader = build_finetuning_dataloader(
+        dataloader_cfg,
+        tokenizer,
+        device_batch_size,
+    )
+
+    optimizer_config = {
+        'name': 'decoupled_adamw',
+        'lr': 6e-4,
+        'betas': [0.9, 0.95],
+        'eps': 1e-8,
+        'weight_decay': 0.0,
+    }
+    optimizer_name = optimizer_config.pop('name')
+
+    init_context = process_init_device(model_cfg, fsdp_config)
+    original_model = build_composer_model(
+        name=model_cfg.name,
+        cfg=model_cfg,
+        tokenizer=tokenizer,
+        init_context=init_context,
+    )
+
+    optimizer = build_optimizer(original_model, optimizer_name,
+                                optimizer_config)
+    trainer = Trainer(
+        model=original_model,
+        device='gpu',
+        fsdp_config=fsdp_config,
+        train_dataloader=train_dataloader,
+        save_folder=os.path.join(tmp_path, 'checkpoints'),
+        save_interval=save_interval,
+        max_duration=max_duration,
+        callbacks=[checkpointer_callback],
+        optimizers=optimizer,
+        save_latest_filename=None,
+        precision=model_cfg.pop('precision', None),
+        save_weights_only=True,
+    )
+    trainer.fit()
+    #self.state.outputs = self.state.model(self.state.batch)
+    batch = trainer.state.batch
+    model_output_logits = trainer.state.model(batch).logits
+
+    # summon full params to check equivalence
+    from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
+    with FSDP.summon_full_params(trainer.state.model,
+                                 writeback=False,
+                                 recurse=True):
+        loaded_model = None
+        loaded_tokenizer = None
+        # Only rank zero is saving the huggingface checkpoints, so only check
+        # for equivalence on rank zero
+        if dist.get_global_rank() == 0:
+            normal_checkpoints = [
+                name
+                for name in os.listdir(os.path.join(tmp_path, 'checkpoints'))
+                if name != 'huggingface'
+            ]
+            huggingface_checkpoints = [
+                name for name in os.listdir(
+                    os.path.join(tmp_path, 'checkpoints', 'huggingface'))
+            ]
+            assert len(normal_checkpoints) == expected_normal_checkpoints
+            assert len(huggingface_checkpoints) == expected_hf_checkpoints
+
+            # Patch flash_attn package to be empty to simulate loading the model in
+            # an environment without flash atttention installed
+            with patch.dict('sys.modules', {'flash_attn': None}):
+                # Load the last huggingface checkpoint
+                loaded_model = transformers.AutoModelForCausalLM.from_pretrained(
+                    os.path.join(tmp_path, 'checkpoints', 'huggingface',
+                                 f'ba1'),
+                    trust_remote_code=True,
+                )
+
+            # Check that the loaded model has the correct precision, and then set it back
+            # to the original for the equivalence check
+            assert loaded_model.config.torch_dtype == precision
+            loaded_model.config.torch_dtype = original_model.model.config.torch_dtype
+
+            loaded_tokenizer = transformers.AutoTokenizer.from_pretrained(
+                os.path.join(tmp_path, 'checkpoints', 'huggingface',
+                             f'ba{batches_per_epoch}'),
+                trust_remote_code=True,
+            )
+        for n, p in trainer.state.model.model.named_parameters():
+            if isinstance(p, DTensor):
+                submodule_name, param_name = '.'.join(
+                    n.split('.')[:-1]), n.split('.')[-1]
+                submodule = trainer.state.model.model.get_submodule(
+                    submodule_name)
+                param_tensor = p.full_tensor()
+                param = torch.nn.Parameter(param_tensor)
+                submodule.register_parameter(param_name, param)
+
+        if dist.get_global_rank() == 0:
+            check_hf_model_equivalence(trainer.state.model.model, loaded_model)
+            check_hf_tokenizer_equivalence(tokenizer, loaded_tokenizer)
+
+            # Check output equivalence
+            loaded_model = loaded_model.cuda().bfloat16()  # type: ignore
+            loaded_model_logits = loaded_model(
+                input_ids=batch.get('input_ids', None),
+                attention_mask=batch.get('attention_mask', None),
+                prefix_mask=batch.get('bidirectional_mask', None),
+                sequence_id=batch.get('sequence_id', None),
+                inputs_embeds=batch.get('inputs_embeds', None),
+            ).logits
+            assert torch.equal(loaded_model_logits, model_output_logits)
+
+    dist.barrier()
+
+    delete_transformers_cache()
+
+
 @pytest.mark.parametrize(
     'license_file_name',
     ['LICENSE', 'LICENSE.txt', 'license', 'license.md', None])
diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py
index 68ed9d421c..ff885ac735 100644
--- a/tests/a_scripts/train/test_train.py
+++ b/tests/a_scripts/train/test_train.py
@@ -1,6 +1,8 @@
 # Copyright 2022 MosaicML LLM Foundry authors
 # SPDX-License-Identifier: Apache-2.0
+
 import copy
+import os
 import pathlib
 from typing import Optional
 
@@ -9,9 +11,10 @@
 from omegaconf import DictConfig, ListConfig
 from omegaconf import OmegaConf as om
 
-from scripts.train.train import main  # noqa: E402
+from scripts.train.train import main, validate_config  # noqa: E402
 from tests.data_utils import (create_arxiv_dataset, create_c4_dataset_xxsmall,
                               gpt_tiny_cfg)
+from tests.fixtures.autouse import REPO_DIR
 
 
 @pytest.mark.parametrize('averages', [{
@@ -144,6 +147,23 @@ def test_train_multi_eval(tmp_path: pathlib.Path):
         tuple)
 
 
+def test_validate_config():
+    conf_path: str = os.path.join(
+        REPO_DIR,
+        'scripts/train/yamls/pretrain/testing-moe.yaml',
+    )
+    with open(conf_path) as f:
+        test_cfg: DictConfig = om.load(f)  # type: ignore
+    test_cfg.model.ffn_config.moe_world_size = 4
+    test_cfg.fsdp_config.use_orig_params = False
+    with pytest.raises(
+            ValueError,
+            match=
+            'MoEs with expert parallelism (.*) require `use_orig_params=True`.'
+    ):
+        validate_config(test_cfg)
+
+
 def test_eval_metrics_with_no_train_metrics(tmp_path: pathlib.Path):
     """Test using use_train_metrics=False does not disable eval metrics."""
     c4_dataset_name = create_c4_dataset_xxsmall(tmp_path)
diff --git a/tests/callbacks/test_mbmoe_tok_per_expert_callback.py b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py
new file mode 100644
index 0000000000..79a625b4e4
--- /dev/null
+++ b/tests/callbacks/test_mbmoe_tok_per_expert_callback.py
@@ -0,0 +1,11 @@
+# Copyright 2024 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+from llmfoundry.utils.builders import build_callback
+
+
+def test_mbmoe_tok_per_expert_builds():
+    """Test that the callback can be built."""
+    callback = build_callback('mbmoe_tok_per_expert')
+    assert callback is not None
+    assert callback.__class__.__name__ == 'MegaBlocksMoE_TokPerExpert'
diff --git a/tests/data/test_dataloader.py b/tests/data/test_dataloader.py
index f5c2631fa7..c99ae6baf2 100644
--- a/tests/data/test_dataloader.py
+++ b/tests/data/test_dataloader.py
@@ -1093,3 +1093,87 @@ def test_build_unknown_dataloader():
     tokenizer = MagicMock()
     with pytest.raises(catalogue.RegistryError):
         _ = build_dataloader(cfg, tokenizer, 2)
+
+
+invalid_conversation_params_sharegpt = [
+    'add_invalid_last_chat_message', 'add_invalid_content_type',
+    'add_invalid_role', 'add_not_alternating_roles'
+]
+
+
+@pytest.mark.parametrize(
+    ','.join(invalid_conversation_params_sharegpt),
+    generate_exclusive_test_params(invalid_conversation_params_sharegpt))
+def test_sharegpt_format(tmp_path: pathlib.Path,
+                         add_invalid_last_chat_message: bool,
+                         add_invalid_content_type: bool, add_invalid_role: bool,
+                         add_not_alternating_roles: bool):
+    tokenizer_name = 'mosaicml/mpt-7b'
+    max_seq_len = 2048
+    dataset_size = 5
+    device_batch_size = 5
+    tiny_dataset_folder_path = tmp_path
+    tiny_dataset_path = str(tiny_dataset_folder_path / 'train.jsonl')
+
+    tokenizer = build_tokenizer(
+        tokenizer_name=tokenizer_name,
+        tokenizer_kwargs={'model_max_length': max_seq_len},
+    )
+    tokenizer.add_special_tokens({
+        'pad_token': '<pad>',
+        'bos_token': '<bos>',
+        'eos_token': '<eos>',
+    })
+
+    if dist.get_global_rank() == 0:
+        make_tiny_conversation_ft_dataset(
+            path=tiny_dataset_path,
+            size=dataset_size,
+            add_invalid_last_chat_message=add_invalid_last_chat_message,
+            add_invalid_message_key_quantity=False,
+            add_invalid_content_type=add_invalid_content_type,
+            add_invalid_role=add_invalid_role,
+            add_not_alternating_roles=add_not_alternating_roles,
+            use_messages_format=False,
+        )
+
+    cfg = {
+        'name': 'finetuning',
+        'dataset': {
+            'hf_name': str(tiny_dataset_folder_path),
+            'preprocessing_fn': 'teknium/OpenHermes-2.5',
+            'split': 'train',
+            'max_seq_len': max_seq_len,
+            'decoder_only_format': True,
+            'allow_pad_trimming': False,
+            'packing_ratio': None,
+            'shuffle': True,
+        },
+        'drop_last': False,
+        'num_workers': 0,
+        'prefetch_factor': None,
+        'pin_memory': False,
+        'persistent_workers': False,
+        'timeout': 0
+    }
+
+    cfg = om.create(cfg)
+
+    error_context = contextlib.nullcontext()
+    if add_invalid_last_chat_message:
+        error_context = pytest.raises(InvalidLastChatMessageRoleError,
+                                      match='Invalid last message role:')
+    if add_invalid_content_type:
+        error_context = pytest.raises(InvalidContentTypeError,
+                                      match='Expected content to be')
+    if add_invalid_role:
+        error_context = pytest.raises(InvalidRoleError,
+                                      match='Expected role to be one of')
+
+    if add_not_alternating_roles:
+        error_context = pytest.raises(ConsecutiveRepeatedChatRolesError,
+                                      match='Conversation roles must alternate')
+
+    with error_context:
+        build_finetuning_dataloader(cfg, tokenizer,
+                                    device_batch_size).dataloader
diff --git a/tests/data/test_template_tokenization.py b/tests/data/test_template_tokenization.py
index a45c4d8f0d..632a79dac9 100644
--- a/tests/data/test_template_tokenization.py
+++ b/tests/data/test_template_tokenization.py
@@ -9,6 +9,7 @@
 from llmfoundry.data.finetuning.tasks import (_ALLOWED_PROMPT_KEYS,
                                               _ALLOWED_RESPONSE_KEYS,
                                               _slice_chat_formatted_example,
+                                              dataset_constructor,
                                               tokenize_formatted_example)
 from llmfoundry.utils.builders import build_tokenizer
 
@@ -178,34 +179,67 @@ def test_tokenize_instruct_example_well_formed():
 @pytest.mark.parametrize(
     'tokenizer_name',
     ['EleutherAI/gpt-neox-20b', 'HuggingFaceH4/zephyr-7b-beta', 't5-base'])
-def test_multi_turn_chat_slicing(tokenizer_name: str):
-    convo = [
-        {
-            'role': 'system',
-            'content': 'everyone thinks you are so cool'
-        },
-        {
-            'role': 'user',
-            'content': 'hiiii'
-        },
-        {
-            'role': 'assistant',
-            'content': 'yassss'
-        },
-        {
-            'role': 'user',
-            'content': 'HIIIIII!!!'
-        },
-        {
-            'role': 'assistant',
-            'content': 'YASSSSSS'
-        },
-    ]
+@pytest.mark.parametrize('messages_format', [True, False])
+def test_multi_turn_chat_slicing(tokenizer_name: str, messages_format: bool):
+    if messages_format:
+        convo = [
+            {
+                'role': 'system',
+                'content': 'everyone thinks you are so cool'
+            },
+            {
+                'role': 'user',
+                'content': 'hiiii'
+            },
+            {
+                'role': 'assistant',
+                'content': 'yassss'
+            },
+            {
+                'role': 'user',
+                'content': 'HIIIIII!!!'
+            },
+            {
+                'role': 'assistant',
+                'content': 'YASSSSSS'
+            },
+        ]
+    else:
+        convo = [
+            {
+                'from': 'system',
+                'value': 'everyone thinks you are so cool'
+            },
+            {
+                'from': 'human',
+                'value': 'hiiii'
+            },
+            {
+                'from': 'gpt',
+                'value': 'yassss'
+            },
+            {
+                'from': 'tool',
+                'value': 'HIIIIII!!!'
+            },
+            {
+                'from': 'gpt',
+                'value': 'YASSSSSS'
+            },
+        ]
+        tmp = {'conversations': convo}
+        preprocessor = dataset_constructor.get_preprocessing_fn_from_str(
+            'teknium/OpenHermes-2.5')
+        assert preprocessor is not None
+        convo = preprocessor(tmp)['messages']
+        assert isinstance(convo, list)
+
+    example = {'messages': convo}
 
     tok = transformers.AutoTokenizer.from_pretrained(tokenizer_name)
 
     templated_prompt_response_turns = _slice_chat_formatted_example(
-        {'messages': convo}, tok)
+        example, tok)
 
     reconstructed_chat = ''
     for prompt, response in templated_prompt_response_turns:
diff --git a/tests/data_utils.py b/tests/data_utils.py
index 3c077b5e71..fd24d4cbbf 100644
--- a/tests/data_utils.py
+++ b/tests/data_utils.py
@@ -6,7 +6,7 @@
 import shutil
 from argparse import Namespace
 from pathlib import Path
-from typing import Optional
+from typing import Dict, List, Optional
 
 from omegaconf import DictConfig
 from omegaconf import OmegaConf as om
@@ -99,6 +99,7 @@ def make_tiny_conversation_ft_dataset(
     add_invalid_content_type: bool = False,
     add_invalid_role: bool = False,
     add_not_alternating_roles: bool = False,
+    use_messages_format: bool = True,
 ):
     if Path(path).suffix != '.jsonl':
         raise ValueError(f'Path {path} must be a jsonl file.')
@@ -198,6 +199,24 @@ def make_tiny_conversation_ft_dataset(
             }]
         })
 
+    def messages_to_conversation(sample: Dict):
+        assert 'messages' in sample
+        messages = sample['messages']
+
+        role_map = {
+            'user': 'human',
+            'assistant': 'gpt',
+        }
+        conversations: List[Dict[str, str]] = []
+        for message in messages:
+            role: str = role_map.get(message['role'], message['role'])
+            content: str = message['content']
+            conversations.append({'from': role, 'value': content})
+        return {'conversations': conversations}
+
+    if not use_messages_format:
+        samples = [messages_to_conversation(sample) for sample in samples]
+
     os.makedirs(os.path.dirname(path), exist_ok=True)
     with open(path, 'w') as _f:
         for sample in samples:
diff --git a/tests/models/layers/test_dmoe.py b/tests/models/layers/test_dmoe.py
new file mode 100644
index 0000000000..9c15745793
--- /dev/null
+++ b/tests/models/layers/test_dmoe.py
@@ -0,0 +1,263 @@
+# Copyright 2024 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+import copy
+from contextlib import nullcontext
+from functools import partial
+from typing import List, Optional
+
+import pytest
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.distributed._tensor import DTensor, Placement, Replicate, Shard
+from torch.distributed._tensor.device_mesh import init_device_mesh
+from torch.distributed.checkpoint.state_dict import (StateDictOptions,
+                                                     get_model_state_dict)
+from torch.distributed.tensor.parallel.ddp import _pre_dp_module_transform
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from llmfoundry.models.layers.dmoe import dMoE
+from llmfoundry.models.layers.ffn import dtensorify_param
+from llmfoundry.models.mpt.configuration_mpt import MPTConfig
+from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM
+
+try:
+    import megablocks
+    is_megablocks_imported = True
+except ModuleNotFoundError:
+    is_megablocks_imported = False
+
+
+def _get_all_inputs(
+    input_shape: List[int],
+    dtype: Optional[torch.dtype],
+):
+    world_size: int = dist.get_world_size()
+    rank: int = dist.get_rank()
+    device: torch.device = torch.device(f'cuda:{rank}')
+    all_inputs = []
+    for _ in range(world_size):
+        all_inputs.append(torch.rand(
+            input_shape,
+            device=device,
+            dtype=dtype,
+        ))
+    return all_inputs
+
+
+def _get_torch_dtype(fp16: bool, bf16: bool) -> Optional[torch.dtype]:
+    if fp16:
+        return torch.float16
+    elif bf16:
+        return torch.bfloat16
+    return None
+
+
+@pytest.mark.skipif(not is_megablocks_imported,
+                    reason='This test needs megablocks module')
+@pytest.mark.gpu
+@pytest.mark.world_size(2)
+@pytest.mark.parametrize('moe_num_experts', [8])
+@pytest.mark.parametrize('mlp_type', ['glu', 'mlp'])
+@pytest.mark.parametrize('moe_world_size', [1, 2])
+@pytest.mark.parametrize('two_d_input', [True, False])
+def test_dmoe(moe_num_experts: int, mlp_type: str, moe_world_size: int,
+              two_d_input: bool):
+    # Generate inputs
+    rank = dist.get_rank()
+    batch_size = 2
+    seq_len = 3
+    hidden_size = 128
+    if two_d_input:
+        input_shape = [batch_size * seq_len, hidden_size]
+    else:
+        input_shape = [batch_size, seq_len, hidden_size]
+    fp16 = False
+    bf16 = True
+    dtype = _get_torch_dtype(fp16, bf16)
+    x = _get_all_inputs(input_shape, dtype)[rank]
+
+    # Construct DDP torch dMoE
+    device = torch.device(f'cuda:{dist.get_rank()}')
+    common_args = {
+        'hidden_size': hidden_size,
+        'ffn_hidden_size': hidden_size,
+        'moe_top_k': 2,
+        'activation_fn': partial(F.gelu, approximate='none'),
+        'moe_jitter_eps': 0.0,  # Disable randomiztion
+        'moe_normalize_expert_weights': 1,
+        'uniform_expert_assignment': False,
+        'bias': False,
+        'device': device,
+        'moe_num_experts': moe_num_experts,
+        'mlp_type': mlp_type,
+    }
+
+    torch_dmoe = dMoE(**common_args).to(device, dtype=dtype)
+    torch_dmoe = DDP(
+        torch_dmoe,
+        device_ids=[rank],
+    )
+    torch_dmoe_optimizer = optim.SGD(torch_dmoe.parameters(), lr=0.1)
+
+    # Construct TP MB dMoE
+    mp_dmoe_args = copy.deepcopy(common_args)
+    extra_args = {
+        'fp16': fp16,
+        'bf16': bf16,
+        'init_method': partial(torch.nn.init.uniform_, a=-1.0, b=1.0),
+    }
+    device_mesh = None
+    if moe_world_size > 1:
+        world_size = dist.get_world_size()
+        assert world_size % moe_world_size == 0
+        moe_dp_dim = world_size // moe_world_size
+        device_mesh = init_device_mesh(
+            'cuda',
+            (moe_dp_dim, moe_world_size),
+            mesh_dim_names=('weight_parallel', 'expert_parallel'),
+        )
+        expert_parallel_group = device_mesh['expert_parallel'].get_group(0)
+        extra_args.update(
+            {
+                'moe_expert_model_parallelism': True,
+                'expert_parallel_group': expert_parallel_group,
+            },)
+    mp_dmoe_args.update(extra_args)
+    args = megablocks.layers.arguments.Arguments(**mp_dmoe_args,)
+    mb_dmoe = megablocks.layers.dmoe.dMoE(args).to(device)
+    mb_dmoe.router = DDP(mb_dmoe.router, device_ids=[rank])
+
+    if moe_world_size > 1:
+        assert device_mesh is not None
+        two_d_placements: List[Placement] = [Replicate(), Shard(0)]
+        dtensorified_params = [(
+            name,
+            dtensorify_param(
+                param=parameter,
+                mesh=device_mesh,
+                placements=two_d_placements,
+            ),
+        ) for name, parameter in mb_dmoe.experts.mlp.named_parameters()]
+        tp_names = []
+        for name, dtensorified_param in dtensorified_params:
+            mb_dmoe.experts.mlp.register_parameter(name, dtensorified_param)
+            tp_names.append('experts.mlp.' + name)
+
+        _pre_dp_module_transform(mb_dmoe.experts.mlp)
+
+        dp_pg = device_mesh['weight_parallel'].get_group(0)
+        mb_dmoe.experts = DDP(mb_dmoe.experts, process_group=dp_pg)
+
+        # Copy mb_dmoe's parameters to torch_dmoe
+        mb_dmoe_state_dict = get_model_state_dict(mb_dmoe,
+                                                  options=StateDictOptions(
+                                                      full_state_dict=True,))
+        for key, t in mb_dmoe_state_dict.items():
+            if key in tp_names:
+                dtensor_full = DTensor.from_local(
+                    t,  # pyright: ignore[reportGeneralTypeIssues]
+                    device_mesh=device_mesh,
+                    placements=two_d_placements,
+                ).full_tensor()
+
+                mb_dmoe_state_dict[key] = dtensor_full
+    else:
+        mb_dmoe.experts = DDP(mb_dmoe.experts, device_ids=[rank])
+        mb_dmoe_state_dict = get_model_state_dict(mb_dmoe,
+                                                  options=StateDictOptions(
+                                                      full_state_dict=True,))
+    mb_dmoe_optimizer = optim.SGD(mb_dmoe.parameters(), lr=0.1)
+
+    # Load mb_dmoe state dict to torch dmoe
+    torch_dmoe.module.load_state_dict(mb_dmoe_state_dict, strict=True)
+
+    # Run train_step check
+    torch_y = torch_dmoe(x)
+    mb_y = mb_dmoe(x)
+
+    torch_y.sum().backward()
+    mb_y.sum().backward()
+    torch_dmoe_optimizer.step()
+    mb_dmoe_optimizer.step()
+
+    torch_y = torch_dmoe(x)
+    mb_y = mb_dmoe(x)
+    torch.testing.assert_close(torch_y, mb_y)
+
+
+@pytest.mark.skipif(not is_megablocks_imported,
+                    reason='This test needs megablocks module')
+@pytest.mark.gpu
+@pytest.mark.parametrize('seqlen', [512])
+@pytest.mark.parametrize('mlp_type', ['glu', 'mlp'])
+@pytest.mark.parametrize('precision', ['bf16', 'fp32'])
+def test_fwd_equal_dmoe(seqlen: int, precision: str, mlp_type: str):
+    mb_dmoe_config = MPTConfig(d_model=1024,
+                               n_heads=32,
+                               n_layers=1,
+                               learned_pos_emb=False,
+                               max_seq_len=2048,
+                               vocab_size=100,
+                               no_bias=True,
+                               fuse_norm_attn_norm=True,
+                               tie_word_embeddings=False,
+                               attn_config=dict(
+                                   attn_type='grouped_query_attention',
+                                   attn_impl='torch',
+                                   attn_pdrop=0.0,
+                                   clip_qkv=8.0,
+                                   kv_n_heads=8,
+                                   rope=True,
+                                   rope_theta=10000.0,
+                               ),
+                               ffn_config=dict(
+                                   ffn_type='mb_dmoe',
+                                   fc_type='torch',
+                                   mlp_type=mlp_type,
+                                   moe_world_size=1,
+                                   ffn_act_fn={'name': 'silu'},
+                                   ffn_hidden_size=1792,
+                                   moe_num_experts=16,
+                                   moe_top_k=4,
+                                   moe_jitter_eps=0.0,
+                                   moe_loss_weight=0.05,
+                                   moe_normalize_expert_weights=1.0,
+                                   uniform_expert_assignment=False,
+                               ))
+    device = 'cuda:0'
+    if precision == 'fp32':
+        dtype = torch.float32
+        context = nullcontext()
+    elif precision == 'bf16':
+        dtype = torch.bfloat16
+        context = torch.autocast('cuda', torch.bfloat16)
+    else:
+        raise ValueError(f'Invalid {precision=}')
+
+    torch_dmoe_config = copy.deepcopy(mb_dmoe_config)
+    torch_dmoe_config.ffn_config['ffn_type'] = 'torch_dmoe'
+
+    mb_dmoe_model = MPTForCausalLM(mb_dmoe_config).to(device=device,
+                                                      dtype=dtype)
+    torch_dmoe_model = MPTForCausalLM(torch_dmoe_config).to(device=device,
+                                                            dtype=dtype)
+
+    # set same state dicts
+    torch_dmoe_model.load_state_dict(mb_dmoe_model.state_dict())
+
+    # tokens
+    token_ids = torch.randint(
+        0,
+        mb_dmoe_config.vocab_size,
+        (1, seqlen),
+        device=device,
+        dtype=torch.long,
+    )
+
+    with context:
+        mpt_logits = mb_dmoe_model(token_ids).logits
+        db_logits = torch_dmoe_model(token_ids).logits
+        assert torch.allclose(mpt_logits, db_logits, rtol=0.01, atol=0.01)
diff --git a/tests/models/layers/test_flash_torch.py b/tests/models/layers/test_flash_torch.py
index c0e9f4b3b5..f212665c93 100644
--- a/tests/models/layers/test_flash_torch.py
+++ b/tests/models/layers/test_flash_torch.py
@@ -8,6 +8,7 @@
 from llmfoundry.models.layers import attention
 from llmfoundry.models.layers.attention import (check_alibi_support, gen_slopes,
                                                 is_flash_v2_installed)
+from llmfoundry.models.layers.layer_builders import build_attention_layer
 from llmfoundry.models.mpt.modeling_mpt import (apply_sequence_id,
                                                 gen_attention_mask_in_length,
                                                 gen_flash_attn_padding_info,
@@ -120,9 +121,15 @@ def test_attn_impl(attn_impl_0: str,
                                        ]).to(device=device)
 
     cfg.attn_impl = attn_impl_0
-    attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
+    attn0 = build_attention_layer(
+        name=attn_type,
+        attn_kwargs=om.to_container(cfg),  # type: ignore
+    ).to(device)
     cfg.attn_impl = attn_impl_1
-    attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
+    attn1 = build_attention_layer(
+        name=attn_type,
+        attn_kwargs=om.to_container(cfg),  # type: ignore
+    ).to(device)
 
     attn1.load_state_dict(attn0.state_dict())
 
diff --git a/tests/models/test_model.py b/tests/models/test_model.py
index 7bd8292151..402698cb27 100644
--- a/tests/models/test_model.py
+++ b/tests/models/test_model.py
@@ -53,7 +53,8 @@ def _load_tokenizer_cfg(cfg: DictConfig) -> Dict:
     return config
 
 
-def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
+def _get_objs(request: pytest.FixtureRequest,
+              conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
     warnings.filterwarnings(
         action='ignore',
         message='Torchmetrics v0.9 introduced a new argument class property')
@@ -64,16 +65,19 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'):
     fsdp_config = om.to_container(fsdp_config,
                                   resolve=True) if fsdp_config else None
 
+    # Check if we are running on GPU
+    is_gpu = False
+    for item in request.session.items:
+        is_gpu |= item.get_closest_marker('gpu') is not None
+
     # Build Model
     # For fast initialization, use `meta` device
     print('Initializing model...')
-    device = 'cpu'
-    test_cfg.precision = 'fp32'
+    device = 'cuda' if is_gpu else 'cpu'
+    test_cfg.precision = 'amp_bf16' if is_gpu else 'fp32'
     test_cfg.model.attn_config = {
         'attn_impl': 'torch',
     }
-    # device = 'cuda'
-    # test_cfg.precision = 'amp'
     test_cfg.model.init_device = device
     test_cfg.device = device
 
@@ -151,9 +155,13 @@ def gen_random_enc_dec_batch(batch_size: int, vocab_size: int, max_seq_len: int,
     return batch
 
 
-def test_full_forward_and_backward(batch_size: int = 2):
-    test_cfg, model, optimizer = get_objs(
-        conf_path='scripts/train/yamls/pretrain/testing.yaml')
+@pytest.mark.parametrize('conf_path', [
+    'scripts/train/yamls/pretrain/testing.yaml',
+])
+def test_full_forward_and_backward(request: pytest.FixtureRequest,
+                                   conf_path: str,
+                                   batch_size: int = 2):
+    test_cfg, model, optimizer = _get_objs(request=request, conf_path=conf_path)
 
     batch = gen_random_batch(batch_size, test_cfg)
 
@@ -170,9 +178,10 @@ def test_full_forward_and_backward(batch_size: int = 2):
     assert not torch.equal(original_params, updated_params)
 
 
-def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2):
-    test_cfg, model, optimizer = get_objs(
-        conf_path='scripts/train/yamls/pretrain/testing.yaml')
+def test_full_forward_and_backward_with_inputs_embeds(
+        request: pytest.FixtureRequest, batch_size: int = 2):
+    test_cfg, model, optimizer = _get_objs(
+        request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml')
 
     batch = gen_random_batch(batch_size, test_cfg, inputs=['inputs_embeds'])
 
@@ -188,9 +197,10 @@ def test_full_forward_and_backward_with_inputs_embeds(batch_size: int = 2):
 
 
 @pytest.mark.parametrize('inputs', [[], ['input_ids', 'inputs_embeds']])
-def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]):
-    test_cfg, model, _ = get_objs(
-        conf_path='scripts/train/yamls/pretrain/testing.yaml')
+def test_invalid_inputs_embeds_input_ids_combinations(
+        request: pytest.FixtureRequest, inputs: List[str]):
+    test_cfg, model, _ = _get_objs(
+        request=request, conf_path='scripts/train/yamls/pretrain/testing.yaml')
 
     batch = gen_random_batch(2, test_cfg, inputs=inputs)
 
@@ -199,9 +209,15 @@ def test_invalid_inputs_embeds_input_ids_combinations(inputs: List[str]):
         _ = model(batch)
 
 
-def test_attention_mechanism(batch_size: int = 2):
-    test_cfg, model, _ = get_objs(
-        conf_path='scripts/train/yamls/pretrain/testing.yaml')
+@pytest.mark.parametrize('conf_path', [
+    'scripts/train/yamls/pretrain/testing.yaml',
+    pytest.param('scripts/train/yamls/pretrain/testing-moe.yaml',
+                 marks=pytest.mark.gpu),
+])
+def test_attention_mechanism(request: pytest.FixtureRequest,
+                             conf_path: str,
+                             batch_size: int = 2):
+    test_cfg, model, _ = _get_objs(request=request, conf_path=conf_path)
 
     batch = gen_random_batch(batch_size, test_cfg)
 
@@ -217,43 +233,45 @@ def test_attention_mechanism(batch_size: int = 2):
     pos = torch.arange(0, S, dtype=torch.long,
                        device=input_ids.device).unsqueeze(0)
 
-    tok_emb = model.model.transformer.wte(input_ids)
-    pos_emb = model.model.transformer.wpe(pos)
-    x = model.model.transformer.emb_drop(tok_emb + pos_emb)
-
-    # basically the attention mask should be a tensor shape (bsz, seqlen, seqlen)
-    # wih -inf along the upper triangle as well as wherever there are any pad tokens
-    # and with 0 everywhere else
-    expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len)\
-        .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len)
-    expected_zerod_weights = torch.isneginf(
-        torch.cat(batch_size * [expected_zerod_weights]))
-    torch_key_padding = torch.cat(  # type: ignore
-        test_cfg.max_seq_len *
-        [(~attention_mask).reshape(batch_size, 1, test_cfg.max_seq_len)],
-        axis=1)
-    expected_zerod_weights |= torch_key_padding
-
-    attn_bias, attention_mask = model.model.transformer._attn_bias(
-        device=x.device, dtype=x.dtype, attention_mask=attention_mask)
-
-    for block in model.model.transformer.blocks:
-        a = block.norm_1(x)
-        b, attention_weights, _ = block.attn(
-            a,
-            past_key_value=None,
-            attn_bias=attn_bias,
-            attention_mask=attention_mask,
-            is_causal=model.model.transformer.is_causal,
-            needs_weights=True)
-
-        zerod_weights = (attention_weights == 0)
-        assert torch.equal(expected_zerod_weights.expand(*zerod_weights.shape),
-                           zerod_weights)
-        x = x + block.resid_attn_dropout(b)
-        m = block.norm_2(x)
-        n = block.ffn(m)
-        x = x + block.resid_ffn_dropout(n)
+    with get_precision_context(test_cfg.precision):
+        tok_emb = model.model.transformer.wte(input_ids)
+        pos_emb = model.model.transformer.wpe(pos)
+        x = model.model.transformer.emb_drop(tok_emb + pos_emb)
+
+        # basically the attention mask should be a tensor shape (bsz, seqlen, seqlen)
+        # wih -inf along the upper triangle as well as wherever there are any pad tokens
+        # and with 0 everywhere else
+        expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len, device=test_cfg.device)\
+            .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len)
+        expected_zerod_weights = torch.isneginf(
+            torch.cat(batch_size * [expected_zerod_weights]))
+        torch_key_padding = torch.cat(  # type: ignore
+            test_cfg.max_seq_len *
+            [(~attention_mask).reshape(batch_size, 1, test_cfg.max_seq_len)],
+            axis=1)
+        expected_zerod_weights |= torch_key_padding
+
+        attn_bias, attention_mask = model.model.transformer._attn_bias(
+            device=x.device, dtype=x.dtype, attention_mask=attention_mask)
+
+        for block in model.model.transformer.blocks:
+            a = block.norm_1(x)
+            b, attention_weights, _ = block.attn(
+                a,
+                past_key_value=None,
+                attn_bias=attn_bias,
+                attention_mask=attention_mask,
+                is_causal=model.model.transformer.is_causal,
+                needs_weights=True)
+
+            zerod_weights = (attention_weights == 0)
+            assert torch.equal(
+                expected_zerod_weights.expand(*zerod_weights.shape),
+                zerod_weights)
+            x = x + block.resid_attn_dropout(b)
+            m = block.norm_2(x)
+            n = block.ffn(m)
+            x = x + block.resid_ffn_dropout(n)
 
 
 def test_full_forward_and_backward_gpt2_small(batch_size: int = 2):
@@ -424,7 +442,6 @@ def test_determinism(attn_impl: str, precision: torch.dtype, ffn_type: str,
             output_2 = model_2(batch)
             assert output_1.logits.allclose(output_2.logits, rtol=0.0,
                                             atol=0.0), f'differed at step {i}'
-
             loss_1 = model_1.loss(output_1, batch)
             loss_2 = model_2.loss(output_2, batch)
             assert isinstance(loss_1, torch.Tensor)
diff --git a/tests/models/test_mpt_gen.py b/tests/models/test_mpt_gen.py
index 35f130cd46..00c6a1c7a8 100644
--- a/tests/models/test_mpt_gen.py
+++ b/tests/models/test_mpt_gen.py
@@ -142,6 +142,50 @@ def test_mpt_generate_callback(attn_impl: str, use_alibi: bool,
     trainer.logger.log_table.assert_called_once()
 
 
+@pytest.mark.gpu
+@pytest.mark.parametrize('device', ['cpu', 'gpu'])
+@pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
+def test_gen_mpt_moe(
+    device: str,
+    attn_impl: str,
+    build_tiny_mpt: Callable[..., ComposerMPTCausalLM],
+    mpt_tokenizer: PreTrainedTokenizerBase,
+):
+    if device == 'cpu':
+        pytest.skip(f'Megablocks is only impelmented on GPU only.')
+    composer_device = get_device(device)
+
+    model = build_tiny_mpt(
+        attn_config={
+            'attn_impl': attn_impl,
+            'attn_uses_sequence_id': False,
+        },
+        expansion_ratio=1,
+        ffn_config={
+            'ffn_type': 'mb_dmoe',
+            'memory_optimized_mlp': True,
+            'moe_lbl_in_fp32': False,
+            'moe_loss_weight': 0.01,
+            'moe_num_experts': 4,
+            'moe_top_k': 2,
+            'moe_world_size': 1,
+            'moe_weight_parallelism': False,
+            'uniform_expert_assignment': False,
+        },
+    )
+    model = composer_device.module_to_device(model)
+
+    model.eval()
+
+    with get_precision_context('amp_bf16' if composer_device.name ==
+                               'gpu' else 'fp32'):
+        _ = model.generate(
+            composer_device.tensor_to_device(
+                mpt_tokenizer('hello', return_tensors='pt')['input_ids']),
+            max_new_tokens=10,
+        )
+
+
 @pytest.mark.gpu
 @pytest.mark.parametrize('attn_impl', ['flash', 'torch'])
 @pytest.mark.parametrize('use_alibi', [True, False])
diff --git a/tests/models/test_rope_dail_vs_hf.py b/tests/models/test_rope_dail_vs_hf.py
index 33c3d3c052..b9ab90357a 100644
--- a/tests/models/test_rope_dail_vs_hf.py
+++ b/tests/models/test_rope_dail_vs_hf.py
@@ -7,6 +7,7 @@
 from omegaconf import OmegaConf as om
 
 from llmfoundry.models.layers.attention import is_flash_v2_installed
+from llmfoundry.models.layers.layer_builders import build_attention_layer
 from llmfoundry.models.mpt.modeling_mpt import (gen_flash_attn_padding_info,
                                                 gen_rotary_embedding)
 
@@ -21,8 +22,6 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
     if not is_flash_v2_installed():
         pytest.skip('dail implementation of rope requires flash attention 2.')
 
-    from llmfoundry.models.layers import attention
-
     cfg = om.create({
         'attn_impl': 'flash',
         'd_model': 128,
@@ -37,8 +36,16 @@ def test_rope_dail_vs_hf(attn_type: str, seq_len: int, device: str = 'cuda'):
     if attn_type == 'grouped_query_attention':
         cfg.kv_n_heads = 2
 
-    attn0 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
-    attn1 = attention.ATTN_CLASS_REGISTRY[attn_type](**cfg).to(device)
+    attn0 = build_attention_layer(
+        name=attn_type,
+        attn_kwargs=om.to_container(
+            cfg),  # type: ignore (to_container return broad type)
+    ).to(device)
+    attn1 = build_attention_layer(
+        name=attn_type,
+        attn_kwargs=om.to_container(
+            cfg),  # type: ignore (to_container return broad type)
+    ).to(device)
 
     attn1.load_state_dict(attn0.state_dict())
     x0 = torch.randn(batch_size, seq_len, cfg.d_model).to(device)
diff --git a/tests/test_registry.py b/tests/test_registry.py
index c93c7c9749..29d8e137f3 100644
--- a/tests/test_registry.py
+++ b/tests/test_registry.py
@@ -31,6 +31,9 @@ def test_expected_registries_exist():
         'metrics',
         'models',
         'norms',
+        'attention_classes',
+        'attention_implementations',
+        'fcs',
     }
 
     assert existing_registries == expected_registry_names