From c5784beae59cbfb44b570099ffcfc92263115450 Mon Sep 17 00:00:00 2001 From: shedaniel Date: Tue, 7 Jun 2022 17:37:17 +0800 Subject: [PATCH] Try to limit Signed-off-by: shedaniel --- build.gradle | 4 +- .../linkie/discord/handler/CommandHandler.kt | 76 ++++++++-------- .../linkie/discord/handler/RateLimiter.kt | 6 +- .../linkie/discord/scommands/SlashCommands.kt | 86 +++++++++++++------ .../me/shedaniel/linkie/discord/LinkieBot.kt | 20 ++++- .../discord/commands/QueryMappingsCommand.kt | 12 ++- .../commands/QueryTranslateMappingsCommand.kt | 10 ++- 7 files changed, 145 insertions(+), 69 deletions(-) diff --git a/build.gradle b/build.gradle index 18eac9d..2f59b50 100644 --- a/build.gradle +++ b/build.gradle @@ -8,7 +8,7 @@ plugins { } group "me.shedaniel" -sourceCompatibility = targetCompatibility = 1.8 +sourceCompatibility = targetCompatibility = 17 license { include "**/*.kt" @@ -131,7 +131,7 @@ build.finalizedBy mainJar compileKotlin { kotlinOptions.suppressWarnings = true - kotlinOptions.jvmTarget = "1.8" + kotlinOptions.jvmTarget = "17" kotlinOptions { freeCompilerArgs = ["-Xopt-in=kotlin.RequiresOptIn", "-Xinline-classes"] languageVersion = "1.4" diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt index b231c4c..3e0e70c 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/CommandHandler.kt @@ -34,6 +34,8 @@ import me.shedaniel.linkie.discord.utils.event import me.shedaniel.linkie.discord.utils.replyComplex import me.shedaniel.linkie.discord.utils.sendEmbedMessage import java.time.Duration +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionException import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -56,47 +58,53 @@ class CommandHandler( val user = event.message.author.orElse(null)?.takeUnless { it.isBot } ?: return val message: String = event.message.content val prefix = commandAcceptor.getPrefix(event) - try { - executor.submit { - if (message.lowercase().startsWith(prefix)) { - val content = message.substring(prefix.length) - if (!rateLimiter.allow(user.id.asLong())) { + CompletableFuture.runAsync({ + if (message.lowercase().startsWith(prefix)) { + val content = message.substring(prefix.length) + val split = content.splitArgs() + if (split.isNotEmpty()) { + val cmd = split[0].lowercase() + val args = split.drop(1).toMutableList() + if (!rateLimiter.allow(user, cmd, mutableMapOf("args" to args))) { throwableHandler.generateErrorMessage(event.message, RateLimitException(rateLimiter.maxRequestPer10Sec), channel, user) - return@submit + return@runAsync } - val split = content.splitArgs() - if (split.isNotEmpty()) { - val cmd = split[0].lowercase() - val ctx = MessageBasedCommandContext(event, prefix, cmd, channel) - val args = split.drop(1).toMutableList() - try { - runBlocking { - commandAcceptor.execute(event, ctx, args) - } - } catch (throwable: Throwable) { - if (throwableHandler.shouldError(throwable)) { - try { - ctx.message.replyComplex { - layout { dismissButton() } - embed { throwableHandler.generateThrowable(this, throwable, user) } - } - } catch (throwable2: Exception) { - throwable2.addSuppressed(throwable) - throwable2.printStackTrace() + val ctx = MessageBasedCommandContext(event, prefix, cmd, channel) + try { + runBlocking { + commandAcceptor.execute(event, ctx, args) + } + } catch (throwable: Throwable) { + if (throwableHandler.shouldError(throwable)) { + try { + ctx.message.replyComplex { + layout { dismissButton() } + embed { throwableHandler.generateThrowable(this, throwable, user) } } + } catch (throwable2: Exception) { + throwable2.addSuppressed(throwable) + throwable2.printStackTrace() } } } } - }.get(10, TimeUnit.SECONDS) - } catch (throwable: TimeoutException) { - val newThrowable = TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.") - throwableHandler.generateErrorMessage(event.message, newThrowable, channel, user) - } catch (throwable: ExecutionException) { - throwableHandler.generateErrorMessage(event.message, throwable.cause ?: throwable, channel, user) - } catch (throwable: Throwable) { - throwableHandler.generateErrorMessage(event.message, throwable, channel, user) - } + } + }, executor).orTimeout(10, TimeUnit.SECONDS) + .exceptionally { + it.let { throwable -> + if (throwable is TimeoutException) { + TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.") + } else if (throwable is CompletionException) { + throwable.cause ?: throwable + } else { + throwable + } + }.also { throwable -> + throwableHandler.generateErrorMessage(event.message, throwable, channel, user) + } + null + } + .join() } } diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt index 319e78f..cbbedbe 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/handler/RateLimiter.kt @@ -16,9 +16,10 @@ package me.shedaniel.linkie.discord.handler +import discord4j.core.`object`.entity.User import java.util.* -class RateLimiter(val maxRequestPer10Sec: Int) { +open class RateLimiter(val maxRequestPer10Sec: Int) { data class Entry( val time: Long, val userId: Long, @@ -26,13 +27,14 @@ class RateLimiter(val maxRequestPer10Sec: Int) { private val log: Queue = LinkedList() - fun allow(userId: Long): Boolean { + open fun allow(user: User, cmd: String, args: Map): Boolean { val curTime = System.currentTimeMillis() val boundary = curTime - 10000 synchronized(log) { while (!log.isEmpty() && log.element().time <= boundary) { log.poll() } + val userId = user.id.asLong() log.add(Entry(curTime, userId)) return log.count { it.userId == userId } <= maxRequestPer10Sec } diff --git a/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt b/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt index dbf127d..464f2a4 100644 --- a/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt +++ b/src/discord_api/kotlin/me/shedaniel/linkie/discord/scommands/SlashCommands.kt @@ -25,6 +25,7 @@ import discord4j.core.`object`.entity.User import discord4j.core.`object`.entity.channel.Channel import discord4j.core.event.domain.interaction.ChatInputAutoCompleteEvent import discord4j.core.event.domain.interaction.ChatInputInteractionEvent +import discord4j.core.`object`.command.ApplicationCommandOption import discord4j.discordjson.json.ApplicationCommandData import discord4j.discordjson.json.ApplicationCommandOptionChoiceData import discord4j.discordjson.json.ApplicationCommandRequest @@ -40,6 +41,8 @@ import me.shedaniel.linkie.discord.utils.extensions.getOrNull import me.shedaniel.linkie.discord.utils.replyComplex import me.shedaniel.linkie.discord.utils.user import reactor.core.publisher.Mono +import java.util.concurrent.CompletableFuture +import java.util.concurrent.CompletionException import java.util.concurrent.ExecutionException import java.util.concurrent.ExecutorService import java.util.concurrent.Executors @@ -256,7 +259,34 @@ class SlashCommands( sentAny = true } } - if (!rateLimiter.allow(event.user.id.asLong())) { + + fun ApplicationCommandInteractionOption.collectOptions(): MutableMap { + val map = mutableMapOf() + for (option in options) { + if (!map.containsKey(option.name)) { + if (option.value.isPresent) { + map[option.name] = option.value.get().raw + } + map.putAll(option.collectOptions()) + } + } + return map + } + + fun ChatInputInteractionEvent.collectOptions(): MutableMap { + val map = mutableMapOf() + for (option in options) { + if (!map.containsKey(option.name)) { + if (option.value.isPresent) { + map[option.name] = option.value.get().raw + } + map.putAll(option.collectOptions()) + } + } + return map + } + + if (!rateLimiter.allow(event.user, cmd, event.collectOptions())) { val exception = RateLimitException(rateLimiter.maxRequestPer10Sec) if (throwableHandler.shouldError(exception)) { try { @@ -272,32 +302,37 @@ class SlashCommands( return@SlashCommandHandler } val optionsGetter = OptionsGetter.of(command, ctx, event) - runCatching { - try { - executor.submit { - if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) { - } - }.get(3, TimeUnit.SECONDS) - } catch (throwable: TimeoutException) { - throw TimeoutException("The command took too long to execute, the maximum execution time is 3 seconds.") - } catch (throwable: ExecutionException) { - throw throwable.cause ?: throwable - } - if (!sentAny) { - throw IllegalStateException("Command was not resolved!") + CompletableFuture.runAsync({ + if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) { } - }.exceptionOrNull()?.also { throwable -> - if (throwableHandler.shouldError(throwable)) { - try { - ctx.message.replyComplex { - layout { dismissButton() } - embed { throwableHandler.generateThrowable(this, throwable, ctx.user) } + }, executor).orTimeout(10, TimeUnit.SECONDS) + .exceptionally { + it.let { throwable -> + if (throwable is TimeoutException) { + TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.") + } else if (throwable is CompletionException) { + throwable.cause ?: throwable + } else { + throwable + } + }.also { throwable -> + if (throwableHandler.shouldError(throwable)) { + try { + ctx.message.replyComplex { + layout { dismissButton() } + embed { throwableHandler.generateThrowable(this, throwable, ctx.user) } + } + } catch (throwable2: Exception) { + throwable2.addSuppressed(throwable) + throwable2.printStackTrace() + } } - } catch (throwable2: Exception) { - throwable2.addSuppressed(throwable) - throwable2.printStackTrace() } + null } + .join() + if (!sentAny) { + throw IllegalStateException("Command was not resolved!") } }, autoCompleter = { event -> val optionsGetter = WeakOptionsGetter.of(command, event).asSuggestion(event.commandName) @@ -433,7 +468,10 @@ interface SlashCommandOptionSuggestionSink { .value(value) .build() - fun choice(value: Any): ApplicationCommandOptionChoiceData = choice(value.toString(), value) + fun choice(value: Any): ApplicationCommandOptionChoiceData = choice(value.toString().take(99), value.let { + if (it is String) it.take(99) + else it + }) } interface SlashCommand : SlashCommandExecutor, SlashCommandSuggester, CommandOptionProperties { diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt b/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt index c0b39a9..21f67eb 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt @@ -24,6 +24,7 @@ import com.soywiz.klock.seconds import discord4j.core.`object`.entity.channel.ThreadChannel import discord4j.core.event.domain.message.MessageCreateEvent import discord4j.core.event.domain.thread.ThreadMembersUpdateEvent +import discord4j.core.`object`.entity.User import discord4j.core.spec.EmbedCreateSpec import discord4j.discordjson.json.gateway.ThreadMembersUpdate import io.ktor.application.* @@ -32,6 +33,7 @@ import io.ktor.response.* import io.ktor.routing.* import io.ktor.server.engine.* import io.ktor.server.netty.* +import io.ktor.utils.io.* import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.GlobalScope @@ -132,7 +134,15 @@ fun main() { ) ) ) { - val rateLimiter = RateLimiter(3) + val rateLimiter = object : RateLimiter(3) { + override fun allow(user: User, cmd: String, args: Map): Boolean { + info("Handling command /$cmd ${ + args.entries.joinToString(" ") + { "${it.key}: ${it.value}" } + } from ${user.discriminatedName}") + return super.allow(user, cmd, args) + } + } val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug, rateLimiter = rateLimiter) TricksManager.listen(slashCommands) val commandManager = object : CommandManager(if (isDebug) "@" else "!") { @@ -161,9 +171,11 @@ fun main() { gateway.getChannelById(event.threadId).subscribe { channel -> if (channel is ThreadChannel) { channel.sendMessage { - it.embeds(EmbedCreateSpec.create() - .withTitle("Linked has entered the thread") - .withDescription("Thanks for having me here! This message is sent when Linkie is brought into a thread.\nThread support in Linkie is still experimental, please report any issues found on our issue tracker! ٭(•﹏•)٭")) + it.embeds( + EmbedCreateSpec.create() + .withTitle("Linked has entered the thread") + .withDescription("Thanks for having me here! This message is sent when Linkie is brought into a thread.\nThread support in Linkie is still experimental, please report any issues found on our issue tracker! ٭(•﹏•)٭") + ) }.subscribe() } } diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt index fde1a8e..c2246f0 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryMappingsCommand.kt @@ -21,12 +21,12 @@ import kotlinx.coroutines.runBlocking import me.shedaniel.linkie.Class import me.shedaniel.linkie.MappingsContainer import me.shedaniel.linkie.MappingsEntryType -import me.shedaniel.linkie.MappingsMember import me.shedaniel.linkie.MappingsProvider import me.shedaniel.linkie.Method import me.shedaniel.linkie.Namespace import me.shedaniel.linkie.Namespaces import me.shedaniel.linkie.discord.Command +import me.shedaniel.linkie.discord.MappingsQueryUtils import me.shedaniel.linkie.discord.MappingsQueryUtils.query import me.shedaniel.linkie.discord.scommands.CommandOptionMeta import me.shedaniel.linkie.discord.scommands.OptionsGetter @@ -107,11 +107,19 @@ open class QueryMappingsCommand( suggest { _, options, sink -> runBlocking { val rawValue = options.optNullable(this@string) ?: "" + if (rawValue.length >= 100) { + sink.suggest(listOf()) + return@runBlocking + } val value = rawValue.replace('.', '/').replace('#', '/') val namespace = weakNamespaceGetter(options.cmd, options) ?: return@runBlocking val provider = options.optNullable(version, VersionNamespaceConfig(namespace)) ?: namespace.getDefaultProvider() val mappings = provider.get() - val result = query(mappings, value, *types) + val result = try { + query(mappings, value, *types) + } catch (ignored: NullPointerException) { + MappingsQueryUtils.Result(mutableListOf(), false) + } val suggestions = result.results.asSequence().sortedByDescending { it.score }.map { (value, _) -> when { value is Class -> { diff --git a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt index ec2be6a..a907a2c 100644 --- a/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt +++ b/src/main/kotlin/me/shedaniel/linkie/discord/commands/QueryTranslateMappingsCommand.kt @@ -133,6 +133,10 @@ class QueryTranslateMappingsCommand( suggest { _, options, sink -> runBlocking { val rawValue = options.optNullable(this@string) ?: "" + if (rawValue.length >= 100) { + sink.suggest(listOf()) + return@runBlocking + } val value = rawValue.replace('.', '/').replace('#', '/') val src = weakSrcNamespaceGetter(options.cmd, options) ?: return@runBlocking val dst = weakDstNamespaceGetter(options.cmd, options) ?: return@runBlocking @@ -144,7 +148,11 @@ class QueryTranslateMappingsCommand( val provider = options.optNullable(version, VersionNamespaceConfig(src, defaultVersion) { allVersions }) ?: src.getProvider(defaultVersion) val mappings = provider.get() - val result = MappingsQueryUtils.query(mappings, value, *types) + val result = try { + MappingsQueryUtils.query(mappings, value, *types) + } catch (ignored: NullPointerException) { + MappingsQueryUtils.Result(mutableListOf(), false) + } val suggestions = result.results.asSequence().sortedByDescending { it.score }.map { (value, _) -> when { value is Class -> {