From 649b8e68235a5fa3e29eaaa2a7a7a75a4eba6187 Mon Sep 17 00:00:00 2001 From: Zhao HG <853842+CyberZHG@users.noreply.github.com> Date: Sat, 22 Jun 2019 00:30:03 +0800 Subject: [PATCH] Add datasets (#93) * Add datasets * Update demos * Update README --- README.md | 12 +++++++ README.zh-CN.md | 12 +++++++ demo/load_model/load_and_extract.py | 23 ++++++------- demo/load_model/load_and_pool.py | 23 ++++++------- demo/load_model/load_and_predict.py | 23 ++++++------- keras_bert/__init__.py | 1 + keras_bert/datasets/__init__.py | 1 + keras_bert/datasets/pretrained.py | 47 +++++++++++++++++++++++++++ setup.py | 2 +- tests/datasets/__init__.py | 0 tests/datasets/test_get_pretrained.py | 10 ++++++ 11 files changed, 114 insertions(+), 40 deletions(-) create mode 100644 keras_bert/datasets/__init__.py create mode 100644 keras_bert/datasets/pretrained.py create mode 100644 tests/datasets/__init__.py create mode 100644 tests/datasets/test_get_pretrained.py diff --git a/README.md b/README.md index d7d3916..e4d6eb2 100644 --- a/README.md +++ b/README.md @@ -160,6 +160,18 @@ total_steps, warmup_steps = calc_train_steps( optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5) ``` +### Download Pretrained Checkpoints + +Several download urls has been added. You can get the downloaded and uncompressed path of a checkpoint by: + +```python +from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths + +model_path = get_pretrained(PretrainedList.multi_cased_base) +paths = get_checkpoint_paths(model_path) +print(paths.config, paths.checkpoint, paths.vocab) +``` + ### Extract Features You can use helper function `extract_embeddings` if the features of tokens or sentences (without further tuning) are what you need. To extract the features of all tokens: diff --git a/README.zh-CN.md b/README.zh-CN.md index da17cdc..5e73c57 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -173,6 +173,18 @@ optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5) 在`training`为`True`的情况下,输入包含三项:token下标、segment下标、被masked的词的模版。当`training`为`False`时输入只包含前两项。位置下标由于是固定的,会在模型内部生成,不需要手动再输入一遍。被masked的词的模版在输入被masked的词是值为1,否则为0。 +### 下载预训练模型 + +库中记录了一些预训练模型的下载地址,可以通过如下方式获得解压后的checkpoint的路径: + +```python +from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths + +model_path = get_pretrained(PretrainedList.multi_cased_base) +paths = get_checkpoint_paths(model_path) +print(paths.config, paths.checkpoint, paths.vocab) +``` + ### 提取特征 如果不需要微调,只想提取词/句子的特征,则可以使用`extract_embeddings`来简化流程。如提取每个句子对应的全部词的特征: diff --git a/demo/load_model/load_and_extract.py b/demo/load_model/load_and_extract.py index f1b3e44..7a631b0 100644 --- a/demo/load_model/load_and_extract.py +++ b/demo/load_model/load_and_extract.py @@ -1,24 +1,21 @@ -import os import sys import numpy as np -from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer - - -if len(sys.argv) != 2: - print('python load_model.py UNZIPPED_MODEL_PATH') - sys.exit(-1) +from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths print('This demo demonstrates how to load the pre-trained model and extract word embeddings') -model_path = sys.argv[1] -config_path = os.path.join(model_path, 'bert_config.json') -checkpoint_path = os.path.join(model_path, 'bert_model.ckpt') -dict_path = os.path.join(model_path, 'vocab.txt') +if len(sys.argv) == 2: + model_path = sys.argv[1] +else: + from keras_bert.datasets import get_pretrained, PretrainedList + model_path = get_pretrained(PretrainedList.chinese_base) + +paths = get_checkpoint_paths(model_path) -model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=10) +model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10) model.summary(line_length=120) -token_dict = load_vocabulary(dict_path) +token_dict = load_vocabulary(paths.vocab) tokenizer = Tokenizer(token_dict) text = '语言模型' diff --git a/demo/load_model/load_and_pool.py b/demo/load_model/load_and_pool.py index fb56145..41082e2 100644 --- a/demo/load_model/load_and_pool.py +++ b/demo/load_model/load_and_pool.py @@ -1,28 +1,25 @@ -import os import sys import numpy as np import keras -from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer +from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths from keras_bert.layers import MaskedGlobalMaxPool1D - -if len(sys.argv) != 2: - print('python load_model.py UNZIPPED_MODEL_PATH') - sys.exit(-1) - print('This demo demonstrates how to load the pre-trained model and extract the sentence embedding with pooling.') -model_path = sys.argv[1] -config_path = os.path.join(model_path, 'bert_config.json') -checkpoint_path = os.path.join(model_path, 'bert_model.ckpt') -dict_path = os.path.join(model_path, 'vocab.txt') +if len(sys.argv) == 2: + model_path = sys.argv[1] +else: + from keras_bert.datasets import get_pretrained, PretrainedList + model_path = get_pretrained(PretrainedList.chinese_base) + +paths = get_checkpoint_paths(model_path) -model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=10) +model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10) pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output) model = keras.models.Model(inputs=model.inputs, outputs=pool_layer) model.summary(line_length=120) -token_dict = load_vocabulary(dict_path) +token_dict = load_vocabulary(paths.vocab) tokenizer = Tokenizer(token_dict) text = '语言模型' diff --git a/demo/load_model/load_and_predict.py b/demo/load_model/load_and_predict.py index 2b8975b..e28a9ea 100644 --- a/demo/load_model/load_and_predict.py +++ b/demo/load_model/load_and_predict.py @@ -1,24 +1,21 @@ -import os import sys import numpy as np -from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer - - -if len(sys.argv) != 2: - print('python load_model.py UNZIPPED_MODEL_PATH') - sys.exit(-1) +from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths print('This demo demonstrates how to load the pre-trained model and check whether the two sentences are continuous') -model_path = sys.argv[1] -config_path = os.path.join(model_path, 'bert_config.json') -checkpoint_path = os.path.join(model_path, 'bert_model.ckpt') -dict_path = os.path.join(model_path, 'vocab.txt') +if len(sys.argv) == 2: + model_path = sys.argv[1] +else: + from keras_bert.datasets import get_pretrained, PretrainedList + model_path = get_pretrained(PretrainedList.chinese_base) + +paths = get_checkpoint_paths(model_path) -model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True, seq_len=None) +model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, training=True, seq_len=None) model.summary(line_length=120) -token_dict = load_vocabulary(dict_path) +token_dict = load_vocabulary(paths.vocab) token_dict_inv = {v: k for k, v in token_dict.items()} tokenizer = Tokenizer(token_dict) diff --git a/keras_bert/__init__.py b/keras_bert/__init__.py index 393f20e..fd9a15d 100644 --- a/keras_bert/__init__.py +++ b/keras_bert/__init__.py @@ -3,3 +3,4 @@ from .tokenizer import Tokenizer from .optimizers import * from .util import * +from .datasets import * diff --git a/keras_bert/datasets/__init__.py b/keras_bert/datasets/__init__.py new file mode 100644 index 0000000..09c264f --- /dev/null +++ b/keras_bert/datasets/__init__.py @@ -0,0 +1 @@ +from .pretrained import * diff --git a/keras_bert/datasets/pretrained.py b/keras_bert/datasets/pretrained.py new file mode 100644 index 0000000..67d2b32 --- /dev/null +++ b/keras_bert/datasets/pretrained.py @@ -0,0 +1,47 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import os +import shutil +from collections import namedtuple +from keras_bert.backend import keras + +__all__ = ['PretrainedInfo', 'PretrainedList', 'get_pretrained'] + + +PretrainedInfo = namedtuple('PretrainedInfo', ['url', 'extract_name', 'target_name']) + + +class PretrainedList(object): + + __test__ = PretrainedInfo( + 'https://github.com/CyberZHG/keras-bert/archive/master.zip', + 'keras-bert-master', + 'keras-bert', + ) + + multi_cased_base = 'https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip' + chinese_base = 'https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip' + wwm_uncased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip' + wwm_cased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip' + chinese_wwm_base = PretrainedInfo( + 'https://storage.googleapis.com/hfl-rc/chinese-bert/chinese_wwm_L-12_H-768_A-12.zip', + 'publish', + 'chinese_wwm_L-12_H-768_A-12', + ) + + +def get_pretrained(info): + path = info + if isinstance(info, PretrainedInfo): + path = info.url + path = keras.utils.get_file(fname=os.path.split(path)[-1], origin=path, extract=True) + base_part, file_part = os.path.split(path) + file_part = file_part.split('.')[0] + if isinstance(info, PretrainedInfo): + extract_path = os.path.join(base_part, info.extract_name) + target_path = os.path.join(base_part, info.target_name) + if not os.path.exists(target_path): + shutil.move(extract_path, target_path) + file_part = info.target_name + return os.path.join(base_part, file_part) diff --git a/setup.py b/setup.py index 7405f0f..8d930e8 100644 --- a/setup.py +++ b/setup.py @@ -11,7 +11,7 @@ setup( name='keras-bert', - version='0.64.0', + version='0.65.0', packages=find_packages(), url='https://github.com/CyberZHG/keras-bert', license='MIT', diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datasets/test_get_pretrained.py b/tests/datasets/test_get_pretrained.py new file mode 100644 index 0000000..6958cc5 --- /dev/null +++ b/tests/datasets/test_get_pretrained.py @@ -0,0 +1,10 @@ +import os +from unittest import TestCase +from keras_bert.datasets import get_pretrained, PretrainedList + + +class TestGetPretrained(TestCase): + + def test_get_pretrained(self): + path = get_pretrained(PretrainedList.__test__) + self.assertTrue(os.path.exists(os.path.join(path, 'README.md')))