Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,40 @@ public class GcpManagedChannel extends ManagedChannel {
public static final CallOptions.Key<Integer> CHANNEL_ID_KEY =
CallOptions.Key.create("GcpChannelId");

/** CallOptions key for sticky channel routing without affinity-key map state. */
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, there are quite big changes to grpc-gcp being made in this pull request, but there are no tests that verify these changes. Can we add tests that cover the changes that we make to grpc-gcp here? Relying on tests in the Spanner client is not enough, as this is a standalone library that can be used by other clients.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test coverage for the changes to the Spanner client is also quite thin, but the existing tests generally do cover these changes. One interesting test (if possible) would be a test that really verifies that all requests in a single read/write or multi-use read-only transaction really all use the same gRPC channel (so basically checking the local port where the requests are coming from on the mock server).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some more tests

public static final CallOptions.Key<ChannelAffinityRef> CHANNEL_AFFINITY_REF_KEY =
CallOptions.Key.create("GcpChannelAffinityRef");

/** Opaque sticky channel reference for callers that should not depend on {@link ChannelRef}. */
public static final class ChannelAffinityRef {
private static final int USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK = 1 << 31;
private static final int CHANNEL_ID_MASK = ~USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK;
private static final int NO_CHANNEL_ID = -1;

// Single allocation hot-path state:
// * lower 31 bits: channel id + 1, or 0 when unset.
// * high bit: use a different active channel on the next call.
private final AtomicInteger state = new AtomicInteger();

/** Forces the next RPC to prefer a different active channel if one is available. */
public void useDifferentChannelOnNextCall() {
state.getAndUpdate(value -> value | USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK);
}

private static int channelIdFromState(int state) {
int encodedChannelId = state & CHANNEL_ID_MASK;
return encodedChannelId == 0 ? NO_CHANNEL_ID : encodedChannelId - 1;
}

private static boolean useDifferentChannelOnNextCallFromState(int state) {
return (state & USE_DIFFERENT_CHANNEL_ON_NEXT_CALL_MASK) != 0;
}

private static int stateFromChannelId(int channelId) {
return (channelId + 1) & CHANNEL_ID_MASK;
}
}

@GuardedBy("this")
private Integer bindingIndex = -1;

Expand Down Expand Up @@ -140,6 +174,7 @@ public class GcpManagedChannel extends ManagedChannel {

// The channel pool.
@VisibleForTesting final List<ChannelRef> channelRefs = new CopyOnWriteArrayList<>();
private final Map<Integer, ChannelRef> channelIdToChannelRef = new ConcurrentHashMap<>();
// A set of channels that we removed from the pool and wait for their RPCs to be completed before
// we can shut them down.
final Set<ChannelRef> removedChannelRefs = new HashSet<>();
Expand Down Expand Up @@ -352,6 +387,7 @@ private synchronized void checkScaleDown() {
channelRef.getChannel().shutdown();
// Remove channel from broken channels map.
fallbackMap.remove(channelRef.getId());
channelIdToChannelRef.remove(channelRef.getId());
}
}

Expand All @@ -372,6 +408,7 @@ private void removeOldestChannels(int num) {

for (ChannelRef channelRef : channelsToRemove) {
channelRef.resetAffinityCount();
channelRef.deactivate();
if (channelRef.getState() == ConnectivityState.READY) {
decReadyChannels(false);
}
Expand Down Expand Up @@ -1678,6 +1715,59 @@ protected ChannelRef getChannelRef(@Nullable String key) {
return mappedChannel;
}

/**
* Pick a {@link ChannelRef} using a caller-owned reference instead of grpc-gcp's affinity map.
*/
protected ChannelRef getChannelRefByAffinityRef(ChannelAffinityRef affinityRef) {
maybeDynamicUpscale();
// Retry if another thread updates the caller-owned affinity ref while we are picking a channel.
while (true) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we add a clarifying comment for why we are looping here?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added

int state = affinityRef.state.get();
int channelId = ChannelAffinityRef.channelIdFromState(state);
boolean useDifferentChannel =
ChannelAffinityRef.useDifferentChannelOnNextCallFromState(state);
ChannelRef channelRef =
channelId == ChannelAffinityRef.NO_CHANNEL_ID
? null
: channelIdToChannelRef.get(channelId);
if (!useDifferentChannel && channelRef != null && channelRef.isActive()) {
return channelRef;
}

ChannelRef selectedChannelRef =
useDifferentChannel
? pickLeastBusyChannelDifferentFrom(channelRef)
: pickLeastBusyChannel(/* forFallback= */ false);
if (affinityRef.state.compareAndSet(
state, ChannelAffinityRef.stateFromChannelId(selectedChannelRef.getId()))) {
return selectedChannelRef;
Comment thread
rahul2393 marked this conversation as resolved.
}
}
Comment thread
rahul2393 marked this conversation as resolved.
}

private ChannelRef pickLeastBusyChannelDifferentFrom(@Nullable ChannelRef excludedChannelRef) {
ChannelRef channelRef = pickLeastBusyChannel(/* forFallback= */ false);
if (excludedChannelRef == null || channelRefs.size() <= 1) {
return channelRef;
}
if (channelRef != excludedChannelRef && channelRef.isActive()) {
return channelRef;
}
ChannelRef leastBusyChannelRef = null;
int leastBusyStreams = Integer.MAX_VALUE;
for (ChannelRef candidate : channelRefs) {
if (candidate == excludedChannelRef || !candidate.isActive()) {
continue;
}
int streams = candidate.getActiveStreamsCount();
if (leastBusyChannelRef == null || streams < leastBusyStreams) {
leastBusyChannelRef = candidate;
leastBusyStreams = streams;
}
}
Comment thread
rahul2393 marked this conversation as resolved.
return leastBusyChannelRef == null ? channelRef : leastBusyChannelRef;
}

// Create a new channel and add it to channelRefs.
// If we have a ready channel not in the pool that we wait for completing its RPCs,
// then re-use that channel instead.
Expand All @@ -1688,6 +1778,8 @@ ChannelRef createNewChannel() {
ChannelRef chRef = reusedChannelRef.get();
channelRefs.add(chRef);
removedChannelRefs.remove(chRef);
channelIdToChannelRef.put(chRef.getId(), chRef);
chRef.activate();
logger.finer(log("Channel %d reused.", chRef.getId()));
incReadyChannels(false);
maxChannels.accumulateAndGet(getNumberOfChannels(), Math::max);
Expand All @@ -1696,6 +1788,7 @@ ChannelRef createNewChannel() {

ChannelRef channelRef = new ChannelRef(delegateChannelBuilder.build());
channelRefs.add(channelRef);
channelIdToChannelRef.put(channelRef.getId(), channelRef);
logger.finer(log("Channel %d created.", channelRef.getId()));
maxChannels.accumulateAndGet(getNumberOfChannels(), Math::max);
return channelRef;
Expand Down Expand Up @@ -1961,6 +2054,12 @@ public String authority() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> newCall(
MethodDescriptor<ReqT, RespT> methodDescriptor, CallOptions callOptions) {
ChannelAffinityRef channelAffinityRef = callOptions.getOption(CHANNEL_AFFINITY_REF_KEY);
if (channelAffinityRef != null) {
return new GcpClientCall.SimpleGcpClientCall<>(
this, getChannelRefByAffinityRef(channelAffinityRef), methodDescriptor, callOptions);
Comment thread
rahul2393 marked this conversation as resolved.
}

if (callOptions.getOption(DISABLE_AFFINITY_KEY)
|| DISABLE_AFFINITY_CTX_KEY.get(Context.current())) {
if (logger.isLoggable(Level.FINEST)) {
Expand Down Expand Up @@ -2314,6 +2413,7 @@ protected class ChannelRef {
private final AtomicLong okCalls = new AtomicLong();
private final AtomicLong errCalls = new AtomicLong();
private final ChannelStateMonitor channelStateMonitor;
private volatile boolean active = true;

protected ChannelRef(ManagedChannel channel) {
this(channel, 0, 0);
Expand Down Expand Up @@ -2343,6 +2443,18 @@ protected int getId() {
return channelId;
}

protected boolean isActive() {
return active;
}

private void activate() {
active = true;
}

private void deactivate() {
active = false;
}

protected void affinityCountIncr() {
int count = affinityCount.incrementAndGet();
maxAffinity.accumulateAndGet(count, Math::max);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import static com.google.common.truth.Truth.assertThat;

import com.google.cloud.grpc.GcpManagedChannel.ChannelAffinityRef;
import com.google.cloud.grpc.GcpManagedChannelOptions.GcpChannelPoolOptions;
import io.grpc.CallOptions;
import io.grpc.Channel;
Expand All @@ -28,13 +29,22 @@
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public class ChannelIdPropagationTest {
private static final MethodDescriptor<String, String> METHOD_DESCRIPTOR =
MethodDescriptor.<String, String>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("test/method")
.setRequestMarshaller(new FakeMarshaller<>())
.setResponseMarshaller(new FakeMarshaller<>())
.build();

private static class FakeMarshaller<T> implements MethodDescriptor.Marshaller<T> {
@Override
Expand Down Expand Up @@ -85,16 +95,8 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
.build())
.build();

MethodDescriptor<String, String> methodDescriptor =
MethodDescriptor.<String, String>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("test/method")
.setRequestMarshaller(new FakeMarshaller<>())
.setResponseMarshaller(new FakeMarshaller<>())
.build();

// Use the pool directly (interceptor is already inside)
ClientCall<String, String> newCall = pool.newCall(methodDescriptor, CallOptions.DEFAULT);
ClientCall<String, String> newCall = pool.newCall(METHOD_DESCRIPTOR, CallOptions.DEFAULT);
Metadata headers = new Metadata();

// First call (should initialize channel and correct ID)
Expand All @@ -105,7 +107,7 @@ public void start(Listener<RespT> responseListener, Metadata headers) {
assertThat(channelId.get()).isAnyOf(0, 1, 2);

// Attempt 2
newCall = pool.newCall(methodDescriptor, CallOptions.DEFAULT);
newCall = pool.newCall(METHOD_DESCRIPTOR, CallOptions.DEFAULT);
newCall.start(
new ForwardingClientCall.SimpleForwardingClientCall.Listener<String>() {}, headers);

Expand All @@ -114,4 +116,82 @@ public void start(Listener<RespT> responseListener, Metadata headers) {

pool.shutdownNow();
}

@Test
public void testChannelAffinityRefSticksToSameChannel() {
List<Integer> channelIds = new ArrayList<>();
GcpManagedChannel pool = newPoolWithChannelIdInterceptor(channelIds);

try {
ChannelAffinityRef affinityRef = new ChannelAffinityRef();
CallOptions callOptions =
CallOptions.DEFAULT.withOption(GcpManagedChannel.CHANNEL_AFFINITY_REF_KEY, affinityRef);

startCall(pool, callOptions);
startCall(pool, callOptions);
startCall(pool, callOptions);

assertThat(channelIds).hasSize(3);
assertThat(channelIds.get(1)).isEqualTo(channelIds.get(0));
assertThat(channelIds.get(2)).isEqualTo(channelIds.get(0));
assertThat(pool.affinityKeyToChannelRef).isEmpty();
} finally {
pool.shutdownNow();
}
}

@Test
public void testChannelAffinityRefCanMoveToDifferentChannelOnNextCall() {
List<Integer> channelIds = new ArrayList<>();
GcpManagedChannel pool = newPoolWithChannelIdInterceptor(channelIds);

try {
ChannelAffinityRef affinityRef = new ChannelAffinityRef();
CallOptions callOptions =
CallOptions.DEFAULT.withOption(GcpManagedChannel.CHANNEL_AFFINITY_REF_KEY, affinityRef);

startCall(pool, callOptions);
affinityRef.useDifferentChannelOnNextCall();
startCall(pool, callOptions);
startCall(pool, callOptions);

assertThat(channelIds).hasSize(3);
assertThat(channelIds.get(1)).isNotEqualTo(channelIds.get(0));
assertThat(channelIds.get(2)).isEqualTo(channelIds.get(1));
assertThat(pool.affinityKeyToChannelRef).isEmpty();
} finally {
pool.shutdownNow();
}
}

private static GcpManagedChannel newPoolWithChannelIdInterceptor(List<Integer> channelIds) {
ManagedChannelBuilder<?> builder = ManagedChannelBuilder.forAddress("localhost", 443);
builder.intercept(
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
Integer channelId = callOptions.getOption(GcpManagedChannel.CHANNEL_ID_KEY);
if (channelId != null) {
channelIds.add(channelId);
}
return next.newCall(method, callOptions);
}
});
return (GcpManagedChannel)
GcpManagedChannelBuilder.forDelegateBuilder(builder)
.withOptions(
GcpManagedChannelOptions.newBuilder()
.withChannelPoolOptions(
GcpChannelPoolOptions.newBuilder().setMinSize(3).setMaxSize(3).build())
.build())
.build();
}

private static void startCall(GcpManagedChannel pool, CallOptions callOptions) {
pool.newCall(METHOD_DESCRIPTOR, callOptions)
.start(
new ForwardingClientCall.SimpleForwardingClientCall.Listener<String>() {},
new Metadata());
}
}
Loading
Loading