grpc-java/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java

271 lines
9.9 KiB
Java

/*
* Copyright 2016 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.grpc.okhttp;
import static com.google.common.truth.Truth.assertThat;
import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.ArgumentMatchers.isA;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
import com.google.common.io.BaseEncoding;
import io.grpc.CallOptions;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Status;
import io.grpc.internal.GrpcUtil;
import io.grpc.internal.NoopClientStreamListener;
import io.grpc.internal.StatsTraceContext;
import io.grpc.internal.TransportTracer;
import io.grpc.okhttp.internal.framed.ErrorCode;
import io.grpc.okhttp.internal.framed.FrameWriter;
import io.grpc.okhttp.internal.framed.Header;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.charset.Charset;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.Mock;
import org.mockito.Mockito;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
@RunWith(JUnit4.class)
public class OkHttpClientStreamTest {
private static final int MAX_MESSAGE_SIZE = 100;
private static final int INITIAL_WINDOW_SIZE = 65535;
@Mock private MethodDescriptor.Marshaller<Void> marshaller;
@Mock private FrameWriter mockedFrameWriter;
private ExceptionHandlingFrameWriter frameWriter;
@Mock private OkHttpClientTransport transport;
@Mock private OutboundFlowController flowController;
@Captor private ArgumentCaptor<List<Header>> headersCaptor;
private final Object lock = new Object();
private final TransportTracer transportTracer = new TransportTracer();
private MethodDescriptor<?, ?> methodDescriptor;
private OkHttpClientStream stream;
@Before
public void setUp() {
MockitoAnnotations.initMocks(this);
methodDescriptor = MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("testService/test")
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.build();
frameWriter =
new ExceptionHandlingFrameWriter(transport, mockedFrameWriter);
stream = new OkHttpClientStream(
methodDescriptor,
new Metadata(),
frameWriter,
transport,
flowController,
lock,
MAX_MESSAGE_SIZE,
INITIAL_WINDOW_SIZE,
"localhost",
"userAgent",
StatsTraceContext.NOOP,
transportTracer,
CallOptions.DEFAULT,
false);
}
@Test
public void getType() {
assertEquals(MethodType.UNARY, stream.getType());
}
@Test
public void cancel_notStarted() {
final AtomicReference<Status> statusRef = new AtomicReference<>();
stream.start(new BaseClientStreamListener() {
@Override
public void closed(
Status status, RpcProgress rpcProgress, Metadata trailers) {
statusRef.set(status);
assertTrue(Thread.holdsLock(lock));
}
});
stream.cancel(Status.CANCELLED);
assertEquals(Status.Code.CANCELLED, statusRef.get().getCode());
}
@Test
@SuppressWarnings("GuardedBy")
public void cancel_started() {
stream.start(new BaseClientStreamListener());
stream.transportState().start(1234);
Mockito.doAnswer(new Answer<Void>() {
@Override
public Void answer(InvocationOnMock invocation) throws Throwable {
assertTrue(Thread.holdsLock(lock));
return null;
}
}).when(transport).finishStream(
1234, Status.CANCELLED, PROCESSED, true, ErrorCode.CANCEL, null);
stream.cancel(Status.CANCELLED);
verify(transport).finishStream(1234, Status.CANCELLED, PROCESSED,true, ErrorCode.CANCEL, null);
}
@Test
@SuppressWarnings("GuardedBy")
public void start_alreadyCancelled() {
stream.start(new BaseClientStreamListener());
stream.cancel(Status.CANCELLED);
stream.transportState().start(1234);
verifyNoMoreInteractions(mockedFrameWriter);
}
@Test
@SuppressWarnings("GuardedBy")
public void start_userAgentRemoved() throws IOException {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "localhost",
"good-application", StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT, false);
stream.start(new BaseClientStreamListener());
stream.transportState().start(3);
verify(mockedFrameWriter)
.synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue())
.contains(new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"));
}
@Test
@SuppressWarnings("GuardedBy")
public void start_headerFieldOrder() throws IOException {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "localhost",
"good-application", StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT, false);
stream.start(new BaseClientStreamListener());
stream.transportState().start(3);
verify(mockedFrameWriter)
.synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).containsExactly(
Headers.HTTPS_SCHEME_HEADER,
Headers.METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "localhost"),
new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()),
new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"),
Headers.CONTENT_TYPE_HEADER,
Headers.TE_HEADER)
.inOrder();
}
@Test
@SuppressWarnings("GuardedBy")
public void start_headerPlaintext() throws IOException {
Metadata metaData = new Metadata();
metaData.put(GrpcUtil.USER_AGENT_KEY, "misbehaving-application");
when(transport.isUsingPlaintext()).thenReturn(true);
stream = new OkHttpClientStream(methodDescriptor, metaData, frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "localhost",
"good-application", StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT, false);
stream.start(new BaseClientStreamListener());
stream.transportState().start(3);
verify(mockedFrameWriter)
.synStream(eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).containsExactly(
Headers.HTTP_SCHEME_HEADER,
Headers.METHOD_HEADER,
new Header(Header.TARGET_AUTHORITY, "localhost"),
new Header(Header.TARGET_PATH, "/" + methodDescriptor.getFullMethodName()),
new Header(GrpcUtil.USER_AGENT_KEY.name(), "good-application"),
Headers.CONTENT_TYPE_HEADER,
Headers.TE_HEADER)
.inOrder();
}
@Test
@SuppressWarnings("GuardedBy")
public void getUnaryRequest() throws IOException {
MethodDescriptor<?, ?> getMethod = MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodDescriptor.MethodType.UNARY)
.setFullMethodName("service/method")
.setIdempotent(true)
.setSafe(true)
.setRequestMarshaller(marshaller)
.setResponseMarshaller(marshaller)
.build();
stream = new OkHttpClientStream(getMethod, new Metadata(), frameWriter, transport,
flowController, lock, MAX_MESSAGE_SIZE, INITIAL_WINDOW_SIZE, "localhost",
"good-application", StatsTraceContext.NOOP, transportTracer, CallOptions.DEFAULT, true);
stream.start(new BaseClientStreamListener());
// GET streams send headers after halfClose is called.
verify(mockedFrameWriter, times(0)).synStream(
eq(false), eq(false), eq(3), eq(0), headersCaptor.capture());
verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class));
byte[] msg = "request".getBytes(Charset.forName("UTF-8"));
stream.writeMessage(new ByteArrayInputStream(msg));
stream.halfClose();
verify(transport).streamReadyToStart(eq(stream));
stream.transportState().start(3);
verify(mockedFrameWriter)
.synStream(eq(true), eq(false), eq(3), eq(0), headersCaptor.capture());
assertThat(headersCaptor.getValue()).contains(Headers.METHOD_GET_HEADER);
assertThat(headersCaptor.getValue()).contains(
new Header(Header.TARGET_PATH, "/" + getMethod.getFullMethodName() + "?"
+ BaseEncoding.base64().encode(msg)));
}
// TODO(carl-mastrangelo): extract this out into a testing/ directory and remove other definitions
// of it.
private static class BaseClientStreamListener extends NoopClientStreamListener {
@Override
public void messagesAvailable(MessageProducer producer) {
while (producer.next() != null) {}
}
}
}