Skip to content

Commit

Permalink
PK: Configurable rate limiting support
Browse files Browse the repository at this point in the history
  • Loading branch information
gdude2002 committed Mar 18, 2024
1 parent 487370b commit 3826269
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.kotlindiscord.kord.extensions.extensions.Extension
import com.kotlindiscord.kord.extensions.extensions.ephemeralSlashCommand
import com.kotlindiscord.kord.extensions.extensions.event
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.api.PluralKit
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config.PKConfigBuilder
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.events.proxied
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.events.unproxied
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.storage.PKGuildConfig
Expand Down Expand Up @@ -55,7 +56,7 @@ const val NEGATIVE_EMOTE = "❌"
const val POSITIVE_EMOTE = ""

@Suppress("StringLiteralDuplication")
class PKExtension : Extension() {
class PKExtension(val config: PKConfigBuilder) : Extension() {
override val name: String = "ext-pluralkit"
override val bundle: String = "kordex.pluralkit"

Expand Down Expand Up @@ -433,17 +434,14 @@ class PKExtension : Extension() {
checkTask = null
}

private fun PKGuildConfig.api(): PluralKit {
var api = apiMap[apiUrl]

if (api == null) {
api = PluralKit(apiUrl)
apiMap[apiUrl] = api
private fun PKGuildConfig.api(): PluralKit =
apiMap.getOrPut(apiUrl) {
PluralKit(
apiUrl,
config.getLimiter(apiUrl)
)
}

return api
}

private fun Boolean.emote() =
if (this) {
POSITIVE_EMOTE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@
package com.kotlindiscord.kord.extensions.modules.extra.pluralkit

import com.kotlindiscord.kord.extensions.builders.ExtensibleBotBuilder
import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config.PKConfigBuilder

/** Set up and add the PluralKit extension to your bot. **/
/** Set up and add the PluralKit extension to your bot, using the default configuration. **/
fun ExtensibleBotBuilder.ExtensionsBuilder.extPluralKit() {
add(::PKExtension)
add { PKExtension(PKConfigBuilder()) }
}

/** Set up and add the PluralKit extension to your bot. **/
fun ExtensibleBotBuilder.ExtensionsBuilder.extPluralKit(body: PKConfigBuilder.() -> Unit) {
val builder = PKConfigBuilder()

body(builder)

add { PKExtension(builder) }
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ package com.kotlindiscord.kord.extensions.modules.extra.pluralkit.api

import com.kotlindiscord.kord.extensions.modules.extra.pluralkit.utils.LRUHashMap
import dev.kord.common.entity.Snowflake
import dev.kord.common.ratelimit.IntervalRateLimiter
import io.github.oshai.kotlinlogging.KotlinLogging
import io.ktor.client.*
import io.ktor.client.call.*
Expand All @@ -19,10 +20,15 @@ import io.ktor.client.request.*
import io.ktor.http.*
import io.ktor.serialization.kotlinx.json.*
import kotlinx.serialization.json.Json
import kotlin.time.Duration.Companion.seconds

internal const val PK_API_VERSION = 2

class PluralKit(private val baseUrl: String = "https://api.pluralkit.me", cacheSize: Int = 10_000) {
class PluralKit(
private val baseUrl: String = "https://api.pluralkit.me",
private val rateLimiter: IntervalRateLimiter? = IntervalRateLimiter(2, 1.seconds),
cacheSize: Int = 10_000
) {
private val logger = KotlinLogging.logger { }

private val messageUrl: String = "${this.baseUrl}/v$PK_API_VERSION/messages/{id}"
Expand Down Expand Up @@ -53,6 +59,8 @@ class PluralKit(private val baseUrl: String = "https://api.pluralkit.me", cacheS
val url = messageUrl.replace("id" to id)

try {
rateLimiter?.consume()

val result: PKMessage = client.get(url).body()
messageCache[id] = result

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
@file:Suppress("StringLiteralDuplication")

package com.kotlindiscord.kord.extensions.modules.extra.pluralkit.config

import dev.kord.common.ratelimit.IntervalRateLimiter
import kotlin.time.Duration
import kotlin.time.Duration.Companion.seconds

class PKConfigBuilder {
/**
* A mapping of domain names to their corresponding rate limiters.
*
* Provide `null` for a domain to disable rate limiting.
*/
val domainRateLimiters: MutableMap<String, IntervalRateLimiter?> = mutableMapOf(
"api.pluralkit.me" to IntervalRateLimiter(2, 1.seconds)
)

/**
* Rate limiter to use by default, when there's no domain-specific rate limiter.
*
* Provide `null` to disable rate limiting.
*/
var defaultRateLimiter: IntervalRateLimiter? = IntervalRateLimiter(2, 1.seconds)

/** Replace the default rate limiter, using the specified settings. **/
fun defaultLimit(limit: Int, interval: Duration) {
defaultRateLimiter = IntervalRateLimiter(limit, interval)
}

/** Remove the default rate limiter, disabling rate limiting by default. **/
fun unlimitByDefault() {
defaultRateLimiter = null
}

/** Set a domain-specific rate limiter, using the specified settings. **/
fun domainLimit(domain: String, limit: Int, interval: Duration) {
if ("/" in domain) {
error("URL provided as the `domain` parameter - please provide a domain instead.")
}

domainRateLimiters[domain] = IntervalRateLimiter(limit, interval)
}

/** Remove a domain-specific rate limiter, making it use the default rate limiter instead. **/
fun defaultDomainLimit(domain: String) {
if ("/" in domain) {
error("URL provided as the `domain` parameter - please provide a domain instead.")
}

domainRateLimiters.remove(domain)
}

/** Disable rate limiting for the given domain. **/
fun unlimitDomain(domain: String) {
if ("/" in domain) {
error("URL provided as the `domain` parameter - please provide a domain instead.")
}

domainRateLimiters[domain] = null
}

internal fun getLimiter(url: String): IntervalRateLimiter? {
var domain = url

if ("://" in domain) {
domain = domain.split("://", limit = 2).first()
}

if ("/" in domain) {
domain = domain.split("/", limit = 2).first()
}

if (domain in domainRateLimiters) {
return domainRateLimiters[domain]
}

return defaultRateLimiter
}
}

0 comments on commit 3826269

Please sign in to comment.