Skip to content

Commit

Permalink
fix: regression bug on regex filter due to styles
Browse files Browse the repository at this point in the history
  • Loading branch information
db0 committed Nov 25, 2024
1 parent ba9b734 commit 4ec3f7b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 20 deletions.
28 changes: 14 additions & 14 deletions horde/apis/v2/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,40 +258,40 @@ def validate(self):
if ip_timeout:
raise e.TimeoutIP(self.user_ip, ip_timeout)
# logger.warning(datetime.utcnow())
prompt_suspicion, _ = prompt_checker(self.args.prompt)
prompt_suspicion, _ = prompt_checker(self.prompt)
# logger.warning(datetime.utcnow())
prompt_replaced = False
if prompt_suspicion >= 2 and self.gentype != "text":
# if replacement filter mode is enabled AND prompt is short enough, do that instead
if self.args.replacement_filter or self.user.education:
if not prompt_checker.check_prompt_replacement_length(self.args.prompt):
if not prompt_checker.check_prompt_replacement_length(self.prompt):
raise e.BadRequest("Prompt has to be below 7000 chars when replacement filter is on")
self.args.prompt = prompt_checker.apply_replacement_filter(self.args.prompt)
self.prompt = prompt_checker.apply_replacement_filter(self.prompt)
# If it returns None, it means it replaced everything with an empty string
if self.args.prompt is not None:
if self.prompt is not None:
prompt_replaced = True
if not prompt_replaced:
# Moderators do not get ip blocked to allow for experiments
if not self.user.moderator:
prompt_dict = {
"prompt": self.args.prompt,
"prompt": self.prompt,
"user": self.username,
"type": "regex",
}
upload_prompt(prompt_dict)
self.user.report_suspicion(1, Suspicions.CORRUPT_PROMPT)
CounterMeasures.report_suspicion(self.user_ip)
raise e.CorruptPrompt(self.username, self.user_ip, self.args.prompt)
if_nsfw_model = prompt_checker.check_nsfw_model_block(self.args.prompt, self.models)
raise e.CorruptPrompt(self.username, self.user_ip, self.prompt)
if_nsfw_model = prompt_checker.check_nsfw_model_block(self.prompt, self.models)
if if_nsfw_model or self.user.flagged:
# For NSFW models and flagged users, we always do replacements
# This is to avoid someone using the NSFW models to figure out the regex since they don't have an IP timeout
self.args.prompt = prompt_checker.nsfw_model_prompt_replace(
self.args.prompt,
self.prompt = prompt_checker.nsfw_model_prompt_replace(
self.prompt,
self.models,
already_replaced=prompt_replaced,
)
if self.args.prompt is None:
if self.prompt is None:
prompt_replaced = False
elif prompt_replaced is False:
prompt_replaced = True
Expand All @@ -303,16 +303,16 @@ def validate(self):
)
if self.user.flagged and not if_nsfw_model:
msg = "To prevent generation of unethical images, we cannot allow this prompt."
raise e.CorruptPrompt(self.username, self.user_ip, self.args.prompt, message=msg)
raise e.CorruptPrompt(self.username, self.user_ip, self.prompt, message=msg)
# Disabling as this is handled by the worker-csam-filter now
# If I re-enable it, also make it use the prompt replacement
# if not prompt_replaced:
# csam_trigger_check = prompt_checker.check_csam_triggers(self.args.prompt)
# csam_trigger_check = prompt_checker.check_csam_triggers(self.prompt)
# if csam_trigger_check is not False and self.gentype != "text":
# raise e.CorruptPrompt(
# self.username,
# self.user_ip,
# self.args.prompt,
# self.prompt,
# message = (f"The trigger '{csam_trigger_check}' has been detected to generate "
# "unethical images on its own and as such has had to be prevented from use. "
# "Thank you for understanding.")
Expand Down Expand Up @@ -1599,7 +1599,7 @@ def get(self):
skname = f": {sk.name}"
user_details["username"] = user_details["username"] + f" (Shared Key{skname})"
if hr.horde_r:
hr.horde_r_setex_json(cache_name, timedelta(seconds=300), user_details)
hr.horde_r_setex_json(cache_name, timedelta(seconds=30), user_details)
return (user_details, 200)


Expand Down
4 changes: 2 additions & 2 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def validate(self):
self.prompt = self.args.prompt
self.apply_style()
super().validate()
param_validator = ParamValidator(self.args.prompt, self.args.models, self.params, self.user)
param_validator = ParamValidator(self.prompt, self.args.models, self.params, self.user)
self.warnings = param_validator.validate_text_params()
if self.args.extra_source_images is not None and len(self.args.extra_source_images) > 0:
raise e.BadRequest("This request type does not accept extra source images.", rc="InvalidExtraSourceImages.")
Expand Down Expand Up @@ -202,7 +202,7 @@ def apply_style(self):
self.models = self.existing_style.get_model_names()
# We need to use defaultdict to avoid getting keyerrors in case the style author added
# Erroneous keys in the string
self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.args.prompt))
self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.prompt))
requested_n = self.params.get("n", 1)
self.params = self.existing_style.params
self.params["n"] = requested_n
Expand Down
8 changes: 4 additions & 4 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def validate(self):
self.prompt = self.args.prompt
self.apply_style()
super().validate()
param_validator = ParamValidator(prompt=self.args.prompt, models=self.args.models, params=self.params, user=self.user)
param_validator = ParamValidator(prompt=self.prompt, models=self.args.models, params=self.params, user=self.user)
self.warnings = param_validator.validate_image_params()
param_validator.check_for_special()
# During raids, we prevent VPNs
Expand Down Expand Up @@ -377,13 +377,13 @@ def apply_style(self):
self.existing_style.use_count += 1
self.models = self.existing_style.get_model_names()
self.negprompt = ""
if "###" in self.args.prompt:
self.prompt, self.negprompt = self.args.prompt.split("###", 1)
if "###" in self.prompt:
self.prompt, self.negprompt = self.prompt.split("###", 1)
if "###" not in self.existing_style.prompt and self.negprompt != "" and "###" not in self.negprompt:
self.negprompt = "###" + self.negprompt
# We need to use defaultdict to avoid getting keyerrors in case the style author added
# Erroneous keys in the string
self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.args.prompt, np=self.negprompt))
self.prompt = self.existing_style.prompt.format_map(defaultdict(str, p=self.prompt, np=self.negprompt))
requested_n = self.params.get("n", 1)
self.params = self.existing_style.params
self.params["n"] = requested_n
Expand Down

0 comments on commit 4ec3f7b

Please sign in to comment.