From 382626934175308e41080298c0be7b52ef01ff28 Mon Sep 17 00:00:00 2001 From: Gareth Coles Date: Mon, 18 Mar 2024 10:10:46 +0000 Subject: [PATCH] PK: Configurable rate limiting support --- .../modules/extra/pluralkit/PKExtension.kt | 18 ++--- .../modules/extra/pluralkit/_Utils.kt | 14 +++- .../modules/extra/pluralkit/api/PluralKit.kt | 10 ++- .../extra/pluralkit/config/PKConfigBuilder.kt | 80 +++++++++++++++++++ 4 files changed, 109 insertions(+), 13 deletions(-) create mode 100644 extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/config/PKConfigBuilder.kt diff --git a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/PKExtension.kt b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/PKExtension.kt index e37c880e98..930e923d2f 100644 --- a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/PKExtension.kt +++ b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/PKExtension.kt @@ -24,6 +24,7 @@ import com.kotlindiscord.kord.extensions.extensions.Extension import com.kotlindiscord.kord.extensions.extensions.ephemeralSlashCommand import com.kotlindiscord.kord.extensions.extensions.event import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.api.PluralKit +import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config.PKConfigBuilder import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.events.proxied import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.events.unproxied import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.storage.PKGuildConfig @@ -55,7 +56,7 @@ const val NEGATIVE_EMOTE = "❌" const val POSITIVE_EMOTE = "✅" @Suppress("StringLiteralDuplication") -class PKExtension : Extension() { +class PKExtension(val config: PKConfigBuilder) : Extension() { override val name: String = "ext-pluralkit" override val bundle: String = "kordex.pluralkit" @@ -433,17 +434,14 @@ class PKExtension : Extension() { checkTask = null } - private fun PKGuildConfig.api(): PluralKit { - var api = apiMap[apiUrl] - - if (api == null) { - api = PluralKit(apiUrl) - apiMap[apiUrl] = api + private fun PKGuildConfig.api(): PluralKit = + apiMap.getOrPut(apiUrl) { + PluralKit( + apiUrl, + config.getLimiter(apiUrl) + ) } - return api - } - private fun Boolean.emote() = if (this) { POSITIVE_EMOTE diff --git a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/_Utils.kt b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/_Utils.kt index 38177d89c1..4f13635be2 100644 --- a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/_Utils.kt +++ b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/_Utils.kt @@ -7,8 +7,18 @@ package com.kotlindiscord.kord.extensions.modules.extra.pluralkit import com.kotlindiscord.kord.extensions.builders.ExtensibleBotBuilder +import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config.PKConfigBuilder -/** Set up and add the PluralKit extension to your bot. **/ +/** Set up and add the PluralKit extension to your bot, using the default configuration. **/ fun ExtensibleBotBuilder.ExtensionsBuilder.extPluralKit() { - add(::PKExtension) + add { PKExtension(PKConfigBuilder()) } +} + +/** Set up and add the PluralKit extension to your bot. **/ +fun ExtensibleBotBuilder.ExtensionsBuilder.extPluralKit(body: PKConfigBuilder.() -> Unit) { + val builder = PKConfigBuilder() + + body(builder) + + add { PKExtension(builder) } } diff --git a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/api/PluralKit.kt b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/api/PluralKit.kt index d36cf8abb4..ba853112a3 100644 --- a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/api/PluralKit.kt +++ b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/api/PluralKit.kt @@ -10,6 +10,7 @@ package com.kotlindiscord.kord.extensions.modules.extra.pluralkit.api import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.utils.LRUHashMap import dev.kord.common.entity.Snowflake +import dev.kord.common.ratelimit.IntervalRateLimiter import io.github.oshai.kotlinlogging.KotlinLogging import io.ktor.client.* import io.ktor.client.call.* @@ -19,10 +20,15 @@ import io.ktor.client.request.* import io.ktor.http.* import io.ktor.serialization.kotlinx.json.* import kotlinx.serialization.json.Json +import kotlin.time.Duration.Companion.seconds internal const val PK_API_VERSION = 2 -class PluralKit(private val baseUrl: String = "https://api.pluralkit.me", cacheSize: Int = 10_000) { +class PluralKit( + private val baseUrl: String = "https://api.pluralkit.me", + private val rateLimiter: IntervalRateLimiter? = IntervalRateLimiter(2, 1.seconds), + cacheSize: Int = 10_000 +) { private val logger = KotlinLogging.logger { } private val messageUrl: String = "${this.baseUrl}/v$PK_API_VERSION/messages/{id}" @@ -53,6 +59,8 @@ class PluralKit(private val baseUrl: String = "https://api.pluralkit.me", cacheS val url = messageUrl.replace("id" to id) try { + rateLimiter?.consume() + val result: PKMessage = client.get(url).body() messageCache[id] = result diff --git a/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/config/PKConfigBuilder.kt b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/config/PKConfigBuilder.kt new file mode 100644 index 0000000000..40d64e44f0 --- /dev/null +++ b/extra-modules/extra-pluralkit/src/main/kotlin/com/kotlindiscord/kord/extensions/modules/extra/pluralkit/config/PKConfigBuilder.kt @@ -0,0 +1,80 @@ +@file:Suppress("StringLiteralDuplication") + +package com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config + +import dev.kord.common.ratelimit.IntervalRateLimiter +import kotlin.time.Duration +import kotlin.time.Duration.Companion.seconds + +class PKConfigBuilder { + /** + * A mapping of domain names to their corresponding rate limiters. + * + * Provide `null` for a domain to disable rate limiting. + */ + val domainRateLimiters: MutableMap = mutableMapOf( + "api.pluralkit.me" to IntervalRateLimiter(2, 1.seconds) + ) + + /** + * Rate limiter to use by default, when there's no domain-specific rate limiter. + * + * Provide `null` to disable rate limiting. + */ + var defaultRateLimiter: IntervalRateLimiter? = IntervalRateLimiter(2, 1.seconds) + + /** Replace the default rate limiter, using the specified settings. **/ + fun defaultLimit(limit: Int, interval: Duration) { + defaultRateLimiter = IntervalRateLimiter(limit, interval) + } + + /** Remove the default rate limiter, disabling rate limiting by default. **/ + fun unlimitByDefault() { + defaultRateLimiter = null + } + + /** Set a domain-specific rate limiter, using the specified settings. **/ + fun domainLimit(domain: String, limit: Int, interval: Duration) { + if ("/" in domain) { + error("URL provided as the `domain` parameter - please provide a domain instead.") + } + + domainRateLimiters[domain] = IntervalRateLimiter(limit, interval) + } + + /** Remove a domain-specific rate limiter, making it use the default rate limiter instead. **/ + fun defaultDomainLimit(domain: String) { + if ("/" in domain) { + error("URL provided as the `domain` parameter - please provide a domain instead.") + } + + domainRateLimiters.remove(domain) + } + + /** Disable rate limiting for the given domain. **/ + fun unlimitDomain(domain: String) { + if ("/" in domain) { + error("URL provided as the `domain` parameter - please provide a domain instead.") + } + + domainRateLimiters[domain] = null + } + + internal fun getLimiter(url: String): IntervalRateLimiter? { + var domain = url + + if ("://" in domain) { + domain = domain.split("://", limit = 2).first() + } + + if ("/" in domain) { + domain = domain.split("/", limit = 2).first() + } + + if (domain in domainRateLimiters) { + return domainRateLimiters[domain] + } + + return defaultRateLimiter + } +}