Skip to content

Commit

Permalink
Support function calls in FedAll and Heuristic planners
Browse files Browse the repository at this point in the history
  • Loading branch information
kev-inn committed Jul 17, 2022
1 parent f586eaa commit 38102ab
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.apache.sysds.hops.fedplanner;

import java.util.HashMap;
import java.util.Map;

import org.apache.sysds.common.Types;
Expand All @@ -28,6 +29,7 @@
import org.apache.sysds.hops.AggUnaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.apache.sysds.utils.Explain;
import org.apache.sysds.utils.Explain.ExplainType;
import org.jetbrains.annotations.NotNull;

public class FederatedPlannerCostbased extends AFederatedPlanner {
private static final Log LOG = LogFactory.getLog(FederatedPlannerCostbased.class.getName());
Expand Down Expand Up @@ -194,32 +195,37 @@ private ArrayList<StatementBlock> rewriteDefaultStatementBlock(DMLProgram prog,
if(sbHop instanceof FunctionOp) {
String funcName = ((FunctionOp) sbHop).getFunctionName();
String funcNamespace = ((FunctionOp) sbHop).getFunctionNamespace();
Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop);
if ( paramMap != null && funcParamMap != null)
funcParamMap.putAll(paramMap);
paramMap = funcParamMap;
FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace)
.getFunction(funcName);
rewriteStatementBlock(prog, sbFuncBlock, paramMap);

FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace).getFunction(funcName);
FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0);

paramMap = createFunctionFedVarTable(paramMap, (FunctionOp) sbHop);
rewriteStatementBlock(prog, sbFuncBlock, paramMap);
mapFunctionOutputs((FunctionOp) sbHop, funcStatement);
}
}
}
return new ArrayList<>(Collections.singletonList(sb));
}

@NotNull
private Map<String, Hop> createFunctionFedVarTable(Map<String, Hop> paramMap, FunctionOp sbHop) {
Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap(sbHop);
if ( paramMap != null && funcParamMap != null)
funcParamMap.putAll(paramMap);
paramMap = funcParamMap;
return paramMap;
}

/**
* Saves the HOPs (TWrite) of the function return values for
* the variable name used when calling the function.
*
* Example:
* <code>
* f = function() return (matrix[double] model) {a = rand(1, 1);}
* f = function() return (matrix[double] a) {a = rand(1, 1);}
* b = f();
* </code>
* This function saves the HOP writing to <code>a</code> for identifier <code>b</code>.
* This function saves the HOP writing to <code>a</code> (transient write) for identifier <code>b</code>.
*
* @param sbHop The <code>FunctionOp</code> for the call
* @param funcStatement The <code>FunctionStatement</code> of the called function
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.FunctionOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
Expand All @@ -45,6 +46,7 @@
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
import org.jetbrains.annotations.NotNull;

/**
* Baseline federated planner that compiles all hops
Expand Down Expand Up @@ -89,14 +91,14 @@ private void rRewriteStatementBlock(StatementBlock sb, Map<String, FType> fedVar
else if (sb instanceof WhileStatementBlock) {
WhileStatementBlock wsb = (WhileStatementBlock) sb;
WhileStatement wstmt = (WhileStatement)wsb.getStatement(0);
rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), Collections.emptyMap());
rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : wstmt.getBody())
rRewriteStatementBlock(csb, fedVars);
}
else if (sb instanceof IfStatementBlock) {
IfStatementBlock isb = (IfStatementBlock) sb;
IfStatement istmt = (IfStatement)isb.getStatement(0);
rRewriteHop(isb.getPredicateHops(), new HashMap<>(), Collections.emptyMap());
rRewriteHop(isb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : istmt.getIfBody())
rRewriteStatementBlock(csb, fedVars);
for (StatementBlock csb : istmt.getElseBody())
Expand All @@ -105,9 +107,9 @@ else if (sb instanceof IfStatementBlock) {
else if (sb instanceof ForStatementBlock) { //incl parfor
ForStatementBlock fsb = (ForStatementBlock) sb;
ForStatement fstmt = (ForStatement)fsb.getStatement(0);
rRewriteHop(fsb.getFromHops(), new HashMap<>(), Collections.emptyMap());
rRewriteHop(fsb.getToHops(), new HashMap<>(), Collections.emptyMap());
rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), Collections.emptyMap());
rRewriteHop(fsb.getFromHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
rRewriteHop(fsb.getToHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg());
for (StatementBlock csb : fstmt.getBody())
rRewriteStatementBlock(csb, fedVars);
}
Expand All @@ -117,7 +119,7 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
Map<Long, FType> fedHops = new HashMap<>();
if( sb.getHops() != null )
for( Hop c : sb.getHops() )
rRewriteHop(c, fedHops, fedVars);
rRewriteHop(c, fedHops, fedVars, sb.getDMLProg());

//TODO handle function calls

Expand All @@ -129,19 +131,31 @@ else if (sb instanceof ForStatementBlock) { //incl parfor
}
}

private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, FType> fedVars) {
private void rRewriteHop(Hop hop, Map<Long, FType> memo, Map<String, FType> fedVars, DMLProgram program) {
if( memo.containsKey(hop.getHopID()) )
return; //already processed

//process children first
for( Hop c : hop.getInput() )
rRewriteHop(c, memo, fedVars);
rRewriteHop(c, memo, fedVars, program);

//handle specific operators (except transient writes)
if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
if(hop instanceof FunctionOp) {
String funcName = ((FunctionOp) hop).getFunctionName();
String funcNamespace = ((FunctionOp) hop).getFunctionNamespace();
FunctionStatementBlock sbFuncBlock = program.getFunctionDictionary(funcNamespace).getFunction(funcName);
FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0);

Map<String, FType> funcFedVars = createFunctionFedVarTable((FunctionOp) hop, memo);
rRewriteStatementBlock(sbFuncBlock, funcFedVars);
mapFunctionOutputs((FunctionOp) hop, funcStatement, funcFedVars, fedVars);
}
else if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) )
memo.put(hop.getHopID(), deriveFType((DataOp)hop));
else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) )
memo.put(hop.getHopID(), fedVars.get(hop.getName()));
else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) )
fedVars.put(hop.getName(), memo.get(hop.getHopID()));
else if( allowsFederated(hop, memo) ) {
hop.setForcedExecType(ExecType.FED);
memo.put(hop.getHopID(), getFederatedOut(hop, memo));
Expand All @@ -151,4 +165,23 @@ else if( allowsFederated(hop, memo) ) {
else // memoization as processed, but not federated
memo.put(hop.getHopID(), null);
}

@NotNull
static private Map<String, FType> createFunctionFedVarTable(FunctionOp hop, Map<Long, FType> memo) {
Map<String, Hop> funcParamMap = FederatedPlannerUtils.getParamMap(hop);
Map<String, FType> funcFedVars = new HashMap<>();
funcParamMap.forEach((key, value) -> {
funcFedVars.put(key, memo.get(value.getHopID()));
});
return funcFedVars;
}

// TODO: Reduce code duplication. The general structure of the federated planners is similar, but memo tables are different.
private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement,
Map<String, FType> funcFedVars, Map<String, FType> callFedVars) {
for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) {
FType outputFType = funcFedVars.get(funcStatement.getOutputParams().get(i).getName());
callFedVars.put(sbHop.getOutputVariableNames()[i], outputFType);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ public void runL2SVMCostBasedTest(){
}

@Test
@Ignore
public void runL2SVMFunctionFOUTTest(){
String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*",
"fed_max", "fed_1-*", "fed_tsmm", "fed_>"};
Expand All @@ -90,7 +89,6 @@ public void runL2SVMFunctionFOUTTest(){
}

@Test
@Ignore
public void runL2SVMFunctionHeuristicTest(){
String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"};
setTestConf("SystemDS-config-heuristic.xml");
Expand Down

0 comments on commit 38102ab

Please sign in to comment.