This commit is contained in:
Mateus Azis 2025-07-24 07:57:25 -07:00 committed by GitHub
commit faf2f1e5e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 103 additions and 16 deletions

View File

@ -1,22 +1,39 @@
package io.grpc.binder.internal;
import static com.google.common.truth.Truth.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.verify;
import static org.junit.Assert.assertThrows;
import static org.mockito.Mockito.verifyNoInteractions;
import static org.mockito.Mockito.when;
import io.grpc.CallOptions;
import io.grpc.Channel;
import io.grpc.ClientCall;
import io.grpc.ClientInterceptor;
import io.grpc.ForwardingClientCall;
import io.grpc.ForwardingClientCallListener;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor;
import io.grpc.Server;
import io.grpc.ServerCall;
import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptor;
import io.grpc.ServerServiceDefinition;
import io.grpc.Status;
import io.grpc.StatusRuntimeException;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.ClientCalls;
import io.grpc.stub.ServerCalls;
import io.grpc.testing.GrpcCleanupRule;
import io.grpc.testing.TestMethodDescriptors;
import java.io.IOException;
import java.time.Duration;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.Mockito;
@ -26,12 +43,15 @@ import org.mockito.junit.MockitoRule;
@RunWith(JUnit4.class)
public final class PendingAuthListenerTest {
private static final MethodDescriptor<Void, Void> TEST_METHOD =
TestMethodDescriptors.voidMethod();
@Rule public final MockitoRule mocks = MockitoJUnit.rule();
@Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule();
@Mock ServerCallHandler<Object, Object> next;
@Mock ServerCall<Object, Object> call;
@Mock ServerCall.Listener<Object> delegate;
@Captor ArgumentCaptor<Status> statusCaptor;
private final Metadata headers = new Metadata();
private final PendingAuthListener<Object, Object> listener = new PendingAuthListener<>();
@ -86,16 +106,80 @@ public final class PendingAuthListenerTest {
}
@Test
public void whenStartCallFails_closesTheCallWithInternalStatus() {
IllegalStateException exception = new IllegalStateException("oops");
when(next.startCall(any(), any())).thenThrow(exception);
public void whenStartCallFails_closesTheCallWithInternalStatus() throws Exception {
// Arrange
ServerCallHandler<Void, Void> callHandler =
ServerCalls.asyncUnaryCall(
(req, respObserver) -> {
throw new IllegalStateException("ooops");
});
ManagedChannel channel = startServer(callHandler);
listener.onReady();
listener.startCall(call, headers, next);
// Act
StatusRuntimeException ex =
assertThrows(
StatusRuntimeException.class,
() ->
ClientCalls.blockingUnaryCall(
channel,
TEST_METHOD,
CallOptions.DEFAULT.withDeadlineAfter(Duration.ofSeconds(5)),
/* request= */ null));
verify(call).close(statusCaptor.capture(), any());
Status status = statusCaptor.getValue();
assertThat(status.getCode()).isEqualTo(Status.Code.INTERNAL);
assertThat(status.getCause()).isSameInstanceAs(exception);
// Assert
assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.INTERNAL);
}
private ManagedChannel startServer(ServerCallHandler<Void, Void> callHandler) throws IOException {
String name = TestMethodDescriptors.SERVICE_NAME;
ServerServiceDefinition serviceDef =
ServerServiceDefinition.builder(name).addMethod(TEST_METHOD, callHandler).build();
Server server =
InProcessServerBuilder.forName(name)
.addService(serviceDef)
.intercept(
new ServerInterceptor() {
@SuppressWarnings("unchecked")
@Override
public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(
ServerCall<ReqT, RespT> call,
Metadata headers,
ServerCallHandler<ReqT, RespT> next) {
listener.startCall(
(ServerCall<Object, Object>) call,
headers,
(ServerCallHandler<Object, Object>) next);
return (ServerCall.Listener<ReqT>) listener;
}
})
.build()
.start();
ManagedChannel channel =
InProcessChannelBuilder.forName(name)
.intercept(
new ClientInterceptor() {
@Override
public <ReqT, RespT> ClientCall<ReqT, RespT> interceptCall(
MethodDescriptor<ReqT, RespT> method, CallOptions callOptions, Channel next) {
ClientCall<ReqT, RespT> delegate = next.newCall(method, callOptions);
return new ForwardingClientCall.SimpleForwardingClientCall<ReqT, RespT>(
delegate) {
@Override
public void start(Listener<RespT> responseListener, Metadata headers) {
ClientCall.Listener<RespT> wrappedListener =
new ForwardingClientCallListener.SimpleForwardingClientCallListener<
RespT>(responseListener) {};
super.start(wrappedListener, headers);
}
};
}
})
.build();
grpcCleanupRule.register(server);
grpcCleanupRule.register(channel);
return channel;
}
}

View File

@ -30,16 +30,19 @@ import java.io.InputStream;
public final class TestMethodDescriptors {
private TestMethodDescriptors() {}
/** The name of the service that the method returned by {@link #voidMethod()} uses. */
public static final String SERVICE_NAME = "service_foo";
/**
* Creates a new method descriptor that always creates zero length messages, and always parses to
* null objects.
* null objects. It is part of the service named {@link #SERVICE_NAME}.
*
* @since 1.1.0
*/
public static MethodDescriptor<Void, Void> voidMethod() {
return MethodDescriptor.<Void, Void>newBuilder()
.setType(MethodType.UNARY)
.setFullMethodName(MethodDescriptor.generateFullMethodName("service_foo", "method_bar"))
.setFullMethodName(MethodDescriptor.generateFullMethodName(SERVICE_NAME, "method_bar"))
.setRequestMarshaller(TestMethodDescriptors.voidMarshaller())
.setResponseMarshaller(TestMethodDescriptors.voidMarshaller())
.build();