From b4d40d4da8c724cdbb2eac5e48abccff4ef6a660 Mon Sep 17 00:00:00 2001 From: MysteryVaibhav Date: Wed, 23 Jan 2019 02:45:09 -0500 Subject: [PATCH 1/4] Adding pytorch version for 05-cnn --- 05-cnn-pytorch/cnn-activation.py | 147 +++++++++++++++++++++++++++++++ 05-cnn-pytorch/cnn-class.py | 108 +++++++++++++++++++++++ 2 files changed, 255 insertions(+) create mode 100644 05-cnn-pytorch/cnn-activation.py create mode 100644 05-cnn-pytorch/cnn-class.py diff --git a/05-cnn-pytorch/cnn-activation.py b/05-cnn-pytorch/cnn-activation.py new file mode 100644 index 0000000..55985ee --- /dev/null +++ b/05-cnn-pytorch/cnn-activation.py @@ -0,0 +1,147 @@ +from collections import defaultdict +import time +import random +import torch +import numpy as np + + +class CNNclass(torch.nn.Module): + def __init__(self, nwords, emb_size, num_filters, window_size, ntags): + super(CNNclass, self).__init__() + + """ layers """ + self.embedding = torch.nn.Embedding(nwords, emb_size) + # uniform initialization + torch.nn.init.uniform_(self.embedding.weight, -0.25, 0.25) + # Conv 1d + self.conv_1d = torch.nn.Conv1d(in_channels=emb_size, out_channels=num_filters, kernel_size=window_size, + stride=1, padding=0, dilation=1, groups=1, bias=True) + self.relu = torch.nn.ReLU() + self.projection_layer = torch.nn.Linear(in_features=num_filters, out_features=ntags, bias=True) + # Initializing the projection layer + torch.nn.init.xavier_uniform_(self.projection_layer.weight) + + def forward(self, words, return_activations=False): + emb = self.embedding(words) # nwords x emb_size + emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords + h = self.conv_1d(emb) # 1 x num_filters x nwords + activations = h.squeeze().max(dim=1)[1] # argmax along length of the sentence + # Do max pooling + h = h.max(dim=2)[0] # 1 x num_filters + h = self.relu(h) + features = h.squeeze() + out = self.projection_layer(h) # size(out) = 1 x ntags + if return_activations: + return out, activations.data.cpu().numpy(), features.data.cpu().numpy() + return out + + +np.set_printoptions(linewidth=np.nan, threshold=np.nan) + +# Functions to read in the corpus +w2i = defaultdict(lambda: len(w2i)) +UNK = w2i[""] +def read_dataset(filename): + with open(filename, "r") as f: + for line in f: + tag, words = line.lower().strip().split(" ||| ") + words = words.split(" ") + yield (words, [w2i[x] for x in words], int(tag)) + +# Read in the data +train = list(read_dataset("../data/classes/train.txt"))[:50] +w2i = defaultdict(lambda: UNK, w2i) +dev = list(read_dataset("../data/classes/test.txt"))[:10] +nwords = len(w2i) +ntags = 5 + +# Define the model +EMB_SIZE = 10 +WIN_SIZE = 3 +FILTER_SIZE = 8 + +# initialize the model +model = CNNclass(nwords, EMB_SIZE, FILTER_SIZE, WIN_SIZE, ntags) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters()) + +type = torch.LongTensor +use_cuda = torch.cuda.is_available() + +if use_cuda: + type = torch.cuda.LongTensor + model.cuda() + + +def calc_predict_and_activations(wids, tag, words): + if len(wids) < WIN_SIZE: + wids += [0] * (WIN_SIZE-len(wids)) + words_tensor = torch.tensor(wids).type(type) + scores, activations, features = model(words_tensor, return_activations=True) + scores = scores.squeeze().cpu().data.numpy() + print('%d ||| %s' % (tag, ' '.join(words))) + predict = np.argmax(scores) + print(display_activations(words, activations)) + W = model.projection_layer.weight.data.cpu().numpy() + bias = model.projection_layer.bias.data.cpu().numpy() + print('scores=%s, predict: %d' % (scores, predict)) + print(' bias=%s' % bias) + contributions = W * features + print(' very bad (%.4f): %s' % (scores[0], contributions[0])) + print(' bad (%.4f): %s' % (scores[1], contributions[1])) + print(' neutral (%.4f): %s' % (scores[2], contributions[2])) + print(' good (%.4f): %s' % (scores[3], contributions[3])) + print('very good (%.4f): %s' % (scores[4], contributions[4])) + + +def display_activations(words, activations): + pad_begin = (WIN_SIZE - 1) / 2 + pad_end = WIN_SIZE - 1 - pad_begin + words_padded = ['pad' for _ in range(int(pad_begin))] + words + ['pad' for _ in range(int(pad_end))] + + ngrams = [] + for act in activations: + ngrams.append('[' + ', '.join(words_padded[act:act+WIN_SIZE]) + ']') + + return ngrams + +for ITER in range(10): + # Perform training + random.shuffle(train) + train_loss = 0.0 + train_correct = 0.0 + start = time.time() + for _, wids, tag in train: + # Padding (can be done in the conv layer as well) + if len(wids) < WIN_SIZE: + wids += [0] * (WIN_SIZE - len(wids)) + words_tensor = torch.tensor(wids).type(type) + tag_tensor = torch.tensor([tag]).type(type) + scores = model(words_tensor) + predict = scores[0].argmax().item() + if predict == tag: + train_correct += 1 + + my_loss = criterion(scores, tag_tensor) + train_loss += my_loss.item() + # Do back-prop + optimizer.zero_grad() + my_loss.backward() + optimizer.step() + print("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % (ITER, train_loss/len(train), train_correct/len(train), time.time()-start)) + # Perform testing + test_correct = 0.0 + for _, wids, tag in dev: + # Padding (can be done in the conv layer as well) + if len(wids) < WIN_SIZE: + wids += [0] * (WIN_SIZE - len(wids)) + words_tensor = torch.tensor(wids).type(type) + scores = model(words_tensor) + predict = scores[0].argmax().item() + if predict == tag: + test_correct += 1 + print("iter %r: test acc=%.4f" % (ITER, test_correct/len(dev))) + + +for words, wids, tag in dev: + calc_predict_and_activations(wids, tag, words) \ No newline at end of file diff --git a/05-cnn-pytorch/cnn-class.py b/05-cnn-pytorch/cnn-class.py new file mode 100644 index 0000000..82da2d3 --- /dev/null +++ b/05-cnn-pytorch/cnn-class.py @@ -0,0 +1,108 @@ +from collections import defaultdict +import time +import random +import torch + + +class CNNclass(torch.nn.Module): + def __init__(self, nwords, emb_size, num_filters, window_size, ntags): + super(CNNclass, self).__init__() + + """ layers """ + self.embedding = torch.nn.Embedding(nwords, emb_size) + # uniform initialization + torch.nn.init.uniform_(self.embedding.weight, -0.25, 0.25) + # Conv 1d + self.conv_1d = torch.nn.Conv1d(in_channels=emb_size, out_channels=num_filters, kernel_size=window_size, + stride=1, padding=0, dilation=1, groups=1, bias=True) + self.relu = torch.nn.ReLU() + self.projection_layer = torch.nn.Linear(in_features=num_filters, out_features=ntags, bias=True) + # Initializing the projection layer + torch.nn.init.xavier_uniform_(self.projection_layer.weight) + + def forward(self, words): + emb = self.embedding(words) # nwords x emb_size + emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords + h = self.conv_1d(emb) # 1 x num_filters x nwords + # Do max pooling + h = h.max(dim=2)[0] # 1 x num_filters + h = self.relu(h) + out = self.projection_layer(h) # size(out) = 1 x ntags + return out + + +# Functions to read in the corpus +w2i = defaultdict(lambda: len(w2i)) +t2i = defaultdict(lambda: len(t2i)) +UNK = w2i[""] + + +def read_dataset(filename): + with open(filename, "r") as f: + for line in f: + tag, words = line.lower().strip().split(" ||| ") + yield ([w2i[x] for x in words.split(" ")], t2i[tag]) + + +# Read in the data +train = list(read_dataset("../data/classes/train.txt")) +w2i = defaultdict(lambda: UNK, w2i) +dev = list(read_dataset("../data/classes/test.txt")) +nwords = len(w2i) +ntags = len(t2i) + +# Define the model +EMB_SIZE = 64 +WIN_SIZE = 3 +FILTER_SIZE = 64 + +# initialize the model +model = CNNclass(nwords, EMB_SIZE, FILTER_SIZE, WIN_SIZE, ntags) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters()) + +type = torch.LongTensor +use_cuda = torch.cuda.is_available() + +if use_cuda: + type = torch.cuda.LongTensor + model.cuda() + + +for ITER in range(100): + # Perform training + random.shuffle(train) + train_loss = 0.0 + train_correct = 0.0 + start = time.time() + for words, tag in train: + # Padding (can be done in the conv layer as well) + if len(words) < WIN_SIZE: + words += [0] * (WIN_SIZE - len(words)) + words_tensor = torch.tensor(words).type(type) + tag_tensor = torch.tensor([tag]).type(type) + scores = model(words_tensor) + predict = scores[0].argmax().item() + if predict == tag: + train_correct += 1 + + my_loss = criterion(scores, tag_tensor) + train_loss += my_loss.item() + # Do back-prop + optimizer.zero_grad() + my_loss.backward() + optimizer.step() + print("iter %r: train loss/sent=%.4f, acc=%.4f, time=%.2fs" % ( + ITER, train_loss / len(train), train_correct / len(train), time.time() - start)) + # Perform testing + test_correct = 0.0 + for words, tag in dev: + # Padding (can be done in the conv layer as well) + if len(words) < WIN_SIZE: + words += [0] * (WIN_SIZE - len(words)) + words_tensor = torch.tensor(words).type(type) + scores = model(words_tensor)[0] + predict = scores.argmax().item() + if predict == tag: + test_correct += 1 + print("iter %r: test acc=%.4f" % (ITER, test_correct / len(dev))) From 166b2561c9abb9b9b81bf6e4273b1d00d17fa3f4 Mon Sep 17 00:00:00 2001 From: MysteryVaibhav Date: Wed, 23 Jan 2019 02:59:55 -0500 Subject: [PATCH 2/4] changes for activation --- 05-cnn-pytorch/cnn-activation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/05-cnn-pytorch/cnn-activation.py b/05-cnn-pytorch/cnn-activation.py index 55985ee..a52c2d3 100644 --- a/05-cnn-pytorch/cnn-activation.py +++ b/05-cnn-pytorch/cnn-activation.py @@ -48,10 +48,11 @@ def read_dataset(filename): words = words.split(" ") yield (words, [w2i[x] for x in words], int(tag)) + # Read in the data -train = list(read_dataset("../data/classes/train.txt"))[:50] +train = list(read_dataset("../data/classes/train.txt")) w2i = defaultdict(lambda: UNK, w2i) -dev = list(read_dataset("../data/classes/test.txt"))[:10] +dev = list(read_dataset("../data/classes/test.txt")) nwords = len(w2i) ntags = 5 @@ -105,6 +106,7 @@ def display_activations(words, activations): return ngrams + for ITER in range(10): # Perform training random.shuffle(train) From 4c653318efdbd437636fcff9b98f32bc3b9e8c74 Mon Sep 17 00:00:00 2001 From: MysteryVaibhav Date: Wed, 6 Feb 2019 15:37:41 -0500 Subject: [PATCH 3/4] Suggested changes --- 05-cnn-pytorch/cnn-activation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/05-cnn-pytorch/cnn-activation.py b/05-cnn-pytorch/cnn-activation.py index a52c2d3..55bf8de 100644 --- a/05-cnn-pytorch/cnn-activation.py +++ b/05-cnn-pytorch/cnn-activation.py @@ -25,7 +25,7 @@ def forward(self, words, return_activations=False): emb = self.embedding(words) # nwords x emb_size emb = emb.unsqueeze(0).permute(0, 2, 1) # 1 x emb_size x nwords h = self.conv_1d(emb) # 1 x num_filters x nwords - activations = h.squeeze().max(dim=1)[1] # argmax along length of the sentence + activations = h.squeeze(0).max(dim=1)[1] # argmax along length of the sentence # Do max pooling h = h.max(dim=2)[0] # 1 x num_filters h = self.relu(h) @@ -146,4 +146,5 @@ def display_activations(words, activations): for words, wids, tag in dev: - calc_predict_and_activations(wids, tag, words) \ No newline at end of file + calc_predict_and_activations(wids, tag, words) + input() From ff32018e2d6353621509aa0ad193b388f4d81c82 Mon Sep 17 00:00:00 2001 From: MysteryVaibhav Date: Wed, 6 Feb 2019 15:45:35 -0500 Subject: [PATCH 4/4] Adding explicit squeeze dimension --- 05-cnn-pytorch/cnn-activation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/05-cnn-pytorch/cnn-activation.py b/05-cnn-pytorch/cnn-activation.py index 55bf8de..1e86d8e 100644 --- a/05-cnn-pytorch/cnn-activation.py +++ b/05-cnn-pytorch/cnn-activation.py @@ -29,7 +29,7 @@ def forward(self, words, return_activations=False): # Do max pooling h = h.max(dim=2)[0] # 1 x num_filters h = self.relu(h) - features = h.squeeze() + features = h.squeeze(0) out = self.projection_layer(h) # size(out) = 1 x ntags if return_activations: return out, activations.data.cpu().numpy(), features.data.cpu().numpy()