diff --git a/klue_re.ipynb b/klue_re.ipynb
new file mode 100644
index 0000000..389e40e
--- /dev/null
+++ b/klue_re.ipynb
@@ -0,0 +1,1597 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "impressive-sociology",
+ "metadata": {},
+ "source": [
+ "# BERT를 활용한 관계추출(Relation Extraction, RE)\n",
+ "본 Workspace에서는 klue/bert-base모델을 이용하여 KLUE 내의 8개 Task 중 관계추출(Relation Extraction) Task에 대해서 다룹니다.\n",
+ "\n",
+ "먼저 관계추출 Task란 문장에 있는 두 개체(entity)간의 관계가 무엇인지 분류하는 것입니다.\n",
+ "\n",
+ "따라서 문장과 문장에 있는 두 개체들이 입력으로 주어지면 두 개체들간의 관계가 출력으로 나옵니다.\n",
+ "\n",
+ "![Imgur](https://i.imgur.com/xeTQRVC.png)\n",
+ "\n",
+ "사진과 같이 개체는 subject entity와 object entity가 있는데, subject entity가 사람(person)이면 \":\" 앞에는 per, 기관(organization)이면 org이 됩니다.\n",
+ "\n",
+ "그리고 \":\" 뒤에 있는 것은 object entity가 subject entity와 무슨 관계인지를 나타냅니다.\n",
+ "\n",
+ "위의 사진과 같이 subject entity가 사람(person)이고, object entity가 subject entity의 출생지(origin)이므로 관계는 per:origin이 됩니다.\n",
+ "\n",
+ "KLUE 관계추출 Task에 있는 관계는 총 30개이고, 관계일부를 캡처한 사진은 다음과 같습니다.\n",
+ "\n",
+ "![Imgur](https://i.imgur.com/TnfSUPo.png)\n",
+ "\n",
+ "만약 전체 관계 목록을 보고 싶으시면 링크를 확인해보세요!"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "august-garden",
+ "metadata": {},
+ "source": [
+ "# 필요한 라이브러리를 설치합니다.\n",
+ "datasets은 KLUE 데이터셋을 가져오기 위해, sklearn은 학습한 모델을 평가하기 위해 필요합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "collective-number",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Requirement already satisfied: datasets in /opt/conda/lib/python3.7/site-packages (1.8.0)\n",
+ "Requirement already satisfied: huggingface-hub<0.1.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (0.0.13)\n",
+ "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (2.25.1)\n",
+ "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from datasets) (3.7.3)\n",
+ "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from datasets) (1.2.5)\n",
+ "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.7/site-packages (from datasets) (0.70.12.2)\n",
+ "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from datasets) (1.18.5)\n",
+ "Requirement already satisfied: pyarrow<4.0.0,>=1.0.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (3.0.0)\n",
+ "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /opt/conda/lib/python3.7/site-packages (from datasets) (4.49.0)\n",
+ "Requirement already satisfied: xxhash in /opt/conda/lib/python3.7/site-packages (from datasets) (2.0.2)\n",
+ "Requirement already satisfied: fsspec in /opt/conda/lib/python3.7/site-packages (from datasets) (2021.6.1)\n",
+ "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from datasets) (20.9)\n",
+ "Requirement already satisfied: dill in /opt/conda/lib/python3.7/site-packages (from datasets) (0.3.4)\n",
+ "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from huggingface-hub<0.1.0->datasets) (3.7.4.3)\n",
+ "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n",
+ "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->datasets) (2.4.7)\n",
+ "Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (4.0.0)\n",
+ "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (2020.12.5)\n",
+ "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (2.10)\n",
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (1.26.4)\n",
+ "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->datasets) (3.4.1)\n",
+ "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets) (2.8.1)\n",
+ "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets) (2021.1)\n",
+ "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n",
+ "Requirement already satisfied: sklearn in /opt/conda/lib/python3.7/site-packages (0.0)\n",
+ "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from sklearn) (0.24.2)\n",
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (2.1.0)\n",
+ "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.0.1)\n",
+ "Requirement already satisfied: scipy>=0.19.1 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.4.1)\n",
+ "Requirement already satisfied: numpy>=1.13.3 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.18.5)\n"
+ ]
+ }
+ ],
+ "source": [
+ "!pip install datasets\n",
+ "!pip install sklearn"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "organic-valley",
+ "metadata": {},
+ "source": [
+ "# 필요한 라이브러리를 import 합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "looking-relevance",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn as nn\n",
+ "import sklearn.metrics\n",
+ "\n",
+ "from tqdm import tqdm\n",
+ "from datasets import load_dataset\n",
+ "from datasets.arrow_dataset import Dataset\n",
+ "from transformers import AutoTokenizer, AutoModel, AdamW\n",
+ "from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification\n",
+ "from torch.utils.data import DataLoader"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "smaller-truth",
+ "metadata": {},
+ "source": [
+ "# GPU 사용을 위해 device를 설정합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "musical-concord",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "valued-delta",
+ "metadata": {},
+ "source": [
+ "# KLUE RE 데이터셋을 가져옵니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "induced-breast",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Reusing dataset klue (/workspace/.cache/huggingface/datasets/klue/re/1.0.0/55ff8f92b7a4b9842be6514ce0b4b5295b46d5e493f8bb5760da4be717018f90)\n"
+ ]
+ }
+ ],
+ "source": [
+ "dataset = load_dataset('klue', 're')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "higher-compact",
+ "metadata": {},
+ "source": [
+ "# 데이터셋은 train과 validation 데이터로 구성되어 있습니다. \n",
+ "train 데이터들은 모델을 학습할 때 사용될 예정이고, validation 데이터들은 학습이 아닌 모델의 성능을 평가할 때 사용됩니다.\n",
+ "\n",
+ "train 데이터와 validation 데이터의 구성은 동일하므로 train 데이터의 구성만 살펴보도록 하겠습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "respective-labor",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "DatasetDict({\n",
+ " train: Dataset({\n",
+ " features: ['guid', 'label', 'object_entity', 'sentence', 'source', 'subject_entity'],\n",
+ " num_rows: 32470\n",
+ " })\n",
+ " validation: Dataset({\n",
+ " features: ['guid', 'label', 'object_entity', 'sentence', 'source', 'subject_entity'],\n",
+ " num_rows: 7765\n",
+ " })\n",
+ "})"
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "increasing-clarity",
+ "metadata": {},
+ "source": [
+ "# 데이터 구성\n",
+ "각 데이터들은 다음과 같이 문장과 관계 추출에 사용될 2개의 개체(object_entity, subject_entity)와 두 개체간의 관계를 라벨로 가지고 있습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "assisted-brooklyn",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'guid': 'klue-re-v1_train_00000',\n",
+ " 'label': 0,\n",
+ " 'object_entity': {'word': '조지 해리슨',\n",
+ " 'start_idx': 13,\n",
+ " 'end_idx': 18,\n",
+ " 'type': 'PER'},\n",
+ " 'sentence': '〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.',\n",
+ " 'source': 'wikipedia',\n",
+ " 'subject_entity': {'word': '비틀즈',\n",
+ " 'start_idx': 24,\n",
+ " 'end_idx': 26,\n",
+ " 'type': 'ORG'}}"
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "dataset['train'][0]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "established-platform",
+ "metadata": {},
+ "source": [
+ "# train데이터에서 각 label의 수를 살펴보겠습니다.\n",
+ "출력 결과를 보시면 30개의 라벨에 대해서 불균형이 심한 것을 확인할 수 있습니다.\n",
+ "\n",
+ "따라서 모델의 성능을 측정할 평가지표로 단순하게 Accuracy를 사용한다면 정확한 평가가 이루어질 수 없기 때문에 KLUE 논문에 따르면 평가지표로 F1 score와 AUPRC를 이용하였습니다.\n",
+ "\n",
+ "\n",
+ "또한 관계 없음(0번 라벨) 데이터가 많은 비중을 차지하고 있는데, 모델이 \"관계 없음\"을 예측하는 데에 많은 초점이 맞춰지지 않도록\n",
+ "\n",
+ "관계 없음에 해당하는 데이터를 제외하고 관계가 있는 데이터들에 대해서만 F1 score을 계산합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "english-royalty",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{0: 9534,\n",
+ " 1: 66,\n",
+ " 2: 450,\n",
+ " 3: 1195,\n",
+ " 4: 1320,\n",
+ " 5: 1866,\n",
+ " 6: 420,\n",
+ " 7: 98,\n",
+ " 8: 380,\n",
+ " 9: 155,\n",
+ " 10: 4284,\n",
+ " 11: 48,\n",
+ " 12: 1130,\n",
+ " 13: 418,\n",
+ " 14: 166,\n",
+ " 15: 40,\n",
+ " 16: 193,\n",
+ " 17: 1234,\n",
+ " 18: 3573,\n",
+ " 19: 82,\n",
+ " 20: 1001,\n",
+ " 21: 520,\n",
+ " 22: 304,\n",
+ " 23: 136,\n",
+ " 24: 795,\n",
+ " 25: 190,\n",
+ " 26: 534,\n",
+ " 27: 139,\n",
+ " 28: 96,\n",
+ " 29: 2103}"
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "label_count = {}\n",
+ "\n",
+ "for data in dataset['train']:\n",
+ " label = data['label']\n",
+ " if label not in label_count:\n",
+ " label_count[label] = 1\n",
+ " else:\n",
+ " label_count[label] += 1\n",
+ "\n",
+ "label_count = dict(sorted(label_count.items(), key=lambda x: x[0]))\n",
+ "label_count"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "accredited-weapon",
+ "metadata": {},
+ "source": [
+ "# 개체 양 끝에 special token을 추가합니다.\n",
+ "KLUE 논문에 따르면 object 개체의 양 끝에는 \\, \\을, subject 개체의 양 끝에는 \\, \\ 토큰을 추가하여 개체의 위치를 표시한 후에 모델의 입력으로 주어집니다.\n",
+ "\n",
+ "따라서 데이터에 있는 entity index를 이용해서 해당 토큰을 추가해줍니다.\n",
+ "\n",
+ "토큰 추가의 예시 사진은 다음과 같습니다. 사진에서 빨간색으로 표시된 토큰들이 추가가 되는 토큰들입니다.\n",
+ "\n",
+ "![Imgur](https://i.imgur.com/gWNeyLv.png)\n",
+ "\n",
+ "그리고 학습에 필요한 데이터는 문장과 라벨 정보이므로 해당 부분만 가져오도록 합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "surface-squad",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def add_entity_tokens(sentence, object_entity, subject_entity):\n",
+ " obj_start_idx, obj_end_idx = object_entity['start_idx'], object_entity['end_idx']\n",
+ " subj_start_idx, subj_end_idx = subject_entity['start_idx'], subject_entity['end_idx']\n",
+ " \n",
+ " if obj_start_idx < subj_start_idx:\n",
+ " new_sentence = sentence[:obj_start_idx] + '' + sentence[obj_start_idx:obj_end_idx+1] + '' + \\\n",
+ " sentence[obj_end_idx+1:subj_start_idx] + '' + sentence[subj_start_idx:subj_end_idx+1] + \\\n",
+ " '' + sentence[subj_end_idx+1:]\n",
+ " else:\n",
+ " new_sentence = sentence[:subj_start_idx] + '' + sentence[subj_start_idx:subj_end_idx+1] + '' + \\\n",
+ " sentence[subj_end_idx+1:obj_start_idx] + '' + sentence[obj_start_idx:obj_end_idx+1] + \\\n",
+ " '' + sentence[obj_end_idx+1:]\n",
+ " \n",
+ " return new_sentence\n",
+ "\n",
+ "\n",
+ "def read_klue_re(dataset):\n",
+ " sentences = []\n",
+ " labels = []\n",
+ " \n",
+ " if isinstance(dataset, Dataset):\n",
+ " for data in dataset:\n",
+ " sentence = add_entity_tokens(data['sentence'], data['object_entity'], data['subject_entity'])\n",
+ " sentences.append(sentence)\n",
+ " labels.append(data['label'])\n",
+ " \n",
+ " return sentences, labels"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "efficient-absorption",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# train, validation데이터셋에서 sentence와 label만 저장.\n",
+ "train_sentences, train_labels = read_klue_re(dataset['train'])\n",
+ "val_sentences, val_labels = read_klue_re(dataset['validation'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "exclusive-travel",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다. \n",
+ "\n",
+ "호남이 기반인 바른미래당·대안신당·민주평화당이 우여곡절 끝에 합당해 민생당(가칭)으로 재탄생한다. \n",
+ "\n",
+ "K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다. \n",
+ "\n",
+ "균일가 생활용품점 (주)아성다이소(대표 박정부)는 코로나19 바이러스로 어려움을 겪고 있는 대구광역시에 행복박스를 전달했다고 10일 밝혔다. \n",
+ "\n",
+ "1967년 프로 야구 드래프트 1순위로 요미우리 자이언츠에게 입단하면서 등번호는 8번으로 배정되었다. \n",
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "# 개체 토큰이 정상적으로 잘 추가됐는지 확인하기 위해 train 문장 5개만 출력.\n",
+ "for i, sentence in enumerate(train_sentences[:5]):\n",
+ " print(sentence, '\\n')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "pointed-earthquake",
+ "metadata": {},
+ "source": [
+ "# klue/bert-base 모델을 사용할 예정이므로 모델에 맞는 tokenizer를 가져옵니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "southwest-retrieval",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "model_name = 'klue/bert-base'"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "color-footwear",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer = AutoTokenizer.from_pretrained(model_name)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "recreational-given",
+ "metadata": {},
+ "source": [
+ "# tokenizer를 이용한 토큰화 결과가 어떻게 나오는지 살펴보도록 하겠습니다.\n",
+ "예시 문장으로는 첫 번째 train 데이터의 문장을 이용하도록 하겠습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "exact-senegal",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ex_sentence = dataset['train'][0]['sentence']"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "fatty-finance",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.'"
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ex_sentence"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "electoral-witch",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "ex_encoding = tokenizer(ex_sentence,\n",
+ " max_length=128,\n",
+ " padding='max_length',\n",
+ " truncation=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "metropolitan-bathroom",
+ "metadata": {},
+ "source": [
+ "토큰화 결과로 Bert모델의 입력으로 필요한 input_ids, token_type_ids, attention_mask가 나오는 것을 확인할 수 있습니다.\n",
+ "\n",
+ "3가지 값에 대한 설명은 링크에서 확인할 수 있습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "pretty-luther",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "{'input_ids': [2, 168, 30985, 14451, 7088, 4586, 169, 793, 8373, 14113, 2234, 2052, 1363, 2088, 29830, 2116, 14879, 2440, 6711, 170, 21406, 26713, 2076, 25145, 5749, 171, 1421, 818, 2073, 4388, 2062, 18, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}"
+ ]
+ },
+ "execution_count": 16,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ex_encoding"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "wrong-corpus",
+ "metadata": {},
+ "source": [
+ "토큰화 된 문장을 다시 디코딩 해봄으로써 원본 문장을 얻을 수 있는지 확인해봅니다.\n",
+ "\n",
+ "디코딩 결과를 살펴보면 원본 문장 뒤에 [PAD] 토큰을 통해 입력 토큰의 개수가 max_length가 되도록 맞춥니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "cellular-arrow",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'[CLS] 〈 Something 〉 는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《 Abbey Road 》 에 담은 노래다. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'"
+ ]
+ },
+ "execution_count": 17,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "tokenizer.decode(ex_encoding['input_ids'])"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "variable-channels",
+ "metadata": {},
+ "source": [
+ "# Special token 추가\n",
+ "위에서 개체에 맞게 4개의 토큰(\\, \\, \\, \\)을 문장에 추가해주었는데, 해당 토큰들을 tokenizer에 special token이라고 알려주지 않으면 추가한 토큰들은 일반 문자로 인식되어서 토큰화될 수 있습니다. \n",
+ "\n",
+ "따라서 토큰화가 되지 않도록 추가한 4개의 토큰을 tokenizer에 special token으로 추가 해 줍니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "id": "matched-investor",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "entity_special_tokens = {'additional_special_tokens': ['', '', '', '']}\n",
+ "num_additional_special_tokens = tokenizer.add_special_tokens(entity_special_tokens)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "secondary-paris",
+ "metadata": {},
+ "source": [
+ "데이터로더 및 학습에 필요한 값들을 설정합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "id": "developmental-intersection",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# For Dataloader\n",
+ "batch_size = 8\n",
+ "\n",
+ "# For model\n",
+ "num_labels = 30\n",
+ "\n",
+ "# For train\n",
+ "learning_rate = 1e-5\n",
+ "weight_decay = 0.0\n",
+ "epochs = 3"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "literary-engagement",
+ "metadata": {},
+ "source": [
+ "# 학습에 이용할 데이터셋과 데이터로더를 만들어 줍니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "id": "trying-consent",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class KlueReDataset(torch.utils.data.Dataset):\n",
+ " def __init__(self, tokenizer, sentences, labels, max_length=128):\n",
+ " self.encodings = tokenizer(sentences,\n",
+ " max_length=max_length,\n",
+ " padding='max_length',\n",
+ " truncation=True)\n",
+ " self.labels = labels\n",
+ " \n",
+ " def __getitem__(self, idx):\n",
+ " item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}\n",
+ " item['labels'] = self.labels[idx]\n",
+ " \n",
+ " return item\n",
+ " \n",
+ " def __len__(self):\n",
+ " return len(self.labels)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "id": "thick-effectiveness",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "train_dataset = KlueReDataset(tokenizer, train_sentences, train_labels)\n",
+ "val_dataset = KlueReDataset(tokenizer, val_sentences, val_labels)\n",
+ "\n",
+ "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n",
+ "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "electronic-payday",
+ "metadata": {},
+ "source": [
+ "# klue/bert-base 모델을 로드합니다.\n",
+ "\n",
+ "저희가 다루는 관계추출 Task는 30개의 관계(클래스)를 분류하는 것이라고 할 수 있습니다.\n",
+ "\n",
+ "이를 위해 [CLS] 토큰의 벡터를 출력의 차원이 30인 1개의 Linear Layer에 통과시켜서 30개의 클래스로 분류하는 모델을 만들겠습니다.\n",
+ "\n",
+ "이를 간단히 사진으로 나타내면 다음과 같습니다.\n",
+ "\n",
+ "![Imgur](https://i.imgur.com/qaUObkV.png)\n",
+ "\n",
+ "모델을 로드할 때 Warning이 발생하는데 \"모델에서 가중치가 초기화되지 않은 부분이 있으니 예측을 하기 위해서는 모델을 학습시킨 후 사용해야된다.\"라는 내용을 담고 있습니다.\n",
+ "\n",
+ "저희는 이후에 모델을 fine-tuning할 것이기 때문에 해당 메세지를 신경쓰지 않아도 됩니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "id": "rocky-justice",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n",
+ "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
+ "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
+ ]
+ }
+ ],
+ "source": [
+ "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "electronic-asian",
+ "metadata": {},
+ "source": [
+ "모델 구성의 마지막에 있는 classifier 부분을 보시면 출력 차원이 30으로 설정되어 있는 것을 확인할 수 있습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 23,
+ "id": "attached-particle",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BertForSequenceClassification(\n",
+ " (bert): BertModel(\n",
+ " (embeddings): BertEmbeddings(\n",
+ " (word_embeddings): Embedding(32000, 768, padding_idx=0)\n",
+ " (position_embeddings): Embedding(512, 768)\n",
+ " (token_type_embeddings): Embedding(2, 768)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (encoder): BertEncoder(\n",
+ " (layer): ModuleList(\n",
+ " (0): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (1): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (2): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (3): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (4): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (5): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (6): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (7): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (8): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (9): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (10): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (11): BertLayer(\n",
+ " (attention): BertAttention(\n",
+ " (self): BertSelfAttention(\n",
+ " (query): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (key): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (value): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " (output): BertSelfOutput(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " (intermediate): BertIntermediate(\n",
+ " (dense): Linear(in_features=768, out_features=3072, bias=True)\n",
+ " )\n",
+ " (output): BertOutput(\n",
+ " (dense): Linear(in_features=3072, out_features=768, bias=True)\n",
+ " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " )\n",
+ " (pooler): BertPooler(\n",
+ " (dense): Linear(in_features=768, out_features=768, bias=True)\n",
+ " (activation): Tanh()\n",
+ " )\n",
+ " )\n",
+ " (dropout): Dropout(p=0.1, inplace=False)\n",
+ " (classifier): Linear(in_features=768, out_features=30, bias=True)\n",
+ ")"
+ ]
+ },
+ "execution_count": 23,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "alternate-import",
+ "metadata": {},
+ "source": [
+ "# Bert Embedding Layer을 resize합니다.\n",
+ "\n",
+ "Bert에는 토큰들의 id에 따른 임베딩 값을 반환하는 Embedding Layer가 존재합니다.\n",
+ "\n",
+ "하지만 현재 Embedding Layer에는 위에서 추가한 4개의 토큰에 대한 정보가 반영되지 않았기 때문에 추가한 토큰들이 입력으로 주어질 경우 index error가 발생합니다. \n",
+ "\n",
+ "따라서 Bert의 Embedding resize해줍니다.\n",
+ "\n",
+ "Resize를 하게되면 Embedding Layer의 input 차원이 32000에서 32004로 4만큼 증가합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 24,
+ "id": "stunning-worthy",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "Embedding(32004, 768)"
+ ]
+ },
+ "execution_count": 24,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "model.resize_token_embeddings(len(tokenizer))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "promising-yesterday",
+ "metadata": {},
+ "source": [
+ "학습 도중 Loss, Accuracy 계산 및 저장을 간단하게 하기 위해 AverageMeter를 클래스를 이용합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 25,
+ "id": "hourly-bridges",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "class AverageMeter():\n",
+ " def __init__(self):\n",
+ " self.val = 0\n",
+ " self.avg = 0\n",
+ " self.sum = 0\n",
+ " self.count = 0\n",
+ "\n",
+ " def update(self, val, n=1):\n",
+ " self.val = val\n",
+ " self.sum += val * n\n",
+ " self.count += n\n",
+ " self.avg = self.sum / self.count"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "laughing-electricity",
+ "metadata": {},
+ "source": [
+ "# Model fine-tuning\n",
+ "klue/bert-base 모델을 fine-tuning합니다.\n",
+ "\n",
+ "학습동안 학습이 잘 진행되고 있는지 확인하기 위해 Loss와 Accuracy를 출력합니다.\n",
+ "\n",
+ "Tesla T4 기준으로 1 epoch 당 약 14분 정도가 소요됩니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 26,
+ "id": "silver-algeria",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ " 0%| | 0/4059 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "< Epoch 1 / 3 >\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 4059/4059 [14:11<00:00, 4.77it/s]\n",
+ "100%|██████████| 971/971 [01:07<00:00, 14.30it/s]\n",
+ " 0%| | 0/4059 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "train_loss: 1.0293, train_acc: 69.64%, val_loss: 0.8416, val_acc: 70.05%\n",
+ "====================================================================================================\n",
+ "< Epoch 2 / 3 >\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 4059/4059 [14:10<00:00, 4.77it/s]\n",
+ "100%|██████████| 971/971 [01:07<00:00, 14.30it/s]\n",
+ " 0%| | 0/4059 [00:00, ?it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "train_loss: 0.5086, train_acc: 82.65%, val_loss: 0.7744, val_acc: 72.38%\n",
+ "====================================================================================================\n",
+ "< Epoch 3 / 3 >\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 4059/4059 [14:10<00:00, 4.77it/s]\n",
+ "100%|██████████| 971/971 [01:07<00:00, 14.32it/s]"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "train_loss: 0.3608, train_acc: 87.74%, val_loss: 0.7518, val_acc: 75.54%\n",
+ "====================================================================================================\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "\n"
+ ]
+ }
+ ],
+ "source": [
+ "def train_epoch(data_loader, model, criterion, optimizer, train=True):\n",
+ " loss_save = AverageMeter()\n",
+ " acc_save = AverageMeter()\n",
+ " \n",
+ " loop = tqdm(enumerate(data_loader), total=len(data_loader))\n",
+ " for _, batch in loop:\n",
+ " inputs = {\n",
+ " 'input_ids': batch['input_ids'].to(device),\n",
+ " 'token_type_ids': batch['token_type_ids'].to(device),\n",
+ " 'attention_mask': batch['attention_mask'].to(device),\n",
+ " }\n",
+ " labels = batch['labels'].to(device)\n",
+ " \n",
+ " optimizer.zero_grad()\n",
+ " outputs = model(**inputs)\n",
+ " logits = outputs['logits']\n",
+ " \n",
+ " loss = criterion(logits, labels)\n",
+ " \n",
+ " if train:\n",
+ " loss.backward()\n",
+ " optimizer.step()\n",
+ " \n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " acc = ((preds == labels).sum().item() / labels.shape[0])\n",
+ " \n",
+ " loss_save.update(loss, labels.shape[0])\n",
+ " acc_save.update(acc, labels.shape[0])\n",
+ " \n",
+ " results = {\n",
+ " 'loss': loss_save.avg,\n",
+ " 'acc': acc_save.avg * 100,\n",
+ " }\n",
+ " \n",
+ " return results\n",
+ " \n",
+ " \n",
+ "# loss function, optimizer 설정\n",
+ "criterion = nn.CrossEntropyLoss()\n",
+ "optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n",
+ "\n",
+ "for epoch in range(epochs):\n",
+ " print(f'< Epoch {epoch+1} / {epochs} >')\n",
+ " \n",
+ " # Train\n",
+ " model.train()\n",
+ " \n",
+ " train_results = train_epoch(train_loader, model, criterion, optimizer)\n",
+ " train_loss, train_acc = train_results['loss'], train_results['acc']\n",
+ " \n",
+ " # Validation\n",
+ " with torch.no_grad():\n",
+ " model.eval()\n",
+ " \n",
+ " val_results = train_epoch(val_loader, model, criterion, optimizer, False)\n",
+ " val_loss, val_acc = val_results['loss'], val_results['acc']\n",
+ " \n",
+ " \n",
+ " print(f'train_loss: {train_loss:.4f}, train_acc: {train_acc:.2f}%, val_loss: {val_loss:.4f}, val_acc: {val_acc:.2f}%')\n",
+ " print('=' * 100)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ongoing-tooth",
+ "metadata": {},
+ "source": [
+ "# 학습된 모델을 저장합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 27,
+ "id": "consistent-cassette",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "tokenizer.save_pretrained('./klue-bert-base-re')\n",
+ "model.save_pretrained('./klue-bert-base-re')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "nervous-positive",
+ "metadata": {},
+ "source": [
+ "# Validation 결과 확인\n",
+ "먼저 관계 없음에 해당하는 문장에 대한 예측 결과를 살펴보겠습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 28,
+ "id": "small-planet",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "\"20대 남성 A(26)씨가 아버지 치료비를 위해 B(30)씨가 모아둔 돈을 훔쳐 인터넷 방송 BJ에게 '별풍선'으로 쏜 사실이 알려졌다.\""
+ ]
+ },
+ "execution_count": 28,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "val_sentence = val_sentences[0]\n",
+ "\n",
+ "val_sentence"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 29,
+ "id": "pediatric-wednesday",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "val_encoding = tokenizer(val_sentence,\n",
+ " max_length=128,\n",
+ " padding='max_length',\n",
+ " truncation=True,\n",
+ " return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 30,
+ "id": "assigned-slovak",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "val_input = {\n",
+ " 'input_ids': val_encoding['input_ids'].to(device),\n",
+ " 'token_type_ids': val_encoding['token_type_ids'].to(device),\n",
+ " 'attention_mask': val_encoding['attention_mask'].to(device),\n",
+ "}\n",
+ "\n",
+ "model.eval()\n",
+ "output = model(**val_input)\n",
+ "label = torch.argmax(output['logits'], dim=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "prime-device",
+ "metadata": {},
+ "source": [
+ "0번 라벨은 \"관계 없음\"이므로 해당 문장에 대해서 예측이 정확하게 됐음을 알 수 있습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 31,
+ "id": "indoor-stuff",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([0], device='cuda:0')"
+ ]
+ },
+ "execution_count": 31,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "label"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "leading-rebecca",
+ "metadata": {},
+ "source": [
+ "다음으로 관계가 있는 문장에 대한 예측 결과를 살펴보겠습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 32,
+ "id": "architectural-engagement",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "'서울교통공사는 서울 노원구 석계역 무빙워크에 고의로 침을 바른 남성이 신종 코로나바이러스 감염증(코로나19) 검사에서 음성 판정이 나왔다고 21일 밝혔다.'"
+ ]
+ },
+ "execution_count": 32,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "val_sentence = val_sentences[13]\n",
+ "\n",
+ "val_sentence"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 33,
+ "id": "enormous-shift",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "val_encoding = tokenizer(val_sentence,\n",
+ " max_length=128,\n",
+ " padding='max_length',\n",
+ " truncation=True,\n",
+ " return_tensors='pt')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 34,
+ "id": "aware-savage",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "val_input = {\n",
+ " 'input_ids': val_encoding['input_ids'].to(device),\n",
+ " 'token_type_ids': val_encoding['token_type_ids'].to(device),\n",
+ " 'attention_mask': val_encoding['attention_mask'].to(device),\n",
+ "}\n",
+ "\n",
+ "model.eval()\n",
+ "output = model(**val_input)\n",
+ "label = torch.argmax(output['logits'], dim=1)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "found-option",
+ "metadata": {},
+ "source": [
+ "3번 라벨은 \"org:place_of_headquarters\" 이므로 해당 문장에 대해서 예측이 정확하게 됐음을 알 수 있습니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 35,
+ "id": "advanced-creek",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "tensor([3], device='cuda:0')"
+ ]
+ },
+ "execution_count": 35,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "label"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "continuous-professor",
+ "metadata": {},
+ "source": [
+ "# 모델 평가\n",
+ "학습된 모델에 대해서 평가를 해보도록 하겠습니다.\n",
+ "\n",
+ "KLUE 논문에 따르면 평가 지표로는 F1 score, AUPRC가 사용되었지만 본 Workspace에서는 F1 score 성능만 측정해보도록 하겠습니다.\n",
+ "\n",
+ "이 때, 0번 라벨(관계 없음)을 제외한 관계가 있는 라벨들에 대해서만 F1 score를 계산합니다."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 36,
+ "id": "surgical-hours",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def calc_f1_score(preds, labels):\n",
+ " \"\"\"\n",
+ " label이 0(관계 없음)이 아닌 예측 값에 대해서만 f1 score 계산.\n",
+ " \"\"\"\n",
+ " preds_relation = []\n",
+ " labels_relation = []\n",
+ " \n",
+ " for pred, label in zip(preds, labels):\n",
+ " if label != 0:\n",
+ " preds_relation.append(pred)\n",
+ " labels_relation.append(label)\n",
+ " \n",
+ " f1_score = sklearn.metrics.f1_score(labels_relation, preds_relation, average='micro', zero_division=1)\n",
+ " \n",
+ " return f1_score * 100"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 37,
+ "id": "literary-granny",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 971/971 [01:07<00:00, 14.45it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "with torch.no_grad():\n",
+ " model.eval()\n",
+ " \n",
+ " label_all = []\n",
+ " pred_all = []\n",
+ " for batch in tqdm(val_loader):\n",
+ " inputs = {\n",
+ " 'input_ids': batch['input_ids'].to(device),\n",
+ " 'token_type_ids': batch['token_type_ids'].to(device),\n",
+ " 'attention_mask': batch['attention_mask'].to(device),\n",
+ " }\n",
+ " labels = batch['labels'].to(device)\n",
+ " \n",
+ " outputs = model(**inputs)\n",
+ " logits = outputs['logits']\n",
+ " \n",
+ " preds = torch.argmax(logits, dim=1)\n",
+ " \n",
+ " label_all.extend(labels.detach().cpu().numpy().tolist())\n",
+ " pred_all.extend(preds.detach().cpu().numpy().tolist())\n",
+ " \n",
+ " f1_score = calc_f1_score(label_all, pred_all)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "functional-cotton",
+ "metadata": {},
+ "source": [
+ "약 63.58의 F1 score를 기록했습니다. 하이퍼파라미터, 모델 구조 등을 변경시켜 더 높은 F1 score를 기록해보세요!"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 38,
+ "id": "general-attitude",
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "63.58118361153262"
+ ]
+ },
+ "execution_count": 38,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "f1_score"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bigger-action",
+ "metadata": {},
+ "source": [
+ "학습된 모델은 Huggingface Model Hub에 배포되어 있고, 언제든지 다운받아서 활용이 가능합니다.\n",
+ "\n",
+ "모델 로드 후 Inference 방법은 해당 링크에 있는 설명부분에 작성되어 있습니다."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "capable-found",
+ "metadata": {},
+ "source": [
+ "# Reference"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "noble-first",
+ "metadata": {},
+ "source": [
+ "### klue-transformers-tutorial\n",
+ "\n",
+ "### Fine-tuning with custom datasets"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.7.9"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}