Skip to content

Commit

Permalink
Try to limit
Browse files Browse the repository at this point in the history
Signed-off-by: shedaniel <[email protected]>
  • Loading branch information
shedaniel committed Jun 7, 2022
1 parent 09e6084 commit c5784be
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 69 deletions.
4 changes: 2 additions & 2 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ plugins {
}

group "me.shedaniel"
sourceCompatibility = targetCompatibility = 1.8
sourceCompatibility = targetCompatibility = 17

license {
include "**/*.kt"
Expand Down Expand Up @@ -131,7 +131,7 @@ build.finalizedBy mainJar

compileKotlin {
kotlinOptions.suppressWarnings = true
kotlinOptions.jvmTarget = "1.8"
kotlinOptions.jvmTarget = "17"
kotlinOptions {
freeCompilerArgs = ["-Xopt-in=kotlin.RequiresOptIn", "-Xinline-classes"]
languageVersion = "1.4"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ import me.shedaniel.linkie.discord.utils.event
import me.shedaniel.linkie.discord.utils.replyComplex
import me.shedaniel.linkie.discord.utils.sendEmbedMessage
import java.time.Duration
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.concurrent.ExecutionException
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
Expand All @@ -56,47 +58,53 @@ class CommandHandler(
val user = event.message.author.orElse(null)?.takeUnless { it.isBot } ?: return
val message: String = event.message.content
val prefix = commandAcceptor.getPrefix(event)
try {
executor.submit {
if (message.lowercase().startsWith(prefix)) {
val content = message.substring(prefix.length)
if (!rateLimiter.allow(user.id.asLong())) {
CompletableFuture.runAsync({
if (message.lowercase().startsWith(prefix)) {
val content = message.substring(prefix.length)
val split = content.splitArgs()
if (split.isNotEmpty()) {
val cmd = split[0].lowercase()
val args = split.drop(1).toMutableList()
if (!rateLimiter.allow(user, cmd, mutableMapOf("args" to args))) {
throwableHandler.generateErrorMessage(event.message, RateLimitException(rateLimiter.maxRequestPer10Sec), channel, user)
return@submit
return@runAsync
}
val split = content.splitArgs()
if (split.isNotEmpty()) {
val cmd = split[0].lowercase()
val ctx = MessageBasedCommandContext(event, prefix, cmd, channel)
val args = split.drop(1).toMutableList()
try {
runBlocking {
commandAcceptor.execute(event, ctx, args)
}
} catch (throwable: Throwable) {
if (throwableHandler.shouldError(throwable)) {
try {
ctx.message.replyComplex {
layout { dismissButton() }
embed { throwableHandler.generateThrowable(this, throwable, user) }
}
} catch (throwable2: Exception) {
throwable2.addSuppressed(throwable)
throwable2.printStackTrace()
val ctx = MessageBasedCommandContext(event, prefix, cmd, channel)
try {
runBlocking {
commandAcceptor.execute(event, ctx, args)
}
} catch (throwable: Throwable) {
if (throwableHandler.shouldError(throwable)) {
try {
ctx.message.replyComplex {
layout { dismissButton() }
embed { throwableHandler.generateThrowable(this, throwable, user) }
}
} catch (throwable2: Exception) {
throwable2.addSuppressed(throwable)
throwable2.printStackTrace()
}
}
}
}
}.get(10, TimeUnit.SECONDS)
} catch (throwable: TimeoutException) {
val newThrowable = TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.")
throwableHandler.generateErrorMessage(event.message, newThrowable, channel, user)
} catch (throwable: ExecutionException) {
throwableHandler.generateErrorMessage(event.message, throwable.cause ?: throwable, channel, user)
} catch (throwable: Throwable) {
throwableHandler.generateErrorMessage(event.message, throwable, channel, user)
}
}
}, executor).orTimeout(10, TimeUnit.SECONDS)
.exceptionally {
it.let { throwable ->
if (throwable is TimeoutException) {
TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.")
} else if (throwable is CompletionException) {
throwable.cause ?: throwable
} else {
throwable
}
}.also { throwable ->
throwableHandler.generateErrorMessage(event.message, throwable, channel, user)
}
null
}
.join()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,25 @@

package me.shedaniel.linkie.discord.handler

import discord4j.core.`object`.entity.User
import java.util.*

class RateLimiter(val maxRequestPer10Sec: Int) {
open class RateLimiter(val maxRequestPer10Sec: Int) {
data class Entry(
val time: Long,
val userId: Long,
)

private val log: Queue<Entry> = LinkedList()

fun allow(userId: Long): Boolean {
open fun allow(user: User, cmd: String, args: Map<String, Any>): Boolean {
val curTime = System.currentTimeMillis()
val boundary = curTime - 10000
synchronized(log) {
while (!log.isEmpty() && log.element().time <= boundary) {
log.poll()
}
val userId = user.id.asLong()
log.add(Entry(curTime, userId))
return log.count { it.userId == userId } <= maxRequestPer10Sec
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import discord4j.core.`object`.entity.User
import discord4j.core.`object`.entity.channel.Channel
import discord4j.core.event.domain.interaction.ChatInputAutoCompleteEvent
import discord4j.core.event.domain.interaction.ChatInputInteractionEvent
import discord4j.core.`object`.command.ApplicationCommandOption
import discord4j.discordjson.json.ApplicationCommandData
import discord4j.discordjson.json.ApplicationCommandOptionChoiceData
import discord4j.discordjson.json.ApplicationCommandRequest
Expand All @@ -40,6 +41,8 @@ import me.shedaniel.linkie.discord.utils.extensions.getOrNull
import me.shedaniel.linkie.discord.utils.replyComplex
import me.shedaniel.linkie.discord.utils.user
import reactor.core.publisher.Mono
import java.util.concurrent.CompletableFuture
import java.util.concurrent.CompletionException
import java.util.concurrent.ExecutionException
import java.util.concurrent.ExecutorService
import java.util.concurrent.Executors
Expand Down Expand Up @@ -256,7 +259,34 @@ class SlashCommands(
sentAny = true
}
}
if (!rateLimiter.allow(event.user.id.asLong())) {

fun ApplicationCommandInteractionOption.collectOptions(): MutableMap<String, Any> {
val map = mutableMapOf<String, Any>()
for (option in options) {
if (!map.containsKey(option.name)) {
if (option.value.isPresent) {
map[option.name] = option.value.get().raw
}
map.putAll(option.collectOptions())
}
}
return map
}

fun ChatInputInteractionEvent.collectOptions(): MutableMap<String, Any> {
val map = mutableMapOf<String, Any>()
for (option in options) {
if (!map.containsKey(option.name)) {
if (option.value.isPresent) {
map[option.name] = option.value.get().raw
}
map.putAll(option.collectOptions())
}
}
return map
}

if (!rateLimiter.allow(event.user, cmd, event.collectOptions())) {
val exception = RateLimitException(rateLimiter.maxRequestPer10Sec)
if (throwableHandler.shouldError(exception)) {
try {
Expand All @@ -272,32 +302,37 @@ class SlashCommands(
return@SlashCommandHandler
}
val optionsGetter = OptionsGetter.of(command, ctx, event)
runCatching {
try {
executor.submit {
if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) {
}
}.get(3, TimeUnit.SECONDS)
} catch (throwable: TimeoutException) {
throw TimeoutException("The command took too long to execute, the maximum execution time is 3 seconds.")
} catch (throwable: ExecutionException) {
throw throwable.cause ?: throwable
}
if (!sentAny) {
throw IllegalStateException("Command was not resolved!")
CompletableFuture.runAsync({
if (!executeOptions(command, ctx, optionsGetter, command.options, event.options) && !command.execute(command, ctx, optionsGetter)) {
}
}.exceptionOrNull()?.also { throwable ->
if (throwableHandler.shouldError(throwable)) {
try {
ctx.message.replyComplex {
layout { dismissButton() }
embed { throwableHandler.generateThrowable(this, throwable, ctx.user) }
}, executor).orTimeout(10, TimeUnit.SECONDS)
.exceptionally {
it.let { throwable ->
if (throwable is TimeoutException) {
TimeoutException("The command took too long to execute, the maximum execution time is 10 seconds.")
} else if (throwable is CompletionException) {
throwable.cause ?: throwable
} else {
throwable
}
}.also { throwable ->
if (throwableHandler.shouldError(throwable)) {
try {
ctx.message.replyComplex {
layout { dismissButton() }
embed { throwableHandler.generateThrowable(this, throwable, ctx.user) }
}
} catch (throwable2: Exception) {
throwable2.addSuppressed(throwable)
throwable2.printStackTrace()
}
}
} catch (throwable2: Exception) {
throwable2.addSuppressed(throwable)
throwable2.printStackTrace()
}
null
}
.join()
if (!sentAny) {
throw IllegalStateException("Command was not resolved!")
}
}, autoCompleter = { event ->
val optionsGetter = WeakOptionsGetter.of(command, event).asSuggestion(event.commandName)
Expand Down Expand Up @@ -433,7 +468,10 @@ interface SlashCommandOptionSuggestionSink {
.value(value)
.build()

fun choice(value: Any): ApplicationCommandOptionChoiceData = choice(value.toString(), value)
fun choice(value: Any): ApplicationCommandOptionChoiceData = choice(value.toString().take(99), value.let {
if (it is String) it.take(99)
else it
})
}

interface SlashCommand : SlashCommandExecutor, SlashCommandSuggester, CommandOptionProperties {
Expand Down
20 changes: 16 additions & 4 deletions src/main/kotlin/me/shedaniel/linkie/discord/LinkieBot.kt
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import com.soywiz.klock.seconds
import discord4j.core.`object`.entity.channel.ThreadChannel
import discord4j.core.event.domain.message.MessageCreateEvent
import discord4j.core.event.domain.thread.ThreadMembersUpdateEvent
import discord4j.core.`object`.entity.User
import discord4j.core.spec.EmbedCreateSpec
import discord4j.discordjson.json.gateway.ThreadMembersUpdate
import io.ktor.application.*
Expand All @@ -32,6 +33,7 @@ import io.ktor.response.*
import io.ktor.routing.*
import io.ktor.server.engine.*
import io.ktor.server.netty.*
import io.ktor.utils.io.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.GlobalScope
Expand Down Expand Up @@ -132,7 +134,15 @@ fun main() {
)
)
) {
val rateLimiter = RateLimiter(3)
val rateLimiter = object : RateLimiter(3) {
override fun allow(user: User, cmd: String, args: Map<String, Any>): Boolean {
info("Handling command /$cmd ${
args.entries.joinToString(" ")
{ "${it.key}: ${it.value}" }
} from ${user.discriminatedName}")
return super.allow(user, cmd, args)
}
}
val slashCommands = SlashCommands(this, LinkieThrowableHandler, ::warn, debug = isDebug, rateLimiter = rateLimiter)
TricksManager.listen(slashCommands)
val commandManager = object : CommandManager(if (isDebug) "@" else "!") {
Expand Down Expand Up @@ -161,9 +171,11 @@ fun main() {
gateway.getChannelById(event.threadId).subscribe { channel ->
if (channel is ThreadChannel) {
channel.sendMessage {
it.embeds(EmbedCreateSpec.create()
.withTitle("Linked has entered the thread")
.withDescription("Thanks for having me here! This message is sent when Linkie is brought into a thread.\nThread support in Linkie is still experimental, please report any issues found on our issue tracker! ٭(•﹏•)٭"))
it.embeds(
EmbedCreateSpec.create()
.withTitle("Linked has entered the thread")
.withDescription("Thanks for having me here! This message is sent when Linkie is brought into a thread.\nThread support in Linkie is still experimental, please report any issues found on our issue tracker! ٭(•﹏•)٭")
)
}.subscribe()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import kotlinx.coroutines.runBlocking
import me.shedaniel.linkie.Class
import me.shedaniel.linkie.MappingsContainer
import me.shedaniel.linkie.MappingsEntryType
import me.shedaniel.linkie.MappingsMember
import me.shedaniel.linkie.MappingsProvider
import me.shedaniel.linkie.Method
import me.shedaniel.linkie.Namespace
import me.shedaniel.linkie.Namespaces
import me.shedaniel.linkie.discord.Command
import me.shedaniel.linkie.discord.MappingsQueryUtils
import me.shedaniel.linkie.discord.MappingsQueryUtils.query
import me.shedaniel.linkie.discord.scommands.CommandOptionMeta
import me.shedaniel.linkie.discord.scommands.OptionsGetter
Expand Down Expand Up @@ -107,11 +107,19 @@ open class QueryMappingsCommand(
suggest { _, options, sink ->
runBlocking {
val rawValue = options.optNullable(this@string) ?: ""
if (rawValue.length >= 100) {
sink.suggest(listOf())
return@runBlocking
}
val value = rawValue.replace('.', '/').replace('#', '/')
val namespace = weakNamespaceGetter(options.cmd, options) ?: return@runBlocking
val provider = options.optNullable(version, VersionNamespaceConfig(namespace)) ?: namespace.getDefaultProvider()
val mappings = provider.get()
val result = query(mappings, value, *types)
val result = try {
query(mappings, value, *types)
} catch (ignored: NullPointerException) {
MappingsQueryUtils.Result(mutableListOf(), false)
}
val suggestions = result.results.asSequence().sortedByDescending { it.score }.map { (value, _) ->
when {
value is Class -> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,10 @@ class QueryTranslateMappingsCommand(
suggest { _, options, sink ->
runBlocking {
val rawValue = options.optNullable(this@string) ?: ""
if (rawValue.length >= 100) {
sink.suggest(listOf())
return@runBlocking
}
val value = rawValue.replace('.', '/').replace('#', '/')
val src = weakSrcNamespaceGetter(options.cmd, options) ?: return@runBlocking
val dst = weakDstNamespaceGetter(options.cmd, options) ?: return@runBlocking
Expand All @@ -144,7 +148,11 @@ class QueryTranslateMappingsCommand(

val provider = options.optNullable(version, VersionNamespaceConfig(src, defaultVersion) { allVersions }) ?: src.getProvider(defaultVersion)
val mappings = provider.get()
val result = MappingsQueryUtils.query(mappings, value, *types)
val result = try {
MappingsQueryUtils.query(mappings, value, *types)
} catch (ignored: NullPointerException) {
MappingsQueryUtils.Result(mutableListOf(), false)
}
val suggestions = result.results.asSequence().sortedByDescending { it.score }.map { (value, _) ->
when {
value is Class -> {
Expand Down

0 comments on commit c5784be

Please sign in to comment.