Skip to content

Commit

Permalink
Merge pull request #421 from Ladicek/fix-connection-sharing
Browse files Browse the repository at this point in the history
fix Redis connection sharing
  • Loading branch information
Ladicek authored Dec 6, 2023
2 parents 86462a9 + e97ca14 commit 305269e
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@

import io.vertx.core.Future;
import io.vertx.core.Handler;
import io.vertx.core.Promise;
import io.vertx.core.impl.ContextInternal;
import io.vertx.core.impl.VertxInternal;
import io.vertx.core.impl.future.PromiseInternal;
import io.vertx.core.impl.logging.Logger;
import io.vertx.core.impl.logging.LoggerFactory;
import io.vertx.core.net.NetClient;
import io.vertx.core.net.ConnectOptions;
import io.vertx.core.net.NetClientOptions;
import io.vertx.core.net.NetSocket;
import io.vertx.core.net.impl.NetClientInternal;
import io.vertx.core.net.impl.endpoint.EndpointManager;
import io.vertx.core.net.impl.endpoint.EndpointProvider;
import io.vertx.core.net.impl.endpoint.Endpoint;
Expand All @@ -45,7 +46,7 @@ class RedisConnectionManager implements EndpointProvider<RedisConnectionManager.
private static final Handler<Throwable> DEFAULT_EXCEPTION_HANDLER = t -> LOG.error("Unhandled Error", t);

private final VertxInternal vertx;
private final NetClient netClient;
private final NetClientInternal netClient;
private final PoolMetrics metrics;
private final NetClientOptions tcpOptions;
private final PoolOptions poolOptions;
Expand All @@ -63,7 +64,7 @@ class RedisConnectionManager implements EndpointProvider<RedisConnectionManager.
this.tracingPolicy = tracingPolicy;
VertxMetrics metricsSPI = this.vertx.metricsSPI();
metrics = metricsSPI != null ? metricsSPI.createPoolMetrics("redis", poolOptions.getName(), poolOptions.getMaxSize()) : null;
this.netClient = vertx.createNetClient(tcpOptions);
this.netClient = (NetClientInternal) vertx.createNetClient(tcpOptions);
this.pooledConnectionManager = new EndpointManager<>();
}

Expand Down Expand Up @@ -118,15 +119,15 @@ public int hashCode() {
static class RedisConnectionProvider implements PoolConnector<RedisConnectionInternal> {

private final VertxInternal vertx;
private final NetClient netClient;
private final NetClientInternal netClient;
private final RedisURI redisURI;
private final Request setup;
private final NetClientOptions netClientOptions;
private final PoolOptions poolOptions;
private final RedisConnectOptions options;
private final TracingPolicy tracingPolicy;

public RedisConnectionProvider(VertxInternal vertx, NetClient netClient, NetClientOptions netClientOptions, PoolOptions poolOptions, RedisConnectOptions options, TracingPolicy tracingPolicy, String connectionString, Request setup) {
public RedisConnectionProvider(VertxInternal vertx, NetClientInternal netClient, NetClientOptions netClientOptions, PoolOptions poolOptions, RedisConnectOptions options, TracingPolicy tracingPolicy, String connectionString, Request setup) {
this.vertx = vertx;
this.netClient = netClient;
this.netClientOptions = netClientOptions;
Expand Down Expand Up @@ -169,8 +170,13 @@ private Future<ConnectResult<RedisConnectionInternal>> connectAndSetup(
boolean connectionStringSsl,
boolean netClientSsl) {
try {
return netClient
.connect(redisURI.socketAddress())
ConnectOptions connectOptions = new ConnectOptions()
.setRemoteAddress(redisURI.socketAddress())
.setSsl(netClientOptions.isSsl())
.setSslOptions(netClientOptions.getSslOptions());
Promise<NetSocket> promise = ctx.promise();
netClient.connectInternal(connectOptions, promise, ctx);
return promise.future()
.compose(so -> {
// upgrade to ssl is only possible for inet sockets
if (connectionStringInetSocket && !netClientSsl && connectionStringSsl) {
Expand Down Expand Up @@ -373,14 +379,14 @@ static class RedisEndpoint extends Endpoint {

final ConnectionPool<RedisConnectionInternal> pool;

public RedisEndpoint(VertxInternal vertx, NetClient netClient, NetClientOptions netClientOptions, PoolOptions poolOptions, RedisConnectOptions connectOptions, TracingPolicy tracingPolicy, Runnable dispose, String connectionString, Request setup) {
public RedisEndpoint(VertxInternal vertx, NetClientInternal netClient, NetClientOptions netClientOptions, PoolOptions poolOptions, RedisConnectOptions connectOptions, TracingPolicy tracingPolicy, Runnable dispose, String connectionString, Request setup) {
super(dispose);
PoolConnector<RedisConnectionInternal> connector = new RedisConnectionProvider(vertx, netClient, netClientOptions, poolOptions, connectOptions, tracingPolicy, connectionString, setup);
pool = ConnectionPool.pool(connector, new int[]{poolOptions.getMaxSize()}, poolOptions.getMaxWaiting());
}

public Future<Lease<RedisConnectionInternal>> requestConnection(ContextInternal ctx) {
PromiseInternal<Lease<RedisConnectionInternal>> promise = ctx.promise();
Promise<Lease<RedisConnectionInternal>> promise = ctx.promise();
pool.acquire(ctx, 0, ar -> {
if (ar.succeeded()) {
// increment the reference counter to avoid the pool to be closed too soon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,15 +180,21 @@ private void taintCheck(CommandImpl cmd) {

@Override
public Future<Response> send(final Request request) {
//System.out.println("send()#" + this.hashCode());
final Promise<Response> promise;
Promise<Response> promise = vertx.promise();
context.execute(() -> doSend(request, promise));
return promise.future();
}

private void doSend(final Request request, Promise<Response> promise) {
//System.out.println("send()#" + this.hashCode());
if (closed) {
throw new IllegalStateException("Connection is closed");
promise.fail("Connection is closed");
return;
}

if (!((RequestImpl) request).valid()) {
return Future.failedFuture("Redis command is not valid, check https://redis.io/commands");
promise.fail("Redis command is not valid, check https://redis.io/commands: " + request);
return;
}

final CommandImpl cmd = (CommandImpl) request.command();
Expand All @@ -204,17 +210,11 @@ public Future<Response> send(final Request request) {
// we might have switch thread/context
synchronized (waiting) {
if (waiting.isFull()) {
return Future.failedFuture("Redis waiting Queue is full");
promise.fail("Redis waiting queue is full");
return;
}
// create a new promise bound to the caller not
// the instance of this object (a.k.a. "context")
promise = vertx.promise();
waiting.offer(promise);
}
} else {
// create a new promise bound to the caller not
// the instance of this object (a.k.a. "context")
promise = vertx.promise();
}
// write to the socket
try {
Expand All @@ -239,26 +239,26 @@ public Future<Response> send(final Request request) {
context.execute(err, this::fail);
promise.fail(err);
}

return promise.future();
}

@Override
public Future<List<Response>> batch(List<Request> commands) {
//System.out.println("batch()#" + this.hashCode());
Promise<List<Response>> promise = vertx.promise();
context.execute(() -> doBatch(commands, promise));
return promise.future();
}

private void doBatch(List<Request> commands, Promise<List<Response>> promise) {
//System.out.println("batch()#" + this.hashCode());
if (closed) {
throw new IllegalStateException("Connection is closed");
promise.fail("Connection is closed");
return;
}

if (commands.isEmpty()) {
LOG.debug("Empty batch");
return Future.succeededFuture(Collections.emptyList());
promise.complete(Collections.emptyList());
} else {
// create a new promise bound to the caller not
// the instance of this object (a.k.a. "context")
final Promise<List<Response>> promise = vertx.promise();

// will re-encode the handler into a list of promises
final List<Promise<Response>> callbacks = new ArrayList<>(commands.size());
final Response[] replies = new Response[commands.size()];
Expand All @@ -274,12 +274,14 @@ public Future<List<Response>> batch(List<Request> commands) {
final CommandImpl cmd = (CommandImpl) req.command();

if (!req.valid()) {
return Future.failedFuture("Redis command is not valid, check https://redis.io/commands");
promise.fail("Redis command is not valid, check https://redis.io/commands: " + req);
return;
}

if (cmd.isPubSub()) {
// mixing pubSub cannot be used on a one-shot operation
return Future.failedFuture("PubSub command in batch not allowed");
promise.fail("PubSub command in batch not allowed");
return;
}
// encode to the single buffer
req.encode(messages);
Expand Down Expand Up @@ -334,7 +336,8 @@ public Future<List<Response>> batch(List<Request> commands) {
// we might have switch thread/context
// this means the check needs to be performed again
if (waiting.freeSlots() < callbacks.size()) {
return Future.failedFuture("Redis waiting Queue is full");
promise.fail("Redis waiting queue is full");
return;
}
// offer all handlers to the waiting queue
for (Promise<Response> callback : callbacks) {
Expand All @@ -352,8 +355,6 @@ public Future<List<Response>> batch(List<Request> commands) {
context.execute(err, this::fail);
promise.fail(err);
}

return promise.future();
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/test/java/io/vertx/redis/client/test/RedisTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ public void simpleFullQueue(TestContext should) {

Future.all(futures)
.onFailure(f -> {
should.assertEquals("Redis waiting Queue is full", f.getMessage());
should.assertEquals("Redis waiting queue is full", f.getMessage());
test.complete();
})
.onSuccess(r -> {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package io.vertx.redis.client.test;

import io.vertx.core.AbstractVerticle;
import io.vertx.core.DeploymentOptions;
import io.vertx.core.Future;
import io.vertx.core.Promise;
import io.vertx.core.Vertx;
import io.vertx.ext.unit.Async;
import io.vertx.ext.unit.TestContext;
import io.vertx.ext.unit.junit.VertxUnitRunner;
import io.vertx.redis.client.Redis;
import io.vertx.redis.client.RedisAPI;
import io.vertx.redis.client.RedisClientType;
import io.vertx.redis.client.RedisOptions;
import io.vertx.redis.client.Response;
import org.junit.After;
import org.junit.Before;
import org.junit.ClassRule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.testcontainers.containers.GenericContainer;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;

@RunWith(VertxUnitRunner.class)
public class SharedRedisConnectionTest {

@ClassRule
public static final GenericContainer<?> redis = new GenericContainer<>("redis:7")
.withExposedPorts(6379);

private static final int VERTICLES_COUNT = 10;
private static final int ITERATIONS_COUNT = 1000;

private static final String REDIS_NUMBER_VALUE_KEY = "user:post:pinned:1372";
private static final String REDIS_SET_VALUE_KEY = "user:like:post:975";

Vertx vertx;
RedisAPI conn;

@Before
public void setup(TestContext test) {
Async async = test.async();
vertx = Vertx.vertx();
RedisOptions options = new RedisOptions()
.setConnectionString("redis://" + redis.getHost() + ":" + redis.getFirstMappedPort())
.setMaxWaitingHandlers(VERTICLES_COUNT * ITERATIONS_COUNT * 2); // 2 requests per iteration
Redis.createClient(vertx, options)
.connect()
.map(RedisAPI::api)
.flatMap(api -> {
return api.set(Arrays.asList(REDIS_NUMBER_VALUE_KEY, "42"))
.map(api);
}).flatMap(api -> {
return api.sadd(Arrays.asList(REDIS_SET_VALUE_KEY, "100", "101", "102"))
.map(api);
})
.onComplete(result -> {
if (result.succeeded()) {
conn = result.result();
} else {
test.fail(result.cause());
}
async.complete();
});
}

@After
public void teardown(TestContext test) {
conn.close();
vertx.close().onComplete(test.asyncAssertSuccess());
}

@Test
public void test(TestContext test) {
vertx.deployVerticle(() -> new MyVerticle(conn, test), new DeploymentOptions().setInstances(VERTICLES_COUNT));
}

public static class MyVerticle extends AbstractVerticle {
private final RedisAPI conn;
private final TestContext test;

public MyVerticle(RedisAPI conn, TestContext test) {
this.conn = conn;
this.test = test;
}

@Override
public void start() {
Async async = test.async(ITERATIONS_COUNT);
for (int i = 0; i < ITERATIONS_COUNT; i++) {
test()
.onSuccess(ignored -> async.countDown())
.onFailure(test::fail);
}
}

private Future<?> test() {
Future<Response> fetchNumberFuture = conn.get(REDIS_NUMBER_VALUE_KEY)
.onSuccess(response -> {
try {
response.toInteger();
} catch (Exception e) {
test.fail(e);
}
});

Future<Response> fetchSetFuture = conn.smembers(REDIS_SET_VALUE_KEY)
.onSuccess(response -> {
try {
for (Response part : response) {
part.toInteger();
}
} catch (Exception e) {
test.fail(e);
}
});

return Future.all(fetchNumberFuture, fetchSetFuture);
}
}
}
29 changes: 13 additions & 16 deletions src/test/java/io/vertx/test/redis/RedisClientTLSTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -116,18 +116,13 @@ public void testConnectionStringUpgrade(TestContext should) {
.setConnectionString("rediss://0.0.0.0:" + port + "?test=upgrade"));

client.connect()
.onFailure(err -> {
System.out.println("REDIS CLIENT (CONNECT) ERR: " + err);
})
.onSuccess(conn -> {
.onComplete(should.asyncAssertSuccess(conn -> {
conn.send(Request.cmd(Command.PING))
.onFailure(should::fail)
.onSuccess(res -> {
System.out.println("REDIS CLIENT SUCCESS");
.onComplete(should.asyncAssertSuccess(ignored -> {
conn.close();
test.complete();
});
});
}));
}));
}

@Test(timeout = 30_000L)
Expand All @@ -145,12 +140,13 @@ public void testConnectionOptions(TestContext should) {
.setConnectionString("rediss://localhost:" + server.actualPort()));

client.connect()
.onFailure(should::fail)
.onSuccess(conn -> {
.onComplete(should.asyncAssertSuccess(conn -> {
conn.send(Request.cmd(Command.PING))
.onFailure(should::fail)
.onSuccess(res -> test.complete());
});
.onComplete(should.asyncAssertSuccess(ignored -> {
conn.close();
test.complete();
}));
}));
}

@Test(timeout = 30_000L)
Expand All @@ -169,7 +165,8 @@ public void testInvalidOptions(TestContext should) {
.setConnectionString("redis://localhost:" + server.actualPort()));

client.connect()
.onFailure(t -> test.complete())
.onSuccess(res -> should.fail());
.onComplete(should.asyncAssertFailure(ignored -> {
test.complete();
}));
}
}

0 comments on commit 305269e

Please sign in to comment.