Skip to content

Commit

Permalink
feat: Shared keys embedded into styles (#476)
Browse files Browse the repository at this point in the history
* feat: assign shared keys to styles

* doc: changelog

* tests: add test for styles
  • Loading branch information
db0 authored Dec 7, 2024
1 parent adf2d13 commit 418070d
Show file tree
Hide file tree
Showing 14 changed files with 271 additions and 27 deletions.
12 changes: 12 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,18 @@ SPDX-License-Identifier: AGPL-3.0-or-later

# Changelog

# 4.45.0

* Can now assign shared keys to styles. When a shared key is assigned to a style, if it is still valid (i.e. not expired and has kudos)
when that style is applied, it will use that shared key instead of the api key provided by the user.
This can allow someone to share a simple style name with their friends and allow them to use their higher priority.

* Shared keys assigned to styles cannot be used in isolation. They can only be used as part of that style.
* A single shared key can be assigned to more than 1 style
* The shared key assigned to a style is always visible in the style description and therefore considered public information.
* People can now transfer kudos to shared keys. This works on both the web interface and the API. Simply provide the shared key ID instead of a username.
When a kudos is transferred to a shared key, the kudos is transferred to the shared key owner, and the shared key kudos is increased by the same amount if it isn't unlimited (-1)

# 4.44.3

* Fix image validation warnings being sent to the wrong requests
Expand Down
17 changes: 17 additions & 0 deletions horde/apis/models/stable_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,13 @@ def __init__(self, api):
"models": fields.List(
fields.String(description="The models to use with this style.", min_length=1, example="stable_diffusion"),
),
"sharedkey": fields.String(
required=False,
min_length=36,
max_length=36,
description="The UUID of a shared key which will be used to fulfil this style when active.",
example="00000000-0000-0000-0000-000000000000",
),
},
)
self.patch_model_style = api.model(
Expand Down Expand Up @@ -1099,6 +1106,13 @@ def __init__(self, api):
"models": fields.List(
fields.String(description="The models to use with this style.", min_length=1, example="stable_diffusion"),
),
"sharedkey": fields.String(
required=False,
min_length=36,
max_length=36,
description="The UUID of a shared key which will be used to fulfil this style when active.",
example="00000000-0000-0000-0000-000000000000",
),
},
)
self.input_model_style_example_post = api.model(
Expand Down Expand Up @@ -1147,10 +1161,13 @@ def __init__(self, api):
{
"id": fields.String(
description="The UUID of the style. Use this to use the style or retrieve its information in the future.",
min_length=36,
max_length=36,
example="00000000-0000-0000-0000-000000000000",
),
"use_count": fields.Integer(description="The amount of times this style has been used in generations."),
"creator": fields.String(description="The alias of the user to whom this style belongs to.", example="db0#1"),
"examples": fields.List(fields.Nested(self.response_model_style_example, skip_none=True)),
"shared_key": fields.Nested(self.response_model_sharedkey_details),
},
)
14 changes: 14 additions & 0 deletions horde/apis/models/v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,13 @@ def __init__(self):
help="Tags describing this style. Can be used for style discovery.",
location="json",
)
self.style_parser.add_argument(
"sharedkey",
type=str,
required=False,
help="The UUID of a shared key which will be used to generate with this style if active.",
location="json",
)
self.style_parser_patch = reqparse.RequestParser()
self.style_parser_patch.add_argument(
"apikey",
Expand Down Expand Up @@ -428,6 +435,13 @@ def __init__(self):
help="Tags describing this style. Can be used for style discovery.",
location="json",
)
self.style_parser.add_argument(
"sharedkey",
type=str,
required=False,
help="The UUID of a shared key which will be used to generate with this style if active.",
location="json",
)


class Models:
Expand Down
50 changes: 44 additions & 6 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ def post(self):
if self.args.workers:
self.workers = self.args.workers
self.user = None
self.apikey = None
self.sharedkey = None
self.user_ip = request.remote_addr
# For now this is checked on validate()
self.safe_ip = True
Expand Down Expand Up @@ -182,19 +184,27 @@ def validate(self):
if self.args.webhook and not self.args.webhook.startswith("https://"):
raise e.BadRequest("webhooks need to point to an https endpoint.")
with HORDE.app_context(): # TODO DOUBLE CHECK THIS
# logger.warning(datetime.utcnow())
if self.args.apikey:
self.sharedkey = database.find_sharedkey(self.args.apikey)
# If this is set, it means we've already found an active shared key through an applied style.
if self.sharedkey:
self.user = self.sharedkey.user
logger.debug(f"Using style-specified shared key {self.sharedkey.id} from user #{self.user.id}")
elif self.apikey:
self.sharedkey = database.find_sharedkey(self.apikey)
if self.sharedkey:
is_valid, error_msg, rc = self.sharedkey.is_valid()
if not is_valid:
if rc == "SharedKeyEmpty" and self.args.allow_downgrade:
self.downgrade_wp_priority = True
else:
raise e.Forbidden(message=error_msg, rc=rc)
if not self.sharedkey.is_adhoc():
raise e.Forbidden(
message="This shared key cannot be used as it has been assigned to specific styles only",
rc="SharedKeyAssignedStyles",
)
self.user = self.sharedkey.user
if not self.user:
self.user = database.find_user_by_api_key(self.args.apikey)
self.user = database.find_user_by_api_key(self.apikey)
# logger.warning(datetime.utcnow())
if not self.user:
raise e.InvalidAPIKey("generation")
Expand Down Expand Up @@ -356,6 +366,17 @@ def activate_waiting_prompt(self):
kudos_adjustment=2 if self.style_kudos is True else 0,
)

def apply_style(self):
# If it reaches this method, we've already made sure self.args.style isn't empty.
self.existing_style = database.get_style_by_uuid(self.args.style)
if not self.existing_style:
self.existing_style = database.get_style_by_name(self.args.style)
if not self.existing_style:
raise e.ThingNotFound("Style", self.args.style)
# If there's an attached shared key to the style, and it's not empty or expired, we use it.
if self.existing_style.sharedkey and self.existing_style.sharedkey.is_valid()[0] is True:
self.sharedkey = self.existing_style.sharedkey


class SyncGenerate(GenerateTemplate):
# @api.expect(parsers.generate_parser, models.input_model_request_generation, validate=True)
Expand Down Expand Up @@ -3172,7 +3193,14 @@ def post(self):
return

def validate(self):
pass
self.sharedkey = None
if self.args.sharedkey:
self.sharedkey = database.find_sharedkey(self.args.sharedkey)
if self.sharedkey is None:
raise e.BadRequest("This shared key does not exist", "SharedKeyInvalid")
shared_key_validity = self.sharedkey.is_valid()
if shared_key_validity[0] is False:
raise e.BadRequest(shared_key_validity[1], shared_key_validity[2])


class SingleStyleTemplateGet(Resource):
Expand Down Expand Up @@ -3250,6 +3278,9 @@ def patch(self, style_id):
style_modified = True
if len(self.tags) > 0:
style_modified = True
if self.sharedkey is not None:
self.existing_style.sharedkey_id = self.sharedkey.id
style_modified = True
if not style_modified:
return {
"id": self.existing_style.id,
Expand All @@ -3265,7 +3296,14 @@ def patch(self, style_id):
}, 200

def validate(self):
pass
self.sharedkey = None
if self.args.sharedkey:
self.shared_key = database.find_sharedkey(self.args.sharedkey)
if self.sharedkey is None:
raise e.BadRequest("This shared key does not exist", "SharedKeyInvalid")
shared_key_validity = self.sharedkey.is_valid()
if shared_key_validity[0] is False:
raise e.BadRequest(shared_key_validity[1], shared_key_validity[2])

def delete(self, style_id):
self.args = parsers.apikey_parser.parse_args()
Expand Down
22 changes: 14 additions & 8 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def initiate_waiting_prompt(self):
ipaddr=self.user_ip,
safe_ip=True,
client_agent=self.args["Client-Agent"],
sharedkey_id=self.args.apikey if self.sharedkey else None,
sharedkey_id=self.sharedkey.id if self.sharedkey else None,
proxied_account=self.args["proxied_account"],
webhook=self.args.webhook,
)
Expand Down Expand Up @@ -157,8 +157,11 @@ def initiate_waiting_prompt(self):
text_tokens=self.wp.max_length,
)
if not is_in_limit:
self.wp.delete()
raise e.BadRequest(fail_message)
# If we are using the shared key assigned to a style, then we bypass the shared key requirements
# since its owner explicitly allowed to be used with a style exceeding them
if not (self.existing_style and self.existing_style.sharedkey and self.existing_style.sharedkey.id == self.sharedkey.id):
self.wp.delete()
raise e.BadRequest(fail_message)

def get_size_too_big_message(self):
return (
Expand All @@ -168,6 +171,7 @@ def get_size_too_big_message(self):

def validate(self):
self.prompt = self.args.prompt
self.apikey = self.args.apikey
self.apply_style()
super().validate()
param_validator = ParamValidator(self.prompt, self.args.models, self.params, self.user)
Expand All @@ -187,11 +191,8 @@ def get_hashed_params_dict(self):
def apply_style(self):
if self.args.style is None:
return
self.existing_style = database.get_style_by_uuid(self.args.style)
if not self.existing_style:
self.existing_style = database.get_style_by_name(self.args.style)
if not self.existing_style:
raise e.ThingNotFound("Style", self.args.style)
# The super() ensures the common parts of applying a style
super().apply_style()
if self.existing_style.style_type != "text":
raise e.BadRequest("Image styles cannot be used on image requests", "StyleMismatch")
if isinstance(self.existing_style, StyleCollection):
Expand Down Expand Up @@ -495,6 +496,7 @@ class TextStyle(StyleTemplate):
code=200,
description="Lists text styles information",
as_list=True,
skip_none=True,
)
def get(self):
"""Retrieves information about all text styles
Expand Down Expand Up @@ -562,6 +564,7 @@ def post(self):
nsfw=self.args.nsfw,
prompt=self.args.prompt,
params=self.args.params if self.args.params is not None else {},
sharedkey_id=self.sharedkey.id if self.sharedkey else None,
)
new_style.create()
new_style.set_models(self.models)
Expand All @@ -573,6 +576,7 @@ def post(self):
}, 200

def validate(self):
super().validate()
if database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}"):
raise e.BadRequest(
(
Expand All @@ -596,6 +600,7 @@ class SingleTextStyle(SingleStyleTemplate):
code=200,
description="Lists text styles information",
as_list=False,
skip_none=True,
)
def get(self, style_id):
"""Displays information about a single text style."""
Expand Down Expand Up @@ -666,6 +671,7 @@ class SingleImageStyleByName(SingleStyleTemplateGet):
code=200,
description="Lists text style information by name",
as_list=False,
skip_none=True,
)
def get(self, style_name):
"""Seeks a text style by name and displays its information."""
Expand Down
23 changes: 15 additions & 8 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def get_size_too_big_message(self):

def validate(self):
self.prompt = self.args.prompt
self.apikey = self.args.apikey
self.apply_style()
super().validate()
param_validator = ParamValidator(prompt=self.prompt, models=self.args.models, params=self.params, user=self.user)
Expand Down Expand Up @@ -244,7 +245,7 @@ def initiate_waiting_prompt(self):
r2=self.args.r2,
shared=shared,
client_agent=self.args["Client-Agent"],
sharedkey_id=self.args.apikey if self.sharedkey else None,
sharedkey_id=self.sharedkey.id if self.sharedkey else None,
proxied_account=self.args["proxied_account"],
disable_batching=self.args["disable_batching"],
webhook=self.args.webhook,
Expand Down Expand Up @@ -301,8 +302,11 @@ def initiate_waiting_prompt(self):
image_steps=requested_steps,
)
if not is_in_limit:
self.wp.delete()
raise e.BadRequest(fail_message)
# If we are using the shared key assigned to a style, then we bypass the shared key requirements
# since its owner explicitly allowed to be used with a style exceeding them
if not (self.existing_style and self.existing_style.sharedkey and self.existing_style.sharedkey.id == self.sharedkey.id):
self.wp.delete()
raise e.BadRequest(fail_message)

def extrapolate_dry_run_kudos(self):
self.wp.source_image = self.args.source_image
Expand Down Expand Up @@ -363,11 +367,8 @@ def activate_waiting_prompt(self):
def apply_style(self):
if self.args.style is None:
return
self.existing_style = database.get_style_by_uuid(self.args.style)
if not self.existing_style:
self.existing_style = database.get_style_by_name(self.args.style)
if not self.existing_style:
raise e.ThingNotFound("Style", self.args.style)
# The super() ensures the common parts of applying a style
super().apply_style()
if self.existing_style.style_type != "image":
raise e.BadRequest("Text styles cannot be used on image requests", "StyleMismatch")
if isinstance(self.existing_style, StyleCollection):
Expand Down Expand Up @@ -1374,6 +1375,7 @@ class ImageStyle(StyleTemplate):
code=200,
description="Lists image styles information",
as_list=True,
skip_none=True,
)
def get(self):
"""Retrieves information about all image styles
Expand Down Expand Up @@ -1441,6 +1443,7 @@ def post(self):
nsfw=self.args.nsfw,
prompt=self.args.prompt,
params=self.args.params if self.args.params is not None else {},
sharedkey_id=self.sharedkey.id if self.sharedkey else None,
)
new_style.create()
new_style.set_models(self.models)
Expand All @@ -1452,6 +1455,7 @@ def post(self):
}, 200

def validate(self):
super().validate()
if database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}"):
raise e.BadRequest(
(
Expand All @@ -1475,6 +1479,7 @@ class SingleImageStyle(SingleStyleTemplate):
code=200,
description="Lists image styles information",
as_list=False,
skip_none=True,
)
def get(self, style_id):
"""Displays information about an image style."""
Expand Down Expand Up @@ -1502,6 +1507,7 @@ def patch(self, style_id):
return super().patch(style_id)

def validate(self):
super().validate()
if (
self.style_name is not None
and database.get_style_by_name(f"{self.user.get_unique_alias()}::style::{self.style_name}")
Expand Down Expand Up @@ -1545,6 +1551,7 @@ class SingleImageStyleByName(SingleStyleTemplateGet):
code=200,
description="Lists image style information by name",
as_list=False,
skip_none=True,
)
def get(self, style_name):
"""Seeks an image style by name and displays its information."""
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/base/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ class Style(db.Model):

user_id = db.Column(db.Integer, db.ForeignKey("users.id", ondelete="CASCADE"), nullable=False)
user = db.relationship("User", back_populates="styles")
sharedkey_id = db.Column(uuid_column_type(), db.ForeignKey("user_sharedkeys.id"), nullable=True)
sharedkey = db.relationship("UserSharedKey", back_populates="styles")
collections: Mapped[list[StyleCollection]] = db.relationship(secondary="style_collection_mapping", back_populates="styles")
models = db.relationship("StyleModel", back_populates="style", cascade="all, delete-orphan")
tags = db.relationship("StyleTag", back_populates="style", cascade="all, delete-orphan")
Expand Down Expand Up @@ -206,6 +208,7 @@ def get_details(self, details_privilege=0):
"use_count": self.use_count,
"public": self.public,
"nsfw": self.nsfw,
"shared_key": self.sharedkey.get_details() if self.sharedkey else None,
}
return ret_dict

Expand Down
Loading

0 comments on commit 418070d

Please sign in to comment.