diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml
index 1151837111..ffbfac4585 100644
--- a/.github/workflows/pr-gpu.yaml
+++ b/.github/workflows/pr-gpu.yaml
@@ -40,7 +40,7 @@ jobs:
if: github.repository_owner == 'mosaicml'
with:
container: ${{ matrix.container }}
- mcloud-timeout: 1200
+ mcloud-timeout: 1800
name: ${{ matrix.name }}
pytest-command: ${{ matrix.pytest_command }}
pytest-markers: ${{ matrix.markers }}
diff --git a/README.md b/README.md
index 04bad9c519..46074613e1 100644
--- a/README.md
+++ b/README.md
@@ -181,14 +181,14 @@ source llmfoundry-venv-amd/bin/activate
# installs
pip install cmake packaging torch
-pip install -e . # this installs some things which are not needed but they dont hurt
+pip install -e . # This installs some things that are not needed but they don't hurt
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm5.4.2
```
**Lastly**, install the ROCm enabled flash attention (instructions [here](https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm2#amd-gpurocm-support)).
Notes:
1. `attn_impl: triton` does not work.
-1. We don't yet have a docker img where everything works perfectly. You might need to up/down grade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
+1. We don't yet have a docker img where everything works perfectly. You might need to up/downgrade some packages (in our case, we needed to downgrade to `numpy==1.23.5`) before everything works without issue.
# Quickstart
@@ -228,7 +228,7 @@ python inference/convert_composer_to_hf.py \
# --hf_repo_for_upload user-org/repo-name
# Evaluate the model on a subset of tasks
-python eval/eval.py \
+composer eval/eval.py \
eval/yamls/hf_eval.yaml \
icl_tasks=eval/yamls/copa.yaml \
model_name_or_path=mpt-125m-hf
diff --git a/TUTORIAL.md b/TUTORIAL.md
index 36993bc409..86bd9829e9 100644
--- a/TUTORIAL.md
+++ b/TUTORIAL.md
@@ -8,27 +8,42 @@ Forging LLMs can be quite complicated — you have to get your data prepared, se
This tutorial will provide a brief intro to the repo’s structure and underlying tools (all courtesy of MosaicML, of course), will go over a few example workflows and point you to the related resources within the repo, and will finally cover a number of FAQs that we have encountered since release.
+- [LLM Foundry Tutorial](#llm-foundry-tutorial)
- [Intro](#intro)
- [How this repo is structured](#how-this-repo-is-structured)
- [Key components](#key-components)
+ - [Composer](#composer)
+ - [StreamingDataset](#streamingdataset)
+ - [MCLI](#mcli)
- [How the YAMLs work](#how-the-yamls-work)
- [Example Workflows](#example-workflows)
- [Workflow 1: I want to play with a HF model like MPT-7B locally](#workflow-1-i-want-to-play-with-a-hf-model-like-mpt-7b-locally)
- [Workflow 2: I want to deploy an inference endpoint with a HF model like MPT-7B](#workflow-2-i-want-to-deploy-an-inference-endpoint-with-a-hf-model-like-mpt-7b)
- [Workflow 3: I want to finetune a HF model like MPT-7B](#workflow-3-i-want-to-finetune-a-hf-model-like-mpt-7b)
+ - [Supervised FineTuning and Instruction FineTuning](#supervised-finetuning-and-instruction-finetuning)
+ - [Domain Adaptation and Sequence Length Adaptation](#domain-adaptation-and-sequence-length-adaptation)
+ - [Data](#data)
+ - [Modeling](#modeling)
- [Workflow 4: I want to train a new HF model from scratch](#workflow-4-i-want-to-train-a-new-hf-model-from-scratch)
- [FAQs](#faqs)
- - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
- - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
- - [What hardware can I train on?](#what-hardware-can-i-train-on)
- - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
- - [What is FSDP?](#what-is-fsdp)
- - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton-for-mpt-and-which-one-should-i-use)
- - [Can I finetune using PEFT / LORA?](#can-i-finetune-using-peft--lora)
- - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
- - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
- - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
- - [Common installation issues](#common-installation-issues)
+ - [Why is the script only using 1 out of N GPUs?](#why-is-the-script-only-using-1-out-of-n-gpus)
+ - [I’m running into an Out-Of-Memory (OOM) error. What do I do?](#im-running-into-an-out-of-memory-oom-error-what-do-i-do)
+ - [What hardware can I train on?](#what-hardware-can-i-train-on)
+ - [What hardware can I run eval on?](#what-hardware-can-i-run-eval-on)
+ - [What hardware can I run inference on?](#what-hardware-can-i-run-inference-on)
+ - [What is FSDP?](#what-is-fsdp)
+ - [What are the different attention options `torch` / `flash` / `triton` for MPT and which one should I use?](#what-are-the-different-attention-options-torch--flash--triton--for-mpt-and-which-one-should-i-use)
+ - [Limitations](#limitations)
+ - [What is `triton-pre-mlir`?](#what-is-triton-pre-mlir)
+ - [Known issue with sm86+ GPUs](#known-issue-with-sm86-gpus)
+ - [Support for FlashAttention-2](#support-for-flashattention-2)
+ - [What kinds of positional embeddings does LLM Foundry support?](#what-kinds-of-positional-embeddings-does-llm-foundry-support)
+ - [Can I finetune using PEFT / LoRA?](#can-i-finetune-using-peft--lora)
+ - [Can I quantize these models and/or run on CPU?](#can-i-quantize-these-models-andor-run-on-cpu)
+ - [How do I deploy with ONNX/FasterTransformer?](#how-do-i-deploy-with-onnxfastertransformer)
+ - [TransformerEngine and amp\_fp8 support](#transformerengine-and-amp_fp8-support)
+ - [How expensive is it to build LLMs?](#how-expensive-is-it-to-build-llms)
+ - [Common installation issues](#common-installation-issues)
Let’s get started!
@@ -68,7 +83,7 @@ The Trainer is a pytorch-native object that composes your model, dataset(s), opt
Spending some time understanding the Composer Trainer is a great way to form a deeper understanding of what the train and eval scripts are doing under the hood.
Composer also comes packaged with the `composer` launcher.
-If you go through our docs, you'll notice that we instruct you to launch the train script (`scripts/train/train.py`) and eval script (`scripts/eval/eval.py`) using the launcher, like so,
+If you go through our docs, you'll notice that we instruct you to launch the training script (`scripts/train/train.py`) and eval script (`scripts/eval/eval.py`) using the launcher, like so,
```bash
@@ -81,7 +96,7 @@ The `composer` launcher puts all your GPUs to work by launching the script on a
### StreamingDataset
The training script contains logic for building a few different types of dataloaders used for different training tasks.
-Each of these dataloaders are built to work with **streaming datasets**.
+Each of these dataloaders is built to work with **streaming datasets**.
There are a number of benefits that come from using streaming datasets, from fast, deterministic resumption to easily loading from a mixture of streams at once.
The scripts in `scripts/data_prep/` are your one-stop-shop for converting a local dataset or a dataset on the Hugging Face Hub to our streaming MDS format.
@@ -178,7 +193,7 @@ We address two possible versions of “finetuning” here. For both, you’ll wa
### Supervised FineTuning and Instruction FineTuning
-`scripts/train/` already includes some resources for supervised finetuning. If that’s what you’re interestested in check out
+`scripts/train/` already includes some resources for supervised finetuning. If that’s what you’re interested in check out
1. [**LLM Finetuning from a Local Dataset: A Concrete Example**](https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/finetune_example/README.md)
2. [The YAML which should replicate the process of creating MPT-7B-Instruct from MPT-7b](https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/yamls/finetune/mpt-7b_dolly_sft.yaml) — You can point this at your own dataset by [following these instructions](https://github.com/mosaicml/llm-foundry/blob/main/scripts/train/README.md#Usage)
@@ -228,7 +243,7 @@ After you're done training, you probably want to convert your Composer checkpoin
> **Note**
> Pretraining for 10s of billions of tokens is a large job even for a smaller model; you’ll want multiple A100s for this example.
-It is conceivable that you would like to train a model *with the same architecture* as a model available in HuggingFace `transformers` but without using those same weights; for example, if you have a large amount of proprietary data, or want to change something about the model that is hard to change after the fact. So, as an example, let’s say you want a version of `gpt2` but with longer sequence length, say 2048. Using the MPT architecture would give us Flash Attention and ALiBi, allowing us to go much longer; but for this example we stick with 2048. And of course, let’s use 150 tokens/parameter, which is the ratio that MPT-7B used, getting us to 17.55B tokens for our 117M param model.
+It is conceivable that you would like to train a model *with the same architecture* as a model available in HuggingFace `transformers` but without using those same weights; for example, if you have a large amount of proprietary data, or want to change something about the model that is hard to change after the fact. So, as an example, let’s say you want a version of `gpt2` but with a longer sequence length, say 2048. Using the MPT architecture would give us Flash Attention and ALiBi, allowing us to go much longer; but for this example we stick with 2048. And of course, let’s use 150 tokens/parameter, which is the ratio that MPT-7B used, getting us to 17.55B tokens for our 117M param model.
The first step to training from scratch is to get your pretraining data prepared. Following [the data preparation README](https://github.com/mosaicml/llm-foundry/blob/main/scripts/data_prep/README.md), we convert C4 as follows:
@@ -294,25 +309,25 @@ The purpose of this section is probably pretty self-evident. You’ve got questi
- **Long answer:** In NLP, Softmax Attention operates on a sequence. It is an all to all graph operation where, during training, the memory complexity is quadratic with respect to the length of the sequence. Furthermore, on GPUs, naive implementations of Softmax Attention are bandwidth (BW) limited.
[Rabe et al. (2021)](https://arxiv.org/abs/2112.05682) and [Dao et al. (2022)](https://arxiv.org/abs/2205.14135) showed that fusing all operations in Softmax Attention can make the operation much less BW limited.
-Furthermore, integrating a recompuation schema decreases the sequence length memory complexity from *quadratic* to *linear*, thereby supporting much longer sequence lengths.
+Furthermore, integrating a recomputation schema decreases the sequence length memory complexity from *quadratic* to *linear*, thereby supporting much longer sequence lengths.
- Setting `attn_config.attn_impl=torch` enables a naive Softmax Attention written using base torch operations.
- Setting `attn_config.attn_impl=flash` enables Flash Attention [implemented by Dao et al in the HazyResearch repo using CUDA](https://github.com/HazyResearch/flash-attention). This will have linear memory complexity (enabling larger batch sizes) and will run much faster.
- - Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). In our experiance, `triton` is slightly faster than `flash`.
+ - Setting `attn_config.attn_impl=triton` enables a Flash Attention [implemented using Triton](https://github.com/mosaicml/llm-foundry/blob/main/llmfoundry/models/layers/flash_attn_triton.py). In our experience, `triton` is slightly faster than `flash`.
-
#### Limitations
- For training, `torch` uses a lot of memory and is slow.
-- `flash` and `triton` cannot return attention weights and therefore cannot be used with methods which require it.
-- `flash` cannot accept an attention bias and therefore cannot be used with methods which require it such as ALiBi.
+- `flash` and `triton` cannot return attention weights and therefore cannot be used with methods that require it.
+- `flash` cannot accept an attention bias and therefore cannot be used with methods that require it such as ALiBi.
#### What is `triton-pre-mlir`?
- Torch2 installs and requires a specific version of [Triton](https://openai.com/research/triton).
@@ -328,6 +343,18 @@ The majority of our training setups use `triton`. -->
Updating to LLVM14 (or LLVM15) cannot be done because there are breaking changes.
What is the result of this? Although sm89+ is not **formally** supported until LLVM15, our testing on H100 GPUs shows that `attn_impl=triton` still works well and still runs fast. The only issue is that when the network is starting to run, LLVM might throw a warning like: `'sm_90' is not a recognized processor for this target (ignoring processor)`. This warning does not seem to affect performance.
+#### Support for FlashAttention-2
+- [FlashAttention-2](https://arxiv.org/pdf/2307.08691.pdf) improves upon FlashAttention to get even faster attention computation. LLM Foundry supports FlashAttention-2. Please follow the instructions [here](https://github.com/mosaicml/llm-foundry/tree/main/scripts/train#flashattention).
+
+### What kinds of positional embeddings does LLM Foundry support?
+Currently we support [Learned Positional Embeddings](https://arxiv.org/pdf/1706.03762.pdf), [Attention with Linear Biases (ALiBi)](https://arxiv.org/pdf/2108.12409.pdf), and [Rotary Positional Embeddings (RoPE)](https://arxiv.org/pdf/2104.09864.pdf). There is also an option to switch off all of these embeddings to get [No Positional Embedding](https://arxiv.org/pdf/2203.16634.pdf).
+
+| Name | YAML Config | Training MFU on MPT-7B trained on 8 A100 80GB GPUs | Notes |
+|:-----------------------------------|:------------------------------------------------------------------|:---------------------------------------------------:|:----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| Learned Positional Embeddings |
model:
learned_pos_emb: True
| 65.7 | |
+| ALiBi | model:
attn_config:
alibi: True
| 64.5 | Requires Triton or Torch attention. |
+| RoPE (Dao-AILab Implementation) | model:
attn_config:
rope: True
rope_impl: dail
| 64.5 | Requires a CUDA GPU and the [flash-attn library](https://github.com/Dao-AILab/flash-attention) v2.0.1 or higher to be installed. Please see the instructions in the [paragraph above](#support-for-flashattention-2) on how to install flash-attn v2. Note that the attention implementation can still be `torch`, `triton`, or `flash`. |
+| RoPE (Hugging
Face Implementation) | model:
attn_config:
rope: True
rope_impl: hf
| 62.3 | |
### Can I finetune using PEFT / LoRA?
- The LLM Foundry codebase does not directly have examples of PEFT or LORA workflows. However, our MPT model is a subclass of HuggingFace `PretrainedModel`, and https://github.com/mosaicml/llm-foundry/pull/346 added required features to enable HuggingFace’s [PEFT](https://huggingface.co/docs/peft/index) / [LORA](https://huggingface.co/docs/peft/conceptual_guides/lora) workflows for MPT. MPT models with LoRA modules can be trained either using LLM Foundry or Hugging Face's [accelerate](https://huggingface.co/docs/accelerate/index). Within LLM Foundry, run (`scripts/train/train.py`), adding `lora` arguments to the config `.yaml`, like so:
@@ -370,7 +397,7 @@ model:
```
enables [TransformerEngine's LayerNormMLP](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html#transformer_engine.pytorch.LayerNormMLP) layer which enables sequence parallelism if configured correctly.
-WARNING: `state_dicts` generated with `ffn_type: te_ln_mlp` will NOT directly map to `state_dicts` generated using the default network configurations. We do not have control over how `te.LayerNormMLP` is implemented and therefore cannot reasily reconcile it with the default implementation (or any other implementation).
+WARNING: `state_dicts` generated with `ffn_type: te_ln_mlp` will NOT directly map to `state_dicts` generated using the default network configurations. We do not have control over how `te.LayerNormMLP` is implemented and therefore cannot readily reconcile it with the default implementation (or any other implementation).
### How expensive is it to build LLMs?
- Check out our blog post [GPT3-Quality for <$500k](https://www.mosaicml.com/blog/gpt-3-quality-for-500k) for guidance on LLM training times and costs.
diff --git a/llmfoundry/__init__.py b/llmfoundry/__init__.py
index 3bb9eed043..51fa67993a 100644
--- a/llmfoundry/__init__.py
+++ b/llmfoundry/__init__.py
@@ -4,6 +4,11 @@
import torch
try:
+ # Before importing any transformers models, we need to disable transformers flash attention if
+ # we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+ # gated import otherwise.
+ import transformers
+
from llmfoundry import optim, utils
from llmfoundry.data import (ConcatTokensDataset,
MixtureOfDenoisersCollator, NoConcatDataset,
@@ -14,8 +19,8 @@
ComposerHFT5)
from llmfoundry.models.layers.attention import (
MultiheadAttention, attn_bias_shape, build_alibi_bias, build_attn_bias,
- flash_attn_fn, scaled_multihead_dot_product_attention,
- triton_flash_attn_fn)
+ flash_attn_fn, is_flash_v1_installed,
+ scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import (FFN_CLASS_REGISTRY, MPTMLP,
build_ffn)
@@ -24,6 +29,8 @@
MPTForCausalLM, MPTModel,
MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
+ if is_flash_v1_installed():
+ transformers.utils.is_flash_attn_available = lambda: False
except ImportError as e:
try:
diff --git a/llmfoundry/callbacks/eval_gauntlet_callback.py b/llmfoundry/callbacks/eval_gauntlet_callback.py
index 78ccbb529b..7281a8d1fc 100644
--- a/llmfoundry/callbacks/eval_gauntlet_callback.py
+++ b/llmfoundry/callbacks/eval_gauntlet_callback.py
@@ -22,6 +22,32 @@ class Weighting(Enum):
LOG_SAMPLE_SZ = 3
+def calculate_named_averages(average_names: Dict[str, list],
+ category_scores: Dict[str, float]):
+ """Calculates the named averages based off the raw category scores.
+
+ For each named average, take a simple average of all the category scores associated with that named average.
+
+ Args:
+ average_names (dict[str, list]): Contains a mapping of named averages to which category scores that average should consist of.
+ category_scores (dict[str, float]): Contains the raw scores corresponding to each category.
+ """
+ average_scores = {}
+ for avg_name, category_list in average_names.items():
+ composite_subset = {
+ category: score
+ for category, score in category_scores.items()
+ if category in category_list
+ }
+ if len(composite_subset.values()) > 0:
+ average_scores[avg_name] = sum(composite_subset.values()) / len(
+ composite_subset.values())
+ else:
+ average_scores[avg_name] = 0
+
+ return average_scores
+
+
class EvalGauntlet(Callback):
"""The EvalGauntlet aggregates ICL eval results.
@@ -31,7 +57,7 @@ class EvalGauntlet(Callback):
Args:
logger_keys (list): These are the exact keys that the individual benchmark metrics will be
logged under in the logger after eval
- tasks (dict): This contains the list of categories, as well as the subtasks within them, the
+ categories (dict): This contains the list of categories, as well as the subtasks within them, the
random baseline accuracy of each subtask, and the number of fewshot examples
used for the task. See `llmfoundry/scripts/eval/yamls/eval_gauntlet.yaml` to see the structure.
weighting (Weighting): The weighting scheme used to balance different tasks within each category.
@@ -43,6 +69,7 @@ class EvalGauntlet(Callback):
rescale_accuracy (bool): Flag determining whether to rescale the accuracy on each benchmark
by (1-random_baseline_accuracy) before aggregating. Using this ensures that all benchmarks max out at 1.0.
benchmark_sizes (Optional[dict]): Optional data on benchmark sizes, used when not relying on equal weighting.
+ averages (Optional[dict]): Optional dictionary specifying a mapping from a average names to lists of categories used produce each named average.
"""
def __init__(self,
@@ -51,7 +78,8 @@ def __init__(self,
weighting: str = 'EQUAL',
subtract_random_baseline: bool = True,
rescale_accuracy: bool = True,
- benchmark_sizes: Optional[dict] = None):
+ benchmark_sizes: Optional[dict] = None,
+ averages: Optional[dict] = None):
if isinstance(logger_keys, dict):
raise ValueError(
'logger_keys now requires a list type as input, not a dict')
@@ -66,13 +94,12 @@ def __init__(self,
)
self.categories = categories
+ self.category_names = [conf.get('name') for conf in self.categories]
self.weighting = Weighting[weighting]
self.subtract_random_baseline = subtract_random_baseline
self.rescale_accuracy = rescale_accuracy
self.logger_keys = logger_keys
-
for category in self.categories:
-
for benchmark in category['benchmarks']:
bench_name = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
@@ -95,7 +122,20 @@ def __init__(self,
assert weight is not None
benchmark['weighting'] = weight
- def compute_averages(self, state: State) -> Dict[str, float]:
+ self.averages = {}
+ if averages is not None:
+ self.averages = averages
+ else:
+ # if no averages spec provided, simply average everything
+ self.averages['default_average'] = self.category_names
+
+ for avg_name in self.averages:
+ if avg_name in self.category_names:
+ raise ValueError(
+ f'Found average name `{avg_name}` used as category name. Average names and category names must be non-overlapping.'
+ )
+
+ def extract_metrics_from_state(self, state: State) -> Dict[str, float]:
results = {}
for key in self.logger_keys:
@@ -121,23 +161,22 @@ def compute_averages(self, state: State) -> Dict[str, float]:
return {k: sum(v) / len(v) for k, v in results.items()}
def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
- new_metrics = self.compute_averages(state)
- if len(new_metrics) == 0:
+ computed_metrics = self.extract_metrics_from_state(state)
+ if len(computed_metrics) == 0:
return {}
- composite_scores = {}
-
+ category_scores = {}
for category in self.categories:
missing_metrics = []
- composite_scores[category['name']] = []
+ category_scores[category['name']] = []
for benchmark in category['benchmarks']:
key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot"
- if key not in new_metrics:
+ if key not in computed_metrics:
log.warning(
f'Could not find results for benchmark: {benchmark}.')
missing_metrics.append(key)
else:
- score = new_metrics[key]
+ score = computed_metrics[key]
if self.subtract_random_baseline:
score -= benchmark['random_baseline']
@@ -145,7 +184,7 @@ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
if self.rescale_accuracy and self.subtract_random_baseline:
score /= 1.0 - benchmark['random_baseline']
- composite_scores[category['name']].append({
+ category_scores[category['name']].append({
'name': benchmark['name'],
'score': score,
'weighting': benchmark['weighting']
@@ -155,23 +194,22 @@ def eval_after_all(self, state: State, logger: Logger) -> Dict[str, float]:
log.warning(
f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}"
)
- del composite_scores[category['name']]
+ del category_scores[category['name']]
continue
total_weight = sum(
- k['weighting'] for k in composite_scores[category['name']])
- composite_scores[category['name']] = sum(
+ k['weighting'] for k in category_scores[category['name']])
+ category_scores[category['name']] = sum(
k['score'] * (k['weighting'] / total_weight)
- for k in composite_scores[category['name']])
+ for k in category_scores[category['name']])
- composite_scores = {
+ named_averages = calculate_named_averages(self.averages,
+ category_scores)
+ category_scores.update(named_averages)
+ category_scores = {
f'icl/metrics/eval_gauntlet/{k}': v
- for k, v in composite_scores.items()
+ for k, v in category_scores.items()
}
-
- composite_scores['icl/metrics/eval_gauntlet/average'] = sum(
- composite_scores.values()) / len(composite_scores.values()) if len(
- composite_scores.values()) > 0 else 0
if logger is not None:
- logger.log_metrics(composite_scores)
+ logger.log_metrics(category_scores)
- return composite_scores
+ return category_scores
diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py
index aa3beda513..e02bf03693 100644
--- a/llmfoundry/callbacks/hf_checkpointer.py
+++ b/llmfoundry/callbacks/hf_checkpointer.py
@@ -4,18 +4,20 @@
import contextlib
import copy
import logging
+import math
import os
import tempfile
from pathlib import Path
from typing import Optional, Union
import torch
-from composer.core import Callback, Event, State, Time
+from composer.core import Callback, Event, State, Time, TimeUnit
from composer.core.state import fsdp_state_dict_type_context
from composer.loggers import Logger, MLFlowLogger
-from composer.loggers.remote_uploader_downloader import RemoteUploaderDownloader
from composer.models import HuggingFaceModel
-from composer.utils import dist, format_name_with_dist_and_time, parse_uri
+from composer.utils import (dist, format_name_with_dist_and_time,
+ maybe_create_remote_uploader_downloader_from_uri,
+ parse_uri)
from composer.utils.misc import create_interval_scheduler
from transformers import PreTrainedModel, PreTrainedTokenizerBase
@@ -52,12 +54,11 @@ def __init__(
save_interval: Union[str, int, Time],
huggingface_folder_name: str = 'ba{batch}',
precision: str = 'float32',
- overwrite: bool = False,
+ overwrite: bool = True,
mlflow_registered_model_name: Optional[str] = None,
mlflow_logging_config: Optional[dict] = None,
):
- self.backend, self.bucket_name, self.save_dir_format_str = parse_uri(
- save_folder)
+ _, _, self.save_dir_format_str = parse_uri(save_folder)
self.overwrite = overwrite
self.precision = precision
self.dtype = {
@@ -83,15 +84,20 @@ def __init__(
self.huggingface_folder_name_fstr = os.path.join(
'huggingface', huggingface_folder_name)
+
+ if isinstance(save_interval, str):
+ save_interval = Time.from_timestring(save_interval)
+ if isinstance(save_interval, int):
+ save_interval = Time(save_interval, TimeUnit.EPOCH)
+
+ self.save_interval = save_interval
self.check_interval = create_interval_scheduler(
save_interval, include_end_of_training=True)
- self.upload_to_object_store = (self.backend != '')
- if self.upload_to_object_store:
- self.remote_ud = RemoteUploaderDownloader(
- bucket_uri=f'{self.backend}://{self.bucket_name}',
- num_concurrent_uploads=4)
- else:
- self.remote_ud = None
+
+ self.remote_ud = maybe_create_remote_uploader_downloader_from_uri(
+ save_folder, loggers=[])
+ if self.remote_ud is not None:
+ self.remote_ud._num_concurrent_uploads = 4
self.last_checkpoint_batch: Optional[Time] = None
self.mlflow_loggers = []
@@ -107,7 +113,7 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
raise ValueError(
f'`HuggingFaceCheckpointer` is only compatible with `HuggingFaceModel`s. '
+ f'Got {type(state.model)} instead.')
- if self.upload_to_object_store and self.remote_ud is not None:
+ if self.remote_ud is not None:
self.remote_ud.init(state, logger)
state.callbacks.append(self.remote_ud)
@@ -128,6 +134,21 @@ def run_event(self, event: Event, state: State, logger: Logger) -> None:
mlflow.environment_variables.MLFLOW_HUGGINGFACE_MODEL_MAX_SHARD_SIZE.set(
'5GB')
+ def _is_last_batch(self, state: State):
+ elapsed_duration = state.get_elapsed_duration()
+ if elapsed_duration is not None and elapsed_duration >= 1.0:
+ return True
+
+ assert state.max_duration is not None # for pyright
+ # If the save interval is specified as 1dur, and the max duration is in epoch units
+ # we need a special case to identify we are on the last batch and should write the mlflow checkpoint
+ if self.save_interval.unit == TimeUnit.DURATION and self.save_interval.value == 1 and state.max_duration.unit == TimeUnit.EPOCH:
+ assert state.dataloader_len is not None # for pyright
+ return int(state.timestamp.batch) % math.ceil(
+ state.max_duration.value * state.dataloader_len) == 0
+
+ return False
+
def _save_checkpoint(self, state: State, logger: Logger):
del logger # unused
@@ -146,7 +167,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
self.huggingface_folder_name_fstr), state.run_name,
state.timestamp)
dir_context_mgr = tempfile.TemporaryDirectory(
- ) if self.upload_to_object_store else contextlib.nullcontext(
+ ) if self.remote_ud is not None else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
@@ -210,11 +231,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
log.debug('Editing MPT files for HuggingFace compatibility')
edit_files_for_hf_compatibility(temp_save_dir)
- if self.upload_to_object_store:
- assert self.remote_ud is not None
- log.info(
- f'Uploading HuggingFace formatted checkpoint to {self.backend}://{self.bucket_name}/{save_dir}'
- )
+ if self.remote_ud is not None:
+ log.info(f'Uploading HuggingFace formatted checkpoint')
for filename in os.listdir(temp_save_dir):
self.remote_ud.upload_file(
state=state,
@@ -224,8 +242,8 @@ def _save_checkpoint(self, state: State, logger: Logger):
overwrite=self.overwrite,
)
- elapsed_duration = state.get_elapsed_duration()
- if self.mlflow_registered_model_name is not None and elapsed_duration is not None and elapsed_duration >= 1.0:
+ if self.mlflow_registered_model_name and self._is_last_batch(
+ state):
components = {'model': new_model_instance}
if original_tokenizer is not None:
components['tokenizer'] = original_tokenizer
diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py
index c997c865dd..8da436b9b1 100644
--- a/llmfoundry/data/__init__.py
+++ b/llmfoundry/data/__init__.py
@@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
+from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
build_text_denoising_dataloader)
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
@@ -18,4 +19,5 @@
'build_text_dataloader',
'NoConcatDataset',
'ConcatTokensDataset',
+ 'build_dataloader',
]
diff --git a/llmfoundry/data/dataloader.py b/llmfoundry/data/dataloader.py
new file mode 100644
index 0000000000..12741717be
--- /dev/null
+++ b/llmfoundry/data/dataloader.py
@@ -0,0 +1,44 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Dataloader builder utilities."""
+
+from composer import DataSpec
+from omegaconf import DictConfig
+from transformers import PreTrainedTokenizerBase
+
+from llmfoundry.data.denoising import build_text_denoising_dataloader
+from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
+from llmfoundry.data.text_data import build_text_dataloader
+
+
+def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
+ device_batch_size: int) -> DataSpec:
+ """Builds a dataloader from a config.
+
+ Args:
+ cfg (DictConfig): An omegaconf dictionary used to configure the loader.
+ tokenizer (PreTrainedTokenizerBase): The tokenizer that the model will use.
+ device_batch_size (int): The size of the batches (number of examples)
+ that the dataloader will produce.
+ """
+ if cfg.name == 'text':
+ return build_text_dataloader(
+ cfg,
+ tokenizer,
+ device_batch_size,
+ )
+ elif cfg.name == 'text_denoising':
+ return build_text_denoising_dataloader(
+ cfg,
+ tokenizer,
+ device_batch_size,
+ )
+ elif cfg.name == 'finetuning':
+ return build_finetuning_dataloader(
+ cfg,
+ tokenizer,
+ device_batch_size,
+ )
+ else:
+ raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py
index bc41945076..7d497b4efd 100644
--- a/llmfoundry/data/denoising.py
+++ b/llmfoundry/data/denoising.py
@@ -16,7 +16,7 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase
-from llmfoundry.data.packing import BinPackWrapper
+from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils
@@ -375,19 +375,25 @@ def build_text_denoising_dataloader(
cfg.dataset.max_seq_len (int): The maximum length of sequences
in the batch. See :class:`MixtureOfDenoisersCollator` docstring
for details.
- cfg.dataset.packing_ratio (float, optional): If provided, this invokes
+ cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
a collator wrapper that packs device_batch_size*packing_ratio
raw examples into device_batch_size packed examples. This helps
minimize padding while preserving sequence integrity.
This adds `sequence_id` to the batch, which indicates which unique
sequence each token belongs to.
+
+ If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
+ zero waste is selected.
+ In practice, this may result in > 0 waste because profiling is done on only a portion
+ of the dataset.
+
Note: Using this feature will not change device_batch_size but it
will determine the number of raw examples consumed by the dataloader
per batch. Some examples may be discarded if they do not fit when
packing.
Select packing_ratio **carefully** based on the dataset
statistics, max_seq_len, and tolerance for discarding samples!
- The packing code in `./packing.py` provides a script that can help
+ The script `scripts/misc/profile_packing.py` can help
you choose the best packing_ratio.
See :class:`StreamingTextDataset` for info on other standard config
options within `cfg.dataset`.
@@ -419,7 +425,7 @@ def build_text_denoising_dataloader(
that the dataloader will produce.
Note:
- You can run the script inside `./packing.py` to quickly test the
+ You can use the script `scripts/misc/profile_packing.py` to quickly test the
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
@@ -492,7 +498,7 @@ def build_text_denoising_dataloader(
raise NotImplementedError(
'On-the-fly packing is currently only supported for decoder-only formats.'
)
- collate_fn = BinPackWrapper(
+ collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=cfg.dataset.max_seq_len,
diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py
index 2dde563ac6..44d6d345f5 100644
--- a/llmfoundry/data/finetuning/dataloader.py
+++ b/llmfoundry/data/finetuning/dataloader.py
@@ -14,7 +14,7 @@
from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
-from llmfoundry.data.packing import BinPackWrapper
+from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
log = logging.getLogger(__name__)
@@ -74,20 +74,26 @@ def build_finetuning_dataloader(cfg: DictConfig,
cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
docstring for details. Default: ``False``.
- cfg.dataset.packing_ratio (float, optional): If provided, this invokes
- a collator wrapper that packs `device_batch_size*packing_ratio`
- raw examples into `device_batch_size` packed examples. This helps
+ cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
+ a collator wrapper that packs device_batch_size*packing_ratio
+ raw examples into device_batch_size packed examples. This helps
minimize padding while preserving sequence integrity.
This adds `sequence_id` to the batch, which indicates which unique
sequence each token belongs to.
+
+ If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
+ zero waste is selected.
+ In practice, this may result in > 0 waste because profiling is done on only a portion
+ of the dataset.
+
Note: Using this feature will not change device_batch_size but it
will determine the number of raw examples consumed by the dataloader
per batch. Some examples may be discarded if they do not fit when
packing.
- Select `packing_ratio` **carefully** based on the dataset
- statistics, `max_seq_len`, and tolerance for discarding samples!
- The packing code in `../packing.py` provides a script that can help
- you choose the best `packing_ratio`.
+ Select packing_ratio **carefully** based on the dataset
+ statistics, max_seq_len, and tolerance for discarding samples!
+ The script `scripts/misc/profile_packing.py` can help
+ you choose the best packing_ratio.
cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
___
See :class:`StreamingFinetuningDataset` for info on other standard config
@@ -106,7 +112,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
A pytorch dataloader
Note:
- You can run the script inside `../packing.py` to quickly test the
+ You can run the script inside `scripts/misc/profile_packing.py` to quickly test the
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
@@ -143,7 +149,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)
collate_fn, dataloader_batch_size = _build_collate_fn(
- cfg.dataset, tokenizer, device_batch_size)
+ cfg, tokenizer, device_batch_size)
dl = DataLoader(
dataset,
@@ -174,7 +180,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)
collate_fn, dataloader_batch_size = _build_collate_fn(
- cfg.dataset, tokenizer, device_batch_size)
+ cfg, tokenizer, device_batch_size)
if cfg.drop_last:
world_size = dist.get_world_size()
@@ -367,25 +373,40 @@ def _build_hf_dataset_from_remote(
def _build_collate_fn(
- dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
+ dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
-) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
+) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
+ dataset_cfg = dataloader_cfg.dataset
+ max_seq_len = dataset_cfg.max_seq_len
+
collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
- max_seq_len=dataset_cfg.max_seq_len,
+ max_seq_len=max_seq_len,
decoder_only_format=dataset_cfg.decoder_only_format,
allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
)
packing_ratio = dataset_cfg.get('packing_ratio')
+ max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep')
if packing_ratio is None:
- if dataset_cfg.get('max_leftover_bins_to_keep') is not None:
+ if max_leftover_bins_to_keep is not None:
raise ValueError(
'dataset.max_leftover_bins_to_keep has been defined, ' +\
'but dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')
return collate_fn, device_batch_size
+ if packing_ratio == 'auto':
+ packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer,
+ device_batch_size)
+
+ if isinstance(packing_ratio, str):
+ raise ValueError(
+ 'dataset.packing_ratio must be a float or "auto", but it was set to '
+ + f'{packing_ratio}.')
+
+ log.info(f'Using packing ratio {packing_ratio}')
+
if packing_ratio == 1.0:
return collate_fn, device_batch_size
elif packing_ratio < 1.0:
@@ -396,13 +417,13 @@ def _build_collate_fn(
'On-the-fly packing is currently only supported for decoder-only formats.'
)
- collate_fn = BinPackWrapper(
+ collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
- max_seq_len=dataset_cfg.max_seq_len,
+ max_seq_len=max_seq_len,
pad_token_id=tokenizer.pad_token_id,
padding_side=tokenizer.padding_side,
- max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'),
+ max_leftover_bins_to_keep=max_leftover_bins_to_keep,
)
n_examples_to_pack = int(device_batch_size * packing_ratio)
return collate_fn, n_examples_to_pack
diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py
index f2bd0239c8..6ba6ad96c8 100644
--- a/llmfoundry/data/finetuning/tasks.py
+++ b/llmfoundry/data/finetuning/tasks.py
@@ -38,6 +38,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
from typing import Any, Callable, Dict, List, Optional, Union
import datasets as hf_datasets
+from composer.utils import dist
from omegaconf import DictConfig
from streaming import StreamingDataset
from transformers import PreTrainedTokenizerBase
@@ -332,6 +333,16 @@ def build_from_hf(
preprocessing_fn = self.get_preprocessing_fn_from_str(
proto_preprocessing_fn, dataset_name)
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_data_prep_completed'
+
+ # Non local rank 0 ranks will wait here for local rank 0 to finish the data processing.
+ # Once local rank 0 is done, the datasets are all cached on disk, and all other ranks
+ # can just read them.
+ if dist.get_local_rank() != 0:
+ log.debug('Waiting for local_rank 0 to finish data prep')
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
+ pass
+
dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)
def dataset_mapper(example: Dict):
@@ -339,34 +350,59 @@ def dataset_mapper(example: Dict):
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)
+ detected_cpu_count = os.cpu_count() or 1
+ detected_cpus_with_margin = detected_cpu_count - 8
+ num_cpus_to_use = max(1, detected_cpus_with_margin)
+
columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
+ num_proc=num_cpus_to_use,
+ desc='Tokenizing dataset',
+ )
+
+ pad_token_id = tokenizer.pad_token_id
+
+ def filter_long_or_empty_examples(example: Dict) -> bool:
+ less_than_max_seq_len = len(example['input_ids']) < max_seq_len
+ non_empty_input = len(example['input_ids']) > 0
+ non_empty_labels = len(example['labels']) > 0
+ non_padding_response = any(
+ token_id != pad_token_id for token_id in example['labels'])
+ return (less_than_max_seq_len and non_empty_input and
+ non_empty_labels and non_padding_response)
+
+ filtered_dataset = tokenized_dataset.filter(
+ filter_long_or_empty_examples,
+ num_proc=num_cpus_to_use,
+ desc='Filtering out long prompts',
)
- prompt_length_filtered_dataset = tokenized_dataset.filter(
- lambda example: len(example['input_ids']) < max_seq_len)
- examples_removed = len(tokenized_dataset) - len(
- prompt_length_filtered_dataset)
+ examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
- f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}.'
+ f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+ +
+ 'the prompt or response was empty, or the response was all padding tokens.'
)
- empty_examples_dropped_dataset = prompt_length_filtered_dataset.filter(
- lambda example: len(example['input_ids']) > 0 and len(example[
- 'labels']) > 0 and any(token_id != tokenizer.pad_token_id
- for token_id in example['labels']))
- empty_examples_removed = len(prompt_length_filtered_dataset) - len(
- empty_examples_dropped_dataset)
- if empty_examples_removed > 0:
- warnings.warn(
- f'Dropped {empty_examples_removed} examples where the prompt or response was empty, '
- + 'or the response was only padding tokens.')
+ # Now local rank 0 indicates to the other ranks that it is done
+ if dist.get_local_rank() == 0:
+ log.debug('Local rank 0 finished data prep')
+ with open(signal_file_path, 'wb') as f:
+ f.write(b'local_rank0_completed_data_prep')
+
+ # All ranks sync up at this barrier, having completed data processing
+ dist.barrier()
+
+ # Last, local rank 0 cleans up the signal file
+ if dist.get_local_rank() == 0:
+ os.remove(signal_file_path)
- return empty_examples_dropped_dataset
+ log.debug('All ranks finished data prep')
+ return filtered_dataset
def build_from_streaming(self, *args: Any,
**kwargs: Any) -> StreamingFinetuningDataset:
diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py
index 1532de276e..45322c9b2f 100644
--- a/llmfoundry/data/packing.py
+++ b/llmfoundry/data/packing.py
@@ -1,8 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
-import os
-from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
+from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
import numpy as np
import torch
@@ -10,7 +9,7 @@
from transformers import PreTrainedTokenizerBase
-class BinPackWrapper:
+class BinPackCollator:
"""Utility collator for packing to reduce padding."""
def __init__(self,
@@ -33,13 +32,10 @@ def __init__(self,
if self.pad_token_id < 0:
raise ValueError(f'{pad_token_id=} must be >=0.')
- if max_leftover_bins_to_keep is None:
- self.max_leftover_bins_to_keep = int(10 * self.out_size)
- elif max_leftover_bins_to_keep < 0:
+ if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0:
raise ValueError(
f'{max_leftover_bins_to_keep=} must be >=0 or None.')
- else:
- self.max_leftover_bins_to_keep = int(max_leftover_bins_to_keep)
+ self.max_leftover_bins_to_keep = max_leftover_bins_to_keep
self.n_packed_tokens = 0
self.n_total_tokens = 0
@@ -60,7 +56,9 @@ def __call__(
self,
examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
+ return self.pack(batch)
+ def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert 'attention_mask' in batch
assert 'input_ids' in batch
@@ -75,12 +73,12 @@ def __call__(
# Cut everything down to size
sizes, trimmed_examples = [], []
for idx in range(batch['attention_mask'].shape[0]):
- size, trimmed_example = extract_trim_batch_idx(batch, idx)
+ size, trimmed_example = _extract_trim_batch_idx(batch, idx)
sizes.append(size)
trimmed_examples.append(trimmed_example)
# Apply our CS 101 bin packing algorithm.
- packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = first_fit_bin_packing(
+ packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(
sizes=sizes,
examples=trimmed_examples,
num_bins=self.out_size,
@@ -93,15 +91,15 @@ def __call__(
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
# Re-pad to max_seq_len and batch
- batch = repad(packed_examples,
- max_seq_len=self.max_seq_len,
- pad_token_id=self.pad_token_id,
- padding_side=self.padding_side)
+ batch = _repad(packed_examples,
+ max_seq_len=self.max_seq_len,
+ pad_token_id=self.pad_token_id,
+ padding_side=self.padding_side)
return batch
-def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
- idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
+def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
+ idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}
keep = example['attention_mask'] == 1
@@ -112,7 +110,7 @@ def extract_trim_batch_idx(batch: Dict[str, torch.Tensor],
return size, trim_example
-def combine_in_place(
+def _combine_in_place(
example: Dict[str, torch.Tensor],
add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if 'labels' in add_on:
@@ -129,7 +127,7 @@ def combine_in_place(
return example
-def first_fit_bin_packing(
+def _first_fit_bin_packing(
sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int,
max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]
) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[
@@ -194,7 +192,7 @@ def first_fit_bin_packing(
if bins[bidx][0] + size <= max_bin_size:
bin_size, packed_example = bins.pop(bidx)
bin_size = bin_size + size
- packed_example = combine_in_place(packed_example, example)
+ packed_example = _combine_in_place(packed_example, example)
bins.append((bin_size, packed_example))
added = True
break
@@ -225,8 +223,8 @@ def first_fit_bin_packing(
bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:]
-def repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
- pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
+def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int,
+ pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
def pad_tensor(tensor: torch.Tensor, pad_value: int):
if len(tensor) == max_seq_len:
@@ -260,14 +258,169 @@ def pad_tensor(tensor: torch.Tensor, pad_value: int):
return batch
+def auto_packing_ratio(dataloader_cfg: DictConfig,
+ tokenizer: PreTrainedTokenizerBase,
+ device_batch_size: int,
+ num_packing_ratios: int = 20) -> float:
+ """Find a packing ratio that minimizes padding with zero waste.
+
+ By packing examples, we can increase training efficiency, training on more data with less batches.
+ However, in practice, the selected packing_ratio may produce some waste because profiling is done on only
+ a subset of the dataset.
+
+ We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to
+ num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive.
+ When a packing_ratio with non-zero waste is found, we stop and select the previous ratio,
+ which has zero waste.
+
+ Args:
+ dataloader_cfg (DictConfig): The dataloader configuration for profiling.
+ tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
+ device_batch_size (int): The size of the batches (number of examples) per device.
+ num_packing_ratio (int): The number of packing ratios to try.
+
+ Returns:
+ A packing ratio that minimizes padding while maintaining zero waste.
+ """
+ from composer.utils import dist, get_device, reproducibility
+
+ # Stash the rng state to restore later.
+ rng_state = reproducibility.get_rng_state()
+ # Set the seed so that auto packing is deterministic.
+ reproducibility.seed_all(0)
+
+ min_ratio = 1
+ max_ratio = dataloader_cfg.dataset.max_seq_len / 100
+ profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio,
+ max_ratio, num_packing_ratios,
+ device_batch_size)
+
+ # Obtain the maximum packing_ratio/minimum padding that has no waste.
+ # profiling_results are sorted from smallest to largest packing_ratio.
+ packing_ratio = 1
+ for packing_ratio_candidate, _, waste in profiling_results:
+ if waste > 0:
+ break
+ packing_ratio = packing_ratio_candidate
+
+ # Select the minimum packing ratio across all ranks.
+ if dist.is_available() and dist.is_initialized():
+ device = get_device(None)
+ packing_ratio_tensor = device.tensor_to_device(
+ torch.tensor(packing_ratio))
+ dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
+ packing_ratio = packing_ratio_tensor.item()
+
+ # Restore rng state.
+ reproducibility.load_rng_state(rng_state)
+
+ return packing_ratio
+
+
+def profile_packing(
+ dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
+ min_ratio: float, max_ratio: float, num_packing_ratios: int,
+ device_batch_size: int) -> Iterable[Tuple[float, float, float]]:
+ """Generator function that profiles example packing across packing ratios.
+
+ Args:
+ dataloader_cfg (DictConfig): The dataloader configuration for profiling.
+ tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
+ min_ratio (float): Smallest packing_ratio to test. Must be >=1.
+ max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
+ num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
+ device_batch_size (int): The size of the batches (number of examples) per device.
+
+ Returns:
+ An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio.
+ """
+ import copy
+
+ from llmfoundry.data.dataloader import build_dataloader
+
+ max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
+ max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
+ None)
+
+ # Turn off packing for the dataloader (we want raw, pre-packed examples)
+ dataloader_cfg = copy.deepcopy(dataloader_cfg)
+ dataloader_cfg.dataset.packing_ratio = None
+ dataloader_cfg.drop_last = False
+ dataloader_cfg.num_workers = 0
+ dataloader_cfg.prefetch_factor = None
+ dataloader_cfg.persistent_workers = False
+
+ # Determine the packing_ratio values we'll try
+ packing_ratios, raw_batch_sizes = [], []
+ for packing_ratio in np.linspace(min_ratio,
+ max_ratio,
+ num_packing_ratios,
+ endpoint=True):
+ packing_ratio = np.round(10 * packing_ratio) / 10
+ raw_batch_size = int(packing_ratio * device_batch_size)
+ if raw_batch_size not in raw_batch_sizes:
+ packing_ratios.append(packing_ratio)
+ raw_batch_sizes.append(raw_batch_size)
+
+ n_profile_examples = max(raw_batch_sizes) * 100
+
+ train_dataspec = build_dataloader(dataloader_cfg, tokenizer,
+ n_profile_examples)
+ train_dataloader = train_dataspec.dataloader
+
+ # Get a bunch of raw examples
+ big_batch = next(iter(train_dataloader))
+
+ def split_big_batch(raw_batch_size: int) -> List:
+ input_ids = big_batch['input_ids'].split(raw_batch_size)
+ batches = [{'input_ids': x} for x in input_ids]
+
+ for key in big_batch.keys():
+ if key == 'input_ids':
+ continue
+ for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
+ batches[idx].update({key: split})
+ return batches
+
+ def profile(raw_batch_size: int) -> Tuple[float, float]:
+ packer = BinPackCollator(
+ collator=lambda x: x,
+ target_batch_size=device_batch_size,
+ max_seq_len=max_seq_len,
+ pad_token_id=0, # <-- Doesn't need to be correct for profiling
+ padding_side='left', # <-- Doesn't need to be correct for profiling
+ max_leftover_bins_to_keep=max_leftovers_to_keep)
+
+ # Simulate feeding the packing collator a bunch of data
+ for batch in split_big_batch(raw_batch_size):
+ if batch['input_ids'].shape[0] < device_batch_size:
+ continue
+ _ = packer.pack(batch)
+
+ # Return the padding / waste stats over that bunch of data
+ padding_percent = 100 * (1 - packer.efficiency)
+ waste_percent = 100 * packer.waste
+ return padding_percent, waste_percent
+
+ for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
+ padding, waste = profile(raw_batch_size)
+ yield (packing_ratio, padding, waste)
+
+
if __name__ == '__main__':
+
+ import warnings
+
+ warnings.warn(
+ DeprecationWarning(
+ 'Please use scripts/misc/profile_packing.py to profile packing.' +
+ 'This script will be removed in later releases.'))
+
+ import os
from argparse import ArgumentParser, Namespace
from omegaconf import OmegaConf as om
- from llmfoundry import (build_finetuning_dataloader,
- build_text_denoising_dataloader)
- from llmfoundry.data import build_text_dataloader
from llmfoundry.utils import build_tokenizer
def parse_args() -> Namespace:
@@ -296,7 +449,7 @@ def parse_args() -> Namespace:
parser.add_argument(
'--num-packing-ratios',
type=int,
- default=10,
+ default=20,
help=
'Number of packing_ratio values (spaced between `min` and `max) to try.'
)
@@ -316,20 +469,6 @@ def parse_args() -> Namespace:
raise ValueError('`num_packing_ratios` must be a positive integer.')
return args
- def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
- device_batch_size: int):
- if cfg.name == 'text':
- return build_text_dataloader(cfg, tokenizer, device_batch_size)
- elif cfg.name == 'text_denoising':
- return build_text_denoising_dataloader(cfg, tokenizer,
- device_batch_size)
- elif cfg.name == 'finetuning':
- return build_finetuning_dataloader(cfg, tokenizer,
- device_batch_size)
- else:
- raise ValueError(
- f'Not sure how to build dataloader with config: {cfg}')
-
args = parse_args()
with open(args.yaml_path) as f:
@@ -339,26 +478,11 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
cfg = om.create(cfg)
device_batch_size = cfg.global_train_batch_size // args.num_devices
- # Determine the packing_ratio values we'll try
- packing_ratios, raw_batch_sizes = [], []
- for packing_ratio in np.linspace(args.min,
- args.max,
- args.num_packing_ratios,
- endpoint=True):
- packing_ratio = np.round(10 * packing_ratio) / 10
- raw_batch_size = int(packing_ratio * device_batch_size)
- if raw_batch_size not in raw_batch_sizes:
- packing_ratios.append(packing_ratio)
- raw_batch_sizes.append(raw_batch_size)
-
# Fetch a bunch of raw examples once, which we'll re-use
if 'train_loader' not in cfg:
raise ValueError('config must define train_loader')
dataloader_cfg = cfg.train_loader
- max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep',
- None)
-
# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')
@@ -367,57 +491,19 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
if not isinstance(resolved_tokenizer_cfg, Dict):
raise ValueError(
'tokenizer config needs to be resolved by omegaconf into a Dict.')
- tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg
+ tokenizer_cfg = resolved_tokenizer_cfg
tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
- # Turn off packing for the dataloader (we want raw, pre-packed examples)
- dataloader_cfg.dataset.packing_ratio = None
- dataloader_cfg.dataset.max_leftovers_to_keep = None
- train_dataloader = build_dataloader(dataloader_cfg, tokenizer,
- max(raw_batch_sizes) * 100).dataloader
-
- # Get a bunch of raw examples
- big_batch = next(iter(train_dataloader))
-
- def split_big_batch(raw_batch_size: int) -> List:
- input_ids = big_batch['input_ids'].split(raw_batch_size)
- batches = [{'input_ids': x} for x in input_ids]
-
- for key in big_batch.keys():
- if key == 'input_ids':
- continue
- for idx, split in enumerate(big_batch[key].split(raw_batch_size)):
- batches[idx].update({key: split})
- return batches
-
- def profile_packing(raw_batch_size: int) -> Tuple[float, float]:
- packer = BinPackWrapper(
- collator=lambda x: x,
- target_batch_size=device_batch_size,
- max_seq_len=dataloader_cfg.dataset.max_seq_len,
- pad_token_id=0, # <-- Doesn't need to be correct for profiling
- padding_side='left', # <-- Doesn't need to be correct for profiling
- max_leftover_bins_to_keep=max_leftovers_to_keep)
-
- # Simulate feeding the packing collator a bunch of data
- for batch in split_big_batch(raw_batch_size):
- if batch['input_ids'].shape[0] < device_batch_size:
- continue
- _ = packer(batch)
-
- # Return the padding / waste stats over that bunch of data
- padding_percent = 100 * (1 - packer.efficiency)
- waste_percent = 100 * packer.waste
- return padding_percent, waste_percent
+ results = profile_packing(dataloader_cfg, tokenizer, args.min, args.max,
+ args.num_packing_ratios, device_batch_size)
header = '\n\n\n packing_ratio | % PADDING | % WASTE'
fstr = ' {:5.1f} | {:5.2f}% | {:6.2f}%'
print(header)
print('-' * len(header))
- for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
- padding, waste = profile_packing(raw_batch_size)
+ for packing_ratio, padding, waste in results:
print(fstr.format(packing_ratio, padding, waste))
diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py
index 13857e9bb9..d52633a09b 100644
--- a/llmfoundry/models/hf/hf_causal_lm.py
+++ b/llmfoundry/models/hf/hf_causal_lm.py
@@ -5,6 +5,7 @@
import logging
import os
+import warnings
from typing import Mapping, Union
# required for loading a python model into composer
@@ -24,8 +25,7 @@
from llmfoundry.models.hf.hf_fsdp import hf_get_init_device
from llmfoundry.models.hf.model_wrapper import HuggingFaceModelWithZLoss
-from llmfoundry.models.layers.llama_attention_monkeypatch import \
- get_llama_attention_patch_fn
+from llmfoundry.models.layers.attention import is_flash_v2_installed
from llmfoundry.models.utils import init_empty_weights
try:
@@ -95,12 +95,28 @@ def __init__(self, om_model_config: Union[DictConfig,
# load the model config
trust_remote_code = om_model_config.get('trust_remote_code', True)
use_auth_token = om_model_config.get('use_auth_token', False)
+ use_flash_attention_2 = om_model_config.get('use_flash_attention_2',
+ False)
+ if use_flash_attention_2 and not is_flash_v2_installed():
+ raise ValueError(
+ 'use_flash_attention_2 is set to True, but flash-attention 2 is not installed. '
+ + 'Please install flash_attn==2.3.2`.')
+
config = AutoConfig.from_pretrained(
om_model_config.pretrained_model_name_or_path,
trust_remote_code=trust_remote_code,
use_auth_token=use_auth_token,
)
+ # This is not how you are supposed to set this, but transformers currently only
+ # supports enabling flash attention 2 when using the from_pretrained API.
+ # We need to support it for both from_pretrained and from_config, so we have to
+ # set the private attribute here. This will just skip all of transformers'
+ # validation logic that it is ok to use flash attention 2, so we check
+ # whether it is installed above, and whether the chosen config supports it here.
+ # https://github.com/huggingface/transformers/issues/26878
+ config._flash_attn_2_enabled = use_flash_attention_2
+
# set config overrides
for k, v in om_model_config.get('config_overrides', {}).items():
if not hasattr(config, k):
@@ -142,6 +158,24 @@ def __init__(self, om_model_config: Union[DictConfig,
if dist.get_local_rank() != 0 and init_device == 'mixed':
om_model_config.pretrained = False
+ # If the HuggingFace model is coming from a local folder, Hugging Face copies the modules into the
+ # transformers modules cache. On particular systems, this operation seems to cause contention between
+ # the different processes. To avoid this contention, we first create the model (on meta device) on local rank
+ # zero. This will set up the transformers model cache and avoid the future contention.
+ if dist.get_local_rank() == 0 and os.path.isdir(
+ om_model_config.pretrained_model_name_or_path):
+ with init_empty_weights(include_buffers=False):
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', UserWarning)
+ AutoModelForCausalLM.from_pretrained(
+ om_model_config.pretrained_model_name_or_path,
+ trust_remote_code=trust_remote_code,
+ use_auth_token=use_auth_token,
+ config=config,
+ )
+
+ dist.barrier()
+
# initialize the model on the correct device
if resolved_init_device == 'cpu':
if om_model_config.pretrained:
@@ -200,6 +234,9 @@ def __init__(self, om_model_config: Union[DictConfig,
)
from transformers.models.llama.modeling_llama import \
LlamaAttention
+
+ from llmfoundry.models.layers.llama_attention_monkeypatch import \
+ get_llama_attention_patch_fn
LlamaAttention.forward = get_llama_attention_patch_fn(
attention_patch_type)
model.config.use_cache = False
diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py
index 39fa7162ac..0503d6d75a 100644
--- a/llmfoundry/models/layers/attention.py
+++ b/llmfoundry/models/layers/attention.py
@@ -5,7 +5,7 @@
import math
import warnings
-from typing import Any, List, Optional, Tuple
+from typing import Any, Optional
import torch
import torch.nn as nn
@@ -17,12 +17,13 @@
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
-def is_flash_v2_installed():
+def is_flash_v2_installed(v2_version: str = '2.0.0'):
+ assert version.parse(v2_version) >= version.parse('2.0.0')
try:
import flash_attn as flash_attn
except:
return False
- return version.parse(flash_attn.__version__) >= version.parse('2.0.0')
+ return version.parse(flash_attn.__version__) >= version.parse(v2_version)
def is_flash_v1_installed():
@@ -33,6 +34,16 @@ def is_flash_v1_installed():
return version.parse(flash_attn.__version__) < version.parse('2.0.0')
+# Before importing any transformers models, we need to disable transformers flash attention if
+# we are in an environment with flash attention version <2. Transformers hard errors on a not properly
+# gated import otherwise.
+if is_flash_v1_installed():
+ import transformers
+ transformers.utils.is_flash_attn_available = lambda: False
+
+from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
+
+
def _reset_is_causal(num_query_tokens: int, num_key_tokens: int,
original_is_causal: bool) -> bool:
# disable causal when it is not needed
@@ -70,7 +81,7 @@ def scaled_multihead_dot_product_attention(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
- past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
@@ -79,7 +90,7 @@ def scaled_multihead_dot_product_attention(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
+) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
if multiquery:
@@ -183,7 +194,7 @@ def scaled_multihead_dot_product_attention(
def check_valid_inputs(*tensors: torch.Tensor,
- valid_dtypes: Optional[List[torch.dtype]] = None):
+ valid_dtypes: Optional[list[torch.dtype]] = None):
if valid_dtypes is None:
valid_dtypes = [torch.float16, torch.bfloat16]
for tensor in tensors:
@@ -199,7 +210,7 @@ def flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
- past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
@@ -208,7 +219,7 @@ def flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
+) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from flash_attn import bert_padding, flash_attn_interface # type: ignore # yapf: disable # isort: skip
@@ -337,7 +348,7 @@ def triton_flash_attn_fn(
value: torch.Tensor,
n_heads: int,
kv_n_heads: Optional[int] = None,
- past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
softmax_scale: Optional[float] = None,
attn_bias: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None,
@@ -346,7 +357,7 @@ def triton_flash_attn_fn(
training: bool = False,
needs_weights: bool = False,
multiquery: bool = False,
-) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor,
+) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor,
torch.Tensor]]]:
try:
from llmfoundry.models.layers.flash_attn_triton import flash_attn_func
@@ -552,12 +563,13 @@ def __init__(
def forward(
self,
x: torch.Tensor,
- past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
+ past_key_value: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
+ rotary_emb_w_meta_info: Optional[dict] = None,
is_causal: bool = True,
needs_weights: bool = False,
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[
torch.Tensor, torch.Tensor]]]:
qkv = self.Wqkv(x)
@@ -581,6 +593,39 @@ def forward(
query = self.q_ln(query).to(dtype)
key = self.k_ln(key).to(dtype)
+ if rotary_emb_w_meta_info is not None:
+ rotary_emb = rotary_emb_w_meta_info['rotary_emb']
+ seq_len = rotary_emb_w_meta_info['seq_len']
+ offset_info = rotary_emb_w_meta_info['offset_info']
+ bsz, seqlen = query.shape[:2]
+ query = query.view(bsz, seqlen, -1, self.head_dim)
+ key = key.view(bsz, seqlen, -1, self.head_dim)
+
+ if rotary_emb_w_meta_info['impl'] == 'dail':
+ value = value.view(bsz, seqlen, -1, self.head_dim)
+
+ kv = torch.stack([key, value], dim=2)
+ query, kv = rotary_emb(query,
+ kv,
+ seqlen_offset=offset_info,
+ max_seqlen=seq_len)
+ [key, value] = torch.unbind(kv, dim=2)
+
+ value = value.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
+ elif rotary_emb_w_meta_info['impl'] == 'hf':
+ (cos, sin) = rotary_emb(value, seq_len)
+ # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+ query, key = apply_rotary_pos_emb(query, key, cos, sin,
+ offset_info)
+ # The following two transposes should be removed once the transformers library allows for the specification of the dimension for heads in the call to apply_rotary_pos_emb
+ query = query.transpose(1, 2)
+ key = key.transpose(1, 2)
+
+ query = query.view(bsz, seqlen, self.d_model)
+ key = key.view(bsz, seqlen, self.kv_n_heads * self.head_dim)
+
context, attn_weights, past_key_value = self.attn_fn(
query,
key,
@@ -677,7 +722,7 @@ def __init__(
def attn_bias_shape(
attn_impl: str, n_heads: int, seq_len: int, alibi: bool,
prefix_lm: bool, causal: bool,
- use_sequence_id: bool) -> Optional[Tuple[int, int, int, int]]:
+ use_sequence_id: bool) -> Optional[tuple[int, int, int, int]]:
if attn_impl == 'flash':
return None
elif attn_impl in ['torch', 'triton']:
diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py
index a08ef6d77f..6605807c6b 100644
--- a/llmfoundry/models/layers/blocks.py
+++ b/llmfoundry/models/layers/blocks.py
@@ -12,6 +12,31 @@
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, build_ffn
from llmfoundry.models.layers.norm import NORM_CLASS_REGISTRY
+attn_config_defaults: Dict = {
+ 'attn_type': 'multihead_attention',
+ 'attn_pdrop': 0.0,
+ 'attn_impl': 'triton',
+ 'qk_ln': False,
+ 'clip_qkv': None,
+ 'softmax_scale': None,
+ 'prefix_lm': False,
+ 'attn_uses_sequence_id': False,
+ 'alibi': False,
+ 'alibi_bias_max': 8,
+ 'rope': False,
+ 'rope_theta': 10000,
+ 'rope_impl': 'dail',
+ 'rope_dail_config': {
+ 'type': 'original',
+ 'pos_idx_in_fp32': True,
+ 'xpos_scale_base': 512,
+ },
+ 'rope_hf_config': {
+ 'type': 'no_scaling',
+ 'factor': 1.0,
+ },
+}
+
class MPTBlock(nn.Module):
@@ -30,18 +55,7 @@ def __init__(
**kwargs: Any,
):
if attn_config is None:
- attn_config = {
- 'attn_type': 'multihead_attention',
- 'attn_pdrop': 0.0,
- 'attn_impl': 'triton',
- 'qk_ln': False,
- 'clip_qkv': None,
- 'softmax_scale': None,
- 'prefix_lm': False,
- 'attn_uses_sequence_id': False,
- 'alibi': False,
- 'alibi_bias_max': 8,
- }
+ attn_config = attn_config_defaults
if ffn_config is None:
ffn_config = {
@@ -58,7 +72,8 @@ def __init__(
# necessary to avoid passing extraneous args into attn_class while allowing the use of **kwargs
args_to_exclude_in_attn_class = {
'attn_type', 'prefix_lm', 'alibi', 'attn_uses_sequence_id',
- 'alibi_bias_max'
+ 'alibi_bias_max', 'rope', 'rope_theta', 'rope_impl',
+ 'rope_dail_config', 'rope_hf_config'
}
attn_config_subset_for_attn_class = {
k: v
@@ -94,6 +109,7 @@ def forward(
x: torch.Tensor,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_bias: Optional[torch.Tensor] = None,
+ rotary_emb_w_meta_info: Optional[Dict] = None,
attention_mask: Optional[torch.ByteTensor] = None,
is_causal: bool = True,
output_attentions: bool = False,
@@ -104,6 +120,7 @@ def forward(
a,
past_key_value=past_key_value,
attn_bias=attn_bias,
+ rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=is_causal,
needs_weights=output_attentions,
diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py
index 88f61e3fef..9ceeb0747e 100644
--- a/llmfoundry/models/layers/llama_attention_monkeypatch.py
+++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py
@@ -78,6 +78,8 @@ def llama_attention_patch_torch(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ # Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
+ padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
@@ -186,6 +188,8 @@ def llama_attention_patch_triton(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
+ # Temporary fix for llama2 transformers compatibility, padding_mask will be deprecated in the next transformers release after 4.34.1.
+ padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if use_cache:
raise NotImplementedError(
diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py
index 251e4f5caf..c4ca68d733 100644
--- a/llmfoundry/models/mpt/configuration_mpt.py
+++ b/llmfoundry/models/mpt/configuration_mpt.py
@@ -8,18 +8,16 @@
from transformers import PretrainedConfig
-attn_config_defaults: Dict = {
- 'attn_type': 'multihead_attention',
- 'attn_pdrop': 0.0,
- 'attn_impl': 'triton',
- 'qk_ln': False,
- 'clip_qkv': None,
- 'softmax_scale': None,
- 'prefix_lm': False,
- 'attn_uses_sequence_id': False,
- 'alibi': False,
- 'alibi_bias_max': 8,
-}
+from llmfoundry.models.layers.attention import is_flash_v2_installed
+from llmfoundry.models.layers.blocks import attn_config_defaults
+
+# NOTE: All utils are imported directly even if unused so that
+# HuggingFace can detect all the needed files to copy into its modules folder.
+# Otherwise, certain modules are missing.
+# isort: off
+from llmfoundry.models.layers.fc import FC_CLASS_REGISTRY # type: ignore (see note)
+from llmfoundry.models.layers.norm import LPLayerNorm # type: ignore (see note)
+from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY # type: ignore (see note)
ffn_config_defaults: Dict = {
'ffn_type': 'mptmlp',
@@ -94,6 +92,16 @@ def __init__(
Defaults to ``False`` meaning any provided `sequence_id` will be ignored.
alibi (bool): Whether to use the alibi bias instead of position embeddings.
alibi_bias_max (int): The maximum value of the alibi bias.
+ rope (bool): Whether to use rotary positional embeddings.
+ rope_theta (int): The base frequency for rope.
+ rope_impl (str): The implementation of rope to use. One of 'hf' (to use the implementation from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py) or 'dail' (to use the implementation from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py).
+ rope_dail_config (Dict): The configuration for the dail implementation of rope.
+ type (str): The type of rotary position embedding to use. Options: 'original' (for https://arxiv.org/pdf/2104.09864.pdf), 'xpos' (for https://arxiv.org/pdf/2212.10554.pdf).
+ pos_idx_in_fp32 (bool): If True, the position indices [0, ..., seqlen - 1] are in fp32, otherwise they might be in lower precision. A consequence could be, for example, that bf16 rounds position 1995 to 2000, which leads to them having the same positional embedding.
+ xpos_scale_base (float): The scale base for XPos (if using XPos).
+ rope_hf_config (Dict): A dictionary used to configure rope's scaling behavior (when scaling beyond the training length).
+ type (str): Can be one of 'no_scaling', 'linear', or 'dynamic'. 'no_scaling' uses the default implementation for rotary embeddings, 'linear' uses linear scaling as proposed by the Reddit user /u/kaiokendev, and 'dynamic' uses Dynamic NTK scaling as proposed by the Reddit users /u/bloc97 and /u/emozilla.
+ factor (float): Scaling factor to use if using 'linear' or 'dynamic' as rope_scaling.type.
kv_n_heads (Optional[int]): For grouped_query_attention only, allow user to specify number of kv heads.
ffn_config (Dict): A dictionary used to configure the model's ffn module:
ffn_type (str): type of ffn to use. Options: mptmlp, te_ln_mlp
@@ -150,10 +158,12 @@ def __init__(
del kwargs['name']
if 'loss_fn' in kwargs:
del kwargs['loss_fn']
- if self.attn_config.get('alibi', False):
+ if self.attn_config.get('alibi', False) or self.attn_config.get(
+ 'rope', False):
self.learned_pos_emb = False
warnings.warn(
- f'alibi is turned on, setting `learned_pos_emb` to `False.`')
+ f'alibi or rope is turned on, setting `learned_pos_emb` to `False.`'
+ )
super().__init__(**kwargs)
self._validate_config()
@@ -164,6 +174,10 @@ def _set_config_defaults(self, config: Dict[str, Any],
for k, v in config_defaults.items():
if k not in config:
config[k] = v
+ elif isinstance(v, dict):
+ # recursively set default values for any sub-dicts
+ config[k] = self._set_config_defaults(
+ config[k] if (config[k] is not None) else {}, v)
return config
def _validate_config(self) -> None:
@@ -206,6 +220,31 @@ def _validate_config(self) -> None:
raise NotImplementedError(
'attn_uses_sequence_id only implemented with torch and triton attention.'
)
+ if self.attn_config['rope'] and (self.attn_config['rope_impl']
+ not in ['dail', 'hf']):
+ raise ValueError(
+ 'If rope is being used then rope_impl should be either "dail", or "hf".'
+ )
+ if self.attn_config['rope'] and (
+ self.attn_config['rope_impl']
+ == 'hf') and self.attn_config['rope_hf_config']['type'] not in [
+ 'no_scaling', 'linear', 'dynamic'
+ ]:
+ raise ValueError(
+ 'If using hf implementation of rope, the type should be one of "no_scaling", "linear" or "dynamic".'
+ )
+ if self.attn_config['rope'] and (self.attn_config['rope_impl']
+ == 'dail'):
+ if self.attn_config['rope_dail_config']['type'] not in [
+ 'original', 'xpos'
+ ]:
+ raise ValueError(
+ 'If using the dail implementation of rope, the type should be one of "original" or "xpos".'
+ )
+ if not is_flash_v2_installed(v2_version='2.0.1'):
+ raise ImportError(
+ 'If using the dail implementation of rope, the flash_attn library v2.0.1 or higher must be installed. Please check the instructions at https://github.com/mosaicml/llm-foundry/blob/main/TUTORIAL.md#what-kinds-of-positional-embeddings-does-llm-foundry-support'
+ )
if self.embedding_fraction > 1 or self.embedding_fraction <= 0:
raise ValueError(
'model.embedding_fraction must be between 0 (exclusive) and 1 (inclusive)!'
@@ -217,9 +256,10 @@ def _validate_config(self) -> None:
)
if self.init_config.get('name', None) is None:
raise ValueError(f"{self.init_config=} 'name' needs to be set.")
- if not self.learned_pos_emb and not self.attn_config['alibi']:
+ if not (self.learned_pos_emb or self.attn_config['alibi'] or
+ self.attn_config['rope']):
warnings.warn(
- f'Positional information not being provided to the model using either learned_pos_emb or alibi.'
+ f'Positional information not being provided to the model using either learned_pos_emb or alibi or rope.'
)
if self.fc_type == 'te' or self.ffn_config['ffn_type'] == 'te_ln_mlp':
try:
diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py
index 4f4581b177..0cb3ebd56c 100644
--- a/llmfoundry/models/mpt/modeling_mpt.py
+++ b/llmfoundry/models/mpt/modeling_mpt.py
@@ -23,11 +23,27 @@
from composer.metrics.nlp import LanguageCrossEntropy, LanguagePerplexity
from composer.models import HuggingFaceModel
from composer.utils import dist
+
+from llmfoundry.models.layers.attention import is_flash_v2_installed
+
+if is_flash_v2_installed():
+ try: # This try...except is needed because transformers requires it despite the 'if' statement above
+ from flash_attn.layers.rotary import \
+ RotaryEmbedding as DAILRotaryEmbedding
+ except Exception as e:
+ raise e
+
from omegaconf import DictConfig
from omegaconf import OmegaConf as om
from transformers import PreTrainedModel, PreTrainedTokenizerBase
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
+from transformers.models.llama.modeling_llama import \
+ LlamaDynamicNTKScalingRotaryEmbedding as HFDynamicNTKScalingRotaryEmbedding
+from transformers.models.llama.modeling_llama import \
+ LlamaLinearScalingRotaryEmbedding as HFLinearScalingRotaryEmbedding
+from transformers.models.llama.modeling_llama import \
+ LlamaRotaryEmbedding as HFRotaryEmbedding
from llmfoundry.models.layers.attention import attn_bias_shape, build_attn_bias
from llmfoundry.models.layers.blocks import MPTBlock
@@ -70,6 +86,50 @@
log = logging.getLogger(__name__)
+def gen_rotary_embedding(rope_head_dim: int, rope_impl: str, rope_theta: int,
+ rope_dail_config: dict, rope_hf_config: dict,
+ max_seq_len: int):
+ if rope_impl == 'dail':
+ return DAILRotaryEmbedding(
+ dim=rope_head_dim,
+ base=rope_theta,
+ interleaved=False,
+ scale_base=rope_dail_config['xpos_scale_base'] if
+ (rope_dail_config['type'] == 'xpos') else None,
+ pos_idx_in_fp32=rope_dail_config['pos_idx_in_fp32'],
+ device=
+ 'cpu', # FSDP does not materialize modules with meta buffers, hence device is set to cpu
+ )
+ elif rope_impl == 'hf':
+ if rope_hf_config['type'] == 'no_scaling':
+ return HFRotaryEmbedding(
+ rope_head_dim,
+ max_position_embeddings=max_seq_len,
+ base=rope_theta,
+ device=
+ 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
+ )
+ elif rope_hf_config['type'] == 'linear':
+ return HFLinearScalingRotaryEmbedding(
+ rope_head_dim,
+ max_position_embeddings=max_seq_len,
+ base=rope_theta,
+ scaling_factor=rope_hf_config['factor'],
+ device=
+ 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
+ )
+ elif rope_hf_config['type'] == 'dynamic':
+ return HFDynamicNTKScalingRotaryEmbedding(
+ rope_head_dim,
+ max_position_embeddings=max_seq_len,
+ base=rope_theta,
+ scaling_factor=rope_hf_config['factor'],
+ device=
+ 'cpu' # FSDP does not materialize modules with meta buffers, hence device is set to cpu
+ )
+ raise ValueError('rope_impl needs to be either dail or hf')
+
+
class MPTPreTrainedModel(PreTrainedModel):
config_class = MPTConfig
base_model_prefix = 'model'
@@ -123,6 +183,18 @@ def __init__(self, config: MPTConfig):
])
self.norm_f = norm_class(config.d_model, device=config.init_device)
+ self.rope = config.attn_config['rope']
+ self.rope_impl = None
+ if self.rope:
+ self.rope_impl = config.attn_config['rope_impl']
+ self.rotary_embedding = gen_rotary_embedding(
+ rope_head_dim=config.d_model // config.n_heads,
+ rope_impl=self.rope_impl,
+ rope_theta=config.attn_config['rope_theta'],
+ rope_dail_config=config.attn_config['rope_dail_config'],
+ rope_hf_config=config.attn_config['rope_hf_config'],
+ max_seq_len=self.config.max_seq_len)
+
if config.init_device != 'meta':
log.info(
f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.'
@@ -361,8 +433,9 @@ def forward(
S <= self.config.max_seq_len
), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}'
- tok_emb = self.wte(input_ids)
- if self.learned_pos_emb:
+ rotary_emb_w_meta_info = None
+ x = self.wte(input_ids)
+ if self.learned_pos_emb or self.rope:
past_position = 0
if past_key_values is not None:
if len(past_key_values) != self.config.n_layers:
@@ -378,31 +451,44 @@ def forward(
if self.attn_impl == 'torch':
past_position = past_key_values[0][0].size(3)
- if S + past_position > self.config.max_seq_len:
+ if self.learned_pos_emb and (S + past_position >
+ self.config.max_seq_len):
raise ValueError(
f'Cannot forward input with past sequence length {past_position} and current sequence length '
+
f'{S + 1}, this model only supports total sequence length <= {self.config.max_seq_len}.'
)
- pos = torch.arange(
- past_position,
- S + past_position,
- dtype=torch.long,
- device=input_ids.device,
- ).unsqueeze(0)
- if attention_mask is not None:
- # adjust the position indices to account for padding tokens
- pos = torch.clamp(
- pos - torch.cumsum((~attention_mask).to(torch.int32),
- dim=1)[:, past_position:],
- min=0,
- )
- pos_emb = self.wpe(pos)
- x = tok_emb + pos_emb
- else:
- # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled)
- x = tok_emb
+ if self.learned_pos_emb or (self.rope and self.rope_impl == 'hf'):
+ pos = torch.arange(
+ past_position,
+ S + past_position,
+ dtype=torch.long,
+ device=input_ids.device,
+ ).unsqueeze(0)
+ if attention_mask is not None:
+ # adjust the position indices to account for padding tokens
+ pos = torch.clamp(
+ pos - torch.cumsum((~attention_mask).to(torch.int32),
+ dim=1)[:, past_position:],
+ min=0,
+ )
+ if self.learned_pos_emb:
+ x = x + self.wpe(pos)
+ elif self.rope and self.rope_impl == 'hf':
+ rotary_emb_w_meta_info = {
+ 'impl': self.rope_impl,
+ 'rotary_emb': self.rotary_embedding,
+ 'offset_info': pos,
+ 'seq_len': S + past_position,
+ }
+ elif self.rope and self.rope_impl == 'dail':
+ rotary_emb_w_meta_info = {
+ 'impl': self.rope_impl,
+ 'rotary_emb': self.rotary_embedding,
+ 'offset_info': past_position,
+ 'seq_len': S + past_position,
+ }
if self.embedding_fraction == 1:
x = self.emb_drop(x)
@@ -439,6 +525,7 @@ def forward(
x,
past_key_value=past_key_value,
attn_bias=attn_bias,
+ rotary_emb_w_meta_info=rotary_emb_w_meta_info,
attention_mask=attention_mask,
is_causal=self.is_causal,
output_attentions=bool(output_attentions),
diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py
index fb9477d909..692fab94c2 100644
--- a/llmfoundry/models/utils/hf_prefixlm_converter.py
+++ b/llmfoundry/models/utils/hf_prefixlm_converter.py
@@ -10,31 +10,14 @@
and treat the input prompt as the prefix in `generate`.
"""
-import math
-import warnings
from types import MethodType
from typing import Any, List, MutableMapping, Optional, Tuple, Union
import torch
-from transformers.models.bloom.modeling_bloom import (
- BaseModelOutputWithPastAndCrossAttentions, BloomForCausalLM, BloomModel,
- CausalLMOutputWithCrossAttentions, CrossEntropyLoss)
-from transformers.models.bloom.modeling_bloom import \
- _expand_mask as _expand_mask_bloom
-from transformers.models.bloom.modeling_bloom import \
- _make_causal_mask as _make_causal_mask_bloom
-from transformers.models.bloom.modeling_bloom import logging
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
from transformers.models.gpt_neo.modeling_gpt_neo import GPTNeoForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import GPTNeoXForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM
-from transformers.models.opt.modeling_opt import OPTForCausalLM
-from transformers.models.opt.modeling_opt import \
- _expand_mask as _expand_mask_opt
-from transformers.models.opt.modeling_opt import \
- _make_causal_mask as _make_causal_mask_opt
-
-logger = logging.get_logger(__name__)
_SUPPORTED_GPT_MODELS = (
GPT2LMHeadModel,
@@ -223,583 +206,10 @@ def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any):
return model
-def _convert_bloom_causal_lm_to_prefix_lm(
- model: BloomForCausalLM) -> BloomForCausalLM:
- """Converts a BLOOM Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `BloomForCausalLM`
-
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
- """
- if hasattr(model, '_prefix_lm_converted'):
- return model
-
- assert isinstance(model, BloomForCausalLM)
- assert model.config.add_cross_attention == False, 'Only supports BLOOM decoder-only models'
-
- # Modified from transformers.models.bloom.modeling_bloom.BloomModel._prepare_attn_mask
- # https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/models/bloom/modeling_bloom.py#L648
- def _prepare_attn_mask(
- self: BloomModel,
- attention_mask: torch.Tensor,
- bidirectional_mask: Optional[torch.Tensor],
- input_shape: Tuple[int, int],
- past_key_values_length: int,
- ) -> torch.BoolTensor:
- # create causal mask
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
- combined_attention_mask = None
- device = attention_mask.device
- _, src_length = input_shape
-
- if src_length > 1:
- combined_attention_mask = _make_causal_mask_bloom(
- input_shape,
- device=device,
- past_key_values_length=past_key_values_length)
- # Make use of the batch-specific `bidirectional_mask` attribute set
- # by the parent module in its (new) `forward` method wrapper
- if bidirectional_mask is not None:
- # The two masks should have the same size
- assert attention_mask.shape == bidirectional_mask.shape
-
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
- expanded_bidirectional_mask = _expand_mask_bloom(
- bidirectional_mask, tgt_length=src_length)
- combined_attention_mask = torch.logical_and(
- combined_attention_mask, expanded_bidirectional_mask)
-
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
- expanded_attn_mask = _expand_mask_bloom(attention_mask,
- tgt_length=src_length)
- combined_attention_mask = (expanded_attn_mask
- if combined_attention_mask is None else
- expanded_attn_mask | combined_attention_mask)
-
- return combined_attention_mask
-
- # Modified from transformers.models.bloom.modeling_bloom._prepare_alibi_transformer
- # https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/models/bloom/modeling_bloom.py#L87
- def _build_alibi_tensor(
- self: BloomModel,
- batch_size: int,
- query_length: int,
- key_length: int,
- dtype: torch.dtype,
- device: torch.device,
- ) -> torch.Tensor:
- num_heads = self.config.n_head
-
- closest_power_of_2 = 2**math.floor(math.log2(num_heads))
- base = torch.tensor(2**(-(2**-(math.log2(closest_power_of_2) - 3))),
- device=device,
- dtype=torch.float32)
- powers = torch.arange(1,
- 1 + closest_power_of_2,
- device=device,
- dtype=torch.int32)
- slopes = torch.pow(base, powers)
-
- if closest_power_of_2 != num_heads:
- extra_base = torch.tensor(
- 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))),
- device=device,
- dtype=torch.float32)
- num_remaining_heads = min(closest_power_of_2,
- num_heads - closest_power_of_2)
- extra_powers = torch.arange(1,
- 1 + 2 * num_remaining_heads,
- 2,
- device=device,
- dtype=torch.int32)
- slopes = torch.cat(
- [slopes, torch.pow(extra_base, extra_powers)], dim=0)
-
- qa = torch.arange(query_length, device=device,
- dtype=torch.int32).view(-1, 1)
- ka = torch.arange(key_length, device=device,
- dtype=torch.int32).view(1, -1)
- diffs = qa - ka + key_length - query_length
- diffs = -diffs.abs()
- alibi = slopes.view(1, num_heads, 1, 1) * diffs.view(
- 1, 1, query_length, key_length)
- alibi = alibi.expand(batch_size, -1, -1,
- -1).reshape(-1, query_length, key_length)
- return alibi.to(dtype)
-
- # Modified from transformers.models.bloom.modeling_bloom.BloomModel.forward
- # Note: The modified code is surrounded with #### START/END #### comments
- # and one new argument (`bidirectional_mask`) is added to the signature.
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
-
- def transformer_forward(
- self: BloomModel,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- bidirectional_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **deprecated_arguments: Any
- ) -> Union[Tuple[torch.Tensor, ...],
- BaseModelOutputWithPastAndCrossAttentions]:
- if deprecated_arguments.pop('position_ids', False) is not False:
- # `position_ids` could have been `torch.Tensor` or `None` so
- # defaulting pop to `False` allows to detect if users were
- # passing explicitly `None`
- warnings.warn(
- '`position_ids` have no functionality in BLOOM and will be removed in v5.0.0. ' +\
- 'You can safely ignore passing `position_ids`.',
- FutureWarning,
- )
- if len(deprecated_arguments) > 0:
- raise ValueError(
- f'Got unexpected arguments: {deprecated_arguments}')
-
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
- output_hidden_states = (output_hidden_states
- if output_hidden_states is not None else
- self.config.output_hidden_states)
- use_cache = use_cache if use_cache is not None else self.config.use_cache
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- if input_ids is not None and inputs_embeds is not None:
- raise ValueError(
- 'You cannot specify both input_ids and inputs_embeds at the same time'
- )
- elif input_ids is not None:
- batch_size, seq_length = input_ids.shape
- elif inputs_embeds is not None:
- batch_size, seq_length, _ = inputs_embeds.shape
- else:
- raise ValueError(
- 'You have to specify either input_ids or inputs_embeds')
-
- if past_key_values is None:
- past_key_values = tuple([None] * len(self.h)) # type: ignore
-
- # Prepare head mask if needed
- # 1.0 in head_mask indicate we keep the head
- # attention_probs has shape batch_size x num_heads x N x N
- # head_mask has shape n_layer x batch x num_heads x N x N
- head_mask = self.get_head_mask(head_mask, self.config.n_layer)
-
- if inputs_embeds is None:
- inputs_embeds = self.word_embeddings(input_ids)
-
- hidden_states = self.word_embeddings_layernorm(inputs_embeds)
-
- presents = () if use_cache else None
- all_self_attentions = () if output_attentions else None
- all_hidden_states = () if output_hidden_states else None
-
- # Compute alibi tensor: check build_alibi_tensor documentation
- seq_length_with_past = seq_length
- past_key_values_length = 0
- if past_key_values[0] is not None: # type: ignore
- tmp = past_key_values[0][0] # type: ignore
- past_key_values_length = tmp.shape[2] # type: ignore
- seq_length_with_past = seq_length_with_past + past_key_values_length
- if attention_mask is None:
- attention_mask = torch.ones((batch_size, seq_length_with_past),
- device=hidden_states.device)
- else:
- attention_mask = attention_mask.to(hidden_states.device)
-
- ##### ALL NON-SIGNATURE MODIFICATIONS ARE CONTAINED TO THIS BLOCK [STARTS HERE] #####
- alibi = self._build_alibi_tensor(
- batch_size=batch_size,
- query_length=seq_length,
- key_length=seq_length_with_past,
- dtype=hidden_states.dtype,
- device=hidden_states.device,
- )
-
- causal_mask = self._prepare_attn_mask(
- attention_mask,
- bidirectional_mask,
- input_shape=(batch_size, seq_length),
- past_key_values_length=past_key_values_length,
- )
- ##### ALL NON-SIGNATURE MODIFICATIONS ARE CONTAINED TO THIS BLOCK [ENDS HERE] #####
-
- for i, (block,
- layer_past) in enumerate(zip(self.h,
- past_key_values)): # type: ignore
-
- if output_hidden_states:
- hst = (hidden_states,)
- all_hidden_states = all_hidden_states + hst # type: ignore
-
- if self.gradient_checkpointing and self.training:
-
- if use_cache:
- logger.warning(
- '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
- )
- use_cache = False
-
- def create_custom_forward(module: torch.nn.Module):
-
- def custom_forward(*inputs: Any):
- # None for past_key_value
- return module(*inputs,
- use_cache=use_cache,
- output_attentions=output_attentions)
-
- return custom_forward
-
- outputs = torch.utils.checkpoint.checkpoint( # type: ignore
- create_custom_forward(block),
- hidden_states,
- alibi,
- causal_mask,
- head_mask[i], # type: ignore
- )
- else:
- outputs = block(
- hidden_states,
- layer_past=layer_past,
- attention_mask=causal_mask,
- head_mask=head_mask[i], # type: ignore
- use_cache=use_cache,
- output_attentions=output_attentions,
- alibi=alibi,
- )
-
- hidden_states = outputs[0]
- if use_cache is True:
- presents = presents + (outputs[1],) # type: ignore
-
- if output_attentions:
- oa = (outputs[2 if use_cache else 1],) # type: ignore
- all_self_attentions = all_self_attentions + oa # type: ignore
-
- # Add last hidden state
- hidden_states = self.ln_f(hidden_states)
-
- if output_hidden_states:
- hst = (hidden_states,)
- all_hidden_states = all_hidden_states + hst # type: ignore
-
- if not return_dict:
- return tuple(v for v in [
- hidden_states, presents, all_hidden_states, all_self_attentions
- ] if v is not None)
-
- return BaseModelOutputWithPastAndCrossAttentions(
- last_hidden_state=hidden_states,
- past_key_values=presents,
- hidden_states=all_hidden_states,
- attentions=all_self_attentions,
- )
-
- # Make it so model.transformer has the new helper methods and new
- # `forward` method
- setattr(model.transformer, '_prepare_attn_mask',
- MethodType(_prepare_attn_mask, model.transformer))
- setattr(model.transformer, '_build_alibi_tensor',
- MethodType(_build_alibi_tensor, model.transformer))
- setattr(model.transformer, 'forward',
- MethodType(transformer_forward, model.transformer))
-
- # In order to actually use the new argument we've added to
- # model.transformer, we need to update the parent module's `forward` to
- # accept/pass the same new argument.
- # We add 2 lines to handle that change.
- # Both lines are tagged with "# WE'RE ADDING A NEW ARGUMENT!"
- KeyValueT = Tuple[torch.Tensor, torch.Tensor]
-
- def forward(
- self: BloomForCausalLM,
- input_ids: Optional[torch.LongTensor] = None,
- past_key_values: Optional[Tuple[KeyValueT, ...]] = None,
- attention_mask: Optional[torch.Tensor] = None,
- # WE'RE ADDING A NEW ARGUMENT! (Change 1/2)
- bidirectional_mask: Optional[torch.Tensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- inputs_embeds: Optional[torch.Tensor] = None,
- labels: Optional[torch.Tensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- **deprecated_arguments: Any,
- ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
- """Replacement forward method for BloomCausalLM."""
- if deprecated_arguments.pop('position_ids', False) is not False:
- # `position_ids` could have been `torch.Tensor` or `None` so
- # defaulting pop to `False` allows to detect if users were passing
- # explicitly `None`
- warnings.warn(
- '`position_ids` have no functionality in BLOOM and will be removed ' +\
- 'in v5.0.0. You can safely ignore passing `position_ids`.',
- FutureWarning,
- )
- if len(deprecated_arguments) > 0:
- raise ValueError(
- f'Got unexpected arguments: {deprecated_arguments}')
-
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
-
- transformer_outputs = self.transformer(
- input_ids,
- past_key_values=past_key_values,
- attention_mask=attention_mask,
- # WE'RE ADDING A NEW ARGUMENT! (Change 2/2)
- bidirectional_mask=bidirectional_mask,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- hidden_states = transformer_outputs[0]
-
- lm_logits = self.lm_head(hidden_states)
-
- loss = None
- if labels is not None:
- # Shift so that tokens < n predict n
- shift_logits = lm_logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- batch_size, seq_length, vocab_size = shift_logits.shape
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- shift_logits.view(batch_size * seq_length, vocab_size),
- shift_labels.view(batch_size * seq_length))
-
- if not return_dict:
- output = (lm_logits,) + transformer_outputs[1:]
- return ((loss,) + output) if loss is not None else output
-
- return CausalLMOutputWithCrossAttentions(
- loss=loss,
- logits=lm_logits,
- past_key_values=transformer_outputs.past_key_values,
- hidden_states=transformer_outputs.hidden_states,
- attentions=transformer_outputs.attentions,
- )
-
- # To handle generation, re-write `prepare_inputs_for_generation` to
- # implement the bidirectional logic.
- def prepare_inputs_for_generation(self: BloomForCausalLM,
- input_ids: torch.LongTensor,
- past: Optional[torch.Tensor] = None,
- attention_mask: Optional[
- torch.Tensor] = None,
- **kwargs: Any) -> dict:
- del kwargs # unused
- # only last token for input_ids if past is not None
- if past:
- input_ids = input_ids[:, -1].unsqueeze(-1) # type: ignore
- # We can turn off bidirectional masking after the prefix
- # has been encoded into `past`
- bidirectional_mask = None
-
- # the cache may be in the standard format (e.g. in contrastive
- # search), convert to bloom's format if needed
- if past[0][0].shape[0] == input_ids.shape[0]:
- past = self._convert_to_bloom_cache(past)
-
- else:
- # If we're here, `input_ids` contains the prefix. Encode it with
- # bidirectional attention.
- bidirectional_mask = torch.ones_like(input_ids)
-
- return {
- 'input_ids': input_ids,
- 'past_key_values': past,
- # "use_cache": kwargs.get("use_cache"),
- # Requires this. TODO(Alex): Confirm this supports other decoding strategies.
- 'use_cache': True,
- 'attention_mask': attention_mask,
- 'bidirectional_mask': bidirectional_mask,
- }
-
- # Register the new `forward` and `prepare_inputs_for_generation` methods
- # with the model
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'prepare_inputs_for_generation',
- MethodType(prepare_inputs_for_generation, model))
-
- # Finally, tag the model so that this conversion cannot happen again.
- setattr(model, '_prefix_lm_converted', True)
- return model
-
-
-def _convert_opt_causal_lm_to_prefix_lm(
- model: OPTForCausalLM) -> OPTForCausalLM:
- """Converts an OPT Causal LM to a Prefix LM.
-
- Supported HuggingFace model classes:
- - `OPTForCausalLM`
-
- See `convert_hf_causal_lm_to_prefix_lm` for more details.
- """
- if hasattr(model, '_prefix_lm_converted'):
- return model
-
- assert isinstance(model, OPTForCausalLM)
- assert model.config.add_cross_attention == False, 'Only supports OPT decoder-only models'
-
- # Rename methods to allow:
- # - new `forward` to wrap original `forward`
- # - new `generate` to wrap original `generate`
- setattr(model, '_original_forward', getattr(model, 'forward'))
- setattr(model, '_original_generate', getattr(model, 'generate'))
-
- model.model.decoder.bidirectional_mask = None
-
- # Modified from transformers.models.bloom.modeling_opt.OPTDecoder._prepare_decoder_attn_mask
- # https://github.com/huggingface/transformers/blob/v4.25.1/src/transformers/models/opt/modeling_opt.py#L532
- def _prepare_decoder_attention_mask(self: torch.nn.Module,
- attention_mask: Optional[torch.Tensor],
- input_shape: Tuple[int, int],
- inputs_embeds: Optional[torch.Tensor],
- past_key_values_length: int):
- # create causal mask
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- combined_attention_mask = None
- if input_shape[-1] > 1:
- assert inputs_embeds is not None
- # 'g' indicates generation mode. Causal mask replaced with 0.
- if self.bidirectional_mask == 'g':
- bsz, src_length = input_shape
- combined_attention_mask = torch.zeros(
- (bsz, 1, src_length, src_length + past_key_values_length),
- dtype=inputs_embeds.dtype,
- device=inputs_embeds.device)
- else:
- combined_attention_mask = _make_causal_mask_opt(
- input_shape,
- inputs_embeds.dtype,
- past_key_values_length=past_key_values_length).to(
- inputs_embeds.device)
-
- # Make use of the batch-specific `bidirectional_mask` attribute
- # set by the parent module in its (new) `forward` method wrapper
- if self.bidirectional_mask is not None:
- assert attention_mask is not None
- # The two masks should have the same size
- assert attention_mask.shape == self.bidirectional_mask.shape
-
- # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
- expanded_bidirectional_mask = _expand_mask_opt(
- self.bidirectional_mask,
- inputs_embeds.dtype,
- tgt_len=input_shape[-1]).to(inputs_embeds.device)
- combined_attention_mask = torch.maximum(
- expanded_bidirectional_mask, combined_attention_mask)
-
- if attention_mask is not None:
- assert inputs_embeds is not None
- # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
- expanded_attn_mask = _expand_mask_opt(attention_mask,
- inputs_embeds.dtype,
- tgt_len=input_shape[-1]).to(
- inputs_embeds.device)
- combined_attention_mask = (expanded_attn_mask
- if combined_attention_mask is None else
- expanded_attn_mask +
- combined_attention_mask)
-
- return combined_attention_mask
-
- # Make it so model.model.decoder uses the above `_prepare_decoder_attn_mask`
- # in place of the original method
- setattr(model.model.decoder, '_prepare_decoder_attention_mask',
- MethodType(_prepare_decoder_attention_mask, model.model.decoder))
-
- def forward(
- self: OPTForCausalLM,
- input_ids: Optional[torch.LongTensor] = None,
- attention_mask: Optional[torch.Tensor] = None,
- bidirectional_mask: Optional[torch.ByteTensor] = None,
- head_mask: Optional[torch.Tensor] = None,
- past_key_values: Optional[List[torch.FloatTensor]] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_attentions: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- return_dict: Optional[bool] = None,
- ):
-
- def call_og_forward():
- return self._original_forward(
- input_ids=input_ids,
- attention_mask=attention_mask,
- head_mask=head_mask,
- past_key_values=past_key_values,
- inputs_embeds=inputs_embeds,
- labels=labels,
- use_cache=use_cache,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
-
- if bidirectional_mask is None:
- # This wrapper is a no-op if bidirectional masks are not supplied
- return call_og_forward()
-
- # Temporarily set `bidirectional_mask` in the child module
- self.model.decoder.bidirectional_mask = bidirectional_mask
-
- # Apply the original forward method (the model will use the mask that
- # was just set)
- try:
- outputs = call_og_forward()
- except:
- self.model.decoder.bidirectional_mask = None
- raise
-
- # Reset the `bidirectional_mask` attribute to None
- self.model.decoder.bidirectional_mask = None
-
- # Return the outputs
- return outputs
-
- def generate(self: OPTForCausalLM, *args: tuple, **kwargs: Any):
- """Wraps original generate to enable PrefixLM-style attention."""
- # Flag the child module to use generation-style attention masking
- self.model.decoder.bidirectional_mask = 'g'
-
- # Collect outputs using the model's original forward method
- try:
- output = self._original_generate(*args, **kwargs)
- except:
- self.model.decoder.bidirectional_mask = None
- raise
-
- # Reset the `bidirectional_mask` attribute to None
- self.model.decoder.bidirectional_mask = None
-
- # Return the output
- return output
-
- # Replace `forward` and `generate` with the new wrappers
- setattr(model, 'forward', MethodType(forward, model))
- setattr(model, 'generate', MethodType(generate, model))
-
- # Finally, tag the model so that this conversion cannot happen again.
- setattr(model, '_prefix_lm_converted', True)
- return model
-
-
-_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS + (BloomForCausalLM,
- OPTForCausalLM)
+_SUPPORTED_HF_MODELS = _SUPPORTED_GPT_MODELS
CAUSAL_LM_TYPES = Union[GPT2LMHeadModel, GPTJForCausalLM, GPTNeoForCausalLM,
- GPTNeoXForCausalLM, BloomForCausalLM, OPTForCausalLM]
+ GPTNeoXForCausalLM]
def convert_hf_causal_lm_to_prefix_lm(
@@ -811,8 +221,6 @@ def convert_hf_causal_lm_to_prefix_lm(
- `GPTNeoForCausalLM`
- `GPTNeoXForCausalLM`
- `GPTJForCausalLM`
- - `BloomForCausalLM`
- - `OPTForCausalLM`
Conversion to a Prefix LM is done by modifying the `forward` method, and possibly also the
`generate` method and/or select underlying methods depending on the model class.
@@ -862,13 +270,6 @@ def convert_hf_causal_lm_to_prefix_lm(
"""
if isinstance(model, _SUPPORTED_GPT_MODELS):
return _convert_gpt_causal_lm_to_prefix_lm(model)
-
- elif isinstance(model, BloomForCausalLM):
- return _convert_bloom_causal_lm_to_prefix_lm(model)
-
- elif isinstance(model, OPTForCausalLM):
- return _convert_opt_causal_lm_to_prefix_lm(model)
-
else:
raise TypeError(
f'Cannot convert model to Prefix LM. ' +\
diff --git a/llmfoundry/tokenizers/tiktoken.py b/llmfoundry/tokenizers/tiktoken.py
index 41518a582a..650d469ecf 100644
--- a/llmfoundry/tokenizers/tiktoken.py
+++ b/llmfoundry/tokenizers/tiktoken.py
@@ -1,6 +1,7 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
+import warnings
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
@@ -26,7 +27,7 @@ def __init__(self,
eos_token: Optional[str] = '<|endoftext|>',
bos_token: Optional[str] = '<|endoftext|>',
pad_token: Optional[str] = None,
- **kwargs: Dict[str, Any]):
+ **kwargs: Any):
"""Constructor creates a tiktoken tokenizer to use as the underlying.
tokenizer.
@@ -49,6 +50,23 @@ def __init__(self,
raise ImportError(
'You need to install tiktoken to use TiktokenTokenizerWrapper.')
+ # Workaround to make tiktokenizer picklable.
+ # https://github.com/huggingface/datasets/issues/5536#issuecomment-1682309347
+ # There is an open PR from HF to add this to tiktoken: https://github.com/openai/tiktoken/pull/181
+ import copyreg
+ import functools
+
+ from tiktoken import Encoding # type: ignore (thirdParty)
+
+ def pickle_Encoding(enc: Encoding):
+ return (functools.partial(Encoding,
+ enc.name,
+ pat_str=enc._pat_str,
+ mergeable_ranks=enc._mergeable_ranks,
+ special_tokens=enc._special_tokens), ())
+
+ copyreg.pickle(Encoding, pickle_Encoding)
+
if model_name is not None and encoding_name is not None:
raise ValueError(
'You need to specify either model_name or encoding_name, not both.'
@@ -90,7 +108,17 @@ def is_fast(self) -> bool:
return False
def get_vocab(self) -> Dict[str, int]:
- """Returns vocab as a dict."""
+ """Returns vocab as a dict.
+
+ Note: This function does not work properly due to difference in assumptions between tiktoken and Hugging Face tokenizers.
+ Most uses do not need to use get_vocab, so this is not a priority to fix.
+ """
+ warnings.warn(
+ 'get_vocab does not work properly with TiktokenTokenizerWrapper. Please do not rely on it being perfectly correct.'
+ +
+ ' It will be called once init just to get the size of the vocab inside the base class.'
+ )
+
vocab = {}
for i in range(self.vocab_size):
try:
@@ -101,6 +129,24 @@ def get_vocab(self) -> Dict[str, int]:
except KeyError:
pass
+ # As far as I can tell, we don't require get_vocab to completely work,
+ # but when using additional_special_tokens, Hugging Face determines the next
+ # token index to add with len(self.get_vocab()) so we need the _size_ of this dictionary to be correct.
+ extra_id_index = 0
+ candidate_extra_id = f''
+ indices_to_fill_in = {i for i in range(self.vocab_size)} - set(
+ vocab.values())
+
+ # Add enough indices to make get_vocab() the right length
+ for index_to_add in indices_to_fill_in:
+ # Make sure we don't overwrite a token that already exists
+ while candidate_extra_id in vocab:
+ extra_id_index += 1
+ candidate_extra_id = f''
+
+ # Get an index to add and add the item
+ vocab[candidate_extra_id] = index_to_add
+
return vocab
def _tokenize(self, text: str) -> List[int]:
@@ -155,7 +201,7 @@ def convert_ids_to_tokens(
"""
if isinstance(ids, int):
if ids in self.added_tokens_decoder:
- return self.added_tokens_decoder[ids]
+ return str(self.added_tokens_decoder[ids])
return self._convert_id_to_token(ids)
@@ -171,7 +217,7 @@ def convert_ids_to_tokens(
if index in self.added_tokens_decoder:
tokens.append(self.encoding.decode(current_stream))
current_stream = []
- tokens.append(self.added_tokens_decoder[index])
+ tokens.append(str(self.added_tokens_decoder[index]))
else:
current_stream.append(index)
diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py
index 38cc562c9d..7abe4dcf75 100644
--- a/llmfoundry/utils/__init__.py
+++ b/llmfoundry/utils/__init__.py
@@ -11,6 +11,8 @@
from llmfoundry.utils.config_utils import (calculate_batch_size_info,
log_config, pop_config,
update_batch_size_info)
+ from llmfoundry.utils.model_download_utils import (
+ download_from_cache_server, download_from_hf_hub)
except ImportError as e:
raise ImportError(
'Please make sure to pip install . to get requirements for llm-foundry.'
@@ -26,6 +28,8 @@
'build_tokenizer',
'calculate_batch_size_info',
'convert_and_save_ft_weights',
+ 'download_from_cache_server',
+ 'download_from_hf_hub',
'get_hf_tokenizer_from_composer_state_dict',
'update_batch_size_info',
'log_config',
diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py
index b82e2581c7..f1ed608dd0 100644
--- a/llmfoundry/utils/builders.py
+++ b/llmfoundry/utils/builders.py
@@ -10,9 +10,9 @@
import datasets as hf_datasets
import json
from composer import algorithms
-from composer.callbacks import (EarlyStopper, Generate, LRMonitor,
- MemoryMonitor, OptimizerMonitor,
- RuntimeEstimator, SpeedMonitor)
+from composer.callbacks import (EarlyStopper, Generate, LRMonitor, MemoryMonitor,
+ OptimizerMonitor, RuntimeEstimator, EvalOutputLogging,
+ SpeedMonitor)
from composer.core import Algorithm, Callback, Evaluator
from composer.datasets.in_context_learning_evaluation import \
get_icl_task_dataloader
@@ -120,6 +120,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback:
return EarlyStopper(**kwargs)
elif name == 'hf_checkpointer':
return HuggingFaceCheckpointer(**kwargs)
+ elif name == 'eval_output_logging':
+ return EvalOutputLogging(**kwargs)
else:
raise ValueError(f'Not sure how to build callback: {name}')
@@ -190,6 +192,12 @@ def build_tokenizer(
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+ signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
+
+ # Make sure the tokenizer files are downloaded and cached first by local rank 0
+ with dist.local_rank_zero_download_and_wait(signal_file_path):
+ pass
+
if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
@@ -204,6 +212,15 @@ def build_tokenizer(
int(1e30),
)
+ if dist.get_local_rank() == 0:
+ with open(signal_file_path, 'wb') as f:
+ f.write(b'local_rank0_completed_tokenizer_setup')
+
+ dist.barrier()
+
+ if dist.get_local_rank() == 0:
+ os.remove(signal_file_path)
+
return tokenizer
def prep_hf_dataset(icl_cfg: ListConfig):
diff --git a/llmfoundry/utils/model_download_utils.py b/llmfoundry/utils/model_download_utils.py
new file mode 100644
index 0000000000..2104455e0f
--- /dev/null
+++ b/llmfoundry/utils/model_download_utils.py
@@ -0,0 +1,235 @@
+# Copyright 2022 MosaicML LLM Foundry authors
+# SPDX-License-Identifier: Apache-2.0
+
+"""Utility functions for downloading models."""
+import copy
+import logging
+import os
+import time
+import warnings
+from http import HTTPStatus
+from typing import Optional
+from urllib.parse import urljoin
+
+import huggingface_hub as hf_hub
+import requests
+import tenacity
+from bs4 import BeautifulSoup
+from requests.packages.urllib3.exceptions import InsecureRequestWarning
+from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME
+from transformers.utils import WEIGHTS_INDEX_NAME as PYTORCH_WEIGHTS_INDEX_NAME
+from transformers.utils import WEIGHTS_NAME as PYTORCH_WEIGHTS_NAME
+
+DEFAULT_IGNORE_PATTERNS = [
+ '*.ckpt',
+ '*.h5',
+ '*.msgpack',
+]
+PYTORCH_WEIGHTS_PATTERN = 'pytorch_model*.bin*'
+SAFE_WEIGHTS_PATTERN = 'model*.safetensors*'
+
+log = logging.getLogger(__name__)
+
+
+@tenacity.retry(retry=tenacity.retry_if_not_exception_type(
+ (ValueError, hf_hub.utils.RepositoryNotFoundError)),
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(min=1, max=10))
+def download_from_hf_hub(
+ repo_id: str,
+ save_dir: Optional[str] = None,
+ prefer_safetensors: bool = True,
+ token: Optional[str] = None,
+):
+ """Downloads model files from a Hugging Face Hub model repo.
+
+ Only supports models stored in Safetensors and PyTorch formats for now. If both formats are available, only the
+ Safetensors weights will be downloaded unless `prefer_safetensors` is set to False.
+
+ Args:
+ repo_id (str): The Hugging Face Hub repo ID.
+ save_dir (str, optional): The path to the directory where the model files will be downloaded. If `None`, reads
+ from the `HUGGINGFACE_HUB_CACHE` environment variable or uses the default Hugging Face Hub cache directory.
+ prefer_safetensors (bool): Whether to prefer Safetensors weights over PyTorch weights if both are
+ available. Defaults to True.
+ token (str, optional): The HuggingFace API token. If not provided, the token will be read from the
+ `HUGGING_FACE_HUB_TOKEN` environment variable.
+
+ Raises:
+ RepositoryNotFoundError: If the model repo doesn't exist or the token is unauthorized.
+ ValueError: If the model repo doesn't contain any supported model weights.
+ """
+ repo_files = set(hf_hub.list_repo_files(repo_id))
+
+ # Ignore TensorFlow, TensorFlow 2, and Flax weights as they are not supported by Composer.
+ ignore_patterns = copy.deepcopy(DEFAULT_IGNORE_PATTERNS)
+
+ safetensors_available = (SAFE_WEIGHTS_NAME in repo_files or
+ SAFE_WEIGHTS_INDEX_NAME in repo_files)
+ pytorch_available = (PYTORCH_WEIGHTS_NAME in repo_files or
+ PYTORCH_WEIGHTS_INDEX_NAME in repo_files)
+
+ if safetensors_available and pytorch_available:
+ if prefer_safetensors:
+ log.info(
+ 'Safetensors available and preferred. Excluding pytorch weights.'
+ )
+ ignore_patterns.append(PYTORCH_WEIGHTS_PATTERN)
+ else:
+ log.info(
+ 'Pytorch available and preferred. Excluding safetensors weights.'
+ )
+ ignore_patterns.append(SAFE_WEIGHTS_PATTERN)
+ elif safetensors_available:
+ log.info('Only safetensors available. Ignoring weights preference.')
+ elif pytorch_available:
+ log.info('Only pytorch available. Ignoring weights preference.')
+ else:
+ raise ValueError(
+ f'No supported model weights found in repo {repo_id}.' +
+ ' Please make sure the repo contains either safetensors or pytorch weights.'
+ )
+
+ download_start = time.time()
+ hf_hub.snapshot_download(repo_id,
+ cache_dir=save_dir,
+ ignore_patterns=ignore_patterns,
+ token=token)
+ download_duration = time.time() - download_start
+ log.info(
+ f'Downloaded model {repo_id} from Hugging Face Hub in {download_duration} seconds'
+ )
+
+
+def _extract_links_from_html(html: str):
+ """Extracts links from HTML content.
+
+ Args:
+ html (str): The HTML content
+
+ Returns:
+ list[str]: A list of links to download.
+ """
+ soup = BeautifulSoup(html, 'html.parser')
+ links = [a['href'] for a in soup.find_all('a')]
+ return links
+
+
+def _recursive_download(
+ session: requests.Session,
+ base_url: str,
+ path: str,
+ save_dir: str,
+ ignore_cert: bool = False,
+):
+ """Downloads all files/subdirectories from a directory on a remote server.
+
+ Args:
+ session: A requests.Session through which to make requests to the remote server.
+ url (str): The base URL where the files are located.
+ path (str): The path from the base URL to the files to download. The full URL for the download is equal to
+ '/'.
+ save_dir (str): The directory to save downloaded files to.
+ ignore_cert (bool): Whether or not to ignore the validity of the SSL certificate of the remote server.
+ Defaults to False.
+ WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
+
+ Raises:
+ PermissionError: If the remote server returns a 401 Unauthorized status code.
+ ValueError: If the remote server returns a 404 Not Found status code.
+ RuntimeError: If the remote server returns a status code other than 200 OK or 401 Unauthorized.
+ """
+ url = urljoin(base_url, path)
+ response = session.get(url, verify=(not ignore_cert))
+
+ if response.status_code == HTTPStatus.UNAUTHORIZED:
+ raise PermissionError(
+ f'Not authorized to download file from {url}. Received status code {response.status_code}. '
+ )
+ elif response.status_code == HTTPStatus.NOT_FOUND:
+ raise ValueError(
+ f'Could not find file at {url}. Received status code {response.status_code}'
+ )
+ elif response.status_code != HTTPStatus.OK:
+ raise RuntimeError(
+ f'Could not download file from {url}. Received unexpected status code {response.status_code}'
+ )
+
+ # Assume that the URL points to a file if it does not end with a slash.
+ if not path.endswith('/'):
+ save_path = os.path.join(save_dir, path)
+ parent_dir = os.path.dirname(save_path)
+ if not os.path.exists(parent_dir):
+ os.makedirs(parent_dir)
+
+ with open(save_path, 'wb') as f:
+ f.write(response.content)
+
+ log.info(f'Downloaded file {save_path}')
+ return
+
+ # If the URL is a directory, the response should be an HTML directory listing that we can parse for additional links
+ # to download.
+ child_links = _extract_links_from_html(response.content.decode())
+ for child_link in child_links:
+ _recursive_download(session,
+ base_url,
+ urljoin(path, child_link),
+ save_dir,
+ ignore_cert=ignore_cert)
+
+
+@tenacity.retry(retry=tenacity.retry_if_not_exception_type(
+ (PermissionError, ValueError)),
+ stop=tenacity.stop_after_attempt(3),
+ wait=tenacity.wait_exponential(min=1, max=10))
+def download_from_cache_server(
+ model_name: str,
+ cache_base_url: str,
+ save_dir: str,
+ token: Optional[str] = None,
+ ignore_cert: bool = False,
+):
+ """Downloads Hugging Face models from a mirror file server.
+
+ The file server is expected to store the files in the same structure as the Hugging Face cache
+ structure. See https://huggingface.co/docs/huggingface_hub/guides/manage-cache.
+
+ Args:
+ model_name: The name of the model to download. This should be the same as the repository ID in the Hugging Face
+ Hub.
+ cache_base_url: The base URL of the cache file server. This function will attempt to download all of the blob
+ files from `//blobs/`, where `formatted_model_name` is equal to
+ `models/` with all slashes replaced with `--`.
+ save_dir: The directory to save the downloaded files to.
+ token: The Hugging Face API token. If not provided, the token will be read from the `HUGGING_FACE_HUB_TOKEN`
+ environment variable.
+ ignore_cert: Whether or not to ignore the validity of the SSL certificate of the remote server. Defaults to
+ False.
+ WARNING: Setting this to true is *not* secure, as no certificate verification will be performed.
+ """
+ formatted_model_name = f'models/{model_name}'.replace('/', '--')
+ with requests.Session() as session:
+ session.headers.update({'Authorization': f'Bearer {token}'})
+
+ download_start = time.time()
+
+ # Temporarily suppress noisy SSL certificate verification warnings if ignore_cert is set to True
+ with warnings.catch_warnings():
+ if ignore_cert:
+ warnings.simplefilter('ignore', category=InsecureRequestWarning)
+
+ # Only downloads the blobs in order to avoid downloading model files twice due to the
+ # symlnks in the Hugging Face cache structure:
+ _recursive_download(
+ session,
+ cache_base_url,
+ # Trailing slash to indicate directory
+ f'{formatted_model_name}/blobs/',
+ save_dir,
+ ignore_cert=ignore_cert,
+ )
+ download_duration = time.time() - download_start
+ log.info(
+ f'Downloaded model {model_name} from cache server in {download_duration} seconds'
+ )
diff --git a/mcli/mcli-hf-eval.yaml b/mcli/mcli-hf-eval.yaml
index accff7d5c0..46aef69940 100644
--- a/mcli/mcli-hf-eval.yaml
+++ b/mcli/mcli-hf-eval.yaml
@@ -1,20 +1,22 @@
integrations:
- integration_type: git_repo
git_repo: mosaicml/llm-foundry
- git_branch: v0.3.0
+ git_branch: output_eval_logging
# git_commit: # OR use your commit hash
pip_install: -e ".[gpu]"
ssh_clone: false # Should be true if using a private repo
command: |
+ pip uninstall mosaicml -y
+ pip install git+https://github.com/bmosaicml/composer.git@error_logging_callback
cd llm-foundry/scripts
composer eval/eval.py /mnt/config/parameters.yaml
# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
-run_name: mpt-eval
+run_name: output-logger-test
gpu_num: 8
-# gpu_type:
-# cluster: # replace with your cluster here!
+gpu_type: a100_80gb
+cluster: r1z1 # replace with your cluster here!
image: mosaicml/llm-foundry:2.0.1_cu118-latest
@@ -31,13 +33,13 @@ parameters:
model_name: mosaicml/mpt-7b-instruct
# Tokenizer
tokenizer:
- name: EleutherAI/gpt-neox-20b
+ name: mosaicml/mpt-7b-instruct
kwargs:
model_max_length: ${max_seq_len}
model:
name: hf_causal_lm
- pretrained_model_name_or_path: mosaicml/mpt-7b-instruct
+ pretrained_model_name_or_path: mosaicml/mpt-7b-instruct
init_device: mixed
pretrained: true
use_auth_token: false
@@ -50,5 +52,17 @@ parameters:
limit_all_gathers: True
- icl_tasks: 'eval/yamls/tasks.yaml'
- eval_gauntlet: 'eval/yamls/eval_gauntlet.yaml'
+ icl_tasks:
+ -
+ label: jeopardy
+ dataset_uri: eval/local_data/world_knowledge/jeopardy_all.jsonl # ADD YOUR OWN DATASET URI
+ num_fewshot: [10]
+ icl_task_type: language_modeling
+ continuation_delimiter: "\nAnswer: " # this separates questions from answers
+ has_categories: true
+
+ callbacks:
+ eval_output_logging:
+ subset_sample: -1
+ output_directory: s3://mosaicml-internal-checkpoints-test/test_icl_output_logger_7b
+
diff --git a/mcli/mcli-llama2-finetune.yaml b/mcli/mcli-llama2-finetune.yaml
index ae8f57abb6..93d46f57e3 100644
--- a/mcli/mcli-llama2-finetune.yaml
+++ b/mcli/mcli-llama2-finetune.yaml
@@ -56,7 +56,10 @@ parameters:
allow_pad_trimming: false
decoder_only_format: true
shuffle: true
- # # Use `python llmfoundry/data/packing.py --yaml-path /path/to/this/yaml/ ...`
+ # # Use packing_ratio: 'auto' to automatically profile and select the highest observed packing ratio with
+ # # zero waste. In practice, this may result in > 0 waste because profiling is done on only a portion
+ # # of the dataset.
+ # # Or use `python llmfoundry/scripts/misc/profile_packing.py --yaml-path /path/to/this/yaml/ ...`
# # to profile this run's optimal packing_ratio as it depends on GPU count,
# # batch size, sequence length
# packing_ratio:
diff --git a/mcli/mcli-rlhf-eval.yaml b/mcli/mcli-rlhf-eval.yaml
new file mode 100644
index 0000000000..e28f53ac8b
--- /dev/null
+++ b/mcli/mcli-rlhf-eval.yaml
@@ -0,0 +1,68 @@
+integrations:
+- integration_type: git_repo
+ git_repo: mosaicml/llm-foundry
+ git_branch: output_eval_logging
+ # git_commit: # OR use your commit hash
+ pip_install: -e ".[gpu]"
+ ssh_clone: false # Should be true if using a private repo
+
+command: |
+ pip uninstall mosaicml -y
+ pip install git+https://github.com/bmosaicml/composer.git@error_logging_callback
+ cd llm-foundry/scripts
+ composer eval/eval.py /mnt/config/parameters.yaml
+
+# Mosaic Cloud will use run_name (with a unique suffix) to populate the env var $RUN_NAME
+run_name: output-logger-rlhf-prompts
+gpu_num: 8
+gpu_type: a100_80gb
+cluster: r1z1 # replace with your cluster here!
+
+image: mosaicml/llm-foundry:2.0.1_cu118-latest
+
+# The below is injected as a YAML file: /mnt/config/parameters.yaml
+parameters:
+ dist_timeout: 6000
+ seed: 1
+ max_seq_len: 1024
+ device_eval_batch_size: 1
+ precision: amp_fp16
+
+ models:
+ -
+ model_name: mosaicml/mpt-30b-instruct
+ # Tokenizer
+ tokenizer:
+ name: mosaicml/mpt-30b-instruct
+ kwargs:
+ model_max_length: ${max_seq_len}
+
+ model:
+ name: hf_causal_lm
+ pretrained_model_name_or_path: mosaicml/mpt-30b-instruct
+ init_device: mixed
+ pretrained: true
+ use_auth_token: false
+
+ # FSDP config for model sharding
+ fsdp_config:
+ sharding_strategy: FULL_SHARD
+ mixed_precision: FULL
+ forward_prefetch: True
+ limit_all_gathers: True
+
+
+ icl_tasks:
+ -
+ label: rlhf_prompts
+ dataset_uri: eval/local_data/rlhf_prompts/rlhf_prompts.jsonl # ADD YOUR OWN DATASET URI
+ num_fewshot: [0]
+ icl_task_type: question_answering
+ has_categories: true
+
+ callbacks:
+ eval_output_logging:
+ print_only_incorrect: false
+ subset_sample: -1
+ output_directory: s3://mosaicml-internal-checkpoints-test/30b_instruct_rlhf_prompts
+
diff --git a/pyproject.toml b/pyproject.toml
index a2fcec3eed..0b078120b3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -86,13 +86,6 @@ filterwarnings = [
'ignore::DeprecationWarning:tensorboard', # ignore tensorboard
]
-# Enable logging for pytest
-log_cli = true
-log_cli_level = "INFO"
-log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)"
-log_cli_date_format = "%Y-%m-%d %H:%M:%S"
-
-
# Yapf
[tool.yapf]
# Align closing bracket with visual indentation.
diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py
index f07942ba10..7891a3ed96 100644
--- a/scripts/eval/eval.py
+++ b/scripts/eval/eval.py
@@ -7,7 +7,7 @@
import time
import warnings
from typing import Any, Dict, List, Optional, Union
-
+from composer.core.callback import Callback
import pandas as pd
import torch
from composer.loggers.logger_destination import LoggerDestination
@@ -21,7 +21,7 @@
from llmfoundry.models import MPTForCausalLM
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
-from llmfoundry.utils.builders import (build_icl_data_and_gauntlet,
+from llmfoundry.utils.builders import (build_icl_data_and_gauntlet, build_callback,
build_logger, build_tokenizer)
from llmfoundry.utils.config_utils import pop_config, process_init_device
@@ -107,6 +107,7 @@ def evaluate_model(
precision: str,
eval_gauntlet_df: Optional[pd.DataFrame],
icl_subset_num_batches: Optional[int],
+ callback_configs: Optional[Dict]
):
print(f'Evaluating model: {model_cfg.model_name}', flush=True)
@@ -122,7 +123,12 @@ def evaluate_model(
icl_tasks, eval_gauntlet_config, tokenizer, device_eval_batch_size,
max_seq_len, icl_subset_num_batches)
- callbacks = []
+ # Callbacks
+ callbacks: List[Callback] = [
+ build_callback(str(name), callback_cfg)
+ for name, callback_cfg in callback_configs.items()
+ ] if callback_configs else []
+
if eval_gauntlet_callback is not None:
callbacks.append(eval_gauntlet_callback)
@@ -145,7 +151,8 @@ def evaluate_model(
if eval_gauntlet_df is None and eval_gauntlet_callback is not None:
eval_gauntlet_df = pd.DataFrame(
- columns=['model_name', 'average'] +
+ columns=['model_name'] +
+ [avg for avg in eval_gauntlet_callback.averages] +
[t.name for t in eval_gauntlet_callback.categories])
load_path = model_cfg.get('load_path', None)
@@ -173,6 +180,7 @@ def evaluate_model(
dist_timeout=dist_timeout,
python_log_level=python_log_level,
)
+
if torch.cuda.is_available():
torch.cuda.synchronize()
@@ -251,7 +259,11 @@ def main(cfg: DictConfig):
default_value=None)
# Pop out interpolation variables.
pop_config(cfg, 'model_name_or_path', must_exist=False, default_value=None)
-
+ callback_configs: Optional[DictConfig] = pop_config(cfg,
+ 'callbacks',
+ must_exist=False,
+ default_value=None)
+
# Warn for unused parameters
for key in cfg:
warnings.warn(
@@ -290,7 +302,9 @@ def main(cfg: DictConfig):
python_log_level=python_log_level,
precision=precision,
eval_gauntlet_df=eval_gauntlet_df,
- icl_subset_num_batches=icl_subset_num_batches)
+ icl_subset_num_batches=icl_subset_num_batches,
+ callback_configs=callback_configs
+ )
if eval_gauntlet_callback is not None:
composite_scores = eval_gauntlet_callback.eval_after_all(
@@ -314,28 +328,25 @@ def main(cfg: DictConfig):
if eval_gauntlet_df is not None and eval_gauntlet_callback is not None:
assert composite_scores is not None
row = {'model_name': model_cfg['model_name']}
- row.update({
- t.name:
- composite_scores.get(f'icl/metrics/eval_gauntlet/{t.name}',
- None)
- for t in eval_gauntlet_callback.categories
- })
- row.update({
- 'average':
- composite_scores[f'icl/metrics/eval_gauntlet/average']
- })
+ row.update(
+ {k.split('/')[-1]: v for k, v in composite_scores.items()})
eval_gauntlet_df = pd.concat(
[eval_gauntlet_df, pd.DataFrame([row])], ignore_index=True)
print(f'Printing gauntlet results for all models')
+
print(
eval_gauntlet_df.sort_values(
- 'average', ascending=False).to_markdown(index=False))
+ list(eval_gauntlet_callback.averages.keys())[0],
+ ascending=False).to_markdown(index=False))
print(f'Printing complete results for all models')
assert models_df is not None
print(models_df.to_markdown(index=False))
+
+
+
def calculate_markdown_results(logger_keys: List[str], trainer: Trainer,
benchmark_to_taxonomy: Dict[str, str],
model_name: str):
diff --git a/scripts/eval/local_data/rlhf_prompts/rlhf_prompts.jsonl b/scripts/eval/local_data/rlhf_prompts/rlhf_prompts.jsonl
new file mode 100644
index 0000000000..87b4d68e46
--- /dev/null
+++ b/scripts/eval/local_data/rlhf_prompts/rlhf_prompts.jsonl
@@ -0,0 +1,187 @@
+{"context": "Imagine you go back in time. What is a clue you could give your past self that you are legitimately you?", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Generate a D&D statblock for a 6th level High Elf Druid named Quinestra. Make sure to come up with a special ability, possibly a unique spell or feat, that Quinestra can perform. Include a brief character bio that describes Quinestra's backstory.", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What kind of plushies would be the most impressive to have when inviting my friends over.", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What are virgos well known for?", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "what ability should i max first if i'm going tankmo.", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What's the better army to go with in Warhammer between Necrons and Tau and Ultramarines", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What aspects of design connote a feeling of calm?", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "what should I cook if i only have potatoes and leeks in my fridge? Keep in mind I have gluten intolerance and I am looking for something savory but appropriate for summer", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Analyze the given transcript of two people collaborating to solve a problem. First summarize the problem they are trying to solve with tag . Then, I want you to label with what collaborative skills they're using and then say exactly why it is useful in helping them achieve their goal using ", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "write a casual twitter thread (lower case, fun, playful) about the experience of living in bushwick in the summer", "category": "chat_(generate)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Here's an email thread, can you clarify the intent of the email?", "category": "Classification", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Here's a product description, can you help me classify what product this is?", "category": "Classification", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Here's a piece of text, can you help me classify how difficult the text is?", "category": "Classification", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Here's a description for an event, can you classify what type of an event it is?", "category": "Classification", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Score the positivity of the following product review on a scale of 1 (least positive) to 5 (most positive): The charger that came with this camera does not charge the battery in this camera, I haven't been able to use this camera once since purchasing", "category": "Classification", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you explain how stacks and queues work?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "I'm trying to prepare for a coding interview and I expect to get a question on data structures. Please generate a coding interview puzzle including any constraints I need to follow. I'll give you my best answer and then I'll ask you to grade it.", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you explain to me how quicksort works?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Design an efficient algorithm that given a directed graph G, outputs the set of all vertices v such that there is a cycle containing v. In other words, your algorithm should output all v such that there is a non-empty path from v to itself in G.", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you explain to me how quicksort works?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Design an efficient algorithm that given a directed graph G, outputs the set of all vertices v such that there is a cycle containing v. In other words, your algorithm should output all v such that there is a non-empty path from v to itself in G.", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Write python code to sum all of the primes from 5 to 10 million, except the primes that end in 3.", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What is the Python GIL?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Should I be using multithreading or multiprocessing in python?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "I have written the following code so far:\n \nHow can I speed up this code even more?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How can I debug this multithreading, multiprocessing code?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you help write unit tests to make sure the distributed version of gradient clipping works?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you explain what gradient clipping is?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How would I implement a thread-safe queue in C++?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Write C++ code to in-place sort a given array, without any extra memory", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Here is a segfault, stacktrace from C++.\n\nWhy is it segfaulting?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you write me an object oriented program (OOP) in scala to track a grocery store, their inventory, their employees, and their cashflow", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How could I check if a string is a palindrome in Java?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How do you create a hyperlink in HTML?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Write js code to query an API endpoint with the \"POST\" method, sending a json input, and process the output query?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What is a cursor and how do I use it?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "What is an index?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you explain to me the different joins?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How can you repartition the disk in linux", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How do you turn distracting sites black and white on your iphone using automation", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How do I remove all eos tokens from this string?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Consider the following 2 cells and suggest a completion that makes sense: Cell1: from pyspark.sql import types as T from pyspark.sql import functions as F reddit_comments_schema = T.StructType([ T.StructField(\"id\", T.StringType()), T.StructField(\"parent_id\", T.StringType()), T.StructField(\"author\", T.StringType()), T.StructField(\"link_id\", T.StringType()), T.StructField(\"subreddit\", T.StringType()), T.StructField(\"subreddit_id\", T.StringType()), T.StructField(\"edited\", T.BooleanType()), T.StructField(\"score\", T.LongType()), T.StructField(\"body\", T.StringType()), T.StructField(\"created_utc\", T.LongType()), T.StructField(\"retrieved_utc\", T.LongType()), T.StructField(\"retrieved_on\", T.LongType()), ]) reddit_df = ( spark.read.json( '/pushshift/decompression', schema=reddit_comments_schema, ) .withColumn( 'retrieved_on', F.when( F.col('retrieved_utc').isNotNull(), F.col('retrieved_utc') ).otherwise( F.col('retrieved_on') ) ) ) Cell2: from pyspark.sql.functions import col data = reddit_df # Extract only the parent-child relationships relationships = data.select(\"id\", \"parent_id\") # Initialize the transitive closure DataFrame with the direct relationships transitive_closure = relationships # You might need to iterate multiple times, depending on the depth of your comment chains. # For this example, I'll assume 10 iterations. Adjust as needed. for _ in range(10):", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you help me write a basic pipeline in spark for mapping, and filtering data?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you give a basic explanation of data frame for me?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Extract the graph path from the html code: Graph Generated Graph
", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How can I optimize this code for runtime over a big dataframe: from pyspark.sql.functions import col approx_join_df = model.approxSimilarityJoin(sig, sig, threshold=0.6, distCol=\"distance\"). select(\"datasetA.id\", \"datasetA.text\", \"datasetB.id\", \"datasetB.text\"). filter(col(\"datasetA.id\") != col(\"datasetB.id\")) display(approx_join_df)", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Please matplotlib plot points using seaborn where the date is formatted like: 'October 11, 2012' . can you 1) bin the dates by month, 2) plot all the bins including bins where the count is 0 from the min month to to max month. The labels of the x axis should be the dates, just the year and the month. thanks!", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "If I am detecting a face in a video, how can I crop out just the facial bounding box, but have that bounding box remain the same size for the whole video?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Write a blender script that animates the multiplication of two matrices where the matrices are represented by 3D blocks", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Can you write a script that converts my obj models to unity fbx formatted models?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Assume you have two tables \"Orders\" and \"Buyers\". The \"Orders\" table has columns OrderID, CusteromerID, Amount, and Rating. The \"Customers\" table has the columns CusteromID, Email. Can you write a sql query to query for the CustomerID, Email, and Rating where the Rating is less than 3.0?", "category": "coding", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "How do I write a basic machine learning program in PyTorch given the PyTorch documentation", "category": "coding_RAG", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "{\"role\": \"system\", \"content\": \"Always answer with Haiku\"}", "category": "corrections", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Please increase the difficulty of the given programming test question. You can increase the difficulty using, but not limited to, the following methods: - Add new constraints and requirements to the original problem, adding approximately 10 additional words. - Replace a commonly used requirement in the programming task with a less common and more specific one. - If the original problem can be solved with only a few logical steps, please add more reasoning steps. - Provide a piece of erroneous code as a reference to increase misdirection. - Propose higher time or space complexity requirements, but please refrain from doing so frequently. Question: Given an array of integers `nums` and an integer `target`, return _indices of the two numbers such that they add up to `target`_. You may assume that each input would have **_exactly_ one solution**, and you may not use the _same_ element twice. You can return the answer in any order. **Example 1:** **Input:** nums = [2,7,11,15], target = 9 **Output:** [0,1] **Explanation:** Because nums[0] + nums[1] == 9, we return [0, 1]. **Example 2:** **Input:** nums = [3,2,4], target = 6 **Output:** [1,2] **Example 3:** **Input:** nums = [3,3], target = 6 **Output:** [0,1] **Constraints:** * `2 <= nums.length <= 104` * `-109 <= nums[i] <= 109` * `-109 <= target <= 109` * **Only one valid answer exists.** **Follow-up:** Can you come up with an algorithm that is less than `O(n2)` time complexity?", "category": "datagen_(formats)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "I want you act as a Prompt Rewriter. Your objective is to rewrite a given prompt into a more complex version to make those famous AI systems (e.g., ChatGPT and GPT4) a bit harder to handle. But the rewritten prompt must be reasonable and must be understood and responded by humans. Your rewriting cannot omit the non-text parts such as the table and code in #Given Prompt#:. Also, please do not omit the input in #Given Prompt#. You SHOULD complicate the given prompt using the following method: Please add one more constraints/requirements into #Given Prompt# You should try your best not to make the #Rewritten Prompt# become verbose, #Rewritten Prompt# can only add 10 to 20 words into #Given Prompt#. ‘#Given Prompt#’, ‘#Rewritten Prompt#’, ‘given prompt’ and ‘rewritten prompt’ are not allowed to appear in #Rewritten Prompt# #Given Prompt#: what is the difference between ARM and x86 chipsets? #Rewritten Prompt#:", "category": "datagen_(formats)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "{\"role\": \"system\", \"content\": \"You are a helpful assistant. Please answer truthfully and write out your thinking step by step to be sure you get the right answer. If you make a mistake or encounter an error in your thinking, say so out loud and attempt to correct it. If you don't know or aren't sure about something, say so clearly. You will act as a professional logician, mathematician, and physicist. You will also act as the most appropriate type of expert to answer any particular question or solve the relevant problem; state which expert type your are, if so. Also think of any particular named expert that would be ideal to answer the relevant question or solve the relevant problem; name and act as them, if appropriate.\"}", "category": "datagen_(formats)", "answer": "<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>", "aliases": []}
+{"context": "Below is an instruction from an user and a candidate answer. Evaluate whether or not the answer is a good example of how AI Assistant should respond to the user’s instruction. Please assign a score using the following 5-point scale: 1: It means the answer is incomplete, vague, off-topic, controversial, or not exactly what the user asked for. For example, some content seems missing, numbered list does not start from the beginning, the opening sentence repeats user’s question. Or the response is from another person’s perspective with their personal experience (e.g. taken from blog posts), or looks like an answer from a forum. Or it contains promotional text, navigation text, or other irrelevant information. 2: It means the answer addresses most of the asks from the user. It does not directly address the user’s question. For example, it only provides a high-level methodology instead of the exact solution to user’s question. 3: It means the answer is helpful but not written by an AI Assistant. It addresses all the basic asks from the user. It is complete and self contained with the drawback that the response is not written from an AI assistant’s perspective, but from other people’s perspective. The content looks like an excerpt from a blog post, web page, or web search results. For example, it contains personal experience or opinion, mentions comments section, or share on social media, etc. 4: It means the answer is written from an AI assistant’s perspective with a clear focus of addressing the instruction. It provide a complete, clear, and comprehensive response to user’s question or instruction without missing or irrelevant information. It is well organized, self-contained, and written in a helpful tone. It has minor room for improvement, e.g. more concise and focused. 5: It means it is a perfect answer from an AI Assistant. It has a clear focus on being a helpful AI Assistant, where the response looks like intentionally written to address the user’s question or instruction without any irrelevant sentences. The answer provides high quality content, demonstrating expert knowledge in the area, is very well written, logical, easy-to-follow, engaging and insightful. Please first provide a brief reasoning you used to derive the rating score, and then write \"Score: \" in the last line.