-
Notifications
You must be signed in to change notification settings - Fork 1
Home
Jason Lin edited this page Feb 9, 2022
·
2 revisions
这个仓库主要目的是基于tensorflow快速实现常用算法模型
pip install Fern2
from fern.estimate import word2vec_estimator
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
with open(path_to_file, 'r') as f:
data = []
for line in f:
if line.strip():
data.append(line.strip())
model, vocab, loss = word2vec_estimator(
data=data, # 训练数据,句子列表
batch_size=128, # 批训练大小
zh_segmentation=True, # 是否中文分词,否则中文按字分词
model_home='/tmp/Fern/model', # 模型保存位置
ckpt_home='/tmp/Fern/ckpt', # 模型检查点保存位置
log_home='/tmp/Fern/log', # 训练日志保存位置
epoch=200, # 数据训练多少个轮回
embedding_dim=128, # 词向量大小
version=1, # 模型版本号,影响模型保存路径
opt='adam', # 训练的优化器名字
lr=1.0e-3, # 训练的学习率
lower=True, # 是否忽略大小写
win_size=2, # 生成训练数据的窗口大小,影响上下文距离
random_seed=42, # 随机种子,用于训练复现
verbose=1) # 0不打印日志,1打印epoch日志,2打印batch日志
print(loss)
from fern.models import Word2Vec
import numpy as np
data = ['friends', 'state', 'test']
model = Word2Vec.load('/tmp/Fern/model/word2vec/1')
vocab = {word: index for index, word in enumerate(vocab)}
data = [vocab.get(da, vocab.get('[UNK]')) for da in data]
data_vec = model(np.array(data)) # (3, embedding_dim)