Migrate users of ManagedChannelBuilder.{forTarget,forAddress} to ChannelCredentials

This commit is contained in:
Eric Anderson 2020-08-05 16:04:44 -07:00 committed by Eric Anderson
parent 5a687e3da8
commit a547e23f5e
9 changed files with 141 additions and 76 deletions

View File

@ -17,9 +17,12 @@
package io.grpc.android.integrationtest; package io.grpc.android.integrationtest;
import android.support.annotation.Nullable; import android.support.annotation.Nullable;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder; import io.grpc.okhttp.SslSocketFactoryChannelCredentials;
import java.io.InputStream; import java.io.InputStream;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.cert.CertificateFactory; import java.security.cert.CertificateFactory;
@ -40,21 +43,23 @@ class TesterOkHttpChannelBuilder {
@Nullable String serverHostOverride, @Nullable String serverHostOverride,
boolean useTls, boolean useTls,
@Nullable InputStream testCa) { @Nullable InputStream testCa) {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress(host, port) ChannelCredentials credentials;
.maxInboundMessageSize(16 * 1024 * 1024);
if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride);
}
if (useTls) { if (useTls) {
try { try {
((OkHttpChannelBuilder) channelBuilder).useTransportSecurity(); credentials = SslSocketFactoryChannelCredentials.create(getSslSocketFactory(testCa));
((OkHttpChannelBuilder) channelBuilder).sslSocketFactory(getSslSocketFactory(testCa));
} catch (Exception e) { } catch (Exception e) {
throw new RuntimeException(e); throw new RuntimeException(e);
} }
} else { } else {
channelBuilder.usePlaintext(); credentials = InsecureChannelCredentials.create();
}
ManagedChannelBuilder<?> channelBuilder = Grpc.newChannelBuilderForAddress(
host, port, credentials)
.maxInboundMessageSize(16 * 1024 * 1024);
if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride);
} }
return channelBuilder.build(); return channelBuilder.build();
} }

View File

@ -17,8 +17,9 @@
package io.grpc.testing.integration; package io.grpc.testing.integration;
import com.google.protobuf.ByteString; import com.google.protobuf.ByteString;
import io.grpc.Grpc;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.TlsChannelCredentials;
import io.grpc.testing.integration.Messages.Payload; import io.grpc.testing.integration.Messages.Payload;
import io.grpc.testing.integration.Messages.SimpleRequest; import io.grpc.testing.integration.Messages.SimpleRequest;
import io.grpc.testing.integration.Messages.SimpleResponse; import io.grpc.testing.integration.Messages.SimpleResponse;
@ -40,7 +41,7 @@ import javax.servlet.http.HttpServletResponse;
public final class LongLivedChannel extends HttpServlet { public final class LongLivedChannel extends HttpServlet {
private static final String INTEROP_TEST_ADDRESS = "grpc-test.sandbox.googleapis.com:443"; private static final String INTEROP_TEST_ADDRESS = "grpc-test.sandbox.googleapis.com:443";
private final ManagedChannel channel = private final ManagedChannel channel =
ManagedChannelBuilder.forTarget(INTEROP_TEST_ADDRESS).build(); Grpc.newChannelBuilder(INTEROP_TEST_ADDRESS, TlsChannelCredentials.create()).build();
@Override @Override
public void destroy() { public void destroy() {

View File

@ -19,7 +19,9 @@ package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import io.grpc.Grpc;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.TlsChannelCredentials;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import java.io.IOException; import java.io.IOException;
import java.io.PrintWriter; import java.io.PrintWriter;
@ -131,7 +133,7 @@ public final class NettyClientInteropServlet extends HttpServlet {
"1.8", "1.8",
System.getProperty("java.specification.version")); System.getProperty("java.specification.version"));
ManagedChannelBuilder<?> builder = ManagedChannelBuilder<?> builder =
ManagedChannelBuilder.forTarget(INTEROP_TEST_ADDRESS) Grpc.newChannelBuilder(INTEROP_TEST_ADDRESS, TlsChannelCredentials.create())
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE); .maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE);
assertTrue(builder instanceof NettyChannelBuilder); assertTrue(builder instanceof NettyChannelBuilder);
((NettyChannelBuilder) builder) ((NettyChannelBuilder) builder)

View File

@ -18,26 +18,30 @@ package io.grpc.testing.integration;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import com.google.common.io.Files; import com.google.common.io.Files;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.ManagedChannelBuilder;
import io.grpc.alts.AltsChannelBuilder; import io.grpc.TlsChannelCredentials;
import io.grpc.alts.ComputeEngineChannelBuilder; import io.grpc.alts.AltsChannelCredentials;
import io.grpc.alts.GoogleDefaultChannelBuilder; import io.grpc.alts.ComputeEngineChannelCredentials;
import io.grpc.alts.GoogleDefaultChannelCredentials;
import io.grpc.internal.GrpcUtil; import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.TestUtils; import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.GrpcSslContexts; import io.grpc.netty.GrpcSslContexts;
import io.grpc.netty.InsecureFromHttp1ChannelCredentials;
import io.grpc.netty.InternalNettyChannelBuilder; import io.grpc.netty.InternalNettyChannelBuilder;
import io.grpc.netty.NegotiationType;
import io.grpc.netty.NettyChannelBuilder; import io.grpc.netty.NettyChannelBuilder;
import io.grpc.netty.NettySslContextChannelCredentials;
import io.grpc.okhttp.InternalOkHttpChannelBuilder; import io.grpc.okhttp.InternalOkHttpChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder; import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.SslSocketFactoryChannelCredentials;
import io.grpc.okhttp.internal.Platform; import io.grpc.okhttp.internal.Platform;
import io.netty.handler.ssl.SslContext;
import java.io.File; import java.io.File;
import java.io.FileInputStream; import java.io.FileInputStream;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLSocketFactory;
/** /**
* Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs * Application that starts a client for the {@link TestServiceGrpc.TestServiceImplBase} and runs
@ -282,8 +286,8 @@ public class TestServiceClient {
break; break;
case COMPUTE_ENGINE_CHANNEL_CREDENTIALS: { case COMPUTE_ENGINE_CHANNEL_CREDENTIALS: {
ManagedChannel channel = ComputeEngineChannelBuilder ManagedChannel channel = Grpc.newChannelBuilderForAddress(
.forAddress(serverHost, serverPort).build(); serverHost, serverPort, ComputeEngineChannelCredentials.create()).build();
try { try {
TestServiceGrpc.TestServiceBlockingStub computeEngineStub = TestServiceGrpc.TestServiceBlockingStub computeEngineStub =
TestServiceGrpc.newBlockingStub(channel); TestServiceGrpc.newBlockingStub(channel);
@ -323,8 +327,8 @@ public class TestServiceClient {
} }
case GOOGLE_DEFAULT_CREDENTIALS: { case GOOGLE_DEFAULT_CREDENTIALS: {
ManagedChannel channel = GoogleDefaultChannelBuilder.forAddress( ManagedChannel channel = Grpc.newChannelBuilderForAddress(
serverHost, serverPort).build(); serverHost, serverPort, GoogleDefaultChannelCredentials.create()).build();
try { try {
TestServiceGrpc.TestServiceBlockingStub googleDefaultStub = TestServiceGrpc.TestServiceBlockingStub googleDefaultStub =
TestServiceGrpc.newBlockingStub(channel); TestServiceGrpc.newBlockingStub(channel);
@ -392,34 +396,56 @@ public class TestServiceClient {
private class Tester extends AbstractInteropTest { private class Tester extends AbstractInteropTest {
@Override @Override
protected ManagedChannelBuilder<?> createChannelBuilder() { protected ManagedChannelBuilder<?> createChannelBuilder() {
if (customCredentialsType != null ChannelCredentials channelCredentials;
&& customCredentialsType.equals("google_default_credentials")) { if (customCredentialsType != null) {
return GoogleDefaultChannelBuilder.forAddress(serverHost, serverPort); useOkHttp = false; // Retain old behavior; avoids erroring if incompatible
} if (customCredentialsType.equals("google_default_credentials")) {
if (customCredentialsType != null channelCredentials = GoogleDefaultChannelCredentials.create();
&& customCredentialsType.equals("compute_engine_channel_creds")) { } else if (customCredentialsType.equals("compute_engine_channel_creds")) {
return ComputeEngineChannelBuilder.forAddress(serverHost, serverPort); channelCredentials = ComputeEngineChannelCredentials.create();
} } else {
if (useAlts) { throw new IllegalArgumentException(
return AltsChannelBuilder.forAddress(serverHost, serverPort); "Unknown custom credentials: " + customCredentialsType);
} }
if (!useOkHttp) { } else if (useAlts) {
SslContext sslContext = null; useOkHttp = false; // Retain old behavior; avoids erroring if incompatible
if (useTestCa) { channelCredentials = AltsChannelCredentials.create();
} else if (useTls) {
if (!useTestCa) {
channelCredentials = TlsChannelCredentials.create();
} else {
try { try {
sslContext = GrpcSslContexts.forClient().trustManager( if (useOkHttp) {
TestUtils.loadCert("ca.pem")).build(); channelCredentials = SslSocketFactoryChannelCredentials.create(
TestUtils.newSslSocketFactoryForCa(
Platform.get().getProvider(), TestUtils.loadCert("ca.pem")));
} else {
channelCredentials = NettySslContextChannelCredentials.create(
GrpcSslContexts.forClient().trustManager(
TestUtils.loadCert("ca.pem")).build());
}
} catch (Exception ex) { } catch (Exception ex) {
throw new RuntimeException(ex); throw new RuntimeException(ex);
} }
} }
} else {
if (useH2cUpgrade) {
if (useOkHttp) {
throw new IllegalArgumentException("OkHttp does not support HTTP/1 upgrade");
} else {
channelCredentials = InsecureFromHttp1ChannelCredentials.create();
}
} else {
channelCredentials = InsecureChannelCredentials.create();
}
}
if (!useOkHttp) {
NettyChannelBuilder nettyBuilder = NettyChannelBuilder nettyBuilder =
NettyChannelBuilder.forAddress(serverHost, serverPort) NettyChannelBuilder.forAddress(serverHost, serverPort, channelCredentials)
.flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW) .flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW);
.negotiationType(useTls ? NegotiationType.TLS :
(useH2cUpgrade ? NegotiationType.PLAINTEXT_UPGRADE : NegotiationType.PLAINTEXT))
.sslContext(sslContext);
if (serverHostOverride != null) { if (serverHostOverride != null) {
nettyBuilder.overrideAuthority(serverHostOverride); nettyBuilder.overrideAuthority(serverHostOverride);
} }
@ -431,25 +457,13 @@ public class TestServiceClient {
return nettyBuilder.intercept(createCensusStatsClientInterceptor()); return nettyBuilder.intercept(createCensusStatsClientInterceptor());
} }
OkHttpChannelBuilder okBuilder = OkHttpChannelBuilder.forAddress(serverHost, serverPort); OkHttpChannelBuilder okBuilder =
OkHttpChannelBuilder.forAddress(serverHost, serverPort, channelCredentials);
if (serverHostOverride != null) { if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses. // Force the hostname to match the cert the server uses.
okBuilder.overrideAuthority( okBuilder.overrideAuthority(
GrpcUtil.authorityFromHostAndPort(serverHostOverride, serverPort)); GrpcUtil.authorityFromHostAndPort(serverHostOverride, serverPort));
} }
if (useTls) {
if (useTestCa) {
try {
SSLSocketFactory factory = TestUtils.newSslSocketFactoryForCa(
Platform.get().getProvider(), TestUtils.loadCert("ca.pem"));
okBuilder.sslSocketFactory(factory);
} catch (Exception e) {
throw new RuntimeException(e);
}
}
} else {
okBuilder.usePlaintext();
}
if (fullStreamDecompression) { if (fullStreamDecompression) {
okBuilder.enableFullStreamDecompression(); okBuilder.enableFullStreamDecompression();
} }

View File

@ -35,8 +35,9 @@ import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry; import io.grpc.DecompressorRegistry;
import io.grpc.ForwardingClientCall.SimpleForwardingClientCall; import io.grpc.ForwardingClientCall.SimpleForwardingClientCall;
import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener; import io.grpc.ForwardingClientCallListener.SimpleForwardingClientCallListener;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Server; import io.grpc.Server;
@ -189,11 +190,11 @@ public class CompressionTest {
.build() .build()
.start(); .start();
channel = ManagedChannelBuilder.forAddress("localhost", server.getPort()) channel = Grpc.newChannelBuilder(
"localhost:" + server.getPort(), InsecureChannelCredentials.create())
.decompressorRegistry(clientDecompressors) .decompressorRegistry(clientDecompressors)
.compressorRegistry(clientCompressors) .compressorRegistry(clientCompressors)
.intercept(new ClientCompressorInterceptor()) .intercept(new ClientCompressorInterceptor())
.usePlaintext()
.build(); .build();
stub = TestServiceGrpc.newBlockingStub(channel); stub = TestServiceGrpc.newBlockingStub(channel);

View File

@ -0,0 +1,36 @@
/*
* Copyright 2020 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.testing.integration;
import io.grpc.ManagedChannelBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Tests for integration between {@link ManagedChannelBuilder} and providers. */
@RunWith(JUnit4.class)
public final class ManagedChannelBuilderSpiTest {
@Test
public void forAddress_isntObviouslyBroken() {
ManagedChannelBuilder.forAddress("localhost", 443).build().shutdownNow();
}
@Test
public void forTarget_isntObviouslyBroken() {
ManagedChannelBuilder.forTarget("localhost:443").build().shutdownNow();
}
}

View File

@ -22,8 +22,9 @@ import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertTrue;
import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.testing.integration.Metrics.EmptyMessage; import io.grpc.testing.integration.Metrics.EmptyMessage;
import io.grpc.testing.integration.Metrics.GaugeResponse; import io.grpc.testing.integration.Metrics.GaugeResponse;
import io.grpc.testing.integration.StressTestClient.TestCaseWeightPair; import io.grpc.testing.integration.StressTestClient.TestCaseWeightPair;
@ -128,8 +129,8 @@ public class StressTestClientTest {
client.runStressTest(); client.runStressTest();
// Connect to the metrics service // Connect to the metrics service
ManagedChannel ch = ManagedChannelBuilder.forAddress("localhost", client.getMetricServerPort()) ManagedChannel ch = Grpc.newChannelBuilder(
.usePlaintext() "localhost:" + client.getMetricServerPort(), InsecureChannelCredentials.create())
.build(); .build();
MetricsServiceGrpc.MetricsServiceBlockingStub stub = MetricsServiceGrpc.newBlockingStub(ch); MetricsServiceGrpc.MetricsServiceBlockingStub stub = MetricsServiceGrpc.newBlockingStub(ch);

View File

@ -17,9 +17,12 @@
package io.grpc.xds; package io.grpc.xds;
import com.google.common.annotations.VisibleForTesting; import com.google.common.annotations.VisibleForTesting;
import io.grpc.ChannelCredentials;
import io.grpc.Grpc;
import io.grpc.InsecureChannelCredentials;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder; import io.grpc.TlsChannelCredentials;
import io.grpc.alts.GoogleDefaultChannelBuilder; import io.grpc.alts.GoogleDefaultChannelCredentials;
import io.grpc.xds.Bootstrapper.ChannelCreds; import io.grpc.xds.Bootstrapper.ChannelCreds;
import io.grpc.xds.Bootstrapper.ServerInfo; import io.grpc.xds.Bootstrapper.ServerInfo;
import io.grpc.xds.XdsClient.XdsChannel; import io.grpc.xds.XdsClient.XdsChannel;
@ -50,33 +53,33 @@ abstract class XdsChannelFactory {
String serverUri = serverInfo.getServerUri(); String serverUri = serverInfo.getServerUri();
logger.log(XdsLogLevel.INFO, "Creating channel to {0}", serverUri); logger.log(XdsLogLevel.INFO, "Creating channel to {0}", serverUri);
List<ChannelCreds> channelCredsList = serverInfo.getChannelCredentials(); List<ChannelCreds> channelCredsList = serverInfo.getChannelCredentials();
ManagedChannelBuilder<?> channelBuilder = null; ChannelCredentials channelCreds = null;
// Use the first supported channel credentials configuration. // Use the first supported channel credentials configuration.
for (ChannelCreds creds : channelCredsList) { for (ChannelCreds creds : channelCredsList) {
switch (creds.getType()) { switch (creds.getType()) {
case "google_default": case "google_default":
logger.log(XdsLogLevel.INFO, "Using channel credentials: google_default"); logger.log(XdsLogLevel.INFO, "Using channel credentials: google_default");
channelBuilder = GoogleDefaultChannelBuilder.forTarget(serverUri); channelCreds = GoogleDefaultChannelCredentials.create();
break; break;
case "insecure": case "insecure":
logger.log(XdsLogLevel.INFO, "Using channel credentials: insecure"); logger.log(XdsLogLevel.INFO, "Using channel credentials: insecure");
channelBuilder = ManagedChannelBuilder.forTarget(serverUri).usePlaintext(); channelCreds = InsecureChannelCredentials.create();
break; break;
case "tls": case "tls":
logger.log(XdsLogLevel.INFO, "Using channel credentials: tls"); logger.log(XdsLogLevel.INFO, "Using channel credentials: tls");
channelBuilder = ManagedChannelBuilder.forTarget(serverUri); channelCreds = TlsChannelCredentials.create();
break; break;
default: default:
} }
if (channelBuilder != null) { if (channelCreds != null) {
break; break;
} }
} }
if (channelBuilder == null) { if (channelCreds == null) {
throw new XdsInitializationException("No server with supported channel creds found"); throw new XdsInitializationException("No server with supported channel creds found");
} }
ManagedChannel channel = channelBuilder ManagedChannel channel = Grpc.newChannelBuilder(serverUri, channelCreds)
.keepAliveTime(5, TimeUnit.MINUTES) .keepAliveTime(5, TimeUnit.MINUTES)
.build(); .build();
boolean useProtocolV3 = experimentalV3SupportEnvVar boolean useProtocolV3 = experimentalV3SupportEnvVar

View File

@ -39,13 +39,14 @@ import io.grpc.Channel;
import io.grpc.ClientCall; import io.grpc.ClientCall;
import io.grpc.ClientInterceptor; import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall; import io.grpc.ForwardingClientCall;
import io.grpc.Grpc;
import io.grpc.InternalLogId; import io.grpc.InternalLogId;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
import io.grpc.SynchronizationContext; import io.grpc.SynchronizationContext;
import io.grpc.TlsChannelCredentials;
import io.grpc.auth.MoreCallCredentials; import io.grpc.auth.MoreCallCredentials;
import io.grpc.internal.BackoffPolicy; import io.grpc.internal.BackoffPolicy;
import io.grpc.internal.TimeProvider; import io.grpc.internal.TimeProvider;
@ -361,8 +362,9 @@ final class MeshCaCertificateProvider extends CertificateProvider {
checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!"); checkArgument(serverUri != null && !serverUri.isEmpty(), "serverUri is null/empty!");
logger.log(Level.INFO, "Creating channel to {0}", serverUri); logger.log(Level.INFO, "Creating channel to {0}", serverUri);
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forTarget(serverUri); return Grpc.newChannelBuilder(serverUri, TlsChannelCredentials.create())
return channelBuilder.keepAliveTime(1, TimeUnit.MINUTES).build(); .keepAliveTime(1, TimeUnit.MINUTES)
.build();
} }
}; };