Skip to content
Jason Lin edited this page Feb 9, 2022 · 2 revisions

Fern Wiki

这个仓库主要目的是基于tensorflow快速实现常用算法模型

安装

pip install Fern2

常用模型使用方法

Word2Vec

训练
from fern.estimate import word2vec_estimator

path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
with open(path_to_file, 'r') as f:
    data = []
    for line in f:
        if line.strip():
            data.append(line.strip())
model, vocab, loss = word2vec_estimator(
    data=data,                      # 训练数据,句子列表 
    batch_size=128,                 # 批训练大小
    zh_segmentation=True,           # 是否中文分词,否则中文按字分词
    model_home='/tmp/Fern/model',   # 模型保存位置
    ckpt_home='/tmp/Fern/ckpt',     # 模型检查点保存位置
    log_home='/tmp/Fern/log',       # 训练日志保存位置
    epoch=200,                      # 数据训练多少个轮回
    embedding_dim=128,              # 词向量大小 
    version=1,                      # 模型版本号,影响模型保存路径
    opt='adam',                     # 训练的优化器名字
    lr=1.0e-3,                      # 训练的学习率
    lower=True,                     # 是否忽略大小写
    win_size=2,                     # 生成训练数据的窗口大小,影响上下文距离
    random_seed=42,                 # 随机种子,用于训练复现
    verbose=1)                      # 0不打印日志,1打印epoch日志,2打印batch日志
print(loss)
使用
from fern.models import Word2Vec
import numpy as np

data = ['friends', 'state', 'test']
model = Word2Vec.load('/tmp/Fern/model/word2vec/1')
vocab = {word: index for index, word in enumerate(vocab)}

data = [vocab.get(da, vocab.get('[UNK]')) for da in data]

data_vec = model(np.array(data))  # (3, embedding_dim)
Clone this wiki locally