diff --git a/klue_re.ipynb b/klue_re.ipynb new file mode 100644 index 0000000..389e40e --- /dev/null +++ b/klue_re.ipynb @@ -0,0 +1,1597 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "impressive-sociology", + "metadata": {}, + "source": [ + "# BERT를 활용한 관계추출(Relation Extraction, RE)\n", + "본 Workspace에서는 klue/bert-base모델을 이용하여 KLUE 내의 8개 Task 중 관계추출(Relation Extraction) Task에 대해서 다룹니다.\n", + "\n", + "먼저 관계추출 Task란 문장에 있는 두 개체(entity)간의 관계가 무엇인지 분류하는 것입니다.\n", + "\n", + "따라서 문장과 문장에 있는 두 개체들이 입력으로 주어지면 두 개체들간의 관계가 출력으로 나옵니다.\n", + "\n", + "![Imgur](https://i.imgur.com/xeTQRVC.png)\n", + "\n", + "사진과 같이 개체는 subject entity와 object entity가 있는데, subject entity가 사람(person)이면 \":\" 앞에는 per, 기관(organization)이면 org이 됩니다.\n", + "\n", + "그리고 \":\" 뒤에 있는 것은 object entity가 subject entity와 무슨 관계인지를 나타냅니다.\n", + "\n", + "위의 사진과 같이 subject entity가 사람(person)이고, object entity가 subject entity의 출생지(origin)이므로 관계는 per:origin이 됩니다.\n", + "\n", + "KLUE 관계추출 Task에 있는 관계는 총 30개이고, 관계일부를 캡처한 사진은 다음과 같습니다.\n", + "\n", + "![Imgur](https://i.imgur.com/TnfSUPo.png)\n", + "\n", + "만약 전체 관계 목록을 보고 싶으시면 링크를 확인해보세요!" + ] + }, + { + "cell_type": "markdown", + "id": "august-garden", + "metadata": {}, + "source": [ + "# 필요한 라이브러리를 설치합니다.\n", + "datasets은 KLUE 데이터셋을 가져오기 위해, sklearn은 학습한 모델을 평가하기 위해 필요합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "collective-number", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: datasets in /opt/conda/lib/python3.7/site-packages (1.8.0)\n", + "Requirement already satisfied: huggingface-hub<0.1.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (0.0.13)\n", + "Requirement already satisfied: requests>=2.19.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (2.25.1)\n", + "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from datasets) (3.7.3)\n", + "Requirement already satisfied: pandas in /opt/conda/lib/python3.7/site-packages (from datasets) (1.2.5)\n", + "Requirement already satisfied: multiprocess in /opt/conda/lib/python3.7/site-packages (from datasets) (0.70.12.2)\n", + "Requirement already satisfied: numpy>=1.17 in /opt/conda/lib/python3.7/site-packages (from datasets) (1.18.5)\n", + "Requirement already satisfied: pyarrow<4.0.0,>=1.0.0 in /opt/conda/lib/python3.7/site-packages (from datasets) (3.0.0)\n", + "Requirement already satisfied: tqdm<4.50.0,>=4.27 in /opt/conda/lib/python3.7/site-packages (from datasets) (4.49.0)\n", + "Requirement already satisfied: xxhash in /opt/conda/lib/python3.7/site-packages (from datasets) (2.0.2)\n", + "Requirement already satisfied: fsspec in /opt/conda/lib/python3.7/site-packages (from datasets) (2021.6.1)\n", + "Requirement already satisfied: packaging in /opt/conda/lib/python3.7/site-packages (from datasets) (20.9)\n", + "Requirement already satisfied: dill in /opt/conda/lib/python3.7/site-packages (from datasets) (0.3.4)\n", + "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.7/site-packages (from huggingface-hub<0.1.0->datasets) (3.7.4.3)\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from huggingface-hub<0.1.0->datasets) (3.0.12)\n", + "Requirement already satisfied: pyparsing>=2.0.2 in /opt/conda/lib/python3.7/site-packages (from packaging->datasets) (2.4.7)\n", + "Requirement already satisfied: chardet<5,>=3.0.2 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (4.0.0)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (2020.12.5)\n", + "Requirement already satisfied: idna<3,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (2.10)\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests>=2.19.0->datasets) (1.26.4)\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->datasets) (3.4.1)\n", + "Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets) (2.8.1)\n", + "Requirement already satisfied: pytz>=2017.3 in /opt/conda/lib/python3.7/site-packages (from pandas->datasets) (2021.1)\n", + "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.7/site-packages (from python-dateutil>=2.7.3->pandas->datasets) (1.15.0)\n", + "Requirement already satisfied: sklearn in /opt/conda/lib/python3.7/site-packages (0.0)\n", + "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from sklearn) (0.24.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (2.1.0)\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.0.1)\n", + "Requirement already satisfied: scipy>=0.19.1 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.4.1)\n", + "Requirement already satisfied: numpy>=1.13.3 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sklearn) (1.18.5)\n" + ] + } + ], + "source": [ + "!pip install datasets\n", + "!pip install sklearn" + ] + }, + { + "cell_type": "markdown", + "id": "organic-valley", + "metadata": {}, + "source": [ + "# 필요한 라이브러리를 import 합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "looking-relevance", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import sklearn.metrics\n", + "\n", + "from tqdm import tqdm\n", + "from datasets import load_dataset\n", + "from datasets.arrow_dataset import Dataset\n", + "from transformers import AutoTokenizer, AutoModel, AdamW\n", + "from transformers.models.bert.modeling_bert import BertPreTrainedModel, BertForSequenceClassification\n", + "from torch.utils.data import DataLoader" + ] + }, + { + "cell_type": "markdown", + "id": "smaller-truth", + "metadata": {}, + "source": [ + "# GPU 사용을 위해 device를 설정합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "musical-concord", + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")" + ] + }, + { + "cell_type": "markdown", + "id": "valued-delta", + "metadata": {}, + "source": [ + "# KLUE RE 데이터셋을 가져옵니다." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "induced-breast", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Reusing dataset klue (/workspace/.cache/huggingface/datasets/klue/re/1.0.0/55ff8f92b7a4b9842be6514ce0b4b5295b46d5e493f8bb5760da4be717018f90)\n" + ] + } + ], + "source": [ + "dataset = load_dataset('klue', 're')" + ] + }, + { + "cell_type": "markdown", + "id": "higher-compact", + "metadata": {}, + "source": [ + "# 데이터셋은 train과 validation 데이터로 구성되어 있습니다. \n", + "train 데이터들은 모델을 학습할 때 사용될 예정이고, validation 데이터들은 학습이 아닌 모델의 성능을 평가할 때 사용됩니다.\n", + "\n", + "train 데이터와 validation 데이터의 구성은 동일하므로 train 데이터의 구성만 살펴보도록 하겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "respective-labor", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DatasetDict({\n", + " train: Dataset({\n", + " features: ['guid', 'label', 'object_entity', 'sentence', 'source', 'subject_entity'],\n", + " num_rows: 32470\n", + " })\n", + " validation: Dataset({\n", + " features: ['guid', 'label', 'object_entity', 'sentence', 'source', 'subject_entity'],\n", + " num_rows: 7765\n", + " })\n", + "})" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "cell_type": "markdown", + "id": "increasing-clarity", + "metadata": {}, + "source": [ + "# 데이터 구성\n", + "각 데이터들은 다음과 같이 문장과 관계 추출에 사용될 2개의 개체(object_entity, subject_entity)와 두 개체간의 관계를 라벨로 가지고 있습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "assisted-brooklyn", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'guid': 'klue-re-v1_train_00000',\n", + " 'label': 0,\n", + " 'object_entity': {'word': '조지 해리슨',\n", + " 'start_idx': 13,\n", + " 'end_idx': 18,\n", + " 'type': 'PER'},\n", + " 'sentence': '〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.',\n", + " 'source': 'wikipedia',\n", + " 'subject_entity': {'word': '비틀즈',\n", + " 'start_idx': 24,\n", + " 'end_idx': 26,\n", + " 'type': 'ORG'}}" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset['train'][0]" + ] + }, + { + "cell_type": "markdown", + "id": "established-platform", + "metadata": {}, + "source": [ + "# train데이터에서 각 label의 수를 살펴보겠습니다.\n", + "출력 결과를 보시면 30개의 라벨에 대해서 불균형이 심한 것을 확인할 수 있습니다.\n", + "\n", + "따라서 모델의 성능을 측정할 평가지표로 단순하게 Accuracy를 사용한다면 정확한 평가가 이루어질 수 없기 때문에 KLUE 논문에 따르면 평가지표로 F1 score와 AUPRC를 이용하였습니다.\n", + "\n", + "\n", + "또한 관계 없음(0번 라벨) 데이터가 많은 비중을 차지하고 있는데, 모델이 \"관계 없음\"을 예측하는 데에 많은 초점이 맞춰지지 않도록\n", + "\n", + "관계 없음에 해당하는 데이터를 제외하고 관계가 있는 데이터들에 대해서만 F1 score을 계산합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "english-royalty", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{0: 9534,\n", + " 1: 66,\n", + " 2: 450,\n", + " 3: 1195,\n", + " 4: 1320,\n", + " 5: 1866,\n", + " 6: 420,\n", + " 7: 98,\n", + " 8: 380,\n", + " 9: 155,\n", + " 10: 4284,\n", + " 11: 48,\n", + " 12: 1130,\n", + " 13: 418,\n", + " 14: 166,\n", + " 15: 40,\n", + " 16: 193,\n", + " 17: 1234,\n", + " 18: 3573,\n", + " 19: 82,\n", + " 20: 1001,\n", + " 21: 520,\n", + " 22: 304,\n", + " 23: 136,\n", + " 24: 795,\n", + " 25: 190,\n", + " 26: 534,\n", + " 27: 139,\n", + " 28: 96,\n", + " 29: 2103}" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "label_count = {}\n", + "\n", + "for data in dataset['train']:\n", + " label = data['label']\n", + " if label not in label_count:\n", + " label_count[label] = 1\n", + " else:\n", + " label_count[label] += 1\n", + "\n", + "label_count = dict(sorted(label_count.items(), key=lambda x: x[0]))\n", + "label_count" + ] + }, + { + "cell_type": "markdown", + "id": "accredited-weapon", + "metadata": {}, + "source": [ + "# 개체 양 끝에 special token을 추가합니다.\n", + "KLUE 논문에 따르면 object 개체의 양 끝에는 \\, \\을, subject 개체의 양 끝에는 \\, \\ 토큰을 추가하여 개체의 위치를 표시한 후에 모델의 입력으로 주어집니다.\n", + "\n", + "따라서 데이터에 있는 entity index를 이용해서 해당 토큰을 추가해줍니다.\n", + "\n", + "토큰 추가의 예시 사진은 다음과 같습니다. 사진에서 빨간색으로 표시된 토큰들이 추가가 되는 토큰들입니다.\n", + "\n", + "![Imgur](https://i.imgur.com/gWNeyLv.png)\n", + "\n", + "그리고 학습에 필요한 데이터는 문장과 라벨 정보이므로 해당 부분만 가져오도록 합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "surface-squad", + "metadata": {}, + "outputs": [], + "source": [ + "def add_entity_tokens(sentence, object_entity, subject_entity):\n", + " obj_start_idx, obj_end_idx = object_entity['start_idx'], object_entity['end_idx']\n", + " subj_start_idx, subj_end_idx = subject_entity['start_idx'], subject_entity['end_idx']\n", + " \n", + " if obj_start_idx < subj_start_idx:\n", + " new_sentence = sentence[:obj_start_idx] + '' + sentence[obj_start_idx:obj_end_idx+1] + '' + \\\n", + " sentence[obj_end_idx+1:subj_start_idx] + '' + sentence[subj_start_idx:subj_end_idx+1] + \\\n", + " '' + sentence[subj_end_idx+1:]\n", + " else:\n", + " new_sentence = sentence[:subj_start_idx] + '' + sentence[subj_start_idx:subj_end_idx+1] + '' + \\\n", + " sentence[subj_end_idx+1:obj_start_idx] + '' + sentence[obj_start_idx:obj_end_idx+1] + \\\n", + " '' + sentence[obj_end_idx+1:]\n", + " \n", + " return new_sentence\n", + "\n", + "\n", + "def read_klue_re(dataset):\n", + " sentences = []\n", + " labels = []\n", + " \n", + " if isinstance(dataset, Dataset):\n", + " for data in dataset:\n", + " sentence = add_entity_tokens(data['sentence'], data['object_entity'], data['subject_entity'])\n", + " sentences.append(sentence)\n", + " labels.append(data['label'])\n", + " \n", + " return sentences, labels" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "efficient-absorption", + "metadata": {}, + "outputs": [], + "source": [ + "# train, validation데이터셋에서 sentence와 label만 저장.\n", + "train_sentences, train_labels = read_klue_re(dataset['train'])\n", + "val_sentences, val_labels = read_klue_re(dataset['validation'])" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "exclusive-travel", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다. \n", + "\n", + "호남이 기반인 바른미래당·대안신당·민주평화당이 우여곡절 끝에 합당해 민생당(가칭)으로 재탄생한다. \n", + "\n", + "K리그2에서 성적 1위를 달리고 있는 광주FC는 지난 26일 한국프로축구연맹으로부터 관중 유치 성과와 마케팅 성과를 인정받아 ‘풀 스타디움상’과 ‘플러스 스타디움상’을 수상했다. \n", + "\n", + "균일가 생활용품점 (주)아성다이소(대표 박정부)는 코로나19 바이러스로 어려움을 겪고 있는 대구광역시에 행복박스를 전달했다고 10일 밝혔다. \n", + "\n", + "1967년 프로 야구 드래프트 1순위로 요미우리 자이언츠에게 입단하면서 등번호는 8번으로 배정되었다. \n", + "\n" + ] + } + ], + "source": [ + "# 개체 토큰이 정상적으로 잘 추가됐는지 확인하기 위해 train 문장 5개만 출력.\n", + "for i, sentence in enumerate(train_sentences[:5]):\n", + " print(sentence, '\\n')" + ] + }, + { + "cell_type": "markdown", + "id": "pointed-earthquake", + "metadata": {}, + "source": [ + "# klue/bert-base 모델을 사용할 예정이므로 모델에 맞는 tokenizer를 가져옵니다." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "southwest-retrieval", + "metadata": {}, + "outputs": [], + "source": [ + "model_name = 'klue/bert-base'" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "color-footwear", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer = AutoTokenizer.from_pretrained(model_name)" + ] + }, + { + "cell_type": "markdown", + "id": "recreational-given", + "metadata": {}, + "source": [ + "# tokenizer를 이용한 토큰화 결과가 어떻게 나오는지 살펴보도록 하겠습니다.\n", + "예시 문장으로는 첫 번째 train 데이터의 문장을 이용하도록 하겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "exact-senegal", + "metadata": {}, + "outputs": [], + "source": [ + "ex_sentence = dataset['train'][0]['sentence']" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "fatty-finance", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'〈Something〉는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《Abbey Road》에 담은 노래다.'" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ex_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "electoral-witch", + "metadata": {}, + "outputs": [], + "source": [ + "ex_encoding = tokenizer(ex_sentence,\n", + " max_length=128,\n", + " padding='max_length',\n", + " truncation=True)" + ] + }, + { + "cell_type": "markdown", + "id": "metropolitan-bathroom", + "metadata": {}, + "source": [ + "토큰화 결과로 Bert모델의 입력으로 필요한 input_ids, token_type_ids, attention_mask가 나오는 것을 확인할 수 있습니다.\n", + "\n", + "3가지 값에 대한 설명은 링크에서 확인할 수 있습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "pretty-luther", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'input_ids': [2, 168, 30985, 14451, 7088, 4586, 169, 793, 8373, 14113, 2234, 2052, 1363, 2088, 29830, 2116, 14879, 2440, 6711, 170, 21406, 26713, 2076, 25145, 5749, 171, 1421, 818, 2073, 4388, 2062, 18, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ex_encoding" + ] + }, + { + "cell_type": "markdown", + "id": "wrong-corpus", + "metadata": {}, + "source": [ + "토큰화 된 문장을 다시 디코딩 해봄으로써 원본 문장을 얻을 수 있는지 확인해봅니다.\n", + "\n", + "디코딩 결과를 살펴보면 원본 문장 뒤에 [PAD] 토큰을 통해 입력 토큰의 개수가 max_length가 되도록 맞춥니다." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "cellular-arrow", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'[CLS] 〈 Something 〉 는 조지 해리슨이 쓰고 비틀즈가 1969년 앨범 《 Abbey Road 》 에 담은 노래다. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenizer.decode(ex_encoding['input_ids'])" + ] + }, + { + "cell_type": "markdown", + "id": "variable-channels", + "metadata": {}, + "source": [ + "# Special token 추가\n", + "위에서 개체에 맞게 4개의 토큰(\\, \\, \\, \\)을 문장에 추가해주었는데, 해당 토큰들을 tokenizer에 special token이라고 알려주지 않으면 추가한 토큰들은 일반 문자로 인식되어서 토큰화될 수 있습니다. \n", + "\n", + "따라서 토큰화가 되지 않도록 추가한 4개의 토큰을 tokenizer에 special token으로 추가 해 줍니다." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "matched-investor", + "metadata": {}, + "outputs": [], + "source": [ + "entity_special_tokens = {'additional_special_tokens': ['', '', '', '']}\n", + "num_additional_special_tokens = tokenizer.add_special_tokens(entity_special_tokens)" + ] + }, + { + "cell_type": "markdown", + "id": "secondary-paris", + "metadata": {}, + "source": [ + "데이터로더 및 학습에 필요한 값들을 설정합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "developmental-intersection", + "metadata": {}, + "outputs": [], + "source": [ + "# For Dataloader\n", + "batch_size = 8\n", + "\n", + "# For model\n", + "num_labels = 30\n", + "\n", + "# For train\n", + "learning_rate = 1e-5\n", + "weight_decay = 0.0\n", + "epochs = 3" + ] + }, + { + "cell_type": "markdown", + "id": "literary-engagement", + "metadata": {}, + "source": [ + "# 학습에 이용할 데이터셋과 데이터로더를 만들어 줍니다." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "trying-consent", + "metadata": {}, + "outputs": [], + "source": [ + "class KlueReDataset(torch.utils.data.Dataset):\n", + " def __init__(self, tokenizer, sentences, labels, max_length=128):\n", + " self.encodings = tokenizer(sentences,\n", + " max_length=max_length,\n", + " padding='max_length',\n", + " truncation=True)\n", + " self.labels = labels\n", + " \n", + " def __getitem__(self, idx):\n", + " item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}\n", + " item['labels'] = self.labels[idx]\n", + " \n", + " return item\n", + " \n", + " def __len__(self):\n", + " return len(self.labels)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "thick-effectiveness", + "metadata": {}, + "outputs": [], + "source": [ + "train_dataset = KlueReDataset(tokenizer, train_sentences, train_labels)\n", + "val_dataset = KlueReDataset(tokenizer, val_sentences, val_labels)\n", + "\n", + "train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", + "val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)" + ] + }, + { + "cell_type": "markdown", + "id": "electronic-payday", + "metadata": {}, + "source": [ + "# klue/bert-base 모델을 로드합니다.\n", + "\n", + "저희가 다루는 관계추출 Task는 30개의 관계(클래스)를 분류하는 것이라고 할 수 있습니다.\n", + "\n", + "이를 위해 [CLS] 토큰의 벡터를 출력의 차원이 30인 1개의 Linear Layer에 통과시켜서 30개의 클래스로 분류하는 모델을 만들겠습니다.\n", + "\n", + "이를 간단히 사진으로 나타내면 다음과 같습니다.\n", + "\n", + "![Imgur](https://i.imgur.com/qaUObkV.png)\n", + "\n", + "모델을 로드할 때 Warning이 발생하는데 \"모델에서 가중치가 초기화되지 않은 부분이 있으니 예측을 하기 위해서는 모델을 학습시킨 후 사용해야된다.\"라는 내용을 담고 있습니다.\n", + "\n", + "저희는 이후에 모델을 fine-tuning할 것이기 때문에 해당 메세지를 신경쓰지 않아도 됩니다." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "rocky-justice", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Some weights of the model checkpoint at klue/bert-base were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.decoder.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']\n", + "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", + "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", + "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at klue/bert-base and are newly initialized: ['classifier.weight', 'classifier.bias']\n", + "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" + ] + } + ], + "source": [ + "model = BertForSequenceClassification.from_pretrained(model_name, num_labels=num_labels).to(device)" + ] + }, + { + "cell_type": "markdown", + "id": "electronic-asian", + "metadata": {}, + "source": [ + "모델 구성의 마지막에 있는 classifier 부분을 보시면 출력 차원이 30으로 설정되어 있는 것을 확인할 수 있습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "attached-particle", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BertForSequenceClassification(\n", + " (bert): BertModel(\n", + " (embeddings): BertEmbeddings(\n", + " (word_embeddings): Embedding(32000, 768, padding_idx=0)\n", + " (position_embeddings): Embedding(512, 768)\n", + " (token_type_embeddings): Embedding(2, 768)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (encoder): BertEncoder(\n", + " (layer): ModuleList(\n", + " (0): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (1): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (2): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (3): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (4): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (5): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (6): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (7): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (8): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (9): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (10): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (11): BertLayer(\n", + " (attention): BertAttention(\n", + " (self): BertSelfAttention(\n", + " (query): Linear(in_features=768, out_features=768, bias=True)\n", + " (key): Linear(in_features=768, out_features=768, bias=True)\n", + " (value): Linear(in_features=768, out_features=768, bias=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " (output): BertSelfOutput(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " (intermediate): BertIntermediate(\n", + " (dense): Linear(in_features=768, out_features=3072, bias=True)\n", + " )\n", + " (output): BertOutput(\n", + " (dense): Linear(in_features=3072, out_features=768, bias=True)\n", + " (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " )\n", + " )\n", + " )\n", + " )\n", + " (pooler): BertPooler(\n", + " (dense): Linear(in_features=768, out_features=768, bias=True)\n", + " (activation): Tanh()\n", + " )\n", + " )\n", + " (dropout): Dropout(p=0.1, inplace=False)\n", + " (classifier): Linear(in_features=768, out_features=30, bias=True)\n", + ")" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "markdown", + "id": "alternate-import", + "metadata": {}, + "source": [ + "# Bert Embedding Layer을 resize합니다.\n", + "\n", + "Bert에는 토큰들의 id에 따른 임베딩 값을 반환하는 Embedding Layer가 존재합니다.\n", + "\n", + "하지만 현재 Embedding Layer에는 위에서 추가한 4개의 토큰에 대한 정보가 반영되지 않았기 때문에 추가한 토큰들이 입력으로 주어질 경우 index error가 발생합니다. \n", + "\n", + "따라서 Bert의 Embedding resize해줍니다.\n", + "\n", + "Resize를 하게되면 Embedding Layer의 input 차원이 32000에서 32004로 4만큼 증가합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "stunning-worthy", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Embedding(32004, 768)" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model.resize_token_embeddings(len(tokenizer))" + ] + }, + { + "cell_type": "markdown", + "id": "promising-yesterday", + "metadata": {}, + "source": [ + "학습 도중 Loss, Accuracy 계산 및 저장을 간단하게 하기 위해 AverageMeter를 클래스를 이용합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "hourly-bridges", + "metadata": {}, + "outputs": [], + "source": [ + "class AverageMeter():\n", + " def __init__(self):\n", + " self.val = 0\n", + " self.avg = 0\n", + " self.sum = 0\n", + " self.count = 0\n", + "\n", + " def update(self, val, n=1):\n", + " self.val = val\n", + " self.sum += val * n\n", + " self.count += n\n", + " self.avg = self.sum / self.count" + ] + }, + { + "cell_type": "markdown", + "id": "laughing-electricity", + "metadata": {}, + "source": [ + "# Model fine-tuning\n", + "klue/bert-base 모델을 fine-tuning합니다.\n", + "\n", + "학습동안 학습이 잘 진행되고 있는지 확인하기 위해 Loss와 Accuracy를 출력합니다.\n", + "\n", + "Tesla T4 기준으로 1 epoch 당 약 14분 정도가 소요됩니다." + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "silver-algeria", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + " 0%| | 0/4059 [00:00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 4059/4059 [14:11<00:00, 4.77it/s]\n", + "100%|██████████| 971/971 [01:07<00:00, 14.30it/s]\n", + " 0%| | 0/4059 [00:00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 4059/4059 [14:10<00:00, 4.77it/s]\n", + "100%|██████████| 971/971 [01:07<00:00, 14.30it/s]\n", + " 0%| | 0/4059 [00:00\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 4059/4059 [14:10<00:00, 4.77it/s]\n", + "100%|██████████| 971/971 [01:07<00:00, 14.32it/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "train_loss: 0.3608, train_acc: 87.74%, val_loss: 0.7518, val_acc: 75.54%\n", + "====================================================================================================\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "def train_epoch(data_loader, model, criterion, optimizer, train=True):\n", + " loss_save = AverageMeter()\n", + " acc_save = AverageMeter()\n", + " \n", + " loop = tqdm(enumerate(data_loader), total=len(data_loader))\n", + " for _, batch in loop:\n", + " inputs = {\n", + " 'input_ids': batch['input_ids'].to(device),\n", + " 'token_type_ids': batch['token_type_ids'].to(device),\n", + " 'attention_mask': batch['attention_mask'].to(device),\n", + " }\n", + " labels = batch['labels'].to(device)\n", + " \n", + " optimizer.zero_grad()\n", + " outputs = model(**inputs)\n", + " logits = outputs['logits']\n", + " \n", + " loss = criterion(logits, labels)\n", + " \n", + " if train:\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " preds = torch.argmax(logits, dim=1)\n", + " acc = ((preds == labels).sum().item() / labels.shape[0])\n", + " \n", + " loss_save.update(loss, labels.shape[0])\n", + " acc_save.update(acc, labels.shape[0])\n", + " \n", + " results = {\n", + " 'loss': loss_save.avg,\n", + " 'acc': acc_save.avg * 100,\n", + " }\n", + " \n", + " return results\n", + " \n", + " \n", + "# loss function, optimizer 설정\n", + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)\n", + "\n", + "for epoch in range(epochs):\n", + " print(f'< Epoch {epoch+1} / {epochs} >')\n", + " \n", + " # Train\n", + " model.train()\n", + " \n", + " train_results = train_epoch(train_loader, model, criterion, optimizer)\n", + " train_loss, train_acc = train_results['loss'], train_results['acc']\n", + " \n", + " # Validation\n", + " with torch.no_grad():\n", + " model.eval()\n", + " \n", + " val_results = train_epoch(val_loader, model, criterion, optimizer, False)\n", + " val_loss, val_acc = val_results['loss'], val_results['acc']\n", + " \n", + " \n", + " print(f'train_loss: {train_loss:.4f}, train_acc: {train_acc:.2f}%, val_loss: {val_loss:.4f}, val_acc: {val_acc:.2f}%')\n", + " print('=' * 100)" + ] + }, + { + "cell_type": "markdown", + "id": "ongoing-tooth", + "metadata": {}, + "source": [ + "# 학습된 모델을 저장합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "consistent-cassette", + "metadata": {}, + "outputs": [], + "source": [ + "tokenizer.save_pretrained('./klue-bert-base-re')\n", + "model.save_pretrained('./klue-bert-base-re')" + ] + }, + { + "cell_type": "markdown", + "id": "nervous-positive", + "metadata": {}, + "source": [ + "# Validation 결과 확인\n", + "먼저 관계 없음에 해당하는 문장에 대한 예측 결과를 살펴보겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "small-planet", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"20대 남성 A(26)씨가 아버지 치료비를 위해 B(30)씨가 모아둔 돈을 훔쳐 인터넷 방송 BJ에게 '별풍선'으로 쏜 사실이 알려졌다.\"" + ] + }, + "execution_count": 28, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val_sentence = val_sentences[0]\n", + "\n", + "val_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "pediatric-wednesday", + "metadata": {}, + "outputs": [], + "source": [ + "val_encoding = tokenizer(val_sentence,\n", + " max_length=128,\n", + " padding='max_length',\n", + " truncation=True,\n", + " return_tensors='pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "assigned-slovak", + "metadata": {}, + "outputs": [], + "source": [ + "val_input = {\n", + " 'input_ids': val_encoding['input_ids'].to(device),\n", + " 'token_type_ids': val_encoding['token_type_ids'].to(device),\n", + " 'attention_mask': val_encoding['attention_mask'].to(device),\n", + "}\n", + "\n", + "model.eval()\n", + "output = model(**val_input)\n", + "label = torch.argmax(output['logits'], dim=1)" + ] + }, + { + "cell_type": "markdown", + "id": "prime-device", + "metadata": {}, + "source": [ + "0번 라벨은 \"관계 없음\"이므로 해당 문장에 대해서 예측이 정확하게 됐음을 알 수 있습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "indoor-stuff", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([0], device='cuda:0')" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "label" + ] + }, + { + "cell_type": "markdown", + "id": "leading-rebecca", + "metadata": {}, + "source": [ + "다음으로 관계가 있는 문장에 대한 예측 결과를 살펴보겠습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "architectural-engagement", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'서울교통공사서울 노원구 석계역 무빙워크에 고의로 침을 바른 남성이 신종 코로나바이러스 감염증(코로나19) 검사에서 음성 판정이 나왔다고 21일 밝혔다.'" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "val_sentence = val_sentences[13]\n", + "\n", + "val_sentence" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "enormous-shift", + "metadata": {}, + "outputs": [], + "source": [ + "val_encoding = tokenizer(val_sentence,\n", + " max_length=128,\n", + " padding='max_length',\n", + " truncation=True,\n", + " return_tensors='pt')" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "aware-savage", + "metadata": {}, + "outputs": [], + "source": [ + "val_input = {\n", + " 'input_ids': val_encoding['input_ids'].to(device),\n", + " 'token_type_ids': val_encoding['token_type_ids'].to(device),\n", + " 'attention_mask': val_encoding['attention_mask'].to(device),\n", + "}\n", + "\n", + "model.eval()\n", + "output = model(**val_input)\n", + "label = torch.argmax(output['logits'], dim=1)" + ] + }, + { + "cell_type": "markdown", + "id": "found-option", + "metadata": {}, + "source": [ + "3번 라벨은 \"org:place_of_headquarters\" 이므로 해당 문장에 대해서 예측이 정확하게 됐음을 알 수 있습니다." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "advanced-creek", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([3], device='cuda:0')" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "label" + ] + }, + { + "cell_type": "markdown", + "id": "continuous-professor", + "metadata": {}, + "source": [ + "# 모델 평가\n", + "학습된 모델에 대해서 평가를 해보도록 하겠습니다.\n", + "\n", + "KLUE 논문에 따르면 평가 지표로는 F1 score, AUPRC가 사용되었지만 본 Workspace에서는 F1 score 성능만 측정해보도록 하겠습니다.\n", + "\n", + "이 때, 0번 라벨(관계 없음)을 제외한 관계가 있는 라벨들에 대해서만 F1 score를 계산합니다." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "surgical-hours", + "metadata": {}, + "outputs": [], + "source": [ + "def calc_f1_score(preds, labels):\n", + " \"\"\"\n", + " label이 0(관계 없음)이 아닌 예측 값에 대해서만 f1 score 계산.\n", + " \"\"\"\n", + " preds_relation = []\n", + " labels_relation = []\n", + " \n", + " for pred, label in zip(preds, labels):\n", + " if label != 0:\n", + " preds_relation.append(pred)\n", + " labels_relation.append(label)\n", + " \n", + " f1_score = sklearn.metrics.f1_score(labels_relation, preds_relation, average='micro', zero_division=1)\n", + " \n", + " return f1_score * 100" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "literary-granny", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 971/971 [01:07<00:00, 14.45it/s]\n" + ] + } + ], + "source": [ + "with torch.no_grad():\n", + " model.eval()\n", + " \n", + " label_all = []\n", + " pred_all = []\n", + " for batch in tqdm(val_loader):\n", + " inputs = {\n", + " 'input_ids': batch['input_ids'].to(device),\n", + " 'token_type_ids': batch['token_type_ids'].to(device),\n", + " 'attention_mask': batch['attention_mask'].to(device),\n", + " }\n", + " labels = batch['labels'].to(device)\n", + " \n", + " outputs = model(**inputs)\n", + " logits = outputs['logits']\n", + " \n", + " preds = torch.argmax(logits, dim=1)\n", + " \n", + " label_all.extend(labels.detach().cpu().numpy().tolist())\n", + " pred_all.extend(preds.detach().cpu().numpy().tolist())\n", + " \n", + " f1_score = calc_f1_score(label_all, pred_all)" + ] + }, + { + "cell_type": "markdown", + "id": "functional-cotton", + "metadata": {}, + "source": [ + "약 63.58의 F1 score를 기록했습니다. 하이퍼파라미터, 모델 구조 등을 변경시켜 더 높은 F1 score를 기록해보세요!" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "general-attitude", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "63.58118361153262" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f1_score" + ] + }, + { + "cell_type": "markdown", + "id": "bigger-action", + "metadata": {}, + "source": [ + "학습된 모델은 Huggingface Model Hub에 배포되어 있고, 언제든지 다운받아서 활용이 가능합니다.\n", + "\n", + "모델 로드 후 Inference 방법은 해당 링크에 있는 설명부분에 작성되어 있습니다." + ] + }, + { + "cell_type": "markdown", + "id": "capable-found", + "metadata": {}, + "source": [ + "# Reference" + ] + }, + { + "cell_type": "markdown", + "id": "noble-first", + "metadata": {}, + "source": [ + "### klue-transformers-tutorial\n", + "\n", + "### Fine-tuning with custom datasets" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}