Inform the server of the client's deadline using the standard header. (#8286)

This commit is contained in:
John Cormie 2021-07-02 09:31:50 -07:00 committed by GitHub
parent 06d34925f9
commit 380f26fd8d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 145 additions and 29 deletions

View File

@ -18,6 +18,8 @@ package io.grpc.binder;
import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertThat;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.TimeUnit.MINUTES;
import static java.util.concurrent.TimeUnit.SECONDS;
import android.app.Application; import android.app.Application;
import android.app.Service; import android.app.Service;
@ -29,19 +31,26 @@ import androidx.test.ext.junit.runners.AndroidJUnit4;
import com.google.common.io.ByteStreams; import com.google.common.io.ByteStreams;
import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.ListenableFuture;
import com.google.common.util.concurrent.SettableFuture;
import io.grpc.CallOptions; import io.grpc.CallOptions;
import io.grpc.ManagedChannel; import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Server; import io.grpc.Server;
import io.grpc.ServerCallHandler; import io.grpc.ServerCallHandler;
import io.grpc.ServerInterceptors;
import io.grpc.ServerServiceDefinition; import io.grpc.ServerServiceDefinition;
import io.grpc.internal.GrpcUtil;
import io.grpc.stub.ClientCalls; import io.grpc.stub.ClientCalls;
import io.grpc.stub.ServerCalls; import io.grpc.stub.ServerCalls;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.TestUtils;
import java.io.ByteArrayInputStream; import java.io.ByteArrayInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.ArrayList;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicReference;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Test; import org.junit.Test;
@ -70,8 +79,15 @@ public final class BinderChannelSmokeTest {
.setType(MethodDescriptor.MethodType.SERVER_STREAMING) .setType(MethodDescriptor.MethodType.SERVER_STREAMING)
.build(); .build();
final MethodDescriptor<String, String> bidiMethod =
MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE)
.setFullMethodName("test/bidiMethod")
.setType(MethodDescriptor.MethodType.BIDI_STREAMING)
.build();
AndroidComponentAddress serverAddress; AndroidComponentAddress serverAddress;
ManagedChannel channel; ManagedChannel channel;
AtomicReference<Metadata> headersCapture = new AtomicReference<>();
@Before @Before
public void setUp() throws Exception { public void setUp() throws Exception {
@ -89,11 +105,17 @@ public final class BinderChannelSmokeTest {
respObserver.onCompleted(); respObserver.onCompleted();
}); });
ServerCallHandler<String, String> bidiCallHandler =
ServerCalls.asyncBidiStreamingCall(ForwardingStreamObserver::new); // Echo it all back.
ServerServiceDefinition serviceDef = ServerServiceDefinition serviceDef =
ServerServiceDefinition.builder("test") ServerInterceptors.intercept(
.addMethod(method, callHandler) ServerServiceDefinition.builder("test")
.addMethod(singleLargeResultMethod, singleLargeResultCallHandler) .addMethod(method, callHandler)
.build(); .addMethod(singleLargeResultMethod, singleLargeResultCallHandler)
.addMethod(bidiMethod, bidiCallHandler)
.build(),
TestUtils.recordRequestHeadersInterceptor(headersCapture));
AndroidComponentAddress serverAddress = HostServices.allocateService(appContext); AndroidComponentAddress serverAddress = HostServices.allocateService(appContext);
HostServices.configureService(serverAddress, HostServices.configureService(serverAddress,
@ -119,10 +141,14 @@ public final class BinderChannelSmokeTest {
private ListenableFuture<String> doCall( private ListenableFuture<String> doCall(
MethodDescriptor<String, String> methodDesc, String request) { MethodDescriptor<String, String> methodDesc, String request) {
return doCall(methodDesc, request, CallOptions.DEFAULT);
}
private ListenableFuture<String> doCall(
MethodDescriptor<String, String> methodDesc, String request, CallOptions callOptions) {
ListenableFuture<String> future = ListenableFuture<String> future =
ClientCalls.futureUnaryCall(channel.newCall(methodDesc, CallOptions.DEFAULT), request); ClientCalls.futureUnaryCall(channel.newCall(methodDesc, callOptions), request);
return Futures.withTimeout( return withTestTimeout(future);
future, 5L, TimeUnit.SECONDS, Executors.newSingleThreadScheduledExecutor());
} }
@Test @Test
@ -147,6 +173,25 @@ public final class BinderChannelSmokeTest {
assertThat(res.length()).isEqualTo(SLIGHTLY_MORE_THAN_ONE_BLOCK); assertThat(res.length()).isEqualTo(SLIGHTLY_MORE_THAN_ONE_BLOCK);
} }
@Test
public void testSingleRequestCallOptionHeaders() throws Exception {
CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(1234, MINUTES);
assertThat(doCall(method, "Hello", callOptions).get()).isEqualTo("Hello");
assertThat(headersCapture.get().get(GrpcUtil.TIMEOUT_KEY)).isGreaterThan(0);
}
@Test
public void testStreamingCallOptionHeaders() throws Exception {
CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(1234, MINUTES);
QueueingStreamObserver<String> responseStreamObserver = new QueueingStreamObserver<>();
StreamObserver<String> streamObserver =
ClientCalls.asyncBidiStreamingCall(
channel.newCall(bidiMethod, callOptions), responseStreamObserver);
streamObserver.onCompleted();
assertThat(withTestTimeout(responseStreamObserver.getAllStreamElements()).get()).isEmpty();
assertThat(headersCapture.get().get(GrpcUtil.TIMEOUT_KEY)).isGreaterThan(0);
}
private static String createLargeString(int size) { private static String createLargeString(int size) {
StringBuilder sb = new StringBuilder(); StringBuilder sb = new StringBuilder();
while (sb.length() < size) { while (sb.length() < size) {
@ -156,6 +201,10 @@ public final class BinderChannelSmokeTest {
return sb.toString(); return sb.toString();
} }
private static <V> ListenableFuture<V> withTestTimeout(ListenableFuture<V> future) {
return Futures.withTimeout(future, 5L, SECONDS, Executors.newSingleThreadScheduledExecutor());
}
private static class StringMarshaller implements MethodDescriptor.Marshaller<String> { private static class StringMarshaller implements MethodDescriptor.Marshaller<String> {
public static final StringMarshaller INSTANCE = new StringMarshaller(); public static final StringMarshaller INSTANCE = new StringMarshaller();
@ -173,4 +222,51 @@ public final class BinderChannelSmokeTest {
} }
} }
} }
private static class QueueingStreamObserver<V> implements StreamObserver<V> {
private final ArrayList<V> elements = new ArrayList<>();
private final SettableFuture<Iterable<V>> result = SettableFuture.create();
public ListenableFuture<Iterable<V>> getAllStreamElements() {
return result;
}
@Override
public void onNext(V value) {
elements.add(value);
}
@Override
public void onError(Throwable t) {
result.setException(t);
}
@Override
public void onCompleted() {
result.set(elements);
}
}
private static class ForwardingStreamObserver<V> implements StreamObserver<V> {
private final StreamObserver<V> delegate;
ForwardingStreamObserver(StreamObserver<V> delegate) {
this.delegate = delegate;
}
@Override
public void onNext(V value) {
delegate.onNext(value);
}
@Override
public void onError(Throwable t) {
delegate.onError(t);
}
@Override
public void onCompleted() {
delegate.onCompleted();
}
}
} }

View File

@ -64,14 +64,16 @@ final class MultiMessageClientStream implements ClientStream {
} }
if (outbound.isReady()) { if (outbound.isReady()) {
listener.onReady(); listener.onReady();
try { }
synchronized (outbound) { try {
outbound.send(); synchronized (outbound) {
} // The ClientStream contract promises no more header changes after start().
} catch (StatusException se) { outbound.onPrefixReady();
synchronized (inbound) { outbound.send();
inbound.closeAbnormal(se.getStatus()); }
} } catch (StatusException se) {
synchronized (inbound) {
inbound.closeAbnormal(se.getStatus());
} }
} }
} }
@ -122,6 +124,13 @@ final class MultiMessageClientStream implements ClientStream {
} }
} }
@Override
public void setDeadline(@Nonnull Deadline deadline) {
synchronized (outbound) {
outbound.setDeadline(deadline);
}
}
@Override @Override
public Attributes getAttributes() { public Attributes getAttributes() {
return attributes; return attributes;
@ -150,11 +159,6 @@ final class MultiMessageClientStream implements ClientStream {
// Ignore. // Ignore.
} }
@Override
public void setDeadline(@Nonnull Deadline deadline) {
// Ignore. (Deadlines should still work at a higher level).
}
@Override @Override
public void setAuthority(String authority) { public void setAuthority(String authority) {
// Ignore. // Ignore.

View File

@ -18,8 +18,11 @@ package io.grpc.binder.internal;
import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkNotNull;
import static com.google.common.base.Preconditions.checkState; import static com.google.common.base.Preconditions.checkState;
import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY;
import static java.lang.Math.max;
import android.os.Parcel; import android.os.Parcel;
import io.grpc.Deadline;
import io.grpc.Metadata; import io.grpc.Metadata;
import io.grpc.MethodDescriptor; import io.grpc.MethodDescriptor;
import io.grpc.Status; import io.grpc.Status;
@ -29,6 +32,7 @@ import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.util.Queue; import java.util.Queue;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable; import javax.annotation.Nullable;
import javax.annotation.concurrent.GuardedBy; import javax.annotation.concurrent.GuardedBy;
@ -108,7 +112,7 @@ abstract class Outbound {
return statsTraceContext; return statsTraceContext;
} }
/** Call to add a message to be delivered. */ /** Call to add a message to be delivered. Implies onPrefixReady(). */
@GuardedBy("this") @GuardedBy("this")
final void addMessage(InputStream message) throws StatusException { final void addMessage(InputStream message) throws StatusException {
onPrefixReady(); // This is implied. onPrefixReady(); // This is implied.
@ -215,7 +219,6 @@ abstract class Outbound {
} }
@GuardedBy("this") @GuardedBy("this")
@SuppressWarnings("fallthrough")
protected final void sendInternal() throws StatusException { protected final void sendInternal() throws StatusException {
Parcel parcel = Parcel.obtain(); Parcel parcel = Parcel.obtain();
int flags = 0; int flags = 0;
@ -354,7 +357,6 @@ abstract class Outbound {
this.method = method; this.method = method;
this.headers = headers; this.headers = headers;
this.statsTraceContext = statsTraceContext; this.statsTraceContext = statsTraceContext;
onPrefixReady(); // Client prefix is available immediately.
} }
@Override @Override
@ -369,6 +371,7 @@ abstract class Outbound {
return 0; return 0;
} }
// Implies onPrefixReady() and onSuffixReady().
@GuardedBy("this") @GuardedBy("this")
void sendSingleMessageAndHalfClose(@Nullable InputStream singleMessage) throws StatusException { void sendSingleMessageAndHalfClose(@Nullable InputStream singleMessage) throws StatusException {
if (singleMessage != null) { if (singleMessage != null) {
@ -390,6 +393,14 @@ abstract class Outbound {
// Client doesn't include anything in the suffix. // Client doesn't include anything in the suffix.
return 0; return 0;
} }
// Must not be called after onPrefixReady() (explicitly or via another method that implies it).
@GuardedBy("this")
void setDeadline(Deadline deadline) {
headers.discardAll(TIMEOUT_KEY);
long effectiveTimeoutNanos = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS));
headers.put(TIMEOUT_KEY, effectiveTimeoutNanos);
}
} }
// ====================================== // ======================================

View File

@ -52,6 +52,7 @@ final class SingleMessageClientStream implements ClientStream {
private final Attributes attributes; private final Attributes attributes;
@Nullable private InputStream pendingSingleMessage; @Nullable private InputStream pendingSingleMessage;
@Nullable private Deadline pendingDeadline;
SingleMessageClientStream( SingleMessageClientStream(
Inbound.ClientInbound inbound, Outbound.ClientOutbound outbound, Attributes attributes) { Inbound.ClientInbound inbound, Outbound.ClientOutbound outbound, Attributes attributes) {
@ -97,6 +98,10 @@ final class SingleMessageClientStream implements ClientStream {
public void halfClose() { public void halfClose() {
try { try {
synchronized (outbound) { synchronized (outbound) {
if (pendingDeadline != null) {
outbound.setDeadline(pendingDeadline);
}
outbound.onPrefixReady();
outbound.sendSingleMessageAndHalfClose(pendingSingleMessage); outbound.sendSingleMessageAndHalfClose(pendingSingleMessage);
} }
} catch (StatusException se) { } catch (StatusException se) {
@ -113,6 +118,11 @@ final class SingleMessageClientStream implements ClientStream {
} }
} }
@Override
public void setDeadline(@Nonnull Deadline deadline) {
this.pendingDeadline = deadline;
}
@Override @Override
public Attributes getAttributes() { public Attributes getAttributes() {
return attributes; return attributes;
@ -141,11 +151,6 @@ final class SingleMessageClientStream implements ClientStream {
// Ignore. // Ignore.
} }
@Override
public void setDeadline(@Nonnull Deadline deadline) {
// Ignore. (Deadlines should still work at a higher level).
}
@Override @Override
public void setAuthority(String authority) { public void setAuthority(String authority) {
// Ignore. // Ignore.