Skip to content

Commit efce818

Browse files
committed
Interceptor logic.
1 parent bee005f commit efce818

File tree

3 files changed

+63
-26
lines changed

3 files changed

+63
-26
lines changed

interop-testing/src/main/java/io/grpc/testing/integration/TestCases.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,7 @@ public enum TestCases {
6060
CHANNEL_SOAK("sends 'soak_iterations' large_unary rpcs in a loop, each on a new channel"),
6161
ORCA_PER_RPC("report backend metrics per query"),
6262
ORCA_OOB("report backend metrics out-of-band"),
63-
MCS("max concurrent streaming"),
64-
MCSSS("mcs server streaming");
63+
MCS_CS("max concurrent streaming connection scaling");
6564

6665
private final String description;
6766

interop-testing/src/main/java/io/grpc/testing/integration/TestServiceClient.java

Lines changed: 22 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package io.grpc.testing.integration;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20-
import static io.grpc.testing.integration.TestCases.MCS;
20+
import static io.grpc.testing.integration.TestCases.MCS_CS;
2121
import static org.junit.Assert.assertEquals;
2222
import static org.junit.Assert.assertFalse;
2323
import static org.junit.Assert.assertNotEquals;
@@ -576,16 +576,10 @@ private void runTest(TestCases testCase) throws Exception {
576576
break;
577577
}
578578

579-
case MCS: {
579+
case MCS_CS: {
580580
tester.testMcs();
581581
break;
582582
}
583-
584-
case MCSSS: {
585-
tester.testMcs_serverStreaming();
586-
break;
587-
}
588-
589583
default:
590584
throw new IllegalArgumentException("Unknown test case: " + testCase);
591585
}
@@ -622,8 +616,8 @@ private class Tester extends AbstractInteropTest {
622616

623617
@Override
624618
protected ManagedChannelBuilder<?> createChannelBuilder() {
625-
boolean useSubchannelMetricsSink = testCase.equals(MCS.toString());
626-
boolean useGeneric = testCase.equals(MCS.toString())? true : false;
619+
boolean useSubchannelMetricsSink = testCase.equals(MCS_CS.toString());
620+
boolean useGeneric = testCase.equals(MCS_CS.toString())? true : false;
627621
ChannelCredentials channelCredentials;
628622
if (customCredentialsType != null) {
629623
useGeneric = true; // Retain old behavior; avoids erroring if incompatible
@@ -683,7 +677,7 @@ protected ManagedChannelBuilder<?> createChannelBuilder() {
683677
if (serverHostOverride != null) {
684678
channelBuilder.overrideAuthority(serverHostOverride);
685679
}
686-
if (testCase.equals(MCS.toString())) {
680+
if (testCase.equals(MCS_CS.toString())) {
687681
channelBuilder.disableServiceConfigLookUp();
688682
try {
689683
@SuppressWarnings("unchecked")
@@ -1101,32 +1095,37 @@ Object take() throws InterruptedException {
11011095
}
11021096

11031097
public void testMcs() throws Exception {
1104-
StreamingOutputCallResponseObserver responseObserver1 = new StreamingOutputCallResponseObserver();
1098+
StreamingOutputCallResponseObserver responseObserver1 =
1099+
new StreamingOutputCallResponseObserver();
11051100
StreamObserver<StreamingOutputCallRequest> streamObserver1 =
11061101
asyncStub.fullDuplexCall(responseObserver1);
1107-
streamObserver1.onNext(StreamingOutputCallRequest.newBuilder()
1108-
.addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build());
1109-
assertThat(responseObserver1.take()).isInstanceOf(StreamingOutputCallResponse.class);
1102+
StreamingOutputCallRequest request = StreamingOutputCallRequest.newBuilder()
1103+
.setPayload(Payload.newBuilder().setBody(
1104+
ByteString.copyFrom(MCS_CS.description().getBytes())).build()).build();
1105+
streamObserver1.onNext(request);
1106+
Object responseObj = responseObserver1.take();
1107+
StreamingOutputCallResponse callResponse = (StreamingOutputCallResponse) responseObj;
1108+
String clientSocketAddressInCall1 = new String(callResponse.getPayload().getBody().toByteArray());
11101109

11111110
StreamingOutputCallResponseObserver responseObserver2 = new StreamingOutputCallResponseObserver();
11121111
StreamObserver<StreamingOutputCallRequest> streamObserver2 =
11131112
asyncStub.fullDuplexCall(responseObserver2);
1114-
streamObserver2.onNext(StreamingOutputCallRequest.newBuilder()
1115-
.addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build());
1116-
assertThat(responseObserver2.take()).isInstanceOf(StreamingOutputCallResponse.class);
1113+
streamObserver2.onNext(request);
1114+
callResponse = (StreamingOutputCallResponse) responseObserver2.take();
1115+
String clientSocketAddressInCall2 = new String(callResponse.getPayload().getBody().toByteArray());
11171116

1118-
assertThat(fakeMetricsSink.openConnectionCount).isEqualTo(1);
1117+
assertThat(clientSocketAddressInCall1).isEqualTo(clientSocketAddressInCall2);
11191118

11201119
// The first connection is at max rpc call count of 2, so the 3rd rpc will cause a new
11211120
// connection to be created in the same subchannel and not get queued.
11221121
StreamingOutputCallResponseObserver responseObserver3 = new StreamingOutputCallResponseObserver();
11231122
StreamObserver<StreamingOutputCallRequest> streamObserver3 =
11241123
asyncStub.fullDuplexCall(responseObserver3);
1125-
streamObserver3.onNext(StreamingOutputCallRequest.newBuilder()
1126-
.addResponseParameters(ResponseParameters.newBuilder().setSize(1).build()).build());
1127-
assertThat(responseObserver3.take()).isInstanceOf(StreamingOutputCallResponse.class);
1124+
streamObserver3.onNext(request);
1125+
callResponse = (StreamingOutputCallResponse) responseObserver3.take();
1126+
String clientSocketAddressInCall3 = new String(callResponse.getPayload().getBody().toByteArray());
11281127

1129-
assertThat(fakeMetricsSink.openConnectionCount).isEqualTo(2);
1128+
assertThat(clientSocketAddressInCall3).isNotEqualTo(clientSocketAddressInCall1);
11301129
}
11311130

11321131
public void testMcs_serverStreaming() throws Exception {

interop-testing/src/main/java/io/grpc/testing/integration/TestServiceImpl.java

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,19 @@
1616

1717
package io.grpc.testing.integration;
1818

19+
import static io.grpc.Grpc.TRANSPORT_ATTR_REMOTE_ADDR;
20+
import static io.grpc.testing.integration.TestCases.MCS_CS;
21+
1922
import com.google.common.base.Preconditions;
2023
import com.google.common.collect.Queues;
2124
import com.google.errorprone.annotations.concurrent.GuardedBy;
2225
import com.google.protobuf.ByteString;
26+
import io.grpc.Attributes;
27+
import io.grpc.Contexts;
2328
import io.grpc.ForwardingServerCall.SimpleForwardingServerCall;
2429
import io.grpc.Metadata;
2530
import io.grpc.ServerCall;
31+
import io.grpc.ServerCall.Listener;
2632
import io.grpc.ServerCallHandler;
2733
import io.grpc.ServerInterceptor;
2834
import io.grpc.Status;
@@ -42,6 +48,7 @@
4248
import io.grpc.testing.integration.Messages.StreamingOutputCallResponse;
4349
import io.grpc.testing.integration.Messages.TestOrcaReport;
4450
import io.grpc.testing.integration.TestServiceGrpc.AsyncService;
51+
import java.net.SocketAddress;
4552
import java.util.ArrayDeque;
4653
import java.util.Arrays;
4754
import java.util.HashMap;
@@ -55,12 +62,14 @@
5562
import java.util.concurrent.ScheduledExecutorService;
5663
import java.util.concurrent.Semaphore;
5764
import java.util.concurrent.TimeUnit;
65+
import io.grpc.Context;
5866

5967
/**
6068
* Implementation of the business logic for the TestService. Uses an executor to schedule chunks
6169
* sent in response streams.
6270
*/
6371
public class TestServiceImpl implements io.grpc.BindableService, AsyncService {
72+
static Context.Key<SocketAddress> PEER_ADDRESS_CONTEXT_KEY = Context.key("peer-address");
6473
private final Random random = new Random();
6574

6675
private final ScheduledExecutorService executor;
@@ -235,6 +244,16 @@ public void onNext(StreamingOutputCallRequest request) {
235244
.asRuntimeException());
236245
return;
237246
}
247+
if (new String(request.getPayload().getBody().toByteArray()).equals(MCS_CS.description())) {
248+
SocketAddress peerAddress = PEER_ADDRESS_CONTEXT_KEY.get();
249+
ByteString payload = ByteString.copyFrom(peerAddress.toString().getBytes());
250+
StreamingOutputCallResponse.Builder responseBuilder =
251+
StreamingOutputCallResponse.newBuilder();
252+
responseBuilder.setPayload(
253+
Payload.newBuilder()
254+
.setBody(payload));
255+
responseObserver.onNext(responseBuilder.build());
256+
}
238257
dispatcher.enqueue(toChunkQueue(request));
239258
}
240259

@@ -507,7 +526,8 @@ public static List<ServerInterceptor> interceptors() {
507526
return Arrays.asList(
508527
echoRequestHeadersInterceptor(Util.METADATA_KEY),
509528
echoRequestMetadataInHeaders(Util.ECHO_INITIAL_METADATA_KEY),
510-
echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY));
529+
echoRequestMetadataInTrailers(Util.ECHO_TRAILING_METADATA_KEY),
530+
new McsScalingTestcaseInterceptor());
511531
}
512532

513533
/**
@@ -539,6 +559,25 @@ public void close(Status status, Metadata trailers) {
539559
};
540560
}
541561

562+
static class McsScalingTestcaseInterceptor implements ServerInterceptor {
563+
@Override
564+
public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
565+
Metadata headers, ServerCallHandler<ReqT, RespT> next) {
566+
SocketAddress peerAddress = call.getAttributes().get(TRANSPORT_ATTR_REMOTE_ADDR);
567+
568+
// Create a new context with the peer address value
569+
Context newContext = Context.current().withValue(PEER_ADDRESS_CONTEXT_KEY, peerAddress);
570+
try {
571+
572+
// Continue the call processing within the new context
573+
// return newContext.call(() -> next.startCall(call, headers));
574+
return Contexts.interceptCall(newContext, call, headers, next);
575+
} catch (Exception ex) {
576+
throw new RuntimeException(ex);
577+
}
578+
}
579+
}
580+
542581
/**
543582
* Echoes request headers with the specified key(s) from a client into response headers only.
544583
*/

0 commit comments

Comments
 (0)