Skip to content

Commit

Permalink
Including simplex dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
cemoody committed May 31, 2016
1 parent a38f5e8 commit 93f2777
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 11 deletions.
4 changes: 2 additions & 2 deletions examples/twenty_newsgroups/lda2vec/lda2vec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class LDA2Vec(Chain):
def __init__(self, n_documents=100, n_document_topics=10,
n_units=256, n_vocab=1000, dropout_ratio=0.5, train=True,
counts=None, n_samples=15, word_dropout_ratio=0.0,
power=0.75):
power=0.75, temperature=1.0):
em = EmbedMixture(n_documents, n_document_topics, n_units,
dropout_ratio=dropout_ratio)
dropout_ratio=dropout_ratio, temperature=temperature)
kwargs = {}
kwargs['mixture'] = em
kwargs['sampler'] = L.NegativeSampling(n_units, counts, n_samples,
Expand Down
15 changes: 9 additions & 6 deletions examples/twenty_newsgroups/lda2vec/lda2vec_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@
power = float(os.getenv('power', 0.75))
# Intialize with pretrained word vectors
pretrained = bool(int(os.getenv('pretrained', True)))
# Sampling temperature
temperature = float(os.getenv('temperature', 1.0))
# Number of dimensions in a single word vector
n_units = int(os.getenv('n_units', 300))
# Get the string representation for every compact key
Expand All @@ -69,7 +71,7 @@

model = LDA2Vec(n_documents=n_docs, n_document_topics=n_topics,
n_units=n_units, n_vocab=n_vocab, counts=term_frequency,
n_samples=15, power=power)
n_samples=15, power=power, temperature=temperature)
if os.path.exists('lda2vec.hdf5'):
print "Reloading from saved"
serializers.load_hdf5("lda2vec.hdf5", model)
Expand All @@ -91,11 +93,12 @@
cuda.to_cpu(model.sampler.W.data).copy(),
words)
top_words = print_top_words_per_topic(data)
coherence = topic_coherence(top_words)
for j in range(n_topics):
print j, coherence[(j, 'cv')]
kw = dict(top_words=top_words, coherence=coherence, epoch=epoch)
progress[str(epoch)] = pickle.dumps(kw)
if j % 100 == 0 and j > 100:
coherence = topic_coherence(top_words)
for j in range(n_topics):
print j, coherence[(j, 'cv')]
kw = dict(top_words=top_words, coherence=coherence, epoch=epoch)
progress[str(epoch)] = pickle.dumps(kw)
data['doc_lengths'] = doc_lengths
data['term_frequency'] = term_frequency
np.savez('topics.pyldavis', **data)
Expand Down
Binary file modified examples/twenty_newsgroups/lda2vec/topics.pyldavis.npz
Binary file not shown.
17 changes: 14 additions & 3 deletions lda2vec/embed_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import chainer
import chainer.links as L
import chainer.functions as F
from chainer import Variable


def _orthogonal_matrix(shape):
Expand Down Expand Up @@ -60,7 +61,8 @@ class EmbedMixture(chainer.Chain):
.. seealso:: :func:`lda2vec.dirichlet_likelihood`
"""

def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2):
def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2,
temperature=1.0):
self.n_documents = n_documents
self.n_topics = n_topics
self.n_dim = n_dim
Expand All @@ -70,6 +72,7 @@ def __init__(self, n_documents, n_topics, n_dim, dropout_ratio=0.2):
super(EmbedMixture, self).__init__(
weights=L.EmbedID(n_documents, n_topics),
factors=L.Parameter(factors))
self.temperature = temperature
self.weights.W.data[...] /= np.sqrt(n_documents + n_topics)

def __call__(self, doc_ids, update_only_docs=False):
Expand Down Expand Up @@ -102,5 +105,13 @@ def proportions(self, doc_ids, softmax=False):
doc_weights : chainer.Variable
Two dimensional topic weights of each document.
"""
w = F.dropout(self.weights(doc_ids), ratio=self.dropout_ratio)
return F.softmax(w) if softmax else w
w = self.weights(doc_ids)
if softmax:
size = w.data.shape
mask = self.xp.random.random_integers(0, 1, size=size)
y = (F.softmax(w * self.temperature) *
Variable(mask.astype('float32')))
norm, y = F.broadcast(F.expand_dims(F.sum(y, axis=1), 1), y)
return y / (norm + 1e-7)
else:
return w

0 comments on commit 93f2777

Please sign in to comment.