Skip to content

Commit

Permalink
feat: Support for Stable Cascade (#376)
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 authored Feb 22, 2024
1 parent 77724c2 commit 93d71b9
Show file tree
Hide file tree
Showing 8 changed files with 25 additions and 5 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

# 4.31.3

* Added some restrictions for Stable Cascade

# 4.31.2

* Specialized some generic rcs
Expand Down
4 changes: 4 additions & 0 deletions README_return_codes.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,7 @@ The errors returned by the AI horde are always in this json format
| SpecialMissingUsername | Special models must always include the username, in the form of 'horde_special::user#id' |
| SpecialModelNeedsSpecialUser | Only special users can request a special model.", "SpecialModelNeedsSpecialUser|
| SpecialFieldNeedsSpecialUser | Only special users can send a special field |
| Img2ImgMismatch | Img2Img cannot be used in combination with this model |
| TilingMismatch | Tiling cannot be used in combination with this model |
| ControlNetMismatch | ControlNet cannot be used in combination with this model |
| HiResMismatch | HiRes fix cannot be used in combination with this model |
9 changes: 8 additions & 1 deletion horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,14 @@ def validate(self):
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("hires fix does not work with SDXL currently.", rc="HiResFixMismatch")
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetSDXLMismatch")
raise e.BadRequest("ControlNet does not work with SDXL currently.", rc="ControlNetMismatch")
if any(model_reference.get_model_baseline(model_name).startswith("stable_cascade") for model_name in self.args.models):
if self.args.source_image:
raise e.BadRequest("Img2Img does not work with Stable Cascade currently.", rc="Img2ImgMismatch")
if self.params.get("hires_fix", False) is True:
raise e.BadRequest("hires fix does not work with Stable Cascade currently.", rc="HiResFixMismatch")
if "control_type" in self.params:
raise e.BadRequest("ControlNet does not work with Stable Cascade currently.", rc="ControlNetMismatch")
if "loras" in self.params:
if len(self.params["loras"]) > 5:
raise e.BadRequest("You cannot request more than 5 loras per generation.", rc="TooManyLoras")
Expand Down
2 changes: 1 addition & 1 deletion horde/classes/stable/processing_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_details(self):

def get_gen_kudos(self):
# We have pre-calculated them as they don't change per worker
if model_reference.get_model_baseline(self.model) == "stable_diffusion_xl":
if model_reference.get_model_baseline(self.model) in ["stable_diffusion_xl", "stable_cascade"]:
return self.wp.kudos * 2
return self.wp.kudos

Expand Down
6 changes: 4 additions & 2 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,9 @@ def require_upfront_kudos(self, counted_totals, total_threads):
):
max_res = 768
# We allow everyone to use SDXL up to 1024
if max_res < 1024 and any(model_reference.get_model_baseline(mn) == "stable_diffusion_xl" for mn in model_names):
if max_res < 1024 and any(
model_reference.get_model_baseline(mn) in ["stable_diffusion_xl", "stable_cascade"] for mn in model_names
):
max_res = 1024
if max_res > 1024:
max_res = 1024
Expand Down Expand Up @@ -437,7 +439,7 @@ def extrapolate_dry_run_kudos(self):
model_name = self.models[0].model
else:
model_name = "SDXL 1.0"
if model_reference.get_model_baseline(model_name) == "stable_diffusion_xl":
if model_reference.get_model_baseline(model_name) in ["stable_diffusion_xl", "stable_cascade"]:
return (self.calculate_extra_kudos_burn(kudos) * self.n * 2) + 1
# The +1 is the extra kudos burn per request
return (self.calculate_extra_kudos_burn(kudos) * self.n) + 1
Expand Down
2 changes: 1 addition & 1 deletion horde/consts.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
HORDE_VERSION = "4.31.2"
HORDE_VERSION = "4.31.3"

WHITELISTED_SERVICE_IPS = {
"212.227.227.178", # Turing Bot
Expand Down
2 changes: 2 additions & 0 deletions horde/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@
"SpecialMissingUsername",
"SpecialModelNeedsSpecialUser",
"SpecialFieldNeedsSpecialUser",
"Img2ImgMismatch",
"TilingMismatch",
]


Expand Down
1 change: 1 addition & 0 deletions horde/model_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def call_function(self):
"stable diffusion 2",
"stable diffusion 2 512",
"stable_diffusion_xl",
"stable_cascade",
}:
self.stable_diffusion_names.add(model)
if self.reference[model].get("nsfw"):
Expand Down

0 comments on commit 93d71b9

Please sign in to comment.