xds: Reuse filter interceptors across RPCs

This moves the interceptor creation from the ConfigSelector to the
resource update handling.

The code structure changes will make adding support for filter
lifecycles (for RLQS) a bit easier. The filter lifecycles will allow
filters to share state across interceptors, and constructing all the
interceptors on a single thread will mean filters wouldn't need to be
thread-safe (but their interceptors would be thread-safe).
This commit is contained in:
Eric Anderson 2025-01-30 12:43:51 -08:00 committed by GitHub
parent 90aefb26e7
commit c506190b0f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 136 additions and 82 deletions

View File

@ -37,7 +37,6 @@ import io.grpc.Context;
import io.grpc.Deadline;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -183,7 +182,7 @@ final class FaultFilter implements Filter, ClientInterceptorBuilder {
@Nullable
@Override
public ClientInterceptor buildClientInterceptor(
FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
FilterConfig config, @Nullable FilterConfig overrideConfig,
final ScheduledExecutorService scheduler) {
checkNotNull(config, "config");
if (overrideConfig != null) {

View File

@ -19,7 +19,6 @@ package io.grpc.xds;
import com.google.common.base.MoreObjects;
import com.google.protobuf.Message;
import io.grpc.ClientInterceptor;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.ServerInterceptor;
import java.util.Objects;
import java.util.concurrent.ScheduledExecutorService;
@ -59,7 +58,7 @@ interface Filter {
interface ClientInterceptorBuilder {
@Nullable
ClientInterceptor buildClientInterceptor(
FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
FilterConfig config, @Nullable FilterConfig overrideConfig,
ScheduledExecutorService scheduler);
}

View File

@ -31,7 +31,6 @@ import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.CompositeCallCredentials;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Status;
@ -97,8 +96,7 @@ final class GcpAuthenticationFilter implements Filter, ClientInterceptorBuilder
@Nullable
@Override
public ClientInterceptor buildClientInterceptor(FilterConfig config,
@Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
ScheduledExecutorService scheduler) {
@Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) {
ComputeEngineCredentials credentials = ComputeEngineCredentials.create();
LruCache<String, CallCredentials> callCredentialsCache =

View File

@ -18,7 +18,6 @@ package io.grpc.xds;
import com.google.protobuf.Message;
import io.grpc.ClientInterceptor;
import io.grpc.LoadBalancer.PickSubchannelArgs;
import io.grpc.ServerInterceptor;
import io.grpc.xds.Filter.ClientInterceptorBuilder;
import io.grpc.xds.Filter.ServerInterceptorBuilder;
@ -64,7 +63,7 @@ enum RouterFilter implements Filter, ClientInterceptorBuilder, ServerInterceptor
@Nullable
@Override
public ClientInterceptor buildClientInterceptor(
FilterConfig config, @Nullable FilterConfig overrideConfig, PickSubchannelArgs args,
FilterConfig config, @Nullable FilterConfig overrideConfig,
ScheduledExecutorService scheduler) {
return null;
}

View File

@ -59,6 +59,7 @@ import io.grpc.xds.VirtualHost.Route.RouteAction;
import io.grpc.xds.VirtualHost.Route.RouteAction.ClusterWeight;
import io.grpc.xds.VirtualHost.Route.RouteAction.HashPolicy;
import io.grpc.xds.VirtualHost.Route.RouteAction.RetryPolicy;
import io.grpc.xds.VirtualHost.Route.RouteMatch;
import io.grpc.xds.XdsNameResolverProvider.CallCounterProvider;
import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate;
import io.grpc.xds.client.Bootstrapper.AuthorityInfo;
@ -384,20 +385,17 @@ final class XdsNameResolver extends NameResolver {
@Override
public Result selectConfig(PickSubchannelArgs args) {
String cluster = null;
Route selectedRoute = null;
ClientInterceptor filters = null; // null iff cluster is null
RouteData selectedRoute = null;
RoutingConfig routingCfg;
Map<String, FilterConfig> selectedOverrideConfigs;
List<ClientInterceptor> filterInterceptors = new ArrayList<>();
Metadata headers = args.getHeaders();
do {
routingCfg = routingConfig;
selectedOverrideConfigs = new HashMap<>(routingCfg.virtualHostOverrideConfig);
for (Route route : routingCfg.routes) {
for (RouteData route : routingCfg.routes) {
if (RoutingUtils.matchRoute(
route.routeMatch(), "/" + args.getMethodDescriptor().getFullMethodName(),
headers, random)) {
route.routeMatch, "/" + args.getMethodDescriptor().getFullMethodName(),
headers, random)) {
selectedRoute = route;
selectedOverrideConfigs.putAll(route.filterConfigOverrides());
break;
}
}
@ -405,13 +403,14 @@ final class XdsNameResolver extends NameResolver {
return Result.forError(
Status.UNAVAILABLE.withDescription("Could not find xDS route matching RPC"));
}
if (selectedRoute.routeAction() == null) {
if (selectedRoute.routeAction == null) {
return Result.forError(Status.UNAVAILABLE.withDescription(
"Could not route RPC to Route with non-forwarding action"));
}
RouteAction action = selectedRoute.routeAction();
RouteAction action = selectedRoute.routeAction;
if (action.cluster() != null) {
cluster = prefixedClusterName(action.cluster());
filters = selectedRoute.filterChoices.get(0);
} else if (action.weightedClusters() != null) {
long totalWeight = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
@ -419,23 +418,25 @@ final class XdsNameResolver extends NameResolver {
}
long select = random.nextLong(totalWeight);
long accumulator = 0;
for (ClusterWeight weightedCluster : action.weightedClusters()) {
for (int i = 0; i < action.weightedClusters().size(); i++) {
ClusterWeight weightedCluster = action.weightedClusters().get(i);
accumulator += weightedCluster.weight();
if (select < accumulator) {
cluster = prefixedClusterName(weightedCluster.name());
selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides());
filters = selectedRoute.filterChoices.get(i);
break;
}
}
} else if (action.namedClusterSpecifierPluginConfig() != null) {
cluster =
prefixedClusterSpecifierPluginName(action.namedClusterSpecifierPluginConfig().name());
filters = selectedRoute.filterChoices.get(0);
}
} while (!retainCluster(cluster));
Long timeoutNanos = null;
if (enableTimeout) {
if (selectedRoute != null) {
timeoutNanos = selectedRoute.routeAction().timeoutNano();
timeoutNanos = selectedRoute.routeAction.timeoutNano();
}
if (timeoutNanos == null) {
timeoutNanos = routingCfg.fallbackTimeoutNano;
@ -445,7 +446,7 @@ final class XdsNameResolver extends NameResolver {
}
}
RetryPolicy retryPolicy =
selectedRoute == null ? null : selectedRoute.routeAction().retryPolicy();
selectedRoute == null ? null : selectedRoute.routeAction.retryPolicy();
// TODO(chengyuanzhang): avoid service config generation and parsing for each call.
Map<String, ?> rawServiceConfig =
generateServiceConfigWithMethodConfig(timeoutNanos, retryPolicy);
@ -457,24 +458,9 @@ final class XdsNameResolver extends NameResolver {
parsedServiceConfig.getError().augmentDescription(
"Failed to parse service config (method config)"));
}
if (routingCfg.filterChain != null) {
for (NamedFilterConfig namedFilter : routingCfg.filterChain) {
FilterConfig filterConfig = namedFilter.filterConfig;
Filter filter = filterRegistry.get(filterConfig.typeUrl());
if (filter instanceof ClientInterceptorBuilder) {
ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter)
.buildClientInterceptor(
filterConfig, selectedOverrideConfigs.get(namedFilter.name),
args, scheduler);
if (interceptor != null) {
filterInterceptors.add(interceptor);
}
}
}
}
final String finalCluster = cluster;
final long hash = generateHash(selectedRoute.routeAction().hashPolicies(), headers);
Route finalSelectedRoute = selectedRoute;
final long hash = generateHash(selectedRoute.routeAction.hashPolicies(), headers);
RouteData finalSelectedRoute = selectedRoute;
class ClusterSelectionInterceptor implements ClientInterceptor {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
@ -483,7 +469,7 @@ final class XdsNameResolver extends NameResolver {
CallOptions callOptionsForCluster =
callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster)
.withOption(RPC_HASH_KEY, hash);
if (finalSelectedRoute.routeAction().autoHostRewrite()) {
if (finalSelectedRoute.routeAction.autoHostRewrite()) {
callOptionsForCluster = callOptionsForCluster.withOption(AUTO_HOST_REWRITE_KEY, true);
}
return new SimpleForwardingClientCall<ReqT, RespT>(
@ -514,11 +500,11 @@ final class XdsNameResolver extends NameResolver {
}
}
filterInterceptors.add(new ClusterSelectionInterceptor());
return
Result.newBuilder()
.setConfig(config)
.setInterceptor(combineInterceptors(filterInterceptors))
.setInterceptor(combineInterceptors(
ImmutableList.of(filters, new ClusterSelectionInterceptor())))
.build();
}
@ -584,8 +570,18 @@ final class XdsNameResolver extends NameResolver {
}
}
static final class PassthroughClientInterceptor implements ClientInterceptor {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
return next.newCall(method, callOptions);
}
}
private static ClientInterceptor combineInterceptors(final List<ClientInterceptor> interceptors) {
checkArgument(!interceptors.isEmpty(), "empty interceptors");
if (interceptors.size() == 0) {
return new PassthroughClientInterceptor();
}
if (interceptors.size() == 1) {
return interceptors.get(0);
}
@ -722,6 +718,7 @@ final class XdsNameResolver extends NameResolver {
}
List<Route> routes = virtualHost.routes();
ImmutableList.Builder<RouteData> routesData = ImmutableList.builder();
// Populate all clusters to which requests can be routed to through the virtual host.
Set<String> clusters = new HashSet<>();
@ -732,26 +729,34 @@ final class XdsNameResolver extends NameResolver {
for (Route route : routes) {
RouteAction action = route.routeAction();
String prefixedName;
if (action != null) {
if (action.cluster() != null) {
prefixedName = prefixedClusterName(action.cluster());
if (action == null) {
routesData.add(new RouteData(route.routeMatch(), null, ImmutableList.of()));
} else if (action.cluster() != null) {
prefixedName = prefixedClusterName(action.cluster());
clusters.add(prefixedName);
clusterNameMap.put(prefixedName, action.cluster());
ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null);
routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters));
} else if (action.weightedClusters() != null) {
ImmutableList.Builder<ClientInterceptor> filterList = ImmutableList.builder();
for (ClusterWeight weightedCluster : action.weightedClusters()) {
prefixedName = prefixedClusterName(weightedCluster.name());
clusters.add(prefixedName);
clusterNameMap.put(prefixedName, action.cluster());
} else if (action.weightedClusters() != null) {
for (ClusterWeight weighedCluster : action.weightedClusters()) {
prefixedName = prefixedClusterName(weighedCluster.name());
clusters.add(prefixedName);
clusterNameMap.put(prefixedName, weighedCluster.name());
}
} else if (action.namedClusterSpecifierPluginConfig() != null) {
PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config();
if (pluginConfig instanceof RlsPluginConfig) {
prefixedName = prefixedClusterSpecifierPluginName(
action.namedClusterSpecifierPluginConfig().name());
clusters.add(prefixedName);
rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig);
}
clusterNameMap.put(prefixedName, weightedCluster.name());
filterList.add(createFilters(filterConfigs, virtualHost, route, weightedCluster));
}
routesData.add(
new RouteData(route.routeMatch(), route.routeAction(), filterList.build()));
} else if (action.namedClusterSpecifierPluginConfig() != null) {
PluginConfig pluginConfig = action.namedClusterSpecifierPluginConfig().config();
if (pluginConfig instanceof RlsPluginConfig) {
prefixedName = prefixedClusterSpecifierPluginName(
action.namedClusterSpecifierPluginConfig().name());
clusters.add(prefixedName);
rlsPluginConfigMap.put(prefixedName, (RlsPluginConfig) pluginConfig);
}
ClientInterceptor filters = createFilters(filterConfigs, virtualHost, route, null);
routesData.add(new RouteData(route.routeMatch(), route.routeAction(), filters));
}
}
@ -796,10 +801,7 @@ final class XdsNameResolver extends NameResolver {
}
// Make newly added clusters selectable by config selector and deleted clusters no longer
// selectable.
routingConfig =
new RoutingConfig(
httpMaxStreamDurationNano, routes, filterConfigs,
virtualHost.filterConfigOverrides());
routingConfig = new RoutingConfig(httpMaxStreamDurationNano, routesData.build());
shouldUpdateResult = false;
for (String cluster : deletedClusters) {
int count = clusterRefs.get(cluster).refCount.decrementAndGet();
@ -813,6 +815,37 @@ final class XdsNameResolver extends NameResolver {
}
}
private ClientInterceptor createFilters(
@Nullable List<NamedFilterConfig> filterConfigs,
VirtualHost virtualHost,
Route route,
@Nullable ClusterWeight weightedCluster) {
if (filterConfigs == null) {
return new PassthroughClientInterceptor();
}
Map<String, FilterConfig> selectedOverrideConfigs =
new HashMap<>(virtualHost.filterConfigOverrides());
selectedOverrideConfigs.putAll(route.filterConfigOverrides());
if (weightedCluster != null) {
selectedOverrideConfigs.putAll(weightedCluster.filterConfigOverrides());
}
ImmutableList.Builder<ClientInterceptor> filterInterceptors = ImmutableList.builder();
for (NamedFilterConfig namedFilter : filterConfigs) {
FilterConfig filterConfig = namedFilter.filterConfig;
Filter filter = filterRegistry.get(filterConfig.typeUrl());
if (filter instanceof ClientInterceptorBuilder) {
ClientInterceptor interceptor = ((ClientInterceptorBuilder) filter)
.buildClientInterceptor(
filterConfig, selectedOverrideConfigs.get(namedFilter.name),
scheduler);
if (interceptor != null) {
filterInterceptors.add(interceptor);
}
}
}
return combineInterceptors(filterInterceptors.build());
}
private void cleanUpRoutes(String error) {
if (existingClusters != null) {
for (String cluster : existingClusters) {
@ -903,22 +936,50 @@ final class XdsNameResolver extends NameResolver {
*/
private static class RoutingConfig {
private final long fallbackTimeoutNano;
final List<Route> routes;
// Null if HttpFilter is not supported.
@Nullable final List<NamedFilterConfig> filterChain;
final Map<String, FilterConfig> virtualHostOverrideConfig;
final ImmutableList<RouteData> routes;
private static RoutingConfig empty = new RoutingConfig(
0, Collections.emptyList(), null, Collections.emptyMap());
private static RoutingConfig empty = new RoutingConfig(0, ImmutableList.of());
private RoutingConfig(
long fallbackTimeoutNano, List<Route> routes, @Nullable List<NamedFilterConfig> filterChain,
Map<String, FilterConfig> virtualHostOverrideConfig) {
private RoutingConfig(long fallbackTimeoutNano, ImmutableList<RouteData> routes) {
this.fallbackTimeoutNano = fallbackTimeoutNano;
this.routes = routes;
checkArgument(filterChain == null || !filterChain.isEmpty(), "filterChain is empty");
this.filterChain = filterChain == null ? null : Collections.unmodifiableList(filterChain);
this.virtualHostOverrideConfig = Collections.unmodifiableMap(virtualHostOverrideConfig);
this.routes = checkNotNull(routes, "routes");
}
}
static final class RouteData {
final RouteMatch routeMatch;
/** null implies non-forwarding action. */
@Nullable
final RouteAction routeAction;
/**
* Only one of these interceptors should be used per-RPC. There are only multiple values in the
* list for weighted clusters, in which case the order of the list mirrors the weighted
* clusters.
*/
final ImmutableList<ClientInterceptor> filterChoices;
RouteData(RouteMatch routeMatch, @Nullable RouteAction routeAction, ClientInterceptor filter) {
this(routeMatch, routeAction, ImmutableList.of(filter));
}
RouteData(
RouteMatch routeMatch,
@Nullable RouteAction routeAction,
ImmutableList<ClientInterceptor> filterChoices) {
this.routeMatch = checkNotNull(routeMatch, "routeMatch");
checkArgument(
routeAction == null || !filterChoices.isEmpty(),
"filter may be empty only for non-forwarding action");
this.routeAction = routeAction;
if (routeAction != null && routeAction.weightedClusters() != null) {
checkArgument(
routeAction.weightedClusters().size() == filterChoices.size(),
"filter choices must match size of weighted clusters");
}
for (ClientInterceptor filter : filterChoices) {
checkNotNull(filter, "entry in filterChoices is null");
}
this.filterChoices = checkNotNull(filterChoices, "filterChoices");
}
}

View File

@ -92,7 +92,7 @@ public class GcpAuthenticationFilterTest {
GcpAuthenticationFilter filter = new GcpAuthenticationFilter();
// Create interceptor
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null, null);
ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null);
MethodDescriptor<Void, Void> methodDescriptor = TestMethodDescriptors.voidMethod();
// Mock channel and capture CallOptions

View File

@ -113,7 +113,6 @@ import io.envoyproxy.envoy.type.v3.Int64Range;
import io.grpc.ClientInterceptor;
import io.grpc.EquivalentAddressGroup;
import io.grpc.InsecureChannelCredentials;
import io.grpc.LoadBalancer;
import io.grpc.LoadBalancerRegistry;
import io.grpc.Status.Code;
import io.grpc.internal.JsonUtil;
@ -1266,7 +1265,6 @@ public class GrpcXdsClientImplDataTest {
@Override
public ClientInterceptor buildClientInterceptor(FilterConfig config,
@Nullable FilterConfig overrideConfig,
LoadBalancer.PickSubchannelArgs args,
ScheduledExecutorService scheduler) {
return null;
}