Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3499] Multi-statementblock LOP rewrites #1784

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/main/java/org/apache/sysds/common/Types.java
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,13 @@ public boolean isWrite() {
public boolean isRead() {
return this == TRANSIENTREAD || this == PERSISTENTREAD;
}

public boolean isTransientRead() {
return this == TRANSIENTREAD;
}
public boolean isTransientWrite() {
return this == TRANSIENTWRITE;
}

@Override
public String toString() {
switch(this) {
Expand Down
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysds/lops/compile/Dag.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public ArrayList<Instruction> getJobs(StatementBlock sb, DMLConfig config) {
*/
private boolean inputNeedsPrefetch(Lop input, Lop lop){
return input.prefetchActivated() && lop.getExecType() != ExecType.FED
&& input.getFederatedOutput().isForcedFederated();
&& input.getFederatedOutput() != null && input.getFederatedOutput().isForcedFederated();
}

/**
Expand Down
132 changes: 119 additions & 13 deletions src/main/java/org/apache/sysds/lops/rewrite/RewriteAddBroadcastLop.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,29 @@

import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RewriteAddBroadcastLop extends LopRewriteRule
{
boolean MULTI_BLOCK_REWRITE = false;

@Override
public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
{
Expand All @@ -47,16 +59,16 @@ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
nodesWithBroadcast.add(l);
if (isBroadcastNeeded(l)) {
List<Lop> oldOuts = new ArrayList<>(l.getOutputs());
// Construct a Broadcast lop that takes this Spark node as an input
// Construct a Broadcast lop that takes this CP node as an input
UnaryCP bc = new UnaryCP(l, Types.OpOp1.BROADCAST, l.getDataType(), l.getValueType(), Types.ExecType.CP);
bc.setAsynchronous(true);
//FIXME: Wire Broadcast only with the necessary outputs
for (Lop outCP : oldOuts) {
// Rewire l -> outCP to l -> Broadcast -> outCP
bc.addOutput(outCP);
outCP.replaceInput(l, bc);
l.removeOutput(outCP);
//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
// FIXME: Wire Broadcast only with the necessary outputs
for (Lop outSP : oldOuts) {
// Rewire l -> outSP to l -> Broadcast -> outSP
bc.addOutput(outSP);
outSP.replaceInput(l, bc);
l.removeOutput(outSP);
// FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
}
//Place it immediately after the Spark lop in the node list
nodesWithBroadcast.add(bc);
Expand All @@ -67,17 +79,111 @@ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
}

@Override
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs)
{
if (!MULTI_BLOCK_REWRITE)
return sbs;
// FIXME: Enable after handling of rmvar of asynchronous inputs

if (!ConfigurationManager.isBroadcastEnabled())
return sbs;
if (sbs == null || sbs.isEmpty())
return sbs;
// The first statement block has to be a basic block
// TODO: Remove this constraints
StatementBlock sb1 = sbs.get(0);
if (!HopRewriteUtils.isLastLevelStatementBlock(sb1))
return sbs;
if (sb1.getLops() == null || sb1.getLops().isEmpty())
return sbs;

// Gather the twrite names of the potential broadcast candidates from the first block
// TODO: Replace repetitive rewrite calls with a single one to place all prefetches
HashMap<String, List<Boolean>> twrites = new HashMap<>();
HashMap<String, Lop> broadcastCandidates = new HashMap<>();
for (Lop root : sb1.getLops()) {
if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) {
Lop written = root.getInputs().get(0);
if (written.getExecType() == Types.ExecType.CP && written.getDataType().isMatrix()) {
// Potential broadcast candidate. Save in the twrite map
twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>());
broadcastCandidates.put(root.getOutputParameters().getLabel(), written);
}
}
}
if (broadcastCandidates.isEmpty())
return sbs;

// Recursively check the consumers in the bellow blocks to find if broadcast is required
for (int i=1; i< sbs.size(); i++)
findConsumers(sbs.get(i), twrites);

// Place a broadcast if any of the consumers are Spark
for (Map.Entry<String, Lop> entry : broadcastCandidates.entrySet()) {
if (twrites.get(entry.getKey()).stream().anyMatch(outBC -> (outBC == true))) {
Lop candidate = entry.getValue();
List<Lop> oldOuts = new ArrayList<>(candidate.getOutputs());
// Construct a broadcast lop that takes this CP node as an input
UnaryCP bc = new UnaryCP(candidate, Types.OpOp1.BROADCAST, candidate.getDataType(),
candidate.getValueType(), Types.ExecType.CP);
bc.setAsynchronous(true);
// FIXME: Wire Broadcast only with the necessary outputs
for (Lop outSP : oldOuts) {
// Rewire l -> outSP to l -> Broadcast -> outSP
bc.addOutput(outSP);
outSP.replaceInput(candidate, bc);
candidate.removeOutput(outSP);
// FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
}
}
}
return sbs;
}

private static boolean isBroadcastNeeded(Lop lop) {
// Asynchronously broadcast a matrix if that is produced by a CP instruction,
// and at least one Spark parent needs to broadcast this intermediate (eg. mapmm)
boolean isBc = lop.getOutputs().stream()
boolean isBcOutput = lop.getOutputs().stream()
.anyMatch(out -> (out.getBroadcastInput() == lop));
//TODO: Early broadcast objects that are bigger than a single block
boolean isCP = lop.getExecType() == Types.ExecType.CP;
return isCP && isBc && lop.getDataType() == Types.DataType.MATRIX;
// TODO: Early broadcast objects that are bigger than a single block
boolean isCPInput = lop.getExecType() == Types.ExecType.CP;
return isCPInput && isBcOutput && lop.getDataType() == Types.DataType.MATRIX;
}

private void findConsumers(StatementBlock sb, HashMap<String, List<Boolean>> twrites) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock input : fstmt.getBody())
findConsumers(input, twrites);
}
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock input : wstmt.getBody())
findConsumers(input, twrites);
}
else if (sb instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock input : fstmt.getBody())
findConsumers(input, twrites);
}

// Find the execution types of the consumers
ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
if (lops == null)
return;
for (Lop l : lops) {
// Find consumers in this basic block
if (l instanceof Data && ((Data) l).getOperationType().isTransientRead()
&& twrites.containsKey(l.getOutputParameters().getLabel())) {
// Check if the consumers satisfy broadcast conditions
for (Lop consumer : l.getOutputs())
if (consumer.getBroadcastInput() == l)
twrites.get(l.getOutputParameters().getLabel()).add(true);
}
}

}
}
130 changes: 121 additions & 9 deletions src/main/java/org/apache/sysds/lops/rewrite/RewriteAddPrefetchLop.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
import org.apache.sysds.lops.CoVariance;
import org.apache.sysds.lops.Data;
import org.apache.sysds.lops.DataGen;
import org.apache.sysds.lops.GroupedAggregate;
import org.apache.sysds.lops.GroupedAggregateM;
Expand All @@ -40,11 +42,19 @@
import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.FunctionStatement;
import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

public class RewriteAddPrefetchLop extends LopRewriteRule
{
Expand All @@ -62,13 +72,14 @@ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
//Find the Spark nodes with all CP outputs
for (Lop l : lops) {
nodesWithPrefetch.add(l);
if (isPrefetchNeeded(l)) {
if (isPrefetchNeeded(l) && !l.prefetchActivated()) {
List<Lop> oldOuts = new ArrayList<>(l.getOutputs());
//Construct a Prefetch lop that takes this Spark node as a input
//Construct a Prefetch lop that takes this Spark node as an input
UnaryCP prefetch = new UnaryCP(l, Types.OpOp1.PREFETCH, l.getDataType(), l.getValueType(), Types.ExecType.CP);
prefetch.setAsynchronous(true);
//Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch)
l.setAsynchronous(false);
l.activatePrefetch();
for (Lop outCP : oldOuts) {
//Rewire l -> outCP to l -> Prefetch -> outCP
prefetch.addOutput(outCP);
Expand All @@ -85,13 +96,78 @@ public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
}

@Override
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs)
{
if (!ConfigurationManager.isPrefetchEnabled())
return sbs;
if (sbs == null || sbs.isEmpty())
return sbs;
// The first statement block has to be a basic block
// TODO: Remove this constraints
StatementBlock sb1 = sbs.get(0);
if (!HopRewriteUtils.isLastLevelStatementBlock(sb1))
return sbs;
if (sb1.getLops() == null || sb1.getLops().isEmpty())
return sbs;

// Gather the twrite names of the potential prefetch candidates from the first block
// TODO: Replace repetitive rewrite calls with a single one to place all prefetches
HashMap<String, List<Boolean>> twrites = new HashMap<>();
HashMap<String, Lop> prefetchCandidates = new HashMap<>();
for (Lop root : sb1.getLops()) {
if (root instanceof Data && ((Data)root).getOperationType().isTransientWrite()) {
Lop written = root.getInputs().get(0);
if (isTransformOP(written) && !hasParameterizedOut(written) && written.getDataType().isMatrix()) {
// Potential prefetch candidate. Save in the twrite map
twrites.put(root.getOutputParameters().getLabel(), new ArrayList<>());
prefetchCandidates.put(root.getOutputParameters().getLabel(), written);
}
}
}
if (prefetchCandidates.isEmpty())
return sbs;

// Recursively check the consumers in the bellow blocks to find if prefetch is required
for (int i=1; i< sbs.size(); i++)
findConsumers(sbs.get(i), twrites);

// Place a prefetch if all the consumers are CP
for (Map.Entry<String, Lop> entry : prefetchCandidates.entrySet()) {
if (twrites.get(entry.getKey()).stream().allMatch(outCP -> (outCP == true))) {
Lop candidate = entry.getValue();
// Add prefetch after prefetch candidate
List<Lop> oldOuts = new ArrayList<>(candidate.getOutputs());
// Construct a Prefetch lop that takes this Spark node as an input
UnaryCP prefetch = new UnaryCP(candidate, Types.OpOp1.PREFETCH, candidate.getDataType(),
candidate.getValueType(), Types.ExecType.CP);
prefetch.setAsynchronous(true);
// Reset asynchronous flag for the input if already set (e.g. mapmm -> prefetch)
candidate.setAsynchronous(false);
candidate.activatePrefetch();
for (Lop outCP : oldOuts) {
// Rewire l -> outCP to l -> Prefetch -> outCP
prefetch.addOutput(outCP);
outCP.replaceInput(candidate, prefetch);
candidate.removeOutput(outCP);
}
}
}
return sbs;
}

private boolean isPrefetchNeeded(Lop lop) {
// Run Prefetch for a Spark instruction if the instruction is a Transformation
// and the output is consumed by only CP instructions.
boolean transformOP = isTransformOP(lop);
//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
boolean hasParameterizedOut = hasParameterizedOut(lop);
//TODO: support non-matrix outputs
return transformOP && !hasParameterizedOut
&& (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop))
&& lop.getDataType().isMatrix();
}

private boolean isTransformOP(Lop lop) {
boolean transformOP = lop.getExecType() == Types.ExecType.SPARK && lop.getAggType() != AggBinaryOp.SparkAggType.SINGLE_BLOCK
// Always Action operations
&& !(lop.getDataType() == Types.DataType.SCALAR)
Expand All @@ -104,15 +180,51 @@ private boolean isPrefetchNeeded(Lop lop) {
// Cannot filter Transformation cases from Actions (FIXME)
&& !(lop instanceof MMTSJ) && !(lop instanceof UAggOuterChain)
&& !(lop instanceof ParameterizedBuiltin) && !(lop instanceof SpoofFused);
return transformOP;
}

//FIXME: Rewire _inputParams when needed (e.g. GroupedAggregate)
boolean hasParameterizedOut = lop.getOutputs().stream()
private boolean hasParameterizedOut(Lop lop) {
return lop.getOutputs().stream()
.anyMatch(out -> ((out instanceof ParameterizedBuiltin)
|| (out instanceof GroupedAggregate)
|| (out instanceof GroupedAggregateM)));
//TODO: support non-matrix outputs
return transformOP && !hasParameterizedOut
&& (lop.isAllOutputsCP() || OperatorOrderingUtils.isCollectForBroadcast(lop))
&& lop.getDataType() == Types.DataType.MATRIX;
}

private void findConsumers(StatementBlock sb, HashMap<String, List<Boolean>> twrites) {
if (sb instanceof FunctionStatementBlock) {
FunctionStatementBlock fsb = (FunctionStatementBlock)sb;
FunctionStatement fstmt = (FunctionStatement)fsb.getStatement(0);
for (StatementBlock input : fstmt.getBody())
findConsumers(input, twrites);
}
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement) wsb.getStatement(0);
for (StatementBlock input : wstmt.getBody())
findConsumers(input, twrites);
}
else if (sb instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement) fsb.getStatement(0);
for (StatementBlock input : fstmt.getBody())
findConsumers(input, twrites);
}

// Find the execution types of the consumers
ArrayList<Lop> lops = OperatorOrderingUtils.getLopList(sb);
if (lops == null)
return;
for (Lop l : lops) {
// Find consumers in this basic block
if (l instanceof Data && ((Data) l).getOperationType().isTransientRead()
&& twrites.containsKey(l.getOutputParameters().getLabel())) {
// Check if the consumers satisfy prefetch conditions
for (Lop consumer : l.getOutputs())
if (consumer.getExecType() == Types.ExecType.CP
|| consumer.getBroadcastInput()==l)
twrites.get(l.getOutputParameters().getLabel()).add(true);
}
}

}
}
Loading