diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java index a12894d4da..dd9cf26bd3 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java @@ -18,6 +18,8 @@ package io.grpc.binder; import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; import android.app.Application; import android.app.Service; @@ -29,19 +31,26 @@ import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.io.ByteStreams; import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; import io.grpc.ManagedChannel; +import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Server; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptors; import io.grpc.ServerServiceDefinition; +import io.grpc.internal.GrpcUtil; import io.grpc.stub.ClientCalls; import io.grpc.stub.ServerCalls; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.TestUtils; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; +import java.util.ArrayList; import java.util.concurrent.Executors; -import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -70,8 +79,15 @@ public final class BinderChannelSmokeTest { .setType(MethodDescriptor.MethodType.SERVER_STREAMING) .build(); + final MethodDescriptor bidiMethod = + MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) + .setFullMethodName("test/bidiMethod") + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .build(); + AndroidComponentAddress serverAddress; ManagedChannel channel; + AtomicReference headersCapture = new AtomicReference<>(); @Before public void setUp() throws Exception { @@ -89,11 +105,17 @@ public final class BinderChannelSmokeTest { respObserver.onCompleted(); }); + ServerCallHandler bidiCallHandler = + ServerCalls.asyncBidiStreamingCall(ForwardingStreamObserver::new); // Echo it all back. + ServerServiceDefinition serviceDef = - ServerServiceDefinition.builder("test") - .addMethod(method, callHandler) - .addMethod(singleLargeResultMethod, singleLargeResultCallHandler) - .build(); + ServerInterceptors.intercept( + ServerServiceDefinition.builder("test") + .addMethod(method, callHandler) + .addMethod(singleLargeResultMethod, singleLargeResultCallHandler) + .addMethod(bidiMethod, bidiCallHandler) + .build(), + TestUtils.recordRequestHeadersInterceptor(headersCapture)); AndroidComponentAddress serverAddress = HostServices.allocateService(appContext); HostServices.configureService(serverAddress, @@ -119,10 +141,14 @@ public final class BinderChannelSmokeTest { private ListenableFuture doCall( MethodDescriptor methodDesc, String request) { + return doCall(methodDesc, request, CallOptions.DEFAULT); + } + + private ListenableFuture doCall( + MethodDescriptor methodDesc, String request, CallOptions callOptions) { ListenableFuture future = - ClientCalls.futureUnaryCall(channel.newCall(methodDesc, CallOptions.DEFAULT), request); - return Futures.withTimeout( - future, 5L, TimeUnit.SECONDS, Executors.newSingleThreadScheduledExecutor()); + ClientCalls.futureUnaryCall(channel.newCall(methodDesc, callOptions), request); + return withTestTimeout(future); } @Test @@ -147,6 +173,25 @@ public final class BinderChannelSmokeTest { assertThat(res.length()).isEqualTo(SLIGHTLY_MORE_THAN_ONE_BLOCK); } + @Test + public void testSingleRequestCallOptionHeaders() throws Exception { + CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(1234, MINUTES); + assertThat(doCall(method, "Hello", callOptions).get()).isEqualTo("Hello"); + assertThat(headersCapture.get().get(GrpcUtil.TIMEOUT_KEY)).isGreaterThan(0); + } + + @Test + public void testStreamingCallOptionHeaders() throws Exception { + CallOptions callOptions = CallOptions.DEFAULT.withDeadlineAfter(1234, MINUTES); + QueueingStreamObserver responseStreamObserver = new QueueingStreamObserver<>(); + StreamObserver streamObserver = + ClientCalls.asyncBidiStreamingCall( + channel.newCall(bidiMethod, callOptions), responseStreamObserver); + streamObserver.onCompleted(); + assertThat(withTestTimeout(responseStreamObserver.getAllStreamElements()).get()).isEmpty(); + assertThat(headersCapture.get().get(GrpcUtil.TIMEOUT_KEY)).isGreaterThan(0); + } + private static String createLargeString(int size) { StringBuilder sb = new StringBuilder(); while (sb.length() < size) { @@ -156,6 +201,10 @@ public final class BinderChannelSmokeTest { return sb.toString(); } + private static ListenableFuture withTestTimeout(ListenableFuture future) { + return Futures.withTimeout(future, 5L, SECONDS, Executors.newSingleThreadScheduledExecutor()); + } + private static class StringMarshaller implements MethodDescriptor.Marshaller { public static final StringMarshaller INSTANCE = new StringMarshaller(); @@ -173,4 +222,51 @@ public final class BinderChannelSmokeTest { } } } + + private static class QueueingStreamObserver implements StreamObserver { + private final ArrayList elements = new ArrayList<>(); + private final SettableFuture> result = SettableFuture.create(); + + public ListenableFuture> getAllStreamElements() { + return result; + } + + @Override + public void onNext(V value) { + elements.add(value); + } + + @Override + public void onError(Throwable t) { + result.setException(t); + } + + @Override + public void onCompleted() { + result.set(elements); + } + } + + private static class ForwardingStreamObserver implements StreamObserver { + private final StreamObserver delegate; + + ForwardingStreamObserver(StreamObserver delegate) { + this.delegate = delegate; + } + + @Override + public void onNext(V value) { + delegate.onNext(value); + } + + @Override + public void onError(Throwable t) { + delegate.onError(t); + } + + @Override + public void onCompleted() { + delegate.onCompleted(); + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java b/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java index 317925e2d0..9873adcb44 100644 --- a/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/MultiMessageClientStream.java @@ -64,14 +64,16 @@ final class MultiMessageClientStream implements ClientStream { } if (outbound.isReady()) { listener.onReady(); - try { - synchronized (outbound) { - outbound.send(); - } - } catch (StatusException se) { - synchronized (inbound) { - inbound.closeAbnormal(se.getStatus()); - } + } + try { + synchronized (outbound) { + // The ClientStream contract promises no more header changes after start(). + outbound.onPrefixReady(); + outbound.send(); + } + } catch (StatusException se) { + synchronized (inbound) { + inbound.closeAbnormal(se.getStatus()); } } } @@ -122,6 +124,13 @@ final class MultiMessageClientStream implements ClientStream { } } + @Override + public void setDeadline(@Nonnull Deadline deadline) { + synchronized (outbound) { + outbound.setDeadline(deadline); + } + } + @Override public Attributes getAttributes() { return attributes; @@ -150,11 +159,6 @@ final class MultiMessageClientStream implements ClientStream { // Ignore. } - @Override - public void setDeadline(@Nonnull Deadline deadline) { - // Ignore. (Deadlines should still work at a higher level). - } - @Override public void setAuthority(String authority) { // Ignore. diff --git a/binder/src/main/java/io/grpc/binder/internal/Outbound.java b/binder/src/main/java/io/grpc/binder/internal/Outbound.java index 5687461998..f629a05a44 100644 --- a/binder/src/main/java/io/grpc/binder/internal/Outbound.java +++ b/binder/src/main/java/io/grpc/binder/internal/Outbound.java @@ -18,8 +18,11 @@ package io.grpc.binder.internal; import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; +import static io.grpc.internal.GrpcUtil.TIMEOUT_KEY; +import static java.lang.Math.max; import android.os.Parcel; +import io.grpc.Deadline; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; @@ -29,6 +32,7 @@ import java.io.IOException; import java.io.InputStream; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; import javax.annotation.Nullable; import javax.annotation.concurrent.GuardedBy; @@ -108,7 +112,7 @@ abstract class Outbound { return statsTraceContext; } - /** Call to add a message to be delivered. */ + /** Call to add a message to be delivered. Implies onPrefixReady(). */ @GuardedBy("this") final void addMessage(InputStream message) throws StatusException { onPrefixReady(); // This is implied. @@ -215,7 +219,6 @@ abstract class Outbound { } @GuardedBy("this") - @SuppressWarnings("fallthrough") protected final void sendInternal() throws StatusException { Parcel parcel = Parcel.obtain(); int flags = 0; @@ -354,7 +357,6 @@ abstract class Outbound { this.method = method; this.headers = headers; this.statsTraceContext = statsTraceContext; - onPrefixReady(); // Client prefix is available immediately. } @Override @@ -369,6 +371,7 @@ abstract class Outbound { return 0; } + // Implies onPrefixReady() and onSuffixReady(). @GuardedBy("this") void sendSingleMessageAndHalfClose(@Nullable InputStream singleMessage) throws StatusException { if (singleMessage != null) { @@ -390,6 +393,14 @@ abstract class Outbound { // Client doesn't include anything in the suffix. return 0; } + + // Must not be called after onPrefixReady() (explicitly or via another method that implies it). + @GuardedBy("this") + void setDeadline(Deadline deadline) { + headers.discardAll(TIMEOUT_KEY); + long effectiveTimeoutNanos = max(0, deadline.timeRemaining(TimeUnit.NANOSECONDS)); + headers.put(TIMEOUT_KEY, effectiveTimeoutNanos); + } } // ====================================== diff --git a/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java b/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java index 9a1fa0ee6e..8a899d621a 100644 --- a/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java +++ b/binder/src/main/java/io/grpc/binder/internal/SingleMessageClientStream.java @@ -52,6 +52,7 @@ final class SingleMessageClientStream implements ClientStream { private final Attributes attributes; @Nullable private InputStream pendingSingleMessage; + @Nullable private Deadline pendingDeadline; SingleMessageClientStream( Inbound.ClientInbound inbound, Outbound.ClientOutbound outbound, Attributes attributes) { @@ -97,6 +98,10 @@ final class SingleMessageClientStream implements ClientStream { public void halfClose() { try { synchronized (outbound) { + if (pendingDeadline != null) { + outbound.setDeadline(pendingDeadline); + } + outbound.onPrefixReady(); outbound.sendSingleMessageAndHalfClose(pendingSingleMessage); } } catch (StatusException se) { @@ -113,6 +118,11 @@ final class SingleMessageClientStream implements ClientStream { } } + @Override + public void setDeadline(@Nonnull Deadline deadline) { + this.pendingDeadline = deadline; + } + @Override public Attributes getAttributes() { return attributes; @@ -141,11 +151,6 @@ final class SingleMessageClientStream implements ClientStream { // Ignore. } - @Override - public void setDeadline(@Nonnull Deadline deadline) { - // Ignore. (Deadlines should still work at a higher level). - } - @Override public void setAuthority(String authority) { // Ignore.