Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safety Checker jitting #11

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions stable_diffusion_jax/safety_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,32 +43,36 @@ def __call__(self, clip_input, images=None):
if images is None:
return special_cos_dist, cos_dist

# I think this should not be invoked here
return self.filtered_with_scores(special_cos_dist, cos_dist, images)

def filtered_with_scores(self, special_cos_dist, cos_dist, images):
batch_size = special_cos_dist.shape[0]
special_cos_dist = np.asarray(special_cos_dist)
cos_dist = np.asarray(cos_dist)

result = []
batch_size = image_embeds.shape[0]
for i in range(batch_size):
result_img = {"special_scores": {}, "special_care": [], "concept_scores": {}, "bad_concepts": []}

# increase this value to create a stronger `nfsw` filter
# at the cost of increasing the possibility of filtering benign image inputs
adjustment = 0.0

for concet_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concet_idx]
concept_threshold = self.special_care_embeds_weights[concet_idx].item()
result_img["special_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concet_idx] > 0:
result_img["special_care"].append({concet_idx, result_img["special_scores"][concet_idx]})
for concept_idx in range(len(special_cos_dist[0])):
concept_cos = special_cos_dist[i][concept_idx]
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
result_img["special_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["special_scores"][concept_idx] > 0:
result_img["special_care"].append({concept_idx, result_img["special_scores"][concept_idx]})
adjustment = 0.01

for concet_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concet_idx]
concept_threshold = self.concept_embeds_weights[concet_idx].item()
result_img["concept_scores"][concet_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concet_idx] > 0:
result_img["bad_concepts"].append(concet_idx)
for concept_idx in range(len(cos_dist[0])):
concept_cos = cos_dist[i][concept_idx]
concept_threshold = self.concept_embeds_weights[concept_idx].item()
result_img["concept_scores"][concept_idx] = round(concept_cos - concept_threshold + adjustment, 3)
if result_img["concept_scores"][concept_idx] > 0:
result_img["bad_concepts"].append(concept_idx)

result.append(result_img)

Expand Down Expand Up @@ -129,11 +133,20 @@ def __call__(
params: dict = None,
images=None,
):
pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))
# pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

return self.module.apply(
{"params": params or self.params},
jnp.array(pixel_values, dtype=jnp.float32),
pixel_values,
images,
rngs={},
)

def filtered_with_scores(self, special_cos_dist, cos_dist, images, params: dict=None):
return self.module.apply(
{"params": params} or self.params,
special_cos_dist,
cos_dist,
images,
method=self.module.filtered_with_scores,
)