From 78742ba9654f3e52e2699a74e21c72ad5fc9cfa7 Mon Sep 17 00:00:00 2001 From: Nick Gaya Date: Thu, 4 Oct 2018 02:06:38 -0700 Subject: [PATCH 1/4] Add support for maxConcurrency parameter in SqsListener --- .../AbstractMessageListenerContainer.java | 30 +-- .../listener/QueueMessageHandler.java | 20 +- .../SimpleMessageListenerContainer.java | 73 +++++-- .../listener/annotation/SqsListener.java | 5 + .../SimpleMessageListenerContainerTest.java | 187 +++++++++++++++++- 5 files changed, 283 insertions(+), 32 deletions(-) diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/AbstractMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/AbstractMessageListenerContainer.java index ddb2f5ade..967fb5e01 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/AbstractMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/AbstractMessageListenerContainer.java @@ -273,7 +273,7 @@ protected void initialize() { for (QueueMessageHandler.MappingInformation mappingInformation : this.messageHandler.getHandlerMethods().keySet()) { for (String queue : mappingInformation.getLogicalResourceIds()) { - QueueAttributes queueAttributes = queueAttributes(queue, mappingInformation.getDeletionPolicy()); + QueueAttributes queueAttributes = queueAttributes(queue, mappingInformation.getDeletionPolicy(), mappingInformation.getMaxConcurrency()); if (queueAttributes != null) { this.registeredQueues.put(queue, queueAttributes); @@ -296,7 +296,7 @@ public void start() { doStart(); } - private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy deletionPolicy) { + private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy deletionPolicy, Integer maxConcurrency) { String destinationUrl; try { destinationUrl = getDestinationResolver().resolveDestination(queue); @@ -313,7 +313,8 @@ private QueueAttributes queueAttributes(String queue, SqsMessageDeletionPolicy d .withAttributeNames(QueueAttributeName.RedrivePolicy)); boolean hasRedrivePolicy = queueAttributes.getAttributes().containsKey(QueueAttributeName.RedrivePolicy.toString()); - return new QueueAttributes(hasRedrivePolicy, deletionPolicy, destinationUrl, getMaxNumberOfMessages(), getVisibilityTimeout(), getWaitTimeOut()); + return new QueueAttributes(hasRedrivePolicy, deletionPolicy, maxConcurrency, destinationUrl, + getMaxNumberOfMessages(), getVisibilityTimeout(), getWaitTimeOut()); } @Override @@ -354,15 +355,17 @@ protected static class QueueAttributes { private final boolean hasRedrivePolicy; private final SqsMessageDeletionPolicy deletionPolicy; + private final Integer maxConcurrency; private final String destinationUrl; private final Integer maxNumberOfMessages; private final Integer visibilityTimeout; private final Integer waitTimeOut; - public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, String destinationUrl, - Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) { + public QueueAttributes(boolean hasRedrivePolicy, SqsMessageDeletionPolicy deletionPolicy, Integer maxConcurrency, + String destinationUrl, Integer maxNumberOfMessages, Integer visibilityTimeout, Integer waitTimeOut) { this.hasRedrivePolicy = hasRedrivePolicy; this.deletionPolicy = deletionPolicy; + this.maxConcurrency = maxConcurrency; this.destinationUrl = destinationUrl; this.maxNumberOfMessages = maxNumberOfMessages; this.visibilityTimeout = visibilityTimeout; @@ -376,13 +379,8 @@ public boolean hasRedrivePolicy() { public ReceiveMessageRequest getReceiveMessageRequest() { ReceiveMessageRequest receiveMessageRequest = new ReceiveMessageRequest(this.destinationUrl). withAttributeNames(RECEIVING_ATTRIBUTES). - withMessageAttributeNames(RECEIVING_MESSAGE_ATTRIBUTES); - - if (this.maxNumberOfMessages != null) { - receiveMessageRequest.withMaxNumberOfMessages(this.maxNumberOfMessages); - } else { - receiveMessageRequest.withMaxNumberOfMessages(DEFAULT_MAX_NUMBER_OF_MESSAGES); - } + withMessageAttributeNames(RECEIVING_MESSAGE_ATTRIBUTES). + withMaxNumberOfMessages(getMaxNumberOfMessages()); if (this.visibilityTimeout != null) { receiveMessageRequest.withVisibilityTimeout(this.visibilityTimeout); @@ -398,5 +396,13 @@ public ReceiveMessageRequest getReceiveMessageRequest() { public SqsMessageDeletionPolicy getDeletionPolicy() { return this.deletionPolicy; } + + public Integer getMaxConcurrency() { + return this.maxConcurrency; + } + + public int getMaxNumberOfMessages() { + return this.maxNumberOfMessages != null ? this.maxNumberOfMessages : DEFAULT_MAX_NUMBER_OF_MESSAGES; + } } } diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/QueueMessageHandler.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/QueueMessageHandler.java index 48131d74c..cf79836a3 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/QueueMessageHandler.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/QueueMessageHandler.java @@ -115,12 +115,17 @@ protected MappingInformation getMappingForMethod(Method method, Class handler this.logger.warn("Listener method '" + method.getName() + "' in type '" + method.getDeclaringClass().getName() + "' has deletion policy 'NEVER' but does not have a parameter of type Acknowledgment."); } - return new MappingInformation(resolveDestinationNames(sqsListenerAnnotation.value()), sqsListenerAnnotation.deletionPolicy()); + Integer maxConcurrency = sqsListenerAnnotation.maxConcurrency() > 0 ? sqsListenerAnnotation.maxConcurrency() : null; + return new MappingInformation(resolveDestinationNames(sqsListenerAnnotation.value()), + sqsListenerAnnotation.deletionPolicy(), + maxConcurrency); } MessageMapping messageMappingAnnotation = AnnotationUtils.findAnnotation(method, MessageMapping.class); if (messageMappingAnnotation != null && messageMappingAnnotation.value().length > 0) { - return new MappingInformation(resolveDestinationNames(messageMappingAnnotation.value()), SqsMessageDeletionPolicy.ALWAYS); + return new MappingInformation(resolveDestinationNames(messageMappingAnnotation.value()), + SqsMessageDeletionPolicy.ALWAYS, + null); // maxConcurrency } return null; @@ -233,9 +238,14 @@ protected static class MappingInformation implements Comparable logicalResourceIds, SqsMessageDeletionPolicy deletionPolicy) { + private final Integer maxConcurrency; + + public MappingInformation(Set logicalResourceIds, + SqsMessageDeletionPolicy deletionPolicy, + Integer maxConcurrency) { this.logicalResourceIds = Collections.unmodifiableSet(logicalResourceIds); this.deletionPolicy = deletionPolicy; + this.maxConcurrency = maxConcurrency; } public Set getLogicalResourceIds() { @@ -246,6 +256,10 @@ public SqsMessageDeletionPolicy getDeletionPolicy() { return this.deletionPolicy; } + public Integer getMaxConcurrency() { + return this.maxConcurrency; + } + @SuppressWarnings("NullableProblems") @Override public int compareTo(MappingInformation o) { diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java index 2f0910c7c..caf452064 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java @@ -28,9 +28,9 @@ import java.util.HashMap; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -183,8 +183,16 @@ protected AsyncTaskExecutor createDefaultTaskExecutor() { if (spinningThreads > 0) { threadPoolTaskExecutor.setCorePoolSize(spinningThreads * DEFAULT_WORKER_THREADS); - int maxNumberOfMessagePerBatch = getMaxNumberOfMessages() != null ? getMaxNumberOfMessages() : DEFAULT_WORKER_THREADS; - threadPoolTaskExecutor.setMaxPoolSize(spinningThreads * (maxNumberOfMessagePerBatch + 1)); + int maxPoolSize = 0; + for (QueueAttributes queueAttributes : this.getRegisteredQueues().values()) { + int queueMaxNumberOfMessages = queueAttributes.getMaxNumberOfMessages(); + // Each queue needs 1 polling thread plus n handler threads + // where n is determined by the queue concurrency or batch size + maxPoolSize += 1 + (queueAttributes.getMaxConcurrency() != null + ? queueAttributes.getMaxConcurrency() + : queueMaxNumberOfMessages); + } + threadPoolTaskExecutor.setMaxPoolSize(maxPoolSize); } // No use of a thread pool executor queue to avoid retaining message to long in memory @@ -276,23 +284,41 @@ private AsynchronousMessageListener(String logicalQueueName, QueueAttributes que @Override public void run() { + final int maxMessages = queueAttributes.getMaxNumberOfMessages(); + final int maxConcurrency = queueAttributes.getMaxConcurrency() != null ? queueAttributes.getMaxConcurrency() : maxMessages; + // Semaphore used to limit the number of messages being handled concurrently. + final Semaphore semaphore = new Semaphore(maxConcurrency); while (isQueueRunning()) { + // How many semaphore permits this thread currently holds. + int currentPermits = 0; try { + // Wait for sufficient threads available before requesting + // additional messages. + currentPermits += acquirePermits(semaphore, Math.min(maxConcurrency, maxMessages)); + if (!isQueueRunning()) { + break; + } + ReceiveMessageResult receiveMessageResult = getAmazonSqs().receiveMessage(this.queueAttributes.getReceiveMessageRequest()); - CountDownLatch messageBatchLatch = new CountDownLatch(receiveMessageResult.getMessages().size()); for (Message message : receiveMessageResult.getMessages()) { + // If maxConcurrency < maxMessages we might have more + // messages than threads available. In that case we + // wait to acquire a worker thread permit before + // submitting each message to the task executor. + if (currentPermits == 0) { + currentPermits += acquirePermits(semaphore, 1); + } if (isQueueRunning()) { MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes); - getTaskExecutor().execute(new SignalExecutingRunnable(messageBatchLatch, messageExecutor)); - } else { - messageBatchLatch.countDown(); + getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor)); + if (currentPermits > 0) { + // After submitting the task to the executor, it's + // the SignalExecutingRunnable's job to release the + // permit, so we can decrement + currentPermits -= 1; + } } } - try { - messageBatchLatch.await(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } } catch (Exception e) { getLogger().warn("An Exception occurred while polling queue '{}'. The failing operation will be " + "retried in {} milliseconds", this.logicalQueueName, getBackOffTime(), e); @@ -302,9 +328,14 @@ public void run() { } catch (InterruptedException ie) { Thread.currentThread().interrupt(); } + } finally { + semaphore.release(currentPermits); } } + // Wait for all tasks to complete before terminating + acquirePermits(semaphore, maxConcurrency); + SimpleMessageListenerContainer.this.scheduledFutureByQueue.remove(this.logicalQueueName); } @@ -316,6 +347,16 @@ private boolean isQueueRunning() { return false; } } + + private int acquirePermits(Semaphore semaphore, int permits) { + try { + semaphore.acquire(permits); + return permits; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + return 0; + } + } } private class MessageExecutor implements Runnable { @@ -383,11 +424,11 @@ private org.springframework.messaging.Message getMessageForExecution() { private static class SignalExecutingRunnable implements Runnable { - private final CountDownLatch countDownLatch; + private final Semaphore semaphore; private final Runnable runnable; - private SignalExecutingRunnable(CountDownLatch endSignal, Runnable runnable) { - this.countDownLatch = endSignal; + private SignalExecutingRunnable(Semaphore semaphore, Runnable runnable) { + this.semaphore = semaphore; this.runnable = runnable; } @@ -396,7 +437,7 @@ public void run() { try { this.runnable.run(); } finally { - this.countDownLatch.countDown(); + this.semaphore.release(); } } } diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/annotation/SqsListener.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/annotation/SqsListener.java index 5a90f15e5..5ca37c3bb 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/annotation/SqsListener.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/annotation/SqsListener.java @@ -75,4 +75,9 @@ */ SqsMessageDeletionPolicy deletionPolicy() default SqsMessageDeletionPolicy.NO_REDRIVE; + /** + * Defines the maximum number of messages to process concurrently. + */ + int maxConcurrency() default 0; + } diff --git a/spring-cloud-aws-messaging/src/test/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainerTest.java b/spring-cloud-aws-messaging/src/test/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainerTest.java index 191dbe54e..c2e4cbcaf 100644 --- a/spring-cloud-aws-messaging/src/test/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainerTest.java +++ b/spring-cloud-aws-messaging/src/test/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainerTest.java @@ -133,7 +133,7 @@ public void testWithDefaultTaskExecutorAndOneHandler() throws Exception { Map messageHandlerMethods = Collections.singletonMap( new QueueMessageHandler.MappingInformation(Collections.singleton("testQueue"), - SqsMessageDeletionPolicy.ALWAYS), null); + SqsMessageDeletionPolicy.ALWAYS, null), null); SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); @@ -157,6 +157,37 @@ public void testWithDefaultTaskExecutorAndOneHandler() throws Exception { assertEquals(expectedPoolMaxSize, taskExecutor.getMaxPoolSize()); } + @Test + public void testWithDefaultTaskExecutorAndOneHandlerMaxConcurrency() throws Exception { + final int testedMaxNumberOfMessages = 10; + final int testedMaxConcurrency = 15; + + Map messageHandlerMethods = Collections.singletonMap( + new QueueMessageHandler.MappingInformation(Collections.singleton("testQueue"), + SqsMessageDeletionPolicy.ALWAYS, testedMaxConcurrency), null); + + SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); + + QueueMessageHandler mockedHandler = mock(QueueMessageHandler.class); + AmazonSQSAsync mockedSqs = mock(AmazonSQSAsync.class, withSettings().stubOnly()); + + when(mockedSqs.getQueueAttributes(any(GetQueueAttributesRequest.class))).thenReturn(new GetQueueAttributesResult()); + when(mockedSqs.getQueueUrl(any(GetQueueUrlRequest.class))).thenReturn(new GetQueueUrlResult().withQueueUrl("testQueueUrl")); + when(mockedHandler.getHandlerMethods()).thenReturn(messageHandlerMethods); + + container.setMaxNumberOfMessages(testedMaxNumberOfMessages); + container.setAmazonSqs(mockedSqs); + container.setMessageHandler(mockedHandler); + + container.afterPropertiesSet(); + + int expectedPoolMaxSize = messageHandlerMethods.size() * (testedMaxConcurrency + 1); + + ThreadPoolTaskExecutor taskExecutor = (ThreadPoolTaskExecutor) container.getTaskExecutor(); + assertNotNull(taskExecutor); + assertEquals(expectedPoolMaxSize, taskExecutor.getMaxPoolSize()); + } + @Test public void testCustomTaskExecutor() throws Exception { SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); @@ -800,6 +831,95 @@ private static void setLogLevel(Level level) { logContext.getLogger(SimpleMessageListenerContainer.class).setLevel(level); } + @Test + public void testReceiveMessageMaxConcurrencyLessThanMaxNumberOfMessages() throws Exception { + SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); + int maxNumberOfMessages = 2; + container.setMaxNumberOfMessages(maxNumberOfMessages); + + AmazonSQSAsync sqs = mock(AmazonSQSAsync.class, withSettings().stubOnly()); + container.setAmazonSqs(sqs); + + QueueMessageHandler messageHandler = new QueueMessageHandler(); + container.setMessageHandler(messageHandler); + StaticApplicationContext applicationContext = new StaticApplicationContext(); + applicationContext.registerSingleton("testMessageListener", TestListenerLowMaxConcurrency.class); + messageHandler.setApplicationContext(applicationContext); + container.setBeanName("testContainerName"); + messageHandler.afterPropertiesSet(); + + String queueUrl = "http://lowMaxConcurrency.amazonaws.com"; + mockGetQueueUrl(sqs, "lowMaxConcurrency", queueUrl); + mockGetQueueAttributesWithEmptyResult(sqs, queueUrl); + + container.afterPropertiesSet(); + + when(sqs.receiveMessage(new ReceiveMessageRequest(queueUrl).withAttributeNames("All") + .withMessageAttributeNames("All") + .withMaxNumberOfMessages(maxNumberOfMessages) + .withWaitTimeSeconds(20))) + .thenReturn(new ReceiveMessageResult().withMessages(new Message().withBody("message1-1"), + new Message().withBody("message1-2"))) + .thenReturn(new ReceiveMessageResult()); + when(sqs.getQueueAttributes(any(GetQueueAttributesRequest.class))).thenReturn(new GetQueueAttributesResult()); + + container.start(); + + final TestListenerLowMaxConcurrency listener = applicationContext.getBean(TestListenerLowMaxConcurrency.class); + assertTrue(listener.getCountDownLatch().await(1, TimeUnit.SECONDS)); + + container.stop(); + + // Verify that the messages were processed one at a time + assertEquals(1, listener.getObservedMaxConcurrency()); + } + + @Test + public void testReceiveMessageMaxConcurrencyMoreThanMaxNumberOfMessages() throws Exception { + SimpleMessageListenerContainer container = new SimpleMessageListenerContainer(); + int maxNumberOfMessages = 2; + container.setMaxNumberOfMessages(maxNumberOfMessages); + + AmazonSQSAsync sqs = mock(AmazonSQSAsync.class, withSettings().stubOnly()); + container.setAmazonSqs(sqs); + + QueueMessageHandler messageHandler = new QueueMessageHandler(); + container.setMessageHandler(messageHandler); + StaticApplicationContext applicationContext = new StaticApplicationContext(); + applicationContext.registerSingleton("testMessageListener", TestListenerHighMaxConcurrency.class); + messageHandler.setApplicationContext(applicationContext); + container.setBeanName("testContainerName"); + messageHandler.afterPropertiesSet(); + + String queueUrl = "http://highMaxConcurrency.amazonaws.com"; + mockGetQueueUrl(sqs, "highMaxConcurrency", queueUrl); + mockGetQueueAttributesWithEmptyResult(sqs, queueUrl); + + container.afterPropertiesSet(); + + when(sqs.receiveMessage(new ReceiveMessageRequest(queueUrl).withAttributeNames("All") + .withMessageAttributeNames("All") + .withMaxNumberOfMessages(maxNumberOfMessages) + .withWaitTimeSeconds(20))) + .thenReturn(new ReceiveMessageResult().withMessages(new Message().withBody("message1-1"), + new Message().withBody("message1-2"))) + .thenReturn(new ReceiveMessageResult().withMessages(new Message().withBody("message2-1"), + new Message().withBody("message2-2"))) + .thenReturn(new ReceiveMessageResult().withMessages(new Message().withBody("message3-1"))) + .thenReturn(new ReceiveMessageResult()); + when(sqs.getQueueAttributes(any(GetQueueAttributesRequest.class))).thenReturn(new GetQueueAttributesResult()); + + container.start(); + + final TestListenerHighMaxConcurrency listener = applicationContext.getBean(TestListenerHighMaxConcurrency.class); + assertTrue(listener.getCountDownLatch().await(1, TimeUnit.SECONDS)); + + container.stop(); + + // Verify that 4 messages (i.e. 2 batches) were processed at once + assertEquals(4, listener.getObservedMaxConcurrency()); + } + @Test public void stop_withALogicalQueueName_mustStopOnlyTheSpecifiedQueue() throws Exception { // Arrange @@ -1364,4 +1484,69 @@ public CountDownLatch getCountDownLatch() { } } + private static class MaxConcurrencyMetric { + private int concurrency; + private int maxConcurrency; + + public synchronized void enter() { + concurrency++; + maxConcurrency = Math.max(concurrency, maxConcurrency); + } + + public synchronized void exit() { + concurrency--; + } + + public synchronized int getMaxConcurrency() { + return maxConcurrency; + } + } + + private static class TestListenerLowMaxConcurrency { + + private final CountDownLatch countDownLatch = new CountDownLatch(2); + + private final MaxConcurrencyMetric maxConcurrencyMetric = new MaxConcurrencyMetric(); + + @RuntimeUse + @SqsListener(value = "lowMaxConcurrency", maxConcurrency = 1) + private void handleMessage(String message) throws InterruptedException { + maxConcurrencyMetric.enter(); + Thread.sleep(200); + maxConcurrencyMetric.exit(); + this.countDownLatch.countDown(); + } + + public CountDownLatch getCountDownLatch() { + return countDownLatch; + } + + public int getObservedMaxConcurrency() { + return maxConcurrencyMetric.getMaxConcurrency(); + } + } + + private static class TestListenerHighMaxConcurrency { + + private final CountDownLatch countDownLatch = new CountDownLatch(5); + + private final MaxConcurrencyMetric maxConcurrencyMetric = new MaxConcurrencyMetric(); + + @RuntimeUse + @SqsListener(value = "highMaxConcurrency", maxConcurrency = 4) + private void handleMessage(String message) throws InterruptedException { + maxConcurrencyMetric.enter(); + Thread.sleep(200); + maxConcurrencyMetric.exit(); + this.countDownLatch.countDown(); + } + + public CountDownLatch getCountDownLatch() { + return countDownLatch; + } + + public int getObservedMaxConcurrency() { + return maxConcurrencyMetric.getMaxConcurrency(); + } + } } From 98ed7e4a82c6a52f0295278bde4f0c2d1e1a3734 Mon Sep 17 00:00:00 2001 From: Nick Gaya Date: Thu, 4 Oct 2018 14:10:34 -0700 Subject: [PATCH 2/4] Adjust handling of interrupt case --- .../listener/SimpleMessageListenerContainer.java | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java index caf452064..8df9ea559 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java @@ -310,12 +310,18 @@ public void run() { } if (isQueueRunning()) { MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes); - getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor)); if (currentPermits > 0) { - // After submitting the task to the executor, it's - // the SignalExecutingRunnable's job to release the - // permit, so we can decrement + getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor)); + // After submitting the task to the executor, + // it's the SignalExecutingRunnable's job to + // release the permit, so we can decrement our + // permit count for this thread. currentPermits -= 1; + } else { + // We failed to acquire a permit due to being interrupted. + // Don't use a SignalExecutingRunnable since the worker + // should not release an extra permit. + getTaskExecutor().execute(messageExecutor); } } } From 71889ad8e01a9f0de008ba45b7d179723ad8f4a0 Mon Sep 17 00:00:00 2001 From: Nick Gaya Date: Thu, 4 Oct 2018 14:24:34 -0700 Subject: [PATCH 3/4] Add workaround for race condition --- .../SimpleMessageListenerContainer.java | 36 +++++++++++++++---- 1 file changed, 30 insertions(+), 6 deletions(-) diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java index 8df9ea559..2df4a3405 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java @@ -30,6 +30,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; +import java.util.concurrent.RejectedExecutionException; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; @@ -311,17 +312,18 @@ public void run() { if (isQueueRunning()) { MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes); if (currentPermits > 0) { - getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor)); - // After submitting the task to the executor, + executeTask(new SignalExecutingRunnable(semaphore, messageExecutor)); + // After the task is submitted to the executor, // it's the SignalExecutingRunnable's job to // release the permit, so we can decrement our // permit count for this thread. currentPermits -= 1; } else { - // We failed to acquire a permit due to being interrupted. - // Don't use a SignalExecutingRunnable since the worker - // should not release an extra permit. - getTaskExecutor().execute(messageExecutor); + // We failed to acquire a permit due to being + // interrupted. We don't want the worker to + // release a permit that wasn't acquired, so + // don't use SignalExecutingRunnable. + executeTask(messageExecutor); } } } @@ -363,6 +365,28 @@ private int acquirePermits(Semaphore semaphore, int permits) { return 0; } } + + private void executeTask(Runnable runnable) { + // There is a potential race condition between the time when the + // semaphore is released by the SignalExecutingRunnable, and the + // time when the thread pool thread actually becomes available to + // accept another task. + // + // As a workaround, we retry a single time after a + // RejectedExecutionException. + try { + getTaskExecutor().execute(runnable); + return; + } catch (RejectedExecutionException e) { + // Sleep for a moment and try again + try { + Thread.sleep(100); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } + getTaskExecutor().execute(runnable); + } } private class MessageExecutor implements Runnable { From fc68aa04be53371ea8ea5f70e06ef3f73d3165ea Mon Sep 17 00:00:00 2001 From: Nick Gaya Date: Mon, 8 Oct 2018 10:21:33 -0700 Subject: [PATCH 4/4] Better fix for semaphore rejected task issue --- .../SimpleMessageListenerContainer.java | 53 +++++++------------ 1 file changed, 18 insertions(+), 35 deletions(-) diff --git a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java index 2df4a3405..7049d8e46 100644 --- a/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java +++ b/spring-cloud-aws-messaging/src/main/java/org/springframework/cloud/aws/messaging/listener/SimpleMessageListenerContainer.java @@ -44,7 +44,6 @@ */ public class SimpleMessageListenerContainer extends AbstractMessageListenerContainer { - private static final int DEFAULT_WORKER_THREADS = 2; private static final String DEFAULT_THREAD_NAME_PREFIX = ClassUtils.getShortName(SimpleMessageListenerContainer.class) + "-"; @@ -182,22 +181,28 @@ protected AsyncTaskExecutor createDefaultTaskExecutor() { int spinningThreads = this.getRegisteredQueues().size(); if (spinningThreads > 0) { - threadPoolTaskExecutor.setCorePoolSize(spinningThreads * DEFAULT_WORKER_THREADS); - - int maxPoolSize = 0; + int poolSize = 0; + int bufferSize = 0; for (QueueAttributes queueAttributes : this.getRegisteredQueues().values()) { int queueMaxNumberOfMessages = queueAttributes.getMaxNumberOfMessages(); // Each queue needs 1 polling thread plus n handler threads // where n is determined by the queue concurrency or batch size - maxPoolSize += 1 + (queueAttributes.getMaxConcurrency() != null - ? queueAttributes.getMaxConcurrency() - : queueMaxNumberOfMessages); + poolSize += 1 + (queueAttributes.getMaxConcurrency() != null + ? queueAttributes.getMaxConcurrency() + : queueMaxNumberOfMessages); + bufferSize += queueMaxNumberOfMessages; } - threadPoolTaskExecutor.setMaxPoolSize(maxPoolSize); + threadPoolTaskExecutor.setCorePoolSize(poolSize); + threadPoolTaskExecutor.setMaxPoolSize(poolSize); + // Ideally we would like to set the queue capacity to 0 to avoid + // messages waiting too long in memory, but due to a race condition + // we may encounter a transitional state where a task finishes + // executing but the thread is not ready to accept a new task for a + // few milliseconds, resulting in rejected task exceptions. To + // prevent this issue we allow a small amount of buffering. + threadPoolTaskExecutor.setQueueCapacity(bufferSize); + threadPoolTaskExecutor.setAllowCoreThreadTimeOut(true); } - - // No use of a thread pool executor queue to avoid retaining message to long in memory - threadPoolTaskExecutor.setQueueCapacity(0); threadPoolTaskExecutor.afterPropertiesSet(); return threadPoolTaskExecutor; @@ -312,7 +317,7 @@ public void run() { if (isQueueRunning()) { MessageExecutor messageExecutor = new MessageExecutor(this.logicalQueueName, message, this.queueAttributes); if (currentPermits > 0) { - executeTask(new SignalExecutingRunnable(semaphore, messageExecutor)); + getTaskExecutor().execute(new SignalExecutingRunnable(semaphore, messageExecutor)); // After the task is submitted to the executor, // it's the SignalExecutingRunnable's job to // release the permit, so we can decrement our @@ -323,7 +328,7 @@ public void run() { // interrupted. We don't want the worker to // release a permit that wasn't acquired, so // don't use SignalExecutingRunnable. - executeTask(messageExecutor); + getTaskExecutor().execute(messageExecutor); } } } @@ -365,28 +370,6 @@ private int acquirePermits(Semaphore semaphore, int permits) { return 0; } } - - private void executeTask(Runnable runnable) { - // There is a potential race condition between the time when the - // semaphore is released by the SignalExecutingRunnable, and the - // time when the thread pool thread actually becomes available to - // accept another task. - // - // As a workaround, we retry a single time after a - // RejectedExecutionException. - try { - getTaskExecutor().execute(runnable); - return; - } catch (RejectedExecutionException e) { - // Sleep for a moment and try again - try { - Thread.sleep(100); - } catch (InterruptedException ie) { - Thread.currentThread().interrupt(); - } - } - getTaskExecutor().execute(runnable); - } } private class MessageExecutor implements Runnable {