From ee96439339e69ff5e6799481fd1cd680a2225070 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Fri, 16 Sep 2022 13:12:01 +0000 Subject: [PATCH] Some changes for easier jitting of safety checker: - Do not transpose (assumes a certain number of dimensions). The array should be prepared externally. - Create a method (which is not jitted) to get the final flags from the distances. --- stable_diffusion_jax/safety_checker.py | 43 +++++++++++++++++--------- 1 file changed, 28 insertions(+), 15 deletions(-) diff --git a/stable_diffusion_jax/safety_checker.py b/stable_diffusion_jax/safety_checker.py index c81c778..592092b 100644 --- a/stable_diffusion_jax/safety_checker.py +++ b/stable_diffusion_jax/safety_checker.py @@ -43,11 +43,15 @@ def __call__(self, clip_input, images=None): if images is None: return special_cos_dist, cos_dist + # I think this should not be invoked here + return self.filtered_with_scores(special_cos_dist, cos_dist, images) + + def filtered_with_scores(self, special_cos_dist, cos_dist, images): + batch_size = special_cos_dist.shape[0] special_cos_dist = np.asarray(special_cos_dist) cos_dist = np.asarray(cos_dist) result = [] - batch_size = image_embeds.shape[0] for i in range(batch_size): result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []} @@ -55,20 +59,20 @@ def __call__(self, clip_input, images=None): # at the cost of increasing the possibility of filtering benign image inputs adjustment = 0.0 - for concet_idx in range(len(special_cos_dist[0])): - concept_cos = special_cos_dist[i][concet_idx] - concept_threshold = self.special_care_embeds_weights[concet_idx].item() - result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["special_scores"][concet_idx] > 0: - result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]}) + for concept_idx in range(len(special_cos_dist[0])): + concept_cos = special_cos_dist[i][concept_idx] + concept_threshold = self.special_care_embeds_weights[concept_idx].item() + result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["special_scores"][concept_idx] > 0: + result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]}) adjustment = 0.01 - for concet_idx in range(len(cos_dist[0])): - concept_cos = cos_dist[i][concet_idx] - concept_threshold = self.concept_embeds_weights[concet_idx].item() - result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3) - if result_img["concept_scores"][concet_idx] > 0: - result_img["bad_concepts"].append(concet_idx) + for concept_idx in range(len(cos_dist[0])): + concept_cos = cos_dist[i][concept_idx] + concept_threshold = self.concept_embeds_weights[concept_idx].item() + result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3) + if result_img["concept_scores"][concept_idx] > 0: + result_img["bad_concepts"].append(concept_idx) result.append(result_img) @@ -129,11 +133,20 @@ def __call__( params: dict = None, images=None, ): - pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) + # pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1)) return self.module.apply( {"params": params or self.params}, - jnp.array(pixel_values, dtype=jnp.float32), + pixel_values, images, rngs={}, ) + + def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict=None): + return self.module.apply( + {"params": params} or self.params, + special_cos_dist, + cos_dist, + images, + method=self.module.filtered_with_scores, + )