Skip to content

Commit

Permalink
Add datasets (CyberZHG#93)
Browse files Browse the repository at this point in the history
* Add datasets

* Update demos

* Update README
  • Loading branch information
CyberZHG authored Jun 21, 2019
1 parent 81c6775 commit 649b8e6
Show file tree
Hide file tree
Showing 11 changed files with 114 additions and 40 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,18 @@ total_steps, warmup_steps = calc_train_steps(
optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5)
```

### Download Pretrained Checkpoints

Several download urls has been added. You can get the downloaded and uncompressed path of a checkpoint by:

```python
from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths

model_path = get_pretrained(PretrainedList.multi_cased_base)
paths = get_checkpoint_paths(model_path)
print(paths.config, paths.checkpoint, paths.vocab)
```

### Extract Features

You can use helper function `extract_embeddings` if the features of tokens or sentences (without further tuning) are what you need. To extract the features of all tokens:
Expand Down
12 changes: 12 additions & 0 deletions README.zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ optimizer = AdamWarmup(total_steps, warmup_steps, lr=1e-3, min_lr=1e-5)

`training``True`的情况下,输入包含三项:token下标、segment下标、被masked的词的模版。当`training``False`时输入只包含前两项。位置下标由于是固定的,会在模型内部生成,不需要手动再输入一遍。被masked的词的模版在输入被masked的词是值为1,否则为0。

### 下载预训练模型

库中记录了一些预训练模型的下载地址,可以通过如下方式获得解压后的checkpoint的路径:

```python
from keras_bert import get_pretrained, PretrainedList, get_checkpoint_paths

model_path = get_pretrained(PretrainedList.multi_cased_base)
paths = get_checkpoint_paths(model_path)
print(paths.config, paths.checkpoint, paths.vocab)
```

### 提取特征

如果不需要微调,只想提取词/句子的特征,则可以使用`extract_embeddings`来简化流程。如提取每个句子对应的全部词的特征:
Expand Down
23 changes: 10 additions & 13 deletions demo/load_model/load_and_extract.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import os
import sys
import numpy as np
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer


if len(sys.argv) != 2:
print('python load_model.py UNZIPPED_MODEL_PATH')
sys.exit(-1)
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths

print('This demo demonstrates how to load the pre-trained model and extract word embeddings')

model_path = sys.argv[1]
config_path = os.path.join(model_path, 'bert_config.json')
checkpoint_path = os.path.join(model_path, 'bert_model.ckpt')
dict_path = os.path.join(model_path, 'vocab.txt')
if len(sys.argv) == 2:
model_path = sys.argv[1]
else:
from keras_bert.datasets import get_pretrained, PretrainedList
model_path = get_pretrained(PretrainedList.chinese_base)

paths = get_checkpoint_paths(model_path)

model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=10)
model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10)
model.summary(line_length=120)

token_dict = load_vocabulary(dict_path)
token_dict = load_vocabulary(paths.vocab)

tokenizer = Tokenizer(token_dict)
text = '语言模型'
Expand Down
23 changes: 10 additions & 13 deletions demo/load_model/load_and_pool.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
import os
import sys
import numpy as np
import keras
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths
from keras_bert.layers import MaskedGlobalMaxPool1D


if len(sys.argv) != 2:
print('python load_model.py UNZIPPED_MODEL_PATH')
sys.exit(-1)

print('This demo demonstrates how to load the pre-trained model and extract the sentence embedding with pooling.')

model_path = sys.argv[1]
config_path = os.path.join(model_path, 'bert_config.json')
checkpoint_path = os.path.join(model_path, 'bert_model.ckpt')
dict_path = os.path.join(model_path, 'vocab.txt')
if len(sys.argv) == 2:
model_path = sys.argv[1]
else:
from keras_bert.datasets import get_pretrained, PretrainedList
model_path = get_pretrained(PretrainedList.chinese_base)

paths = get_checkpoint_paths(model_path)

model = load_trained_model_from_checkpoint(config_path, checkpoint_path, seq_len=10)
model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10)
pool_layer = MaskedGlobalMaxPool1D(name='Pooling')(model.output)
model = keras.models.Model(inputs=model.inputs, outputs=pool_layer)
model.summary(line_length=120)

token_dict = load_vocabulary(dict_path)
token_dict = load_vocabulary(paths.vocab)

tokenizer = Tokenizer(token_dict)
text = '语言模型'
Expand Down
23 changes: 10 additions & 13 deletions demo/load_model/load_and_predict.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
import os
import sys
import numpy as np
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer


if len(sys.argv) != 2:
print('python load_model.py UNZIPPED_MODEL_PATH')
sys.exit(-1)
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths

print('This demo demonstrates how to load the pre-trained model and check whether the two sentences are continuous')

model_path = sys.argv[1]
config_path = os.path.join(model_path, 'bert_config.json')
checkpoint_path = os.path.join(model_path, 'bert_model.ckpt')
dict_path = os.path.join(model_path, 'vocab.txt')
if len(sys.argv) == 2:
model_path = sys.argv[1]
else:
from keras_bert.datasets import get_pretrained, PretrainedList
model_path = get_pretrained(PretrainedList.chinese_base)

paths = get_checkpoint_paths(model_path)

model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True, seq_len=None)
model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, training=True, seq_len=None)
model.summary(line_length=120)

token_dict = load_vocabulary(dict_path)
token_dict = load_vocabulary(paths.vocab)
token_dict_inv = {v: k for k, v in token_dict.items()}

tokenizer = Tokenizer(token_dict)
Expand Down
1 change: 1 addition & 0 deletions keras_bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
from .tokenizer import Tokenizer
from .optimizers import *
from .util import *
from .datasets import *
1 change: 1 addition & 0 deletions keras_bert/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .pretrained import *
47 changes: 47 additions & 0 deletions keras_bert/datasets/pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# coding=utf-8
from __future__ import unicode_literals

import os
import shutil
from collections import namedtuple
from keras_bert.backend import keras

__all__ = ['PretrainedInfo', 'PretrainedList', 'get_pretrained']


PretrainedInfo = namedtuple('PretrainedInfo', ['url', 'extract_name', 'target_name'])


class PretrainedList(object):

__test__ = PretrainedInfo(
'https://github.com/CyberZHG/keras-bert/archive/master.zip',
'keras-bert-master',
'keras-bert',
)

multi_cased_base = 'https://storage.googleapis.com/bert_models/2018_11_23/multi_cased_L-12_H-768_A-12.zip'
chinese_base = 'https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip'
wwm_uncased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_uncased_L-24_H-1024_A-16.zip'
wwm_cased_large = 'https://storage.googleapis.com/bert_models/2019_05_30/wwm_cased_L-24_H-1024_A-16.zip'
chinese_wwm_base = PretrainedInfo(
'https://storage.googleapis.com/hfl-rc/chinese-bert/chinese_wwm_L-12_H-768_A-12.zip',
'publish',
'chinese_wwm_L-12_H-768_A-12',
)


def get_pretrained(info):
path = info
if isinstance(info, PretrainedInfo):
path = info.url
path = keras.utils.get_file(fname=os.path.split(path)[-1], origin=path, extract=True)
base_part, file_part = os.path.split(path)
file_part = file_part.split('.')[0]
if isinstance(info, PretrainedInfo):
extract_path = os.path.join(base_part, info.extract_name)
target_path = os.path.join(base_part, info.target_name)
if not os.path.exists(target_path):
shutil.move(extract_path, target_path)
file_part = info.target_name
return os.path.join(base_part, file_part)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

setup(
name='keras-bert',
version='0.64.0',
version='0.65.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-bert',
license='MIT',
Expand Down
Empty file added tests/datasets/__init__.py
Empty file.
10 changes: 10 additions & 0 deletions tests/datasets/test_get_pretrained.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import os
from unittest import TestCase
from keras_bert.datasets import get_pretrained, PretrainedList


class TestGetPretrained(TestCase):

def test_get_pretrained(self):
path = get_pretrained(PretrainedList.__test__)
self.assertTrue(os.path.exists(os.path.join(path, 'README.md')))

0 comments on commit 649b8e6

Please sign in to comment.