Skip to content

Commit

Permalink
Fixed unnecessary importing for gsdmm and btm (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
farinamhz committed Oct 8, 2022
1 parent 4f7789d commit 0959d94
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
5 changes: 2 additions & 3 deletions src/apl/NewsTopicExtraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,8 @@ def main(news_table):
total_news_topics = {}
for index, row in processed_docs.iterrows():

if Params.tml['method'].lower() == 'btm':
import bitermplus as btm
topics = tm.doc2topics(tm_model, btm.get_vectorized_docs([' '.join(row[t_t])], dictionary), threshold=Params.evl['threshold'], just_one=Params.tml['justOne'], binary=Params.tml['binary'])
if Params.tml['method'].lower() == 'btm':
topics = tm.doc2topics(tm_model, row[t_t], threshold=Params.evl['threshold'], just_one=Params.tml['justOne'], binary=Params.tml['binary'], dic=dictionary)
else:
news_bow_corpus = dictionary.doc2bow(row[t_t])
topics = tm.doc2topics(tm_model, news_bow_corpus, threshold=Params.evl['threshold'], just_one=Params.tml['justOne'], binary=Params.tml['binary'])
Expand Down
6 changes: 4 additions & 2 deletions src/tml/TopicModeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
warnings.filterwarnings(action='ignore', category=UserWarning, module='gensim')
import gensim
from gensim.models.coherencemodel import CoherenceModel
from gsdmm import MovieGroupProcess
import csv
#import pyLDAvis
#import pyLDAvis.gensim
Expand All @@ -25,6 +24,7 @@ def topic_modeling(processed_docs, method, num_topics, filter_extremes, path_2_s

c, cv = None, None
if method.lower() == "gsdmm":
from gsdmm import MovieGroupProcess
tm_model = MovieGroupProcess(K=Params.tml['numTopics'], alpha=0.1, beta=0.1, n_iters=30)
#output = tm_model.fit(bow_corpus, len(dictionary))
tm_model.fit(bow_corpus, len(dictionary))
Expand Down Expand Up @@ -167,8 +167,10 @@ def visualization(dictionary, bow_corpus, lda_model, num_topics, path_2_save_tml
return 'Visualization is finished'


def doc2topics(lda_model, doc, threshold=0.2, just_one=True, binary=True):
def doc2topics(lda_model, doc, threshold=0.2, just_one=True, binary=True, dic=None):
if Params.tml['method'].lower() == "btm":
import bitermplus as btm
doc = btm.get_vectorized_docs([' '.join(doc)], dic)
doc_topic_vector = np.zeros((lda_model.topics_num_))
d2t_vector = lda_model.transform(doc)[0]
elif Params.tml['method'].lower() == "gsdmm":
Expand Down

0 comments on commit 0959d94

Please sign in to comment.