diff --git a/src/one/nio/rpc/RpcClient.java b/src/one/nio/rpc/RpcClient.java index 9182147..60a7ea1 100755 --- a/src/one/nio/rpc/RpcClient.java +++ b/src/one/nio/rpc/RpcClient.java @@ -19,6 +19,7 @@ import one.nio.net.ConnectionString; import one.nio.net.Socket; import one.nio.pool.SocketPool; +import one.nio.rpc.stream.RpcStreamImpl; import one.nio.serial.CalcSizeStream; import one.nio.serial.DataStream; import one.nio.serial.DeserializeStream; @@ -30,7 +31,7 @@ import java.io.IOException; import java.lang.reflect.InvocationHandler; import java.lang.reflect.Method; -import java.net.SocketException; +import java.net.SocketTimeoutException; public class RpcClient extends SocketPool implements InvocationHandler { protected static final byte[][] uidLocks = new byte[64][0]; @@ -44,12 +45,16 @@ public Object invoke(Object request) throws Exception { } public Object invoke(Object request, int timeout) throws Exception { - byte[] buffer = invokeRaw(request, timeout); + Object rawResponse = invokeRaw(request, timeout); + + while (true) { + if (!(rawResponse instanceof byte[])) { + return rawResponse; + } - for (;;) { Object response; try { - response = new DeserializeStream(buffer).readObject(); + response = new DeserializeStream((byte[]) rawResponse).readObject(); } catch (SerializerNotFoundException e) { long uid = e.getUid(); synchronized (uidLockFor(uid)) { @@ -65,7 +70,7 @@ public Object invoke(Object request, int timeout) throws Exception { } else if (response instanceof SerializerNotFoundException) { long uid = ((SerializerNotFoundException) response).getUid(); provideSerializer(Repository.requestSerializer(uid)); - buffer = invokeRaw(request, readTimeout); + rawResponse = invokeRaw(request, readTimeout); } else { throw (Exception) response; } @@ -74,6 +79,10 @@ public Object invoke(Object request, int timeout) throws Exception { @Override public Object invoke(Object proxy, Method method, Object... args) throws Exception { + if (method.getDeclaringClass() == Object.class) { + // toString(), hashCode() etc. are not remote methods + return method.invoke(this, args); + } return invoke(new RemoteCall(method, args)); } @@ -90,7 +99,7 @@ protected Serializer requestSerializer(long uid) throws Exception { } protected Object invokeServiceRequest(Object request) throws Exception { - byte[] rawResponse = invokeRaw(request, readTimeout); + byte[] rawResponse = (byte[]) invokeRaw(request, readTimeout); Object response = new DeserializeStream(rawResponse).readObject(); if (response instanceof Exception) { throw (Exception) response; @@ -98,27 +107,50 @@ protected Object invokeServiceRequest(Object request) throws Exception { return response; } - private byte[] invokeRaw(Object request, int timeout) throws Exception { + private Object invokeRaw(Object request, int timeout) throws Exception { byte[] buffer = serialize(request); Socket socket = borrowObject(); try { try { sendRequest(socket, buffer, timeout); - } catch (SocketException e) { + } catch (SocketTimeoutException e) { + throw e; + } catch (IOException e) { // Stale connection? Retry on a fresh socket destroyObject(socket); socket = createObject(); sendRequest(socket, buffer, timeout); } - int responseSize = RpcPacket.getSize(buffer, socket); + int responseSize = RpcPacket.getSize(buffer); + if (responseSize == RpcPacket.STREAM_HEADER) { + return new RpcStreamImpl(socket) { + { + socket.setTos(Socket.IPTOS_THROUGHPUT); + } + + @Override + public void close() { + super.close(); + + if (error) { + invalidateObject(socket); + } else { + socket.setTos(0); + returnObject(socket); + } + } + }; + } + + RpcPacket.checkSize(responseSize, socket); if (responseSize > 4) buffer = new byte[responseSize]; socket.readFully(buffer, 0, responseSize); returnObject(socket); return buffer; - } catch (Exception e) { + } catch (Throwable e) { invalidateObject(socket); throw e; } diff --git a/src/one/nio/rpc/RpcPacket.java b/src/one/nio/rpc/RpcPacket.java index 14ef831..d4ed7ca 100755 --- a/src/one/nio/rpc/RpcPacket.java +++ b/src/one/nio/rpc/RpcPacket.java @@ -29,15 +29,31 @@ class RpcPacket { private static final int WARN_PACKET_SIZE = 4 * 1024 * 1024; private static final int ERROR_PACKET_SIZE = 128 * 1024 * 1024; - static int getSize(byte[] buffer, Socket socket) throws IOException { - int size = buffer[0] << 24 | (buffer[1] & 0xff) << 16 | (buffer[2] & 0xff) << 8 | (buffer[3] & 0xff); + static final int STREAM_HEADER = 0xEDAEDA03; + static final byte[] STREAM_HEADER_ARRAY = { + (byte) (STREAM_HEADER >>> 24), + (byte) (STREAM_HEADER >>> 16), + (byte) (STREAM_HEADER >>> 8), + (byte) STREAM_HEADER + }; + + static final int HTTP_GET = 'G' << 24 | 'E' << 16 | 'T' << 8 | ' '; + static final int HTTP_POST = 'P' << 24 | 'O' << 16 | 'S' << 8 | 'T'; + static final int HTTP_HEAD = 'H' << 24 | 'E' << 16 | 'A' << 8 | 'D'; + + static boolean isHttpHeader(int header) { + return header == HTTP_GET || header == HTTP_POST || header == HTTP_HEAD; + } + + static int getSize(byte[] buffer) { + return buffer[0] << 24 | (buffer[1] & 0xff) << 16 | (buffer[2] & 0xff) << 8 | (buffer[3] & 0xff); + } - if (size < 0 || size >= ERROR_PACKET_SIZE) { + static void checkSize(int size, Socket socket) throws IOException { + if (size <= 0 || size >= ERROR_PACKET_SIZE) { throw new IOException("Invalid RPC packet from " + socket.getRemoteAddress()); } else if (size >= WARN_PACKET_SIZE) { log.warn("RPC packet from " + socket.getRemoteAddress() + " is too large: " + size); } - - return size; } } diff --git a/src/one/nio/rpc/RpcSession.java b/src/one/nio/rpc/RpcSession.java index 9f25824..af38971 100755 --- a/src/one/nio/rpc/RpcSession.java +++ b/src/one/nio/rpc/RpcSession.java @@ -16,20 +16,26 @@ package one.nio.rpc; +import one.nio.http.Request; import one.nio.net.Session; import one.nio.net.Socket; +import one.nio.rpc.stream.RpcStreamImpl; +import one.nio.rpc.stream.StreamProxy; import one.nio.serial.CalcSizeStream; import one.nio.serial.DataStream; import one.nio.serial.DeserializeStream; +import one.nio.serial.Repository; import one.nio.serial.SerializeStream; import one.nio.serial.SerializerNotFoundException; +import one.nio.util.Utf8; import java.io.IOException; import java.net.InetSocketAddress; import java.util.concurrent.RejectedExecutionException; public class RpcSession extends Session { - private static final int BUFFER_SIZE = 8000; + protected static final int BUFFER_SIZE = 8000; + protected static final byte HTTP_REQUEST_UID = (byte) Repository.get(Request.class).uid(); protected final RpcServer server; protected final InetSocketAddress peer; @@ -48,29 +54,36 @@ public RpcSession(Socket socket, RpcServer server) { @Override protected void processRead(byte[] unusedBuffer) throws Exception { byte[] buffer = this.buffer; - int bytesRead = this.bytesRead; int requestSize = this.requestSize; // Read 4-bytes header if (requestSize == 0) { - bytesRead += super.read(buffer, bytesRead, 4 - bytesRead); - if (bytesRead < 4) { - this.bytesRead = bytesRead; + if (bytesRead < 4 && (bytesRead += super.read(buffer, bytesRead, 4 - bytesRead)) < 4) { return; } - bytesRead = 0; - requestSize = this.requestSize = RpcPacket.getSize(buffer, socket); + requestSize = RpcPacket.getSize(buffer); + if (requestSize >= RpcPacket.HTTP_GET && RpcPacket.isHttpHeader(requestSize)) { + // Looks like HTTP request - try to parse as HTTP + if ((requestSize = readHttpHeader()) < 0) { + // HTTP headers not yet complete + return; + } + } else { + bytesRead = 0; + } + + RpcPacket.checkSize(requestSize, socket); if (requestSize > buffer.length) { - buffer = this.buffer = new byte[requestSize]; + this.buffer = buffer = expandBuffer(requestSize); } + + this.requestSize = requestSize; this.requestStartTime = selector.lastWakeupTime(); } // Read request - bytesRead += super.read(buffer, bytesRead, requestSize - bytesRead); - if (bytesRead < requestSize) { - this.bytesRead = bytesRead; + if ((bytesRead += super.read(buffer, bytesRead, requestSize - bytesRead)) < requestSize) { return; } @@ -96,7 +109,7 @@ protected void processRead(byte[] unusedBuffer) throws Exception { } // Perform the invocation - if (server.getWorkersUsed()) { + if (isAsyncRequest(request)) { try { server.asyncExecute(new AsyncRequest(request, meta)); server.incRequestsProcessed(); @@ -110,6 +123,55 @@ protected void processRead(byte[] unusedBuffer) throws Exception { } } + private byte[] expandBuffer(int requestSize) { + byte[] newBuffer = new byte[requestSize]; + System.arraycopy(buffer, 0, newBuffer, 0, bytesRead); + return newBuffer; + } + + private int readHttpHeader() throws IOException { + byte[] buffer = this.buffer; + int bytesRead = this.bytesRead; + + bytesRead += super.read(buffer, bytesRead, BUFFER_SIZE - bytesRead); + this.bytesRead = bytesRead; + + int contentLength = 0; + int lineStart = 4; + for (int i = 4; i < bytesRead; i++) { + // Parse line by line + if (buffer[i] == '\n') { + if (buffer[i - 1] == '\n' || buffer[i - 1] == '\r' && buffer[i - 2] == '\n') { + // Make HTTP request deserializable with the standard DeserializeStream + buffer[0] = HTTP_REQUEST_UID; + return i + 1 + contentLength; + } else if (i - lineStart > 16 && startsWith(buffer, lineStart, "content-length: ")) { + int end = buffer[i - 1] == '\r' ? i - 1 : i; + contentLength = (int) Utf8.parseLong(buffer, lineStart + 16, end - (lineStart + 16)); + } + lineStart = i + 1; + } + } + + // The headers are not yet complete. Return error if the buffer is already full. + return bytesRead < BUFFER_SIZE ? -1 : Integer.MAX_VALUE; + } + + private static boolean startsWith(byte[] buffer, int from, String s) { + int length = s.length(); + for (int i = 0; i < length; i++) { + // Make letters case-insensitive + if ((buffer[from + i] | 32) != s.charAt(i)) { + return false; + } + } + return true; + } + + protected boolean isAsyncRequest(Object request) { + return server.getWorkersUsed(); + } + // To be overridden protected M onRequestRead() { return null; @@ -129,10 +191,38 @@ protected int writeResponse(Object response) throws IOException { return responseSize; } + @SuppressWarnings("unchecked") + protected void streamCommunicate(StreamProxy streamProxy) throws IOException { + selector.disable(this); + socket.setBlocking(true); + socket.setTos(Socket.IPTOS_THROUGHPUT); + socket.writeFully(RpcPacket.STREAM_HEADER_ARRAY, 0, 4); + + try (RpcStreamImpl stream = new RpcStreamImpl(socket)) { + streamProxy.handler.communicate(stream); + streamProxy.bytesRead = stream.getBytesRead(); + streamProxy.bytesWritten = stream.getBytesWritten(); + } catch (ClassNotFoundException e) { + close(); + throw new IOException(e); + } catch (Throwable e) { + close(); + throw e; + } + + socket.setTos(0); + socket.setBlocking(false); + selector.enable(this); + } + protected void invoke(Object request, M meta) throws Exception { RemoteCall remoteCall = (RemoteCall) request; Object response = remoteCall.method().invoke(server.service, remoteCall.args()); - writeResponse(response); + if (response instanceof StreamProxy) { + streamCommunicate((StreamProxy) response); + } else { + writeResponse(response); + } } protected void handleDeserializationException(Exception e) throws IOException { diff --git a/test/one/nio/rpc/RpcTest.java b/test/one/nio/rpc/RpcTest.java new file mode 100644 index 0000000..b4dac71 --- /dev/null +++ b/test/one/nio/rpc/RpcTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2018 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.rpc; + +import one.nio.config.ConfigParser; +import one.nio.net.ConnectionString; +import one.nio.rpc.stream.BidiStream; +import one.nio.serial.sample.Message; +import one.nio.serial.sample.Sample; +import one.nio.server.ServerConfig; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.Serializable; +import java.lang.reflect.Proxy; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +public class RpcTest { + private static RpcServer server; + private static TestService client; + + @BeforeClass + public static void setup() throws Exception { + ServerConfig config = ConfigParser.parse("acceptors:\n - port: 8888", ServerConfig.class); + server = new RpcServer<>(config, new TestServiceImpl()); + server.start(); + + client = (TestService) Proxy.newProxyInstance( + RpcTest.class.getClassLoader(), + new Class[]{TestService.class}, + new RpcClient(new ConnectionString("127.0.0.1:8888"))); + } + + @AfterClass + public static void destroy() { + server.stop(); + } + + @Test + public void testSimpleMethod() { + Set ids = new HashSet<>(Arrays.asList(111L, 222L, 333L, 444L, 555L)); + Map messages = client.getMessagesByIds(ids); + assertEquals(5, messages.size()); + assertEquals("Second message", messages.get(222L).text); + assertEquals(2, messages.get(444L).attachments.size()); + } + + @Test + public void testNumberStream() throws Exception { + Stats stats = testNumberStream(1); + assertEquals(-100, stats.min); + assertEquals(300, stats.max); + assertEquals(BigDecimal.valueOf(100), stats.avg); + assertEquals(BigInteger.valueOf(400), stats.sum); + + stats = testNumberStream(10000); + assertEquals(-100, stats.min); + assertEquals(300, stats.max); + assertEquals(BigDecimal.valueOf(100), stats.avg); + assertEquals(BigInteger.valueOf(4000000), stats.sum); + } + + private Stats testNumberStream(int iterations) throws Exception { + try (BidiStream stream = client.openNumberStream()) { + for (int i = 0; i < iterations; i++) { + stream.send(200L); + stream.send(-100); + stream.send(BigInteger.ZERO); + stream.send(300.0); + } + return stream.sendAndGet(null); + } + } + + + interface TestService { + Map getMessagesByIds(Set ids); + BidiStream openNumberStream(); + } + + static class TestServiceImpl implements TestService { + + @Override + public Map getMessagesByIds(Set ids) { + Map map = new HashMap<>(); + for (Message message : Sample.createChat().messages) { + if (ids.contains(message.id)) { + map.put(message.id, message); + } + } + return map; + } + + @Override + public BidiStream openNumberStream() { + return BidiStream.create(stream -> { + BigInteger sum = BigInteger.ZERO; + long count = 0; + long min = Long.MAX_VALUE; + long max = Long.MIN_VALUE; + + for (Number number; (number = stream.receive()) != null; ) { + long n = number.longValue(); + sum = sum.add(BigInteger.valueOf(n)); + count++; + min = Math.min(min, n); + max = Math.max(max, n); + } + + BigDecimal avg = new BigDecimal(sum).divide(BigDecimal.valueOf(count), 0); + stream.send(new Stats(sum, avg, min, max)); + }); + } + } + + static class Stats implements Serializable { + final BigInteger sum; + final BigDecimal avg; + final long min; + final long max; + + Stats(BigInteger sum, BigDecimal avg, long min, long max) { + this.sum = sum; + this.avg = avg; + this.min = min; + this.max = max; + } + } +}