diff --git a/examples/custom_generator.py b/examples/custom_generator.py new file mode 100644 index 00000000..4f73d484 --- /dev/null +++ b/examples/custom_generator.py @@ -0,0 +1,87 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: custom_generator.py +# time: 4:13 下午 + +import os +import linecache +from tensorflow.keras.utils import get_file +from kashgari.generators import ABCGenerator + + +def download_data(duplicate=1000): + url_list = [ + 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.train.w-intent.iob', + 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis-2.dev.w-intent.iob', + 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.test.w-intent.iob', + 'https://raw.githubusercontent.com/BrikerMan/JointSLU/master/data/atis.train.w-intent.iob' + ] + files = [] + for url in url_list: + files.append(get_file(url.split('/')[-1], url)) + + return files * duplicate + + +class ClassificationGenerator: + def __init__(self, files): + self.files = files + self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files) + + @property + def steps(self) -> int: + return self._line_count + + def __iter__(self): + for file in self.files: + with open(file, 'r') as f: + for line in f: + rows = line.split('\t') + x = rows[0].strip().split(' ')[1:-1] + y = rows[1].strip().split(' ')[-1] + yield x, y + + +class LabelingGenerator(ABCGenerator): + def __init__(self, files): + self.files = files + self._line_count = sum(sum(1 for line in open(file, 'r')) for file in files) + + @property + def steps(self) -> int: + return self._line_count + + def __iter__(self): + for file in self.files: + with open(file, 'r') as f: + for line in f: + rows = line.split('\t') + x = rows[0].strip().split(' ')[1:-1] + y = rows[1].strip().split(' ')[1:-1] + yield x, y + + +def run_classification_model(): + from kashgari.tasks.classification import BiGRU_Model + files = download_data() + gen = ClassificationGenerator(files) + + model = BiGRU_Model() + model.fit_generator(gen) + + +def run_labeling_model(): + from kashgari.tasks.labeling import BiGRU_Model + files = download_data() + gen = LabelingGenerator(files) + + model = BiGRU_Model() + model.fit_generator(gen) + + +if __name__ == "__main__": + run_classification_model() diff --git a/kashgari/embeddings/abc_embedding.py b/kashgari/embeddings/abc_embedding.py index 7e28165e..c09f3126 100644 --- a/kashgari/embeddings/abc_embedding.py +++ b/kashgari/embeddings/abc_embedding.py @@ -7,6 +7,7 @@ # file: abc_embedding.py # time: 2:43 下午 +import tqdm import json import pydoc import logging @@ -73,6 +74,8 @@ def __init__(self, self.segment = False # 默认不需要添加 segment self.kwargs = kwargs + self.embedding_size = None + def set_sequence_length(self, length: int): self.sequence_length = length if self.embed_model is not None: @@ -80,6 +83,15 @@ def set_sequence_length(self, length: int): self.embed_model = None self.build_embedding_model() + def calculate_sequence_length_if_needs(self, corpus_gen: CorpusGenerator, cover_rate: float = 0.95): + if self.sequence_length is None: + seq_lens = [] + for sentence, _ in tqdm.tqdm(corpus_gen, total=corpus_gen.steps, + desc="Calculating sequence length"): + seq_lens.append(len(sentence)) + self.sequence_length = sorted(seq_lens)[int(cover_rate * len(seq_lens))] + logging.warning(f'Calculated sequence length = {self.sequence_length}') + def build(self, x_data: TextSamplesVar, y_data: LabelSamplesVar): gen = CorpusGenerator(x_data=x_data, y_data=y_data) self.build_with_generator(gen) diff --git a/kashgari/embeddings/transformer_embedding.py b/kashgari/embeddings/transformer_embedding.py index 09e460f9..233c5d8e 100644 --- a/kashgari/embeddings/transformer_embedding.py +++ b/kashgari/embeddings/transformer_embedding.py @@ -59,6 +59,7 @@ def __init__(self, self.segment = True self.vocab_list = [] + self.max_sequence_length = None def build_text_vocab(self, gen: CorpusGenerator = None, force=False): if not self.text_processor.is_vocab_build: @@ -78,14 +79,13 @@ def build_text_vocab(self, gen: CorpusGenerator = None, force=False): def build_embedding_model(self): if self.embed_model is None: - kwargs = {} config_path = self.config_path config = json.load(open(config_path)) - if self.sequence_length: - if self.sequence_length > config.get('max_position_embeddings'): - self.sequence_length = config.get('max_position_embeddings') - logging.warning(f"Max seq length is {self.sequence_length}") + if 'max_position' in config: + self.max_sequence_length = config['max_position'] + else: + self.max_sequence_length = config.get('max_position_embeddings') bert_model = build_transformer_model(config_path=self.config_path, checkpoint_path=self.checkpoint_path, @@ -94,6 +94,7 @@ def build_embedding_model(self): return_keras_model=True) self.embed_model = bert_model + self.embedding_size = bert_model.output.shape[-1] if __name__ == "__main__": diff --git a/kashgari/generators.py b/kashgari/generators.py index febaf532..7d806d89 100644 --- a/kashgari/generators.py +++ b/kashgari/generators.py @@ -7,44 +7,43 @@ # file: generator.py # time: 4:53 下午 +from abc import ABC import random from typing import List +from typing import Iterable -class CorpusGenerator: +class ABCGenerator(Iterable, ABC): - def __init__(self, x_data: List, y_data: List): + @property + def steps(self) -> int: + raise NotImplementedError + + def __iter__(self): + raise NotImplementedError + + +class CorpusGenerator(ABCGenerator): + + def __init__(self, x_data: List, y_data: List, shuffle=True): self.x_data = x_data self.y_data = y_data self._index_list = list(range(len(self.x_data))) - self._current_index = 0 - random.shuffle(self._index_list) + if shuffle: + random.shuffle(self._index_list) - def reset(self): - self._current_index = 0 + def __iter__(self): + for i in self._index_list: + yield self.x_data[i], self.y_data[i] @property def steps(self) -> int: return len(self.x_data) - def __iter__(self): - return self - - def __next__(self): - self._current_index += 1 - if self._current_index >= len(self.x_data) - 1: - raise StopIteration() - - sample_index = self._index_list[self._current_index] - return self.x_data[sample_index], self.y_data[sample_index] - - def __call__(self, *args, **kwargs): - return self - -class BatchDataGenerator: +class BatchDataGenerator(Iterable): def __init__(self, corpus, text_processor, @@ -66,27 +65,24 @@ def steps(self) -> int: return self.corpus.steps // self.batch_size def __iter__(self): - return self - - def __next__(self): - x_set = [] - y_set = [] - for i in range(self.batch_size): - try: - x, y = next(self.corpus) - except StopIteration: - self.corpus.reset() - x, y = next(self.corpus) + x_set, y_set = [], [] + for x, y in self.corpus: x_set.append(x) y_set.append(y) + if len(x_set) == self.batch_size: + x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment) + y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True) + yield x_tensor, y_tensor + x_set, y_set = [], [] + # final step + if x_set: + x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment) + y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True) + yield x_tensor, y_tensor - x_tensor = self.text_processor.numerize_samples(x_set, seq_length=self.seq_length, segment=self.segment) - y_tensor = self.label_processor.numerize_samples(y_set, seq_length=self.seq_length, one_hot=True) - return x_tensor, y_tensor - - def __call__(self, *args, **kwargs): + def __next__(self): return self - -if __name__ == "__main__": - pass + def generator(self): + for item in self: + yield item diff --git a/kashgari/processors/abc_processor.py b/kashgari/processors/abc_processor.py index 2735450a..9d053b9b 100644 --- a/kashgari/processors/abc_processor.py +++ b/kashgari/processors/abc_processor.py @@ -25,8 +25,6 @@ def __init__(self, **kwargs): self.vocab2idx = kwargs.get('vocab2idx', {}) self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) - self.corpus_sequence_length = kwargs.get('corpus_sequence_length', None) - @property def vocab_size(self) -> int: return len(self.vocab2idx) @@ -35,7 +33,7 @@ def vocab_size(self) -> int: def is_vocab_build(self) -> bool: return self.vocab_size != 0 - def build_vocab_dict_if_needs(self, generator: Generator, min_count: int = 3): + def build_vocab_dict_if_needs(self, generator: Generator): raise NotImplementedError diff --git a/kashgari/processors/class_processor.py b/kashgari/processors/class_processor.py index 7c2fe253..665c6a95 100644 --- a/kashgari/processors/class_processor.py +++ b/kashgari/processors/class_processor.py @@ -22,7 +22,6 @@ class ClassificationProcessor(ABCProcessor): def build_vocab_dict_if_needs(self, generator: CorpusGenerator): - generator.reset() if not self.vocab2idx: vocab2idx = {} diff --git a/kashgari/processors/sequence_processor.py b/kashgari/processors/sequence_processor.py index 9b6f0028..a3147109 100644 --- a/kashgari/processors/sequence_processor.py +++ b/kashgari/processors/sequence_processor.py @@ -14,6 +14,7 @@ import tqdm import numpy as np from typing import Dict, List + from tensorflow.keras.preprocessing.sequence import pad_sequences from tensorflow.keras.utils import to_categorical @@ -73,19 +74,19 @@ def __init__(self, else: self._initial_vocab_dic = {} + self._showed_seq_len_warning = False + def build_vocab_dict_if_needs(self, generator: CorpusGenerator): if not self.vocab2idx: vocab2idx = self._initial_vocab_dic token2count = {} - seq_lens = [] - generator.reset() + for sentence, label in tqdm.tqdm(generator, total=generator.steps, desc="Preparing text vocab dict"): if self.vocab_dict_type == 'text': target = sentence else: target = label - seq_lens.append(len(target)) for token in target: count = token2count.get(token, 0) token2count[token] = count + 1 @@ -101,20 +102,10 @@ def build_vocab_dict_if_needs(self, generator: CorpusGenerator): self.vocab2idx = vocab2idx self.idx2vocab = dict([(v, k) for k, v in self.vocab2idx.items()]) - if self.corpus_sequence_length is None: - self.corpus_sequence_length = sorted(seq_lens)[int(0.95 * len(seq_lens))] - logging.info("------ Build vocab dict finished, Top 10 token ------") for token, index in list(self.vocab2idx.items())[:10]: logging.info(f"Token: {token:8s} -> {index}") logging.info("------ Build vocab dict finished, Top 10 token ------") - else: - if self.corpus_sequence_length is None: - seq_lens = [] - generator.reset() - for sentence, _ in generator: - seq_lens.append(len(sentence)) - self.corpus_sequence_length = sorted(seq_lens)[int(0.95 * len(seq_lens))] def numerize_samples(self, samples: TextSamplesVar, @@ -124,8 +115,10 @@ def numerize_samples(self, **kwargs) -> np.ndarray: if seq_length is None: seq_length = max([len(i) for i in samples]) - logging.warning( - f'Sequence length is None, will use the max length of the samples, which is {seq_length}') + if not self._showed_seq_len_warning: + logging.warning( + f'Sequence length is None, will use the max length of the samples, which is {seq_length}') + self._showed_seq_len_warning = True numerized_samples = [] for seq in samples: @@ -173,6 +166,3 @@ def reverse_numerize(self, p.build_vocab_dict_if_needs(gen) print(p.vocab2idx) - p2 = SequenceProcessor() - p2.build_vocab_dict_if_needs(gen) - print(p2.vocab2idx) diff --git a/kashgari/tasks/abs_task_model.py b/kashgari/tasks/abs_task_model.py index 3173d96f..410328c6 100644 --- a/kashgari/tasks/abs_task_model.py +++ b/kashgari/tasks/abs_task_model.py @@ -102,6 +102,7 @@ def build_model(self, raise ValueError('Need to set default_labeling_processor') self.embedding.label_processor = self.default_labeling_processor self.embedding.build_with_generator(train_gen) + self.embedding.calculate_sequence_length_if_needs(train_gen) if self.tf_model is None: self.build_model_arc() self.compile_model() diff --git a/kashgari/tasks/classification/abc_model.py b/kashgari/tasks/classification/abc_model.py index 2fb10cf5..c6fbd150 100644 --- a/kashgari/tasks/classification/abc_model.py +++ b/kashgari/tasks/classification/abc_model.py @@ -131,13 +131,13 @@ def fit_generator(self, segment=self.embedding.segment, seq_length=self.embedding.sequence_length, batch_size=batch_size) - fit_kwargs['validation_data'] = valid_gen + fit_kwargs['validation_data'] = valid_gen.generator() fit_kwargs['validation_steps'] = valid_gen.steps if callbacks: fit_kwargs['callbacks'] = callbacks - return self.tf_model.fit(train_gen, + return self.tf_model.fit(train_gen.generator(), steps_per_epoch=train_gen.steps, epochs=epochs, callbacks=callbacks) diff --git a/kashgari/tasks/labeling/__init__.py b/kashgari/tasks/labeling/__init__.py index 927d1378..d05830f4 100644 --- a/kashgari/tasks/labeling/__init__.py +++ b/kashgari/tasks/labeling/__init__.py @@ -7,6 +7,7 @@ # file: __init__.py # time: 4:30 下午 +from .bi_gru_model import BiGRU_Model from .bi_lstm_model import BiLSTM_Model if __name__ == "__main__": diff --git a/kashgari/tasks/labeling/abc_model.py b/kashgari/tasks/labeling/abc_model.py index 6fdd8ad3..9e2f75ac 100644 --- a/kashgari/tasks/labeling/abc_model.py +++ b/kashgari/tasks/labeling/abc_model.py @@ -128,18 +128,12 @@ def fit_generator(self, segment=self.embedding.segment, seq_length=self.embedding.sequence_length, batch_size=batch_size) - fit_kwargs['validation_data'] = valid_gen + fit_kwargs['validation_data'] = valid_gen.generator() fit_kwargs['validation_steps'] = valid_gen.steps if callbacks: fit_kwargs['callbacks'] = callbacks - (x0, x1), y = next(train_gen) - print('-'*20) - print(x0.shape) - print(x1.shape) - print(y.shape) - - return self.tf_model.fit(train_gen, + return self.tf_model.fit(train_gen.generator(), steps_per_epoch=train_gen.steps, epochs=epochs, callbacks=callbacks) diff --git a/kashgari/tasks/labeling/bi_gru_model.py b/kashgari/tasks/labeling/bi_gru_model.py new file mode 100644 index 00000000..bcebb264 --- /dev/null +++ b/kashgari/tasks/labeling/bi_gru_model.py @@ -0,0 +1,57 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: bi_gru_model.py +# time: 5:01 下午 + +from typing import Dict, Any + +from tensorflow import keras + +from kashgari.layers import L +from kashgari.tasks.labeling.abc_model import ABCLabelingModel + + +class BiGRU_Model(ABCLabelingModel): + @classmethod + def default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: + return { + 'layer_bgru': { + 'units': 128, + 'return_sequences': True + }, + 'layer_dropout': { + 'rate': 0.4 + }, + 'layer_time_distributed': {}, + 'layer_activation': { + 'activation': 'softmax' + } + } + + def build_model_arc(self): + output_dim = self.label_processor.vocab_size + + config = self.hyper_parameters + embed_model = self.embedding.embed_model + + layer_stack = [ + L.Bidirectional(L.GRU(**config['layer_bgru']), name='layer_bgru'), + L.Dropout(**config['layer_dropout'], name='layer_dropout'), + # L.Dense(output_dim, **config['layer_time_distributed']), + L.TimeDistributed(L.Dense(output_dim, **config['layer_time_distributed']),name='layer_time_distributed'), + L.Activation(**config['layer_activation']) + ] + + tensor = embed_model.output + for layer in layer_stack: + tensor = layer(tensor) + + self.tf_model = keras.Model(embed_model.inputs, tensor) + + +if __name__ == "__main__": + pass diff --git a/tests/test_embeddings/test_bare_embedding.py b/tests/test_embeddings/test_bare_embedding.py index 5907e3c8..76936c8d 100644 --- a/tests/test_embeddings/test_bare_embedding.py +++ b/tests/test_embeddings/test_bare_embedding.py @@ -18,18 +18,19 @@ def test_base_cases(self): x, y = SMP2018ECDTCorpus.load_data() embedding = BareEmbedding() embedding.build(x, y) - res = embedding.embed(x[:2]) - assert res.shape == (2, 15, 100) + res = embedding.embed(x[:10]) + max_len = max([len(i) for i in x[:10]]) + assert res.shape == (10, max_len, 100) embedding.set_sequence_length(30) res = embedding.embed(x[:2]) assert res.shape == (2, 30, 100) x, y = ChineseDailyNerCorpus.load_data() - embedding2 = BareEmbedding(sequence_length=30, embedding_size=32) + embedding2 = BareEmbedding(sequence_length=25, embedding_size=32) embedding2.build(x, y) res = embedding2.embed(x[:2]) - assert res.shape == (2, 30, 32) + assert res.shape == (2, 25, 32) if __name__ == "__main__": diff --git a/tests/test_embeddings/test_transformer_embedding.py b/tests/test_embeddings/test_transformer_embedding.py index 26d22854..ceeca867 100644 --- a/tests/test_embeddings/test_transformer_embedding.py +++ b/tests/test_embeddings/test_transformer_embedding.py @@ -16,24 +16,32 @@ from tests.test_macros import TestMacros +bert_path = get_file('bert_sample_model', + "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", + cache_dir=DATA_PATH, + untar=True) + class TestTransferEmbedding(unittest.TestCase): def test_bert_embedding(self): - bert_path = get_file('bert_sample_model', - "http://s3.bmio.net/kashgari/bert_sample_model.tar.bz2", - cache_dir=DATA_PATH, - untar=True) - embedding = BertEmbedding(model_folder=bert_path, sequence_length=12) - - # --- classification ---- + sequence_length = 12 + embedding = BertEmbedding(model_folder=bert_path, + sequence_length=sequence_length) x, y = TestMacros.load_classification_corpus() - embedding.embed(x) + res = embedding.embed(x[:10]) + assert res.shape == (10, sequence_length, embedding.embedding_size) + def test_classification_task(self): + embedding = BertEmbedding(model_folder=bert_path, sequence_length=12) model = Classification_BiLSTM_Model(embedding=embedding) + + x, y = TestMacros.load_classification_corpus() model.fit(x, y, epochs=1) + def test_label_task(self): # ------ labeling ------- + embedding = BertEmbedding(model_folder=bert_path, sequence_length=12) x, y = TestMacros.load_labeling_corpus() model = BiLSTM_Model(embedding=embedding) diff --git a/tests/test_embeddings/test_word_embedding.py b/tests/test_embeddings/test_word_embedding.py index e9404a2a..f0abee4d 100644 --- a/tests/test_embeddings/test_word_embedding.py +++ b/tests/test_embeddings/test_word_embedding.py @@ -31,8 +31,9 @@ def test_base_cases(self): x, y = SMP2018ECDTCorpus.load_data() embedding = WordEmbedding(self.w2v_path) embedding.build(x, y) - res = embedding.embed(x[:2]) - assert res.shape == (2, 15, 100) + res = embedding.embed(x[:10]) + max_len = max([len(i) for i in x[:10]]) + assert res.shape == (10, max_len, 100) embedding.set_sequence_length(30) res = embedding.embed(x[:2]) diff --git a/tests/test_generator.py b/tests/test_generator.py new file mode 100644 index 00000000..81a686cd --- /dev/null +++ b/tests/test_generator.py @@ -0,0 +1,64 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: test_generator.py +# time: 5:46 下午 + +import unittest +from tests.test_macros import TestMacros + +from kashgari.generators import CorpusGenerator, BatchDataGenerator + + +class TestGenerator(unittest.TestCase): + def test_corpus_generator(self): + x_set, y_set = TestMacros.load_labeling_corpus('custom_1') + corpus_gen = CorpusGenerator(x_set, y_set) + + for x, y in corpus_gen: + print(x, y) + + def test_batch_generator(self): + from kashgari.processors import SequenceProcessor + x_set, y_set = [], [] + for i in range(22): + x_set.append([f'x{i}'] * 4) + y_set.append([f'y{i}'] * 4) + corpus_gen = CorpusGenerator(x_set, y_set, shuffle=False) + + a = [] + for x, y in corpus_gen: + print(x, y) + a.append(x[0]) + + print(sorted(a)) + + p1 = SequenceProcessor(min_count=1) + p1.build_vocab_dict_if_needs(corpus_gen) + p2 = SequenceProcessor(vocab_dict_type='labeling', min_count=1) + p2.build_vocab_dict_if_needs(corpus_gen) + + batch_gen = BatchDataGenerator(corpus_gen, + text_processor=p1, + label_processor=p2, + seq_length=5, + batch_size=4) + print('------ Iterator --------') + for i in batch_gen: + x, y = i + print(x) + + print('------ Generator --------') + gen = batch_gen.generator() + try: + while True: + x, y = next(gen) + print(x) + except StopIteration: + pass + +if __name__ == '__main__': + unittest.main()