From 21f953cda8ab3d60ed483d363f08ad55563a0302 Mon Sep 17 00:00:00 2001 From: Oleg Larionov Date: Thu, 19 Sep 2024 18:37:27 +0300 Subject: [PATCH] Upgrade to OpenSSL 3 (#85) Breaking changes: The native part of one-nio now links and works only with OpenSSL 3 (tested on OpenSSL versions 3.0 and 3.2). The OpenSSL library file must have name libssl.so.3. Major changes: Added support for Kernel TLS, including sendfile Added support for TLS early data (TLS1.3 0-RTT) Added ability to export encryption keys (SSL keylog) Added support for SSL certificate compression (RFC8879) Added ability to use an external cache for SSL sessions Added ability to constrain selectors used by an acceptor Added ability to use a single acceptor thread for all sockets Added support for server-sent events in one.nio.http.HttpClient Also various fixes and optimizations. --- build.xml | 68 --- pom.xml | 7 +- src/one/nio/http/EventSource.java | 84 +++ src/one/nio/http/EventSourceResponse.java | 66 +++ src/one/nio/http/HttpClient.java | 413 ++++++++++++++- src/one/nio/http/HttpSession.java | 13 +- src/one/nio/http/Request.java | 47 ++ src/one/nio/lz4/LZ4.java | 22 +- src/one/nio/mem/LongHashSet.java | 10 + src/one/nio/mem/MallocMT.java | 3 + src/one/nio/mem/OffheapBitSet.java | 15 +- src/one/nio/net/JavaServerSocket.java | 36 +- src/one/nio/net/JavaSslClientContext.java | 162 ++++++ src/one/nio/net/JavaSslClientSocket.java | 355 +++++++++++++ src/one/nio/net/KeylogHolder.java | 45 ++ src/one/nio/net/NativeSocket.java | 20 +- src/one/nio/net/NativeSslContext.java | 155 +++++- src/one/nio/net/NativeSslSocket.java | 48 +- src/one/nio/net/Socket.java | 27 +- src/one/nio/net/SslClientContextFactory.java | 37 ++ src/one/nio/net/SslConfig.java | 11 +- src/one/nio/net/SslContext.java | 48 +- src/one/nio/net/SslOption.java | 33 +- src/one/nio/net/SslSessionCache.java | 116 ++++ src/one/nio/net/native/socket.c | 33 +- src/one/nio/net/native/ssl.c | 499 ++++++++++++++++-- src/one/nio/os/Cpus.java | 2 +- src/one/nio/os/Proc.java | 6 + src/one/nio/os/bpf/Bpf.java | 2 + src/one/nio/os/bpf/BpfProg.java | 29 + src/one/nio/os/native/bpf.c | 66 ++- src/one/nio/pool/SocketPool.java | 11 +- src/one/nio/rpc/RpcClient.java | 33 +- src/one/nio/rpc/RpcSession.java | 25 +- src/one/nio/serial/DataStream.java | 39 +- src/one/nio/serial/DeserializeStream.java | 2 +- src/one/nio/serial/PersistStream.java | 2 +- src/one/nio/serial/SerializationContext.java | 1 + src/one/nio/server/AcceptorConfig.java | 3 + src/one/nio/server/SelectorThread.java | 4 +- src/one/nio/server/Server.java | 170 +++--- src/one/nio/server/ServerConfig.java | 23 + src/one/nio/server/acceptor/Acceptor.java | 39 ++ .../nio/server/acceptor/AcceptorFactory.java | 56 ++ .../nio/server/acceptor/AcceptorSupport.java | 66 +++ .../server/{ => acceptor}/AcceptorThread.java | 66 +-- .../nio/server/acceptor/DefaultAcceptor.java | 124 +++++ .../server/acceptor/DefaultAcceptorGroup.java | 143 +++++ .../server/acceptor/MultiAcceptSession.java | 52 ++ .../nio/server/acceptor/MultiAcceptor.java | 134 +++++ .../server/acceptor/MultiAcceptorGroup.java | 112 ++++ .../server/acceptor/MultiAcceptorThread.java | 95 ++++ src/one/nio/util/JavaFeatures.java | 5 +- test/one/nio/config/ConfigParserTest.java | 4 +- test/one/nio/http/ChunkedEventReaderTest.java | 261 +++++++++ test/one/nio/http/HttpHeaderTest.java | 21 + test/one/nio/mem/LongHashSetFuncTest.java | 34 ++ .../nio/mem/SharedMemoryStringMapTest.java | 32 ++ test/one/nio/net/SocketTest.java | 12 +- test/one/nio/ssl/TLSCurveTest.java | 183 +++++++ test_data/ssl/ca.crt | 19 + test_data/ssl/ca.key | 28 + test_data/ssl/certificate.crt | 19 + test_data/ssl/certificate.key | 28 + test_data/ssl/generate.sh | 29 + test_data/ssl/trustore.jks | Bin 0 -> 1142 bytes 66 files changed, 3968 insertions(+), 385 deletions(-) delete mode 100755 build.xml create mode 100644 src/one/nio/http/EventSource.java create mode 100644 src/one/nio/http/EventSourceResponse.java create mode 100644 src/one/nio/net/JavaSslClientContext.java create mode 100644 src/one/nio/net/JavaSslClientSocket.java create mode 100644 src/one/nio/net/KeylogHolder.java create mode 100644 src/one/nio/net/SslClientContextFactory.java create mode 100644 src/one/nio/net/SslSessionCache.java create mode 100644 src/one/nio/server/acceptor/Acceptor.java create mode 100644 src/one/nio/server/acceptor/AcceptorFactory.java create mode 100644 src/one/nio/server/acceptor/AcceptorSupport.java rename src/one/nio/server/{ => acceptor}/AcceptorThread.java (57%) create mode 100644 src/one/nio/server/acceptor/DefaultAcceptor.java create mode 100644 src/one/nio/server/acceptor/DefaultAcceptorGroup.java create mode 100644 src/one/nio/server/acceptor/MultiAcceptSession.java create mode 100644 src/one/nio/server/acceptor/MultiAcceptor.java create mode 100644 src/one/nio/server/acceptor/MultiAcceptorGroup.java create mode 100644 src/one/nio/server/acceptor/MultiAcceptorThread.java create mode 100644 test/one/nio/http/ChunkedEventReaderTest.java create mode 100644 test/one/nio/mem/LongHashSetFuncTest.java create mode 100644 test/one/nio/mem/SharedMemoryStringMapTest.java create mode 100644 test/one/nio/ssl/TLSCurveTest.java create mode 100644 test_data/ssl/ca.crt create mode 100644 test_data/ssl/ca.key create mode 100644 test_data/ssl/certificate.crt create mode 100644 test_data/ssl/certificate.key create mode 100755 test_data/ssl/generate.sh create mode 100644 test_data/ssl/trustore.jks diff --git a/build.xml b/build.xml deleted file mode 100755 index 8a300ab..0000000 --- a/build.xml +++ /dev/null @@ -1,68 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/pom.xml b/pom.xml index 3dcc0fc..661749e 100644 --- a/pom.xml +++ b/pom.xml @@ -61,6 +61,11 @@ install src test + + + test_data + + org.apache.maven.plugins @@ -85,7 +90,7 @@ - HTML Standard: 9.2 Server-sent events + * @see HttpClient#openEvents(Request, int) + */ +public interface EventSource extends Closeable +{ + /** + * Waits for the next SSE and returns an event. The method can block for a long time ( determined by server ). + * It is essential to check for null and {@link Event#isEmpty()} before processing + * + * @return the next event from the stream or null, if stream was closed by either party + * @throws IOException an I/O exception occurred + * @throws HttpException an incorrect HTTP request received + */ + Event poll( ) throws IOException, HttpException; + + /** + * A single Server Sent Event, received from peer. + */ + interface Event + { + /** + * No name, id and data in event ( only comment ) + * + * @return true, if the event has no name, id and data. false, otherwise + */ + boolean isEmpty(); + + /** + * an SSE event name + * + * @return the event name + */ + String name(); + + /** + * an SSE event id. this can be used to request events starting from specified + * + * @return the event id + * @see HTML Standard: 9.2.4 The `Last-Event-ID` header + */ + String id(); + + /** + * an SSE "data" line. + * + * @return the event data + */ + D data(); + + + /** + * an SSE comment concatenated + * + * @return the event comment + */ + String comment(); + + } + +} diff --git a/src/one/nio/http/EventSourceResponse.java b/src/one/nio/http/EventSourceResponse.java new file mode 100644 index 0000000..1778ca0 --- /dev/null +++ b/src/one/nio/http/EventSourceResponse.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.http; + +import java.io.IOException; + +/** + * A Response, which can poll for server emitted events. + * Unlike regular {@link Response} this object must be close'd to + * prevent resource leak. + *

+ * The usage flow is as follows: + *

    + *
  1. Call {@link HttpClient#openEvents(Request, int)}
  2. + *
  3. Inspect the result code, if it is not OK process the error
  4. + *
  5. Inspect the content-type, it must be text/event-stream; if it is not - process the response body - there will be no events
  6. + *
  7. while ( ( event = poll() ) != null ) process( event )
  8. + *
  9. call {@link #close()}
  10. + *
  11. call {@link HttpClient#reopenEvents(Request, String, int)} with last processed {@link Event#id()} and go to p.2
  12. + *
+ * + * @see HTML Standard: 9.2 Server-sent events + * @see HttpClient#openEvents(Request, int) + */ +public class EventSourceResponse extends Response implements EventSource +{ + private EventSource eventSource; + + public EventSourceResponse( String resultCode ) + { + super( resultCode ); + } + + @Override + public Event poll() throws IOException, HttpException + { + return eventSource == null ? null : eventSource.poll(); + } + + void setEventSource( EventSource es ) { + this.eventSource = es; + } + + @Override + public void close() throws IOException + { + if ( eventSource != null ) { + eventSource.close(); + eventSource = null; + } + } +} diff --git a/src/one/nio/http/HttpClient.java b/src/one/nio/http/HttpClient.java index 1cddd08..422cb07 100755 --- a/src/one/nio/http/HttpClient.java +++ b/src/one/nio/http/HttpClient.java @@ -16,11 +16,11 @@ package one.nio.http; +import one.nio.net.SslClientContextFactory; import one.nio.net.ConnectionString; import one.nio.net.HttpProxy; import one.nio.net.Socket; import one.nio.net.SocketClosedException; -import one.nio.net.SslContext; import one.nio.pool.PoolException; import one.nio.pool.SocketPool; import one.nio.util.Utf8; @@ -28,10 +28,12 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.Closeable; import java.io.IOException; import java.net.SocketTimeoutException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; public class HttpClient extends SocketPool { @@ -55,7 +57,7 @@ public HttpClient(ConnectionString conn, String... permanentHeaders) { protected void setProperties(ConnectionString conn) { boolean https = "https".equals(conn.getProtocol()); if (https) { - sslContext = SslContext.getDefault(); + sslContext = SslClientContextFactory.create(); } if (port == 0) { port = https ? 443 : 80; @@ -119,6 +121,11 @@ public Response get(String uri, String... headers) return invoke(createRequest(Request.METHOD_GET, uri, headers)); } + public EventSourceResponse openEvents(String uri, String... headers) + throws InterruptedException, PoolException, IOException, HttpException { + return openEvents( createRequest( Request.METHOD_GET, uri, headers ), readTimeout ); + } + public Response delete(String uri, String... headers) throws InterruptedException, PoolException, IOException, HttpException { return invoke(createRequest(Request.METHOD_DELETE, uri, headers)); @@ -189,6 +196,51 @@ public Response connect(String uri, String... headers) return invoke(createRequest(Request.METHOD_CONNECT, uri, headers)); } + @SuppressWarnings ( "resource") + public EventSourceResponse openEvents( Request request, int timeout ) throws InterruptedException, PoolException, IOException, HttpException + { + request.addHeader( "Accept: text/event-stream" ); + + int method = request.getMethod(); + byte[] rawRequest = request.toBytes(); + ServerSentEventsReader sseReader; + + Socket socket = borrowObject(); + boolean open = false; + + try { + try { + socket.setTimeout( timeout == 0 ? readTimeout : timeout ); + socket.writeFully( rawRequest, 0, rawRequest.length ); + sseReader = new ServerSentEventsReader( socket, bufferSize ); + } catch (SocketTimeoutException e) { + throw e; + } catch (IOException e) { + // Stale connection? Retry on a fresh socket + destroyObject(socket); + socket = createObject(); + socket.setTimeout( timeout == 0 ? readTimeout : timeout ); + socket.writeFully(rawRequest, 0, rawRequest.length); + sseReader = new ServerSentEventsReader( socket, bufferSize ); + } + + EventSourceResponse response = sseReader.readResponse(method); + open = true; + return response; + } finally { + if (!open) { + invalidateObject(socket); + } + } + } + + public EventSourceResponse reopenEvents( Request request, String lastId, int timeout ) throws InterruptedException, PoolException, IOException, HttpException + { + request.addHeader( "Last-Event-ID: " + lastId ); + + return openEvents( request, timeout ); + } + public Request createRequest(int method, String uri, String... headers) { Request request = new Request(method, uri, true); for (String header : permanentHeaders) { @@ -213,16 +265,30 @@ static class ResponseReader { } Response readResponse(int method) throws IOException, HttpException { + Response response = new Response( readResultCode() ); + readResponseHeaders( response ); + readResponseBody( method, response ); + return response; + } + + String readResultCode() throws IOException, HttpException + { String responseHeader = readLine(); if (responseHeader.length() <= 9) { throw new HttpException("Invalid response header: " + responseHeader); } + return responseHeader.substring(9); + } - Response response = new Response(responseHeader.substring(9)); + void readResponseHeaders(Response response) throws IOException, HttpException + { for (String header; !(header = readLine()).isEmpty(); ) { response.addHeader(header); } + } + void readResponseBody( int method, Response response ) throws IOException, HttpException + { if (method != Request.METHOD_HEAD && mayHaveBody(response.getStatus())) { if ("chunked".equalsIgnoreCase(response.getHeader("Transfer-Encoding:"))) { response.setBody(readChunkedBody()); @@ -238,8 +304,6 @@ Response readResponse(int method) throws IOException, HttpException { } } } - - return response; } String readLine() throws IOException, HttpException { @@ -345,4 +409,343 @@ private static boolean mayHaveBody(int status) { return status >= 200 && status != 204 && status != 304; } } + + class ChunkedLineReader extends ResponseReader implements Iterator, Closeable { + + private byte[] ch; + private int chPos, chLen; + + private boolean hasNext; + + + ChunkedLineReader( Socket socket, int bufferSize ) throws IOException + { + super( socket, bufferSize ); + this.ch = buf; + this.chPos = 0; + this.chLen = 0; + this.hasNext = true; + } + + private boolean nextChunk() throws IOException, HttpException { + + // the very first chunk header is written without empty line at the start, like: + // 999\n\r + // all subsequent chunk headers start with empty line, like: + // \n\r999\n\r + String l = readLine(); + int chunkSize = Integer.parseInt( l.isEmpty() ? readLine() : l, 16 ); + if (chunkSize == 0) { + readLine(); + this.chPos = 0; + this.chLen = 0; + this.hasNext = false; + return false; + } + + if ( chunkSize > ch.length ) { + // initially ch points to buf and reallocates to separate only if chunk size is greater than buf + ch = new byte[ chunkSizeFor( chunkSize ) ]; + } + + int contentBytes = length - pos; + if (contentBytes < chunkSize) { + System.arraycopy(buf, pos, ch, 0, contentBytes); + socket.readFully(ch, contentBytes, chunkSize - contentBytes); + pos = 0; + length = 0; + chPos = 0; + } else { + if ( ch != buf ) { + System.arraycopy(buf, pos, ch, 0, chunkSize); + chPos = 0; + } else { + chPos = pos; + } + pos += chunkSize; + } + chLen = chunkSize; + + return true; + + } + + private int chunkSizeFor( int cap ) + { + int n = -1 >>> Integer.numberOfLeadingZeros( cap - 1 ); + return n + 1; + } + + @Override + public boolean hasNext() + { + return hasNext; + } + + @Override + public String next() + { + try { + return readChunkedLine(); + } catch ( IOException | HttpException e ) { + log.debug("Event stream is closed by server"); + close(); + } + + return null; + } + + private String readChunkedLine() throws IOException, HttpException { + // whole line is found within current chunk + int end = findLineEnd( ch, chPos, chLen ); + if ( end >= 0 ) { + int lineLen = end - chPos; + String line = Utf8.read( ch, chPos, lineLen ); + lineLen++; // skip over \n + chLen -= lineLen; + chPos += lineLen; + return line; + } + + ArrayList chunks = new ArrayList<>(); + int lineLen = 0; + + do { + chunks.add(Arrays.copyOfRange(ch, this.chPos, this.chPos + this.chLen)); + lineLen += this.chLen; + this.ch = this.buf; + this.chPos = 0; + this.chLen = 0; + + if ( !nextChunk() ) { + // end of stream detected + end = 0; + break; + } + + end = findLineEnd( ch, chPos, chLen ); + } while ( end < 0 ); + + lineLen += Math.max( end - chPos, 0 ); + if ( lineLen == 0 ) + return ""; + + byte[] lineBytes = new byte[ lineLen ]; + int linePos = 0; + for ( byte[] b : chunks ) { + System.arraycopy( b, 0, lineBytes, linePos, b.length ); + linePos += b.length; + } + + // ch has last piece of line, if end > 0 || lineLen > linePos + if ( end > 0 ) { + System.arraycopy( ch, this.chPos, lineBytes, linePos, end - this.chPos ); + linePos += end - this.chPos; + chLen -= end - chPos + 1; + chPos = end + 1; // skip over \n + } + + assert linePos == lineLen; + + String line = Utf8.read( lineBytes, 0, lineBytes.length ); + return line; + } + + private int findLineEnd( byte[] b, int start, int len ) { + int end = start + len; + for ( ; start < end && b[start] != '\n'; start++ ) ; + + return start >= end ? -1 : start ; + } + + @Override + public void close() + { + if ( socket == null ) + return; + + invalidateObject(socket); + this.hasNext = false; + this.socket = null; + } + + } + + class ServerSentEventsReader extends ChunkedLineReader implements EventSource { + + private boolean keepAlive; + + ServerSentEventsReader( Socket socket, int bufferSize ) throws IOException + { + super( socket, bufferSize ); + } + + EventSourceResponse readResponse(int method) throws IOException, HttpException { + EventSourceResponse response = new EventSourceResponse( readResultCode() ); + readResponseHeaders( response ); + + if ( response.getHeader( "Content-Type: text/event-stream" ) == null ) { + try { + readResponseBody( method, response ); + keepAlive = !"close".equalsIgnoreCase(response.getHeader("Connection:")); + return response; + } finally { + close(); + } + } + + if ( !"chunked".equalsIgnoreCase( response.getHeader( "Transfer-Encoding:" ) ) ) { + throw new UnsupportedOperationException( "Only chunked transfer encoding is supported for text/event-stream" ); + } + + response.setEventSource( this ); + + return response; + } + + @Override + public Event poll( ) + { + if ( !hasNext() ) + return null; + + String line = next(); + return line == null || line.isEmpty() ? null : readEvent( line ); + + } + + private EventImpl readEvent( String line ) + { + EventImpl eimpl = new EventImpl(); + + StringBuilder databuf = new StringBuilder( line.length() ); + String field=":"; // impossible value + try { + do { + int cpos = line.indexOf( ':' ); + String f; + if ( cpos == 0 ) { + // comment. sometimes used alone as keep alive messages + f=""; + cpos++; + } else if ( cpos < 0 ) { + // no colon - whole line is field name as per spec + f = line; + cpos = line.length(); + } else { + // field name separated from data by colon with optional + // single space char after colon, like field-name: data + f = line.substring( 0, cpos ); + cpos++; + if ( cpos < line.length() && line.charAt( cpos )==' ') + cpos++; + } + + if ( !field.equals( f ) ) { + + eimpl.with( field, databuf ); + + field = f; + databuf.setLength( 0 ); + } else { + // multiple lines of the same field name concatenate data with newline + // a:b + // a:c + // a="b\nc" + databuf.append('\n'); + } + + databuf.append( line, cpos, line.length() ); + + line = next(); + if (line == null) { + // EOF + return null; + } + } while ( !line.isEmpty() ); + + if ( databuf.length() > 0 ) + eimpl.with( field, databuf ); + + } catch ( RuntimeException e ) { + log.error( "Cannot parse line: {}", line, e ); + throw e; + } + + log.debug( "Read event from stream: {}", eimpl ); + + return eimpl; + } + + @Override + public void close() + { + if ( socket != null && keepAlive) { + returnObject(socket); + socket = null; + } else { + super.close(); + } + } + + } + + static class EventImpl implements EventSource.Event { + + private String id, name, data, comment; + + @Override + public String name() + { + return name; + } + + @Override + public String id() + { + return id; + } + + @Override + public String data() + { + return data; + } + + @Override + public String comment() + { + return comment; + } + + boolean with( String field, StringBuilder databuf ) { + switch ( field ) { + case "id": + id = databuf.toString(); + break; + case "event": + name = databuf.toString(); + break; + case "data": + data = databuf.toString(); + break; + case "": + comment = databuf.toString(); + break; + default: + return false; + } + return true; + } + + public boolean isEmpty() { + return id == null && name == null && data == null; + } + + @Override + public String toString() + { + return isEmpty() ? "empty" : name + ":" + id; + } + } } diff --git a/src/one/nio/http/HttpSession.java b/src/one/nio/http/HttpSession.java index 425b926..697a11c 100755 --- a/src/one/nio/http/HttpSession.java +++ b/src/one/nio/http/HttpSession.java @@ -19,6 +19,7 @@ import one.nio.net.Session; import one.nio.net.Socket; import one.nio.net.SocketClosedException; +import one.nio.net.SslOption; import one.nio.util.Utf8; import java.io.IOException; @@ -93,8 +94,10 @@ protected void processRead(byte[] buffer) throws IOException { } protected void handleSocketClosed() { - // Unsubscribe from read events - listen(queueHead == null ? 0 : WRITEABLE); + if (selector != null) { + // Unsubscribe from read events + listen(queueHead == null ? 0 : WRITEABLE); + } if (handling == null) { scheduleClose(); @@ -176,6 +179,12 @@ protected int processHttpBuffer(byte[] buffer, int length) throws IOException, H if (parsing == null) { parsing = parseRequest(buffer, lineStart, lineLength); + if (isSsl()) { + boolean earlyDataAccepted = socket.getSslOption(SslOption.SESSION_EARLYDATA_ACCEPTED); + boolean handshakeDone = socket.getSslOption(SslOption.SESSION_HANDSHAKE_DONE); + parsing.setEarlyData(earlyDataAccepted && !handshakeDone); + } + } else if (lineLength > 0) { if (parsing.getHeaderCount() < MAX_HEADERS) { parsing.addHeader(Utf8.read(buffer, lineStart, lineLength)); diff --git a/src/one/nio/http/Request.java b/src/one/nio/http/Request.java index 5120095..c946018 100755 --- a/src/one/nio/http/Request.java +++ b/src/one/nio/http/Request.java @@ -113,6 +113,16 @@ public boolean isHttp11() { return http11; } + void setEarlyData(boolean earlyData) { + if (earlyData) { + addHeader("Early-Data: 1"); + } + } + + public boolean isEarlyData() { + return "1".equals(getHeader("Early-Data:")); + } + public String getPath() { return params >= 0 ? uri.substring(0, params) : uri; } @@ -262,6 +272,43 @@ public void consumeHeaders(String prefix, Consumer suffixConsumer) { } } + /** + * Returns trimmed header value after ':' delimiter + * + * @param key header name without ':' + * @return trimmed value after key: + */ + public String getHeaderValue(String key) { + int keyLength = key.length(); + for (int i = 0; i < headerCount; i++) { + String header = headers[i]; + if (header.length() > keyLength + && header.charAt(keyLength) == ':' + && header.regionMatches(true, 0, key, 0, keyLength)) { + return trim(header, keyLength + 1); + } + } + return null; + } + + /** + * Consume trimmed header value after ':' delimiter + + * @param key header name without ':' + * @param suffixConsumer a function for processing the header value + */ + public void consumeHeaderValues(String key, Consumer suffixConsumer) { + int keyLength = key.length(); + for (int i = 0; i < headerCount; i++) { + String header = headers[i]; + if (header.length() > keyLength + && header.charAt(keyLength) == ':' + && header.regionMatches(true, 0, key, 0, keyLength)) { + suffixConsumer.accept(trim(header, keyLength + 1)); + } + } + } + public String getHeader(String key, String defaultValue) { String value = getHeader(key); return value != null ? value : defaultValue; diff --git a/src/one/nio/lz4/LZ4.java b/src/one/nio/lz4/LZ4.java index c349aad..d4eb93e 100644 --- a/src/one/nio/lz4/LZ4.java +++ b/src/one/nio/lz4/LZ4.java @@ -133,7 +133,7 @@ public static int decompress(byte[] src, int srcOffset, byte[] dst, int dstOffse } if (result < 0) { - throw new IllegalArgumentException("Malformed input"); + throw new IllegalArgumentException("Malformed input or destination buffer overflow"); } return result; } @@ -147,7 +147,7 @@ public static int decompress(ByteBuffer src, ByteBuffer dst) { } if (result < 0) { - throw new IllegalArgumentException("Malformed input"); + throw new IllegalArgumentException("Malformed input or destination buffer overflow"); } src.position(src.limit()); @@ -541,13 +541,15 @@ private static int decompress(final Object src, final long srcOffset, s = unsafe.getByte(src, ip++) & 0xff; length += s; } while (ip < srcEnd - RUN_MASK && s == 255); - if (length < 0) return -1; // Error: overflow + if (length < 0) + return -1; // Error: overflow } // Copy literals long cpy = op + length; if (cpy > dstEnd - MFLIMIT || ip + length > srcEnd - (2 + 1 + LASTLITERALS)) { - if (ip + length != srcEnd || cpy > dstEnd) return -1; // Error: input must be consumed + if (ip + length != srcEnd || cpy > dstEnd) + return -1; // Error: input must be consumed unsafe.copyMemory(src, ip, dst, op, length); op += length; return (int) (op - dstOffset); @@ -559,18 +561,21 @@ private static int decompress(final Object src, final long srcOffset, // Get offset long match = cpy - (unsafe.getShort(src, ip) & 0xffff); ip += 2; - if (match < dstOffset) return -1; // Error: offset outside destination buffer + if (match < dstOffset) + return -1; // Error: offset outside destination buffer // Get matchlength length = token & ML_MASK; if (length == ML_MASK) { int s; do { - if (ip > srcEnd - LASTLITERALS) return -1; + if (ip > srcEnd - LASTLITERALS) + return -1; s = unsafe.getByte(src, ip++) & 0xff; length += s; } while (s == 255); - if (length < 0) return -1; // Error: overflow + if (length < 0) + return -1; // Error: overflow } length += MINMATCH; @@ -593,7 +598,8 @@ private static int decompress(final Object src, final long srcOffset, } if (cpy > dstEnd - 12) { - if (cpy > dstEnd - LASTLITERALS) return -1; // Error: last LASTLITERALS bytes must be literals + if (cpy > dstEnd - LASTLITERALS) + return -1; // Error: last LASTLITERALS bytes must be literals if (op < dstEnd - 8) { wildCopy(dst, match, dst, op, dstEnd - 8); match += (dstEnd - 8) - op; diff --git a/src/one/nio/mem/LongHashSet.java b/src/one/nio/mem/LongHashSet.java index fd17fe9..1e8e839 100755 --- a/src/one/nio/mem/LongHashSet.java +++ b/src/one/nio/mem/LongHashSet.java @@ -116,8 +116,18 @@ public final void setKeyAt(int index, long value) { unsafe.putOrderedLong(null, keys + (long) index * 8, value); } + /** + * This method is not atomic and must not be invoked concurrently with other modification methods (e.g., {@link LongHashSet#putKey} or {@link LongHashSet#removeKey}) + */ public void clear() { + int sizeBefore = size; unsafe.setMemory(keys, (long) capacity * 8, (byte) 0); + for (;;) { + int current = size; + if (unsafe.compareAndSwapInt(this, sizeOffset, current, Math.max(0, current - sizeBefore))) { + return; + } + } } protected void incrementSize() { diff --git a/src/one/nio/mem/MallocMT.java b/src/one/nio/mem/MallocMT.java index 1d76848..ab71373 100755 --- a/src/one/nio/mem/MallocMT.java +++ b/src/one/nio/mem/MallocMT.java @@ -69,6 +69,9 @@ public Malloc segment(int index) { /** * Deterministically get one of the segments by some {@code long} value + * + * @param n an index of the segment to return + * @return the {@link Malloc} instance for the specified segment */ public Malloc segmentFor(long n) { return segments[(int) n & (segments.length - 1)]; diff --git a/src/one/nio/mem/OffheapBitSet.java b/src/one/nio/mem/OffheapBitSet.java index 12d037c..f2bdf1c 100644 --- a/src/one/nio/mem/OffheapBitSet.java +++ b/src/one/nio/mem/OffheapBitSet.java @@ -47,7 +47,8 @@ public OffheapBitSet(long address, long sizeBytes) { } /** - * returns the number of 64 bit words it would take to hold numBits + * @param numBits a number of bits to hold + * @return the number of 64 bit words it would take to hold numBits */ public static long bits2words(long numBits) { return (((numBits - 1) >>> 6) + 1); @@ -69,7 +70,9 @@ public long capacity() { * Returns true or false for the specified bit index. The index should be * less than the capacity. * - * @throws IndexOutOfBoundsException + * @param index the bit index + * @return the value of the bit with the specified index + * @throws IndexOutOfBoundsException if the index is out of range */ public boolean get(long index) { return unsafeGet(checkBounds(index)); @@ -80,6 +83,9 @@ public boolean get(long index) { * bounds. This allows to make it few ticks faster in exchange to seg fault * possibility. Use when going out of capacity is ensured by other means * outside of this method + * + * @param index a bit index + * @return the value of the bit with the specified index */ public boolean unsafeGet(long index) { long word = index >> 6; // div 8 and round to long word @@ -91,6 +97,9 @@ public boolean unsafeGet(long index) { /** * Sets the bit at the specified index. The index should be less than the * capacity. + * + * @param index a bit index + * @throws IndexOutOfBoundsException if the index is out of range */ public void set(long index) { unsafeSet(checkBounds(index)); @@ -106,6 +115,8 @@ public void unsafeSet(long index) { /** * clears the bit. The index should be less than the capacity. + * + * @param index a bit index */ public void clear(long index) { unsafeClear(checkBounds(index)); diff --git a/src/one/nio/net/JavaServerSocket.java b/src/one/nio/net/JavaServerSocket.java index b6b32ce..f6704dc 100755 --- a/src/one/nio/net/JavaServerSocket.java +++ b/src/one/nio/net/JavaServerSocket.java @@ -18,15 +18,22 @@ import java.io.IOException; import java.io.RandomAccessFile; +import java.lang.reflect.Field; import java.net.InetAddress; import java.net.InetSocketAddress; import java.net.SocketException; +import java.net.SocketOption; import java.net.StandardSocketOptions; import java.nio.ByteBuffer; import java.nio.channels.SelectableChannel; import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; + +import one.nio.util.JavaInternals; final class JavaServerSocket extends SelectableJavaSocket { + private static final SocketOption SO_REUSEPORT_COMPAT = findReusePortOption(); + final ServerSocketChannel ch; JavaServerSocket() throws IOException { @@ -49,7 +56,8 @@ public final void close() { @Override public final JavaSocket accept() throws IOException { - return new JavaSocket(ch.accept()); + SocketChannel accepted = ch.accept(); + return accepted != null ? new JavaSocket(accepted) : null; } @Override @@ -208,7 +216,9 @@ public boolean getDeferAccept() { public final void setReuseAddr(boolean reuseAddr, boolean reusePort) { try { ch.setOption(StandardSocketOptions.SO_REUSEADDR, reuseAddr); - // todo: java 9+ SO_REUSEPORT + if (SO_REUSEPORT_COMPAT != null && ch.supportedOptions().contains(SO_REUSEPORT_COMPAT)) { + ch.setOption(SO_REUSEPORT_COMPAT, reusePort); + } } catch (IOException e) { // Ignore } @@ -225,7 +235,13 @@ public boolean getReuseAddr() { @Override public boolean getReusePort() { - return false; + try { + return SO_REUSEPORT_COMPAT != null && ch.supportedOptions().contains(SO_REUSEPORT_COMPAT) + ? ch.getOption(SO_REUSEPORT_COMPAT) + : false; + } catch (IOException e) { + return false; + } } @Override @@ -266,11 +282,12 @@ public int getTos() { @Override public final void setSendBuffer(int sendBuf) { - // Ignore + // See sun.nio.ch.ServerSocketChannelImpl.supportedOptions } @Override public int getSendBuffer() { + // See sun.nio.ch.ServerSocketChannelImpl.supportedOptions return 0; } @@ -322,4 +339,15 @@ public T getSslOption(SslOption option) { public SelectableChannel getSelectableChannel() { return ch; } + + private static SocketOption findReusePortOption() { + try { + Field reusePortField = JavaInternals.findField(StandardSocketOptions.class, "SO_REUSEPORT"); + if (reusePortField != null) { + return (SocketOption) reusePortField.get(null); + } + } catch (Throwable ignored) { + } + return null; + } } diff --git a/src/one/nio/net/JavaSslClientContext.java b/src/one/nio/net/JavaSslClientContext.java new file mode 100644 index 0000000..ede9167 --- /dev/null +++ b/src/one/nio/net/JavaSslClientContext.java @@ -0,0 +1,162 @@ +/* + * Copyright 2015 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.net; + +import java.io.IOException; +import java.security.NoSuchAlgorithmException; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSocket; + +// Should be used only for development purposes only on local machine +public class JavaSslClientContext extends SslContext { + private final SSLParameters parameters; + private final SSLContext sslContext; + + public JavaSslClientContext() throws NoSuchAlgorithmException, IOException { + sslContext = SSLContext.getDefault(); + parameters = sslContext.getDefaultSSLParameters(); + } + + public JavaSslClientContext(SSLContext sslContext) { + this.sslContext = sslContext; + this.parameters = sslContext.getDefaultSSLParameters(); + } + + @Override + public void setDebug(boolean debug) { + // Ignore + } + + @Override + public boolean getDebug() { + throw new UnsupportedOperationException(); + } + + @Override + public void setRdrand(boolean rdrand) throws SSLException { + // Ignore + } + + @Override + public void setProtocols(String protocols) throws SSLException { + parameters.setProtocols(protocols.split("\\+")); + } + + @Override + public void setCiphers(String ciphers) throws SSLException { + parameters.setCipherSuites(ciphers.split(":")); + } + + @Override + public void setCurve(String curve) throws SSLException { + // Ignore + } + + @Override + public void setCertificate(String certFile) throws SSLException { + // Ignore + } + + @Override + public void setPrivateKey(String privateKeyFile) throws SSLException { + // Ignore + } + + @Override + public void setPassphrase(byte[] passphrase) throws SSLException { + // Ignore + } + + @Override + public void setCA(String caFile) throws SSLException { + // Ignore + } + + @Override + public void setVerify(int verifyMode) throws SSLException { + // Ignore + } + + @Override + public void setTicketKeys(byte[] keys) throws SSLException { + // Ignore + } + + @Override + public void setSessionCache(String mode, int size) throws SSLException { + // Ignore + } + + @Override + public void setTimeout(long timeout) throws SSLException { + // Ignore + } + + @Override + public void setSessionId(byte[] sessionId) throws SSLException { + // Ignore + } + + @Override + public void setApplicationProtocols(String[] protocols) throws SSLException { + parameters.setApplicationProtocols(protocols); + } + + @Override + public void setOCSP(byte[] response) throws SSLException { + // Ignore + } + + @Override + public void setSNI(SslConfig[] sni) throws IOException { + // Ignore + } + + @Override + public void setMaxEarlyData(int size) throws SSLException { + // Ignore + } + + @Override + public void setKernelTlsEnabled(boolean kernelTlsEnabled) throws SSLException { + // Ignore + } + + @Override + public void setCompressionAlgorithms(String[] algorithms) throws SSLException { + // Ignore + } + + @Override + public void setAntiReplayEnabled(boolean antiReplayEnabled) throws SSLException { + // Ignore + } + + @Override + public void setKeylog(boolean keylog) { + // Ignore + } + + public SSLSocket createSocket() throws IOException { + SSLSocket socket = (SSLSocket) sslContext.getSocketFactory().createSocket(); + socket.setSSLParameters(parameters); + return socket; + } +} diff --git a/src/one/nio/net/JavaSslClientSocket.java b/src/one/nio/net/JavaSslClientSocket.java new file mode 100644 index 0000000..3f7e269 --- /dev/null +++ b/src/one/nio/net/JavaSslClientSocket.java @@ -0,0 +1,355 @@ +/* + * Copyright 2015 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.net; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.io.RandomAccessFile; +import java.io.UncheckedIOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.channels.WritableByteChannel; + +import javax.net.ssl.SSLSocket; + +import one.nio.mem.DirectMemory; + +public final class JavaSslClientSocket extends Socket { + private final SSLSocket socket; + private final JavaSslClientContext sslContext; + private volatile WritableByteChannel outCh; + private volatile ReadableByteChannel inCh; + private volatile OutputStream outputStream; + private volatile InputStream inputStream; + + public JavaSslClientSocket(JavaSslClientContext sslContext) { + try { + this.sslContext = sslContext; + this.socket = this.sslContext.createSocket(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public boolean isOpen() { + return !socket.isClosed(); + } + + @Override + public void close() { + try { + socket.close(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public Socket accept() throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void connect(InetAddress address, int port) throws IOException { + this.socket.connect(new InetSocketAddress(address, port)); + this.outputStream = socket.getOutputStream(); + this.outCh = Channels.newChannel(outputStream); + this.inputStream = socket.getInputStream(); + this.inCh = Channels.newChannel(inputStream); + } + + @Override + public void bind(InetAddress address, int port, int backlog) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int writeRaw(long buf, int count, int flags) throws IOException { + return outCh.write(DirectMemory.wrap(buf, count)); + } + + @Override + public int write(byte[] data, int offset, int count, int flags) throws IOException { + return outCh.write(ByteBuffer.wrap(data, offset, count)); + } + + @Override + public void writeFully(byte[] data, int offset, int count) throws IOException { + outputStream.write(data, offset, count); + } + + @Override + public int send(ByteBuffer src, int flags, InetAddress address, int port) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int readRaw(long buf, int count, int flags) throws IOException { + return inCh.read(DirectMemory.wrap(buf, count)); + } + + @Override + public int read(byte[] data, int offset, int count, int flags) throws IOException { + return inCh.read(ByteBuffer.wrap(data, offset, count)); + } + + @Override + public void readFully(byte[] data, int offset, int count) throws IOException { + while (count > 0) { + int bytes = inputStream.read(data, offset, count); + if (bytes < 0) { + throw new SocketClosedException(); + } + offset += bytes; + count -= bytes; + } + } + + @Override + public long sendFile(RandomAccessFile file, long offset, long count) throws IOException { + return file.getChannel().transferTo(offset, count, outCh); + } + + @Override + public InetSocketAddress recv(ByteBuffer dst, int flags) { + throw new UnsupportedOperationException(); + } + + @Override + public int sendMsg(Msg msg, int flags) { + throw new UnsupportedOperationException(); + } + + @Override + public int recvMsg(Msg msg, int flags) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public void setBlocking(boolean blocking) { + // Ignore + } + + @Override + public boolean isBlocking() { + return true; + } + + @Override + public void setTimeout(int timeout) { + try { + socket.setSoTimeout(timeout); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public int getTimeout() { + try { + return socket.getSoTimeout(); + } catch (SocketException e) { + return 0; + } + } + + @Override + public void setKeepAlive(boolean keepAlive) { + try { + socket.setKeepAlive(keepAlive); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public boolean getKeepAlive() { + try { + return socket.getKeepAlive(); + } catch (SocketException e) { + return false; + } + } + + @Override + public void setNoDelay(boolean noDelay) { + try { + socket.setTcpNoDelay(noDelay); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public boolean getNoDelay() { + try { + return socket.getTcpNoDelay(); + } catch (SocketException e) { + return false; + } + } + + @Override + public void setTcpFastOpen(boolean tcpFastOpen) { + // Ignore + } + + @Override + public boolean getTcpFastOpen() { + return false; + } + + @Override + public void setDeferAccept(boolean deferAccept) { + // Ignore + } + + @Override + public boolean getDeferAccept() { + return false; + } + + @Override + public void setReuseAddr(boolean reuseAddr, boolean reusePort) { + try { + socket.setReuseAddress(reuseAddr); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public boolean getReuseAddr() { + try { + return socket.getReuseAddress(); + } catch (SocketException e) { + return false; + } + } + + @Override + public boolean getReusePort() { + return false; + } + + @Override + public void setRecvBuffer(int recvBuf) { + try { + socket.setReceiveBufferSize(recvBuf); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public int getRecvBuffer() { + try { + return socket.getReceiveBufferSize(); + } catch (SocketException e) { + return 0; + } + } + + @Override + public void setSendBuffer(int sendBuf) { + try { + socket.setSendBufferSize(sendBuf); + } catch (SocketException e) { + // Ignore + } + } + + @Override + public int getSendBuffer() { + try { + return socket.getSendBufferSize(); + } catch (SocketException e) { + return 0; + } + } + + @Override + public void setTos(int tos) { + // Ignore + } + + @Override + public int getTos() { + return 0; + } + + @Override + public byte[] getOption(int level, int option) { + return new byte[0]; + } + + @Override + public boolean setOption(int level, int option, byte[] value) { + return false; + } + + @Override + public InetSocketAddress getLocalAddress() { + return new InetSocketAddress(socket.getLocalAddress(), socket.getPort()); + } + + @Override + public InetSocketAddress getRemoteAddress() { + return (InetSocketAddress) socket.getRemoteSocketAddress(); + } + + @Override + public Socket sslWrap(SslContext context) { + return this; + } + + @Override + public int read(ByteBuffer dst) throws IOException { + return inCh.read(dst); + } + + @Override + public int write(ByteBuffer src) throws IOException { + return outCh.write(src); + } + + @Override + public Socket sslUnwrap() { + return this; + } + + @Override + public SslContext getSslContext() { + return sslContext; + } + + @Override + public T getSslOption(SslOption option) { + return null; + } + + @Override + public void listen(int backlog) throws IOException { + throw new UnsupportedOperationException(); + } +} diff --git a/src/one/nio/net/KeylogHolder.java b/src/one/nio/net/KeylogHolder.java new file mode 100644 index 0000000..1d6a2d6 --- /dev/null +++ b/src/one/nio/net/KeylogHolder.java @@ -0,0 +1,45 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.net; + +import java.net.InetSocketAddress; +import java.util.Objects; +import java.util.function.BiConsumer; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KeylogHolder { + private static final Logger log = LoggerFactory.getLogger(KeylogHolder.class); + + public static final BiConsumer DEFAULT_CONSUMER = (keyLine, addr) -> log.info(keyLine); + public static final BiConsumer NOP_CONSUMER = (s, bytes) -> {}; + + private static volatile BiConsumer CONSUMER = DEFAULT_CONSUMER; + + public static void setConsumer(BiConsumer consumer) { + CONSUMER = Objects.requireNonNull(consumer); + } + + public static void log(String keyLine, InetSocketAddress addr) { + try { + CONSUMER.accept(keyLine, addr); + } catch (Exception e) { + // Ignore + } + } +} \ No newline at end of file diff --git a/src/one/nio/net/NativeSocket.java b/src/one/nio/net/NativeSocket.java index 80a6573..e358c9a 100755 --- a/src/one/nio/net/NativeSocket.java +++ b/src/one/nio/net/NativeSocket.java @@ -50,12 +50,14 @@ public final boolean isOpen() { @Override public NativeSocket accept() throws IOException { - return new NativeSocket(accept0(false)); + int fd = accept0(false); + return fd >= 0 ? new NativeSocket(fd) : null; } @Override public NativeSocket acceptNonBlocking() throws IOException { - return new NativeSocket(accept0(true)); + int fd = accept0(true); + return fd >= 0 ? new NativeSocket(fd) : null; } @Override @@ -298,6 +300,18 @@ public int recvMsg(Msg msg, int flags) throws IOException { @Override public final native int getTos(); + @Override + public final native void setNotsentLowat(int lowat); + + @Override + public final native int getNotsentLowat(); + + @Override + public final native void setThinLinearTimeouts(boolean thinLto); + + @Override + public final native boolean getThinLinearTimeouts(); + @Override public native byte[] getOption(int level, int option); @@ -315,7 +329,7 @@ static Object toNativeAddr(String host, int port) throws UnknownHostException { final native void connect0(Object address, int port) throws IOException; final native void bind0(Object address, int port) throws IOException; final native int accept0(boolean nonblock) throws IOException; - final native long sendFile0(int sourceFD, long offset, long count) throws IOException; + native long sendFile0(int sourceFD, long offset, long count) throws IOException; final native int sendTo0(byte[] data, int offset, int size, int flags, Object address, int port) throws IOException; final native int sendTo1(long buf, int size, int flags, Object address, int port) throws IOException; final native int recvFrom0(byte[] data, int offset, int maxSize, int flags, AddressHolder holder) throws IOException; diff --git a/src/one/nio/net/NativeSslContext.java b/src/one/nio/net/NativeSslContext.java index 74d7ae1..a8a15f3 100755 --- a/src/one/nio/net/NativeSslContext.java +++ b/src/one/nio/net/NativeSslContext.java @@ -16,23 +16,59 @@ package one.nio.net; -import one.nio.mgt.Management; -import one.nio.util.ByteArrayBuilder; -import one.nio.util.Utf8; - -import javax.net.ssl.SSLException; import java.io.IOException; import java.util.ServiceConfigurationError; import java.util.StringTokenizer; import java.util.concurrent.atomic.AtomicInteger; +import javax.net.ssl.SSLException; + +import one.nio.mgt.Management; +import one.nio.util.ByteArrayBuilder; +import one.nio.util.Utf8; + class NativeSslContext extends SslContext { private static final AtomicInteger counter = new AtomicInteger(); + + private static class CompressionAlgorithms { + // Possible compression values from RFC8879 (Refer to openssl/tls1.h) + public static int ZLIB = 1; + public static int BROTLI = 2; + public static int ZSTD = 3; + } + + private static class SslOption { + public static long ENABLE_KTLS = 1L << 3; + public static long NO_COMPRESSION = 1L << 17; + public static long NO_SSLv2 = 0; // as of OpenSSL 1.0.2g the SSL_OP_NO_SSLv2 option is set by default. + public static long NO_ANTI_REPLAY = 1L << 24; + public static long NO_SSLv3 = 1L << 25; + public static long NO_TLSv1 = 1L << 26; + public static long NO_TLSv1_2 = 1L << 27; + public static long NO_TLSv1_1 = 1L << 28; + public static long NO_TLSv1_3 = 1L << 29; + public static long NO_TX_CERT_COMPRESSION = 1L << 32; + } + + private static class CacheMode { + public static int NONE = 0; + public static int INTERNAL = 1; + public static int EXTERNAL = 2; + } + + private static final long ALL_DISABLED = SslOption.NO_COMPRESSION + | SslOption.NO_SSLv2 + | SslOption.NO_SSLv3 + | SslOption.NO_TLSv1 + | SslOption.NO_TLSv1_1 + | SslOption.NO_TLSv1_2 + | SslOption.NO_TLSv1_3; final int id; long ctx; NativeSslContext[] subcontexts; + NativeSslContext() throws SSLException { this.id = counter.incrementAndGet(); this.ctx = ctxNew(); @@ -93,32 +129,68 @@ public void setProtocols(String protocols) { String protocol = st.nextToken(); switch (protocol) { case "compression": - enabled |= 0x00020000; + enabled |= SslOption.NO_COMPRESSION; break; case "sslv2": - enabled |= 0x01000000; + enabled |= SslOption.NO_SSLv2; break; case "sslv3": - enabled |= 0x02000000; + enabled |= SslOption.NO_SSLv3; break; case "tlsv1": - enabled |= 0x04000000; + enabled |= SslOption.NO_TLSv1; break; case "tlsv1.1": - enabled |= 0x10000000; + enabled |= SslOption.NO_TLSv1_1; break; case "tlsv1.2": - enabled |= 0x08000000; + enabled |= SslOption.NO_TLSv1_2; break; case "tlsv1.3": - enabled |= 0x20000000; + enabled |= SslOption.NO_TLSv1_3; break; } } - int all = 0x00020000 + 0x01000000 + 0x02000000 + 0x04000000 + 0x08000000 + 0x10000000 + 0x20000000; clearOptions(enabled); - setOptions(all - enabled); + setOptions(ALL_DISABLED & ~enabled); + } + + @Override + public void setKernelTlsEnabled(boolean kernelTlsEnabled) throws SSLException { + if (kernelTlsEnabled) { + setOptions(SslOption.ENABLE_KTLS); + } else { + clearOptions(SslOption.ENABLE_KTLS); + } + } + + @Override + public void setAntiReplayEnabled(boolean antiReplayEnabled) throws SSLException { + if (antiReplayEnabled) { + clearOptions(SslOption.NO_ANTI_REPLAY); + } else { + setOptions(SslOption.NO_ANTI_REPLAY); + } + } + + @Override + public void setSessionCache(String mode, int size) throws SSLException { + switch (mode) { + case "none": + setCacheMode(CacheMode.NONE); + break; + case "internal": + setCacheMode(CacheMode.INTERNAL); + setInternalCacheSize(size); + break; + case "external": + setCacheMode(CacheMode.EXTERNAL); + SslSessionCache.Singleton.setCapacity(size); + break; + default: + throw new SSLException("Unsupported session cache mode: " + mode); + } } @Override @@ -127,6 +199,15 @@ public void setProtocols(String protocols) { @Override public native void setCiphers(String ciphers) throws SSLException; + /** + * Sets the curve used for ECDH temporary keys used during key exchange. + * Use openssl ecparam -list_curves to get list of supported curves. + * @param curve short name of the curve, if null - all curves built into the OpenSSL library will be allowed + * @throws SSLException + */ + @Override + public native void setCurve(String curve) throws SSLException; + @Override public native void setCertificate(String certFile) throws SSLException; @@ -145,9 +226,6 @@ public void setProtocols(String protocols) { @Override public native void setTicketKeys(byte[] keys) throws SSLException; - @Override - public native void setCacheSize(int size) throws SSLException; - @Override public native void setTimeout(long timeout) throws SSLException; @@ -157,6 +235,9 @@ public void setProtocols(String protocols) { @Override public native void setOCSP(byte[] response) throws SSLException; + @Override + public native void setMaxEarlyData(int size) throws SSLException; + @Override public void setApplicationProtocols(String[] protocols) throws SSLException { ByteArrayBuilder builder = new ByteArrayBuilder(); @@ -194,14 +275,50 @@ public void setSNI(SslConfig[] sni) throws IOException { setSNI0(names.toBytes(), contexts); } + @Override + public void setCompressionAlgorithms(String[] compressionAlgorithms) throws SSLException { + if (compressionAlgorithms == null || compressionAlgorithms.length == 0) { + setOptions(SslOption.NO_TX_CERT_COMPRESSION); + return; + } + + int[] algorithms = new int[compressionAlgorithms.length]; + for (int i = 0; i < compressionAlgorithms.length; i++) { + String algorithm = compressionAlgorithms[i]; + switch (algorithm) { + case "zlib": + algorithms[i] = CompressionAlgorithms.ZLIB; + break; + case "brotli": + algorithms[i] = CompressionAlgorithms.BROTLI; + break; + case "zstd": + algorithms[i] = CompressionAlgorithms.ZSTD; + break; + default: + throw new SSLException("Unsupported cert compression algorithm: " + algorithm); + } + } + clearOptions(SslOption.NO_TX_CERT_COMPRESSION); + setCompressionAlgorithms0(algorithms); + } + + private native void setCompressionAlgorithms0(int[] algorithms) throws SSLException; + private native void setSNI0(byte[] names, long[] contexts) throws SSLException; - private native void setOptions(int options); - private native void clearOptions(int options); + @Override + public native void setKeylog(boolean keylog); + + private native void setOptions(long options); + private native void clearOptions(long options); private native long getSessionCounter(int key); private native long[] getSessionCounters(int keysBitmap); + private native void setInternalCacheSize(int size) throws SSLException; + private native void setCacheMode(int mode) throws SSLException; + private static native void init(); private static native long ctxNew() throws SSLException; private static native void ctxFree(long ctx); diff --git a/src/one/nio/net/NativeSslSocket.java b/src/one/nio/net/NativeSslSocket.java index c1da41e..c460e01 100755 --- a/src/one/nio/net/NativeSslSocket.java +++ b/src/one/nio/net/NativeSslSocket.java @@ -23,6 +23,9 @@ class NativeSslSocket extends NativeSocket { NativeSslContext context; long ssl; + private volatile boolean isEarlyDataAccepted = false; + private volatile boolean isHandshakeDone = false; + NativeSslSocket(int fd, NativeSslContext context, boolean serverMode) throws IOException { super(fd); context.refresh(); @@ -41,12 +44,14 @@ public synchronized void close() { @Override public NativeSocket accept() throws IOException { - return new NativeSslSocket(accept0(false), context, true); + int fd = accept0(false); + return fd >= 0 ? new NativeSslSocket(fd, context, true) : null; } @Override public NativeSocket acceptNonBlocking() throws IOException { - return new NativeSslSocket(accept0(true), context, true); + int fd = accept0(true); + return fd >= 0 ? new NativeSslSocket(fd, context, true) : null; } @Override @@ -63,31 +68,29 @@ public SslContext getSslContext() { @SuppressWarnings("unchecked") public Object getSslOption(SslOption option) { switch (option.id) { - case 1: + case SslOption.PEER_CERTIFICATE_ID: return sslPeerCertificate(); - case 2: + case SslOption.PEER_CERTIFICATE_CHAIN_ID: return sslPeerCertificateChain(); - case 3: + case SslOption.PEER_SUBJECT_ID: return sslCertName(0); - case 4: + case SslOption.PEER_ISSUER_ID: return sslCertName(1); - case 5: + case SslOption.VERIFY_RESULT_ID: return sslVerifyResult(); - case 6: + case SslOption.SESSION_REUSED_ID: return sslSessionReused(); - case 7: + case SslOption.SESSION_TICKET_ID: return sslSessionTicket(); - case 8: + case SslOption.CURRENT_CIPHER_ID: return sslCurrentCipher(); + case SslOption.SESSION_EARLYDATA_ACCEPTED_ID: + return sslSessionEarlyDataAccepted(); + case SslOption.SESSION_HANDSHAKE_DONE_ID: + return sslHandshakeDone(); } return null; } - - @Override - public long sendFile(RandomAccessFile file, long offset, long count) throws IOException { - throw new IOException("Cannot use sendFile with SSL"); - } - @Override public synchronized native void handshake() throws IOException; @@ -109,6 +112,19 @@ public long sendFile(RandomAccessFile file, long offset, long count) throws IOEx @Override public synchronized native void readFully(byte[] data, int offset, int count) throws IOException; + @Override + synchronized native long sendFile0(int sourceFD, long offset, long count) throws IOException; + + private boolean sslSessionEarlyDataAccepted() { + // the value is updated by native code during IO operations + return isEarlyDataAccepted; + } + + private boolean sslHandshakeDone() { + // the value is updated by native code during IO operations + return isHandshakeDone; + } + private synchronized native byte[] sslPeerCertificate(); private synchronized native Object[] sslPeerCertificateChain(); private synchronized native String sslCertName(int which); diff --git a/src/one/nio/net/Socket.java b/src/one/nio/net/Socket.java index df5b6ca..30621cf 100755 --- a/src/one/nio/net/Socket.java +++ b/src/one/nio/net/Socket.java @@ -141,6 +141,10 @@ public abstract class Socket implements ByteChannel { public abstract int getSendBuffer(); public abstract void setTos(int tos); public abstract int getTos(); + public void setNotsentLowat(int lowat) {} + public int getNotsentLowat() {return 0;} + public void setThinLinearTimeouts(boolean thinLto) {} + public boolean getThinLinearTimeouts(){return false;} public abstract byte[] getOption(int level, int option); public abstract boolean setOption(int level, int option, byte[] value); public abstract InetSocketAddress getLocalAddress(); @@ -152,7 +156,9 @@ public abstract class Socket implements ByteChannel { public Socket acceptNonBlocking() throws IOException { Socket s = accept(); - s.setBlocking(false); + if (s != null) { + s.setBlocking(false); + } return s; } @@ -180,8 +186,23 @@ public int read(byte[] data, int offset, int count) throws IOException { return read(data, offset, count, 0); } + @Deprecated public static Socket create() throws IOException { - return NativeLibrary.IS_SUPPORTED ? new NativeSocket(0, SOCK_STREAM) : new JavaSocket(); + return createClientSocket(null); + } + + public static Socket createClientSocket() throws IOException { + return createClientSocket(null); + } + + public static Socket createClientSocket(SslContext sslContext) throws IOException { + Socket socket; + if (NativeLibrary.IS_SUPPORTED) { + socket = new NativeSocket(0, SOCK_STREAM); + } else { + socket = sslContext == null ? new JavaSocket() : new JavaSslClientSocket((JavaSslClientContext) sslContext); + } + return socket; } public static Socket createServerSocket() throws IOException { @@ -201,7 +222,7 @@ public static Socket createUnixSocket(int type) throws IOException { } public static Socket connectInet(InetAddress address, int port) throws IOException { - Socket sock = create(); + Socket sock = createClientSocket(); sock.connect(address, port); return sock; } diff --git a/src/one/nio/net/SslClientContextFactory.java b/src/one/nio/net/SslClientContextFactory.java new file mode 100644 index 0000000..837a86c --- /dev/null +++ b/src/one/nio/net/SslClientContextFactory.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.net; + +import java.io.IOException; +import java.security.NoSuchAlgorithmException; + +import one.nio.os.NativeLibrary; + +public class SslClientContextFactory { + + public static SslContext create() { + if (NativeLibrary.IS_SUPPORTED) { + return SslContext.getDefault(); + } else { + try { + return new JavaSslClientContext(); + } catch (NoSuchAlgorithmException | IOException e) { + throw new RuntimeException(e); + } + } + } +} diff --git a/src/one/nio/net/SslConfig.java b/src/one/nio/net/SslConfig.java index 4226319..7056621 100644 --- a/src/one/nio/net/SslConfig.java +++ b/src/one/nio/net/SslConfig.java @@ -25,6 +25,7 @@ public class SslConfig { // Conservative ciphersuite according to https://wiki.mozilla.org/Security/Server_Side_TLS static final String DEFAULT_CIPHERS = "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:ECDHE-ECDSA-AES256-GCM-SHA384:ECDHE-RSA-AES256-GCM-SHA384:ECDHE-ECDSA-CHACHA20-POLY1305:ECDHE-RSA-CHACHA20-POLY1305:DHE-RSA-AES128-GCM-SHA256:DHE-RSA-AES256-GCM-SHA384:DHE-RSA-CHACHA20-POLY1305:ECDHE-ECDSA-AES128-SHA256:ECDHE-RSA-AES128-SHA256:ECDHE-ECDSA-AES128-SHA:ECDHE-RSA-AES128-SHA:ECDHE-ECDSA-AES256-SHA384:ECDHE-RSA-AES256-SHA384:ECDHE-ECDSA-AES256-SHA:ECDHE-RSA-AES256-SHA:DHE-RSA-AES128-SHA256:DHE-RSA-AES256-SHA256:AES128-GCM-SHA256:AES256-GCM-SHA384:AES128-SHA256:AES256-SHA256:AES128-SHA:AES256-SHA:DES-CBC3-SHA"; + static final String DEFAULT_CACHE_MODE = "internal"; static final int DEFAULT_CACHE_SIZE = 262144; static final long DEFAULT_TIMEOUT_SEC = 300; static final long DEFAULT_REFRESH_INTERVAL = 300_000; @@ -33,6 +34,7 @@ public class SslConfig { public boolean rdrand; public String protocols; public String ciphers; + public String curve; public String[] certFile; public String[] privateKeyFile; public String passphrase; @@ -40,7 +42,8 @@ public class SslConfig { public String ticketKeyFile; public String ticketDir; public int verifyMode; - public int cacheSize; + public String cacheMode = DEFAULT_CACHE_MODE; // "none", "internal", "external" + public int cacheSize = DEFAULT_CACHE_SIZE; @Converter(method = "longTime") public long timeout; @Converter(method = "longTime") @@ -48,6 +51,11 @@ public class SslConfig { public String sessionId; public String[] applicationProtocols; public String ocspFile; + public String[] compressionAlgorithms; + public int maxEarlyDataSize = 0; // zero value disables 0-RTT feature + public boolean kernelTlsEnabled = false; + public boolean antiReplayEnabled = true; // flag is relevant only if early-data used + public boolean keylog; // The following fields should not be updated by SslContext.inherit() String hostName; @@ -57,6 +65,7 @@ public static SslConfig from(Properties props) { SslConfig config = new SslConfig(); config.protocols = props.getProperty("one.nio.ssl.protocols"); config.ciphers = props.getProperty("one.nio.ssl.ciphers"); + config.curve = props.getProperty("one.nio.ssl.curve"); config.certFile = toArray(props.getProperty("one.nio.ssl.certFile")); config.privateKeyFile = toArray(props.getProperty("one.nio.ssl.privateKeyFile")); config.passphrase = props.getProperty("one.nio.ssl.passphrase"); diff --git a/src/one/nio/net/SslContext.java b/src/one/nio/net/SslContext.java index 775bf26..f50e6c7 100755 --- a/src/one/nio/net/SslContext.java +++ b/src/one/nio/net/SslContext.java @@ -16,14 +16,6 @@ package one.nio.net; -import one.nio.os.NativeLibrary; -import one.nio.util.ByteArrayBuilder; -import one.nio.util.Utf8; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import javax.net.ssl.SSLException; import java.io.File; import java.io.IOException; import java.lang.reflect.Field; @@ -34,6 +26,15 @@ import java.util.Date; import java.util.concurrent.atomic.AtomicLong; +import javax.net.ssl.SSLException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import one.nio.os.NativeLibrary; +import one.nio.util.ByteArrayBuilder; +import one.nio.util.Utf8; + public abstract class SslContext { private static final Logger log = LoggerFactory.getLogger(SslContext.class); @@ -84,6 +85,9 @@ public synchronized SslContext configure(SslConfig config) throws IOException { setCiphers(config.ciphers != null ? config.ciphers : SslConfig.DEFAULT_CIPHERS); + // with null the curve will be auto-selected by openssl + setCurve(config.curve); + if (changed(config.passphrase, currentConfig.passphrase)) { setPassphrase(Utf8.toBytes(getPassphrase(config.passphrase))); } @@ -116,9 +120,19 @@ public synchronized SslContext configure(SslConfig config) throws IOException { } setVerify(config.verifyMode); - setCacheSize(config.cacheSize != 0 ? config.cacheSize : SslConfig.DEFAULT_CACHE_SIZE); + + setSessionCache(config.cacheMode, config.cacheSize != 0 ? config.cacheSize : SslConfig.DEFAULT_CACHE_SIZE); + setTimeout(config.timeout != 0 ? config.timeout / 1000 : SslConfig.DEFAULT_TIMEOUT_SEC); + setMaxEarlyData(config.maxEarlyDataSize); + + if (config.maxEarlyDataSize > 0) { + setAntiReplayEnabled(config.antiReplayEnabled); + } + + setKernelTlsEnabled(config.kernelTlsEnabled); + if (changed(config.sessionId, currentConfig.sessionId)) { setSessionId(Utf8.toBytes(config.sessionId)); } @@ -138,6 +152,14 @@ public synchronized SslContext configure(SslConfig config) throws IOException { setSNI(config.sni); } + if (config.compressionAlgorithms != currentConfig.compressionAlgorithms) { + setCompressionAlgorithms(config.compressionAlgorithms); + } + + if (config.keylog != currentConfig.keylog) { + setKeylog(config.keylog); + } + this.currentConfig = config; return this; } @@ -285,16 +307,22 @@ void refresh() { public abstract void setRdrand(boolean rdrand) throws SSLException; public abstract void setProtocols(String protocols) throws SSLException; public abstract void setCiphers(String ciphers) throws SSLException; + public abstract void setCurve(String curve) throws SSLException; public abstract void setCertificate(String certFile) throws SSLException; public abstract void setPrivateKey(String privateKeyFile) throws SSLException; public abstract void setPassphrase(byte[] passphrase) throws SSLException; public abstract void setCA(String caFile) throws SSLException; public abstract void setVerify(int verifyMode) throws SSLException; public abstract void setTicketKeys(byte[] keys) throws SSLException; - public abstract void setCacheSize(int size) throws SSLException; + public abstract void setSessionCache(String mode, int size) throws SSLException; public abstract void setTimeout(long timeout) throws SSLException; public abstract void setSessionId(byte[] sessionId) throws SSLException; public abstract void setApplicationProtocols(String[] protocols) throws SSLException; public abstract void setOCSP(byte[] response) throws SSLException; public abstract void setSNI(SslConfig[] sni) throws IOException; + public abstract void setMaxEarlyData(int size) throws SSLException; + public abstract void setKernelTlsEnabled(boolean kernelTlsEnabled) throws SSLException; + public abstract void setCompressionAlgorithms(String[] algorithms) throws SSLException; + public abstract void setAntiReplayEnabled(boolean antiReplayEnabled) throws SSLException; + public abstract void setKeylog(boolean keylog); } diff --git a/src/one/nio/net/SslOption.java b/src/one/nio/net/SslOption.java index e374637..7ad8c0e 100644 --- a/src/one/nio/net/SslOption.java +++ b/src/one/nio/net/SslOption.java @@ -17,16 +17,29 @@ package one.nio.net; public class SslOption { - public static final SslOption PEER_CERTIFICATE = new SslOption<>(1, byte[].class); - public static final SslOption PEER_CERTIFICATE_CHAIN = new SslOption<>(2, Object[].class); - public static final SslOption PEER_SUBJECT = new SslOption<>(3, String.class); - public static final SslOption PEER_ISSUER = new SslOption<>(4, String.class); - public static final SslOption VERIFY_RESULT = new SslOption<>(5, String.class); - - public static final SslOption SESSION_REUSED = new SslOption<>(6, Boolean.class); - public static final SslOption SESSION_TICKET = new SslOption<>(7, Integer.class); - - public static final SslOption CURRENT_CIPHER = new SslOption<>(8, String.class); + static final int PEER_CERTIFICATE_ID = 1; + static final int PEER_CERTIFICATE_CHAIN_ID = 2; + static final int PEER_SUBJECT_ID = 3; + static final int PEER_ISSUER_ID = 4; + static final int VERIFY_RESULT_ID = 5; + static final int SESSION_REUSED_ID = 6; + static final int SESSION_TICKET_ID = 7; + static final int CURRENT_CIPHER_ID = 8; + static final int SESSION_EARLYDATA_ACCEPTED_ID = 9; + static final int SESSION_HANDSHAKE_DONE_ID = 10; + + public static final SslOption PEER_CERTIFICATE = new SslOption<>(PEER_CERTIFICATE_ID, byte[].class); + public static final SslOption PEER_CERTIFICATE_CHAIN = new SslOption<>(PEER_CERTIFICATE_CHAIN_ID, Object[].class); + public static final SslOption PEER_SUBJECT = new SslOption<>(PEER_SUBJECT_ID, String.class); + public static final SslOption PEER_ISSUER = new SslOption<>(PEER_ISSUER_ID, String.class); + public static final SslOption VERIFY_RESULT = new SslOption<>(VERIFY_RESULT_ID, String.class); + + public static final SslOption SESSION_REUSED = new SslOption<>(SESSION_REUSED_ID, Boolean.class); + public static final SslOption SESSION_TICKET = new SslOption<>(SESSION_TICKET_ID, Integer.class); + + public static final SslOption CURRENT_CIPHER = new SslOption<>(CURRENT_CIPHER_ID, String.class); + public static final SslOption SESSION_EARLYDATA_ACCEPTED = new SslOption<>(SESSION_EARLYDATA_ACCEPTED_ID, Boolean.class); + public static final SslOption SESSION_HANDSHAKE_DONE = new SslOption<>(SESSION_HANDSHAKE_DONE_ID, Boolean.class); final int id; final Class type; diff --git a/src/one/nio/net/SslSessionCache.java b/src/one/nio/net/SslSessionCache.java new file mode 100644 index 0000000..079dd9e --- /dev/null +++ b/src/one/nio/net/SslSessionCache.java @@ -0,0 +1,116 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.net; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import sun.security.util.Cache; + +import java.util.Objects; + +public interface SslSessionCache { + Logger log = LoggerFactory.getLogger(SslSessionCache.class); + + void resize(int maxSize); + + void addSession(byte[] sessionId, byte[] session); + + byte[] getSession(byte[] sessionId); + + void removeSession(byte[] sessionId); + + + class Singleton { + private static volatile SslSessionCache INSTANCE; + private static Singleton.Factory FACTORY = Default::new; + private static int CAPACITY = Default.CAPACITY; + + public interface Factory { + SslSessionCache create(int size); + } + + public synchronized static void setFactory(Factory factory) { + if (INSTANCE != null) { + throw new IllegalStateException("Unable to change factory after lazy instantiation is done"); + } + Singleton.FACTORY = Objects.requireNonNull(factory); + } + + public synchronized static void setCapacity(int capacity) { + if (capacity < 0) { + throw new IllegalArgumentException("Capacity must be positive"); + } + if (INSTANCE != null && CAPACITY != capacity) { + INSTANCE.resize(capacity); + } + Singleton.CAPACITY = capacity; + } + + public static SslSessionCache getInstance() { + if (INSTANCE == null) { + synchronized (Singleton.class) { + if (INSTANCE == null) { + INSTANCE = FACTORY.create(CAPACITY); + } + } + } + return INSTANCE; + } + + private synchronized static void clearInstance() { + Singleton.INSTANCE = null; + } + } + + class Default implements SslSessionCache { + private final Cache cache; + static int CAPACITY = 1024; + + private static Cache.EqualByteArray toKey(byte[] bytes) { + return new Cache.EqualByteArray(bytes); + } + + public Default(int maxSize) { + this.cache = Cache.newSoftMemoryCache(maxSize); + } + + public Default() { + this(Default.CAPACITY); + } + + @Override + public void resize(int maxSize) { + cache.setCapacity(maxSize); + } + + @Override + public void addSession(byte[] sessionId, byte[] session) { + cache.put(toKey(sessionId), session); + } + + @Override + public byte[] getSession(byte[] sessionId) { + return cache.get(toKey(sessionId)); + } + + @Override + public void removeSession(byte[] sessionId) { + cache.remove(toKey(sessionId)); + } + } +} diff --git a/src/one/nio/net/native/socket.c b/src/one/nio/net/native/socket.c index 4ca28e9..6f74fd6 100755 --- a/src/one/nio/net/native/socket.c +++ b/src/one/nio/net/native/socket.c @@ -56,6 +56,9 @@ static pthread_t* fd_table; #define TCP_FASTOPEN 23 #endif +#ifndef TCP_NOTSENT_LOWAT +#define TCP_NOTSENT_LOWAT 25 +#endif static socklen_t sockaddr_from_java(JNIEnv* env, jobject address, jint port, struct sockaddr_storage* sa) { // AF_UNIX @@ -96,7 +99,7 @@ static socklen_t sockaddr_from_java(JNIEnv* env, jobject address, jint port, str } } -static jobject sockaddr_to_java(JNIEnv* env, struct sockaddr_storage* sa, socklen_t len) { +jobject sockaddr_to_java(JNIEnv* env, struct sockaddr_storage* sa, socklen_t len) { if (sa->ss_family == AF_INET) { struct sockaddr_in* sin = (struct sockaddr_in*)sa; int ip = ntohl(sin->sin_addr.s_addr); @@ -252,6 +255,9 @@ Java_one_nio_net_NativeSocket_accept0(JNIEnv* env, jobject self, jboolean nonblo end_blocking_call(fd_lock); if (result == -1) { + if (errno == EWOULDBLOCK) { + return -1; + } throw_io_exception(env); } return result; @@ -882,6 +888,31 @@ Java_one_nio_net_NativeSocket_getTos(JNIEnv* env, jobject self) { return get_int_socket_opt(fd, IPPROTO_IP, IP_TOS); } +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSocket_setNotsentLowat(JNIEnv* env, jobject self, jint lowat) { + int fd = (*env)->GetIntField(env, self, f_fd); + setsockopt(fd, SOL_TCP, TCP_NOTSENT_LOWAT, &lowat, sizeof(lowat)); +} + +JNIEXPORT jint JNICALL +Java_one_nio_net_NativeSocket_getNotsentLowat(JNIEnv* env, jobject self) { + int fd = (*env)->GetIntField(env, self, f_fd); + return get_int_socket_opt(fd, SOL_TCP, TCP_NOTSENT_LOWAT); +} + +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSocket_setThinLinearTimeouts(JNIEnv* env, jobject self, jboolean thinLto) { + int fd = (*env)->GetIntField(env, self, f_fd); + int value = (int) thinLto; + setsockopt(fd, SOL_TCP, TCP_THIN_LINEAR_TIMEOUTS, &value, sizeof(value)); +} + +JNIEXPORT jboolean JNICALL +Java_one_nio_net_NativeSocket_getThinLinearTimeouts(JNIEnv* env, jobject self) { + int fd = (*env)->GetIntField(env, self, f_fd); + return get_bool_socket_opt(fd, SOL_TCP, TCP_THIN_LINEAR_TIMEOUTS); +} + JNIEXPORT jbyteArray JNICALL Java_one_nio_net_NativeSocket_getOption(JNIEnv* env, jobject self, jint level, jint option) { int fd = (*env)->GetIntField(env, self, f_fd); diff --git a/src/one/nio/net/native/ssl.c b/src/one/nio/net/native/ssl.c index 4bba9bf..434d5d0 100755 --- a/src/one/nio/net/native/ssl.c +++ b/src/one/nio/net/native/ssl.c @@ -25,21 +25,20 @@ #include #include #include +#include #include #include #include #include #include #include +#include #include #include #include -#include -#include #include #include #include "jni_util.h" -#include "sslcompat.h" #define MAX_COUNTERS 32 @@ -49,7 +48,15 @@ enum SSLFlags { SF_HANDSHAKED = 2, SF_HAS_TICKET = 4, SF_HAS_OLD_TICKET = 8, - SF_NEW_TICKET = 12 + SF_NEW_TICKET = SF_HAS_TICKET | SF_HAS_OLD_TICKET, + SF_EARLY_DATA_ENABLED = 16, + SF_EARLY_DATA_FINISHED = 32, +}; + +enum SSLCacheMode { + CACHE_MODE_NONE = 0, + CACHE_MODE_INTERNAL = 1, + CACHE_MODE_EXTERNAL = 2, }; typedef struct { @@ -90,8 +97,25 @@ typedef struct { static jfieldID f_ctx; static jfieldID f_ssl; +static jfieldID f_isEarlyDataAccepted; +static jfieldID f_isHandshakeDone; static int preclosed_socket; +static jfieldID f_sslSessionCache; + +static JavaVM* global_vm; +static jclass c_KeylogHolder; +static jmethodID m_log; + +static jclass c_SslSessionCacheSingleton; +static jclass c_SslSessionCache; +static jmethodID m_getInstance; +static jmethodID m_clearInstance; + +static jmethodID m_addSession; +static jmethodID m_getSession; +static jmethodID m_removeSession; + // openssl dhparam -C 2048 static unsigned char dh2048_p[] = { 0xF5, 0x03, 0x6F, 0xFC, 0xA7, 0xFD, 0xC7, 0xD2, 0x69, 0xD8, 0xED, 0x73, 0x7D, 0x4D, 0x2A, 0x05, @@ -115,6 +139,7 @@ static unsigned char dh2048_g[] = { 0x02 }; extern void throw_socket_closed_cached(JNIEnv* env); +extern jobject sockaddr_to_java(JNIEnv* env, struct sockaddr_storage* sa, socklen_t len); static void throw_ssl_exception(JNIEnv* env) { char buf[256]; @@ -134,7 +159,9 @@ static int check_ssl_error(JNIEnv* env, SSL* ssl, int ret) { throw_socket_closed_cached(env); return 0; case SSL_ERROR_SYSCALL: - if (ERR_peek_error()) { + { + unsigned long e = ERR_peek_error(); + if (e && !ERR_SYSTEM_ERROR(e)) { throw_ssl_exception(env); } else if (ret == 0 || errno == 0) { // OpenSSL 1.0 and 1.1 return different error code in case of "dirty" connection close @@ -143,7 +170,16 @@ static int check_ssl_error(JNIEnv* env, SSL* ssl, int ret) { throw_io_exception(env); } return 0; + } case SSL_ERROR_SSL: + // workaround for SSL_sendfile() OpenSSL issue #23722 [ https://github.com/openssl/openssl/issues/23722 ] + { + int reason = ERR_GET_REASON(ERR_peek_error()); + if ((errno == EPIPE || errno == ECONNRESET) && reason == SSL_R_UNINITIALIZED) { + throw_io_exception(env); + return 0; + } + } throw_ssl_exception(env); return 0; case SSL_ERROR_WANT_READ: @@ -159,7 +195,7 @@ static int check_ssl_error(JNIEnv* env, SSL* ssl, int ret) { return err; } default: - sprintf(buf, "Unexpected SSL error code (%d)", err); + snprintf(buf, sizeof(buf), "Unexpected SSL error code (%d)", err); throw_by_name(env, "javax/net/ssl/SSLException", buf); return 0; } @@ -194,6 +230,7 @@ static void ssl_debug(const SSL* ssl, const char* fmt, ...) { char buf[128]; printf("ssl_debug [%s]: %s\n", ssl_get_peer_ip(ssl, buf, sizeof(buf)), message); + fflush(stdout); } static long get_session_counter(SSL_CTX* ctx, int key) { @@ -238,14 +275,6 @@ static void setup_dh_params(SSL_CTX* ctx) { } } -static void setup_ecdh_params(SSL_CTX* ctx) { - EC_KEY* ecdh = EC_KEY_new_by_curve_name(NID_X9_62_prime256v1); - if (ecdh != NULL) { - SSL_CTX_set_tmp_ecdh(ctx, ecdh); - EC_KEY_free(ecdh); - } -} - static AppData* create_app_data() { AppData* appData = calloc(1, sizeof(AppData)); if (appData != NULL) { @@ -307,6 +336,7 @@ static int ticket_key_callback(SSL* ssl, unsigned char key_name[16], unsigned ch TicketArray* tickets = &appData->tickets; Ticket* ticket = tickets->data; + intptr_t ssl_flags = (intptr_t)SSL_get_app_data(ssl); if (ticket == NULL) { // No ticket keys set } else if (new_session) { @@ -314,7 +344,7 @@ static int ticket_key_callback(SSL* ssl, unsigned char key_name[16], unsigned ch memcpy(key_name, ticket->name, 16); EVP_EncryptInit_ex(evp_ctx, EVP_aes_128_cbc(), NULL, ticket->aes_key, iv); HMAC_Init_ex(hmac_ctx, ticket->hmac_key, 16, EVP_sha256(), NULL); - SSL_set_app_data(ssl, (char*)(SF_SERVER | SF_NEW_TICKET)); + SSL_set_app_data(ssl, (char*)(ssl_flags | SF_NEW_TICKET)); result = 1; } } else { @@ -323,8 +353,8 @@ static int ticket_key_callback(SSL* ssl, unsigned char key_name[16], unsigned ch if (memcmp(key_name, ticket->name, 16) == 0) { HMAC_Init_ex(hmac_ctx, ticket->hmac_key, 16, EVP_sha256(), NULL); EVP_DecryptInit_ex(evp_ctx, EVP_aes_128_cbc(), NULL, ticket->aes_key, iv); - intptr_t ticket_options = i == 0 ? SF_SERVER | SF_HAS_TICKET : SF_SERVER | SF_HAS_OLD_TICKET; - SSL_set_app_data(ssl, (char*)ticket_options); + intptr_t ticket_options = i == 0 ? SF_HAS_TICKET : SF_HAS_OLD_TICKET; + SSL_set_app_data(ssl, (char*)(ssl_flags | ticket_options)); result = i == 0 ? 1 : 2; break; } @@ -375,7 +405,7 @@ static int ocsp_callback(SSL* ssl, void* arg) { if (appData->debug) { ssl_debug(ssl, "ocsp_callback: result=%d", result); } - + pthread_rwlock_unlock(&appData->lock); return result; } @@ -418,11 +448,13 @@ static int sni_callback(SSL* ssl, int* unused, void* arg) { static void ssl_info_callback(const SSL* ssl, int cb, int ret) { if (cb == SSL_CB_HANDSHAKE_START) { +#ifndef SSL_OP_NO_RENEGOTIATION // Reject any renegotiation by replacing actual socket with a dummy intptr_t flags = (intptr_t)SSL_get_app_data(ssl); if (flags & SF_HANDSHAKED) { SSL_set_fd((SSL*)ssl, preclosed_socket); } +#endif } else if (cb == SSL_CB_HANDSHAKE_DONE) { intptr_t flags = (intptr_t)SSL_get_app_data(ssl); if (flags & SF_SERVER) { @@ -431,6 +463,10 @@ static void ssl_info_callback(const SSL* ssl, int cb, int ret) { } } +static void update_NativeSslSocket_isHandshakeDone_field(JNIEnv* env, jobject self, SSL* ssl); +static void update_NativeSslSocket_isEarlyDataAccepted_field(JNIEnv* env, jobject self, SSL* ssl); + + static jbyteArray X509_cert_to_jbyteArray(JNIEnv* env, X509* cert) { jbyteArray result = NULL; @@ -449,11 +485,8 @@ static jbyteArray X509_cert_to_jbyteArray(JNIEnv* env, X509* cert) { JNIEXPORT void JNICALL Java_one_nio_net_NativeSslContext_init(JNIEnv* env, jclass cls) { - if (dlopen("libssl.so.3", RTLD_LAZY | RTLD_GLOBAL) == NULL && - dlopen("libssl.so", RTLD_LAZY | RTLD_GLOBAL) == NULL && - dlopen("libssl.so.1.0.0", RTLD_LAZY | RTLD_GLOBAL) == NULL && - dlopen("libssl.so.10", RTLD_LAZY | RTLD_GLOBAL) == NULL) { - throw_by_name(env, "java/lang/UnsupportedOperationException", "Failed to load libssl.so"); + if (dlopen("libssl.so.3", RTLD_LAZY | RTLD_GLOBAL) == NULL) { + throw_by_name(env, "java/lang/UnsupportedOperationException", "Failed to load libssl.so.3"); return; } @@ -464,8 +497,113 @@ Java_one_nio_net_NativeSslContext_init(JNIEnv* env, jclass cls) { f_ctx = cache_field(env, "one/nio/net/NativeSslContext", "ctx", "J"); f_ssl = cache_field(env, "one/nio/net/NativeSslSocket", "ssl", "J"); + f_isEarlyDataAccepted = cache_field(env, "one/nio/net/NativeSslSocket", "isEarlyDataAccepted", "Z"); + f_isHandshakeDone = cache_field(env, "one/nio/net/NativeSslSocket", "isHandshakeDone", "Z"); preclosed_socket = socket(PF_INET, SOCK_STREAM, 0); + + (*env)->GetJavaVM(env, &global_vm); + c_KeylogHolder = (*env)->NewGlobalRef(env, (*env)->FindClass(env, "one/nio/net/KeylogHolder")); + m_log = (*env)->GetStaticMethodID(env, c_KeylogHolder, "log", "(Ljava/lang/String;Ljava/net/InetSocketAddress;)V"); + + c_SslSessionCacheSingleton = (*env)->NewGlobalRef(env, (*env)->FindClass(env, "one/nio/net/SslSessionCache$Singleton")); + c_SslSessionCache = (*env)->NewGlobalRef(env, (*env)->FindClass(env, "one/nio/net/SslSessionCache")); + + m_getInstance = (*env)->GetStaticMethodID(env, c_SslSessionCacheSingleton, "getInstance", "()Lone/nio/net/SslSessionCache;"); + m_clearInstance = (*env)->GetStaticMethodID(env, c_SslSessionCacheSingleton, "clearInstance", "()V"); + m_addSession = (*env)->GetMethodID(env, c_SslSessionCache, "addSession", "([B[B)V"); + m_getSession = (*env)->GetMethodID(env, c_SslSessionCache, "getSession", "([B)[B"); + m_removeSession = (*env)->GetMethodID(env, c_SslSessionCache, "removeSession", "([B)V"); +} + +static int new_session_cb(SSL* ssl, SSL_SESSION* ssl_session) { + JNIEnv* env; + if (JNI_OK != (*global_vm)->GetEnv(global_vm, (void**)&env, JNI_VERSION_1_8)) { + return 0; + } + jobject sslSessionCache = (*env)->CallStaticObjectMethod(env, c_SslSessionCacheSingleton, m_getInstance); + if (sslSessionCache == NULL) { + return 0; + } + + int session_id_len; + const char* session_id = SSL_SESSION_get_id(ssl_session, &session_id_len); + jbyteArray sessionId = (*env)->NewByteArray(env, session_id_len); + if (sessionId == NULL) { + return 0; + } + (*env)->SetByteArrayRegion(env, sessionId, 0, session_id_len, (jbyte*)session_id); + + int session_len = i2d_SSL_SESSION(ssl_session, NULL); + if (session_len == 0) { + return 0; + } + jbyteArray session = (*env)->NewByteArray(env, session_len); + if (session == NULL) { + return 0; + } + + jbyte* b_session = (*env)->GetByteArrayElements(env, session, NULL); + unsigned char* ptr = (unsigned char*)b_session; + i2d_SSL_SESSION(ssl_session, &ptr); + (*env)->ReleaseByteArrayElements(env, session, b_session, 0); + + + (*env)->CallObjectMethod(env, sslSessionCache, m_addSession, sessionId, session); + return 0; +} + +static SSL_SESSION* get_session_cb(SSL* ssl, const unsigned char* session_id, int session_id_len, int* copy) { + *copy = 0; + + JNIEnv* env; + if (JNI_OK != (*global_vm)->GetEnv(global_vm, (void**)&env, JNI_VERSION_1_8)) { + return NULL; + } + jobject sslSessionCache = (*env)->CallStaticObjectMethod(env, c_SslSessionCacheSingleton, m_getInstance); + if (sslSessionCache == NULL) { + return NULL; + } + + jbyteArray sessionId = (*env)->NewByteArray(env, session_id_len); + if (sessionId == NULL) { + return NULL; + } + (*env)->SetByteArrayRegion(env, sessionId, 0, session_id_len, (jbyte*)session_id); + + jbyteArray session = (*env)->CallObjectMethod(env, sslSessionCache, m_getSession, sessionId); + if (session == NULL) { + return NULL; + } + jbyte* b_session = (*env)->GetByteArrayElements(env, session, NULL); + if (b_session == NULL) { + return NULL; + } + + int session_len = (*env)->GetArrayLength(env, session); + const unsigned char* ptr = (const unsigned char*)b_session; + SSL_SESSION* ssl_session = d2i_SSL_SESSION(NULL, &ptr, session_len); + (*env)->ReleaseByteArrayElements(env, session, b_session, JNI_ABORT); + return ssl_session; +} + +static void remove_session_cb(SSL_CTX* ssl, SSL_SESSION* ssl_session) { + JNIEnv* env; + if (JNI_OK != (*global_vm)->GetEnv(global_vm, (void**)&env, JNI_VERSION_1_8)) { + return; + } + jobject sslSessionCache = (*env)->CallStaticObjectMethod(env, c_SslSessionCacheSingleton, m_getInstance); + if (sslSessionCache == NULL) { + return; + } + + int session_id_len; + const char* session_id = SSL_SESSION_get_id(ssl_session, &session_id_len); + jbyteArray sessionId = (*env)->NewByteArray(env, session_id_len); + if (sessionId != NULL) { + (*env)->SetByteArrayRegion(env, sessionId, 0, session_id_len, (jbyte*)session_id); + (*env)->CallObjectMethod(env, sslSessionCache, m_removeSession, sessionId); + } } JNIEXPORT jlong JNICALL @@ -494,7 +632,6 @@ Java_one_nio_net_NativeSslContext_ctxNew(JNIEnv* env, jclass cls) { SSL_CTX_set_app_data(ctx, appData); setup_dh_params(ctx); - setup_ecdh_params(ctx); return (jlong)(intptr_t)ctx; } @@ -523,13 +660,13 @@ Java_one_nio_net_NativeSslContext_getDebug(JNIEnv* env, jobject self) { } JNIEXPORT void JNICALL -Java_one_nio_net_NativeSslContext_setOptions(JNIEnv* env, jobject self, jint options) { +Java_one_nio_net_NativeSslContext_setOptions(JNIEnv* env, jobject self, jlong options) { SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); SSL_CTX_set_options(ctx, options); } JNIEXPORT void JNICALL -Java_one_nio_net_NativeSslContext_clearOptions(JNIEnv* env, jobject self, jint options) { +Java_one_nio_net_NativeSslContext_clearOptions(JNIEnv* env, jobject self, jlong options) { SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); SSL_CTX_clear_options(ctx, options); } @@ -539,14 +676,20 @@ Java_one_nio_net_NativeSslContext_setRdrand(JNIEnv* env, jobject self, jboolean if (rdrand) { OPENSSL_init_crypto(/* OPENSSL_INIT_ENGINE_RDRAND */ 0x200L, NULL); ENGINE* e = ENGINE_by_id("rdrand"); - if (e == NULL || !ENGINE_init(e) || !ENGINE_set_default_RAND(e)) { - throw_ssl_exception(env); + if (e != NULL) { + if (ENGINE_init(e) && ENGINE_set_default_RAND(e)) { + RAND_set_rand_method(ENGINE_get_RAND(e)); + ENGINE_free(e); + return; + } + ENGINE_free(e); } - RAND_set_rand_method(ENGINE_get_RAND(e)); + throw_ssl_exception(env); } else { ENGINE* e = ENGINE_by_id("rdrand"); if (e != NULL) { ENGINE_unregister_RAND(e); + ENGINE_free(e); } ERR_clear_error(); } @@ -566,6 +709,19 @@ Java_one_nio_net_NativeSslContext_setCiphers(JNIEnv* env, jobject self, jstring } } +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSslContext_setCurve(JNIEnv* env, jobject self, jstring curve) { + SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); + if (curve != NULL) { + const char* value = (*env)->GetStringUTFChars(env, curve, NULL); + int result = SSL_CTX_set1_curves_list(ctx, value); + (*env)->ReleaseStringUTFChars(env, curve, value); + if (result == 0) { + throw_ssl_exception(env); + } + } +} + JNIEXPORT void JNICALL Java_one_nio_net_NativeSslContext_setCertificate(JNIEnv* env, jobject self, jstring certFile) { SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); @@ -783,7 +939,7 @@ Java_one_nio_net_NativeSslContext_setSNI0(JNIEnv* env, jobject self, jbyteArray contexts = (jlong*)(names + names_len); (*env)->GetLongArrayRegion(env, sniContexts, 0, contexts_len, contexts); } - + if (pthread_rwlock_wrlock(&appData->lock) != 0) { throw_by_name(env, "javax/net/ssl/SSLException", "Invalid state of appData lock"); free(names); @@ -793,7 +949,7 @@ Java_one_nio_net_NativeSslContext_setSNI0(JNIEnv* env, jobject self, jbyteArray free(sni->names); sni->names = names; sni->contexts = contexts; - + if (names != NULL) { SSL_CTX_set_tlsext_servername_callback(ctx, sni_callback); } @@ -801,6 +957,25 @@ Java_one_nio_net_NativeSslContext_setSNI0(JNIEnv* env, jobject self, jbyteArray pthread_rwlock_unlock(&appData->lock); } +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSslContext_setCompressionAlgorithms0(JNIEnv* env, jobject self, jintArray algorithms) { +#ifdef TLSEXT_comp_cert_limit + SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); + if (algorithms != NULL) { + jint len = (*env)->GetArrayLength(env, algorithms); + jint* algs = (*env)->GetIntArrayElements(env, algorithms, NULL); + + int result = SSL_CTX_set1_cert_comp_preference(ctx, (int*)algs, len); + (*env)->ReleaseIntArrayElements(env, algorithms, algs, JNI_ABORT); + if (result == 0) { + throw_by_name(env, "javax/net/ssl/SSLException", "Cannot set certificate compression algorithm"); + return; + } + SSL_CTX_compress_certs(ctx, 0); + } +#endif +} + JNIEXPORT void JNICALL Java_one_nio_net_NativeSslContext_setSessionId(JNIEnv* env, jobject self, jbyteArray sessionId) { SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); @@ -816,7 +991,32 @@ Java_one_nio_net_NativeSslContext_setSessionId(JNIEnv* env, jobject self, jbyteA } JNIEXPORT void JNICALL -Java_one_nio_net_NativeSslContext_setCacheSize(JNIEnv* env, jobject self, jint size) { +Java_one_nio_net_NativeSslContext_setCacheMode(JNIEnv* env, jobject self, jint mode) { + SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); + + SSL_CTX_sess_set_get_cb(ctx, mode == CACHE_MODE_EXTERNAL ? get_session_cb : NULL); + SSL_CTX_sess_set_new_cb(ctx, mode == CACHE_MODE_EXTERNAL ? new_session_cb : NULL); + SSL_CTX_sess_set_remove_cb(ctx, mode == CACHE_MODE_EXTERNAL ? remove_session_cb : NULL); + + switch (mode) { + case CACHE_MODE_NONE: + (*env)->CallStaticObjectMethod(env, c_SslSessionCacheSingleton, m_clearInstance); + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF); + break; + case CACHE_MODE_INTERNAL: + (*env)->CallStaticObjectMethod(env, c_SslSessionCacheSingleton, m_clearInstance); + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER); + break; + case CACHE_MODE_EXTERNAL: + SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_SERVER | SSL_SESS_CACHE_NO_INTERNAL_LOOKUP); + break; + default: + throw_illegal_argument_msg(env, "Unknown cache mode value"); + } +} + +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSslContext_setInternalCacheSize(JNIEnv* env, jobject self, jint size) { SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); SSL_CTX_sess_set_cache_size(ctx, size); } @@ -852,16 +1052,32 @@ Java_one_nio_net_NativeSslContext_getSessionCounters(JNIEnv* env, jobject self, return values; } +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSslContext_setMaxEarlyData(JNIEnv* env, jobject self, jint size) { +#if (OPENSSL_VERSION_MAJOR >= 3) + SSL_CTX* ctx = (SSL_CTX*)(intptr_t)(*env)->GetLongField(env, self, f_ctx); + SSL_CTX_set_max_early_data(ctx, size); +#endif +} + JNIEXPORT jlong JNICALL Java_one_nio_net_NativeSslSocket_sslNew(JNIEnv* env, jclass cls, jint fd, jlong ctx, jboolean serverMode) { SSL* ssl = SSL_new((SSL_CTX*)(intptr_t)ctx); if (ssl != NULL && SSL_set_fd(ssl, fd)) { if (serverMode) { SSL_set_accept_state(ssl); - SSL_set_app_data(ssl, (char*)SF_SERVER); + + intptr_t flags = SF_SERVER; +#if (OPENSSL_VERSION_MAJOR >= 3) + flags |= SSL_CTX_get_max_early_data((SSL_CTX*)ctx) > 0 ? SF_EARLY_DATA_ENABLED : 0; +#endif + SSL_set_app_data(ssl, (char*)flags); } else { SSL_set_connect_state(ssl); } +#ifdef SSL_OP_NO_RENEGOTIATION + SSL_set_options(ssl, SSL_OP_NO_RENEGOTIATION); +#endif return (jlong)(intptr_t)ssl; } @@ -899,8 +1115,22 @@ Java_one_nio_net_NativeSslSocket_writeRaw(JNIEnv* env, jobject self, jlong buf, throw_socket_closed(env); return 0; } else { +#if (OPENSSL_VERSION_MAJOR >= 3) + if (!SSL_is_init_finished(ssl)) { + while (1) { + size_t written; + int result = SSL_write_early_data(ssl, (void*)(intptr_t)buf, count, &written); + if (result == 1) { + return written; + } else if ((result = check_ssl_error(env, ssl, 0)) != SSL_ERROR_WANT_WRITE || errno != EINTR) { + return result == SSL_ERROR_WANT_READ ? -1 : 0; + } + } + } +#endif while (1) { int result = SSL_write(ssl, (void*)(intptr_t)buf, count); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); if (result > 0) { return result; } else if (check_ssl_error(env, ssl, result) != SSL_ERROR_WANT_WRITE || errno != EINTR) { @@ -920,9 +1150,24 @@ Java_one_nio_net_NativeSslSocket_write(JNIEnv* env, jobject self, jbyteArray dat return 0; } else { if (count > MAX_STACK_BUF) count = MAX_STACK_BUF; - (*env)->GetByteArrayRegion(env, data, offset, count, buf); + (*env)->GetByteArrayRegion(env, data, offset, count <= MAX_STACK_BUF ? count : MAX_STACK_BUF, buf); + +#if (OPENSSL_VERSION_MAJOR >= 3) + if (!SSL_is_init_finished(ssl)) { + while (1) { + size_t written; + int result = SSL_write_early_data(ssl, (void*)(intptr_t)buf, count, &written); + if (result == 1) { + return written; + } else if ((result = check_ssl_error(env, ssl, 0)) != SSL_ERROR_WANT_WRITE || errno != EINTR) { + return result == SSL_ERROR_WANT_READ ? -1 : 0; + } + } + } +#endif while (1) { int result = SSL_write(ssl, (void*)(intptr_t)buf, count); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); if (result > 0) { return result; } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_WRITE || errno != EINTR) { @@ -932,6 +1177,32 @@ Java_one_nio_net_NativeSslSocket_write(JNIEnv* env, jobject self, jbyteArray dat } } +JNIEXPORT jlong JNICALL +Java_one_nio_net_NativeSslSocket_sendFile0(JNIEnv* env, jobject self, jint sourceFD, jlong offset, jlong count) { +#if (OPENSSL_VERSION_MAJOR >= 3) + SSL* ssl = (SSL*)(intptr_t) (*env)->GetLongField(env, self, f_ssl); + if (ssl == NULL) { + throw_socket_closed(env); + } else if (count != 0) { + while (1) { + int result = SSL_sendfile(ssl, sourceFD, (off_t)offset, count, 0); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); + if (result > 0) { + return result; + } else if (result == 0) { + throw_socket_closed_cached(env); + break; + } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_WRITE || errno != EINTR) { + return result == SSL_ERROR_WANT_READ ? -1 : 0; + } + } + } +#else + throw_by_name(env, "javax/net/ssl/SSLException", "Cannot use sendFile with SSL"); +#endif +} + + JNIEXPORT void JNICALL Java_one_nio_net_NativeSslSocket_writeFully(JNIEnv* env, jobject self, jbyteArray data, jint offset, jint count) { SSL* ssl = (SSL*)(intptr_t) (*env)->GetLongField(env, self, f_ssl); @@ -939,12 +1210,13 @@ Java_one_nio_net_NativeSslSocket_writeFully(JNIEnv* env, jobject self, jbyteArra if (ssl == NULL) { throw_socket_closed(env); - } else { + } else if (SSL_is_init_finished(ssl)) { while (count > 0) { int to_write = count <= MAX_STACK_BUF ? count : MAX_STACK_BUF; (*env)->GetByteArrayRegion(env, data, offset, to_write, buf); int result = SSL_write(ssl, (void*)(intptr_t)buf, to_write); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); if (result > 0) { offset += result; count -= result; @@ -952,7 +1224,43 @@ Java_one_nio_net_NativeSslSocket_writeFully(JNIEnv* env, jobject self, jbyteArra break; } } + } else { + throw_by_name(env, "javax/net/ssl/SSLException", "Too early. SSL Handshake is not finished"); + } +} + +static int ssl_socket_readRaw_early_data(JNIEnv* env, jobject self, SSL* ssl, jlong buf, jint count) { +#if (OPENSSL_VERSION_MAJOR >= 3) + intptr_t ssl_flags = (intptr_t)SSL_get_app_data(ssl); + + while (1) { + size_t bytes_read = 0; + int result; + int ed_status = SSL_read_early_data(ssl, (void*)buf, count, &bytes_read); + + switch (ed_status) { + case SSL_READ_EARLY_DATA_FINISH: + SSL_set_app_data(ssl, (char*)(ssl_flags | SF_EARLY_DATA_FINISHED)); + case SSL_READ_EARLY_DATA_SUCCESS: + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + return bytes_read; + case SSL_READ_EARLY_DATA_ERROR: + if ((result = check_ssl_error(env, ssl, ed_status)) != SSL_ERROR_WANT_READ || errno != EINTR) { + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + } + default: { + char error[64]; + snprintf(error, sizeof(error), "Unexpected Early data status (%d)", ed_status); + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + throw_by_name(env, "javax/net/ssl/SSLException", error); + } + } } +#else + // it may happen if early data settings are enabled on openssl ver < 3.0.0 + throw_by_name(env, "javax/net/ssl/SSLException", "Early data is not supported in this openssl version"); +#endif } JNIEXPORT jint JNICALL @@ -962,15 +1270,58 @@ Java_one_nio_net_NativeSslSocket_readRaw(JNIEnv* env, jobject self, jlong buf, j throw_socket_closed(env); return 0; } else { - while (1) { - int result = SSL_read(ssl, (void*)(intptr_t)buf, count); - if (result > 0) { - return result; - } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_READ || errno != EINTR) { - return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + intptr_t ssl_flags = (intptr_t)SSL_get_app_data(ssl); + bool early_data = ssl_flags & SF_EARLY_DATA_ENABLED; + if (!early_data || ssl_flags & SF_EARLY_DATA_FINISHED) { + while (1) { + int result = SSL_read(ssl, (void*)(intptr_t)buf, count); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); + if (result > 0) { + return result; + } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_READ || errno != EINTR) { + return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + } + } + } else { + return ssl_socket_readRaw_early_data(env, self, ssl, buf, count); + } + } +} + +static int ssl_socket_read_early_data(JNIEnv* env, jobject self, SSL* ssl, jbyteArray data, jint offset, jint count) { +#if (OPENSSL_VERSION_MAJOR >= 3) + jbyte buf[MAX_STACK_BUF]; + intptr_t ssl_flags = (intptr_t)SSL_get_app_data(ssl); + + while (1) { + size_t bytes_read = 0; + int result; + int ed_status = SSL_read_early_data(ssl, (void*)buf, count <= MAX_STACK_BUF ? count : MAX_STACK_BUF, &bytes_read); + + switch (ed_status) { + case SSL_READ_EARLY_DATA_FINISH: + SSL_set_app_data(ssl, (char*)(ssl_flags | SF_EARLY_DATA_FINISHED)); + case SSL_READ_EARLY_DATA_SUCCESS: + (*env)->SetByteArrayRegion(env, data, offset, bytes_read, buf); + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + return bytes_read; + case SSL_READ_EARLY_DATA_ERROR: + if ((result = check_ssl_error(env, ssl, ed_status)) != SSL_ERROR_WANT_READ || errno != EINTR) { + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + } + default: { + char error[64]; + snprintf(error, sizeof(error), "Unexpected Early data status (%d)", ed_status); + update_NativeSslSocket_isEarlyDataAccepted_field(env, self, ssl); + throw_by_name(env, "javax/net/ssl/SSLException", error); } } } +#else + // it may happen if early data settings are enabled on openssl ver < 3.0.0 + throw_by_name(env, "javax/net/ssl/SSLException", "Early data is not supported in this openssl version"); +#endif } JNIEXPORT int JNICALL @@ -982,14 +1333,21 @@ Java_one_nio_net_NativeSslSocket_read(JNIEnv* env, jobject self, jbyteArray data throw_socket_closed(env); return 0; } else { - while (1) { - int result = SSL_read(ssl, buf, count <= MAX_STACK_BUF ? count : MAX_STACK_BUF); - if (result > 0) { - (*env)->SetByteArrayRegion(env, data, offset, result, buf); - return result; - } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_READ || errno != EINTR) { - return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + intptr_t ssl_flags = (intptr_t)SSL_get_app_data(ssl); + bool early_data = ssl_flags & SF_EARLY_DATA_ENABLED; + if (!early_data || ssl_flags & SF_EARLY_DATA_FINISHED) { + while (1) { + int result = SSL_read(ssl, buf, count <= MAX_STACK_BUF ? count : MAX_STACK_BUF); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); + if (result > 0) { + (*env)->SetByteArrayRegion(env, data, offset, result, buf); + return result; + } else if ((result = check_ssl_error(env, ssl, result)) != SSL_ERROR_WANT_READ || errno != EINTR) { + return result == SSL_ERROR_WANT_WRITE ? -1 : 0; + } } + } else { + return ssl_socket_read_early_data(env, self, ssl, data, offset, count); } } } @@ -1004,6 +1362,7 @@ Java_one_nio_net_NativeSslSocket_readFully(JNIEnv* env, jobject self, jbyteArray } else { while (count > 0) { int result = SSL_read(ssl, buf, count <= MAX_STACK_BUF ? count : MAX_STACK_BUF); + update_NativeSslSocket_isHandshakeDone_field(env, self, ssl); if (result > 0) { (*env)->SetByteArrayRegion(env, data, offset, result, buf); offset += result; @@ -1022,7 +1381,7 @@ Java_one_nio_net_NativeSslSocket_sslPeerCertificate(JNIEnv* env, jobject self) { return NULL; } - X509* cert = SSL_get_peer_certificate(ssl); + X509* cert = SSL_get1_peer_certificate(ssl); if (cert == NULL) { return NULL; } @@ -1072,7 +1431,7 @@ Java_one_nio_net_NativeSslSocket_sslCertName(JNIEnv* env, jobject self, jint whi return NULL; } - X509* cert = SSL_get_peer_certificate(ssl); + X509* cert = SSL_get1_peer_certificate(ssl); if (cert == NULL) { return NULL; } @@ -1124,6 +1483,20 @@ Java_one_nio_net_NativeSslSocket_sslSessionReused(JNIEnv* env, jobject self) { return ssl != NULL && SSL_session_reused(ssl) ? JNI_TRUE : JNI_FALSE; } +static void update_NativeSslSocket_isEarlyDataAccepted_field(JNIEnv* env, jobject self, SSL* ssl) { +#ifdef SSL_EARLY_DATA_ACCEPTED + jboolean isEarlyDataAccepted = ssl != NULL + && SSL_get_early_data_status(ssl) == SSL_EARLY_DATA_ACCEPTED ? JNI_TRUE : JNI_FALSE; + (*env)->SetBooleanField(env, self, f_isEarlyDataAccepted, isEarlyDataAccepted); +#endif +} + + +static void update_NativeSslSocket_isHandshakeDone_field(JNIEnv* env, jobject self, SSL* ssl) { + jboolean isHandshakeDone = ssl != NULL && SSL_is_init_finished(ssl) ? JNI_TRUE : JNI_FALSE; + (*env)->SetBooleanField(env, self, f_isHandshakeDone, isHandshakeDone); +} + JNIEXPORT jint JNICALL Java_one_nio_net_NativeSslSocket_sslSessionTicket(JNIEnv* env, jobject self) { SSL* ssl = (SSL*)(intptr_t) (*env)->GetLongField(env, self, f_ssl); @@ -1140,3 +1513,29 @@ Java_one_nio_net_NativeSslSocket_sslCurrentCipher(JNIEnv* env, jobject self) { const char* name = SSL_CIPHER_get_name(SSL_get_current_cipher(ssl)); return name == NULL ? NULL : (*env)->NewStringUTF(env, name); } + +#if (OPENSSL_VERSION_MAJOR >= 3) +static void keylog_callback(const SSL *ssl, const char *line) { + JNIEnv* env; + if (JNI_OK != (*global_vm)->GetEnv(global_vm, (void**)&env, JNI_VERSION_1_8)) { + return; + } + + int fd = SSL_get_fd(ssl); + struct sockaddr_storage sa; + socklen_t len = sizeof(sa); + if (getpeername(fd, (struct sockaddr*)&sa, &len) == 0) { + jobject isa = sockaddr_to_java(env, &sa, len); + jstring key_line = (*env)->NewStringUTF(env, line); + (*env)->CallStaticVoidMethod(env, c_KeylogHolder, m_log, key_line, isa); + } +} +#endif + +JNIEXPORT void JNICALL +Java_one_nio_net_NativeSslContext_setKeylog(JNIEnv* env, jobject self, jboolean keylog) { +#if (OPENSSL_VERSION_MAJOR >= 3) + SSL_CTX* ctx = (SSL_CTX*)(intptr_t) (*env)->GetLongField(env, self, f_ctx); + SSL_CTX_set_keylog_callback(ctx, keylog ? keylog_callback : NULL); +#endif +} diff --git a/src/one/nio/os/Cpus.java b/src/one/nio/os/Cpus.java index 1f0de3f..5adcdb9 100644 --- a/src/one/nio/os/Cpus.java +++ b/src/one/nio/os/Cpus.java @@ -43,7 +43,7 @@ private static BitSet cpus(String rangeFile) { String[] s = range.split("-"); int from = Integer.parseInt(s[0]); int to = s.length == 1 ? from : Integer.parseInt(s[1]); - cpus.set(from, to); + cpus.set(from, to + 1); } return cpus; } catch (IOException e) { diff --git a/src/one/nio/os/Proc.java b/src/one/nio/os/Proc.java index 8a13680..64a0dc3 100755 --- a/src/one/nio/os/Proc.java +++ b/src/one/nio/os/Proc.java @@ -44,6 +44,10 @@ public final class Proc { /** * The same as above, but allows an arbitrary long mask + * + * @param pid an id of a thread. If pid is zero, then the calling thread is used + * @param mask a thread's CPU affinity mask + * @return 0 on success or errno on failure */ public static native int setAffinity(int pid, long[] mask); public static native long[] getAffinity(int pid); @@ -63,6 +67,7 @@ public static void setDedicatedCpu(int pid, int cpu) { /** * @param pid pid or tid. 0 for current thread + * @param policy one of the POSIX scheduling policies * @return 0 on success or errno on failure */ public static native int sched_setscheduler(int pid, int policy); @@ -90,6 +95,7 @@ public static void setDedicatedCpu(int pid, int cpu) { * setpriority() shall set the nice value to the highest supported value. * * @param pid pid or tid. 0 for current thread + * @param value a nice value * @return 0 on success; otherwise, -1 shall be returned and errno set to indicate the error. */ public static native int setpriority(int pid, int value); diff --git a/src/one/nio/os/bpf/Bpf.java b/src/one/nio/os/bpf/Bpf.java index b4a85bf..377ac33 100644 --- a/src/one/nio/os/bpf/Bpf.java +++ b/src/one/nio/os/bpf/Bpf.java @@ -43,6 +43,8 @@ public class Bpf { static native int[] progGetMapIds(int fd) throws IOException; + static native void progTestRun(int fd, byte[] dataIn, int lenDataIn, byte[] dataOut, byte[] ctxIn, int lenCtxIn, byte[] ctxOut, int[] retvals /* data_size_out,ctx_size_out,duration,retval */) throws IOException; + static native int rawTracepointOpen(int progFd, String name) throws IOException; static native String mapGetInfo(int fd, int[] result /*type,id,key_size,value_size,max_entries,flags*/) throws IOException; diff --git a/src/one/nio/os/bpf/BpfProg.java b/src/one/nio/os/bpf/BpfProg.java index 69b8189..55209f0 100644 --- a/src/one/nio/os/bpf/BpfProg.java +++ b/src/one/nio/os/bpf/BpfProg.java @@ -69,7 +69,36 @@ public int[] getMapIds() throws IOException { return Bpf.progGetMapIds(fd()); } + public void testRun(TestRunContext context) throws IOException { + assert context.ctxIn == null || context.lenCtxIn <= context.ctxIn.length; + assert context.dataIn == null || context.lenDataIn <= context.dataIn.length; + + Bpf.progTestRun(fd(), context.dataIn, context.lenDataIn, context.dataOut, context.ctxIn, context.lenCtxIn, context.ctxOut, context.retvals); + } + public static Iterable getAllIds() { return () -> new IdsIterator(Bpf.OBJ_PROG); } + + public static class TestRunContext { + public byte[] dataIn; + public int lenDataIn; + public byte[] ctxIn; + public int lenCtxIn; + public byte[] dataOut; + public byte[] ctxOut; + int[] retvals = new int[4]; + public int lenDataOut() { + return retvals[0]; + } + public int lenCtxOut() { + return retvals[1]; + } + public int duration() { + return retvals[2]; + } + public int result() { + return retvals[3]; + } + } } diff --git a/src/one/nio/os/native/bpf.c b/src/one/nio/os/native/bpf.c index 91a58ef..3fa42d6 100644 --- a/src/one/nio/os/native/bpf.c +++ b/src/one/nio/os/native/bpf.c @@ -53,7 +53,7 @@ Java_one_nio_os_bpf_Bpf_progLoad(JNIEnv* env, jclass cls, jstring pathname, jint if (libbpf == NULL) { libbpf = dlopen("libbpf.so.0", RTLD_LAZY | RTLD_GLOBAL); if (libbpf == NULL) { - throw_by_name(env, "java/lang/UnsupportedOperationException", "Failed to load libbpf.so"); + throw_by_name(env, "java/lang/UnsupportedOperationException", "Failed to load libbpf.so or libbpf.so.0"); return -EINVAL; } } @@ -225,6 +225,70 @@ Java_one_nio_os_bpf_Bpf_progGetMapIds(JNIEnv* env, jclass cls, int bpf_fd) { return result; } +JNIEXPORT void JNICALL +Java_one_nio_os_bpf_Bpf_progTestRun(JNIEnv* env, jclass cls, jint prog_fd, jbyteArray data_in, jint len_data_in, jbyteArray data_out, + jbyteArray ctx_in, jint len_ctx_in, jbyteArray ctx_out, jintArray retvals /* data_size_out,ctx_size_out,duration,retval */) { + + union bpf_attr attr; + int res; + + memset(&attr, 0, sizeof(attr)); + attr.test.prog_fd = prog_fd; + + jbyte *b_ctx_in=NULL, *b_data_in=NULL, *b_ctx_out=NULL, *b_data_out=NULL; + + + if (ctx_in != NULL) { + attr.test.ctx_size_in = len_ctx_in; + b_ctx_in = (*env)->GetByteArrayElements(env, ctx_in, NULL); + attr.test.ctx_in = ptr_to_u64(b_ctx_in); + } + if (data_in != NULL) { + attr.test.data_size_in = len_data_in; + b_data_in = (*env)->GetByteArrayElements(env, data_in, NULL); + attr.test.data_in = ptr_to_u64(b_data_in); + } + if (ctx_out != NULL) { + attr.test.ctx_size_out = (*env)->GetArrayLength(env, ctx_out); + b_ctx_out = (*env)->GetByteArrayElements(env, ctx_out, NULL); + attr.test.ctx_out = ptr_to_u64(b_ctx_out); + } + if (data_out != NULL) { + attr.test.data_size_out = (*env)->GetArrayLength(env, data_out); + b_data_out = (*env)->GetByteArrayElements(env, data_out, NULL); + attr.test.data_out = ptr_to_u64(b_data_out); + } + + res = sys_bpf(BPF_PROG_TEST_RUN, &attr, sizeof(attr)); + + if (retvals != NULL) { + const jint b_result[] = { + attr.test.data_size_out, + attr.test.ctx_size_out, + attr.test.duration, + attr.test.retval + }; + (*env)->SetIntArrayRegion(env, retvals, 0, sizeof(b_result)/sizeof(int), b_result); + } + + if (ctx_in != NULL) { + (*env)->ReleaseByteArrayElements(env, ctx_in, b_ctx_in, JNI_ABORT); + } + if (data_in != NULL) { + (*env)->ReleaseByteArrayElements(env, data_in, b_data_in, JNI_ABORT); + } + if (ctx_out != NULL) { + (*env)->ReleaseByteArrayElements(env, ctx_out, b_ctx_out, 0); + } + if (data_out != NULL) { + (*env)->ReleaseByteArrayElements(env, data_out, b_data_out, 0); + } + + if (res < 0) { + throw_io_exception(env); + } +} + JNIEXPORT jint JNICALL Java_one_nio_os_bpf_Bpf_progGetFdById(JNIEnv* env, jclass cls, jint id) { union bpf_attr attr; diff --git a/src/one/nio/pool/SocketPool.java b/src/one/nio/pool/SocketPool.java index d5b3f50..bc23db5 100755 --- a/src/one/nio/pool/SocketPool.java +++ b/src/one/nio/pool/SocketPool.java @@ -17,6 +17,7 @@ package one.nio.pool; import one.nio.mgt.Management; +import one.nio.net.SslClientContextFactory; import one.nio.net.ConnectionString; import one.nio.net.Proxy; import one.nio.net.Socket; @@ -28,6 +29,7 @@ public class SocketPool extends Pool implements SocketPoolMXBean { protected int readTimeout; protected int connectTimeout; protected int tos; + protected boolean thinLto; protected SslContext sslContext; protected Proxy proxy; @@ -42,6 +44,7 @@ public SocketPool(ConnectionString conn) { this.connectTimeout = conn.getIntParam("connectTimeout", 1000); this.tos = conn.getIntParam("tos", 0); this.fifo = conn.getBooleanParam("fifo", false); + this.thinLto = conn.getBooleanParam("thinLto", false); setProperties(conn); initialize(); @@ -53,7 +56,7 @@ public SocketPool(ConnectionString conn) { protected void setProperties(ConnectionString conn) { if ("ssl".equals(conn.getProtocol())) { - sslContext = SslContext.getDefault(); + sslContext = SslClientContextFactory.create(); } } @@ -152,7 +155,7 @@ public void setProxy(Proxy proxy) { public Socket createObject() throws PoolException { Socket socket = null; try { - socket = Socket.create(); + socket = Socket.createClientSocket(sslContext); socket.setKeepAlive(true); socket.setNoDelay(true); @@ -160,6 +163,10 @@ public Socket createObject() throws PoolException { socket.setTos(tos); } + if (thinLto) { + socket.setThinLinearTimeouts(true); + } + socket.setTimeout(connectTimeout); if (proxy == null) { socket.connect(host, port); diff --git a/src/one/nio/rpc/RpcClient.java b/src/one/nio/rpc/RpcClient.java index 0fdbc58..6d31ab5 100755 --- a/src/one/nio/rpc/RpcClient.java +++ b/src/one/nio/rpc/RpcClient.java @@ -36,8 +36,16 @@ public class RpcClient extends SocketPool implements InvocationHandler { protected static final byte[][] uidLocks = new byte[64][0]; + private final StackTraceElement remoteMarkerElement; + public RpcClient(ConnectionString conn) { super(conn); + + this.remoteMarkerElement = new StackTraceElement( + "<>", // pseudo class name + "remoteCall", + this.name(), // pseudo file name, will contain host and port + -1); } public Object invoke(Object request) throws Exception { @@ -72,11 +80,34 @@ public Object invoke(Object request, int timeout) throws Exception { provideSerializer(Repository.requestSerializer(uid)); rawResponse = invokeRaw(request, readTimeout); } else { - throw (Exception) response; + Exception exception = (Exception) response; + addLocalStack(exception, request); + throw exception; } } } + private void addLocalStack(Throwable e, Object remoteRequest) { + StackTraceElement[] remoteStackTrace = e.getStackTrace(); + StackTraceElement[] localStackTrace = new Exception().getStackTrace(); + + if (remoteStackTrace == null || localStackTrace == null) { + return; + } + StackTraceElement[] newStackTrace = new StackTraceElement[remoteStackTrace.length + localStackTrace.length]; + + System.arraycopy(remoteStackTrace, 0, newStackTrace, 0, remoteStackTrace.length); + newStackTrace[remoteStackTrace.length] = remoteMarkerElement; + + System.arraycopy(localStackTrace, + 1, // starting from 1 to skip 'addLocalStack' line in stack trace + newStackTrace, + remoteStackTrace.length + 1, + localStackTrace.length - 1); + + e.setStackTrace(newStackTrace); + } + @Override public Object invoke(Object proxy, Method method, Object... args) throws Exception { if (method.getDeclaringClass() == Object.class) { diff --git a/src/one/nio/rpc/RpcSession.java b/src/one/nio/rpc/RpcSession.java index a378c9e..036ea76 100755 --- a/src/one/nio/rpc/RpcSession.java +++ b/src/one/nio/rpc/RpcSession.java @@ -16,6 +16,14 @@ package one.nio.rpc; +import java.io.IOException; +import java.io.NotSerializableException; +import java.net.InetSocketAddress; +import java.util.concurrent.RejectedExecutionException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + import one.nio.http.Request; import one.nio.net.ProxyProtocol; import one.nio.net.Session; @@ -30,11 +38,8 @@ import one.nio.serial.SerializerNotFoundException; import one.nio.util.Utf8; -import java.io.IOException; -import java.net.InetSocketAddress; -import java.util.concurrent.RejectedExecutionException; - public class RpcSession extends Session { + protected static final Logger logSerialize = LoggerFactory.getLogger("one-serializer-logger"); protected static final int BUFFER_SIZE = 8000; protected static final byte HTTP_REQUEST_UID = (byte) Repository.get(Request.class).uid(); @@ -198,10 +203,14 @@ protected int writeResponse(Object response) throws IOException { RpcPacket.checkWriteSize(responseSize); byte[] buffer = new byte[responseSize + 4]; - DataStream ds = css.hasCycles() ? new SerializeStream(buffer, css.capacity()) : new DataStream(buffer); - ds.writeInt(responseSize); - ds.writeObject(response); - + try { + DataStream ds = css.hasCycles() ? new SerializeStream(buffer, css.capacity()) : new DataStream(buffer); + ds.writeInt(responseSize); + ds.writeObject(response); + } catch (IOException | RuntimeException e) { + logSerialize.warn("Exception while serializing: {}", response, e); + return writeResponse(new NotSerializableException(e.getMessage())); + } super.write(buffer, 0, buffer.length); return responseSize; } diff --git a/src/one/nio/serial/DataStream.java b/src/one/nio/serial/DataStream.java index 4a2a6a6..a74f75e 100755 --- a/src/one/nio/serial/DataStream.java +++ b/src/one/nio/serial/DataStream.java @@ -183,8 +183,7 @@ public void writeFrom(long address, int len) throws IOException { } public int read() throws IOException { - long offset = alloc(1); - return unsafe.getByte(array, offset); + return unsafe.getByte(array, alloc(1)); } public int read(byte[] b) throws IOException { @@ -198,13 +197,11 @@ public int read(byte[] b, int off, int len) throws IOException { } public void readFully(byte[] b) throws IOException { - long offset = alloc(b.length); - unsafe.copyMemory(array, offset, b, byteArrayOffset, b.length); + unsafe.copyMemory(array, alloc(b.length), b, byteArrayOffset, b.length); } public void readFully(byte[] b, int off, int len) throws IOException { - long offset = alloc(len); - unsafe.copyMemory(array, offset, b, byteArrayOffset + off, len); + unsafe.copyMemory(array, alloc(len), b, byteArrayOffset + off, len); } public long skip(long n) throws IOException { @@ -218,43 +215,35 @@ public int skipBytes(int n) throws IOException { } public boolean readBoolean() throws IOException { - long offset = alloc(1); - return unsafe.getBoolean(array, offset); + return unsafe.getBoolean(array, alloc(1)); } public byte readByte() throws IOException { - long offset = alloc(1); - return unsafe.getByte(array, offset); + return unsafe.getByte(array, alloc(1)); } public int readUnsignedByte() throws IOException { - long offset = alloc(1); - return unsafe.getByte(array, offset) & 0xff; + return unsafe.getByte(array, alloc(1)) & 0xff; } public short readShort() throws IOException { - long offset = alloc(2); - return Short.reverseBytes(unsafe.getShort(array, offset)); + return Short.reverseBytes(unsafe.getShort(array, alloc(2))); } public int readUnsignedShort() throws IOException { - long offset = alloc(2); - return Short.reverseBytes(unsafe.getShort(array, offset)) & 0xffff; + return Short.reverseBytes(unsafe.getShort(array, alloc(2))) & 0xffff; } public char readChar() throws IOException { - long offset = alloc(2); - return Character.reverseBytes(unsafe.getChar(array, offset)); + return Character.reverseBytes(unsafe.getChar(array, alloc(2))); } public int readInt() throws IOException { - long offset = alloc(4); - return Integer.reverseBytes(unsafe.getInt(array, offset)); + return Integer.reverseBytes(unsafe.getInt(array, alloc(4))); } public long readLong() throws IOException { - long offset = alloc(8); - return Long.reverseBytes(unsafe.getLong(array, offset)); + return Long.reverseBytes(unsafe.getLong(array, alloc(8))); } public float readFloat() throws IOException { @@ -287,8 +276,7 @@ public String readUTF() throws IOException { if (length > 0x7fff) { length = (length & 0x7fff) << 16 | readUnsignedShort(); } - long offset = alloc(length); - return Utf8.read(array, offset, length); + return Utf8.read(array, alloc(length), length); } public Object readObject() throws IOException, ClassNotFoundException { @@ -321,8 +309,7 @@ public void read(ByteBuffer dst) throws IOException { } public void readTo(long address, int len) throws IOException { - long offset = alloc(len); - unsafe.copyMemory(array, offset, null, address, len); + unsafe.copyMemory(array, alloc(len), null, address, len); } public ByteBuffer byteBuffer(int len) throws IOException { diff --git a/src/one/nio/serial/DeserializeStream.java b/src/one/nio/serial/DeserializeStream.java index 7f0f8c1..264b6ee 100755 --- a/src/one/nio/serial/DeserializeStream.java +++ b/src/one/nio/serial/DeserializeStream.java @@ -76,7 +76,7 @@ public Object readObject() throws IOException, ClassNotFoundException { } @Override - public void close() throws IOException { + public void close() { context = null; } diff --git a/src/one/nio/serial/PersistStream.java b/src/one/nio/serial/PersistStream.java index b476d6c..a5bf714 100755 --- a/src/one/nio/serial/PersistStream.java +++ b/src/one/nio/serial/PersistStream.java @@ -69,7 +69,7 @@ public void writeObject(Object obj) throws IOException { } @Override - protected long alloc(int size) throws IOException { + protected long alloc(int size) { long currentOffset = offset; if ((offset = currentOffset + size) > limit) { limit = Math.max(offset, limit * 2); diff --git a/src/one/nio/serial/SerializationContext.java b/src/one/nio/serial/SerializationContext.java index 4109896..b909208 100755 --- a/src/one/nio/serial/SerializationContext.java +++ b/src/one/nio/serial/SerializationContext.java @@ -42,6 +42,7 @@ public int capacity() { } /** + * @param obj an object to put in the context * @return index for existing objects, -1-index for new */ public int put(Object obj) { diff --git a/src/one/nio/server/AcceptorConfig.java b/src/one/nio/server/AcceptorConfig.java index df85bbc..9c6a544 100644 --- a/src/one/nio/server/AcceptorConfig.java +++ b/src/one/nio/server/AcceptorConfig.java @@ -31,10 +31,13 @@ public class AcceptorConfig { public int sendBuf; public int tos; public int backlog = 128; + @Converter(method = "size") + public int notsentLowat; public boolean keepAlive = true; public boolean noDelay = true; public boolean tcpFastOpen = true; public boolean deferAccept; public boolean reusePort; + public boolean thinLto; public SslConfig ssl; } diff --git a/src/one/nio/server/SelectorThread.java b/src/one/nio/server/SelectorThread.java index 762a994..d35286b 100755 --- a/src/one/nio/server/SelectorThread.java +++ b/src/one/nio/server/SelectorThread.java @@ -34,8 +34,8 @@ public final class SelectorThread extends PayloadThread { long sessions; int maxReady; - public SelectorThread(int num, int dedicatedCpu, SchedulingPolicy schedulingPolicy) throws IOException { - super("NIO Selector #" + num); + public SelectorThread(int num, int dedicatedCpu, SchedulingPolicy schedulingPolicy, String name) throws IOException { + super(name); this.selector = Selector.create(); this.dedicatedCpu = dedicatedCpu; setSchedulingPolicy(schedulingPolicy); diff --git a/src/one/nio/server/Server.java b/src/one/nio/server/Server.java index cb1f677..a3216c5 100755 --- a/src/one/nio/server/Server.java +++ b/src/one/nio/server/Server.java @@ -16,22 +16,21 @@ package one.nio.server; -import one.nio.net.Selector; -import one.nio.net.Session; -import one.nio.net.Socket; -import one.nio.mgt.Management; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import java.io.IOException; -import java.util.ArrayList; import java.util.Arrays; -import java.util.List; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ThreadLocalRandom; import java.util.concurrent.atomic.LongAdder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import one.nio.mgt.Management; +import one.nio.net.Selector; +import one.nio.net.Session; +import one.nio.net.Socket; +import one.nio.server.acceptor.Acceptor; +import one.nio.server.acceptor.AcceptorFactory; + public class Server implements ServerMXBean { private static final Logger log = LoggerFactory.getLogger(Server.class); @@ -42,34 +41,25 @@ public class Server implements ServerMXBean { private volatile QueueStats queueStats; protected final int port; - protected final CountDownLatch startSync; - protected volatile AcceptorThread[] acceptors; + + protected volatile Acceptor acceptor; + protected volatile SelectorThread[] selectors; protected boolean useWorkers; protected final WorkerPool workers; protected final CleanupThread cleanup; protected boolean closeSessions; + protected boolean pinAcceptors; public Server(ServerConfig config) throws IOException { - List acceptors = new ArrayList<>(); - for (AcceptorConfig ac : config.acceptors) { - for (int i = 0; i < ac.threads; i++) { - acceptors.add(new AcceptorThread(this, ac, i)); - } - } - - if (acceptors.isEmpty()) { - throw new IllegalArgumentException("No configured acceptors"); - } - - this.acceptors = acceptors.toArray(new AcceptorThread[0]); - this.startSync = new CountDownLatch(this.acceptors.length); - this.port = this.acceptors[0].port; + this.acceptor = AcceptorFactory.get(config).create(this, config.acceptors); + this.port = acceptor.getSinglePort(); int processors = Runtime.getRuntime().availableProcessors(); SelectorThread[] selectors = new SelectorThread[config.selectors != 0 ? config.selectors : processors]; for (int i = 0; i < selectors.length; i++) { - selectors[i] = new SelectorThread(i, config.affinity ? i % processors : -1, config.schedulingPolicy); + String threadName = config.formatSelectorThreadName(i); + selectors[i] = new SelectorThread(i, config.affinity ? i % processors : -1, config.schedulingPolicy, threadName); selectors[i].setPriority(config.threadPriority); } this.selectors = selectors; @@ -81,6 +71,7 @@ public Server(ServerConfig config) throws IOException { this.cleanup = new CleanupThread(selectors, config.keepAlive); this.closeSessions = config.closeSessions; + this.pinAcceptors = config.pinAcceptors; this.selectorStats = new SelectorStats(); this.queueStats = new QueueStats(); @@ -88,7 +79,7 @@ public Server(ServerConfig config) throws IOException { public synchronized void reconfigure(ServerConfig config) throws IOException { useWorkers = config.maxWorkers > 0; - if (config.minWorkers > workers.getMaximumPoolSize()) { + if (config.minWorkers > workers.getMaximumPoolSize()) { workers.setMaximumPoolSize(useWorkers ? config.maxWorkers : 2); workers.setCorePoolSize(config.minWorkers); } else { @@ -97,47 +88,15 @@ public synchronized void reconfigure(ServerConfig config) throws IOException { } workers.setQueueTime(config.queueTime); - // Create a copy of the array, since the elements will be nulled out - // to allow reconfiguring multiple acceptors with the same address:port - AcceptorThread[] oldAcceptors = acceptors.clone(); - List newAcceptors = new ArrayList<>(); - for (AcceptorConfig ac : config.acceptors) { - int threads = 0; - for (int i = 0; i < oldAcceptors.length; i++) { - AcceptorThread oldAcceptor = oldAcceptors[i]; - if (oldAcceptor != null && oldAcceptor.port == ac.port && oldAcceptor.address.equals(ac.address)) { - if (++threads <= ac.threads) { - log.info("Reconfiguring acceptor: {}", oldAcceptor.getName()); - oldAcceptor.reconfigure(ac); - oldAcceptors[i] = null; - newAcceptors.add(oldAcceptor); - } - } - } - - for (; threads < ac.threads; threads++) { - AcceptorThread newAcceptor = new AcceptorThread(this, ac, threads); - log.info("New acceptor: {}", newAcceptor.getName()); - newAcceptor.start(); - newAcceptors.add(newAcceptor); - } - } - - for (AcceptorThread oldAcceptor : oldAcceptors) { - if (oldAcceptor != null) { - log.info("Stopping acceptor: {}", oldAcceptor.getName()); - oldAcceptor.shutdown(); - } - } - - acceptors = newAcceptors.toArray(new AcceptorThread[0]); + acceptor.reconfigure(config.acceptors); int processors = Runtime.getRuntime().availableProcessors(); SelectorThread[] selectors = this.selectors; if (config.selectors > selectors.length) { SelectorThread[] newSelectors = Arrays.copyOf(selectors, config.selectors); for (int i = selectors.length; i < config.selectors; i++) { - newSelectors[i] = new SelectorThread(i, config.affinity ? i % processors : -1, config.schedulingPolicy); + String threadName = config.formatSelectorThreadName(i); + newSelectors[i] = new SelectorThread(i, config.affinity ? i % processors : -1, config.schedulingPolicy, threadName); newSelectors[i].setPriority(config.threadPriority); newSelectors[i].start(); } @@ -146,6 +105,7 @@ public synchronized void reconfigure(ServerConfig config) throws IOException { cleanup.update(this.selectors, config.keepAlive); closeSessions = config.closeSessions; + pinAcceptors = config.pinAcceptors; } public synchronized void start() { @@ -153,13 +113,10 @@ public synchronized void start() { selector.start(); } - for (AcceptorThread acceptor : acceptors) { - acceptor.start(); - } + acceptor.start(); - // Wait until all AcceptorThreads are listening for incoming connections try { - startSync.await(); + acceptor.syncStart(); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } @@ -174,9 +131,7 @@ public synchronized void stop() { cleanup.shutdown(); - for (AcceptorThread acceptor : acceptors) { - acceptor.shutdown(); - } + acceptor.shutdown(); for (SelectorThread selector : selectors) { if (closeSessions) { @@ -201,14 +156,64 @@ public void run() { }); } - protected Session createSession(Socket socket) throws RejectedSessionException { + public Session createSession(Socket socket) throws RejectedSessionException { return new Session(socket); } - protected void register(Session session) { + public void register(Session session, int acceptorIndex, int acceptorGroupSize) { + if (pinAcceptors) { + getSmallestPinnedSelector(acceptorIndex, acceptorGroupSize).register(session); + return; + } + register(session); + } + + public void register(Session session) { getSmallestSelector().register(session); } + /* + * If `pinAcceptors` is enabled for the server, accepted sessions are distributed across the disjunctive set of selectors. + * When server is configured to have less `acceptors`(K) less than `selectors`(N), the selectors group for the given acceptor + * by its indices forms an finite arithmetic sequence starting from the acceptor index with step K. + * Example: ServerConfig.acceptors = 3, ServerConfig.selectors = 8. + * Acceptor #0 -> Selectors #0, #3, #6 + * Acceptor #1 -> Selectors #1, #4, #7 + * Acceptor #2 -> Selectors #2, #5 + * Across the selectors' subset, the selector to serve the session is chosen on a random basis. + * Provided the server is configured to have more `acceptors`(K) than `selectors`(N), the serving selector index is calculated out of acceptor index modulo N. + * Example: ServerConfig.acceptors = 8, ServerConfig.selectors = 3. + * Acceptor #0 -> Selector #0 + * Acceptor #1 -> Selector #1 + * Acceptor #2 -> Selector #2 + * Acceptor #3 -> Selector #0 + * ... + * Acceptor #7 -> Selector #1 + * Base configuration 1: acceptors = 1, selectors = N. The single acceptor balances sessions across all N selectors randomly. + * Base configuration 2: acceptors = N, selectors = N. Each acceptor has a single designated selector to serve the sessions. + */ + private Selector getSmallestPinnedSelector(int idx, int total) { + Selector chosen; + SelectorThread[] selectors = this.selectors; + if (total >= selectors.length) { + chosen = selectors[idx % selectors.length].selector; + } else { + int q = selectors.length / total; + if (q * total + idx < selectors.length) { + q++; + } + if (q == 1) { + chosen = selectors[idx].selector; + } else { + ThreadLocalRandom r = ThreadLocalRandom.current(); + Selector a = selectors[r.nextInt(q) * total + idx].selector; + Selector b = selectors[r.nextInt(q) * total + idx].selector; + chosen = a.size() < b.size() ? a : b; + } + } + return chosen; + } + private Selector getSmallestSelector() { SelectorThread[] selectors = this.selectors; @@ -257,20 +262,12 @@ public int getWorkersActive() { @Override public long getAcceptedSessions() { - long result = 0; - for (AcceptorThread acceptor : acceptors) { - result += acceptor.acceptedSessions; - } - return result; + return acceptor.getAcceptedSessions(); } @Override public long getRejectedSessions() { - long result = 0; - for (AcceptorThread acceptor : acceptors) { - result += acceptor.rejectedSessions; - } - return result; + return acceptor.getRejectedSessions(); } @Override @@ -333,10 +330,7 @@ public long getRequestsRejected() { @Override public synchronized void reset() { - for (AcceptorThread acceptor : acceptors) { - acceptor.acceptedSessions = 0; - acceptor.rejectedSessions = 0; - } + acceptor.resetCounters(); for (SelectorThread selector : selectors) { selector.operations = 0; diff --git a/src/one/nio/server/ServerConfig.java b/src/one/nio/server/ServerConfig.java index 3fb2334..787a1ea 100644 --- a/src/one/nio/server/ServerConfig.java +++ b/src/one/nio/server/ServerConfig.java @@ -16,6 +16,8 @@ package one.nio.server; +import java.util.Locale; + import one.nio.config.Config; import one.nio.config.Converter; import one.nio.net.ConnectionString; @@ -24,7 +26,11 @@ @Config public class ServerConfig { + + public static String DEFAULT_SELECTOR_THREAD_NAME_FORMAT = "NIO Selector #%d"; + public AcceptorConfig[] acceptors; + public boolean multiAcceptor; public int selectors; public boolean affinity; public int minWorkers; @@ -36,6 +42,10 @@ public class ServerConfig { public int threadPriority = Thread.NORM_PRIORITY; public SchedulingPolicy schedulingPolicy; public boolean closeSessions; + public boolean pinAcceptors; + + @Converter(value = ServerConfig.class, method = "threadNameFormat") + public String selectorThreadNameFormat = DEFAULT_SELECTOR_THREAD_NAME_FORMAT; public ServerConfig() { } @@ -60,6 +70,8 @@ private ServerConfig(ConnectionString conn) { this.threadPriority = conn.getIntParam("threadPriority", Thread.NORM_PRIORITY); this.schedulingPolicy = SchedulingPolicy.valueOf(conn.getStringParam("schedulingPolicy", "OTHER")); this.closeSessions = conn.getBooleanParam("closeSessions", false); + this.keepAlive = conn.getIntParam("keepAlive", 0); + this.selectorThreadNameFormat = threadNameFormat(conn.getStringParam("selectorThreadNameFormat", DEFAULT_SELECTOR_THREAD_NAME_FORMAT)); } // Do not use for new servers! Use ConfigParser instead @@ -71,4 +83,15 @@ public static ServerConfig from(String conn) { public static ServerConfig from(ConnectionString conn) { return new ServerConfig(conn); } + + public String formatSelectorThreadName(int threadNumber) { + return String.format(Locale.ROOT, selectorThreadNameFormat, threadNumber); + } + + public static String threadNameFormat(String s) { + // validate pattern + String.format(Locale.ROOT, s, 0); + + return s; + } } diff --git a/src/one/nio/server/acceptor/Acceptor.java b/src/one/nio/server/acceptor/Acceptor.java new file mode 100644 index 0000000..17132c6 --- /dev/null +++ b/src/one/nio/server/acceptor/Acceptor.java @@ -0,0 +1,39 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; + +import one.nio.server.AcceptorConfig; + +public interface Acceptor { + void reconfigure(AcceptorConfig... configs) throws IOException; + + void start(); + + void shutdown(); + + void syncStart() throws InterruptedException; + + long getAcceptedSessions(); + + long getRejectedSessions(); + + void resetCounters(); + + int getSinglePort(); +} diff --git a/src/one/nio/server/acceptor/AcceptorFactory.java b/src/one/nio/server/acceptor/AcceptorFactory.java new file mode 100644 index 0000000..c16273c --- /dev/null +++ b/src/one/nio/server/acceptor/AcceptorFactory.java @@ -0,0 +1,56 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; + +import one.nio.server.AcceptorConfig; +import one.nio.server.Server; +import one.nio.server.ServerConfig; + +public abstract class AcceptorFactory { + private AcceptorFactory() { + } + + public abstract Acceptor create(Server s, AcceptorConfig... configs) throws IOException; + + public static AcceptorFactory get(ServerConfig sc) { + if (sc.multiAcceptor) { + return MultiAcceptorFactory.INSTANCE; + } else { + return DefaultAcceptorFactory.INSTANCE; + } + } + + private static class DefaultAcceptorFactory extends AcceptorFactory { + private static final DefaultAcceptorFactory INSTANCE = new DefaultAcceptorFactory(); + + @Override + public Acceptor create(Server s, AcceptorConfig... configs) throws IOException { + return new DefaultAcceptor(s, configs); + } + } + + private static class MultiAcceptorFactory extends AcceptorFactory { + private static final MultiAcceptorFactory INSTANCE = new MultiAcceptorFactory(); + + @Override + public Acceptor create(Server s, AcceptorConfig... configs) throws IOException { + return new MultiAcceptor(s, configs); + } + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/AcceptorSupport.java b/src/one/nio/server/acceptor/AcceptorSupport.java new file mode 100644 index 0000000..0ad8f21 --- /dev/null +++ b/src/one/nio/server/acceptor/AcceptorSupport.java @@ -0,0 +1,66 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; + +import one.nio.net.Socket; +import one.nio.net.SslContext; +import one.nio.server.AcceptorConfig; + +public class AcceptorSupport { + private AcceptorSupport() {} + + public static Socket createServerSocket(AcceptorConfig config) throws IOException { + Socket serverSocket = Socket.createServerSocket(); + if (config.ssl != null) { + SslContext sslContext = SslContext.create(); + sslContext.configure(config.ssl); + serverSocket = serverSocket.sslWrap(sslContext); + } + if (config.recvBuf != 0) serverSocket.setRecvBuffer(config.recvBuf); + if (config.sendBuf != 0) serverSocket.setSendBuffer(config.sendBuf); + if (config.tos != 0) serverSocket.setTos(config.tos); + if (config.notsentLowat != 0) serverSocket.setNotsentLowat(config.notsentLowat); + if (config.deferAccept) serverSocket.setDeferAccept(true); + + serverSocket.setKeepAlive(config.keepAlive); + serverSocket.setNoDelay(config.noDelay); + serverSocket.setTcpFastOpen(config.tcpFastOpen); + serverSocket.setReuseAddr(true, config.reusePort); + serverSocket.setThinLinearTimeouts(config.thinLto); + return serverSocket; + } + + public static void reconfigureSocket(Socket socket, AcceptorConfig config) throws IOException { + if (config.recvBuf != 0) socket.setRecvBuffer(config.recvBuf); + if (config.sendBuf != 0) socket.setSendBuffer(config.sendBuf); + if (config.tos != 0) socket.setTos(config.tos); + if (config.notsentLowat != 0) socket.setNotsentLowat(config.notsentLowat); + socket.setDeferAccept(config.deferAccept); + socket.setKeepAlive(config.keepAlive); + socket.setNoDelay(config.noDelay); + socket.setTcpFastOpen(config.tcpFastOpen); + socket.setReuseAddr(true, config.reusePort); + socket.setThinLinearTimeouts(config.thinLto); + + SslContext sslContext = socket.getSslContext(); + if (sslContext != null && config.ssl != null) { + sslContext.configure(config.ssl); + } + } +} \ No newline at end of file diff --git a/src/one/nio/server/AcceptorThread.java b/src/one/nio/server/acceptor/AcceptorThread.java similarity index 57% rename from src/one/nio/server/AcceptorThread.java rename to src/one/nio/server/acceptor/AcceptorThread.java index 0573c0f..891b1e4 100755 --- a/src/one/nio/server/AcceptorThread.java +++ b/src/one/nio/server/acceptor/AcceptorThread.java @@ -14,19 +14,24 @@ * limitations under the License. */ -package one.nio.server; - -import one.nio.net.Session; -import one.nio.net.Socket; -import one.nio.net.SslContext; +package one.nio.server.acceptor; import java.io.IOException; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import one.nio.net.Session; +import one.nio.net.Socket; +import one.nio.server.AcceptorConfig; +import one.nio.server.RejectedSessionException; +import one.nio.server.Server; + final class AcceptorThread extends Thread { private static final Logger log = LoggerFactory.getLogger(AcceptorThread.class); + final DefaultAcceptorGroup group; + final int num; final String address; final int port; final int backlog; @@ -36,53 +41,22 @@ final class AcceptorThread extends Thread { volatile long acceptedSessions; volatile long rejectedSessions; - AcceptorThread(Server server, AcceptorConfig config, int num) throws IOException { + AcceptorThread(Server server, AcceptorConfig config, DefaultAcceptorGroup group, int num) throws IOException { super("NIO Acceptor " + config.address + ":" + config.port + " #" + num); + this.group = group; + this.num = num; this.address = config.address; this.port = config.port; this.backlog = config.backlog; this.server = server; - Socket serverSocket = Socket.createServerSocket(); - if (config.ssl != null) { - SslContext sslContext = SslContext.create(); - sslContext.configure(config.ssl); - serverSocket = serverSocket.sslWrap(sslContext); - } - this.serverSocket = serverSocket; - - if (config.recvBuf != 0) serverSocket.setRecvBuffer(config.recvBuf); - if (config.sendBuf != 0) serverSocket.setSendBuffer(config.sendBuf); - if (config.tos != 0) serverSocket.setTos(config.tos); - if (config.deferAccept) serverSocket.setDeferAccept(true); - - serverSocket.setKeepAlive(config.keepAlive); - serverSocket.setNoDelay(config.noDelay); - serverSocket.setTcpFastOpen(config.tcpFastOpen); - serverSocket.setReuseAddr(true, config.reusePort); + Socket serverSocket = AcceptorSupport.createServerSocket(config); serverSocket.bind(address, port, backlog); + this.serverSocket = serverSocket; } void reconfigure(AcceptorConfig config) throws IOException { - if (config.recvBuf != 0) { - serverSocket.setRecvBuffer(config.recvBuf); - } - if (config.sendBuf != 0) { - serverSocket.setSendBuffer(config.sendBuf); - } - if (config.tos != 0) { - serverSocket.setTos(config.tos); - } - serverSocket.setDeferAccept(config.deferAccept); - serverSocket.setKeepAlive(config.keepAlive); - serverSocket.setNoDelay(config.noDelay); - serverSocket.setTcpFastOpen(config.tcpFastOpen); - serverSocket.setReuseAddr(true, config.reusePort); - - SslContext sslContext = serverSocket.getSslContext(); - if (sslContext != null && config.ssl != null) { - sslContext.configure(config.ssl); - } + AcceptorSupport.reconfigureSocket(serverSocket, config); } void shutdown() { @@ -102,7 +76,7 @@ public void run() { log.error("Cannot start listening at {}", port, e); return; } finally { - server.startSync.countDown(); + group.syncLatch.countDown(); } while (serverSocket.isOpen()) { @@ -110,12 +84,10 @@ public void run() { try { socket = serverSocket.acceptNonBlocking(); Session session = server.createSession(socket); - server.register(session); + server.register(session, num, group.size()); acceptedSessions++; } catch (RejectedSessionException e) { - if (log.isDebugEnabled()) { - log.debug("Rejected session from {}", socket.getRemoteAddress(), e); - } + log.debug("Rejected session from {}", socket.getRemoteAddress(), e); rejectedSessions++; socket.close(); } catch (Throwable e) { diff --git a/src/one/nio/server/acceptor/DefaultAcceptor.java b/src/one/nio/server/acceptor/DefaultAcceptor.java new file mode 100644 index 0000000..ebcdf04 --- /dev/null +++ b/src/one/nio/server/acceptor/DefaultAcceptor.java @@ -0,0 +1,124 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import one.nio.server.AcceptorConfig; +import one.nio.server.Server; + +public class DefaultAcceptor implements Acceptor { + private static final Logger log = LoggerFactory.getLogger(DefaultAcceptor.class); + + private final Server server; + + private volatile DefaultAcceptorGroup[] acceptorGroups; + + DefaultAcceptor(Server server, AcceptorConfig... configs) throws IOException { + this.server = server; + + DefaultAcceptorGroup[] acceptorGroups = new DefaultAcceptorGroup[configs.length]; + for (int configIdx = 0; configIdx < configs.length; configIdx++) { + AcceptorConfig config = configs[configIdx]; + acceptorGroups[configIdx] = new DefaultAcceptorGroup(server, config); + } + this.acceptorGroups = acceptorGroups; + } + + @Override + public void reconfigure(AcceptorConfig... configs) throws IOException { + // Create a copy of the array, since the elements will be nulled out + // to allow reconfiguring multiple acceptors with the same address:port + DefaultAcceptorGroup[] oldAcceptorGroups = this.acceptorGroups.clone(); + DefaultAcceptorGroup[] newAcceptorGroups = new DefaultAcceptorGroup[configs.length]; + for (int configIdx = 0; configIdx < configs.length; configIdx++) { + AcceptorConfig ac = configs[configIdx]; + DefaultAcceptorGroup oldGroup = configIdx < oldAcceptorGroups.length ? oldAcceptorGroups[configIdx] : null; + if (oldGroup != null && oldGroup.isSameAddressPort(ac)) { + log.info("Reconfiguring acceptor group: {}", oldGroup); + oldGroup.reconfigure(ac); + newAcceptorGroups[configIdx] = oldGroup; + oldAcceptorGroups[configIdx] = null; + } else { + DefaultAcceptorGroup newGroup = new DefaultAcceptorGroup(server, ac); + log.info("New acceptor group: {}", newGroup); + newAcceptorGroups[configIdx] = newGroup; + newGroup.start(); + } + } + + for (DefaultAcceptorGroup oldGroup : oldAcceptorGroups) { + if (oldGroup != null) { + log.info("Stopping acceptor group: {}", oldGroup); + oldGroup.shutdown(); + } + } + + this.acceptorGroups = newAcceptorGroups; + } + + @Override + public void start() { + for (DefaultAcceptorGroup acceptorGroup : this.acceptorGroups) { + acceptorGroup.start(); + } + } + + @Override + public void shutdown() { + for (DefaultAcceptorGroup acceptorGroup : this.acceptorGroups) { + acceptorGroup.shutdown(); + } + } + + @Override + public void syncStart() throws InterruptedException { + for (DefaultAcceptorGroup acceptorGroup : this.acceptorGroups) { + acceptorGroup.syncStart(); + } + } + + @Override + public long getAcceptedSessions() { + return Arrays.stream(this.acceptorGroups) + .mapToLong(DefaultAcceptorGroup::getAcceptedSessions) + .sum(); + } + + @Override + public long getRejectedSessions() { + return Arrays.stream(this.acceptorGroups) + .mapToLong(DefaultAcceptorGroup::getRejectedSessions) + .sum(); + } + + @Override + public void resetCounters() { + for (DefaultAcceptorGroup acceptorGroup : this.acceptorGroups) { + acceptorGroup.resetCounters(); + } + } + + @Override + public int getSinglePort() { + return acceptorGroups[0].getPort(); + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/DefaultAcceptorGroup.java b/src/one/nio/server/acceptor/DefaultAcceptorGroup.java new file mode 100644 index 0000000..271c6c4 --- /dev/null +++ b/src/one/nio/server/acceptor/DefaultAcceptorGroup.java @@ -0,0 +1,143 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.util.Arrays; +import java.util.concurrent.CountDownLatch; + +import one.nio.server.AcceptorConfig; +import one.nio.server.Server; + +class DefaultAcceptorGroup { + private final Server server; + private final String address; + private final int port; + + private volatile AcceptorThread[] acceptors; + + CountDownLatch syncLatch; + + public DefaultAcceptorGroup(Server server, AcceptorConfig ac) throws IOException { + this.server = server; + this.address = ac.address; + this.port = ac.port; + + AcceptorThread[] acceptors = new AcceptorThread[ac.threads]; + for (int threadId = 0; threadId < ac.threads; threadId++) { + acceptors[threadId] = new AcceptorThread(server, ac, this, threadId); + } + this.acceptors = acceptors; + } + + public void reconfigure(AcceptorConfig ac) throws IOException { + if (!isSameAddressPort(ac)) { + throw new IllegalArgumentException("Acceptor config has different address:port"); + } + AcceptorThread[] oldAcceptors = this.acceptors; + if (ac.threads < oldAcceptors.length) { + for (int i = 0; i < oldAcceptors.length; i++) { + if (i < ac.threads) { + oldAcceptors[i].reconfigure(ac); + } else { + oldAcceptors[i].shutdown(); + } + } + this.acceptors = Arrays.copyOf(oldAcceptors, ac.threads); + } else { + AcceptorThread[] newAcceptors = Arrays.copyOf(oldAcceptors, ac.threads); + for (int i = 0; i < newAcceptors.length; i++) { + if (newAcceptors[i] != null) { + newAcceptors[i].reconfigure(ac); + } else { + newAcceptors[i] = new AcceptorThread(server, ac, this, i); + newAcceptors[i].start(); + } + } + this.acceptors = newAcceptors; + } + } + + public boolean isSameAddressPort(AcceptorConfig ac) { + return ac.address.equals(address) && ac.port == port; + } + + public void start() { + AcceptorThread[] acceptors = this.acceptors; + this.syncLatch = new CountDownLatch(acceptors.length); + for (AcceptorThread acceptor : acceptors) { + acceptor.start(); + } + } + + public void syncStart() throws InterruptedException { + if (this.syncLatch != null) { + this.syncLatch.await(); + } + } + + public void shutdown() { + for (AcceptorThread acceptor : acceptors) { + acceptor.shutdown(); + } + } + + public String getAddress() { + return address; + } + + public int getPort() { + return port; + } + + public int size() { + AcceptorThread[] acceptors = this.acceptors; + return acceptors == null ? 0 : acceptors.length; + } + + public long getAcceptedSessions() { + long sum = 0; + for (AcceptorThread acceptor : acceptors) { + sum += acceptor.acceptedSessions; + } + return sum; + } + + public long getRejectedSessions() { + long sum = 0; + for (AcceptorThread acceptor : acceptors) { + sum += acceptor.rejectedSessions; + } + return sum; + } + + public void resetCounters() { + for (AcceptorThread acceptor : acceptors) { + acceptor.acceptedSessions = 0; + acceptor.rejectedSessions = 0; + } + } + + @Override + public String toString() { + return "DefaultAcceptorGroup{" + + "address='" + address + '\'' + + ", port=" + port + + ", size=" + size() + + '}'; + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/MultiAcceptSession.java b/src/one/nio/server/acceptor/MultiAcceptSession.java new file mode 100644 index 0000000..9edea3a --- /dev/null +++ b/src/one/nio/server/acceptor/MultiAcceptSession.java @@ -0,0 +1,52 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.nio.channels.SelectionKey; + +import one.nio.net.Selector; +import one.nio.net.Session; +import one.nio.net.Socket; +import one.nio.os.NativeLibrary; +import one.nio.server.AcceptorConfig; + +class MultiAcceptSession extends Session { + final int backlog; + final MultiAcceptorGroup group; + final int idx; + + MultiAcceptSession(Socket socket, int backlog, MultiAcceptorGroup group, int idx) { + super(socket, acceptOp()); + this.backlog = backlog; + this.group = group; + this.idx = idx; + } + + void listen(Selector selector) throws IOException { + socket.listen(backlog); + selector.register(this); + } + + public void reconfigure(AcceptorConfig newConfig) throws IOException { + AcceptorSupport.reconfigureSocket(socket, newConfig); + } + + static int acceptOp() { + return NativeLibrary.IS_SUPPORTED ? READABLE : SelectionKey.OP_ACCEPT; + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/MultiAcceptor.java b/src/one/nio/server/acceptor/MultiAcceptor.java new file mode 100644 index 0000000..9f6f606 --- /dev/null +++ b/src/one/nio/server/acceptor/MultiAcceptor.java @@ -0,0 +1,134 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.util.Arrays; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import one.nio.server.AcceptorConfig; +import one.nio.server.Server; + +class MultiAcceptor implements Acceptor { + private static final Logger log = LoggerFactory.getLogger(MultiAcceptor.class); + + private final MultiAcceptorThread thread; + + private volatile MultiAcceptorGroup[] acceptorGroups; + + MultiAcceptor(Server server, AcceptorConfig... configs) throws IOException { + this.thread = new MultiAcceptorThread(server); + + MultiAcceptorGroup[] newGroups = new MultiAcceptorGroup[configs.length]; + for (int configIdx = 0; configIdx < configs.length; configIdx++) { + AcceptorConfig config = configs[configIdx]; + validateConfig(config); + newGroups[configIdx] = new MultiAcceptorGroup(thread, config); + } + + setAcceptorGroups(newGroups); + } + + @Override + public void reconfigure(AcceptorConfig... configs) throws IOException { + MultiAcceptorGroup[] oldGroups = this.acceptorGroups.clone(); + MultiAcceptorGroup[] newGroups = new MultiAcceptorGroup[configs.length]; + for (int configIdx = 0; configIdx < configs.length; configIdx++) { + AcceptorConfig newConfig = configs[configIdx]; + validateConfig(newConfig); + MultiAcceptorGroup oldGroup = configIdx < oldGroups.length ? oldGroups[configIdx] : null; + if (oldGroup != null && oldGroup.isSameAddressPort(newConfig)) { + log.info("Reconfiguring acceptor group: {}", oldGroup); + oldGroup.reconfigure(newConfig); + newGroups[configIdx] = oldGroup; + oldGroups[configIdx] = null; + } else { + MultiAcceptorGroup newGroup = new MultiAcceptorGroup(thread, newConfig); + log.info("New acceptor group: {}", newGroup); + newGroups[configIdx] = newGroup; + newGroup.start(); + } + } + + for (MultiAcceptorGroup oldGroup : oldGroups) { + if (oldGroup != null) { + oldGroup.close(); + } + } + + setAcceptorGroups(newGroups); + } + + @Override + public void start() { + thread.start(); + for (MultiAcceptorGroup group : acceptorGroups) { + try { + group.start(); + } catch (IOException e) { + log.error("Cannot start listening at {}", group, e); + } + } + } + + @Override + public void syncStart() { + // not needed, this is a single thread + } + + @Override + public void shutdown() { + thread.shutdown(); + } + + @Override + public long getAcceptedSessions() { + return thread.acceptedSessions; + } + + @Override + public long getRejectedSessions() { + return thread.rejectedSessions; + } + + @Override + public void resetCounters() { + thread.acceptedSessions = 0; + thread.rejectedSessions = 0; + } + + @Override + public int getSinglePort() { + return acceptorGroups[0].getPort(); + } + + private void validateConfig(AcceptorConfig newConfig) { + if (newConfig.threads <= 0) { + throw new IllegalArgumentException("Cannot create acceptor with 0 ports"); + } + if (newConfig.threads > 1 && !newConfig.reusePort) { + throw new IllegalArgumentException("Cannot create multiport acceptor without reusePort"); + } + } + + private void setAcceptorGroups(MultiAcceptorGroup[] newGroups) { + this.acceptorGroups = newGroups; + thread.setName("NIO MultiAcceptor " + Arrays.toString(newGroups)); + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/MultiAcceptorGroup.java b/src/one/nio/server/acceptor/MultiAcceptorGroup.java new file mode 100644 index 0000000..d2c44bd --- /dev/null +++ b/src/one/nio/server/acceptor/MultiAcceptorGroup.java @@ -0,0 +1,112 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.util.Arrays; + +import one.nio.net.Socket; +import one.nio.server.AcceptorConfig; + +public class MultiAcceptorGroup { + private final MultiAcceptorThread thread; + private final String address; + private final int port; + + private volatile MultiAcceptSession[] sessions; + + MultiAcceptorGroup(MultiAcceptorThread thread, AcceptorConfig config) throws IOException { + this.thread = thread; + this.address = config.address; + this.port = config.port; + + MultiAcceptSession[] sessions = new MultiAcceptSession[config.threads]; + for (int sessionIdx = 0; sessionIdx < config.threads; sessionIdx++) { + sessions[sessionIdx] = createMultiAcceptSession(config, sessionIdx); + } + this.sessions = sessions; + } + + boolean isSameAddressPort(AcceptorConfig config) { + return this.address.equals(config.address) && this.port == config.port; + } + + public String getAddress() { + return address; + } + + public int getPort() { + return port; + } + + void start() throws IOException { + for (MultiAcceptSession session : sessions) { + thread.register(session); + } + } + + void close() { + for (MultiAcceptSession session : sessions) { + session.close(); + } + } + + int size() { + MultiAcceptSession[] sessions = this.sessions; + return sessions == null ? 0 : sessions.length; + } + + void reconfigure(AcceptorConfig newConfig) throws IOException { + if (!isSameAddressPort(newConfig)) { + throw new IllegalArgumentException("Acceptor config has different address:port"); + } + MultiAcceptSession[] oldSessions = this.sessions; + MultiAcceptSession[] newSessions = Arrays.copyOf(oldSessions, newConfig.threads); + if (oldSessions.length > newConfig.threads) { + for (int sessionIdx = 0; sessionIdx < oldSessions.length; sessionIdx++) { + if (sessionIdx < newSessions.length) { + oldSessions[sessionIdx].reconfigure(newConfig); + } else { + oldSessions[sessionIdx].close(); + } + } + } else { + for (int sessionIdx = 0; sessionIdx < newSessions.length; sessionIdx++) { + MultiAcceptSession session = newSessions[sessionIdx]; + if (session != null) { + session.reconfigure(newConfig); + } else { + MultiAcceptSession acceptSession = createMultiAcceptSession(newConfig, sessionIdx); + thread.register(acceptSession); + } + } + } + this.sessions = newSessions; + } + + @Override + public String toString() { + return address + ':' + port + 'x' + size(); + } + + private MultiAcceptSession createMultiAcceptSession(AcceptorConfig config, int sessionIdx) throws IOException { + Socket serverSocket = AcceptorSupport.createServerSocket(config); + serverSocket.setBlocking(false); + serverSocket.bind(config.address, config.port, config.backlog); + return new MultiAcceptSession(serverSocket, config.backlog, this, sessionIdx); + } +} \ No newline at end of file diff --git a/src/one/nio/server/acceptor/MultiAcceptorThread.java b/src/one/nio/server/acceptor/MultiAcceptorThread.java new file mode 100644 index 0000000..6746566 --- /dev/null +++ b/src/one/nio/server/acceptor/MultiAcceptorThread.java @@ -0,0 +1,95 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.server.acceptor; + +import java.io.IOException; +import java.util.Iterator; +import java.util.concurrent.RejectedExecutionException; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import one.nio.net.Selector; +import one.nio.net.Session; +import one.nio.net.Socket; +import one.nio.server.Server; + +class MultiAcceptorThread extends Thread { + private static final Logger log = LoggerFactory.getLogger(MultiAcceptorThread.class); + + private static final int MAX_ACCEPTED_PER_SOCKET = 128; + + private final Server server; + private final Selector selector; + + volatile long acceptedSessions; + volatile long rejectedSessions; + + MultiAcceptorThread(Server server) throws IOException { + this.server = server; + this.selector = Selector.create(); + } + + void register(MultiAcceptSession session) throws IOException { + session.listen(selector); + } + + @Override + public void run() { + Socket clientSocket = null; + try { + while (!Thread.currentThread().isInterrupted() && selector.isOpen()) { + Iterator it = selector.select(); + while (it.hasNext()) { + MultiAcceptSession as = (MultiAcceptSession) it.next(); + int accepted = 0; + while (accepted < MAX_ACCEPTED_PER_SOCKET && (clientSocket = as.socket().acceptNonBlocking()) != null) { + try { + Session clientSession = server.createSession(clientSocket); + server.register(clientSession, as.idx, as.group.size()); + clientSocket = null; + acceptedSessions++; + accepted++; + } catch (RejectedExecutionException e) { + log.debug("Rejected session from {}", clientSocket.getRemoteAddress(), e); + rejectedSessions++; + clientSocket.close(); + clientSocket = null; + } + } + } + } + } catch (Throwable t) { + if (selector.isOpen()) { + log.error("Cannot accept incoming connection", t); + } + if (clientSocket != null) { + clientSocket.close(); + } + } + } + + void shutdown() { + selector.close(); + interrupt(); + try { + join(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } +} \ No newline at end of file diff --git a/src/one/nio/util/JavaFeatures.java b/src/one/nio/util/JavaFeatures.java index 29ec3ff..7298fea 100644 --- a/src/one/nio/util/JavaFeatures.java +++ b/src/one/nio/util/JavaFeatures.java @@ -56,7 +56,10 @@ public static void onSpinWait() { } /** - * Calls Class.isRecord() since Java 14 preview; returns false otherwise + * Calls Class.isRecord() since Java 14 preview + * + * @param cls a class object + * @return the result of the Class.isRecord() method invoked. It is always false, if the version of the JVM Runtime is less than 14 */ public static boolean isRecord(Class cls) { if (isRecord != null) { diff --git a/test/one/nio/config/ConfigParserTest.java b/test/one/nio/config/ConfigParserTest.java index a7afccb..14e550a 100644 --- a/test/one/nio/config/ConfigParserTest.java +++ b/test/one/nio/config/ConfigParserTest.java @@ -28,7 +28,7 @@ public class ConfigParserTest { "keepAlive: 120s\n" + "maxWorkers: 1000\n" + "queueTime: 50MS\n" + - "\n" + + "selectorThreadNameFormat: push sel-r #%d\n" + "acceptors:\n" + " - port: 443\n" + " backlog: 10000\n" + @@ -68,6 +68,8 @@ public void testConfigParser() throws Exception { assertEquals(50, config.queueTime); assertEquals(0, config.minWorkers); assertEquals(0, config.selectors); + assertEquals("push sel-r #%d", config.selectorThreadNameFormat); + assertEquals("push sel-r #42", config.formatSelectorThreadName(42)); assertEquals(false, config.affinity); assertEquals(Thread.NORM_PRIORITY, config.threadPriority); diff --git a/test/one/nio/http/ChunkedEventReaderTest.java b/test/one/nio/http/ChunkedEventReaderTest.java new file mode 100644 index 0000000..8c6ec74 --- /dev/null +++ b/test/one/nio/http/ChunkedEventReaderTest.java @@ -0,0 +1,261 @@ +/* + * Copyright 2019 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.http; + +import one.nio.http.EventSource.Event; +import one.nio.net.ConnectionString; +import one.nio.net.SocketUtil; +import one.nio.pool.PoolException; +import one.nio.util.Hex; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; + +import static org.junit.Assert.*; + +/** + * Unit tests for client and server support for HTTP Chunked transfer encoding + */ +public class ChunkedEventReaderTest { + private static final String ENDPOINT = "/echoChunked"; + + private static TestServer server; + private static HttpClient client; + + @BeforeClass + public static void beforeAll() throws IOException { + int availablePort = SocketUtil.getFreePort(); + server = new TestServer(HttpServerConfigFactory.create(availablePort)); + server.start(); + client = new HttpClient(new ConnectionString("http://127.0.0.1:" + availablePort+"?bufferSize=8000")); + } + + @AfterClass + public static void afterAll() { + client.close(); + server.stop(); + } + + private static String size(int size) { + return ENDPOINT + "?size=" + size; + } + + + @Test + public void testNoEvents() throws InterruptedException, PoolException, IOException, HttpException { + final Request req = echoReq( "".getBytes(), 0 ); + EventSourceResponse events = client.openEvents( req, 1000 ); + + assertNull("poll() must return null when EOF", events.poll() ); + } + + + @Test + public void testEmptyEvent() throws InterruptedException, PoolException, IOException, HttpException { + final Request req = echoReq( ":only comment and no data\n\n".getBytes(), 0 ); + EventSourceResponse events = client.openEvents( req, 1000 ); + + Event event = events.poll(); + assertEquals( "only comment and no data", event.comment() ); + assertTrue("Events consisting of only comment are empty", event.isEmpty() ); + } + + @Test + public void testEvents() throws InterruptedException, PoolException, IOException, HttpException { + final Request req = echoReq( + ("id:CAFEBABE1\n" + + "event:testmebabe\n" + + "data:dataisgold\n" + + "\n" + + "id: CAFEBABE2\n" + + "event: testmebabe2\n" + + "nosuchfield: oioioi\n" + + "data: dataisgold2\n" + + "\n" + + "\n").getBytes(), 7 ); + EventSourceResponse events = client.openEvents( req, 1000 ); + + { + Event event = events.poll(); + assertEquals( "CAFEBABE1", event.id() ); + assertEquals( "testmebabe", event.name() ); + assertEquals( "dataisgold", event.data() ); + } + { + Event event = events.poll(); + assertEquals( "Space after colon must be ignored in id", "CAFEBABE2", event.id() ); + assertEquals( "Space after colon must be ignored in event","testmebabe2", event.name() ); + assertEquals( "Space after colon must be ignored in data","dataisgold2", event.data() ); + } + + assertNull("poll() must return null when EOF", events.poll() ); + } + + @Test + public void testLargeEvent() throws InterruptedException, PoolException, IOException, HttpException { + byte[] bytes = new byte[20001]; + ThreadLocalRandom.current().nextBytes(bytes); + String data = Hex.toHex(bytes); // length=40000 > 32768 + final Request req = echoReq(32786, + "id:CAFEBABE1\n" + + "event:testmebabe\n" + + "data:", + data + "\n" + + "\n", + "id: CAFEBABE2\n" + + "event: testmebabe2\n" + + "nosuchfield: oioioi\n" + + "data: ", + data+"2\n" + + "\n" + + "\n"); + EventSourceResponse events = client.openEvents( req, 1000 ); + { + Event event = events.poll(); + assertEquals( "CAFEBABE1", event.id() ); + assertEquals( "testmebabe", event.name() ); + assertEquals( data, event.data() ); + } + { + Event event = events.poll(); + assertEquals( "Space after colon must be ignored in id", "CAFEBABE2", event.id() ); + assertEquals( "Space after colon must be ignored in event","testmebabe2", event.name() ); + assertEquals( "Space after colon must be ignored in data",data + "2", event.data() ); + } + + assertNull("poll() must return null when EOF", events.poll() ); + } + + private Request echoReq(byte[] body, int chunkSize) { + if ( chunkSize == 0 ) + chunkSize = body.length; + Request req = client.createRequest( Request.METHOD_PUT, size( chunkSize ) ); + server.data = Collections.singletonList(body); + return req; + } + + private Request echoReq(int chunkSize, String... parts) { + Request req = client.createRequest( Request.METHOD_PUT, size( chunkSize ) ); + server.data = new ArrayList<>(); + for (String part : parts) { + server.data.add(part.getBytes(StandardCharsets.UTF_8)); + } + return req; + } + + private Request echoReqRaw(byte[] body, int chunkSize) { + if ( chunkSize == 0 ) + chunkSize = body.length; + Request req = client.createRequest( Request.METHOD_PUT, size( chunkSize ) ); + server.dataRaw = body; + return req; + } + + private void check(final byte[] body, final int chunkSize) throws Exception { + final Response response = client.put(size(chunkSize), body); + assertEquals(200, response.getStatus()); + if (body == null) { + assertEquals(0, response.getBody().length); + } else { + assertArrayEquals(body, response.getBody()); + } + } + + public static class TestServer extends HttpServer { + private static final byte[] CRLF = "\r\n".getBytes(StandardCharsets.US_ASCII); + private static final byte[] EOF = "0\r\n\r\n".getBytes(StandardCharsets.US_ASCII); + + private TestServer(HttpServerConfig config) throws IOException { + super(config); + } + + List data; + byte[] dataRaw; + + @Path(ENDPOINT) + public Response echo( + final Request request, + @Param(value = "size", required = true) final int size) throws IOException { + if (data == null && dataRaw == null) { + return Response.ok(Response.EMPTY); + } + + if (size < 1) { + return new Response(Response.BAD_REQUEST, Response.EMPTY); + } + + if (dataRaw != null) { + final byte[] content = dataRaw; + dataRaw = null; + final Response response = new Response(Response.OK, content); + response.addHeader("Transfer-encoding: chunked"); + response.addHeader("Content-Type: text/event-stream"); + + return response; + } + + // Slice into parts + int contentLength = 0; + final List parts = data; + final List chunks = new ArrayList<>(); + for (byte[] part : parts) { + for (int start = 0; start < part.length; start += size) { + // Encode chunk + final int chunkLength = Math.min(part.length - start, size); + final byte[] encodedLength = Integer.toHexString(chunkLength).getBytes(StandardCharsets.US_ASCII); + final byte[] chunk = new byte[encodedLength.length + 2 + chunkLength + 2]; + final ByteBuffer buffer = ByteBuffer.wrap(chunk); + buffer.put(encodedLength); + buffer.put(CRLF); + buffer.put(part, start, chunkLength); + buffer.put(CRLF); + assert !buffer.hasRemaining(); + + // Save + chunks.add(chunk); + contentLength += chunk.length; + } + } + + // EOF + chunks.add(EOF); + contentLength += EOF.length; + + // Concat + final byte[] content = new byte[contentLength]; + final ByteBuffer contentBuffer = ByteBuffer.wrap(content); + for (final byte[] chunk : chunks) { + contentBuffer.put(chunk); + } + + final Response response = new Response(Response.OK, content); + response.addHeader("Transfer-encoding: chunked"); + response.addHeader("Content-Type: text/event-stream"); + + return response; + } + } +} diff --git a/test/one/nio/http/HttpHeaderTest.java b/test/one/nio/http/HttpHeaderTest.java index 70a2dfa..1a11095 100644 --- a/test/one/nio/http/HttpHeaderTest.java +++ b/test/one/nio/http/HttpHeaderTest.java @@ -23,6 +23,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; /** * Unit tests for HTTP header processing facilities. @@ -31,6 +32,7 @@ */ public class HttpHeaderTest { private static final String HEADER = "X-OK-Custom-Header: "; + private static final String HEADER_KEY = "X-OK-Custom-Header"; private void testHeaderConsumer(final String... values) { final Request request = new Request(Request.METHOD_GET, "/", true); @@ -41,6 +43,10 @@ private void testHeaderConsumer(final String... values) { final List sink = new ArrayList<>(values.length); request.consumeHeaders(HEADER, sink::add); assertEquals(Arrays.asList(values), sink); + + final List sinkValues = new ArrayList<>(values.length); + request.consumeHeaderValues(HEADER_KEY, sinkValues::add); + assertEquals(Arrays.asList(values), sinkValues); } @Test @@ -57,4 +63,19 @@ public void consumeSingle() { public void consumeDouble() { testHeaderConsumer("First", "Second"); } + + @Test + public void testHeaderValue() { + final Request request = new Request(Request.METHOD_GET, "/", true); + request.addHeader("X-Custom-Header-1: 01"); + request.addHeader("X-Custom-Header-2: 02"); + + assertEquals("01", request.getHeaderValue("X-Custom-Header-1")); + assertEquals("02", request.getHeaderValue("X-Custom-Header-2")); + assertNull(request.getHeaderValue("X-Custom-Header-3")); + assertNull(request.getHeaderValue("X-Very-Long-Key-Custom-Header")); + assertNull(request.getHeaderValue("X-Custom-Header")); + assertNull(request.getHeaderValue("X-Custom-Header ")); + assertNull(request.getHeaderValue("X-Custom-Header:")); + } } diff --git a/test/one/nio/mem/LongHashSetFuncTest.java b/test/one/nio/mem/LongHashSetFuncTest.java new file mode 100644 index 0000000..293a8a4 --- /dev/null +++ b/test/one/nio/mem/LongHashSetFuncTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2024 LLC VK + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.mem; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class LongHashSetFuncTest { + private final LongHashSet set = new LongHashSet(10); + + @Test + public void testClear() { + set.putKey(1L); + set.putKey(2L); + assertEquals(2, set.size()); + set.clear(); + assertEquals(0, set.size()); + } +} \ No newline at end of file diff --git a/test/one/nio/mem/SharedMemoryStringMapTest.java b/test/one/nio/mem/SharedMemoryStringMapTest.java new file mode 100644 index 0000000..fe149c5 --- /dev/null +++ b/test/one/nio/mem/SharedMemoryStringMapTest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2015 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.mem; + +import org.junit.Test; + +import java.io.IOException; + +public class SharedMemoryStringMapTest { + @Test + public void create() throws IOException { + new SharedMemoryStringMap<>( + 100, + "/tmp/SharedMemoryStringMapTest", + 2 * 1024 * 1024 + ).close(); + } +} diff --git a/test/one/nio/net/SocketTest.java b/test/one/nio/net/SocketTest.java index fbd16b1..2bd7213 100755 --- a/test/one/nio/net/SocketTest.java +++ b/test/one/nio/net/SocketTest.java @@ -26,7 +26,7 @@ public class SocketTest { private static void testIPv4() throws IOException { - Socket s = Socket.create(); + Socket s = Socket.createClientSocket(); s.setTimeout(3000); s.connect("google.com", 80); @@ -35,7 +35,7 @@ private static void testIPv4() throws IOException { } private static void testIPv6() throws IOException { - Socket s = Socket.create(); + Socket s = Socket.createClientSocket(); s.setTimeout(3000); s.connect("2a00:1450:4010:c07::71", 80); @@ -129,6 +129,14 @@ public static void testSocketOpts(Socket socket, boolean datagram) { socket.setTos(96); assertEquals(96, socket.getTos()); } + + if (socket instanceof NativeSocket) { + socket.setNotsentLowat(67890); + assertEquals(67890, socket.getNotsentLowat()); + + socket.setThinLinearTimeouts(true); + assertTrue(socket.getThinLinearTimeouts()); + } } catch (Exception e) { throw e; } finally { diff --git a/test/one/nio/ssl/TLSCurveTest.java b/test/one/nio/ssl/TLSCurveTest.java new file mode 100644 index 0000000..853d69b --- /dev/null +++ b/test/one/nio/ssl/TLSCurveTest.java @@ -0,0 +1,183 @@ +/* + * Copyright 2015-2016 Odnoklassniki Ltd, Mail.Ru Group + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package one.nio.ssl; + +import java.io.IOException; +import java.util.Properties; + +import javax.net.ssl.SSLSocket; +import javax.net.ssl.SSLSocketFactory; + +import one.nio.config.ConfigParser; +import one.nio.server.Server; +import one.nio.server.ServerConfig; + +import org.junit.After; +import org.junit.AfterClass; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +public class TLSCurveTest { + + private static int port = 7443; + + + private static String clientAllowedCurves = "secp521r1,secp256r1"; + + static String startupConfigTemplate = "\n" + + "acceptors:\n" + + " - port: " + port +"\n" + + " ssl:\n" + + " applicationProtocols: http/1.1\n" + + " protocols: TLSv1.3\n" + + " certFile: %s\n" + + " privateKeyFile: %s\n"; + + static String curveConfigTemplate = "\n" + + " curve: %s\n"; + + private Server server; + + private static String cert; + private static String privKey; + SSLSocket socket; + + @BeforeClass + public static void beforeClass() { + Properties systemProps = System.getProperties(); + String truststorePath = TLSCurveTest.class.getClassLoader().getResource("ssl/trustore.jks").getFile(); + systemProps.put("javax.net.ssl.trustStore", truststorePath); + systemProps.put("javax.net.ssl.trustStorePassword","changeit"); + System.setProperties(systemProps); + cert = TLSCurveTest.class.getClassLoader().getResource("ssl/certificate.crt").getFile(); + privKey = TLSCurveTest.class.getClassLoader().getResource("ssl/certificate.key").getFile(); + + // set allowed curves list for the client once + // changing jdk.tls.namedGroups after first call SSLSocketFactory.getDefault() won't take effect + System.setProperty("jdk.tls.namedGroups", clientAllowedCurves); + } + + @AfterClass + public static void tearDownClass() { + System.clearProperty("javax.net.ssl.trustStore"); + System.clearProperty("javax.net.ssl.trustStorePassword"); + System.clearProperty("jdk.tls.namedGroups"); + } + + @After + public void tearDown() throws Exception { + if (socket != null) socket.close(); + if (server != null) server.stop(); + } + + private ServerConfig getServerConfig(String curve) { + String curveConfigPart = curve == null ? "" : String.format(curveConfigTemplate, curve); + return ConfigParser.parse(String.format(startupConfigTemplate, cert, privKey) + curveConfigPart, + ServerConfig.class); + } + + private void setupServer(String curve) throws IOException { + ServerConfig config = getServerConfig(curve); + server = new Server(config); + server.start(); + } + + private void tryHandshake() throws IOException { + socket = (SSLSocket) SSLSocketFactory.getDefault().createSocket("127.0.0.1", port); + socket.setEnabledProtocols(new String[] {"TLSv1.3"}); + socket.startHandshake(); + } + + /** + * Both client and server support secp521r1 - successful handshake. + */ + @Test + public void secp521r1() throws Exception { + setupServer("secp521r1"); + tryHandshake(); + } + + /** + * Both client and server support prime256v1 - successful handshake. + * + * Name of the curve depends on standards organization: + * server curve (openssl impl.) - uses ANSI X9.62 and SECG names; + * client curve (java impl.) - uses SECG names only, + * Curve name prime256v1 (ANSI) is alias of the curve secp256r1 (SECG). + */ + @Test + public void prime256v1() throws Exception { + setupServer("prime256v1"); + tryHandshake(); + } + + /** + * Both client and server support intersecting sets of curves - successful handshake. + */ + @Test + public void secp256k1_secp521r1() throws Exception { + setupServer("secp256k1:secp521r1"); + tryHandshake(); + } + + /** + * Both client and server support not intersecting sets of curves - handshake fails. + */ + @Test(expected = javax.net.ssl.SSLHandshakeException.class) + public void client_server_curve_mismatch() throws Exception { + setupServer("secp384r1"); + tryHandshake(); + } + + /** + * A curve is not specified in the server config - successful handshake. + * Using auto selection server curve + */ + @Test + public void no_server_curve_specified() throws Exception { + setupServer(null); + tryHandshake(); + } + + /** + * A curve specified in the server config not supported by linked openssl build - server fails startup. + */ + @Test(expected = javax.net.ssl.SSLException.class) + public void bad_curve_name() throws Exception { + setupServer("prime256v1:BAD_CURVE:secp521r1"); + tryHandshake(); + } + + /** + * The server started up with a curve not supported by the client - handshake fails. + * Then reconfigure the server to supported one - successful handshake. + */ + @Test + public void reconfigure_server_curve() throws Exception { + setupServer("secp384r1"); // a curve not supported by the client + try { + tryHandshake(); + Assert.fail("First handshake didn't fail"); + } catch (javax.net.ssl.SSLHandshakeException e) { + } + ServerConfig newConfig = getServerConfig("prime256v1"); // the curve supported by the client + server.reconfigure(newConfig); + tryHandshake(); + } + +} diff --git a/test_data/ssl/ca.crt b/test_data/ssl/ca.crt new file mode 100644 index 0000000..588ffae --- /dev/null +++ b/test_data/ssl/ca.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDBTCCAe2gAwIBAgIUe94Iu7/EgJaMCVs4xUFlSjDaCAYwDQYJKoZIhvcNAQEL +BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yMzAzMTAxMzAwNTdaFw00MzAzMDUx +MzAwNTdaMBIxEDAOBgNVBAMMB1Rlc3QgQ0EwggEiMA0GCSqGSIb3DQEBAQUAA4IB +DwAwggEKAoIBAQCujUGscEKlm19a4AFXttAFg/HoqtSc80IYzFh+Mg7uG4supniQ +xvrNEw7QYHp49ABPn76SdNNxD8EpIqmgRKwZ26tJkN0thR6CSIu9FouRT7wglfJD +G+3ZU4z7zLt1U4P6dgkXGosHrZ9W76ARcNkZ8mUlXbUCCj+cMuE5TWIsk1cIk5u7 +sumslTJcUL4mQ+mlxCvo+3dS6ql+ENS1c0OqOUcOoTrctNNJRJCOyvZQNHdvJ+tn +CXwxNwOJOkNjLkT9Iumjz/t9coTdYeRQUo2OzrP0/lwxaB+hZpX3xOaplxSHJA4x +8H330DpJ9pQoKz+Mca9rtD0CX+HtkHqfGh/XAgMBAAGjUzBRMB0GA1UdDgQWBBR+ +2zzbdSV06uu9rnYy9t6rNVuOGDAfBgNVHSMEGDAWgBR+2zzbdSV06uu9rnYy9t6r +NVuOGDAPBgNVHRMBAf8EBTADAQH/MA0GCSqGSIb3DQEBCwUAA4IBAQBdk9oL/poW +2PXuvn7oaRCx+z3kmxF2YVjIRGzn7XQoi1/cpadk9U8bgTuc6IWewycke0hSapGE +R8QPGi0uEIaVBpwWYPsu8DtIWk4ImYzCZXHrOS72Re+ZzeaPl7tFAVhhZm4wTI2y +jRLs79l5iw/bCOxbqBdjOZgpq+Uhx7tA/8TWqMWD8H5J8VTm9cNSpuGGxDEKLjbC +NAIbNqZfbcXYd4o+z2UE/ZL8zIGR9sGcJM6AaPA1y8v7Q8oShLTmoAYdWEt1Nq8l +G6/8ZLbv3MTB7pCGkvz/Pzd/6nO3KybaKa63wzICPpeChdaq3XpTEB0QiphXgFwt +TMt0naO2NEB0 +-----END CERTIFICATE----- diff --git a/test_data/ssl/ca.key b/test_data/ssl/ca.key new file mode 100644 index 0000000..cc31b06 --- /dev/null +++ b/test_data/ssl/ca.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvAIBADANBgkqhkiG9w0BAQEFAASCBKYwggSiAgEAAoIBAQCujUGscEKlm19a +4AFXttAFg/HoqtSc80IYzFh+Mg7uG4supniQxvrNEw7QYHp49ABPn76SdNNxD8Ep +IqmgRKwZ26tJkN0thR6CSIu9FouRT7wglfJDG+3ZU4z7zLt1U4P6dgkXGosHrZ9W +76ARcNkZ8mUlXbUCCj+cMuE5TWIsk1cIk5u7sumslTJcUL4mQ+mlxCvo+3dS6ql+ +ENS1c0OqOUcOoTrctNNJRJCOyvZQNHdvJ+tnCXwxNwOJOkNjLkT9Iumjz/t9coTd +YeRQUo2OzrP0/lwxaB+hZpX3xOaplxSHJA4x8H330DpJ9pQoKz+Mca9rtD0CX+Ht +kHqfGh/XAgMBAAECggEAGZCRuk40zKl9ZUNinvSk8fHZt9BxVE1idFxVZ6IwQpph +H6N9DANAj1GHvWgr99QQyZilYjDfGPOMQFiVzGXSaPUd2KfxZO3WwpEaeka3iUBK +1uDvOvn3s3lDsEtTd0MUlNW9zhbyntzEdgihgnxrJ5wzSS44yQQ0Pc3L48ccJlft +/VNGmbMEKNFiBvzCl8FWvPsocfSMn22Qt5ica3n4vAV1kK+7hZd19vz5QrMn1Jis +sZQhqRaWUux6qE27zvAP3KfaTGjsNk+GNRSa0mRkiaFYzHloM5o3H4NOHEHq151s +0KD2MOPyqBpvTW7icIAdFXsCUSRbZbxmkvZzwEr8uQKBgQDhput0hx+zWa+CeKBz +CUV27IKSmfcZxfbRIpcEMCZWaCFal4DbNyAaki73dvkdbKyRJjUEu/pBxuepGnQT +BI1zIKndtBV5RBtNS78yZj5RM8wQlQLTgN0b8LCeXfsgnFxTBFeY1UJMvoXuSUHm ++LxdnKPRqPiMANnarnGX70dzvwKBgQDGBvt2I2wu5oF9Js+8dsGm8zAlD5E5sThH +VJKayjvx9zyRLja7R3XEvr2T39At8t3P2RiUn3N5DSrWthTg1SgnJhQgGx24j0El +fyGH3i2aAnCGHy4mhwFjC7avG5S/1iWpQJuMh5OnTrsrnVNDwjAftDE/R08qsGxE +9APPE8P56QKBgHMdLTa4DCL5mMCNewTUcPx3MTFKnDUouX39EF9RlO01l/ZYHaM3 +XwhPFOinZ/Rr0SkG/bsZjlRp/RpWZWqiQuq9egqg6OeBKBBpmPNEF3xjWTIIYnW1 +YpbzVZakyQzc13h+WZWdDYKLG1XxR05mC+oRk7zMX2dEs96MaWSh66iPAoGAXAqQ +xGdQeDghRYdNlN5pwexUm1Ux/eu2KclQXuvSkvOETkJ5o/Bh97FkMiAcBc9Vc68H +MsUuCbyqBaQZ7iqFADU9s3KHDOpgsEn3zsvgzC2IGX7Sl4u2hE/EeH1MVSH23UUv +R1EYuvHoIyx7sAnJDmNVZIEDdecMW7xDLPOV6kkCgYBo6kny6y58W00orYfjdjJv +6GDggOWOaPhXI4hKlcsnrutHQZiRXjsIVFjaLSyj8LHpFcdtvgrRTkgxYGcrVwn1 +bsULiQ43j3Jz/K7kclOycdL366XX79NOmSO87hqFrzk34iQT8+SgVsHm7zXWhkWk +OjgWFgZtGxCEJYJVO6n2dg== +-----END PRIVATE KEY----- diff --git a/test_data/ssl/certificate.crt b/test_data/ssl/certificate.crt new file mode 100644 index 0000000..9e0bb0b --- /dev/null +++ b/test_data/ssl/certificate.crt @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDCzCCAfOgAwIBAgIUJvEKMtCoVcAv9tGqA8aPrixn71UwDQYJKoZIhvcNAQEL +BQAwEjEQMA4GA1UEAwwHVGVzdCBDQTAeFw0yMzAzMTAxMzAwNTdaFw00MzAzMDUx +MzAwNTdaMBsxGTAXBgNVBAMMEFRlc3QgY2VydGlmaWNhdGUwggEiMA0GCSqGSIb3 +DQEBAQUAA4IBDwAwggEKAoIBAQCnYBjsgsEptJnum+1fju2OF1jzLRMjp28pzzmZ +CTtvMjpyVK9Fr2cHIcE/ZUNStvQ8hIAAnfhkz65uYAOgZEHe7Qa04YitVUOworJl +Nr7YJ5eUEA9I9KifTS4bkQxfkn7Fbn7KPZBE6X6WO4cEMZMbTZ9rzE5xMyUcbvux +X8wE2ZyJxxm2pKt60dM5Jy+KS944fOEMqq4vlNp37Q1CPrlXzjzRRB152Ct93aWQ +OpDwTySv1ae0lxzrAmoUnIErnDylb0Q6W/u7fjVmWtia9CcsGgAUNwvi1fyS1oqF ++4PT6Dgl5XNoTXmCjF91eApRokvFKn6VQYvMhBy31mAoQlpjAgMBAAGjUDBOMB0G +A1UdDgQWBBRhmg2PRgxohtpcNSt83vhOOpVASjAfBgNVHSMEGDAWgBR+2zzbdSV0 +6uu9rnYy9t6rNVuOGDAMBgNVHRMBAf8EAjAAMA0GCSqGSIb3DQEBCwUAA4IBAQAU +ZYj95zUjgpaCeeHyKPOpebNG/34pNpFmdMNS7Far+DMe9MHlD1X/+4qqo07813KU +dlWCpK4opGIBMNEaHlzXj+EqZuQelwqWJbPF1kcY5tPLAcD9XEGtHevFa3zYVhcN +n/ixPQ/RZUNJffhQ+3qzzHCCmsrsV1DfQwiPL+7ohsxO9BBBlLtsF0ZcfwObYK0u +eXGYoB+n/rixlEYzvE0LTEo7QT62iATScKiwR0OYj98t7fkiIDPwI9IhG5cAcnGs +bBW0ost0ekTAKIFNOQdwkOmWYXHv1kgQTU/glyIKWKzkklcRYKVvo5NcXMIeKJdS +8PV+76i0uEye7GSU8D1I +-----END CERTIFICATE----- diff --git a/test_data/ssl/certificate.key b/test_data/ssl/certificate.key new file mode 100644 index 0000000..cbae8e5 --- /dev/null +++ b/test_data/ssl/certificate.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQCnYBjsgsEptJnu +m+1fju2OF1jzLRMjp28pzzmZCTtvMjpyVK9Fr2cHIcE/ZUNStvQ8hIAAnfhkz65u +YAOgZEHe7Qa04YitVUOworJlNr7YJ5eUEA9I9KifTS4bkQxfkn7Fbn7KPZBE6X6W +O4cEMZMbTZ9rzE5xMyUcbvuxX8wE2ZyJxxm2pKt60dM5Jy+KS944fOEMqq4vlNp3 +7Q1CPrlXzjzRRB152Ct93aWQOpDwTySv1ae0lxzrAmoUnIErnDylb0Q6W/u7fjVm +Wtia9CcsGgAUNwvi1fyS1oqF+4PT6Dgl5XNoTXmCjF91eApRokvFKn6VQYvMhBy3 +1mAoQlpjAgMBAAECggEAJECxRVEQ66qdiMwnPwJjO9qcvpctxRQ4BLNIw3f3cp9L +f4YOLKbQRwAtrkCNs3XDMvR0ES4mJhfBrVCImI3+on1ubXFIPFryfWjQARI5hfIV +Z9GOrfuoJlD+QqRJLM47PSIwvjdVb0evznR4SxU9yrUmaZ/oAabVS/JR+9pHA2mR +KEYkn+DCuPmmGHxLFCe5d6LCupVB7gSSzkr74CKfcRnBL7bCGpYFzjMp9mGNtc4L +QomNf4+T6ZFWv3u9Uvl8mH+BfJLmlTuvUrB4/QPCkv06lXZRfX+A1V/Wa31Ua9V/ +e/dcru40IA9K9VzvUDaSTW4V26soD5A22HFleb3DMQKBgQDbJYtSe+dzy/aPmCKI +wEzhHyrqWoc02BnAz9Oqmq5bv2g15nfgxUsqR2F7hJ7Xj/kwz4sDeYRlbPngIVh8 +pdnJAKqpvtCysWgggsZe0D0yLPTIV+H0ImUUbxMftIw11oLc9LgtZ/TUcz7uisJ0 +117TGgKYgW7/NknikgqHhcnHawKBgQDDhcEXWz+YpDD/vp+zv25uRuzkQiT6yd1f +nW0wUWwNPQzXVWs7lLe0Ez8BmiTSJjcFs6jVzjvBOnQXTtqud4ikDG22og8Piv3+ +2BPgnatCkVGXIDWWW3viQz1Q9/9B/nGBeJQKWx20+sDvDnSlvfeb3zbqkoz+sj+F +nDl/YRwO6QKBgQCPdHtIao2vssPbafapWGC7OaDpOoupnxD0s9dWpd0feCPqrMyq +mdxDd+irZ7xnVfsE5ceVZbWyg2zrOEjph9QSDVqqtZt+bj3Aknry1BRLRTyT7Vuf +aeiLQM7fAVyLXbnalGQAbT3K2QpIMxNqUxXi0PMEDC6x6ELji0BTSNQ2swKBgQCx +L8JHp1KfwpQA/7/8hcvOtgmx5VtbxpoOLz0nH3J7IMtFTUyLRv+feh2MFyOtKiKM +0T5825N9Tbqs8LHuj7bNa9H1QzHA1SXO0ARbdqcgAU9eVDsb0jYXWvXzLXsuVCaK +vTnzlJT8UI2NVp1RIjGFGSjMNRj/K4uzSls0200xGQKBgBJFaE5rYi508XtARl+U +LRR1SzLihy5JaSHKjJp+7K/vp0bm5CGFGpwK6QNAkEzv2Jvl47bauzcnafdgwnsC +w4okm5CziM+uFB8rzttd9fT1pHPW5HVUqrUoVv+sImw5x12AbStJEpDudEeFUGtW +kUeXVGvc5hQaUjsHA4r6XmL7 +-----END PRIVATE KEY----- diff --git a/test_data/ssl/generate.sh b/test_data/ssl/generate.sh new file mode 100755 index 0000000..6cf0842 --- /dev/null +++ b/test_data/ssl/generate.sh @@ -0,0 +1,29 @@ +#!/bin/bash + +rm -f ca.kley ca.crt certificate.key certificate.crt trustore.jks + +# CA +openssl req -new \ + -days 7300 \ + -nodes \ + -x509 \ + -sha256 \ + -newkey rsa:2048 \ + -keyout ca.key \ + -out ca.crt \ + -subj "/CN=Test CA" + +openssl req -new \ + -days 7300 \ + -nodes \ + -x509 \ + -sha256 \ + -newkey rsa:2048 \ + -addext basicConstraints=critical,CA:FALSE \ + -CA ca.crt \ + -CAkey ca.key \ + -keyout certificate.key \ + -out certificate.crt \ + -subj "/CN=Test certificate" + +keytool -import -file ca.crt -keystore trustore.jks -storepass changeit -noprompt diff --git a/test_data/ssl/trustore.jks b/test_data/ssl/trustore.jks new file mode 100644 index 0000000000000000000000000000000000000000..0f5505d3ad2efa0b997212aeb05ec4776b93b777 GIT binary patch literal 1142 zcmV-+1d01Ff&_8`0Ru3C1RMqlDuzgg_YDCD0ic2e4FrM&2{3{L1u%jH0R{;whDe6@ z4FLxRpn?PPFoFZ_0s#Opf&=FU2`Yw2hW8Bt2LUi<1_>&LNQU+thDZTr0|Wso1Q1+uX!GZkuEL338EyhFmr{U&18~+6luVK=Oe&&2co1nXrJ!S8 zX$$$`h~Gw>0Dx=QW^mr^*@%F!Fr$CbThFwUMOlY%g-d$T;JHQM)l<`Epo>I)fAsH@ zX1qIDfpRA^Yeu1fJLlqG)O{-RUPqoCdvmS89{HnASUY^TVeeeihQyWYt!59DBLeTM z`a6H(TFK5GFk6*pB57bPu>HX^lBJk@OcAsFz{6wSz(n*WPQQedLf9IT z7}d2YKA^!($HtbfFfQ?gTD`@ffYel^3B=PwI?eWI0tByl`&aVI>I;rM>l{Q*=eQdx z8u+yOxUcQtQzX^{)y4}f4~%2IxMYR1?lu%~t&Z$}RP$YZk)?X*jnmxLkoR`(E%8jJ zE|>OZjCc*lV)(5BYcXdR24g-%KGR78+T%4uig?w#6fKeGY5w zfBdn@rMZTabE^leH}|(pC@HF8?8cDE$DKjW*q&VJq8@5*X%VXQtI;MNt z8Iu53VLcg4X+RGwnQn^fn3%aze0AbW%RQj+;+j`e_^{*C5{^9!I;2>HEhg5$ATdxGhp_>VXK(o z(ZZ(xcxdRU;SiRgewwOF%H&(1`vV?1C!+Gnd@6BN7W2ptQVig>(%ph>kq`@y-rr6V z0#YOCA$Jr>dW!hDDG*^Gz%%+d$3Fe>MooWa+TU6TQ;U6gPqm0&e)GTc71=@f9Xj^v z060H`xGiA%Tboz_<3eQNC9O71OfpC00bb| z_jMo$UCP)$8-j^!XBP9K*vt@?Us;PV!O%tF!aAq~6giSG?UQ}$&%MXwaFghT<3$ie IpaKFX5cO^kG5`Po literal 0 HcmV?d00001