-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathREADME.Rmd
259 lines (185 loc) · 8.55 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
---
output: github_document
---
<!-- README.md is generated from README.Rmd. Please edit that file -->
```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```
# tidylda <img src='man/figures/logo.png' align="right" height="136.5" />
<!-- badges: start -->
[![DOI](https://joss.theoj.org/papers/10.21105/joss.06800/status.svg)](https://doi.org/10.21105/joss.06800)
[![Codecov test coverage](https://codecov.io/gh/TommyJones/tidylda/branch/main/graph/badge.svg)](https://app.codecov.io/gh/tommyjones/tidylda/branch/main)
[![R-CMD-check](https://GitHub.com/TommyJones/tidylda/actions/workflows/R-CMD-check.yaml/badge.svg)](https://GitHub.com/TommyJones/tidylda/actions/workflows/R-CMD-check.yaml)
[![Lifecycle: stable](https://img.shields.io/badge/lifecycle-stable-brightgreen.svg)](https://lifecycle.r-lib.org/articles/stages.html#stable)
<!-- badges: end -->
Latent Dirichlet Allocation Using 'tidyverse' Conventions
`tidylda` implements an algorithm for Latent Dirichlet Allocation using style conventions from the [tidyverse](https://style.tidyverse.org/) and [tidymodels](https://tidymodels.GitHub.io/model-implementation-principles/).
In addition this implementation of LDA allows you to:
* use asymmetric prior parameters alpha and eta
* use a matrix prior parameter, eta, to seed topics into a model
* use a previously-trained model as a prior for a new model
* apply LDA in a transfer-learning paradigm, updating a model's parameters with additional data (or additional iterations)
## Installation
You can install the latest CRAN release with:
``` r
install("tidylda")
```
You can install the development version from [GitHub](https://GitHub.com/) with:
``` r
install.packages("remotes")
remotes::install_GitHub("tommyjones/tidylda")
```
For a list of dependencies see the DESCRIPTION file.
# Getting started
This package is still in its early stages of development. However, some basic functionality is below. Here, we will use the `tidytext` package to create a document term matrix, fit a topic model, predict topics of unseen documents, and update the model with those new documents.
`tidylda` uses the following naming conventions for topic models:
* `theta` is a matrix whose rows are distributions of topics over documents, or P(topic|document)
* `beta` is a matrix whose rows are distributions of tokens over topics, or P(token|topic)
* `lambda` is a matrix whose rows are distributions of topics over tokens, or P(topic|token)
`lambda` is useful for making predictions with a computationally-simple and efficient dot product and it may be interesting to analyze in its own right.
* `alpha` is the prior that tunes `theta`
* `eta` is the prior that tunes `beta`
## Example
```{r example}
library(tidytext)
library(dplyr)
library(ggplot2)
library(tidyr)
library(tidylda)
library(Matrix)
### Initial set up ---
# load some documents
docs <- nih_sample
# tokenize using tidytext's unnest_tokens
tidy_docs <- docs %>%
select(APPLICATION_ID, ABSTRACT_TEXT) %>%
unnest_tokens(output = word,
input = ABSTRACT_TEXT,
stopwords = stop_words$word,
token = "ngrams",
n_min = 1, n = 2) %>%
count(APPLICATION_ID, word) %>%
filter(n>1) #Filtering for words/bigrams per document, rather than per corpus
tidy_docs <- tidy_docs %>% # filter words that are just numbers
filter(! stringr::str_detect(tidy_docs$word, "^[0-9]+$"))
# append observation level data
colnames(tidy_docs)[1:2] <- c("document", "term")
# turn a tidy tbl into a sparse dgCMatrix
# note tidylda has support for several document term matrix formats
d <- tidy_docs %>%
cast_sparse(document, term, n)
# let's split the documents into two groups to demonstrate predictions and updates
d1 <- d[1:50, ]
d2 <- d[51:nrow(d), ]
# make sure we have different vocabulary for each data set to simulate the "real world"
# where you get new tokens coming in over time
d1 <- d1[, colSums(d1) > 0]
d2 <- d2[, colSums(d2) > 0]
### fit an intial model and inspect it ----
set.seed(123)
lda <- tidylda(
data = d1,
k = 10,
iterations = 200,
burnin = 175,
alpha = 0.1, # also accepts vector inputs
eta = 0.05, # also accepts vector or matrix inputs
optimize_alpha = FALSE, # experimental
calc_likelihood = TRUE,
calc_r2 = TRUE, # see https://arxiv.org/abs/1911.11061
return_data = FALSE
)
# did the model converge?
# there are actual test stats for this, but should look like "yes"
qplot(x = iteration, y = log_likelihood, data = lda$log_likelihood, geom = "line") +
ggtitle("Checking model convergence")
# look at the model overall
glance(lda)
print(lda)
# it comes with its own summary matrix that's printed out with print(), above
lda$summary
# inspect the individual matrices
tidy_theta <- tidy(lda, matrix = "theta")
tidy_theta
tidy_beta <- tidy(lda, matrix = "beta")
tidy_beta
tidy_lambda <- tidy(lda, matrix = "lambda")
tidy_lambda
# append observation-level data
augmented_docs <- augment(lda, data = tidy_docs)
augmented_docs
### predictions on held out data ---
# two methods: gibbs is cleaner and more technically correct in the bayesian sense
p_gibbs <- predict(lda, new_data = d2[1, ], iterations = 100, burnin = 75)
# dot is faster, less prone to error (e.g. underflow), noisier, and frequentist
p_dot <- predict(lda, new_data = d2[1, ], method = "dot")
# pull both together into a plot to compare
tibble(topic = 1:ncol(p_gibbs), gibbs = p_gibbs[1,], dot = p_dot[1, ]) %>%
pivot_longer(cols = gibbs:dot, names_to = "type") %>%
ggplot() +
geom_bar(mapping = aes(x = topic, y = value, group = type, fill = type),
stat = "identity", position="dodge") +
scale_x_continuous(breaks = 1:10, labels = 1:10) +
ggtitle("Gibbs predictions vs. dot product predictions")
### Augment as an implicit prediction using the 'dot' method ----
# Aggregating over terms results in a distribution of topics over documents
# roughly equivalent to using the "dot" method of predictions.
augment_predict <-
augment(lda, tidy_docs, "prob") %>%
group_by(document) %>%
select(-c(document, term)) %>%
summarise_all(function(x) sum(x, na.rm = T))
# reformat for easy plotting
augment_predict <-
as_tibble(t(augment_predict[, -c(1,2)]), .name_repair = "minimal")
colnames(augment_predict) <- unique(tidy_docs$document)
augment_predict$topic <- 1:nrow(augment_predict) %>% as.factor()
compare_mat <-
augment_predict %>%
select(
topic,
augment = matches(rownames(d2)[1])
) %>%
mutate(
augment = augment / sum(augment), # normalize to sum to 1
dot = p_dot[1, ]
) %>%
pivot_longer(cols = c(augment, dot))
ggplot(compare_mat) +
geom_bar(aes(y = value, x = topic, group = name, fill = name),
stat = "identity", position = "dodge") +
labs(title = "Prediction using 'augment' vs 'predict(..., method = \"dot\")'")
# Not shown: aggregating over documents results in recovering the "tidy" lambda.
### updating the model ----
# now that you have new documents, maybe you want to fold them into the model?
lda2 <- refit(
object = lda,
new_data = d, # save me the trouble of manually-combining these by just using d
iterations = 200,
burnin = 175,
calc_likelihood = TRUE,
calc_r2 = TRUE
)
# we can do similar analyses
# did the model converge?
qplot(x = iteration, y = log_likelihood, data = lda2$log_likelihood, geom = "line") +
ggtitle("Checking model convergence")
# look at the model overall
glance(lda2)
print(lda2)
# how does that compare to the old model?
print(lda)
```
There are several vignettes available in [/vignettes](https://GitHub.com/TommyJones/tidylda/tree/main/vignettes). They can be compiled using `knitr` or you can view [PDF versions on CRAN](https://CRAN.R-project.org/package=tidylda).
See NEWS.md for a changelog, including changes from the CRAN release to the development version on GitHub.
See the "Issues" tab on GitHub to see planned features as well as bug fixes.
# Contributions
If you would like to contribute to this package, please start by opening an issue on GitHub. Direct contributions via merge requests are discouraged unless invited to do so.
If you have any suggestions for additional functionality, changes to functionality, changes to arguments or other aspects of the API please let me know by opening an issue on GitHub or sending me an email: jones.thos.w at gmail.com.
## Code of Conduct
Please note that the tidylda project is released with a [Contributor Code of Conduct](https://contributor-covenant.org/version/2/1/CODE_OF_CONDUCT.html). By contributing to this project, you agree to abide by its terms.