From 49c0032a11f2bac93d4ffc44be9fc80476a55968 Mon Sep 17 00:00:00 2001 From: Ben Christensen Date: Sat, 30 Aug 2014 21:09:36 -0700 Subject: [PATCH] Operator Scan Backpressure Fix Problem 1) The initial value was being emitted before subscription which caused issues with request/producer state, particularly if filter() was used to skip that initial value and then called request(1) before the real request had been sent. Problem 2) The initial value was not accounted for by the request so it was sending 1 more value than requested. It now modifies the request to account for it. Problem 3) Redo relied upon these nuances to work. I've fixed this by using a simpler implementation that just maintains state within a map function. --- .../internal/operators/OnSubscribeRedo.java | 27 ++-- .../rx/internal/operators/OperatorScan.java | 63 ++++++-- .../internal/operators/OperatorScanTest.java | 150 ++++++++++++++++++ 3 files changed, 219 insertions(+), 21 deletions(-) diff --git a/src/main/java/rx/internal/operators/OnSubscribeRedo.java b/src/main/java/rx/internal/operators/OnSubscribeRedo.java index e49a787102..21cd177323 100644 --- a/src/main/java/rx/internal/operators/OnSubscribeRedo.java +++ b/src/main/java/rx/internal/operators/OnSubscribeRedo.java @@ -74,18 +74,24 @@ public RedoFinite(long count) { @Override public Observable call(Observable> ts) { - final Notification first = count < 0 ? Notification. createOnCompleted() : Notification.createOnNext(0l); + return ts.map(new Func1, Notification>() { - return ts.scan(first, new Func2, Notification, Notification>() { - @SuppressWarnings("unchecked") + int num=0; + @Override - public Notification call(Notification n, Notification term) { - final long value = n.getValue(); - if (value < count) - return Notification.createOnNext(value + 1); - else - return (Notification) term; + public Notification call(Notification terminalNotification) { + if(count == 0) { + return terminalNotification; + } + + num++; + if(num <= count) { + return Notification.createOnNext(num); + } else { + return terminalNotification; + } } + }).dematerialize(); } } @@ -146,6 +152,9 @@ public static Observable repeat(Observable source, final long count) { } public static Observable repeat(Observable source, final long count, Scheduler scheduler) { + if(count == 0) { + return Observable.empty(); + } if (count < 0) throw new IllegalArgumentException("count >= 0 expected"); return repeat(source, new RedoFinite(count - 1), scheduler); diff --git a/src/main/java/rx/internal/operators/OperatorScan.java b/src/main/java/rx/internal/operators/OperatorScan.java index 6477fec7bf..35f653dc3b 100644 --- a/src/main/java/rx/internal/operators/OperatorScan.java +++ b/src/main/java/rx/internal/operators/OperatorScan.java @@ -15,7 +15,10 @@ */ package rx.internal.operators; +import java.util.concurrent.atomic.AtomicBoolean; + import rx.Observable.Operator; +import rx.Producer; import rx.Subscriber; import rx.exceptions.OnErrorThrowable; import rx.functions.Func2; @@ -70,37 +73,73 @@ public OperatorScan(final Func2 accumulator) { } @Override - public Subscriber call(final Subscriber observer) { - if (initialValue != NO_INITIAL_VALUE) { - observer.onNext(initialValue); - } - return new Subscriber(observer) { + public Subscriber call(final Subscriber child) { + return new Subscriber(child) { private R value = initialValue; + boolean initialized = false; @SuppressWarnings("unchecked") @Override - public void onNext(T value) { + public void onNext(T currentValue) { + emitInitialValueIfNeeded(child); + if (this.value == NO_INITIAL_VALUE) { // if there is NO_INITIAL_VALUE then we know it is type T for both so cast T to R - this.value = (R) value; + this.value = (R) currentValue; } else { try { - this.value = accumulator.call(this.value, value); + this.value = accumulator.call(this.value, currentValue); } catch (Throwable e) { - observer.onError(OnErrorThrowable.addValueAsLastCause(e, value)); + child.onError(OnErrorThrowable.addValueAsLastCause(e, currentValue)); } } - observer.onNext(this.value); + child.onNext(this.value); } @Override public void onError(Throwable e) { - observer.onError(e); + child.onError(e); } @Override public void onCompleted() { - observer.onCompleted(); + emitInitialValueIfNeeded(child); + child.onCompleted(); + } + + private void emitInitialValueIfNeeded(final Subscriber child) { + if (!initialized) { + initialized = true; + // we emit first time through if we have an initial value + if (initialValue != NO_INITIAL_VALUE) { + child.onNext(initialValue); + } + } + } + + /** + * We want to adjust the requested value by subtracting 1 if we have an initial value + */ + @Override + public void setProducer(final Producer producer) { + child.setProducer(new Producer() { + + final AtomicBoolean once = new AtomicBoolean(); + + @Override + public void request(long n) { + if (once.compareAndSet(false, true)) { + if (initialValue == NO_INITIAL_VALUE) { + producer.request(n); + } else { + producer.request(n - 1); + } + } else { + // pass-thru after first time + producer.request(n); + } + } + }); } }; } diff --git a/src/test/java/rx/internal/operators/OperatorScanTest.java b/src/test/java/rx/internal/operators/OperatorScanTest.java index dcd5606794..e379b65e6a 100644 --- a/src/test/java/rx/internal/operators/OperatorScanTest.java +++ b/src/test/java/rx/internal/operators/OperatorScanTest.java @@ -15,6 +15,7 @@ */ package rx.internal.operators; +import static org.junit.Assert.*; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyInt; import static org.mockito.Matchers.anyString; @@ -23,13 +24,21 @@ import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import java.util.concurrent.atomic.AtomicInteger; + import org.junit.Before; import org.junit.Test; import org.mockito.MockitoAnnotations; import rx.Observable; import rx.Observer; +import rx.Subscriber; +import rx.functions.Action1; +import rx.functions.Func1; import rx.functions.Func2; +import rx.internal.util.RxRingBuffer; +import rx.observers.TestSubscriber; +import rx.schedulers.Schedulers; public class OperatorScanTest { @@ -116,4 +125,145 @@ public Integer call(Integer t1, Integer t2) { verify(observer, times(1)).onCompleted(); verify(observer, never()).onError(any(Throwable.class)); } + + @Test + public void shouldNotEmitUntilAfterSubscription() { + TestSubscriber ts = new TestSubscriber(); + Observable.range(1, 100).scan(0, new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + + }).filter(new Func1() { + + @Override + public Boolean call(Integer t1) { + // this will cause request(1) when 0 is emitted + return t1 > 0; + } + + }).subscribe(ts); + + assertEquals(100, ts.getOnNextEvents().size()); + } + + @Test + public void testBackpressureWithInitialValue() { + final AtomicInteger count = new AtomicInteger(); + Observable.range(1, 100) + .scan(0, new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + + }) + .subscribe(new Subscriber() { + + @Override + public void onStart() { + request(10); + } + + @Override + public void onCompleted() { + + } + + @Override + public void onError(Throwable e) { + fail(e.getMessage()); + e.printStackTrace(); + } + + @Override + public void onNext(Integer t) { + count.incrementAndGet(); + } + + }); + + // we only expect to receive 10 since we request(10) + assertEquals(10, count.get()); + } + + @Test + public void testBackpressureWithoutInitialValue() { + final AtomicInteger count = new AtomicInteger(); + Observable.range(1, 100) + .scan(new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + + }) + .subscribe(new Subscriber() { + + @Override + public void onStart() { + request(10); + } + + @Override + public void onCompleted() { + + } + + @Override + public void onError(Throwable e) { + fail(e.getMessage()); + e.printStackTrace(); + } + + @Override + public void onNext(Integer t) { + count.incrementAndGet(); + } + + }); + + // we only expect to receive 10 since we request(10) + assertEquals(10, count.get()); + } + + @Test + public void testNoBackpressureWithInitialValue() { + final AtomicInteger count = new AtomicInteger(); + Observable.range(1, 100) + .scan(0, new Func2() { + + @Override + public Integer call(Integer t1, Integer t2) { + return t1 + t2; + } + + }) + .subscribe(new Subscriber() { + + @Override + public void onCompleted() { + + } + + @Override + public void onError(Throwable e) { + fail(e.getMessage()); + e.printStackTrace(); + } + + @Override + public void onNext(Integer t) { + count.incrementAndGet(); + } + + }); + + // we only expect to receive 101 as we'll receive all 100 + the initial value + assertEquals(101, count.get()); + } }