Skip to content

Commit d67f108

Browse files
authored
fix: respect PDML timeout when using streaming RPC (#338)
* fix: respect PDML timeout when using streaming RPC * fix: check for negative or zero deadline * fix: subtract from original timeout to get remaining
1 parent 78c3192 commit d67f108

File tree

3 files changed

+313
-5
lines changed

3 files changed

+313
-5
lines changed

google-cloud-spanner/src/main/java/com/google/cloud/spanner/PartitionedDMLTransaction.java

+14-5
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import static com.google.common.base.Preconditions.checkState;
2020

2121
import com.google.api.gax.grpc.GrpcStatusCode;
22+
import com.google.api.gax.rpc.DeadlineExceededException;
2223
import com.google.api.gax.rpc.ServerStream;
2324
import com.google.api.gax.rpc.UnavailableException;
2425
import com.google.cloud.spanner.SessionImpl.SessionTransaction;
@@ -77,13 +78,12 @@ private ByteString initTransaction() {
7778
* statement, and will retry the stream if an {@link UnavailableException} is thrown, using the
7879
* last seen resume token if the server returns any.
7980
*/
80-
long executeStreamingPartitionedUpdate(final Statement statement, Duration timeout) {
81+
long executeStreamingPartitionedUpdate(final Statement statement, final Duration timeout) {
8182
checkState(isValid, "Partitioned DML has been invalidated by a new operation on the session");
8283
log.log(Level.FINER, "Starting PartitionedUpdate statement");
8384
boolean foundStats = false;
8485
long updateCount = 0L;
85-
Duration remainingTimeout = timeout;
86-
Stopwatch stopWatch = Stopwatch.createStarted();
86+
Stopwatch stopWatch = createStopwatchStarted();
8787
try {
8888
// Loop to catch AbortedExceptions.
8989
while (true) {
@@ -105,8 +105,13 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
105105
}
106106
}
107107
while (true) {
108-
remainingTimeout =
109-
remainingTimeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
108+
Duration remainingTimeout =
109+
timeout.minus(stopWatch.elapsed(TimeUnit.MILLISECONDS), ChronoUnit.MILLIS);
110+
if (remainingTimeout.isNegative() || remainingTimeout.isZero()) {
111+
// The total deadline has been exceeded while retrying.
112+
throw new DeadlineExceededException(
113+
null, GrpcStatusCode.of(Code.DEADLINE_EXCEEDED), false);
114+
}
110115
try {
111116
builder.setResumeToken(resumeToken);
112117
ServerStream<PartialResultSet> stream =
@@ -157,6 +162,10 @@ long executeStreamingPartitionedUpdate(final Statement statement, Duration timeo
157162
}
158163
}
159164

165+
Stopwatch createStopwatchStarted() {
166+
return Stopwatch.createStarted();
167+
}
168+
160169
@Override
161170
public void invalidate() {
162171
isValid = false;

google-cloud-spanner/src/main/java/com/google/cloud/spanner/spi/v1/GapicSpannerRpc.java

+1
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,7 @@ public GapicSpannerRpc(final SpannerOptions options) {
369369
.setStreamWatchdogProvider(watchdogProvider)
370370
.executeSqlSettings()
371371
.setRetrySettings(partitionedDmlRetrySettings);
372+
pdmlSettings.executeStreamingSqlSettings().setRetrySettings(partitionedDmlRetrySettings);
372373
// The stream watchdog will by default only check for a timeout every 10 seconds, so if the
373374
// timeout is less than 10 seconds, it would be ignored for the first 10 seconds unless we
374375
// also change the StreamWatchdogCheckInterval.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
/*
2+
* Copyright 2020 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package com.google.cloud.spanner;
18+
19+
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.fail;
21+
import static org.mockito.Matchers.any;
22+
import static org.mockito.Matchers.anyMap;
23+
import static org.mockito.Mockito.mock;
24+
import static org.mockito.Mockito.times;
25+
import static org.mockito.Mockito.verify;
26+
import static org.mockito.Mockito.when;
27+
28+
import com.google.api.gax.grpc.GrpcStatusCode;
29+
import com.google.api.gax.rpc.AbortedException;
30+
import com.google.api.gax.rpc.ServerStream;
31+
import com.google.api.gax.rpc.UnavailableException;
32+
import com.google.cloud.spanner.spi.v1.SpannerRpc;
33+
import com.google.common.base.Stopwatch;
34+
import com.google.common.base.Ticker;
35+
import com.google.common.collect.ImmutableList;
36+
import com.google.protobuf.ByteString;
37+
import com.google.spanner.v1.BeginTransactionRequest;
38+
import com.google.spanner.v1.ExecuteSqlRequest;
39+
import com.google.spanner.v1.ExecuteSqlRequest.QueryMode;
40+
import com.google.spanner.v1.PartialResultSet;
41+
import com.google.spanner.v1.ResultSetStats;
42+
import com.google.spanner.v1.Transaction;
43+
import com.google.spanner.v1.TransactionSelector;
44+
import io.grpc.Status.Code;
45+
import java.util.Collections;
46+
import java.util.Iterator;
47+
import java.util.concurrent.TimeUnit;
48+
import org.junit.Before;
49+
import org.junit.Test;
50+
import org.junit.runner.RunWith;
51+
import org.junit.runners.JUnit4;
52+
import org.mockito.Mock;
53+
import org.mockito.Mockito;
54+
import org.mockito.MockitoAnnotations;
55+
import org.mockito.invocation.InvocationOnMock;
56+
import org.mockito.stubbing.Answer;
57+
import org.threeten.bp.Duration;
58+
59+
@SuppressWarnings("unchecked")
60+
@RunWith(JUnit4.class)
61+
public class PartitionedDmlTransactionTest {
62+
63+
@Mock private SpannerRpc rpc;
64+
65+
@Mock private SessionImpl session;
66+
67+
private final String sessionId = "projects/p/instances/i/databases/d/sessions/s";
68+
private final ByteString txId = ByteString.copyFromUtf8("tx");
69+
private final ByteString resumeToken = ByteString.copyFromUtf8("resume");
70+
private final String sql = "UPDATE FOO SET BAR=1 WHERE TRUE";
71+
private final ExecuteSqlRequest executeRequestWithoutResumeToken =
72+
ExecuteSqlRequest.newBuilder()
73+
.setQueryMode(QueryMode.NORMAL)
74+
.setSession(sessionId)
75+
.setSql(sql)
76+
.setTransaction(TransactionSelector.newBuilder().setId(txId))
77+
.build();
78+
private final ExecuteSqlRequest executeRequestWithResumeToken =
79+
executeRequestWithoutResumeToken.toBuilder().setResumeToken(resumeToken).build();
80+
81+
@Before
82+
public void setup() {
83+
MockitoAnnotations.initMocks(this);
84+
when(session.getName()).thenReturn(sessionId);
85+
when(session.getOptions()).thenReturn(Collections.EMPTY_MAP);
86+
when(rpc.beginTransaction(any(BeginTransactionRequest.class), anyMap()))
87+
.thenReturn(Transaction.newBuilder().setId(txId).build());
88+
}
89+
90+
@Test
91+
public void testExecuteStreamingPartitionedUpdate() {
92+
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
93+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
94+
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
95+
ServerStream<PartialResultSet> stream = mock(ServerStream.class);
96+
when(stream.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
97+
when(rpc.executeStreamingPartitionedDml(
98+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
99+
.thenReturn(stream);
100+
101+
PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
102+
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
103+
assertThat(count).isEqualTo(1000L);
104+
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
105+
verify(rpc)
106+
.executeStreamingPartitionedDml(
107+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
108+
}
109+
110+
@Test
111+
public void testExecuteStreamingPartitionedUpdateAborted() {
112+
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
113+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
114+
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
115+
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
116+
Iterator<PartialResultSet> iterator = mock(Iterator.class);
117+
when(iterator.hasNext()).thenReturn(true, true, false);
118+
when(iterator.next())
119+
.thenReturn(p1)
120+
.thenThrow(
121+
new AbortedException(
122+
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
123+
when(stream1.iterator()).thenReturn(iterator);
124+
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
125+
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
126+
when(rpc.executeStreamingPartitionedDml(
127+
any(ExecuteSqlRequest.class), anyMap(), any(Duration.class)))
128+
.thenReturn(stream1, stream2);
129+
130+
PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
131+
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
132+
assertThat(count).isEqualTo(1000L);
133+
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
134+
verify(rpc, times(2))
135+
.executeStreamingPartitionedDml(
136+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
137+
}
138+
139+
@Test
140+
public void testExecuteStreamingPartitionedUpdateUnavailable() {
141+
ResultSetStats stats = ResultSetStats.newBuilder().setRowCountLowerBound(1000L).build();
142+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
143+
PartialResultSet p2 = PartialResultSet.newBuilder().setStats(stats).build();
144+
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
145+
Iterator<PartialResultSet> iterator = mock(Iterator.class);
146+
when(iterator.hasNext()).thenReturn(true, true, false);
147+
when(iterator.next())
148+
.thenReturn(p1)
149+
.thenThrow(
150+
new UnavailableException(
151+
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
152+
when(stream1.iterator()).thenReturn(iterator);
153+
ServerStream<PartialResultSet> stream2 = mock(ServerStream.class);
154+
when(stream2.iterator()).thenReturn(ImmutableList.of(p1, p2).iterator());
155+
when(rpc.executeStreamingPartitionedDml(
156+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
157+
.thenReturn(stream1);
158+
when(rpc.executeStreamingPartitionedDml(
159+
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class)))
160+
.thenReturn(stream2);
161+
162+
PartitionedDMLTransaction tx = new PartitionedDMLTransaction(session, rpc);
163+
long count = tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
164+
assertThat(count).isEqualTo(1000L);
165+
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
166+
verify(rpc)
167+
.executeStreamingPartitionedDml(
168+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
169+
verify(rpc)
170+
.executeStreamingPartitionedDml(
171+
Mockito.eq(executeRequestWithResumeToken), anyMap(), any(Duration.class));
172+
}
173+
174+
@Test
175+
public void testExecuteStreamingPartitionedUpdateUnavailableAndThenDeadlineExceeded() {
176+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
177+
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
178+
Iterator<PartialResultSet> iterator = mock(Iterator.class);
179+
when(iterator.hasNext()).thenReturn(true, true, false);
180+
when(iterator.next())
181+
.thenReturn(p1)
182+
.thenThrow(
183+
new UnavailableException(
184+
"temporary unavailable", null, GrpcStatusCode.of(Code.UNAVAILABLE), true));
185+
when(stream1.iterator()).thenReturn(iterator);
186+
when(rpc.executeStreamingPartitionedDml(
187+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
188+
.thenReturn(stream1);
189+
190+
PartitionedDMLTransaction tx =
191+
new PartitionedDMLTransaction(session, rpc) {
192+
@Override
193+
Stopwatch createStopwatchStarted() {
194+
Ticker ticker = mock(Ticker.class);
195+
when(ticker.read())
196+
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
197+
return Stopwatch.createStarted(ticker);
198+
}
199+
};
200+
try {
201+
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
202+
fail("missing expected DEADLINE_EXCEEDED exception");
203+
} catch (SpannerException e) {
204+
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
205+
verify(rpc).beginTransaction(any(BeginTransactionRequest.class), anyMap());
206+
verify(rpc)
207+
.executeStreamingPartitionedDml(
208+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
209+
}
210+
}
211+
212+
@Test
213+
public void testExecuteStreamingPartitionedUpdateAbortedAndThenDeadlineExceeded() {
214+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
215+
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
216+
Iterator<PartialResultSet> iterator = mock(Iterator.class);
217+
when(iterator.hasNext()).thenReturn(true, true, false);
218+
when(iterator.next())
219+
.thenReturn(p1)
220+
.thenThrow(
221+
new AbortedException(
222+
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
223+
when(stream1.iterator()).thenReturn(iterator);
224+
when(rpc.executeStreamingPartitionedDml(
225+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
226+
.thenReturn(stream1);
227+
228+
PartitionedDMLTransaction tx =
229+
new PartitionedDMLTransaction(session, rpc) {
230+
@Override
231+
Stopwatch createStopwatchStarted() {
232+
Ticker ticker = mock(Ticker.class);
233+
when(ticker.read())
234+
.thenReturn(0L, 1L, TimeUnit.NANOSECONDS.convert(10L, TimeUnit.MINUTES));
235+
return Stopwatch.createStarted(ticker);
236+
}
237+
};
238+
try {
239+
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
240+
fail("missing expected DEADLINE_EXCEEDED exception");
241+
} catch (SpannerException e) {
242+
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
243+
verify(rpc, times(2)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
244+
verify(rpc)
245+
.executeStreamingPartitionedDml(
246+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
247+
}
248+
}
249+
250+
@Test
251+
public void testExecuteStreamingPartitionedUpdateMultipleAbortsUntilDeadlineExceeded() {
252+
PartialResultSet p1 = PartialResultSet.newBuilder().setResumeToken(resumeToken).build();
253+
ServerStream<PartialResultSet> stream1 = mock(ServerStream.class);
254+
Iterator<PartialResultSet> iterator = mock(Iterator.class);
255+
when(iterator.hasNext()).thenReturn(true);
256+
when(iterator.next())
257+
.thenReturn(p1)
258+
.thenThrow(
259+
new AbortedException(
260+
"transaction aborted", null, GrpcStatusCode.of(Code.ABORTED), true));
261+
when(stream1.iterator()).thenReturn(iterator);
262+
when(rpc.executeStreamingPartitionedDml(
263+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class)))
264+
.thenReturn(stream1);
265+
266+
PartitionedDMLTransaction tx =
267+
new PartitionedDMLTransaction(session, rpc) {
268+
long ticks = 0L;
269+
270+
@Override
271+
Stopwatch createStopwatchStarted() {
272+
Ticker ticker = mock(Ticker.class);
273+
when(ticker.read())
274+
.thenAnswer(
275+
new Answer<Long>() {
276+
@Override
277+
public Long answer(InvocationOnMock invocation) throws Throwable {
278+
return TimeUnit.NANOSECONDS.convert(++ticks, TimeUnit.MINUTES);
279+
}
280+
});
281+
return Stopwatch.createStarted(ticker);
282+
}
283+
};
284+
try {
285+
tx.executeStreamingPartitionedUpdate(Statement.of(sql), Duration.ofMinutes(10));
286+
fail("missing expected DEADLINE_EXCEEDED exception");
287+
} catch (SpannerException e) {
288+
assertThat(e.getErrorCode()).isEqualTo(ErrorCode.DEADLINE_EXCEEDED);
289+
// It should start a transaction exactly 10 times (10 ticks == 10 minutes).
290+
verify(rpc, times(10)).beginTransaction(any(BeginTransactionRequest.class), anyMap());
291+
// The last transaction should timeout before it starts the actual statement execution, which
292+
// means that the execute method is only executed 9 times.
293+
verify(rpc, times(9))
294+
.executeStreamingPartitionedDml(
295+
Mockito.eq(executeRequestWithoutResumeToken), anyMap(), any(Duration.class));
296+
}
297+
}
298+
}

0 commit comments

Comments
 (0)