Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IGNITE-24224 Fixed NIO server broken message serialization with SSL enabled. #11819

Merged
merged 6 commits into from
Jan 22, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -3321,7 +3321,7 @@ private static final class WriteRequestSystemImpl implements SessionWriteRequest
/**
*
*/
private static final class WriteRequestImpl implements SessionWriteRequest, SessionChangeRequest {
static final class WriteRequestImpl implements SessionWriteRequest, SessionChangeRequest {
/** */
private GridNioSession ses;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,15 @@ public void enabledProtocols(String... enabledProtos) {
}

try {
GridNioSslHandler hnd = new GridNioSslHandler(this,
GridNioSslHandler hnd = new GridNioSslHandler(
this,
ses,
engine,
directBuf,
order,
log,
handshake,
sslMeta.encodedBuffer());
sslMeta);

sslMeta.handler(hnd);

Expand All @@ -257,10 +258,7 @@ public void enabledProtocols(String... enabledProtos) {

hnd.handshake();

ByteBuffer alreadyDecoded = sslMeta.decodedBuffer();

if (alreadyDecoded != null)
proceedMessageReceived(ses, alreadyDecoded);
processApplicationBuffer(ses, hnd.getApplicationBuffer());
}
catch (SSLException e) {
onSessionOpenedException = e;
Expand Down Expand Up @@ -400,14 +398,7 @@ public ByteBuffer encrypt(GridNioSession ses, ByteBuffer input) throws SSLExcept
if (hnd.isHandshakeFinished())
hnd.flushDeferredWrites();

ByteBuffer appBuf = hnd.getApplicationBuffer();

appBuf.flip();

if (appBuf.hasRemaining())
proceedMessageReceived(ses, appBuf);

appBuf.compact();
processApplicationBuffer(ses, hnd.getApplicationBuffer());

if (hnd.isInboundDone() && !hnd.isOutboundDone()) {
if (log.isDebugEnabled())
Expand All @@ -424,6 +415,16 @@ public ByteBuffer encrypt(GridNioSession ses, ByteBuffer input) throws SSLExcept
}
}

/** */
private void processApplicationBuffer(GridNioSession ses, ByteBuffer appBuffer) throws IgniteCheckedException {
appBuffer.flip();

if (appBuffer.hasRemaining())
proceedMessageReceived(ses, appBuffer);

appBuffer.compact();
}

/** {@inheritDoc} */
@Override public GridNioFuture<Boolean> onSessionClose(GridNioSession ses) throws IgniteCheckedException {
GridNioSslHandler hnd = sslHandler(ses);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.ignite.internal.util.nio.GridNioSession;
import org.apache.ignite.internal.util.typedef.internal.U;
import org.apache.ignite.lang.IgniteInClosure;
import org.jetbrains.annotations.Nullable;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
Expand Down Expand Up @@ -109,17 +110,19 @@ class GridNioSslHandler extends ReentrantLock {
* @param directBuf Direct buffer flag.
* @param order Byte order.
* @param handshake is handshake required.
* @param encBuf encoded buffer to be used.
* @param sslMeta SSL meta.
* @throws SSLException If exception occurred when starting SSL handshake.
*/
GridNioSslHandler(GridNioSslFilter parent,
GridNioSslHandler(
GridNioSslFilter parent,
GridNioSession ses,
SSLEngine engine,
boolean directBuf,
ByteOrder order,
IgniteLogger log,
boolean handshake,
ByteBuffer encBuf) throws SSLException {
GridSslMeta sslMeta
) throws SSLException {
assert parent != null;
assert ses != null;
assert engine != null;
Expand All @@ -145,32 +148,21 @@ class GridNioSslHandler extends ReentrantLock {
// Allocate a little bit more so SSL engine would not return buffer overflow status.
int netBufSize = sslEngine.getSession().getPacketBufferSize() + 50;

outNetBuf = directBuf ? ByteBuffer.allocateDirect(netBufSize) : ByteBuffer.allocate(netBufSize);

outNetBuf.order(order);

inNetBuf = directBuf ? ByteBuffer.allocateDirect(netBufSize) : ByteBuffer.allocate(netBufSize);

inNetBuf.order(order);

if (encBuf != null) {
encBuf.flip();

inNetBuf.put(encBuf); // Buffer contains bytes read but not handled by sslEngine at BlockingSslHandler.
}
outNetBuf = createBuffer(netBufSize, null);

// Initially buffer is empty.
outNetBuf.position(0);
outNetBuf.limit(0);

int appBufSize = Math.max(sslEngine.getSession().getApplicationBufferSize() + 50, netBufSize * 2);

appBuf = directBuf ? ByteBuffer.allocateDirect(appBufSize) : ByteBuffer.allocate(appBufSize);
inNetBuf = createBuffer(netBufSize, sslMeta.encodedBuffer());

appBuf.order(order);
appBuf = createBuffer(
Math.max(sslEngine.getSession().getApplicationBufferSize() + 50, netBufSize * 2),
sslMeta.decodedBuffer()
);

if (log.isDebugEnabled())
log.debug("Started SSL session [netBufSize=" + netBufSize + ", appBufSize=" + appBufSize + ']');
log.debug("Started SSL session [netBufSize=" + outNetBuf.capacity() + ", appBufSize=" + appBuf.capacity() + ']');
}

/**
Expand Down Expand Up @@ -682,6 +674,21 @@ private ByteBuffer copy(ByteBuffer original) {
return cp;
}

/** */
private ByteBuffer createBuffer(int size, @Nullable ByteBuffer init) {
if (init != null)
size = Math.max(size, init.remaining());

ByteBuffer buf = directBuf ? ByteBuffer.allocateDirect(size) : ByteBuffer.allocate(size);

buf.order(order);

if (init != null)
buf.put(init);

return buf;
}

/**
* Write request for cases while handshake is not finished yet.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public long tcpHandshake(
BlockingTransport transport = stateProvider.isSslEnabled() ?
new SslTransport(sslMeta, ch, directBuffer, log) : new TcpTransport(ch);

ByteBuffer buf = transport.recieveNodeId();
ByteBuffer buf = transport.receiveNodeId();

if (buf == null)
return NEED_WAIT;
Expand All @@ -98,7 +98,7 @@ else if (log.isDebugEnabled())

transport.sendHandshake(msg);

buf = transport.recieveAcknowledge();
buf = transport.receiveAcknowledge();

long rcvCnt = buf.getLong(DIRECT_TYPE_SIZE);

Expand All @@ -125,7 +125,7 @@ private abstract static class BlockingTransport {
* @return Buffer with {@link NodeIdMessage}.
* @throws IgniteCheckedException If failed.
*/
ByteBuffer recieveNodeId() throws IgniteCheckedException {
ByteBuffer receiveNodeId() throws IgniteCheckedException {
ByteBuffer buf = ByteBuffer.allocate(NodeIdMessage.MESSAGE_FULL_SIZE)
.order(ByteOrder.LITTLE_ENDIAN);

Expand Down Expand Up @@ -171,7 +171,7 @@ void sendHandshake(HandshakeMessage msg) throws IgniteCheckedException {
* @return Buffer with message.
* @throws IgniteCheckedException If failed.
*/
ByteBuffer recieveAcknowledge() throws IgniteCheckedException {
ByteBuffer receiveAcknowledge() throws IgniteCheckedException {
ByteBuffer buf = ByteBuffer.allocate(RecoveryLastReceivedMessage.MESSAGE_FULL_SIZE)
.order(ByteOrder.LITTLE_ENDIAN);

Expand Down Expand Up @@ -333,8 +333,11 @@ private static class SslTransport extends BlockingTransport {

ByteBuffer inBuf = handler.inputBuffer();

if (inBuf.position() > 0)
if (inBuf.position() > 0) {
inBuf.flip();

sslMeta.encodedBuffer(inBuf);
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.ignite.internal.util.nio;

import java.util.Arrays;
import java.util.Collections;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;
import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.internal.util.nio.ssl.BlockingSslHandler;
import org.apache.ignite.internal.util.typedef.internal.U;
import org.apache.ignite.lang.IgniteRunnable;
import org.apache.ignite.plugin.extensions.communication.Message;
import org.apache.ignite.spi.communication.CommunicationListener;
import org.apache.ignite.spi.communication.CommunicationSpi;
import org.apache.ignite.spi.communication.GridAbstractCommunicationSelfTest;
import org.apache.ignite.spi.communication.TestVolatilePayloadMessage;
import org.apache.ignite.spi.communication.tcp.TcpCommunicationSpi;
import org.apache.ignite.spi.communication.tcp.internal.GridNioServerWrapper;
import org.apache.ignite.spi.communication.tcp.messages.RecoveryLastReceivedMessage;
import org.apache.ignite.testframework.GridTestUtils;
import org.junit.Test;

import static org.apache.ignite.testframework.GridTestUtils.waitForCondition;

/**
* Tests the case when regular communications messages are sent along with the last handshake messages and SSL is enabled.
* It asserts that if not all received by network bytes are processed by {@link BlockingSslHandler} during the handshake
* phase, then all remaining bytes are properly copied to {@code GridNioSslHandler}, which replaces
* {@link BlockingSslHandler} after the handshake phase.
* The steps that can lead to mentioned above conditions:
* <p>
* 1. Node B sends a MESSAGE to Node A and stores it in the local recovery descriptor until an acknowledgment is received
* from Node A.
* <p>
* 2. Node A, for whatever reason, reestablishes connection with node B and starts handshake negotiation.
* <p>
* 3. Node B during the final phase of handshake sends {@link RecoveryLastReceivedMessage} and resends not acknowledged
* MESSAGE from step 1. But all sent bytes are divided into two network packets. Let's assume that the first packet
* contains all bytes related to {@link RecoveryLastReceivedMessage} and only half of the MESSAGE bytes.
* <p>
* 4. Node A decodes {@link RecoveryLastReceivedMessage} from the received network packet and finishes the handshake.
* But the MESSAGE cannot be processed because not enough bytes were received. So we must save remaining bytes from the
* first network packet, wait for the next network packet and finish MESSAGE deserialization.
*
*/
public class TcpCommunicationSpiSslVolatilePayloadTest extends GridAbstractCommunicationSelfTest<CommunicationSpi<Message>> {
/** */
private static final int TEST_ITERATION_CNT = 1000;

/** The number of messages intended to fill the network buffer during last handshake message sending. */
private static final int RECOVERY_DESCRIPTOR_QUEUE_MESSAGE_CNT = 50;

/** */
private static final AtomicInteger msgCreatedCntr = new AtomicInteger();

/** */
private static final AtomicInteger msgReceivedCntr = new AtomicInteger();

/** */
private static final Map<Integer, TestVolatilePayloadMessage> messages = new ConcurrentHashMap<>();

/** {@inheritDoc} */
@Override protected CommunicationSpi<Message> getSpi(int idx) {
return new TcpCommunicationSpi().setLocalPort(GridTestUtils.getNextCommPort(getClass()))
.setIdleConnectionTimeout(2000)
.setTcpNoDelay(true);
}

/** {@inheritDoc} */
@Override protected CommunicationListener<Message> createMessageListener(UUID nodeId) {
return new TestCommunicationListener();
}

/** {@inheritDoc} */
@Override protected Map<Short, Supplier<Message>> customMessageTypes() {
return Collections.singletonMap(TestVolatilePayloadMessage.DIRECT_TYPE, TestVolatilePayloadMessage::new);
}

/** {@inheritDoc} */
@Override protected boolean isSslEnabled() {
return true;
}

/** */
@Test
public void test() throws Exception {
ClusterNode from = nodes.get(0);
ClusterNode to = nodes.get(1);

for (int i = 0; i < TEST_ITERATION_CNT; i++) {
// Force connection to be established.
sendMessage(from, to, createMessage());

GridNioRecoveryDescriptor fromDesc = extractRecoveryDescriptor(from);
GridNioRecoveryDescriptor toDesc = extractRecoveryDescriptor(to);

// Stores multiple dummy messages in a recovery descriptor. When the connection is restored, they will be
// written to the network buffer along with the last handshake message.
// See TcpHandshakeExecutor#receiveAcknowledge
for (int j = 0; j < RECOVERY_DESCRIPTOR_QUEUE_MESSAGE_CNT; j++)
toDesc.add(new GridNioServer.WriteRequestImpl(toDesc.session(), createMessage(), false, null));

// Close connection to re-initiate handshake between nodes.
if (fromDesc.session() != null)
fromDesc.session().close();
}

assertTrue(waitForCondition(() -> msgCreatedCntr.get() == msgReceivedCntr.get(), 5000));
}

/** */
public GridNioRecoveryDescriptor extractRecoveryDescriptor(ClusterNode node) throws Exception {
CommunicationSpi<Message> spi = spis.get(node.id());

GridNioServerWrapper wrapper = U.field(spi, "nioSrvWrapper");

assertTrue(waitForCondition(() -> !wrapper.recoveryDescs().values().isEmpty(), getTestTimeout()));

return wrapper.recoveryDescs().values().stream().findFirst().get();
}

/** */
private Message createMessage() {
byte[] payload = new byte[ThreadLocalRandom.current().nextInt(10, 1024)];

ThreadLocalRandom.current().nextBytes(payload);

TestVolatilePayloadMessage msg = new TestVolatilePayloadMessage(msgCreatedCntr.getAndIncrement(), payload);

messages.put(msg.index(), msg);

return msg;
}

/** */
private void sendMessage(ClusterNode from, ClusterNode to, Message msg) {
spis.get(from.id()).sendMessage(to, msg);
}

/** */
private static class TestCommunicationListener implements CommunicationListener<Message> {
/** {@inheritDoc} */
@Override public void onMessage(UUID nodeId, Message msg, IgniteRunnable msgC) {
msgC.run();

if (msg instanceof TestVolatilePayloadMessage) {
TestVolatilePayloadMessage testMsg = (TestVolatilePayloadMessage)msg;

TestVolatilePayloadMessage expMsg = messages.get(testMsg.index());

assertNotNull(expMsg);

assertTrue(Arrays.equals(expMsg.payload(), testMsg.payload()));

msgReceivedCntr.incrementAndGet();
}
}

/** {@inheritDoc} */
@Override public void onDisconnected(UUID nodeId) {
// No-op.
}
}
}
Loading
Loading