Skip to content
This repository has been archived by the owner on Jan 19, 2022. It is now read-only.

Add support for maxConcurrency parameter in SqsListener #380

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ protected void initialize() {

for (QueueMessageHandler.MappingInformation mappingInformation : this.messageHandler.getHandlerMethods().keySet()) {
for (String queue : mappingInformation.getLogicalResourceIds()) {
QueueAttributes queueAttributes = queueAttributes(queue, mappingInformation.getDeletionPolicy());
QueueAttributes queueAttributes = queueAttributes(queue, mappingInformation.getDeletionPolicy(), mappingInformation.getMaxConcurrency());

if (queueAttributes != null) {
this.registeredQueues.put(queue, queueAttributes);
Expand All @@ -296,7 +296,7 @@ public void start() {
doStart();
}

private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy deletionPolicy) {
private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy deletionPolicy, Integer maxConcurrency) {
String destinationUrl;
try {
destinationUrl = getDestinationResolver().resolveDestination(queue);
Expand All @@ -313,7 +313,8 @@ private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy d
.withAttributeNames(QueueAttributeName.RedrivePolicy));
boolean hasRedrivePolicy = queueAttributes.getAttributes().containsKey(QueueAttributeName.RedrivePolicy.toString());

return new QueueAttributes(hasRedrivePolicy, deletionPolicy, destinationUrl, getMaxNumberOfMessages(), getVisibilityTimeout(), getWaitTimeOut());
return new QueueAttributes(hasRedrivePolicy, deletionPolicy, maxConcurrency, destinationUrl,
getMaxNumberOfMessages(), getVisibilityTimeout(), getWaitTimeOut());
}

@Override
Expand Down Expand Up @@ -354,15 +355,17 @@ protected static class QueueAttributes {

private final boolean hasRedrivePolicy;
private final SqsMessageDeletionPolicy deletionPolicy;
private final Integer maxConcurrency;
private final String destinationUrl;
private final Integer maxNumberOfMessages;
private final Integer visibilityTimeout;
private final Integer waitTimeOut;

public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, String destinationUrl,
Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) {
public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, Integer maxConcurrency,
String destinationUrl, Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) {
this.hasRedrivePolicy = hasRedrivePolicy;
this.deletionPolicy = deletionPolicy;
this.maxConcurrency = maxConcurrency;
this.destinationUrl = destinationUrl;
this.maxNumberOfMessages = maxNumberOfMessages;
this.visibilityTimeout = visibilityTimeout;
Expand All @@ -376,13 +379,8 @@ public boolean hasRedrivePolicy() {
public ReceiveMessageRequest getReceiveMessageRequest() {
ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(this.destinationUrl).
withAttributeNames(RECEIVING_ATTRIBUTES).
withMessageAttributeNames(RECEIVING_MESSAGE_ATTRIBUTES);

if (this.maxNumberOfMessages != null) {
receiveMessageRequest.withMaxNumberOfMessages(this.maxNumberOfMessages);
} else {
receiveMessageRequest.withMaxNumberOfMessages(DEFAULT_MAX_NUMBER_OF_MESSAGES);
}
withMessageAttributeNames(RECEIVING_MESSAGE_ATTRIBUTES).
withMaxNumberOfMessages(getMaxNumberOfMessages());

if (this.visibilityTimeout != null) {
receiveMessageRequest.withVisibilityTimeout(this.visibilityTimeout);
Expand All @@ -398,5 +396,13 @@ public ReceiveMessageRequest getReceiveMessageRequest() {
public SqsMessageDeletionPolicy getDeletionPolicy() {
return this.deletionPolicy;
}

public Integer getMaxConcurrency() {
return this.maxConcurrency;
}

public int getMaxNumberOfMessages() {
return this.maxNumberOfMessages != null ? this.maxNumberOfMessages : DEFAULT_MAX_NUMBER_OF_MESSAGES;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,17 @@ protected MappingInformation getMappingForMethod(Method method, Class<?> handler
this.logger.warn("Listener method '" + method.getName() + "' in type '" + method.getDeclaringClass().getName() +
"' has deletion policy 'NEVER' but does not have a parameter of type Acknowledgment.");
}
return new MappingInformation(resolveDestinationNames(sqsListenerAnnotation.value()), sqsListenerAnnotation.deletionPolicy());
Integer maxConcurrency = sqsListenerAnnotation.maxConcurrency() > 0 ? sqsListenerAnnotation.maxConcurrency() : null;
return new MappingInformation(resolveDestinationNames(sqsListenerAnnotation.value()),
sqsListenerAnnotation.deletionPolicy(),
maxConcurrency);
}

MessageMapping messageMappingAnnotation = AnnotationUtils.findAnnotation(method, MessageMapping.class);
if (messageMappingAnnotation != null && messageMappingAnnotation.value().length > 0) {
return new MappingInformation(resolveDestinationNames(messageMappingAnnotation.value()), SqsMessageDeletionPolicy.ALWAYS);
return new MappingInformation(resolveDestinationNames(messageMappingAnnotation.value()),
SqsMessageDeletionPolicy.ALWAYS,
null); // maxConcurrency
}

return null;
Expand Down Expand Up @@ -233,9 +238,14 @@ protected static class MappingInformation implements Comparable<MappingInformati

private final SqsMessageDeletionPolicy deletionPolicy;

public MappingInformation(Set<String> logicalResourceIds, SqsMessageDeletionPolicy deletionPolicy) {
private final Integer maxConcurrency;

public MappingInformation(Set<String> logicalResourceIds,
SqsMessageDeletionPolicy deletionPolicy,
Integer maxConcurrency) {
this.logicalResourceIds = Collections.unmodifiableSet(logicalResourceIds);
this.deletionPolicy = deletionPolicy;
this.maxConcurrency = maxConcurrency;
}

public Set<String> getLogicalResourceIds() {
Expand All @@ -246,6 +256,10 @@ public SqsMessageDeletionPolicy getDeletionPolicy() {
return this.deletionPolicy;
}

public Integer getMaxConcurrency() {
return this.maxConcurrency;
}

@SuppressWarnings("NullableProblems")
@Override
public int compareTo(MappingInformation o) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,10 @@
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.Semaphore;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

Expand All @@ -43,7 +44,6 @@
*/
public class SimpleMessageListenerContainer extends AbstractMessageListenerContainer {

private static final int DEFAULT_WORKER_THREADS = 2;
private static final String DEFAULT_THREAD_NAME_PREFIX =
ClassUtils.getShortName(SimpleMessageListenerContainer.class) + "-";

Expand Down Expand Up @@ -181,14 +181,28 @@ protected AsyncTaskExecutor createDefaultTaskExecutor() {
int spinningThreads = this.getRegisteredQueues().size();

if (spinningThreads > 0) {
threadPoolTaskExecutor.setCorePoolSize(spinningThreads * DEFAULT_WORKER_THREADS);

int maxNumberOfMessagePerBatch = getMaxNumberOfMessages() != null ? getMaxNumberOfMessages() : DEFAULT_WORKER_THREADS;
threadPoolTaskExecutor.setMaxPoolSize(spinningThreads * (maxNumberOfMessagePerBatch + 1));
int poolSize = 0;
int bufferSize = 0;
for (QueueAttributes queueAttributes : this.getRegisteredQueues().values()) {
int queueMaxNumberOfMessages = queueAttributes.getMaxNumberOfMessages();
// Each queue needs 1 polling thread plus n handler threads
// where n is determined by the queue concurrency or batch size
poolSize += 1 + (queueAttributes.getMaxConcurrency() != null
? queueAttributes.getMaxConcurrency()
: queueMaxNumberOfMessages);
bufferSize += queueMaxNumberOfMessages;
}
threadPoolTaskExecutor.setCorePoolSize(poolSize);
threadPoolTaskExecutor.setMaxPoolSize(poolSize);
// Ideally we would like to set the queue capacity to 0 to avoid
// messages waiting too long in memory, but due to a race condition
// we may encounter a transitional state where a task finishes
// executing but the thread is not ready to accept a new task for a
// few milliseconds, resulting in rejected task exceptions. To
// prevent this issue we allow a small amount of buffering.
threadPoolTaskExecutor.setQueueCapacity(bufferSize);
threadPoolTaskExecutor.setAllowCoreThreadTimeOut(true);
}

// No use of a thread pool executor queue to avoid retaining message to long in memory
threadPoolTaskExecutor.setQueueCapacity(0);
threadPoolTaskExecutor.afterPropertiesSet();

return threadPoolTaskExecutor;
Expand Down Expand Up @@ -276,23 +290,48 @@ private AsynchronousMessageListener(String logicalQueueName, QueueAttributes que

@Override
public void run() {
final int maxMessages = queueAttributes.getMaxNumberOfMessages();
final int maxConcurrency = queueAttributes.getMaxConcurrency() != null ? queueAttributes.getMaxConcurrency() : maxMessages;
// Semaphore used to limit the number of messages being handled concurrently.
final Semaphore semaphore = new Semaphore(maxConcurrency);
while (isQueueRunning()) {
// How many semaphore permits this thread currently holds.
int currentPermits = 0;
try {
// Wait for sufficient threads available before requesting
// additional messages.
currentPermits += acquirePermits(semaphore, Math.min(maxConcurrency, maxMessages));
if (!isQueueRunning()) {
break;
}

ReceiveMessageResult receiveMessageResult = getAmazonSqs().receiveMessage(this.queueAttributes.getReceiveMessageRequest());
CountDownLatch messageBatchLatch = new CountDownLatch(receiveMessageResult.getMessages().size());
for (Message message : receiveMessageResult.getMessages()) {
// If maxConcurrency < maxMessages we might have more
// messages than threads available. In that case we
// wait to acquire a worker thread permit before
// submitting each message to the task executor.
if (currentPermits == 0) {
currentPermits += acquirePermits(semaphore, 1);
}
if (isQueueRunning()) {
MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes);
getTaskExecutor().execute(new SignalExecutingRunnable(messageBatchLatch, messageExecutor));
} else {
messageBatchLatch.countDown();
if (currentPermits > 0) {
getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor));
// After the task is submitted to the executor,
// it's the SignalExecutingRunnable's job to
// release the permit, so we can decrement our
// permit count for this thread.
currentPermits -= 1;
} else {
// We failed to acquire a permit due to being
// interrupted. We don't want the worker to
// release a permit that wasn't acquired, so
// don't use SignalExecutingRunnable.
getTaskExecutor().execute(messageExecutor);
}
}
}
try {
messageBatchLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
} catch (Exception e) {
getLogger().warn("An Exception occurred while polling queue '{}'. The failing operation will be " +
"retried in {} milliseconds", this.logicalQueueName, getBackOffTime(), e);
Expand All @@ -302,9 +341,14 @@ public void run() {
} catch (InterruptedException ie) {
Thread.currentThread().interrupt();
}
} finally {
semaphore.release(currentPermits);
}
}

// Wait for all tasks to complete before terminating
acquirePermits(semaphore, maxConcurrency);

SimpleMessageListenerContainer.this.scheduledFutureByQueue.remove(this.logicalQueueName);
}

Expand All @@ -316,6 +360,16 @@ private boolean isQueueRunning() {
return false;
}
}

private int acquirePermits(Semaphore semaphore, int permits) {
try {
semaphore.acquire(permits);
return permits;
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
return 0;
}
}
}

private class MessageExecutor implements Runnable {
Expand Down Expand Up @@ -383,11 +437,11 @@ private org.springframework.messaging.Message<String> getMessageForExecution() {

private static class SignalExecutingRunnable implements Runnable {

private final CountDownLatch countDownLatch;
private final Semaphore semaphore;
private final Runnable runnable;

private SignalExecutingRunnable(CountDownLatch endSignal, Runnable runnable) {
this.countDownLatch = endSignal;
private SignalExecutingRunnable(Semaphore semaphore, Runnable runnable) {
this.semaphore = semaphore;
this.runnable = runnable;
}

Expand All @@ -396,7 +450,7 @@ public void run() {
try {
this.runnable.run();
} finally {
this.countDownLatch.countDown();
this.semaphore.release();
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,9 @@
*/
SqsMessageDeletionPolicy deletionPolicy() default SqsMessageDeletionPolicy.NO_REDRIVE;

/**
* Defines the maximum number of messages to process concurrently.
*/
int maxConcurrency() default 0;

}
Loading