-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6b085cc
commit b889a16
Showing
4 changed files
with
223 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import argparse | ||
from util import load_datasets | ||
from model import SentClassif | ||
import time | ||
import torch.nn as nn | ||
from torch.optim import Adadelta | ||
from torch.optim.lr_scheduler import ReduceLROnPlateau | ||
from train import train_model, evaluate | ||
import torch | ||
import numpy as np | ||
import random | ||
from tensorboardX import SummaryWriter | ||
|
||
|
||
seed_num = 42 | ||
random.seed(seed_num) | ||
torch.manual_seed(seed_num) | ||
np.random.seed(seed_num) | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='Convolutional Neural Networks for Sentence Classification') | ||
parser.add_argument('--word_dim', default=300) | ||
parser.add_argument('--out_dim', default=100, help='The out channel of cnn') | ||
parser.add_argument('--dropout', default=0.5) | ||
parser.add_argument('--savedir', default='data/model/') | ||
parser.add_argument('--batch_size', type=int, default=50) | ||
parser.add_argument('--epochs', default=50) | ||
parser.add_argument('--early_stop', default=10) | ||
parser.add_argument('--least_epoch', default=15) | ||
parser.add_argument('--optimizer', choices=['Adadelta', 'Adam', 'Sgd'], default='Adadelta') | ||
parser.add_argument('--lr', default=1.) | ||
parser.add_argument('--filters', default=[3, 4, 5]) | ||
parser.add_argument('--pretrain', action='store_true') | ||
|
||
args = parser.parse_args() | ||
|
||
data_iters, text_vocab, label_vocab = load_datasets(args.batch_size, args.pretrain) | ||
label_vocab_size = len(label_vocab) - 2 | ||
model = SentClassif(args.word_dim, args.out_dim, label_vocab_size, text_vocab, args.dropout, args.filters) | ||
criterion = nn.CrossEntropyLoss() | ||
optimizer = Adadelta(model.parameters(), lr=args.lr) | ||
|
||
# for name, param in model.named_parameters(): | ||
# print(name) | ||
# print(param.size()) | ||
# print('*'*100) | ||
|
||
best_acc = -1 | ||
patience_count = 0 | ||
model_name = args.savedir + 'best.pt' | ||
writer = SummaryWriter('log') | ||
|
||
train_begin = time.time() | ||
|
||
print('train begin with use pretrain wordvectors :', args.pretrain, time.asctime(time.localtime(time.time()))) | ||
print('*'*100) | ||
print() | ||
|
||
for epoch in range(args.epochs): | ||
epoch_begin = time.time() | ||
|
||
print('train {}/{} epoch starting:'.format(epoch+1, args.epochs)) | ||
loss = train_model(data_iters['train_iter'], model, criterion, optimizer) | ||
acc = evaluate(data_iters['dev_iter'], model) | ||
print('acc:', acc) | ||
writer.add_scalar('dev_acc', acc, epoch) | ||
if acc > best_acc: | ||
patience_count = 0 | ||
best_acc = acc | ||
print('new best_acc:', best_acc) | ||
torch.save(model.state_dict(), model_name) | ||
else: | ||
patience_count += 1 | ||
|
||
epoch_end = time.time() | ||
cost_time = epoch_end - epoch_begin | ||
print('train {}th cost {}s'.format(epoch+1, cost_time)) | ||
print('-'*100) | ||
print() | ||
writer.add_scalar('train_loss', loss, epoch) | ||
if patience_count > args.early_stop and epoch + 1 > args.least_epoch: | ||
break | ||
|
||
train_end = time.time() | ||
|
||
train_cost = train_end - train_begin | ||
hour = int(train_cost / 3600) | ||
min = int((train_cost % 3600)/60) | ||
second = train_cost % 3600 % 60 | ||
print('train total cost {}h {}m {}s'.format(hour, min, second)) | ||
model.load_state_dict(torch.load(model_name)) | ||
test_acc = evaluate(data_iters['test_iter'], model) | ||
print('The test accuracy:', test_acc) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import torch.nn as nn | ||
import torch | ||
import torch.nn.functional as F | ||
import numpy as np | ||
|
||
|
||
class SentClassif(nn.Module): | ||
def __init__(self, word_dim, out_dim, label_size, vocab, dropout, filters): | ||
super(SentClassif, self).__init__() | ||
|
||
self.embedding = nn.Embedding(len(vocab), word_dim) | ||
self.cnn_list = nn.ModuleList() | ||
if vocab.vectors is not None: | ||
self.embedding.weight.data.copy_(vocab.vectors) | ||
else: | ||
self.embedding.weight.data.copy_(torch.from_numpy(self.random_embedding(len(vocab), word_dim))) | ||
for kernel_size in filters: | ||
self.cnn_list.append(nn.Conv1d(word_dim, out_dim, kernel_size, padding=int((kernel_size-1)/2))) | ||
|
||
self.out2tag = nn.Linear(out_dim*3, label_size) | ||
|
||
self.drop = nn.Dropout(dropout) | ||
|
||
def random_embedding(self, vocab_size, word_dim): | ||
pretrain_emb = np.empty([vocab_size, word_dim]) | ||
scale = np.sqrt(3.0 / word_dim) | ||
for index in range(vocab_size): | ||
pretrain_emb[index,:] = np.random.uniform(-scale, scale, [1, word_dim]) | ||
return pretrain_emb | ||
|
||
def forward(self, word_input): | ||
batch_size = word_input.size(0) | ||
word_represent = self.embedding(word_input) | ||
word_represent = word_represent.transpose(1, 2) | ||
out = [] | ||
for cnn in self.cnn_list: | ||
cnn_out = cnn(word_represent) | ||
cnn_out = F.relu(cnn_out) | ||
cnn_out = F.max_pool1d(cnn_out, cnn_out.size(2)).view(batch_size, -1) | ||
out.append(cnn_out) | ||
|
||
cat_out = torch.cat(out, 1) | ||
|
||
out = self.drop(cat_out) | ||
tag_socre = self.out2tag(out) | ||
|
||
return tag_socre |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch.nn.functional as F | ||
import torch | ||
|
||
|
||
def train_model(train_iter, model, criterion, optimizer): | ||
model.train() | ||
total_loss = 0. | ||
|
||
for batch in train_iter: | ||
model.zero_grad() | ||
word_input = batch.text | ||
target = batch.label - 2 | ||
tag_socre = model(word_input) | ||
tag_socre = tag_socre.view(-1, tag_socre.size(1)) | ||
loss = criterion(tag_socre, target.view(-1)) | ||
loss.backward() | ||
optimizer.step() | ||
total_loss += loss.item() | ||
|
||
print('total_loss:', total_loss) | ||
return total_loss | ||
|
||
|
||
def evaluate(val_or_test_iter, model): | ||
model.eval() | ||
correct_num = 0 | ||
total_num = 0 | ||
|
||
for batch in val_or_test_iter: | ||
word_input = batch.text | ||
target = batch.label - 2 | ||
target = target.view(-1) | ||
total_num += len(target) | ||
tag_score = model(word_input) | ||
tag_score = F.softmax(tag_score, dim=1) | ||
_, preds = torch.max(tag_score, 1) | ||
correct_num += torch.sum((preds == target)).item() | ||
|
||
acc = (correct_num/total_num) * 100 | ||
return acc |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from torchtext import data | ||
from torchtext.datasets.trec import TREC | ||
from torchtext.data import BucketIterator | ||
from torchtext.vocab import Vectors | ||
import re | ||
import os | ||
|
||
SPACE_NORMALIZER = re.compile(r"\s+") | ||
|
||
|
||
def tokenize_line(line): | ||
line = SPACE_NORMALIZER.sub(" ", line) | ||
line = line.strip() | ||
return line.split() | ||
|
||
|
||
def load_datasets(batch_size, pretrain): | ||
text = data.Field(tokenize=tokenize_line, lower=True, batch_first=True) | ||
label = data.Field(tokenize=tokenize_line, batch_first=True) | ||
train_dev_data, test_data = TREC.splits(text_field=text, label_field=label, root='data') | ||
train_data, dev_data = train_dev_data.split(split_ratio=0.9) | ||
if pretrain: | ||
print('use pretrain word vectors') | ||
cache = '.vector_cache' | ||
if not os.path.exists('.vector_cache'): | ||
os.mkdir('.vector_cache') | ||
vectors = Vectors(name='data/glove/glove.6B.300d.txt', cache=cache) | ||
text.build_vocab(train_data, dev_data, test_data, vectors=vectors) | ||
else: | ||
text.build_vocab(train_data, dev_data, test_data) | ||
label.build_vocab(train_data) | ||
|
||
train_iter, dev_iter, test_iter = BucketIterator.splits((train_data, dev_data, test_data), batch_sizes=(batch_size, batch_size, batch_size), | ||
sort_key=lambda x: len(x.text), sort_within_batch=True, repeat=False) | ||
data_iters = {'train_iter': train_iter, 'dev_iter': dev_iter, 'test_iter': test_iter} | ||
print('vocabulary size:', len(text.vocab)) | ||
|
||
return data_iters, text.vocab, label.vocab | ||
|
||
|
||
if __name__ == '__main__': | ||
load_datasets(50, False) |