From 2ad31a05452479a55ae61297c937b76e8e49e032 Mon Sep 17 00:00:00 2001 From: Arnab Phani Date: Sat, 18 Feb 2023 19:53:18 +0100 Subject: [PATCH] [SYSTEMDS-3499] Multi-statementblock LOP rewrites This patch extends the rewrites to add prefetch and broadcast to support multiple statement blocks, i.e. if the consumers are situated in different blocks. This change is currently not in effect as the dynamic recompilation does not allow multi-block rewrites easily. Moreover, need to find a way to delay the rmvar of inputs to Broadcast instructions. --- .../java/org/apache/sysds/common/Types.java | 8 +- .../org/apache/sysds/lops/compile/Dag.java | 2 +- .../lops/rewrite/RewriteAddBroadcastLop.java | 132 ++++++++++++++++-- .../lops/rewrite/RewriteAddPrefetchLop.java | 130 +++++++++++++++-- .../scripts/functions/async/BroadcastVar3.dml | 40 ++++++ .../scripts/functions/async/PrefetchRDD5.dml | 38 +++++ .../scripts/functions/async/PrefetchRDD6.dml | 40 ++++++ .../scripts/functions/async/PrefetchRDD7.dml | 43 ++++++ 8 files changed, 409 insertions(+), 24 deletions(-) create mode 100644 src/test/scripts/functions/async/BroadcastVar3.dml create mode 100644 src/test/scripts/functions/async/PrefetchRDD5.dml create mode 100644 src/test/scripts/functions/async/PrefetchRDD6.dml create mode 100644 src/test/scripts/functions/async/PrefetchRDD7.dml diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index ab81ff4e31a..7b8a9dc83e9 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -606,7 +606,13 @@ public boolean isWrite() { public boolean isRead() { return this == TRANSIENTREAD || this == PERSISTENTREAD; } - + public boolean isTransientRead() { + return this == TRANSIENTREAD; + } + public boolean isTransientWrite() { + return this == TRANSIENTWRITE; + } + @Override public String toString() { switch(this) { diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index 786e281d64a..6702df0549d 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -193,7 +193,7 @@ public ArrayList getJobs(StatementBlock sb, DMLConfig config) { */ private boolean inputNeedsPrefetch(Lop input, Lop lop){ return input.prefetchActivated() && lop.getExecType() != ExecType.FED - && input.getFederatedOutput().isForcedFederated(); + && input.getFederatedOutput() != null && input.getFederatedOutput().isForcedFederated(); } /** diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java index da22c511869..bfd47600227 100644 --- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java +++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java @@ -21,17 +21,29 @@ import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.lops.Data; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.OperatorOrderingUtils; import org.apache.sysds.lops.UnaryCP; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RewriteAddBroadcastLop extends LopRewriteRule { + boolean MULTI_BLOCK_REWRITE = false; + @Override public List rewriteLOPinStatementBlock(StatementBlock sb) { @@ -47,16 +59,16 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) nodesWithBroadcast.add(l); if (isBroadcastNeeded(l)) { List oldOuts = new ArrayList<>(l.getOutputs()); - // Construct a Broadcast lop that takes this Spark node as an input + // Construct a Broadcast lop that takes this CP node as an input UnaryCP bc = new UnaryCP(l, Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP); bc.setAsynchronous(true); - //FIXME: Wire Broadcast only with the necessary outputs - for (Lop outCP : oldOuts) { - // Rewire l -> outCP to l -> Broadcast -> outCP - bc.addOutput(outCP); - outCP.replaceInput(l, bc); - l.removeOutput(outCP); - //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + // FIXME: Wire Broadcast only with the necessary outputs + for (Lop outSP : oldOuts) { + // Rewire l -> outSP to l -> Broadcast -> outSP + bc.addOutput(outSP); + outSP.replaceInput(l, bc); + l.removeOutput(outSP); + // FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) } //Place it immediately after the Spark lop in the node list nodesWithBroadcast.add(bc); @@ -67,17 +79,111 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) } @Override - public List rewriteLOPinStatementBlocks(List sbs) { + public List rewriteLOPinStatementBlocks(List sbs) + { + if (!MULTI_BLOCK_REWRITE) + return sbs; + // FIXME: Enable after handling of rmvar of asynchronous inputs + + if (!ConfigurationManager.isBroadcastEnabled()) + return sbs; + if (sbs == null || sbs.isEmpty()) + return sbs; + // The first statement block has to be a basic block + // TODO: Remove this constraints + StatementBlock sb1 = sbs.get(0); + if (!HopRewriteUtils.isLastLevelStatementBlock(sb1)) + return sbs; + if (sb1.getLops() == null || sb1.getLops().isEmpty()) + return sbs; + + // Gather the twrite names of the potential broadcast candidates from the first block + // TODO: Replace repetitive rewrite calls with a single one to place all prefetches + HashMap> twrites = new HashMap<>(); + HashMap broadcastCandidates = new HashMap<>(); + for (Lop root : sb1.getLops()) { + if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) { + Lop written = root.getInputs().get(0); + if (written.getExecType() == Types.ExecType.CP && written.getDataType().isMatrix()) { + // Potential broadcast candidate. Save in the twrite map + twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>()); + broadcastCandidates.put(root.getOutputParameters().getLabel(), written); + } + } + } + if (broadcastCandidates.isEmpty()) + return sbs; + + // Recursively check the consumers in the bellow blocks to find if broadcast is required + for (int i=1; i< sbs.size(); i++) + findConsumers(sbs.get(i), twrites); + + // Place a broadcast if any of the consumers are Spark + for (Map.Entry entry : broadcastCandidates.entrySet()) { + if (twrites.get(entry.getKey()).stream().anyMatch(outBC -> (outBC == true))) { + Lop candidate = entry.getValue(); + List oldOuts = new ArrayList<>(candidate.getOutputs()); + // Construct a broadcast lop that takes this CP node as an input + UnaryCP bc = new UnaryCP(candidate, Types.OpOp1.BROADCAST, candidate.getDataType(), + candidate.getValueType(), Types.ExecType.CP); + bc.setAsynchronous(true); + // FIXME: Wire Broadcast only with the necessary outputs + for (Lop outSP : oldOuts) { + // Rewire l -> outSP to l -> Broadcast -> outSP + bc.addOutput(outSP); + outSP.replaceInput(candidate, bc); + candidate.removeOutput(outSP); + // FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + } + } + } return sbs; } private static boolean isBroadcastNeeded(Lop lop) { // Asynchronously broadcast a matrix if that is produced by a CP instruction, // and at least one Spark parent needs to broadcast this intermediate (eg. mapmm) - boolean isBc = lop.getOutputs().stream() + boolean isBcOutput = lop.getOutputs().stream() .anyMatch(out -> (out.getBroadcastInput() == lop)); - //TODO: Early broadcast objects that are bigger than a single block - boolean isCP = lop.getExecType() == Types.ExecType.CP; - return isCP && isBc && lop.getDataType() == Types.DataType.MATRIX; + // TODO: Early broadcast objects that are bigger than a single block + boolean isCPInput = lop.getExecType() == Types.ExecType.CP; + return isCPInput && isBcOutput && lop.getDataType() == Types.DataType.MATRIX; + } + + private void findConsumers(StatementBlock sb, HashMap> twrites) { + if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + for (StatementBlock input : wstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + + // Find the execution types of the consumers + ArrayList lops = OperatorOrderingUtils.getLopList(sb); + if (lops == null) + return; + for (Lop l : lops) { + // Find consumers in this basic block + if (l instanceof Data && ((Data) l).getOperationType().isTransientRead() + && twrites.containsKey(l.getOutputParameters().getLabel())) { + // Check if the consumers satisfy broadcast conditions + for (Lop consumer : l.getOutputs()) + if (consumer.getBroadcastInput() == l) + twrites.get(l.getOutputParameters().getLabel()).add(true); + } + } + } } diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java index 6eb52e0d9fc..bae768e1245 100644 --- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java +++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java @@ -22,10 +22,12 @@ import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.lops.CSVReBlock; import org.apache.sysds.lops.CentralMoment; import org.apache.sysds.lops.Checkpoint; import org.apache.sysds.lops.CoVariance; +import org.apache.sysds.lops.Data; import org.apache.sysds.lops.DataGen; import org.apache.sysds.lops.GroupedAggregate; import org.apache.sysds.lops.GroupedAggregateM; @@ -40,11 +42,19 @@ import org.apache.sysds.lops.SpoofFused; import org.apache.sysds.lops.UAggOuterChain; import org.apache.sysds.lops.UnaryCP; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RewriteAddPrefetchLop extends LopRewriteRule { @@ -62,13 +72,14 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) //Find the Spark nodes with all CP outputs for (Lop l : lops) { nodesWithPrefetch.add(l); - if (isPrefetchNeeded(l)) { + if (isPrefetchNeeded(l) && !l.prefetchActivated()) { List oldOuts = new ArrayList<>(l.getOutputs()); - //Construct a Prefetch lop that takes this Spark node as a input + //Construct a Prefetch lop that takes this Spark node as an input UnaryCP prefetch = new UnaryCP(l, Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP); prefetch.setAsynchronous(true); //Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch) l.setAsynchronous(false); + l.activatePrefetch(); for (Lop outCP : oldOuts) { //Rewire l -> outCP to l -> Prefetch -> outCP prefetch.addOutput(outCP); @@ -85,13 +96,78 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) } @Override - public List rewriteLOPinStatementBlocks(List sbs) { + public List rewriteLOPinStatementBlocks(List sbs) + { + if (!ConfigurationManager.isPrefetchEnabled()) + return sbs; + if (sbs == null || sbs.isEmpty()) + return sbs; + // The first statement block has to be a basic block + // TODO: Remove this constraints + StatementBlock sb1 = sbs.get(0); + if (!HopRewriteUtils.isLastLevelStatementBlock(sb1)) + return sbs; + if (sb1.getLops() == null || sb1.getLops().isEmpty()) + return sbs; + + // Gather the twrite names of the potential prefetch candidates from the first block + // TODO: Replace repetitive rewrite calls with a single one to place all prefetches + HashMap> twrites = new HashMap<>(); + HashMap prefetchCandidates = new HashMap<>(); + for (Lop root : sb1.getLops()) { + if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) { + Lop written = root.getInputs().get(0); + if (isTransformOP(written) && !hasParameterizedOut(written) && written.getDataType().isMatrix()) { + // Potential prefetch candidate. Save in the twrite map + twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>()); + prefetchCandidates.put(root.getOutputParameters().getLabel(), written); + } + } + } + if (prefetchCandidates.isEmpty()) + return sbs; + + // Recursively check the consumers in the bellow blocks to find if prefetch is required + for (int i=1; i< sbs.size(); i++) + findConsumers(sbs.get(i), twrites); + + // Place a prefetch if all the consumers are CP + for (Map.Entry entry : prefetchCandidates.entrySet()) { + if (twrites.get(entry.getKey()).stream().allMatch(outCP -> (outCP == true))) { + Lop candidate = entry.getValue(); + // Add prefetch after prefetch candidate + List oldOuts = new ArrayList<>(candidate.getOutputs()); + // Construct a Prefetch lop that takes this Spark node as an input + UnaryCP prefetch = new UnaryCP(candidate, Types.OpOp1.PREFETCH, candidate.getDataType(), + candidate.getValueType(), Types.ExecType.CP); + prefetch.setAsynchronous(true); + // Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch) + candidate.setAsynchronous(false); + candidate.activatePrefetch(); + for (Lop outCP : oldOuts) { + // Rewire l -> outCP to l -> Prefetch -> outCP + prefetch.addOutput(outCP); + outCP.replaceInput(candidate, prefetch); + candidate.removeOutput(outCP); + } + } + } return sbs; } private boolean isPrefetchNeeded(Lop lop) { // Run Prefetch for a Spark instruction if the instruction is a Transformation // and the output is consumed by only CP instructions. + boolean transformOP = isTransformOP(lop); + //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + boolean hasParameterizedOut = hasParameterizedOut(lop); + //TODO: support non-matrix outputs + return transformOP && !hasParameterizedOut + && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) + && lop.getDataType().isMatrix(); + } + + private boolean isTransformOP(Lop lop) { boolean transformOP = lop.getExecType() == Types.ExecType.SPARK && lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK // Always Action operations && !(lop.getDataType() == Types.DataType.SCALAR) @@ -104,15 +180,51 @@ private boolean isPrefetchNeeded(Lop lop) { // Cannot filter Transformation cases from Actions (FIXME) && !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused); + return transformOP; + } - //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) - boolean hasParameterizedOut = lop.getOutputs().stream() + private boolean hasParameterizedOut(Lop lop) { + return lop.getOutputs().stream() .anyMatch(out -> ((out instanceof ParameterizedBuiltin) || (out instanceof GroupedAggregate) || (out instanceof GroupedAggregateM))); - //TODO: support non-matrix outputs - return transformOP && !hasParameterizedOut - && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) - && lop.getDataType() == Types.DataType.MATRIX; + } + + private void findConsumers(StatementBlock sb, HashMap> twrites) { + if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + for (StatementBlock input : wstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + + // Find the execution types of the consumers + ArrayList lops = OperatorOrderingUtils.getLopList(sb); + if (lops == null) + return; + for (Lop l : lops) { + // Find consumers in this basic block + if (l instanceof Data && ((Data) l).getOperationType().isTransientRead() + && twrites.containsKey(l.getOutputParameters().getLabel())) { + // Check if the consumers satisfy prefetch conditions + for (Lop consumer : l.getOutputs()) + if (consumer.getExecType() == Types.ExecType.CP + || consumer.getBroadcastInput()==l) + twrites.get(l.getOutputParameters().getLabel()).add(true); + } + } + } } diff --git a/src/test/scripts/functions/async/BroadcastVar3.dml b/src/test/scripts/functions/async/BroadcastVar3.dml new file mode 100644 index 00000000000..2a3fb0c853c --- /dev/null +++ b/src/test/scripts/functions/async/BroadcastVar3.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# CP operations +v = ((v + v) * 1 - v) / (1+1); +v1 = t(v); + +# Break the statement block +while(FALSE){} + +# Spark transformation operations +sp = X + ceil(X); +sp = ((sp + sp) * 1 - sp) / (1+1); + +# mapmm - broadcast v +sp2 = sp %*% v; + +while(FALSE){} +R = sum(sp2); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD5.dml b/src/test/scripts/functions/async/PrefetchRDD5.dml new file mode 100644 index 00000000000..cd07de5946b --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD5.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# Spark transformation operations +sp1 = X + ceil(X); +sp2 = sp1 %*% v; #output fits in local + +# Break the statement block +while(FALSE){} + +# CP instructions +v = ((v + v) * 1 - v) / (1+1); +v = ((v + v) * 2 - v) / (2+1); + +# CP binary triggers the DAG of SP operations +cp = sp2 + sum(v); +R = sum(cp); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD6.dml b/src/test/scripts/functions/async/PrefetchRDD6.dml new file mode 100644 index 00000000000..3d6895da487 --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD6.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# Spark transformation operations +sp1 = X + ceil(X); +sp2 = sp1 %*% v; #output fits in local + +# CP instructions +v = ((v + v) * 1 - v) / (1+1); +v = ((v + v) * 2 - v) / (2+1); + +# Consumer is in a different statement block +if (sum(v) > 1) + # CP binary triggers the DAG of SP operations + cp = sp2 + sum(v); +else + cp = rowSums(v); + +R = sum(cp); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD7.dml b/src/test/scripts/functions/async/PrefetchRDD7.dml new file mode 100644 index 00000000000..f821ea2d3e6 --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD7.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +func = function(Matrix[Double] X, Matrix[Double] v) return (Matrix[Double] cp) { + # Spark transformation operations + sp1 = X + ceil(X); + sp2 = sp1 %*% v; #output fits in local + + # CP instructions + v = ((v + v) * 1 - v) / (1+1); + v = ((v + v) * 2 - v) / (2+1); + + # Consumer is in a different statement block + for (i in 1:5) { + # CP binary triggers the DAG of SP operations + cp = sp2 + sum(v); + } +} + +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +cp = func(X, v); +R = sum(cp); +write(R, $1, format="text");