From 85c338d17661a2d0df38c4872e9e8d5bf77ea39c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=88=98=E6=B5=A9=E4=B9=9F=E8=A6=81=E4=B8=80=E4=B8=AA?= =?UTF-8?q?=E5=8F=AF=E7=88=B1=E7=9A=84=E5=90=8D=E5=AD=97=E5=B9=B6=E4=B8=94?= =?UTF-8?q?=E6=83=B3=E8=AF=95=E8=AF=95=E8=BF=99=E4=B8=AA=E5=90=8D=E5=AD=97?= =?UTF-8?q?=E6=9C=80=E5=A4=9A=E5=8F=AF=E4=BB=A5=E6=9C=89=E5=A4=9A=E9=95=BF?= =?UTF-8?q?=E6=89=80=E4=BB=A5=E5=A4=9A=E6=89=93=E4=BA=86=E5=BE=88=E5=A4=9A?= =?UTF-8?q?=E5=BE=88=E5=A4=9A=E5=BE=88=E5=A4=9A=E5=BE=88=E5=A4=9A=E5=BE=88?= =?UTF-8?q?=E5=A4=9A=E5=AD=97?= Date: Wed, 21 Apr 2021 16:04:25 +0800 Subject: [PATCH] rewirte with torch in matrix format, support GPU, change the output format --- em.py | 93 +++++++++++++++++++++++++---------------------------------- 1 file changed, 39 insertions(+), 54 deletions(-) diff --git a/em.py b/em.py index 74d0443..083b333 100644 --- a/em.py +++ b/em.py @@ -1,13 +1,18 @@ import argparse import os import numpy as np +import torch +os.environ["CUDA_VISIBLE_DEVICES"] = "1" +device = 'cuda' if torch.cuda.is_available() else 'cpu' def read(path): word_map = {} # id to word - freq_map = {} with open(os.path.join(path, "20news.vocab"), "r") as f: - for line in f: + flines=f.readlines() + W = len(flines) + freq_map = torch.zeros([W]) + for line in flines: l = line.strip().split() i = int(l[0]) word = l[1] @@ -15,20 +20,18 @@ def read(path): word_map[i] = word freq_map[i] = f - documents = [] - word_to_doc = [[] for _ in range(len(word_map))] with open(os.path.join(path, "20news.libsvm"), "r") as f: - for i, line in enumerate(f): - doc = {} + flines=f.readlines() + D = len(flines) + T = torch.zeros([D,W]) + for i, line in enumerate(flines): l = line.strip().split() for pair in l[1:]: l2 = pair.split(":") word, freq = int(l2[0]), int(l2[1]) - doc[word] = freq - word_to_doc[word].append(i) - documents.append(doc) + T[i][word] = freq - return word_map, freq_map, documents, word_to_doc + return D, W , word_map, T parser = argparse.ArgumentParser() @@ -38,51 +41,33 @@ def read(path): help="Data directory") args = parser.parse_args() -word_map, freq_map, documents, word_to_doc = read(args.data) -W = len(word_map) -D = len(documents) -K = args.K -# for K in [10, 20, 50, 100]: +D,W,word_map, T = read(args.data) -# initialization -pi = np.random.random([K]) -mu = np.random.random([K, W]) -for k in range(K): - pi[k] /= np.sum(pi[k]) -for k in range(K): - mu[k] /= np.sum(mu[k]) +T=T.to(device) -step = 0 -eps = 1e-10 -pi_old = np.zeros([K], dtype=float) -while np.linalg.norm(pi_old - pi) > 1e-3: - # E step - gamma = np.zeros([D, K], dtype=float) - for d in range(D): - for k in range(K): - gamma[d][k] = np.log(pi[k]) - for w in documents[d]: - gamma[d][k] += documents[d][w] * np.log(mu[k][w] + eps) - maxn = max(gamma[d]) - for k in range(K): - gamma[d][k] = np.exp(gamma[d][k] - maxn) - gamma[d] = gamma[d] / np.sum(gamma[d]) +for K in [10, 20, 30, 50]: + pi = torch.softmax(torch.randn([K]),dim=0).to(device) + mu = torch.softmax(torch.randn([W,K]),dim=1).to(device) + step = 0 + eps = 1e-10 + pi_old = torch.zeros([K]).to(device) - # M step - pi_old = pi - pi = np.sum(gamma, axis=0) / np.sum(gamma) - for k in range(K): - for w in range(W): - mu[k][w] = 0 - for d in word_to_doc[w]: - mu[k][w] += gamma[d][k] * documents[d][w] - mu[k] = mu[k] / np.sum(mu[k]) - - print("K=%d, step=%d, norm-diff=%f" % (K, step, np.linalg.norm(pi_old - pi))) - step += 1 + while torch.norm(pi_old - pi) > 1e-3: + # E step + gamma = torch.softmax( T.mm(torch.log(mu+eps)) + torch.log(pi).t() ,dim =1 ) -for k in range(K): - print("Topic %d:" % k) - topics = np.argsort(mu[k])[::-1] - for i in range(min(10, len(topics))): - print(" %s: %f" % (word_map[topics[i]], mu[k][topics[i]])) + # M step + pi_old = pi + pi = torch.mean(gamma,dim=0) + mu = T.t().mm(gamma) + mu = mu/torch.sum(mu,dim=0) + + print("K=%d, step=%d, onestep-diff=%f" % (K, step, torch.norm(pi_old - pi))) + step += 1 + + for k in range(K): + print("Topic %d:" % k,end='') + topics = np.argsort(mu.cpu().numpy()[:,k])[::-1] + for i in range(min(10, len(topics))): + print(" %s"% (word_map[topics[i]]),end=',') + print("\n") \ No newline at end of file