Add ClientTransportFilter (#10646)

* Add ClientTransportFilter
This commit is contained in:
joybestourous 2024-01-03 13:45:22 -05:00 committed by GitHub
parent 7692a9f5db
commit 91d15ce4e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 205 additions and 14 deletions

View File

@ -0,0 +1,51 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc;
/**
* Listens on the client transport life-cycle events. These filters do not have the capability
* to modify the channels or transport life-cycle event behavior, but they can be useful hooks
* for transport observability. Multiple filters may be registered to the client.
*
* @since 1.61.0
*/
@ExperimentalApi("https://gitub.com/grpc/grpc-java/issues/10652")
public abstract class ClientTransportFilter {
/**
* Called when a transport is ready to accept traffic (when a connection has been established).
* The default implementation is a no-op.
*
* @param transportAttrs current transport attributes
*
* @return new transport attributes. Default implementation returns the passed-in attributes
* intact.
*/
public Attributes transportReady(Attributes transportAttrs) {
return transportAttrs;
}
/**
* Called when a transport completed shutting down. All resources have been released.
* All streams have either been closed or transferred off this transport.
* Default implementation is a no-op
*
* @param transportAttrs the effective transport attributes, which is what is returned by {@link
* #transportReady} of the last executed filter.
*/
public void transportTerminated(Attributes transportAttrs) {
}
}

View File

@ -94,6 +94,12 @@ public abstract class ForwardingChannelBuilder2<T extends ManagedChannelBuilder<
return thisT(); return thisT();
} }
@Override
public T addTransportFilter(ClientTransportFilter transportFilter) {
delegate().addTransportFilter(transportFilter);
return thisT();
}
@Override @Override
public T userAgent(String userAgent) { public T userAgent(String userAgent) {
delegate().userAgent(userAgent); delegate().userAgent(userAgent);

View File

@ -159,6 +159,18 @@ public abstract class ManagedChannelBuilder<T extends ManagedChannelBuilder<T>>
*/ */
public abstract T intercept(ClientInterceptor... interceptors); public abstract T intercept(ClientInterceptor... interceptors);
/**
* Adds a {@link ClientTransportFilter}. The order of filters being added is the order they will
* be executed
*
* @return this
* @since 1.60.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10652")
public T addTransportFilter(ClientTransportFilter filter) {
throw new UnsupportedOperationException();
}
/** /**
* Provides a custom {@code User-Agent} for the application. * Provides a custom {@code User-Agent} for the application.
* *

View File

@ -758,6 +758,7 @@ public abstract class BinderTransport
// triggers), could have shut us down. // triggers), could have shut us down.
if (!isShutdown()) { if (!isShutdown()) {
setState(TransportState.READY); setState(TransportState.READY);
attributes = clientTransportListener.filterTransport(attributes);
clientTransportListener.transportReady(); clientTransportListener.transportReady();
} }
} }

View File

@ -35,6 +35,7 @@ import io.grpc.CallOptions;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel; import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer;
import io.grpc.ClientTransportFilter;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
@ -77,6 +78,8 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
private final ChannelTracer channelTracer; private final ChannelTracer channelTracer;
private final ChannelLogger channelLogger; private final ChannelLogger channelLogger;
private final List<ClientTransportFilter> transportFilters;
/** /**
* All field must be mutated in the syncContext. * All field must be mutated in the syncContext.
*/ */
@ -159,7 +162,8 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor, ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Supplier<Stopwatch> stopwatchSupplier, SynchronizationContext syncContext, Callback callback, Supplier<Stopwatch> stopwatchSupplier, SynchronizationContext syncContext, Callback callback,
InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer, InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer,
InternalLogId logId, ChannelLogger channelLogger) { InternalLogId logId, ChannelLogger channelLogger,
List<ClientTransportFilter> transportFilters) {
Preconditions.checkNotNull(addressGroups, "addressGroups"); Preconditions.checkNotNull(addressGroups, "addressGroups");
Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty"); Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty");
checkListHasNoNulls(addressGroups, "addressGroups contains null entry"); checkListHasNoNulls(addressGroups, "addressGroups contains null entry");
@ -180,6 +184,7 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
this.channelTracer = Preconditions.checkNotNull(channelTracer, "channelTracer"); this.channelTracer = Preconditions.checkNotNull(channelTracer, "channelTracer");
this.logId = Preconditions.checkNotNull(logId, "logId"); this.logId = Preconditions.checkNotNull(logId, "logId");
this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger"); this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger");
this.transportFilters = transportFilters;
} }
ChannelLogger getChannelLogger() { ChannelLogger getChannelLogger() {
@ -539,6 +544,15 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
this.transport = transport; this.transport = transport;
} }
@Override
public Attributes filterTransport(Attributes attributes) {
for (ClientTransportFilter filter : transportFilters) {
attributes = Preconditions.checkNotNull(filter.transportReady(attributes),
"Filter %s returned null", filter);
}
return attributes;
}
@Override @Override
public void transportReady() { public void transportReady() {
channelLogger.log(ChannelLogLevel.INFO, "READY"); channelLogger.log(ChannelLogLevel.INFO, "READY");
@ -607,6 +621,9 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
channelLogger.log(ChannelLogLevel.INFO, "{0} Terminated", transport.getLogId()); channelLogger.log(ChannelLogLevel.INFO, "{0} Terminated", transport.getLogId());
channelz.removeClientSocket(transport); channelz.removeClientSocket(transport);
handleTransportInUseState(transport, false); handleTransportInUseState(transport, false);
for (ClientTransportFilter filter : transportFilters) {
filter.transportTerminated(transport.getAttributes());
}
syncContext.execute(new Runnable() { syncContext.execute(new Runnable() {
@Override @Override
public void run() { public void run() {

View File

@ -42,6 +42,7 @@ import io.grpc.ClientCall;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors; import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer;
import io.grpc.ClientTransportFilter;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
@ -209,6 +210,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
* {@link RealChannel}. * {@link RealChannel}.
*/ */
private final Channel interceptorChannel; private final Channel interceptorChannel;
private final List<ClientTransportFilter> transportFilters;
@Nullable private final String userAgent; @Nullable private final String userAgent;
// Only null after channel is terminated. Must be assigned from the syncContext. // Only null after channel is terminated. Must be assigned from the syncContext.
@ -661,6 +664,7 @@ final class ManagedChannelImpl extends ManagedChannel implements
channel = builder.binlog.wrapChannel(channel); channel = builder.binlog.wrapChannel(channel);
} }
this.interceptorChannel = ClientInterceptors.intercept(channel, interceptors); this.interceptorChannel = ClientInterceptors.intercept(channel, interceptors);
this.transportFilters = new ArrayList<>(builder.transportFilters);
this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier"); this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
if (builder.idleTimeoutMillis == IDLE_TIMEOUT_MILLIS_DISABLE) { if (builder.idleTimeoutMillis == IDLE_TIMEOUT_MILLIS_DISABLE) {
this.idleTimeoutMillis = builder.idleTimeoutMillis; this.idleTimeoutMillis = builder.idleTimeoutMillis;
@ -1566,7 +1570,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
callTracerFactory.create(), callTracerFactory.create(),
subchannelTracer, subchannelTracer,
subchannelLogId, subchannelLogId,
subchannelLogger); subchannelLogger,
transportFilters);
oobChannelTracer.reportEvent(new ChannelTrace.Event.Builder() oobChannelTracer.reportEvent(new ChannelTrace.Event.Builder()
.setDescription("Child Subchannel created") .setDescription("Child Subchannel created")
.setSeverity(ChannelTrace.Event.Severity.CT_INFO) .setSeverity(ChannelTrace.Event.Severity.CT_INFO)
@ -1990,7 +1995,8 @@ final class ManagedChannelImpl extends ManagedChannel implements
callTracerFactory.create(), callTracerFactory.create(),
subchannelTracer, subchannelTracer,
subchannelLogId, subchannelLogId,
subchannelLogger); subchannelLogger,
transportFilters);
channelTracer.reportEvent(new ChannelTrace.Event.Builder() channelTracer.reportEvent(new ChannelTrace.Event.Builder()
.setDescription("Child Subchannel started") .setDescription("Child Subchannel started")
@ -2148,6 +2154,11 @@ final class ManagedChannelImpl extends ManagedChannel implements
// Don't care // Don't care
} }
@Override
public Attributes filterTransport(Attributes attributes) {
return attributes;
}
@Override @Override
public void transportInUse(final boolean inUse) { public void transportInUse(final boolean inUse) {
inUseStateAggregator.updateObjectInUse(delayedTransport, inUse); inUseStateAggregator.updateObjectInUse(delayedTransport, inUse);

View File

@ -17,6 +17,7 @@
package io.grpc.internal; package io.grpc.internal;
import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions; import com.google.common.base.Preconditions;
@ -27,6 +28,7 @@ import io.grpc.BinaryLog;
import io.grpc.CallCredentials; import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials; import io.grpc.ChannelCredentials;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ClientTransportFilter;
import io.grpc.CompressorRegistry; import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup; import io.grpc.EquivalentAddressGroup;
@ -137,6 +139,8 @@ public final class ManagedChannelImplBuilder
private final List<ClientInterceptor> interceptors = new ArrayList<>(); private final List<ClientInterceptor> interceptors = new ArrayList<>();
NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry(); NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry();
final List<ClientTransportFilter> transportFilters = new ArrayList<>();
final String target; final String target;
@Nullable @Nullable
final ChannelCredentials channelCredentials; final ChannelCredentials channelCredentials;
@ -267,11 +271,11 @@ public final class ManagedChannelImplBuilder
String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds, String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds,
ClientTransportFactoryBuilder clientTransportFactoryBuilder, ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) { @Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = Preconditions.checkNotNull(target, "target"); this.target = checkNotNull(target, "target");
this.channelCredentials = channelCreds; this.channelCredentials = channelCreds;
this.callCredentials = callCreds; this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions this.clientTransportFactoryBuilder = checkNotNull(clientTransportFactoryBuilder,
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); "clientTransportFactoryBuilder");
this.directServerAddress = null; this.directServerAddress = null;
if (channelBuilderDefaultPortProvider != null) { if (channelBuilderDefaultPortProvider != null) {
@ -323,8 +327,8 @@ public final class ManagedChannelImplBuilder
this.target = makeTargetStringForDirectAddress(directServerAddress); this.target = makeTargetStringForDirectAddress(directServerAddress);
this.channelCredentials = channelCreds; this.channelCredentials = channelCreds;
this.callCredentials = callCreds; this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions this.clientTransportFactoryBuilder = checkNotNull(clientTransportFactoryBuilder,
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder"); "clientTransportFactoryBuilder");
this.directServerAddress = directServerAddress; this.directServerAddress = directServerAddress;
NameResolverRegistry reg = new NameResolverRegistry(); NameResolverRegistry reg = new NameResolverRegistry();
reg.register(new DirectAddressNameResolverProvider(directServerAddress, reg.register(new DirectAddressNameResolverProvider(directServerAddress,
@ -374,6 +378,12 @@ public final class ManagedChannelImplBuilder
return intercept(Arrays.asList(interceptors)); return intercept(Arrays.asList(interceptors));
} }
@Override
public ManagedChannelImplBuilder addTransportFilter(ClientTransportFilter hook) {
transportFilters.add(checkNotNull(hook, "transport filter"));
return this;
}
@Deprecated @Deprecated
@Override @Override
public ManagedChannelImplBuilder nameResolverFactory(NameResolver.Factory resolverFactory) { public ManagedChannelImplBuilder nameResolverFactory(NameResolver.Factory resolverFactory) {

View File

@ -16,6 +16,7 @@
package io.grpc.internal; package io.grpc.internal;
import io.grpc.Attributes;
import io.grpc.Status; import io.grpc.Status;
import javax.annotation.CheckReturnValue; import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -104,5 +105,11 @@ public interface ManagedClientTransport extends ClientTransport {
* at least one stream. * at least one stream.
*/ */
void transportInUse(boolean inUse); void transportInUse(boolean inUse);
/**
* Called just before {@link #transportReady} to allow direct modification of transport
* Attributes.
*/
Attributes filterTransport(Attributes attributes);
} }
} }

View File

@ -130,6 +130,11 @@ final class OobChannel extends ManagedChannel implements InternalInstrumented<Ch
// Don't care // Don't care
} }
@Override
public Attributes filterTransport(Attributes attributes) {
return attributes;
}
@Override @Override
public void transportInUse(boolean inUse) { public void transportInUse(boolean inUse) {
// Don't care // Don't care

View File

@ -54,6 +54,7 @@ import io.grpc.internal.InternalSubchannel.TransportLogger;
import io.grpc.internal.TestUtils.MockClientTransportInfo; import io.grpc.internal.TestUtils.MockClientTransportInfo;
import java.net.SocketAddress; import java.net.SocketAddress;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList; import java.util.LinkedList;
import java.util.List; import java.util.List;
import java.util.concurrent.BlockingQueue; import java.util.concurrent.BlockingQueue;
@ -1360,7 +1361,8 @@ public class InternalSubchannelTest {
channelz, CallTracer.getDefaultFactory().create(), channelz, CallTracer.getDefaultFactory().create(),
subchannelTracer, subchannelTracer,
logId, logId,
new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider())); new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()),
Collections.emptyList());
} }
private void assertNoCallbackInvoke() { private void assertNoCallbackInvoke() {

View File

@ -72,6 +72,7 @@ import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors; import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer; import io.grpc.ClientStreamTracer;
import io.grpc.ClientStreamTracer.StreamInfo; import io.grpc.ClientStreamTracer.StreamInfo;
import io.grpc.ClientTransportFilter;
import io.grpc.CompositeChannelCredentials; import io.grpc.CompositeChannelCredentials;
import io.grpc.ConnectivityState; import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo; import io.grpc.ConnectivityStateInfo;
@ -139,6 +140,7 @@ import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable; import javax.annotation.Nullable;
@ -4240,6 +4242,48 @@ public class ManagedChannelImplTest {
} }
} }
@Test
public void transportFilters() {
final AtomicInteger readyCallbackCalled = new AtomicInteger(0);
final AtomicInteger terminationCallbackCalled = new AtomicInteger(0);
ClientTransportFilter transportFilter = new ClientTransportFilter() {
@Override
public Attributes transportReady(Attributes transportAttrs) {
readyCallbackCalled.incrementAndGet();
return transportAttrs;
}
@Override
public void transportTerminated(Attributes transportAttrs) {
terminationCallbackCalled.incrementAndGet();
}
};
channelBuilder.addTransportFilter(transportFilter);
assertEquals(0, readyCallbackCalled.get());
createChannel();
final Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory)
.newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
MockClientTransportInfo transportInfo = transports.poll();
ManagedClientTransport.Listener transportListener = transportInfo.listener;
transportListener.filterTransport(Attributes.EMPTY);
transportListener.transportReady();
assertEquals(1, readyCallbackCalled.get());
assertEquals(0, terminationCallbackCalled.get());
transportListener.transportShutdown(Status.OK);
transportListener.transportTerminated();
assertEquals(1, terminationCallbackCalled.get());
}
private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider { private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
@Override @Override
public BackoffPolicy get() { public BackoffPolicy get() {

View File

@ -38,6 +38,7 @@ import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.timeout; import static org.mockito.Mockito.timeout;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import com.google.common.base.Objects; import com.google.common.base.Objects;
import com.google.common.collect.Lists; import com.google.common.collect.Lists;
@ -219,6 +220,7 @@ public abstract class AbstractTransportTest {
@Before @Before
public void setUp() { public void setUp() {
server = newServer(Arrays.asList(serverStreamTracerFactory)); server = newServer(Arrays.asList(serverStreamTracerFactory));
when(mockClientTransportListener.filterTransport(any())).thenAnswer(i -> i.getArguments()[0]);
} }
@After @After
@ -2136,6 +2138,7 @@ public abstract class AbstractTransportTest {
ManagedClientTransport clientTransport, ManagedClientTransport clientTransport,
ManagedClientTransport.Listener listener) { ManagedClientTransport.Listener listener) {
runIfNotNull(clientTransport.start(listener)); runIfNotNull(clientTransport.start(listener));
verify(listener, timeout(TIMEOUT_MS)).filterTransport(any());
verify(listener, timeout(TIMEOUT_MS)).transportReady(); verify(listener, timeout(TIMEOUT_MS)).transportReady();
} }

View File

@ -61,7 +61,7 @@ class CronetClientTransport implements ConnectionClientTransport {
private final int maxMessageSize; private final int maxMessageSize;
private final boolean alwaysUsePut; private final boolean alwaysUsePut;
private final TransportTracer transportTracer; private final TransportTracer transportTracer;
private final Attributes attrs; private Attributes attrs;
private final boolean useGetForSafeMethods; private final boolean useGetForSafeMethods;
private final boolean usePutForIdempotentMethods; private final boolean usePutForIdempotentMethods;
// Indicates the transport is in go-away state: no new streams will be processed, // Indicates the transport is in go-away state: no new streams will be processed,
@ -169,6 +169,7 @@ class CronetClientTransport implements ConnectionClientTransport {
return new Runnable() { return new Runnable() {
@Override @Override
public void run() { public void run() {
attrs = CronetClientTransport.this.listener.filterTransport(attrs);
// Listener callbacks should not be called simultaneously // Listener callbacks should not be called simultaneously
CronetClientTransport.this.listener.transportReady(); CronetClientTransport.this.listener.transportReady();
} }

View File

@ -80,6 +80,8 @@ public final class CronetClientTransportTest {
@Before @Before
public void setUp() { public void setUp() {
when(clientTransportListener.filterTransport(any()))
.thenAnswer(i -> i.getArgument(0, Attributes.class));
transport = transport =
new CronetClientTransport( new CronetClientTransport(
streamFactory, streamFactory,

View File

@ -106,7 +106,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans
new IdentityHashMap<InProcessStream, Boolean>()); new IdentityHashMap<InProcessStream, Boolean>());
@GuardedBy("this") @GuardedBy("this")
private List<ServerStreamTracer.Factory> serverStreamTracerFactories; private List<ServerStreamTracer.Factory> serverStreamTracerFactories;
private final Attributes attributes; private Attributes attributes;
private Thread.UncaughtExceptionHandler uncaughtExceptionHandler = private Thread.UncaughtExceptionHandler uncaughtExceptionHandler =
new Thread.UncaughtExceptionHandler() { new Thread.UncaughtExceptionHandler() {
@ -213,6 +213,7 @@ final class InProcessTransport implements ServerTransport, ConnectionClientTrans
.set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address) .set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, address)
.build(); .build();
serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs); serverStreamAttributes = serverTransportListener.transportReady(serverTransportAttrs);
attributes = clientTransportListener.filterTransport(attributes);
clientTransportListener.transportReady(); clientTransportListener.transportReady();
} }
} }

View File

@ -17,6 +17,7 @@
package io.grpc.netty; package io.grpc.netty;
import com.google.errorprone.annotations.CanIgnoreReturnValue; import com.google.errorprone.annotations.CanIgnoreReturnValue;
import io.grpc.Attributes;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.internal.ManagedClientTransport; import io.grpc.internal.ManagedClientTransport;
@ -36,12 +37,14 @@ final class ClientTransportLifecycleManager {
this.listener = listener; this.listener = listener;
} }
public void notifyReady() { public Attributes notifyReady(Attributes attributes) {
if (transportReady || transportShutdown) { if (transportReady || transportShutdown) {
return; return attributes;
} }
transportReady = true; transportReady = true;
attributes = listener.filterTransport(attributes);
listener.transportReady(); listener.transportReady();
return attributes;
} }
/** /**

View File

@ -917,7 +917,7 @@ class NettyClientHandler extends AbstractNettyHandler {
public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) { public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) {
if (firstSettings) { if (firstSettings) {
firstSettings = false; firstSettings = false;
lifecycleManager.notifyReady(); attributes = lifecycleManager.notifyReady(attributes);
} }
} }

View File

@ -36,6 +36,8 @@ import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame; import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail; import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import com.google.common.base.Strings; import com.google.common.base.Strings;
import com.google.common.base.Ticker; import com.google.common.base.Ticker;
@ -111,6 +113,7 @@ import javax.annotation.Nullable;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException; import javax.net.ssl.SSLHandshakeException;
import org.junit.After; import org.junit.After;
import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.runner.RunWith; import org.junit.runner.RunWith;
@ -149,6 +152,11 @@ public class NettyClientTransportTest {
private String authority; private String authority;
private NettyServer server; private NettyServer server;
@Before
public void setup() {
when(clientTransportListener.filterTransport(any())).thenAnswer(i -> i.getArguments()[0]);
}
@After @After
public void teardown() throws Exception { public void teardown() throws Exception {
for (NettyClientTransport transport : transports) { for (NettyClientTransport transport : transports) {

View File

@ -20,6 +20,7 @@ import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import com.google.common.util.concurrent.SettableFuture; import com.google.common.util.concurrent.SettableFuture;
import io.grpc.Attributes;
import io.grpc.ChannelLogger; import io.grpc.ChannelLogger;
import io.grpc.ServerStreamTracer; import io.grpc.ServerStreamTracer;
import io.grpc.Status; import io.grpc.Status;
@ -144,6 +145,11 @@ public class NettyTransportTest extends AbstractTransportTest {
Throwable t = new Throwable("transport should have failed and shutdown but didnt"); Throwable t = new Throwable("transport should have failed and shutdown but didnt");
future.setException(t); future.setException(t);
} }
@Override
public Attributes filterTransport(Attributes attributes) {
return attributes;
}
}); });
if (runnable != null) { if (runnable != null) {
runnable.run(); runnable.run();

View File

@ -1284,6 +1284,7 @@ class OkHttpClientTransport implements ConnectionClientTransport, TransportExcep
outboundWindowSizeIncreased = outboundFlow.initialOutboundWindowSize(initialWindowSize); outboundWindowSizeIncreased = outboundFlow.initialOutboundWindowSize(initialWindowSize);
} }
if (firstSettings) { if (firstSettings) {
attributes = listener.filterTransport(attributes);
listener.transportReady(); listener.transportReady();
firstSettings = false; firstSettings = false;
} }