Skip to content

Commit 30b6b08

Browse files
Merge pull request #502 from benjchristensen/observeOn-parallelMerge
Fix ObserveOn and add ParallelMerge Scheduler overload
2 parents a6a2440 + bc6965c commit 30b6b08

File tree

4 files changed

+139
-11
lines changed

4 files changed

+139
-11
lines changed

rxjava-core/src/main/java/rx/Observable.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4071,6 +4071,23 @@ public static <T> Observable<Observable<T>> parallelMerge(Observable<Observable<
40714071
return OperationParallelMerge.parallelMerge(source, parallelObservables);
40724072
}
40734073

4074+
/**
4075+
* Merges an <code>Observable<Observable<T>></code> to <code>Observable<Observable<T>></code>
4076+
* with number of inner Observables as defined by <code>parallelObservables</code> and runs each Observable on the defined Scheduler.
4077+
* <p>
4078+
* For example, if the original <code>Observable<Observable<T>></code> has 100 Observables to be emitted and <code>parallelObservables</code>
4079+
* is defined as 8, the 100 will be grouped onto 8 output Observables.
4080+
* <p>
4081+
* This is a mechanism for efficiently processing N number of Observables on a smaller N number of resources (typically CPU cores).
4082+
*
4083+
* @param parallelObservables
4084+
* the number of Observables to merge into.
4085+
* @return an Observable of Observables constrained to number defined by <code>parallelObservables</code>.
4086+
*/
4087+
public static <T> Observable<Observable<T>> parallelMerge(Observable<Observable<T>> source, int parallelObservables, Scheduler scheduler) {
4088+
return OperationParallelMerge.parallelMerge(source, parallelObservables, scheduler);
4089+
}
4090+
40744091
/**
40754092
* Returns a {@link ConnectableObservable}, which waits until its
40764093
* {@link ConnectableObservable#connect connect} method is called before it

rxjava-core/src/main/java/rx/operators/OperationObserveOn.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@
2727
import rx.concurrency.CurrentThreadScheduler;
2828
import rx.concurrency.ImmediateScheduler;
2929
import rx.subscriptions.CompositeSubscription;
30+
import rx.subscriptions.Subscriptions;
3031
import rx.util.functions.Action0;
3132
import rx.util.functions.Action1;
33+
import rx.util.functions.Func2;
3234

3335
/**
3436
* Asynchronously notify Observers on the specified Scheduler.
@@ -44,6 +46,7 @@ public static <T> OnSubscribeFunc<T> observeOn(Observable<? extends T> source, S
4446
private static class ObserveOn<T> implements OnSubscribeFunc<T> {
4547
private final Observable<? extends T> source;
4648
private final Scheduler scheduler;
49+
private volatile Scheduler recursiveScheduler;
4750

4851
final ConcurrentLinkedQueue<Notification<? extends T>> queue = new ConcurrentLinkedQueue<Notification<? extends T>>();
4952
final AtomicInteger counter = new AtomicInteger(0);
@@ -66,7 +69,7 @@ public Subscription onSubscribe(final Observer<? super T> observer) {
6669
}
6770
}
6871

69-
public Subscription observeOn(final Observer<? super T> observer, Scheduler scheduler) {
72+
public Subscription observeOn(final Observer<? super T> observer, final Scheduler scheduler) {
7073
final CompositeSubscription s = new CompositeSubscription();
7174

7275
s.add(source.materialize().subscribe(new Action1<Notification<? extends T>>() {
@@ -80,7 +83,22 @@ public void call(Notification<? extends T> e) {
8083
// it will be 0 if it's the first notification or the scheduler has finished processing work
8184
// and we need to start doing it again
8285
if (counter.getAndIncrement() == 0) {
83-
processQueue(s, observer);
86+
if (recursiveScheduler == null) {
87+
s.add(scheduler.schedule(null, new Func2<Scheduler, T, Subscription>() {
88+
89+
@Override
90+
public Subscription call(Scheduler innerScheduler, T state) {
91+
// record innerScheduler so 'processQueue' can use it for all subsequent executions
92+
recursiveScheduler = innerScheduler;
93+
94+
processQueue(s, observer);
95+
96+
return Subscriptions.empty();
97+
}
98+
}));
99+
} else {
100+
processQueue(s, observer);
101+
}
84102
}
85103

86104
}
@@ -89,8 +107,13 @@ public void call(Notification<? extends T> e) {
89107
return s;
90108
}
91109

92-
private void processQueue(CompositeSubscription s, final Observer<? super T> observer) {
93-
s.add(scheduler.schedule(new Action1<Action0>() {
110+
/**
111+
* This uses 'recursiveScheduler' NOT 'scheduler' as it should reuse the same scheduler each time it processes.
112+
* This means it must first get the recursiveScheduler when it first executes.
113+
*/
114+
private void processQueue(final CompositeSubscription s, final Observer<? super T> observer) {
115+
116+
s.add(recursiveScheduler.schedule(new Action1<Action0>() {
94117
@Override
95118
public void call(Action0 self) {
96119
Notification<? extends T> not = queue.poll();

rxjava-core/src/main/java/rx/operators/OperationParallelMerge.java

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,29 +18,31 @@
1818
import java.util.concurrent.atomic.AtomicLong;
1919

2020
import rx.Observable;
21+
import rx.Scheduler;
22+
import rx.concurrency.Schedulers;
2123
import rx.observables.GroupedObservable;
2224
import rx.util.functions.Func1;
2325

2426
public class OperationParallelMerge {
2527

26-
public static <T> Observable<Observable<T>> parallelMerge(final Observable<Observable<T>> source, final int num) {
28+
public static <T> Observable<Observable<T>> parallelMerge(final Observable<Observable<T>> source, final int parallelObservables) {
29+
return parallelMerge(source, parallelObservables, Schedulers.currentThread());
30+
}
31+
32+
public static <T> Observable<Observable<T>> parallelMerge(final Observable<Observable<T>> source, final int parallelObservables, final Scheduler scheduler) {
2733

2834
return source.groupBy(new Func1<Observable<T>, Integer>() {
2935
final AtomicLong rollingCount = new AtomicLong();
3036

3137
@Override
3238
public Integer call(Observable<T> o) {
33-
return (int) rollingCount.incrementAndGet() % num;
39+
return (int) rollingCount.incrementAndGet() % parallelObservables;
3440
}
3541
}).map(new Func1<GroupedObservable<Integer, Observable<T>>, Observable<T>>() {
3642

37-
/**
38-
* Safe to cast from GroupedObservable to Observable so suppressing warning
39-
*/
40-
@SuppressWarnings("unchecked")
4143
@Override
4244
public Observable<T> call(GroupedObservable<Integer, Observable<T>> o) {
43-
return (Observable<T>) o;
45+
return Observable.merge(o).observeOn(scheduler);
4446
}
4547

4648
});

rxjava-core/src/main/java/rx/operators/OperationParallelMergeTest.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@
1818
import static org.junit.Assert.*;
1919

2020
import java.util.List;
21+
import java.util.concurrent.ConcurrentHashMap;
22+
import java.util.concurrent.TimeUnit;
2123

2224
import org.junit.Test;
2325

2426
import rx.Observable;
27+
import rx.concurrency.Schedulers;
2528
import rx.subjects.PublishSubject;
29+
import rx.util.functions.Action1;
30+
import rx.util.functions.Func1;
2631

2732
public class OperationParallelMergeTest {
2833

@@ -42,8 +47,89 @@ public void testParallelMerge() {
4247
List<? super Observable<String>> threeList = threeStreams.toList().toBlockingObservable().last();
4348
List<? super Observable<String>> twoList = twoStreams.toList().toBlockingObservable().last();
4449

50+
System.out.println("two list: " + twoList);
51+
System.out.println("three list: " + threeList);
52+
System.out.println("four list: " + fourList);
53+
4554
assertEquals(4, fourList.size());
4655
assertEquals(3, threeList.size());
4756
assertEquals(2, twoList.size());
4857
}
58+
59+
@Test
60+
public void testNumberOfThreads() {
61+
final ConcurrentHashMap<String, String> threads = new ConcurrentHashMap<String, String>();
62+
Observable.merge(getStreams())
63+
.toBlockingObservable().forEach(new Action1<String>() {
64+
65+
@Override
66+
public void call(String o) {
67+
System.out.println("o: " + o + " Thread: " + Thread.currentThread());
68+
threads.put(Thread.currentThread().getName(), Thread.currentThread().getName());
69+
}
70+
});
71+
72+
// without injecting anything, the getStream() method uses Interval which runs on a default scheduler
73+
assertEquals(Runtime.getRuntime().availableProcessors(), threads.keySet().size());
74+
75+
// clear
76+
threads.clear();
77+
78+
// now we parallelMerge into 3 streams and observeOn for each
79+
// we expect 3 threads in the output
80+
OperationParallelMerge.parallelMerge(getStreams(), 3)
81+
.flatMap(new Func1<Observable<String>, Observable<String>>() {
82+
83+
@Override
84+
public Observable<String> call(Observable<String> o) {
85+
// for each of the parallel
86+
return o.observeOn(Schedulers.newThread());
87+
}
88+
})
89+
.toBlockingObservable().forEach(new Action1<String>() {
90+
91+
@Override
92+
public void call(String o) {
93+
System.out.println("o: " + o + " Thread: " + Thread.currentThread());
94+
threads.put(Thread.currentThread().getName(), Thread.currentThread().getName());
95+
}
96+
});
97+
98+
assertEquals(3, threads.keySet().size());
99+
}
100+
101+
@Test
102+
public void testNumberOfThreadsOnScheduledMerge() {
103+
final ConcurrentHashMap<String, String> threads = new ConcurrentHashMap<String, String>();
104+
105+
// now we parallelMerge into 3 streams and observeOn for each
106+
// we expect 3 threads in the output
107+
Observable.merge(OperationParallelMerge.parallelMerge(getStreams(), 3, Schedulers.newThread()))
108+
.toBlockingObservable().forEach(new Action1<String>() {
109+
110+
@Override
111+
public void call(String o) {
112+
System.out.println("o: " + o + " Thread: " + Thread.currentThread());
113+
threads.put(Thread.currentThread().getName(), Thread.currentThread().getName());
114+
}
115+
});
116+
117+
assertEquals(3, threads.keySet().size());
118+
}
119+
120+
private static Observable<Observable<String>> getStreams() {
121+
return Observable.range(0, 10).map(new Func1<Integer, Observable<String>>() {
122+
123+
@Override
124+
public Observable<String> call(final Integer i) {
125+
return Observable.interval(10, TimeUnit.MILLISECONDS).map(new Func1<Long, String>() {
126+
127+
@Override
128+
public String call(Long l) {
129+
return "Stream " + i + " Value: " + l;
130+
}
131+
}).take(5);
132+
}
133+
});
134+
}
49135
}

0 commit comments

Comments
 (0)