diff --git a/src/main/java/org/apache/sysds/common/Types.java b/src/main/java/org/apache/sysds/common/Types.java index ab81ff4e31a..7b8a9dc83e9 100644 --- a/src/main/java/org/apache/sysds/common/Types.java +++ b/src/main/java/org/apache/sysds/common/Types.java @@ -606,7 +606,13 @@ public boolean isWrite() { public boolean isRead() { return this == TRANSIENTREAD || this == PERSISTENTREAD; } - + public boolean isTransientRead() { + return this == TRANSIENTREAD; + } + public boolean isTransientWrite() { + return this == TRANSIENTWRITE; + } + @Override public String toString() { switch(this) { diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index 786e281d64a..6702df0549d 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -193,7 +193,7 @@ public ArrayList getJobs(StatementBlock sb, DMLConfig config) { */ private boolean inputNeedsPrefetch(Lop input, Lop lop){ return input.prefetchActivated() && lop.getExecType() != ExecType.FED - && input.getFederatedOutput().isForcedFederated(); + && input.getFederatedOutput() != null && input.getFederatedOutput().isForcedFederated(); } /** diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java index da22c511869..bfd47600227 100644 --- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java +++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java @@ -21,17 +21,29 @@ import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; +import org.apache.sysds.lops.Data; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.OperatorOrderingUtils; import org.apache.sysds.lops.UnaryCP; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RewriteAddBroadcastLop extends LopRewriteRule { + boolean MULTI_BLOCK_REWRITE = false; + @Override public List rewriteLOPinStatementBlock(StatementBlock sb) { @@ -47,16 +59,16 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) nodesWithBroadcast.add(l); if (isBroadcastNeeded(l)) { List oldOuts = new ArrayList<>(l.getOutputs()); - // Construct a Broadcast lop that takes this Spark node as an input + // Construct a Broadcast lop that takes this CP node as an input UnaryCP bc = new UnaryCP(l, Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP); bc.setAsynchronous(true); - //FIXME: Wire Broadcast only with the necessary outputs - for (Lop outCP : oldOuts) { - // Rewire l -> outCP to l -> Broadcast -> outCP - bc.addOutput(outCP); - outCP.replaceInput(l, bc); - l.removeOutput(outCP); - //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + // FIXME: Wire Broadcast only with the necessary outputs + for (Lop outSP : oldOuts) { + // Rewire l -> outSP to l -> Broadcast -> outSP + bc.addOutput(outSP); + outSP.replaceInput(l, bc); + l.removeOutput(outSP); + // FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) } //Place it immediately after the Spark lop in the node list nodesWithBroadcast.add(bc); @@ -67,17 +79,111 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) } @Override - public List rewriteLOPinStatementBlocks(List sbs) { + public List rewriteLOPinStatementBlocks(List sbs) + { + if (!MULTI_BLOCK_REWRITE) + return sbs; + // FIXME: Enable after handling of rmvar of asynchronous inputs + + if (!ConfigurationManager.isBroadcastEnabled()) + return sbs; + if (sbs == null || sbs.isEmpty()) + return sbs; + // The first statement block has to be a basic block + // TODO: Remove this constraints + StatementBlock sb1 = sbs.get(0); + if (!HopRewriteUtils.isLastLevelStatementBlock(sb1)) + return sbs; + if (sb1.getLops() == null || sb1.getLops().isEmpty()) + return sbs; + + // Gather the twrite names of the potential broadcast candidates from the first block + // TODO: Replace repetitive rewrite calls with a single one to place all prefetches + HashMap> twrites = new HashMap<>(); + HashMap broadcastCandidates = new HashMap<>(); + for (Lop root : sb1.getLops()) { + if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) { + Lop written = root.getInputs().get(0); + if (written.getExecType() == Types.ExecType.CP && written.getDataType().isMatrix()) { + // Potential broadcast candidate. Save in the twrite map + twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>()); + broadcastCandidates.put(root.getOutputParameters().getLabel(), written); + } + } + } + if (broadcastCandidates.isEmpty()) + return sbs; + + // Recursively check the consumers in the bellow blocks to find if broadcast is required + for (int i=1; i< sbs.size(); i++) + findConsumers(sbs.get(i), twrites); + + // Place a broadcast if any of the consumers are Spark + for (Map.Entry entry : broadcastCandidates.entrySet()) { + if (twrites.get(entry.getKey()).stream().anyMatch(outBC -> (outBC == true))) { + Lop candidate = entry.getValue(); + List oldOuts = new ArrayList<>(candidate.getOutputs()); + // Construct a broadcast lop that takes this CP node as an input + UnaryCP bc = new UnaryCP(candidate, Types.OpOp1.BROADCAST, candidate.getDataType(), + candidate.getValueType(), Types.ExecType.CP); + bc.setAsynchronous(true); + // FIXME: Wire Broadcast only with the necessary outputs + for (Lop outSP : oldOuts) { + // Rewire l -> outSP to l -> Broadcast -> outSP + bc.addOutput(outSP); + outSP.replaceInput(candidate, bc); + candidate.removeOutput(outSP); + // FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + } + } + } return sbs; } private static boolean isBroadcastNeeded(Lop lop) { // Asynchronously broadcast a matrix if that is produced by a CP instruction, // and at least one Spark parent needs to broadcast this intermediate (eg. mapmm) - boolean isBc = lop.getOutputs().stream() + boolean isBcOutput = lop.getOutputs().stream() .anyMatch(out -> (out.getBroadcastInput() == lop)); - //TODO: Early broadcast objects that are bigger than a single block - boolean isCP = lop.getExecType() == Types.ExecType.CP; - return isCP && isBc && lop.getDataType() == Types.DataType.MATRIX; + // TODO: Early broadcast objects that are bigger than a single block + boolean isCPInput = lop.getExecType() == Types.ExecType.CP; + return isCPInput && isBcOutput && lop.getDataType() == Types.DataType.MATRIX; + } + + private void findConsumers(StatementBlock sb, HashMap> twrites) { + if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + for (StatementBlock input : wstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + + // Find the execution types of the consumers + ArrayList lops = OperatorOrderingUtils.getLopList(sb); + if (lops == null) + return; + for (Lop l : lops) { + // Find consumers in this basic block + if (l instanceof Data && ((Data) l).getOperationType().isTransientRead() + && twrites.containsKey(l.getOutputParameters().getLabel())) { + // Check if the consumers satisfy broadcast conditions + for (Lop consumer : l.getOutputs()) + if (consumer.getBroadcastInput() == l) + twrites.get(l.getOutputParameters().getLabel()).add(true); + } + } + } } diff --git a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java index 6eb52e0d9fc..bae768e1245 100644 --- a/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java +++ b/src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java @@ -22,10 +22,12 @@ import org.apache.sysds.common.Types; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.AggBinaryOp; +import org.apache.sysds.hops.rewrite.HopRewriteUtils; import org.apache.sysds.lops.CSVReBlock; import org.apache.sysds.lops.CentralMoment; import org.apache.sysds.lops.Checkpoint; import org.apache.sysds.lops.CoVariance; +import org.apache.sysds.lops.Data; import org.apache.sysds.lops.DataGen; import org.apache.sysds.lops.GroupedAggregate; import org.apache.sysds.lops.GroupedAggregateM; @@ -40,11 +42,19 @@ import org.apache.sysds.lops.SpoofFused; import org.apache.sysds.lops.UAggOuterChain; import org.apache.sysds.lops.UnaryCP; +import org.apache.sysds.parser.ForStatement; +import org.apache.sysds.parser.ForStatementBlock; +import org.apache.sysds.parser.FunctionStatement; +import org.apache.sysds.parser.FunctionStatementBlock; import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.parser.WhileStatement; +import org.apache.sysds.parser.WhileStatementBlock; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; public class RewriteAddPrefetchLop extends LopRewriteRule { @@ -62,13 +72,14 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) //Find the Spark nodes with all CP outputs for (Lop l : lops) { nodesWithPrefetch.add(l); - if (isPrefetchNeeded(l)) { + if (isPrefetchNeeded(l) && !l.prefetchActivated()) { List oldOuts = new ArrayList<>(l.getOutputs()); - //Construct a Prefetch lop that takes this Spark node as a input + //Construct a Prefetch lop that takes this Spark node as an input UnaryCP prefetch = new UnaryCP(l, Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP); prefetch.setAsynchronous(true); //Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch) l.setAsynchronous(false); + l.activatePrefetch(); for (Lop outCP : oldOuts) { //Rewire l -> outCP to l -> Prefetch -> outCP prefetch.addOutput(outCP); @@ -85,13 +96,78 @@ public List rewriteLOPinStatementBlock(StatementBlock sb) } @Override - public List rewriteLOPinStatementBlocks(List sbs) { + public List rewriteLOPinStatementBlocks(List sbs) + { + if (!ConfigurationManager.isPrefetchEnabled()) + return sbs; + if (sbs == null || sbs.isEmpty()) + return sbs; + // The first statement block has to be a basic block + // TODO: Remove this constraints + StatementBlock sb1 = sbs.get(0); + if (!HopRewriteUtils.isLastLevelStatementBlock(sb1)) + return sbs; + if (sb1.getLops() == null || sb1.getLops().isEmpty()) + return sbs; + + // Gather the twrite names of the potential prefetch candidates from the first block + // TODO: Replace repetitive rewrite calls with a single one to place all prefetches + HashMap> twrites = new HashMap<>(); + HashMap prefetchCandidates = new HashMap<>(); + for (Lop root : sb1.getLops()) { + if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) { + Lop written = root.getInputs().get(0); + if (isTransformOP(written) && !hasParameterizedOut(written) && written.getDataType().isMatrix()) { + // Potential prefetch candidate. Save in the twrite map + twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>()); + prefetchCandidates.put(root.getOutputParameters().getLabel(), written); + } + } + } + if (prefetchCandidates.isEmpty()) + return sbs; + + // Recursively check the consumers in the bellow blocks to find if prefetch is required + for (int i=1; i< sbs.size(); i++) + findConsumers(sbs.get(i), twrites); + + // Place a prefetch if all the consumers are CP + for (Map.Entry entry : prefetchCandidates.entrySet()) { + if (twrites.get(entry.getKey()).stream().allMatch(outCP -> (outCP == true))) { + Lop candidate = entry.getValue(); + // Add prefetch after prefetch candidate + List oldOuts = new ArrayList<>(candidate.getOutputs()); + // Construct a Prefetch lop that takes this Spark node as an input + UnaryCP prefetch = new UnaryCP(candidate, Types.OpOp1.PREFETCH, candidate.getDataType(), + candidate.getValueType(), Types.ExecType.CP); + prefetch.setAsynchronous(true); + // Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch) + candidate.setAsynchronous(false); + candidate.activatePrefetch(); + for (Lop outCP : oldOuts) { + // Rewire l -> outCP to l -> Prefetch -> outCP + prefetch.addOutput(outCP); + outCP.replaceInput(candidate, prefetch); + candidate.removeOutput(outCP); + } + } + } return sbs; } private boolean isPrefetchNeeded(Lop lop) { // Run Prefetch for a Spark instruction if the instruction is a Transformation // and the output is consumed by only CP instructions. + boolean transformOP = isTransformOP(lop); + //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) + boolean hasParameterizedOut = hasParameterizedOut(lop); + //TODO: support non-matrix outputs + return transformOP && !hasParameterizedOut + && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) + && lop.getDataType().isMatrix(); + } + + private boolean isTransformOP(Lop lop) { boolean transformOP = lop.getExecType() == Types.ExecType.SPARK && lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK // Always Action operations && !(lop.getDataType() == Types.DataType.SCALAR) @@ -104,15 +180,51 @@ private boolean isPrefetchNeeded(Lop lop) { // Cannot filter Transformation cases from Actions (FIXME) && !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain) && !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused); + return transformOP; + } - //FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate) - boolean hasParameterizedOut = lop.getOutputs().stream() + private boolean hasParameterizedOut(Lop lop) { + return lop.getOutputs().stream() .anyMatch(out -> ((out instanceof ParameterizedBuiltin) || (out instanceof GroupedAggregate) || (out instanceof GroupedAggregateM))); - //TODO: support non-matrix outputs - return transformOP && !hasParameterizedOut - && (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop)) - && lop.getDataType() == Types.DataType.MATRIX; + } + + private void findConsumers(StatementBlock sb, HashMap> twrites) { + if (sb instanceof FunctionStatementBlock) { + FunctionStatementBlock fsb = (FunctionStatementBlock)sb; + FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof WhileStatementBlock) { + WhileStatementBlock wsb = (WhileStatementBlock) sb; + WhileStatement wstmt = (WhileStatement) wsb.getStatement(0); + for (StatementBlock input : wstmt.getBody()) + findConsumers(input, twrites); + } + else if (sb instanceof ForStatementBlock) { //incl parfor + ForStatementBlock fsb = (ForStatementBlock) sb; + ForStatement fstmt = (ForStatement) fsb.getStatement(0); + for (StatementBlock input : fstmt.getBody()) + findConsumers(input, twrites); + } + + // Find the execution types of the consumers + ArrayList lops = OperatorOrderingUtils.getLopList(sb); + if (lops == null) + return; + for (Lop l : lops) { + // Find consumers in this basic block + if (l instanceof Data && ((Data) l).getOperationType().isTransientRead() + && twrites.containsKey(l.getOutputParameters().getLabel())) { + // Check if the consumers satisfy prefetch conditions + for (Lop consumer : l.getOutputs()) + if (consumer.getExecType() == Types.ExecType.CP + || consumer.getBroadcastInput()==l) + twrites.get(l.getOutputParameters().getLabel()).add(true); + } + } + } } diff --git a/src/test/scripts/functions/async/BroadcastVar3.dml b/src/test/scripts/functions/async/BroadcastVar3.dml new file mode 100644 index 00000000000..2a3fb0c853c --- /dev/null +++ b/src/test/scripts/functions/async/BroadcastVar3.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# CP operations +v = ((v + v) * 1 - v) / (1+1); +v1 = t(v); + +# Break the statement block +while(FALSE){} + +# Spark transformation operations +sp = X + ceil(X); +sp = ((sp + sp) * 1 - sp) / (1+1); + +# mapmm - broadcast v +sp2 = sp %*% v; + +while(FALSE){} +R = sum(sp2); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD5.dml b/src/test/scripts/functions/async/PrefetchRDD5.dml new file mode 100644 index 00000000000..cd07de5946b --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD5.dml @@ -0,0 +1,38 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# Spark transformation operations +sp1 = X + ceil(X); +sp2 = sp1 %*% v; #output fits in local + +# Break the statement block +while(FALSE){} + +# CP instructions +v = ((v + v) * 1 - v) / (1+1); +v = ((v + v) * 2 - v) / (2+1); + +# CP binary triggers the DAG of SP operations +cp = sp2 + sum(v); +R = sum(cp); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD6.dml b/src/test/scripts/functions/async/PrefetchRDD6.dml new file mode 100644 index 00000000000..3d6895da487 --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD6.dml @@ -0,0 +1,40 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +# Spark transformation operations +sp1 = X + ceil(X); +sp2 = sp1 %*% v; #output fits in local + +# CP instructions +v = ((v + v) * 1 - v) / (1+1); +v = ((v + v) * 2 - v) / (2+1); + +# Consumer is in a different statement block +if (sum(v) > 1) + # CP binary triggers the DAG of SP operations + cp = sp2 + sum(v); +else + cp = rowSums(v); + +R = sum(cp); +write(R, $1, format="text"); diff --git a/src/test/scripts/functions/async/PrefetchRDD7.dml b/src/test/scripts/functions/async/PrefetchRDD7.dml new file mode 100644 index 00000000000..f821ea2d3e6 --- /dev/null +++ b/src/test/scripts/functions/async/PrefetchRDD7.dml @@ -0,0 +1,43 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +func = function(Matrix[Double] X, Matrix[Double] v) return (Matrix[Double] cp) { + # Spark transformation operations + sp1 = X + ceil(X); + sp2 = sp1 %*% v; #output fits in local + + # CP instructions + v = ((v + v) * 1 - v) / (1+1); + v = ((v + v) * 2 - v) / (2+1); + + # Consumer is in a different statement block + for (i in 1:5) { + # CP binary triggers the DAG of SP operations + cp = sp2 + sum(v); + } +} + +X = rand(rows=10000, cols=200, seed=42); #sp_rand +v = rand(rows=200, cols=1, seed=42); #cp_rand + +cp = func(X, v); +R = sum(cp); +write(R, $1, format="text");