Skip to content

Commit

Permalink
Handle common methods in single place
Browse files Browse the repository at this point in the history
Consolidate the common method handlings in `CallbackHandlerSupport`.
This will provide consistent behavior across all proxy logic.

Fix #139
  • Loading branch information
ttddyy committed Nov 21, 2023
1 parent dfe23ad commit 50e3237
Show file tree
Hide file tree
Showing 13 changed files with 277 additions and 105 deletions.
11 changes: 5 additions & 6 deletions src/main/java/io/r2dbc/proxy/callback/BatchCallbackHandler.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
import io.r2dbc.spi.Batch;
import io.r2dbc.spi.Result;
import org.reactivestreams.Publisher;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;
import java.util.ArrayList;
Expand Down Expand Up @@ -51,16 +52,14 @@ public BatchCallbackHandler(Batch batch, ConnectionInfo connectionInfo, ProxyCon

@Override
@SuppressWarnings("unchecked")
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();

if ("unwrap".equals(methodName)) {
return this.batch;
} else if ("unwrapConnection".equals(methodName)) {
return this.connectionInfo.getOriginalConnection();
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.batch, args, this.connectionInfo.getOriginalConnection());
}

Object result = proceedExecution(method, this.batch, args, this.proxyConfig.getListeners(), this.connectionInfo, null);
Expand Down
3 changes: 2 additions & 1 deletion src/main/java/io/r2dbc/proxy/callback/CallbackHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ public interface CallbackHandler {
* @param method the method that has invoked on the proxy instance
* @param args an array of objects that has passed to the method invocation.
* this can be {@code null} when method is invoked with no argument.
* @return result returned from the method invocation on the proxy instance
* @return result returned from the method invocation on the proxy instance. (can be {@code null}.)
* @throws Throwable the exception thrown from the method invocation on the proxy instance.
* @throws IllegalArgumentException if {@code proxy} is {@code null}
* @throws IllegalArgumentException if {@code method} is {@code null}
*/
@Nullable
Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable;

}
81 changes: 41 additions & 40 deletions src/main/java/io/r2dbc/proxy/callback/CallbackHandlerSupport.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018-2020 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
import io.r2dbc.proxy.util.Assert;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Wrapped;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
Expand All @@ -32,12 +33,11 @@
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;

import static java.util.stream.Collectors.toSet;

/**
* Defines methods to augment execution of proxy methods used by child classes.
*
Expand Down Expand Up @@ -77,19 +77,12 @@ public interface MethodInvocationStrategy {
return result;
};

private static final Set<Method> PASS_THROUGH_METHODS;

static {
try {
Method objectToStringMethod = Object.class.getMethod("toString");
PASS_THROUGH_METHODS = Arrays.stream(Object.class.getMethods())
.filter(method -> !objectToStringMethod.equals(method))
.collect(toSet());

} catch (NoSuchMethodException e) {
throw new RuntimeException(e);
}
}
private static final Set<String> COMMON_METHODS = new HashSet<>(Arrays.asList(
"toString", "equals", "hashCode",
"unwrap", // "Wrapped#unwrap"
"getProxyConfig", // "ProxyConfigHolder#getProxyConfig"
"unwrapConnection" // "ConnectionHolder#unwrapConnection"
));

protected final ProxyConfig proxyConfig;

Expand All @@ -99,6 +92,38 @@ public CallbackHandlerSupport(ProxyConfig proxyConfig) {
this.proxyConfig = Assert.requireNonNull(proxyConfig, "proxyConfig must not be null");
}

protected boolean isCommonMethod(String methodName) {
return COMMON_METHODS.contains(methodName);
}

@Nullable
protected Object handleCommonMethod(String methodName, Object original, @Nullable Object[] args, @Nullable Connection originalConnection) {
if ("toString".equals(methodName)) {
StringBuilder sb = new StringBuilder();
sb.append(original.getClass().getSimpleName());
sb.append("-proxy [");
sb.append(original);
sb.append("]");
return sb.toString(); // differentiate toString message.
} else if ("equals".equals(methodName)) {
// when target is a proxy, also compares the proxied object
return original.equals(args[0]) || (args[0] instanceof Wrapped && args[0] instanceof ProxyConfigHolder && original.equals(((Wrapped<?>) args[0]).unwrap()));
} else if ("hashCode".equals(methodName)) {
return original.hashCode();
} else if ("getProxyConfig".equals(methodName)) {
return this.proxyConfig; // "ProxyConfigHolder#getProxyConfig"
} else if ("unwrapConnection".equals(methodName)) {
return originalConnection; // "ConnectionHolder#unwrapConnection"
} else if ("unwrap".equals(methodName)) {
if (args == null) {
return original; // for no-arg "unwrap"
} else {
return ((Wrapped<?>) original).unwrap((Class<?>) args[0]);
}
}
throw new IllegalStateException(methodName + " does not match to the common method names.");
}

/**
* Augment method invocation and call method listener.
*
Expand All @@ -121,30 +146,6 @@ protected Object proceedExecution(Method method, Object target, @Nullable Object
Assert.requireNonNull(target, "target must not be null");
Assert.requireNonNull(listener, "listener must not be null");

if (PASS_THROUGH_METHODS.contains(method)) {
try {
return method.invoke(target, args);
} catch (InvocationTargetException ex) {
throw ex.getTargetException();
}
}

// special handling for toString()
if ("toString".equals(method.getName())) {
StringBuilder sb = new StringBuilder();
sb.append(target.getClass().getSimpleName()); // ConnectionFactory, Connection, Batch, or Statement
sb.append("-proxy [");
sb.append(target.toString());
sb.append("]");
return sb.toString(); // differentiate toString message.
}

// special handling for "ProxyConfigHolder#getProxyConfig"
if ("getProxyConfig".equals(method.getName())) {
return this.proxyConfig;
}


StopWatch stopWatch = new StopWatch(this.proxyConfig.getClock());

MutableMethodExecutionInfo executionInfo = new MutableMethodExecutionInfo();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import io.r2dbc.spi.Batch;
import io.r2dbc.spi.Connection;
import io.r2dbc.spi.Statement;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;
import java.util.function.Consumer;
Expand All @@ -44,16 +45,13 @@ public ConnectionCallbackHandler(Connection connection, ConnectionInfo connectio
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();

if ("unwrap".equals(methodName)) {
return this.connection;
} else if ("unwrapConnection".equals(methodName)) {
return this.connection;
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.connection, args, this.connection);
}

Consumer<MethodExecutionInfo> onComplete = null;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018-2020 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -23,6 +23,7 @@
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;
import reactor.core.publisher.Operators;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;

Expand All @@ -41,14 +42,13 @@ public ConnectionFactoryCallbackHandler(ConnectionFactory connectionFactory, Pro
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();

if ("unwrap".equals(methodName)) {
return this.connectionFactory;
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.connectionFactory, args, null);
}

if ("create".equals(methodName)) {
Expand Down
11 changes: 5 additions & 6 deletions src/main/java/io/r2dbc/proxy/callback/ResultCallbackHandler.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Operators;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -71,17 +72,15 @@ public ResultCallbackHandler(Result result, QueryExecutionInfo queryExecutionInf

@Override
@SuppressWarnings("unchecked")
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();
ConnectionInfo connectionInfo = this.queryExecutionInfo.getConnectionInfo();

if ("unwrap".equals(methodName)) { // for Wrapped
return this.result;
} else if ("unwrapConnection".equals(methodName)) { // for ConnectionHolder
return connectionInfo.getOriginalConnection();
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.result, args, connectionInfo.getOriginalConnection());
}

// replace mapping function
Expand Down Expand Up @@ -131,7 +130,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
private Function<Result.Segment, Publisher<?>> createMappingForFlatMap(Function<Result.Segment, Publisher<?>> mapping) {
return (segment) -> {
if (segment instanceof Result.RowSegment) {
Result.RowSegment rowSegmentProxy = this.proxyConfig.getProxyFactory().wrapRowSegment((Result.RowSegment)segment, this.queryExecutionInfo);
Result.RowSegment rowSegmentProxy = this.proxyConfig.getProxyFactory().wrapRowSegment((Result.RowSegment) segment, this.queryExecutionInfo);
return mapping.apply(rowSegmentProxy);
}
return mapping.apply(segment);
Expand Down
11 changes: 5 additions & 6 deletions src/main/java/io/r2dbc/proxy/callback/RowCallbackHandler.java
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2021 the original author or authors.
* Copyright 2021-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -22,6 +22,7 @@
import io.r2dbc.proxy.listener.ResultRowConverter;
import io.r2dbc.proxy.util.Assert;
import io.r2dbc.spi.Row;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;

Expand Down Expand Up @@ -54,17 +55,15 @@ public RowCallbackHandler(Row row, QueryExecutionInfo queryExecutionInfo, ProxyC
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();
ConnectionInfo connectionInfo = this.queryExecutionInfo.getConnectionInfo();

if ("unwrap".equals(methodName)) { // for Wrapped
return this.row;
} else if ("unwrapConnection".equals(methodName)) { // for ConnectionHolder
return connectionInfo.getOriginalConnection();
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.row, args, connectionInfo.getOriginalConnection());
}

// when converter decides to perform the original call("getOperation.proceed()"), this lambda is called.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.r2dbc.proxy.util.Assert;
import io.r2dbc.spi.Result;
import io.r2dbc.spi.Row;
import reactor.util.annotation.Nullable;

import java.lang.reflect.Method;

Expand Down Expand Up @@ -53,17 +54,15 @@ public RowSegmentCallbackHandler(Result.RowSegment rowSegment, QueryExecutionInf
}

@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
public Object invoke(Object proxy, Method method, @Nullable Object[] args) throws Throwable {
Assert.requireNonNull(proxy, "proxy must not be null");
Assert.requireNonNull(method, "method must not be null");

String methodName = method.getName();
ConnectionInfo connectionInfo = this.queryExecutionInfo.getConnectionInfo();

if ("unwrap".equals(methodName)) { // for Wrapped
return this.rowSegment;
} else if ("unwrapConnection".equals(methodName)) { // for ConnectionHolder
return connectionInfo.getOriginalConnection();
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.rowSegment, args, connectionInfo.getOriginalConnection());
}

Object result = proceedExecution(method, this.rowSegment, args, this.proxyConfig.getListeners(), connectionInfo, null);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2018 the original author or authors.
* Copyright 2018-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -69,12 +69,16 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl

String methodName = method.getName();

if ("unwrap".equals(methodName)) {
return this.statement;
} else if ("unwrapConnection".equals(methodName)) {
return this.connectionInfo.getOriginalConnection();
if (isCommonMethod(methodName)) {
return handleCommonMethod(methodName, this.statement, args, this.connectionInfo.getOriginalConnection());
}

// if ("unwrap".equals(methodName)) {
// return this.statement;
// } else if ("unwrapConnection".equals(methodName)) {
// return this.connectionInfo.getOriginalConnection();
// }

if ("bind".equals(methodName) || "bindNull".equals(methodName)) {

BoundValue boundValue;
Expand Down
Loading

0 comments on commit 50e3237

Please sign in to comment.