Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
pre-commit-ci[bot] committed Jul 3, 2023
1 parent 50a4939 commit 6740187
Showing 8 changed files with 126 additions and 61 deletions.
27 changes: 17 additions & 10 deletions mess.py
Original file line number Diff line number Diff line change
@@ -17,10 +17,14 @@
from tqdm import trange

from tweetopic._doc import init_doc_words
from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive,
predict_doc, sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf)
from tweetopic.bayesian.dmm import (
BayesianDMM,
posterior_predictive,
predict_doc,
sparse_multinomial_logpdf,
symmetric_dirichlet_logpdf,
symmetric_dirichlet_multinomial_logpdf,
)
from tweetopic.bayesian.sampling import batch_data, sample_nuts
from tweetopic.func import spread

@@ -58,23 +62,26 @@ def logprior_fn(params):

def loglikelihood_fn(params, data):
doc_likelihood = jax.vmap(
partial(sparse_multinomial_logpdf, component=params["component"])
partial(sparse_multinomial_logpdf, component=params["component"]),
)
return jnp.sum(
doc_likelihood(
unique_words=data["doc_unique_words"],
unique_word_counts=data["doc_unique_word_counts"],
)
),
)


logdensity_fn(position)

logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn(
params, data
params,
data,
)
grad_estimator = blackjax.sgmcmc.gradients.grad_estimator(
logprior_fn, loglikelihood_fn, data_size=n_documents
logprior_fn,
loglikelihood_fn,
data_size=n_documents,
)
rng_key = jax.random.PRNGKey(0)
batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3)
@@ -88,8 +95,8 @@ def loglikelihood_fn(params, data):
)
position = dict(
component=jnp.array(
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha)))
)
transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))),
),
)

samples, states = sample_nuts(position, logdensity_fn)
43 changes: 32 additions & 11 deletions tweetopic/_btm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Module for utility functions for fitting BTMs"""
"""Module for utility functions for fitting BTMs."""

import random
from typing import Dict, Tuple, TypeVar
@@ -12,7 +12,8 @@

@njit
def doc_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
(n_max_unique_words,) = doc_unique_words.shape
biterm_counts = dict()
@@ -43,7 +44,7 @@ def doc_unique_biterms(

@njit
def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):
"""Adds one counter dict to another in place with Numba"""
"""Adds one counter dict to another in place with Numba."""
for key in source:
if key in dest:
dest[key] += source[key]
@@ -53,25 +54,28 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]):

@njit
def corpus_unique_biterms(
doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray
doc_unique_words: np.ndarray,
doc_unique_word_counts: np.ndarray,
) -> Dict[Tuple[int, int], int]:
n_documents, _ = doc_unique_words.shape
biterm_counts = doc_unique_biterms(
doc_unique_words[0], doc_unique_word_counts[0]
doc_unique_words[0],
doc_unique_word_counts[0],
)
for i_doc in range(1, n_documents):
doc_unique_words_i = doc_unique_words[i_doc]
doc_unique_word_counts_i = doc_unique_word_counts[i_doc]
doc_biterms = doc_unique_biterms(
doc_unique_words_i, doc_unique_word_counts_i
doc_unique_words_i,
doc_unique_word_counts_i,
)
nb_add_counter(biterm_counts, doc_biterms)
return biterm_counts


@njit
def compute_biterm_set(
biterm_counts: Dict[Tuple[int, int], int]
biterm_counts: Dict[Tuple[int, int], int],
) -> np.ndarray:
return np.array(list(biterm_counts.keys()))

@@ -116,7 +120,12 @@ def add_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
True,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


@@ -129,7 +138,12 @@ def remove_biterm(
topic_biterm_count: np.ndarray,
) -> None:
add_remove_biterm(
False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
False,
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)


@@ -147,7 +161,11 @@ def init_components(
i_topic = random.randint(0, n_components - 1)
biterm_topic_assignments[i_biterm] = i_topic
add_biterm(
i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count
i_biterm,
i_topic,
biterms,
topic_word_count,
topic_biterm_count,
)
return biterm_topic_assignments, topic_word_count, topic_biterm_count

@@ -448,7 +466,10 @@ def predict_docs(
)
biterms = doc_unique_biterms(words, word_counts)
prob_topic_given_document(
pred, biterms, topic_distribution, topic_word_distribution
pred,
biterms,
topic_distribution,
topic_word_distribution,
)
predictions[i_doc, :] = pred
return predictions
3 changes: 2 additions & 1 deletion tweetopic/_dmm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model."""
"""Module containing tools for fitting a Dirichlet Multinomial Mixture
Model."""
from __future__ import annotations

from math import exp, log
2 changes: 1 addition & 1 deletion tweetopic/_doc.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ def init_doc_words(
n_docs, _ = doc_term_matrix.shape
doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32)
doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype(
np.uint32
np.uint32,
)
for i_doc in range(n_docs):
unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore
68 changes: 46 additions & 22 deletions tweetopic/bayesian/dmm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""JAX implementation of probability densities and parameter initialization
for the Dirichlet Multinomial Mixture Model."""
"""JAX implementation of probability densities and parameter initialization for
the Dirichlet Multinomial Mixture Model."""
from functools import partial

import jax
@@ -22,12 +22,18 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int):


def init_parameters(
n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float
n_docs: int,
n_vocab: int,
n_components: int,
alpha: float,
beta: float,
) -> dict:
"""Initializes the parameters of the dmm to the mean of the prior."""
return dict(
weights=symmetric_dirichlet_multinomial_mean(
alpha, n_docs, n_components
alpha,
n_docs,
n_components,
),
components=np.broadcast_to(
scipy.stats.dirichlet.mean(np.full(n_vocab, beta)),
@@ -41,13 +47,15 @@ def sparse_multinomial_logpdf(
unique_words,
unique_word_counts,
):
"""Calculates joint multinomial probability of a sparse representation"""
"""Calculates joint multinomial probability of a sparse representation."""
unique_word_counts = jnp.float64(unique_word_counts)
n_words = jnp.sum(unique_word_counts)
n_factorial = jax.lax.lgamma(n_words + 1)
word_count_factorial = jax.lax.lgamma(unique_word_counts + 1)
word_count_factorial = jnp.where(
unique_word_counts != 0, word_count_factorial, 0
unique_word_counts != 0,
word_count_factorial,
0,
)
denominator = jnp.sum(word_count_factorial)
probs = component[unique_words]
@@ -84,18 +92,18 @@ def symmetric_dirichlet_multinomial_logpdf(x, n, alpha):


def predict_doc(components, weights, unique_words, unique_word_counts):
"""Predicts likelihood of a document belonging to
each cluster based on given parameters."""
"""Predicts likelihood of a document belonging to each cluster based on
given parameters."""
component_logpdf = partial(
sparse_multinomial_logpdf,
unique_words=unique_words,
unique_word_counts=unique_word_counts,
)
component_logprobs = jax.lax.map(component_logpdf, components) + jnp.log(
weights
weights,
)
norm_probs = jnp.exp(
component_logprobs - jax.scipy.special.logsumexp(component_logprobs)
component_logprobs - jax.scipy.special.logsumexp(component_logprobs),
)
return norm_probs

@@ -106,24 +114,31 @@ def predict_one(unique_words, unique_word_counts, components, weights):
predict_doc,
unique_words=unique_words,
unique_word_counts=unique_word_counts,
)
),
)(components, weights)


def posterior_predictive(
doc_unique_words, doc_unique_word_counts, components, weights
doc_unique_words,
doc_unique_word_counts,
components,
weights,
):
"""Predicts probability of a document belonging to each component
for all posterior samples.
"""
"""Predicts probability of a document belonging to each component for all
posterior samples."""
predict_all = jax.vmap(
partial(predict_one, components=components, weights=weights)
partial(predict_one, components=components, weights=weights),
)
return predict_all(doc_unique_words, doc_unique_word_counts)


def dmm_loglikelihood(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
components,
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
):
docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1)

@@ -135,7 +150,8 @@ def doc_likelihood(doc):
unique_word_counts=unique_word_counts,
)
component_logprobs = jax.lax.map(
component_logpdf, components
component_logpdf,
components,
) + jnp.log(weights)
return jax.scipy.special.logsumexp(component_logprobs)

@@ -146,17 +162,25 @@ def doc_likelihood(doc):
def dmm_logprior(components, weights, alpha, beta, n_docs):
components_prior = jnp.sum(
jax.lax.map(
partial(symmetric_dirichlet_logpdf, alpha=alpha), components
)
partial(symmetric_dirichlet_logpdf, alpha=alpha),
components,
),
)
weights_prior = symmetric_dirichlet_multinomial_logpdf(
weights, n=jnp.float64(n_docs), alpha=beta
weights,
n=jnp.float64(n_docs),
alpha=beta,
)
return components_prior + weights_prior


def dmm_logpdf(
components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta
components,
weights,
doc_unique_words,
doc_unique_word_counts,
alpha,
beta,
):
"""Calculates logdensity of the DMM at a given point in parameter space."""
n_docs = doc_unique_words.shape[0]
Loading

0 comments on commit 6740187

Please sign in to comment.