Skip to content

Commit

Permalink
Various user2worker options (#402)
Browse files Browse the repository at this point in the history
* feat: allow disabling targeted worker without priority
* feat: option to require upfront kudos on worker allow/deny lists
* feat: maintenance worker can pick up priority users
* feat: 0 kudos on targeted gens when option set
  • Loading branch information
db0 authored Apr 14, 2024
1 parent d4e7ca7 commit d67fa6f
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 35 deletions.
5 changes: 5 additions & 0 deletions .env_template
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@ HORDE_MARKDOWN_INDEX="index_stable.md"
HORDE_IMAGE_COMPVIS_REFERENCE="https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/stable_diffusion.json"
HORDE_IMAGE_DIFFUSERS_REFERENCE="https://raw.githubusercontent.com/Haidra-Org/AI-Horde-image-model-reference/main/diffusers.json"
HORDE_IMAGE_LLM_REFERENCE="https://raw.githubusercontent.com/db0/AI-Horde-text-model-reference/main/db.json"
# Set to 1 to prevent user specifying a worker who doesn't have them as prio
# When enabled. targeted worker request will also not use any kudos except some horde tax
HORDE_REQUIRE_MATCHED_TARGETING=0
# Set to 1 to make specifying a worker allow/denylist require upfront kudos
HORDE_UPFRONT_KUDOS_ON_WORKERLIST=0
# Google Oauth2
GOOGLE_CLIENT_ID=""
GLOOGLE_CLIENT_SECRET=""
Expand Down
6 changes: 3 additions & 3 deletions horde/apis/v2/kobold.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,9 @@ def initiate_waiting_prompt(self):
if model_multiplier > highest_multiplier:
highest_multiplier = model_multiplier
required_kudos = round(self.wp.max_length * highest_multiplier / 21, 2) * self.wp.n
needs_kudos, tokens, disable_downgrade = self.wp.require_upfront_kudos(database.retrieve_totals(), total_threads)
if self.sharedkey and self.sharedkey.kudos != -1 and required_kudos > self.sharedkey.kudos:
if self.args.allow_downgrade:
if self.args.allow_downgrade and not disable_downgrade:
self.downgrade_wp_priority = True
else:
self.wp.delete()
Expand All @@ -118,10 +119,9 @@ def initiate_waiting_prompt(self):
f"to fulfill this reques ({required_kudos}).",
rc="SharedKeyInsufficientKudos",
)
needs_kudos, tokens = self.wp.require_upfront_kudos(database.retrieve_totals(), total_threads)
if needs_kudos:
if required_kudos > self.user.kudos:
if self.args.allow_downgrade:
if self.args.allow_downgrade and not disable_downgrade:
self.wp.downgrade(tokens)
else:
self.wp.delete()
Expand Down
6 changes: 3 additions & 3 deletions horde/apis/v2/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,12 +274,12 @@ def initiate_waiting_prompt(self):
webhook=self.args.webhook,
)
_, total_threads = database.count_active_workers("image")
needs_kudos, resolution = self.wp.require_upfront_kudos(database.retrieve_totals(), total_threads)
needs_kudos, resolution, disable_downgrade = self.wp.require_upfront_kudos(database.retrieve_totals(), total_threads)
required_kudos = 0
if (self.sharedkey and self.sharedkey.kudos != -1) or needs_kudos:
required_kudos = self.wp.extrapolate_dry_run_kudos()
if self.sharedkey and self.sharedkey.kudos != -1 and required_kudos > self.sharedkey.kudos:
if self.args.allow_downgrade:
if self.args.allow_downgrade and not disable_downgrade:
self.downgrade_wp_priority = True
else:
self.wp.delete()
Expand All @@ -292,7 +292,7 @@ def initiate_waiting_prompt(self):
)
if needs_kudos is True:
if required_kudos > self.user.kudos:
if self.args.allow_downgrade:
if self.args.allow_downgrade and not disable_downgrade:
self.wp.downgrade(resolution)
else:
self.wp.delete()
Expand Down
3 changes: 3 additions & 0 deletions horde/classes/kobold/processing_generation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os

from horde import vars as hv
from horde.classes.base.processing_generation import ProcessingGeneration
Expand Down Expand Up @@ -30,6 +31,8 @@ def get_details(self):
return ret_dict

def get_gen_kudos(self):
if os.getenv("HORDE_REQUIRE_MATCHED_TARGETING", "0") == "1" and len(self.wp.workers) > 0:
return 0.1
# This formula creates an exponential increase on the kudos consumption, based on the context requested
# 1024 context is considered the base.
# The reason is that higher context has exponential VRAM requirements
Expand Down
19 changes: 14 additions & 5 deletions horde/classes/kobold/waiting_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import os

from sqlalchemy.sql import expression

Expand Down Expand Up @@ -90,21 +91,26 @@ def record_usage(self, raw_things, kudos, usage_type="text", avoid_burn=False):
super().record_usage(raw_things, kudos, usage_type)

def require_upfront_kudos(self, counted_totals, total_threads):
"""Returns True if this wp requires that the user already has the required kudos to fulfil it
else returns False
"""Returns A tuple
First entry in the tuple is True if this wp requires that the user already has the required kudos to fulfil it
else is False
Second entry in the tuple is the max tokens that can be used without upfront kudos
Third entry in the tuple is whether the upfront kudos requirement prevents downgrading to resolve this.
"""
queue = counted_totals["queued_text_requests"]
max_tokens = 512 + (total_threads * 5) - round(queue * 0.9)
# logger.debug([queue,max_tokens])
if not self.slow_workers:
return (True, max_tokens)
return (True, max_tokens, False)
if max_tokens < 256:
max_tokens = 256
if max_tokens > 512:
max_tokens = 512
if self.max_length > max_tokens:
return (True, max_tokens)
return (False, max_tokens)
return (True, max_tokens, False)
if os.getenv("HORDE_UPFRONT_KUDOS_ON_WORKERLIST", "0") == "1" and len(self.workers) > 0:
return (True, max_tokens, True)
return (False, max_tokens, False)

def downgrade(self, max_tokens):
"""Ensures this WP requirements are not exceeding upfront kudos requirements"""
Expand All @@ -117,6 +123,9 @@ def downgrade(self, max_tokens):
db.session.commit()

def calculate_kudos(self):
if os.getenv("HORDE_REQUIRE_MATCHED_TARGETING", "0") == "1" and len(self.workers) > 0:
self.kudos = 0.1
return self.kudos
# Slimmed down version of procgen.get_gen_kudos()
# As we don't know the worker's trusted status.
# It exists here in order to allow us to calculate dry_runs
Expand Down
29 changes: 20 additions & 9 deletions horde/classes/stable/waiting_prompt.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
import os
import random

from sqlalchemy.sql import expression
Expand Down Expand Up @@ -287,6 +288,10 @@ def calculate_kudos(self):
#
# Legacy calculation
#
if os.getenv("HORDE_REQUIRE_MATCHED_TARGETING", "0") == "1" and len(self.workers) > 0:
self.kudos = 0.1
db.session.commit()
return self.kudos
legacy_kudos_cost = 0
result = pow(
(self.params.get("width", 512) * self.params.get("height", 512)) - (64 * 64),
Expand Down Expand Up @@ -333,13 +338,16 @@ def calculate_kudos(self):
return self.kudos

def require_upfront_kudos(self, counted_totals, total_threads):
"""Returns True if this wp requires that the user already has the required kudos to fulfil it
else returns False
"""Returns A tuple
First entry in the tuple is True if this wp requires that the user already has the required kudos to fulfil it
else is False
Second entry in the tuple is the max resolution that can be used without upfront kudos
Third entry in the tuple is whether the upfront kudos requirement prevents downgrading to resolve this.
"""
queue = counted_totals["queued_requests"]
max_res = 1024 + (total_threads * 10) - round(queue * 0.9)
if not self.slow_workers:
return (True, max_res)
return (True, max_res, False)
if max_res < 576:
max_res = 576
model_names = self.get_model_names()
Expand All @@ -359,20 +367,23 @@ def require_upfront_kudos(self, counted_totals, total_threads):
max_res = 1024
# Using more than 10 steps with LCM requires upfront kudos
if self.is_using_lcm() and self.get_accurate_steps() > 10:
return (True, max_res)
return (True, max_res, False)
# Stable Cascade doesn't need so many steps, so we limit it a bit to prevent abuse.
if any(model_reference.get_model_baseline(mn) in ["stable_cascade"] for mn in model_names) and self.get_accurate_steps() > 30:
return (True, max_res)
return (True, max_res, False)
if self.get_accurate_steps() > 50:
return (True, max_res)
return (True, max_res, False)
if self.width * self.height > max_res * max_res:
return (True, max_res)
return (True, max_res, False)
if self.params.get("control_type") and self.get_accurate_steps() > 20:
return (True, max_res)
return (True, max_res, False)
# haven't decided yet if this is a good idea.
# if 'RealESRGAN_x4plus' in self.gen_payload.get('post_processing', []):
# return(True,max_res)
return (False, max_res)
# if HORDE_UPFRONT_KUDOS_ON_WORKERLIST is set to 1, then specifying a worker allow/deny list requires upfront kudos
if os.getenv("HORDE_UPFRONT_KUDOS_ON_WORKERLIST", "0") == "1" and len(self.workers) > 0:
return (True, max_res, True)
return (False, max_res, False)

def downgrade(self, max_resolution):
"""Ensures this WP requirements are not exceeding upfront kudos requirements"""
Expand Down
69 changes: 54 additions & 15 deletions horde/database/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import os
import time
import urllib.parse
import uuid
Expand Down Expand Up @@ -783,17 +784,6 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p
"SDXL_beta::stability.ai#6901" not in models_list,
),
),
or_(
WPAllowedWorkers.id.is_(None),
and_(
ImageWaitingPrompt.worker_blacklist.is_(False),
WPAllowedWorkers.worker_id == worker.id,
),
and_(
ImageWaitingPrompt.worker_blacklist.is_(True),
WPAllowedWorkers.worker_id != worker.id,
),
),
or_(
ImageWaitingPrompt.source_image == None, # noqa E712
worker.allow_img2img == True, # noqa E712
Expand All @@ -814,10 +804,6 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p
ImageWaitingPrompt.nsfw == False, # noqa E712
worker.nsfw == True, # noqa E712
),
or_(
worker.maintenance == False, # noqa E712
ImageWaitingPrompt.user_id == worker.user_id,
),
or_(
check_bridge_capability("r2", worker.bridge_agent),
ImageWaitingPrompt.r2 == False, # noqa E712
Expand Down Expand Up @@ -856,6 +842,59 @@ def get_sorted_wp_filtered_to_worker(worker, models_list=None, blacklist=None, p
# logger.debug(final_wp_list)
if priority_user_ids:
final_wp_list = final_wp_list.filter(ImageWaitingPrompt.user_id.in_(priority_user_ids))
final_wp_list = final_wp_list.filter(
# Workers in maintenance can still pick up their owner or their friends
or_(
worker.maintenance == False, # noqa E712
ImageWaitingPrompt.user_id.in_(priority_user_ids),
),
or_(
WPAllowedWorkers.id.is_(None),
and_(
ImageWaitingPrompt.worker_blacklist.is_(False),
WPAllowedWorkers.worker_id == worker.id,
),
and_(
ImageWaitingPrompt.worker_blacklist.is_(True),
WPAllowedWorkers.worker_id != worker.id,
),
),
)
else:
final_wp_list = final_wp_list.filter(
or_(
worker.maintenance == False, # noqa E712
ImageWaitingPrompt.user_id == worker.user_id,
),
)
# If HORDE_REQUIRE_MATCHED_TARGETING is set to 1, we disable using WPAllowedWorkers
# Targeted requests will only be picked up in the condition above as it will include the
# filter to ensure the worker also has that user as a priority
if os.getenv("HORDE_REQUIRE_MATCHED_TARGETING", "0") == "1":
final_wp_list = final_wp_list.filter(
or_(
WPAllowedWorkers.id.is_(None),
and_(
ImageWaitingPrompt.worker_blacklist.is_(True),
WPAllowedWorkers.worker_id != worker.id,
),
),
)
else:
final_wp_list = final_wp_list.filter(
or_(
WPAllowedWorkers.id.is_(None),
and_(
ImageWaitingPrompt.worker_blacklist.is_(False),
WPAllowedWorkers.worker_id == worker.id,
),
and_(
ImageWaitingPrompt.worker_blacklist.is_(True),
WPAllowedWorkers.worker_id != worker.id,
),
),
)

# logger.debug(final_wp_list)
final_wp_list = (
final_wp_list.order_by(ImageWaitingPrompt.extra_priority.desc(), ImageWaitingPrompt.created.asc())
Expand Down

0 comments on commit d67fa6f

Please sign in to comment.