Skip to content
This repository was archived by the owner on Dec 30, 2019. It is now read-only.

Commit 645d22c

Browse files
committed
add mapPartitionsWithIndex
1 parent ea2ee6c commit 645d22c

File tree

4 files changed

+55
-25
lines changed

4 files changed

+55
-25
lines changed

src/Spock.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Spock
44
using Scotty
55
using JavaCall
66
import Base: Callable, map, collect, convert, count, reduce
7-
export SparkContext, RDD, parallelize
7+
export SparkContext, RDD, parallelize, transform
88

99
const classpath = get(ENV, "CLASSPATH", "")
1010
JavaCall.init(["-ea", "-Xmx1024M", "-Djava.class.path=$(classpath)"])
@@ -13,7 +13,7 @@ module Spock
1313
JArrays = @jimport java.util.Arrays
1414
JList = @jimport java.util.List
1515
JFunction = @jimport org.apache.spark.api.java.function.Function
16-
JFlatMapFunction = @jimport org.apache.spark.api.java.function.FlatMapFunction
16+
JFunction2 = @jimport org.apache.spark.api.java.function.Function2
1717
JJavaRDD = @jimport org.apache.spark.api.java.JavaRDD
1818
JJavaSparkContext = @jimport org.apache.spark.api.java.JavaSparkContext
1919
JJuliaRDD = @jimport edu.berkeley.bids.spock.JuliaRDD
@@ -49,7 +49,7 @@ module Spock
4949
function jrdd(rdd::TransformedRDD)
5050
if rdd.jrdd === nothing
5151
jfunc = JJuliaFunction((JJuliaObject,), jbox(rdd.task))
52-
rdd.jrdd = jcall(jrdd(rdd.parent), "mapPartitions", JJavaRDD, (JFlatMapFunction,), jfunc)
52+
rdd.jrdd = jcall(jrdd(rdd.parent), "mapPartitionsWithIndex", JJavaRDD, (JFunction2, jboolean), jfunc, false)
5353
end
5454
rdd.jrdd::JJavaRDD
5555
end
@@ -78,7 +78,11 @@ module Spock
7878
deserialize(IOBuffer(payload))
7979
end
8080

81-
function transform(rdd::RDD, task)
81+
# Analogous to `mapPartitionsWithIndex`.
82+
#
83+
# `task` will be called once per input partition with arguments of
84+
# (partition_id, input_iter) and returns an iterable of new contents.
85+
function transform(task::Function, rdd::RDD)
8286
if isa(rdd, TransformedRDD) && ispipelineable(rdd)
8387
TransformedRDD(rdd.parent, pipetask(task, rdd.task))
8488
else
@@ -87,11 +91,11 @@ module Spock
8791
end
8892

8993
function map(f::Callable, rdd::RDD)
90-
transform(rdd, maptask(f))
94+
transform(maptask(f), rdd)
9195
end
9296

9397
function reduce(f::Callable, rdd::RDD)
94-
reduce(f, collect(transform(rdd, reducetask(f))))
98+
reduce(f, collect(transform(reducetask(f), rdd)))
9599
end
96100

97101
function collect(rdd::RDD)

src/edu/berkeley/cs/amplab/spock/JuliaFunction.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
package edu.berkeley.bids.spock;
22
import org.apache.spark.rdd.RDD;
33
import org.apache.spark.api.java.JavaRDD;
4-
import org.apache.spark.api.java.function.FlatMapFunction;
4+
import org.apache.spark.api.java.function.Function2;
55
import java.io.BufferedInputStream;
66
import java.io.BufferedOutputStream;
77
import java.io.DataInputStream;
@@ -10,7 +10,7 @@
1010
import java.util.Iterator;
1111
import java.util.LinkedList;
1212

13-
public class JuliaFunction implements FlatMapFunction<Iterator<JuliaObject>, JuliaObject> {
13+
public class JuliaFunction implements Function2<Integer, Iterator<JuliaObject>, Iterator<JuliaObject>> {
1414
private static final long serialVersionUID = 1;
1515
final JuliaObject func;
1616

@@ -23,7 +23,7 @@ String getScottyPath() {
2323
}
2424

2525
@Override
26-
public Iterable<JuliaObject> call(Iterator<JuliaObject> args) throws Exception {
26+
public Iterator<JuliaObject> call(Integer partId, Iterator<JuliaObject> args) throws Exception {
2727
// launch worker
2828
ProcessBuilder pb = new ProcessBuilder("julia", "-L", getScottyPath(), "-e", "Scotty.worker()");
2929
pb.redirectError(ProcessBuilder.Redirect.INHERIT);
@@ -32,6 +32,7 @@ public Iterable<JuliaObject> call(Iterator<JuliaObject> args) throws Exception {
3232
// send input
3333
DataOutputStream out = new DataOutputStream(new BufferedOutputStream(worker.getOutputStream()));
3434
func.write(out);
35+
out.writeInt(partId.intValue());
3536
while(args.hasNext()) {
3637
args.next().write(out);
3738
}
@@ -51,7 +52,7 @@ public Iterable<JuliaObject> call(Iterator<JuliaObject> args) throws Exception {
5152
if(worker.waitFor() != 0) {
5253
throw new RuntimeException(String.format("Spock worker died with exitValue=%d", worker.exitValue()));
5354
} else {
54-
return results;
55+
return results.iterator();
5556
}
5657
}
5758
}

src/scotty.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
module Scotty
2-
export maptask, reducetask, pipetask, intask, outtask
2+
export maptask, reducetask, pipetask
33
const inf = STDIN
44
const outf = STDOUT
55

@@ -24,23 +24,23 @@ module Scotty
2424
end
2525
end
2626

27-
function outtask(inp)
28-
map(writeobj, inp)
27+
function outtask(iter)
28+
map(writeobj, iter)
2929
end
3030

3131
function maptask(f)
32-
(inp) -> begin
32+
(partid, iter) -> begin
3333
Task() do
34-
map(arg -> produce(f(arg)), inp)
34+
map(arg -> produce(f(arg)), iter)
3535
end
3636
end
3737
end
3838

3939
function reducetask(f)
40-
(inp) -> begin
40+
(partid, iter) -> begin
4141
Task() do
4242
accum = nothing
43-
for arg in inp
43+
for arg in iter
4444
if accum == nothing
4545
accum = arg
4646
else
@@ -53,16 +53,17 @@ module Scotty
5353
end
5454

5555
function pipetask(t2, t1)
56-
(inp) -> t2(t1(inp))
56+
(partid, iter) -> t2(partid, t1(partid, iter))
5757
end
5858

5959
function worker()
6060
try
6161
redirect_stdout(STDERR)
6262
close(redirect_stdin()[2])
6363
task = readobj()
64+
partid = readint()
6465
try
65-
outtask(task(intask()))
66+
outtask(task(partid, intask()))
6667
catch exc
6768
writeint(0)
6869
writeint(2) # OOB 2: task error (fatal)

test/runtests.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,19 @@ using Spock
22
using Base.Test
33

44
sc = SparkContext()
5-
65
rdd1 = parallelize(sc, 1:10, 2)
6+
rdd2 = map(x -> x^2, rdd1)
7+
rddx = parallelize(sc, fill("x", 10), 2)
8+
allequal(val, rdd) = all(x -> x == val, collect(rdd))
9+
10+
# test basics
711
@assert 10 == count(rdd1)
812
@assert 55 == sum(collect(rdd1))
913
@assert 55 == reduce(+, rdd1)
10-
11-
rdd2 = map(x -> x^2, rdd1)
1214
@assert 385 == sum(collect(rdd2))
1315
@assert 385 == reduce(+, rdd2)
1416

17+
# test propagation of exceptions from workers
1518
@assert "moo" == begin
1619
try
1720
collect(map(x->throw("woof"), rdd1))
@@ -23,6 +26,7 @@ rdd2 = map(x -> x^2, rdd1)
2326
end
2427

2528
let driverpid = getpid()
29+
# test containment of worker I/O
2630
@assert 55 == reduce(rdd1) do x, y
2731
if getpid() != driverpid
2832
println("Ekke Ekke Ekke Ekke Ptangya Zoooooooom Boing Ni!")
@@ -31,13 +35,33 @@ let driverpid = getpid()
3135
x + y
3236
end
3337

34-
rdd3 = map(x->getpid(), rdd1)
35-
@assert 2 == length(Set(collect(rdd3)))
36-
@assert driverpid == reduce(rdd3) do l, r
38+
# test pipelining
39+
pids = map(x->getpid(), rdd1)
40+
@assert 2 == length(Set(collect(pids)))
41+
@assert driverpid == reduce(pids) do l, r
3742
mypid = getpid()
3843
@assert mypid == driverpid || (mypid == l && mypid == r)
3944
mypid
4045
end
4146
end
4247

48+
# test mixing synchronous and asynchronous transforms
49+
frob = (partid, iter) -> ["f$(x)" for x in collect(iter)]
50+
brof = (partid, iter) -> begin
51+
Task() do
52+
for x in iter
53+
produce("b$(x)")
54+
end
55+
end
56+
end
57+
@assert allequal("fbfbx", transform(frob, transform(brof, transform(frob, transform(brof, rddx)))))
58+
@assert allequal("bfbfx", transform(brof, transform(frob, transform(brof, transform(frob, rddx)))))
59+
@assert allequal("ffbfx", transform(frob, transform(frob, transform(brof, transform(frob, rddx)))))
60+
@assert allequal("fbbfx", transform(frob, transform(brof, transform(brof, transform(frob, rddx)))))
61+
@assert allequal("bbfbx", transform(brof, transform(brof, transform(frob, transform(brof, rddx)))))
62+
@assert allequal("bffbx", transform(brof, transform(frob, transform(frob, transform(brof, rddx)))))
63+
partids = collect(transform((partid, iter) -> [partid], rddx))
64+
@assert length(partids) == 2
65+
@assert Set([0, 1]) == Set(partids)
66+
4367
println("Spock: all tests passed")

0 commit comments

Comments
 (0)