diff --git a/mess.py b/mess.py index 3d544f4..490631e 100644 --- a/mess.py +++ b/mess.py @@ -17,10 +17,14 @@ from tqdm import trange from tweetopic._doc import init_doc_words -from tweetopic.bayesian.dmm import (BayesianDMM, posterior_predictive, - predict_doc, sparse_multinomial_logpdf, - symmetric_dirichlet_logpdf, - symmetric_dirichlet_multinomial_logpdf) +from tweetopic.bayesian.dmm import ( + BayesianDMM, + posterior_predictive, + predict_doc, + sparse_multinomial_logpdf, + symmetric_dirichlet_logpdf, + symmetric_dirichlet_multinomial_logpdf, +) from tweetopic.bayesian.sampling import batch_data, sample_nuts from tweetopic.func import spread @@ -58,23 +62,26 @@ def logprior_fn(params): def loglikelihood_fn(params, data): doc_likelihood = jax.vmap( - partial(sparse_multinomial_logpdf, component=params["component"]) + partial(sparse_multinomial_logpdf, component=params["component"]), ) return jnp.sum( doc_likelihood( unique_words=data["doc_unique_words"], unique_word_counts=data["doc_unique_word_counts"], - ) + ), ) logdensity_fn(position) logdensity_fn = lambda params: logprior_fn(params) + loglikelihood_fn( - params, data + params, + data, ) grad_estimator = blackjax.sgmcmc.gradients.grad_estimator( - logprior_fn, loglikelihood_fn, data_size=n_documents + logprior_fn, + loglikelihood_fn, + data_size=n_documents, ) rng_key = jax.random.PRNGKey(0) batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3) @@ -88,8 +95,8 @@ def loglikelihood_fn(params, data): ) position = dict( component=jnp.array( - transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))) - ) + transform(stats.dirichlet.mean(alpha=np.full(n_features, alpha))), + ), ) samples, states = sample_nuts(position, logdensity_fn) diff --git a/tweetopic/_btm.py b/tweetopic/_btm.py index b485fec..c336f38 100644 --- a/tweetopic/_btm.py +++ b/tweetopic/_btm.py @@ -1,4 +1,4 @@ -"""Module for utility functions for fitting BTMs""" +"""Module for utility functions for fitting BTMs.""" import random from typing import Dict, Tuple, TypeVar @@ -12,7 +12,8 @@ @njit def doc_unique_biterms( - doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray + doc_unique_words: np.ndarray, + doc_unique_word_counts: np.ndarray, ) -> Dict[Tuple[int, int], int]: (n_max_unique_words,) = doc_unique_words.shape biterm_counts = dict() @@ -43,7 +44,7 @@ def doc_unique_biterms( @njit def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]): - """Adds one counter dict to another in place with Numba""" + """Adds one counter dict to another in place with Numba.""" for key in source: if key in dest: dest[key] += source[key] @@ -53,17 +54,20 @@ def nb_add_counter(dest: Dict[T, int], source: Dict[T, int]): @njit def corpus_unique_biterms( - doc_unique_words: np.ndarray, doc_unique_word_counts: np.ndarray + doc_unique_words: np.ndarray, + doc_unique_word_counts: np.ndarray, ) -> Dict[Tuple[int, int], int]: n_documents, _ = doc_unique_words.shape biterm_counts = doc_unique_biterms( - doc_unique_words[0], doc_unique_word_counts[0] + doc_unique_words[0], + doc_unique_word_counts[0], ) for i_doc in range(1, n_documents): doc_unique_words_i = doc_unique_words[i_doc] doc_unique_word_counts_i = doc_unique_word_counts[i_doc] doc_biterms = doc_unique_biterms( - doc_unique_words_i, doc_unique_word_counts_i + doc_unique_words_i, + doc_unique_word_counts_i, ) nb_add_counter(biterm_counts, doc_biterms) return biterm_counts @@ -71,7 +75,7 @@ def corpus_unique_biterms( @njit def compute_biterm_set( - biterm_counts: Dict[Tuple[int, int], int] + biterm_counts: Dict[Tuple[int, int], int], ) -> np.ndarray: return np.array(list(biterm_counts.keys())) @@ -116,7 +120,12 @@ def add_biterm( topic_biterm_count: np.ndarray, ) -> None: add_remove_biterm( - True, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + True, + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) @@ -129,7 +138,12 @@ def remove_biterm( topic_biterm_count: np.ndarray, ) -> None: add_remove_biterm( - False, i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + False, + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) @@ -147,7 +161,11 @@ def init_components( i_topic = random.randint(0, n_components - 1) biterm_topic_assignments[i_biterm] = i_topic add_biterm( - i_biterm, i_topic, biterms, topic_word_count, topic_biterm_count + i_biterm, + i_topic, + biterms, + topic_word_count, + topic_biterm_count, ) return biterm_topic_assignments, topic_word_count, topic_biterm_count @@ -448,7 +466,10 @@ def predict_docs( ) biterms = doc_unique_biterms(words, word_counts) prob_topic_given_document( - pred, biterms, topic_distribution, topic_word_distribution + pred, + biterms, + topic_distribution, + topic_word_distribution, ) predictions[i_doc, :] = pred return predictions diff --git a/tweetopic/_dmm.py b/tweetopic/_dmm.py index 21f59f6..7d5abc6 100644 --- a/tweetopic/_dmm.py +++ b/tweetopic/_dmm.py @@ -1,4 +1,5 @@ -"""Module containing tools for fitting a Dirichlet Multinomial Mixture Model.""" +"""Module containing tools for fitting a Dirichlet Multinomial Mixture +Model.""" from __future__ import annotations from math import exp, log diff --git a/tweetopic/_doc.py b/tweetopic/_doc.py index 657c6dc..e66e65a 100644 --- a/tweetopic/_doc.py +++ b/tweetopic/_doc.py @@ -11,7 +11,7 @@ def init_doc_words( n_docs, _ = doc_term_matrix.shape doc_unique_words = np.zeros((n_docs, max_unique_words)).astype(np.uint32) doc_unique_word_counts = np.zeros((n_docs, max_unique_words)).astype( - np.uint32 + np.uint32, ) for i_doc in range(n_docs): unique_words = doc_term_matrix[i_doc].rows[0] # type: ignore diff --git a/tweetopic/bayesian/dmm.py b/tweetopic/bayesian/dmm.py index 10cb4a5..4d961fc 100644 --- a/tweetopic/bayesian/dmm.py +++ b/tweetopic/bayesian/dmm.py @@ -1,5 +1,5 @@ -"""JAX implementation of probability densities and parameter initialization -for the Dirichlet Multinomial Mixture Model.""" +"""JAX implementation of probability densities and parameter initialization for +the Dirichlet Multinomial Mixture Model.""" from functools import partial import jax @@ -22,12 +22,18 @@ def symmetric_dirichlet_multinomial_mean(alpha: float, n: int, K: int): def init_parameters( - n_docs: int, n_vocab: int, n_components: int, alpha: float, beta: float + n_docs: int, + n_vocab: int, + n_components: int, + alpha: float, + beta: float, ) -> dict: """Initializes the parameters of the dmm to the mean of the prior.""" return dict( weights=symmetric_dirichlet_multinomial_mean( - alpha, n_docs, n_components + alpha, + n_docs, + n_components, ), components=np.broadcast_to( scipy.stats.dirichlet.mean(np.full(n_vocab, beta)), @@ -41,13 +47,15 @@ def sparse_multinomial_logpdf( unique_words, unique_word_counts, ): - """Calculates joint multinomial probability of a sparse representation""" + """Calculates joint multinomial probability of a sparse representation.""" unique_word_counts = jnp.float64(unique_word_counts) n_words = jnp.sum(unique_word_counts) n_factorial = jax.lax.lgamma(n_words + 1) word_count_factorial = jax.lax.lgamma(unique_word_counts + 1) word_count_factorial = jnp.where( - unique_word_counts != 0, word_count_factorial, 0 + unique_word_counts != 0, + word_count_factorial, + 0, ) denominator = jnp.sum(word_count_factorial) probs = component[unique_words] @@ -84,18 +92,18 @@ def symmetric_dirichlet_multinomial_logpdf(x, n, alpha): def predict_doc(components, weights, unique_words, unique_word_counts): - """Predicts likelihood of a document belonging to - each cluster based on given parameters.""" + """Predicts likelihood of a document belonging to each cluster based on + given parameters.""" component_logpdf = partial( sparse_multinomial_logpdf, unique_words=unique_words, unique_word_counts=unique_word_counts, ) component_logprobs = jax.lax.map(component_logpdf, components) + jnp.log( - weights + weights, ) norm_probs = jnp.exp( - component_logprobs - jax.scipy.special.logsumexp(component_logprobs) + component_logprobs - jax.scipy.special.logsumexp(component_logprobs), ) return norm_probs @@ -106,24 +114,31 @@ def predict_one(unique_words, unique_word_counts, components, weights): predict_doc, unique_words=unique_words, unique_word_counts=unique_word_counts, - ) + ), )(components, weights) def posterior_predictive( - doc_unique_words, doc_unique_word_counts, components, weights + doc_unique_words, + doc_unique_word_counts, + components, + weights, ): - """Predicts probability of a document belonging to each component - for all posterior samples. - """ + """Predicts probability of a document belonging to each component for all + posterior samples.""" predict_all = jax.vmap( - partial(predict_one, components=components, weights=weights) + partial(predict_one, components=components, weights=weights), ) return predict_all(doc_unique_words, doc_unique_word_counts) def dmm_loglikelihood( - components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta + components, + weights, + doc_unique_words, + doc_unique_word_counts, + alpha, + beta, ): docs = jnp.stack((doc_unique_words, doc_unique_word_counts), axis=1) @@ -135,7 +150,8 @@ def doc_likelihood(doc): unique_word_counts=unique_word_counts, ) component_logprobs = jax.lax.map( - component_logpdf, components + component_logpdf, + components, ) + jnp.log(weights) return jax.scipy.special.logsumexp(component_logprobs) @@ -146,17 +162,25 @@ def doc_likelihood(doc): def dmm_logprior(components, weights, alpha, beta, n_docs): components_prior = jnp.sum( jax.lax.map( - partial(symmetric_dirichlet_logpdf, alpha=alpha), components - ) + partial(symmetric_dirichlet_logpdf, alpha=alpha), + components, + ), ) weights_prior = symmetric_dirichlet_multinomial_logpdf( - weights, n=jnp.float64(n_docs), alpha=beta + weights, + n=jnp.float64(n_docs), + alpha=beta, ) return components_prior + weights_prior def dmm_logpdf( - components, weights, doc_unique_words, doc_unique_word_counts, alpha, beta + components, + weights, + doc_unique_words, + doc_unique_word_counts, + alpha, + beta, ): """Calculates logdensity of the DMM at a given point in parameter space.""" n_docs = doc_unique_words.shape[0] diff --git a/tweetopic/bayesian/sampling.py b/tweetopic/bayesian/sampling.py index 126ba84..ed28a2d 100644 --- a/tweetopic/bayesian/sampling.py +++ b/tweetopic/bayesian/sampling.py @@ -53,7 +53,9 @@ def sample_nuts( print("Warmup, window adaptation") warmup = blackjax.window_adaptation(blackjax.nuts, logdensity_fn) (state, parameters), _ = warmup.run( - warmup_key, initial_position, num_steps=n_warmup + warmup_key, + initial_position, + num_steps=n_warmup, ) kernel = jax.jit(blackjax.nuts(logdensity_fn, **parameters).step) states = [] @@ -75,8 +77,8 @@ def sample_sgld( decay: float = 2.5, data_axis: int = 0, ) -> tuple[list[PyTree], None]: - """Stochastic Gradient Langevin Dynamics sampling loop with decaying step size - for any logdensity function that is differentiable with JAX. + """Stochastic Gradient Langevin Dynamics sampling loop with decaying step + size for any logdensity function that is differentiable with JAX. Since there is no adaptation step, you have to manually discard the samples before the convergence of the Markov chain. @@ -134,7 +136,9 @@ def sample_pathfinder( pathfinder = blackjax.pathfinder(logdensity_fn) print("Optimizing normal approximations.") state, _ = jax.jit(pathfinder.approximate)( - rng_key=optim_key, position=initial_position, ftol=1e-4 + rng_key=optim_key, + position=initial_position, + ftol=1e-4, ) print("Sampling approximate normals.") samples = pathfinder.sample(sampling_key, state, n_samples) @@ -157,7 +161,9 @@ def sample_meanfield_vi( optim_key, sampling_key = jax.random.split(rng_key) optimizer = adam(learning_rate) mfvi = blackjax.meanfield_vi( - logdensity_fn, optimizer=optimizer, num_samples=n_optim_samples + logdensity_fn, + optimizer=optimizer, + num_samples=n_optim_samples, ) states = [] state = mfvi.init(initial_position) @@ -174,7 +180,9 @@ def batch_data(rng_key, batch_size: int, data_size: int): while True: _, rng_key = jax.random.split(rng_key) idx = jax.random.choice( - key=rng_key, a=jnp.arange(data_size), shape=(batch_size,) + key=rng_key, + a=jnp.arange(data_size), + shape=(batch_size,), ) yield idx @@ -201,15 +209,20 @@ def sample_minibatch_hmc( rng_key = jax.random.PRNGKey(seed) batch_key, warmup_key, sampling_key = jax.random.split(rng_key, 3) batches = batch_data( - batch_key, batch_size, data_size=len(data[list(data.keys())[0]]) + batch_key, + batch_size, + data_size=len(data[list(data.keys())[0]]), ) print("Warmup, window adaptation") warmup_batch = get_batch(next(batches), data, data_axis=data_axis) warmup = blackjax.window_adaptation( - blackjax.hmc, partial(logdensity_fn, **warmup_batch) + blackjax.hmc, + partial(logdensity_fn, **warmup_batch), ) (state, parameters), _ = warmup.run( - warmup_key, initial_position, num_steps=n_warmup + warmup_key, + initial_position, + num_steps=n_warmup, ) sghmc = blackjax.sghmc() kernel = jax.jit(blackjax.nuts(logdensity_fn, **parameters).step) diff --git a/tweetopic/btm.py b/tweetopic/btm.py index 51df7de..33017bb 100644 --- a/tweetopic/btm.py +++ b/tweetopic/btm.py @@ -19,8 +19,7 @@ class BTM(sklearn.base.TransformerMixin, sklearn.base.BaseEstimator): - """Implementation of the Biterm Topic Model with Gibbs Sampling - solver. + """Implementation of the Biterm Topic Model with Gibbs Sampling solver. Parameters ---------- @@ -141,7 +140,8 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None): max_unique_words=max_unique_words, ) biterms = corpus_unique_biterms( - doc_unique_words, doc_unique_word_counts + doc_unique_words, + doc_unique_word_counts, ) biterm_set = compute_biterm_set(biterms) self.topic_distribution, self.components_ = fit_model( @@ -157,8 +157,7 @@ def fit(self, X: Union[spr.spmatrix, ArrayLike], y: None = None): # TODO: Something goes terribly wrong here, fix this def transform(self, X: Union[spr.spmatrix, ArrayLike]) -> np.ndarray: - """Predicts probabilities for each document belonging to each - topic. + """Predicts probabilities for each document belonging to each topic. Parameters ---------- diff --git a/tweetopic/func.py b/tweetopic/func.py index cfd2029..b4ad791 100644 --- a/tweetopic/func.py +++ b/tweetopic/func.py @@ -4,8 +4,8 @@ def spread(fn: Callable): - """Creates a new function from the given function so that it takes one - dict (PyTree) and spreads the arguments.""" + """Creates a new function from the given function so that it takes one dict + (PyTree) and spreads the arguments.""" @wraps(fn) def inner(kwargs):