okhttp: Move HostnameVerifier tests to TlsTest

Http2OkHttp is now unnecessary, as Http2Test tests OkHttp client to
Netty server. receivedDataForFinishedStream() was the only remaining
unique test and it seems already covered by AbstractInteropTest these
days.
This commit is contained in:
Eric Anderson 2024-02-09 21:09:09 -08:00
parent 92463f62bf
commit 7e72413233
2 changed files with 56 additions and 212 deletions

View File

@ -1,211 +0,0 @@
/*
* Copyright 2014 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.testing.integration;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import com.google.common.base.Throwables;
import com.squareup.okhttp.ConnectionSpec;
import io.grpc.ChannelCredentials;
import io.grpc.ManagedChannel;
import io.grpc.ServerBuilder;
import io.grpc.ServerCredentials;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.testing.StreamRecorder;
import io.grpc.internal.testing.TestUtils;
import io.grpc.netty.InternalNettyServerBuilder;
import io.grpc.netty.NettyServerBuilder;
import io.grpc.okhttp.InternalOkHttpChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.internal.Platform;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TlsTesting;
import io.grpc.testing.integration.EmptyProtos.Empty;
import java.io.IOException;
import java.net.InetSocketAddress;
import javax.net.ssl.HostnameVerifier;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSession;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/**
* Integration tests for GRPC over Http2 using the OkHttp framework.
*/
@RunWith(JUnit4.class)
public class Http2OkHttpTest extends AbstractInteropTest {
private static final String BAD_HOSTNAME = "I.am.a.bad.hostname";
@Override
protected ServerBuilder<?> getServerBuilder() {
// Starts the server with HTTPS.
try {
ServerCredentials serverCreds = TlsServerCredentials.create(
TlsTesting.loadCert("server1.pem"), TlsTesting.loadCert("server1.key"));
NettyServerBuilder builder = NettyServerBuilder.forPort(0, serverCreds)
.flowControlWindow(AbstractInteropTest.TEST_FLOW_CONTROL_WINDOW)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE);
// Disable the default census stats tracer, use testing tracer instead.
InternalNettyServerBuilder.setStatsEnabled(builder, false);
return builder.addStreamTracerFactory(createCustomCensusTracerFactory());
} catch (IOException ex) {
throw new RuntimeException(ex);
}
}
@Override
protected OkHttpChannelBuilder createChannelBuilder() {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ChannelCredentials channelCreds;
try {
channelCreds = TlsChannelCredentials.newBuilder()
.trustManager(TlsTesting.loadCert("ca.pem"))
.build();
} catch (IOException ex) {
throw new RuntimeException(ex);
}
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port, channelCreds)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port));
// Disable the default census stats interceptor, use testing interceptor instead.
InternalOkHttpChannelBuilder.setStatsEnabled(builder, false);
return builder.intercept(createCensusStatsClientInterceptor());
}
private OkHttpChannelBuilder createChannelBuilderPreCredentialsApi() {
int port = ((InetSocketAddress) getListenAddress()).getPort();
OkHttpChannelBuilder builder = OkHttpChannelBuilder.forAddress("localhost", port)
.maxInboundMessageSize(AbstractInteropTest.MAX_MESSAGE_SIZE)
.connectionSpec(new ConnectionSpec.Builder(ConnectionSpec.MODERN_TLS)
.build())
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port));
try {
builder.sslSocketFactory(TestUtils.newSslSocketFactoryForCa(Platform.get().getProvider(),
TestUtils.loadCert("ca.pem")));
} catch (Exception e) {
throw new RuntimeException(e);
}
// Disable the default census stats interceptor, use testing interceptor instead.
InternalOkHttpChannelBuilder.setStatsEnabled(builder, false);
return builder.intercept(createCensusStatsClientInterceptor());
}
@Test
public void receivedDataForFinishedStream() throws Exception {
Messages.ResponseParameters.Builder responseParameters =
Messages.ResponseParameters.newBuilder()
.setSize(1);
Messages.StreamingOutputCallRequest.Builder requestBuilder =
Messages.StreamingOutputCallRequest.newBuilder();
for (int i = 0; i < 1000; i++) {
requestBuilder.addResponseParameters(responseParameters);
}
StreamRecorder<Messages.StreamingOutputCallResponse> recorder = StreamRecorder.create();
StreamObserver<Messages.StreamingOutputCallRequest> requestStream =
asyncStub.fullDuplexCall(recorder);
Messages.StreamingOutputCallRequest request = requestBuilder.build();
requestStream.onNext(request);
recorder.firstValue().get();
requestStream.onError(new Exception("failed"));
recorder.awaitCompletion();
assertEquals(EMPTY, blockingStub.emptyCall(EMPTY));
}
@Test
public void wrongHostNameFailHostnameVerification() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port))
.build();
TestServiceGrpc.TestServiceBlockingStub blockingStub =
TestServiceGrpc.newBlockingStub(channel);
Throwable actualThrown = null;
try {
blockingStub.emptyCall(Empty.getDefaultInstance());
} catch (Throwable t) {
actualThrown = t;
}
assertNotNull("The rpc should have been failed due to hostname verification", actualThrown);
Throwable cause = Throwables.getRootCause(actualThrown);
assertTrue(
"Failed by unexpected exception: " + cause, cause instanceof SSLPeerUnverifiedException);
channel.shutdown();
}
@Test
public void hostnameVerifierWithBadHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
BAD_HOSTNAME, port))
.hostnameVerifier(new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
return true;
}
})
.build();
TestServiceGrpc.TestServiceBlockingStub blockingStub =
TestServiceGrpc.newBlockingStub(channel);
blockingStub.emptyCall(Empty.getDefaultInstance());
channel.shutdown();
}
@Test
public void hostnameVerifierWithCorrectHostname() throws Exception {
int port = ((InetSocketAddress) getListenAddress()).getPort();
ManagedChannel channel = createChannelBuilderPreCredentialsApi()
.overrideAuthority(GrpcUtil.authorityFromHostAndPort(
TestUtils.TEST_SERVER_HOST, port))
.hostnameVerifier(new HostnameVerifier() {
@Override
public boolean verify(String hostname, SSLSession session) {
return false;
}
})
.build();
TestServiceGrpc.TestServiceBlockingStub blockingStub =
TestServiceGrpc.newBlockingStub(channel);
Throwable actualThrown = null;
try {
blockingStub.emptyCall(Empty.getDefaultInstance());
} catch (Throwable t) {
actualThrown = t;
}
assertNotNull("The rpc should have been failed due to hostname verification", actualThrown);
Throwable cause = Throwables.getRootCause(actualThrown);
assertTrue(
"Failed by unexpected exception: " + cause, cause instanceof SSLPeerUnverifiedException);
channel.shutdown();
}
}

View File

@ -31,6 +31,7 @@ import io.grpc.StatusRuntimeException;
import io.grpc.TlsChannelCredentials;
import io.grpc.TlsServerCredentials;
import io.grpc.internal.testing.TestUtils;
import io.grpc.okhttp.internal.Platform;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.TlsTesting;
@ -41,6 +42,8 @@ import java.io.IOException;
import java.io.InputStream;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLPeerUnverifiedException;
import javax.net.ssl.SSLSocketFactory;
import org.junit.Assume;
import org.junit.Before;
import org.junit.Rule;
@ -230,6 +233,55 @@ public class TlsTest {
assertRpcFails(channel);
}
@Test
public void hostnameVerifierTrusts_succeeds()
throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.build();
}
SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa(
Platform.get().getProvider(), TestUtils.loadCert("ca.pem"));
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(
OkHttpChannelBuilder.forAddress("localhost", server.getPort())
.directExecutor()
.overrideAuthority("notgonnamatch.example.com")
.sslSocketFactory(sslSocketFactory)
.hostnameVerifier((hostname, session) -> true)
.build());
SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance());
}
@Test
public void hostnameVerifierFails_fails()
throws Exception {
ServerCredentials serverCreds;
try (InputStream serverCert = TlsTesting.loadCert("server1.pem");
InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) {
serverCreds = TlsServerCredentials.newBuilder()
.keyManager(serverCert, serverPrivateKey)
.build();
}
SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa(
Platform.get().getProvider(), TestUtils.loadCert("ca.pem"));
Server server = grpcCleanupRule.register(server(serverCreds));
ManagedChannel channel = grpcCleanupRule.register(
OkHttpChannelBuilder.forAddress("localhost", server.getPort())
.directExecutor()
.overrideAuthority(TestUtils.TEST_SERVER_HOST)
.sslSocketFactory(sslSocketFactory)
.hostnameVerifier((hostname, session) -> false)
.build());
Status status = assertRpcFails(channel);
assertThat(status.getCause()).isInstanceOf(SSLPeerUnverifiedException.class);
}
private static Server server(ServerCredentials creds) throws IOException {
return OkHttpServerBuilder.forPort(0, creds)
.directExecutor()
@ -249,7 +301,8 @@ public class TlsTest {
return clientChannelBuilder(server, creds).build();
}
private static void assertRpcFails(ManagedChannel channel) {
private static Status assertRpcFails(ManagedChannel channel) {
Status status = null;
SimpleServiceGrpc.SimpleServiceBlockingStub stub = SimpleServiceGrpc.newBlockingStub(channel);
try {
stub.unaryRpc(SimpleRequest.getDefaultInstance());
@ -258,6 +311,7 @@ public class TlsTest {
} catch (StatusRuntimeException e) {
assertWithMessage(Throwables.getStackTraceAsString(e))
.that(e.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE);
status = e.getStatus();
}
// We really want to see TRANSIENT_FAILURE here, but if the test runs slowly the 1s backoff
// may be exceeded by the time the failure happens (since it counts from the start of the
@ -265,6 +319,7 @@ public class TlsTest {
// expect READY or IDLE.
assertThat(channel.getState(false))
.isAnyOf(ConnectivityState.TRANSIENT_FAILURE, ConnectivityState.CONNECTING);
return status;
}
private static final class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase {