diff --git a/concept/__init__.py b/concept/__init__.py index f68baf7..7ce600c 100644 --- a/concept/__init__.py +++ b/concept/__init__.py @@ -1,6 +1,6 @@ from concept._model import ConceptModel -__version__ = "0.2.0" +__version__ = "0.2.1" __all__ = [ "ConceptModel", diff --git a/concept/_model.py b/concept/_model.py index a5d29b7..3417d45 100644 --- a/concept/_model.py +++ b/concept/_model.py @@ -5,7 +5,7 @@ import matplotlib.pyplot as plt from tqdm import tqdm -from typing import List, Mapping, Tuple +from typing import List, Mapping, Tuple, Union from PIL import Image from umap import UMAP from scipy.sparse.csr import csr_matrix @@ -166,7 +166,9 @@ def fit(self, self.fit_transform(images, image_names=image_names, image_embeddings=image_embeddings) return self - def transform(self, images, image_embeddings=None): + def transform(self, + images: Union[List[str], str], + image_embeddings: np.ndarray = None): """ After having fit a model, use transform to predict new instances Arguments: @@ -183,7 +185,9 @@ def transform(self, images, image_embeddings=None): new_concepts = concept_model.transform(new_images) ``` """ - if image_embeddings is not None: + if image_embeddings is None: + if isinstance(images, str): + images = [images] image_embeddings = self._embed_images(images) umap_embeddings = self.umap_model.transform(image_embeddings) @@ -207,21 +211,21 @@ def _embed_images(self, embeddings: The image embeddings """ # Prepare images - batch_size = 64 - images_to_embed = [Image.open(filepath) for filepath in images] - nr_iterations = int(np.ceil(len(images_to_embed) / batch_size)) + batch_size = 128 + nr_iterations = int(np.ceil(len(images) / batch_size)) # Embed images per batch embeddings = [] for i in tqdm(range(nr_iterations)): start_index = i * batch_size end_index = (i * batch_size) + batch_size - img_emb = self.embedding_model.encode(images_to_embed[start_index:end_index], - show_progress_bar=False) + + images_to_embed = [Image.open(filepath) for filepath in images[start_index:end_index]] + img_emb = self.embedding_model.encode(images_to_embed, show_progress_bar=False) embeddings.extend(img_emb.tolist()) - # If images within e - for image in images_to_embed[start_index:end_index]: + # Close images + for image in images_to_embed: image.close() return np.array(embeddings) @@ -366,21 +370,25 @@ def _cluster_representation(self, images: A list of paths to each image selected_exemplars: A selection of exemplar images for each concept cluster """ - pil_images = [Image.open(filepath) for filepath in images] - - sliced_exemplars = {cluster: [[pil_images[j] - for j in selected_exemplars[cluster][i:i + 3]] + # Find indices of exemplars per cluster + sliced_exemplars = {cluster: [[j for j in selected_exemplars[cluster][i:i + 3]] for i in range(0, len(selected_exemplars[cluster]), 3)] for cluster in self.cluster_labels[1:]} + + # combine exemplars into a single image + cluster_images = {} + for cluster in self.cluster_labels[1:]: + images_to_cluster = [[Image.open(images[index]) for index in sub_indices] for sub_indices in sliced_exemplars[cluster]] + cluster_image = get_concat_tile_resize(images_to_cluster) + cluster_images[cluster] = cluster_image + + # Make sure to properly close images + for image_list in images_to_cluster: + for image in image_list: + image.close() - cluster_images = {cluster: get_concat_tile_resize(sliced_exemplars[cluster]) - for cluster in self.cluster_labels[1:]} self.cluster_images = cluster_images - # Properly close images - for image in pil_images: - image.close() - def _extract_textual_representation(self, docs: List[str]): """ Extract textual representation of concepts by comparing with documents diff --git a/docs/changelog.md b/docs/changelog.md index a273009..09bf630 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,3 +1,10 @@ +## **Version 0.2.1** +*Release date: 5 November, 2021* + +* Fixed issue when loading in more than 40.000 images +* Fixed `transform` only working with pre-trained embeddings + + ## **Version 0.2.0** *Release date: 2 November, 2021* diff --git a/setup.py b/setup.py index ccc2071..552ab86 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( name="concept", packages=find_packages(exclude=["notebooks", "docs"]), - version="0.2.0", + version="0.2.1", author="Maarten P. Grootendorst", author_email="maartengrootendorst@gmail.com", description="Topic Model Images",