From 773752e7f4335f6dc42e39e9d80d16b0805bccc8 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sun, 17 Jul 2022 21:14:08 +0200 Subject: [PATCH 1/2] Support function calls in `FedAll` and `Heuristic` planners --- .../hops/fedplanner/AFederatedPlanner.java | 2 + .../fedplanner/FederatedPlannerCostbased.java | 22 ++++---- .../fedplanner/FederatedPlannerFedAll.java | 51 +++++++++++++++---- .../FederatedL2SVMPlanningTest.java | 2 - 4 files changed, 57 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java index 1b4382bb051..5d1edfa1a1f 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java @@ -19,6 +19,7 @@ package org.apache.sysds.hops.fedplanner; +import java.util.HashMap; import java.util.Map; import org.apache.sysds.common.Types; @@ -28,6 +29,7 @@ import org.apache.sysds.hops.AggUnaryOp; import org.apache.sysds.hops.BinaryOp; import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.TernaryOp; import org.apache.sysds.hops.fedplanner.FTypes.FType; diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java index d809544f6bc..f8d2b5aa6f6 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java @@ -194,22 +194,26 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, if(sbHop instanceof FunctionOp) { String funcName = ((FunctionOp) sbHop).getFunctionName(); String funcNamespace = ((FunctionOp) sbHop).getFunctionNamespace(); - Map funcParamMap = FederatedPlannerUtils.getParamMap((FunctionOp) sbHop); - if ( paramMap != null && funcParamMap != null) - funcParamMap.putAll(paramMap); - paramMap = funcParamMap; - FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace) - .getFunction(funcName); - rewriteStatementBlock(prog, sbFuncBlock, paramMap); - + FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace).getFunction(funcName); FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0); - FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement, transientWrites); + + paramMap = createFunctionFedVarTable(paramMap, (FunctionOp) sbHop); + rewriteStatementBlock(prog, sbFuncBlock, paramMap); + FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement); } } } return new ArrayList<>(Collections.singletonList(sb)); } + private Map createFunctionFedVarTable(Map paramMap, FunctionOp sbHop) { + Map funcParamMap = FederatedPlannerUtils.getParamMap(sbHop); + if ( paramMap != null && funcParamMap != null) + funcParamMap.putAll(paramMap); + paramMap = funcParamMap; + return paramMap; + } + /** * Set final fedouts of all hops starting from terminal hops. */ diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java index 4bf2e5606a9..aec6d54cddb 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java @@ -26,6 +26,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.hops.DataOp; +import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.hops.fedplanner.FTypes.FType; import org.apache.sysds.hops.ipa.FunctionCallGraph; @@ -45,6 +46,7 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; +import org.jetbrains.annotations.NotNull; /** * Baseline federated planner that compiles all hops @@ -89,14 +91,14 @@ private void rRewriteStatementBlock(StatementBlock sb, Map fedVar else if (sb instanceof WhileStatementBlock) { WhileStatementBlock wsb = (WhileStatementBlock) sb; WhileStatement wstmt = (WhileStatement)wsb.getStatement(0); - rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), Collections.emptyMap()); + rRewriteHop(wsb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg()); for (StatementBlock csb : wstmt.getBody()) rRewriteStatementBlock(csb, fedVars); } else if (sb instanceof IfStatementBlock) { IfStatementBlock isb = (IfStatementBlock) sb; IfStatement istmt = (IfStatement)isb.getStatement(0); - rRewriteHop(isb.getPredicateHops(), new HashMap<>(), Collections.emptyMap()); + rRewriteHop(isb.getPredicateHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg()); for (StatementBlock csb : istmt.getIfBody()) rRewriteStatementBlock(csb, fedVars); for (StatementBlock csb : istmt.getElseBody()) @@ -105,9 +107,9 @@ else if (sb instanceof IfStatementBlock) { else if (sb instanceof ForStatementBlock) { //incl parfor ForStatementBlock fsb = (ForStatementBlock) sb; ForStatement fstmt = (ForStatement)fsb.getStatement(0); - rRewriteHop(fsb.getFromHops(), new HashMap<>(), Collections.emptyMap()); - rRewriteHop(fsb.getToHops(), new HashMap<>(), Collections.emptyMap()); - rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), Collections.emptyMap()); + rRewriteHop(fsb.getFromHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg()); + rRewriteHop(fsb.getToHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg()); + rRewriteHop(fsb.getIncrementHops(), new HashMap<>(), new HashMap<>(), sb.getDMLProg()); for (StatementBlock csb : fstmt.getBody()) rRewriteStatementBlock(csb, fedVars); } @@ -117,7 +119,7 @@ else if (sb instanceof ForStatementBlock) { //incl parfor Map fedHops = new HashMap<>(); if( sb.getHops() != null ) for( Hop c : sb.getHops() ) - rRewriteHop(c, fedHops, fedVars); + rRewriteHop(c, fedHops, fedVars, sb.getDMLProg()); //TODO handle function calls @@ -129,19 +131,31 @@ else if (sb instanceof ForStatementBlock) { //incl parfor } } - private void rRewriteHop(Hop hop, Map memo, Map fedVars) { + private void rRewriteHop(Hop hop, Map memo, Map fedVars, DMLProgram program) { if( memo.containsKey(hop.getHopID()) ) return; //already processed //process children first for( Hop c : hop.getInput() ) - rRewriteHop(c, memo, fedVars); + rRewriteHop(c, memo, fedVars, program); //handle specific operators (except transient writes) - if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) ) + if(hop instanceof FunctionOp) { + String funcName = ((FunctionOp) hop).getFunctionName(); + String funcNamespace = ((FunctionOp) hop).getFunctionNamespace(); + FunctionStatementBlock sbFuncBlock = program.getFunctionDictionary(funcNamespace).getFunction(funcName); + FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0); + + Map funcFedVars = createFunctionFedVarTable((FunctionOp) hop, memo); + rRewriteStatementBlock(sbFuncBlock, funcFedVars); + mapFunctionOutputs((FunctionOp) hop, funcStatement, funcFedVars, fedVars); + } + else if( HopRewriteUtils.isData(hop, OpOpData.FEDERATED) ) memo.put(hop.getHopID(), deriveFType((DataOp)hop)); else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTREAD) ) memo.put(hop.getHopID(), fedVars.get(hop.getName())); + else if( HopRewriteUtils.isData(hop, OpOpData.TRANSIENTWRITE) ) + fedVars.put(hop.getName(), memo.get(hop.getHopID())); else if( allowsFederated(hop, memo) ) { hop.setForcedExecType(ExecType.FED); memo.put(hop.getHopID(), getFederatedOut(hop, memo)); @@ -151,4 +165,23 @@ else if( allowsFederated(hop, memo) ) { else // memoization as processed, but not federated memo.put(hop.getHopID(), null); } + + @NotNull + static private Map createFunctionFedVarTable(FunctionOp hop, Map memo) { + Map funcParamMap = FederatedPlannerUtils.getParamMap(hop); + Map funcFedVars = new HashMap<>(); + funcParamMap.forEach((key, value) -> { + funcFedVars.put(key, memo.get(value.getHopID())); + }); + return funcFedVars; + } + + // TODO: Reduce code duplication. The general structure of the federated planners is similar, but memo tables are different. + private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement, + Map funcFedVars, Map callFedVars) { + for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) { + FType outputFType = funcFedVars.get(funcStatement.getOutputParams().get(i).getName()); + callFedVars.put(sbHop.getOutputVariableNames()[i], outputFType); + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java index 60ab0d93ce5..f939fd06a8b 100644 --- a/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java +++ b/src/test/java/org/apache/sysds/test/functions/privacy/fedplanning/FederatedL2SVMPlanningTest.java @@ -81,7 +81,6 @@ public void runL2SVMCostBasedTest(){ } @Test - @Ignore public void runL2SVMFunctionFOUTTest(){ String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*", "fed_tak+*", "fed_+*", "fed_max", "fed_1-*", "fed_tsmm", "fed_>"}; @@ -90,7 +89,6 @@ public void runL2SVMFunctionFOUTTest(){ } @Test - @Ignore public void runL2SVMFunctionHeuristicTest(){ String[] expectedHeavyHitters = new String[]{ "fed_fedinit", "fed_ba+*"}; setTestConf("SystemDS-config-heuristic.xml"); From 25d8e834618f0c9bad16b4f1b5417fc347732d74 Mon Sep 17 00:00:00 2001 From: Kevin Innerebner Date: Sun, 17 Jul 2022 21:54:31 +0200 Subject: [PATCH 2/2] Update and method cleanup --- .../hops/fedplanner/FederatedPlannerCostbased.java | 12 ++---------- .../hops/fedplanner/FederatedPlannerFedAll.java | 6 ++---- .../hops/fedplanner/FederatedPlannerUtils.java | 13 +++++++++++-- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java index f8d2b5aa6f6..396d75e4b94 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerCostbased.java @@ -197,23 +197,15 @@ private ArrayList rewriteDefaultStatementBlock(DMLProgram prog, FunctionStatementBlock sbFuncBlock = prog.getFunctionDictionary(funcNamespace).getFunction(funcName); FunctionStatement funcStatement = (FunctionStatement) sbFuncBlock.getStatement(0); - paramMap = createFunctionFedVarTable(paramMap, (FunctionOp) sbHop); + paramMap = FederatedPlannerUtils.createFunctionFedVarTable(paramMap, (FunctionOp) sbHop); rewriteStatementBlock(prog, sbFuncBlock, paramMap); - FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement); + FederatedPlannerUtils.mapFunctionOutputs((FunctionOp) sbHop, funcStatement, transientWrites); } } } return new ArrayList<>(Collections.singletonList(sb)); } - private Map createFunctionFedVarTable(Map paramMap, FunctionOp sbHop) { - Map funcParamMap = FederatedPlannerUtils.getParamMap(sbHop); - if ( paramMap != null && funcParamMap != null) - funcParamMap.putAll(paramMap); - paramMap = funcParamMap; - return paramMap; - } - /** * Set final fedouts of all hops starting from terminal hops. */ diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java index aec6d54cddb..ed953e7fb04 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java @@ -19,7 +19,6 @@ package org.apache.sysds.hops.fedplanner; -import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -46,7 +45,6 @@ import org.apache.sysds.runtime.controlprogram.caching.CacheableData; import org.apache.sysds.runtime.instructions.cp.Data; import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput; -import org.jetbrains.annotations.NotNull; /** * Baseline federated planner that compiles all hops @@ -166,7 +164,8 @@ else if( allowsFederated(hop, memo) ) { memo.put(hop.getHopID(), null); } - @NotNull + // TODO: Reduce code duplication. See `createFunctionFedVarTable` and `mapFunctionOutputs` in + // `FederationPlannerUtils.java` static private Map createFunctionFedVarTable(FunctionOp hop, Map memo) { Map funcParamMap = FederatedPlannerUtils.getParamMap(hop); Map funcFedVars = new HashMap<>(); @@ -176,7 +175,6 @@ static private Map createFunctionFedVarTable(FunctionOp hop, Map< return funcFedVars; } - // TODO: Reduce code duplication. The general structure of the federated planners is similar, but memo tables are different. private void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement, Map funcFedVars, Map callFedVars) { for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) { diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java index 42c5f648f18..1d596550a15 100644 --- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java +++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerUtils.java @@ -92,10 +92,19 @@ public static Map getParamMap(FunctionOp funcOp){ * @param funcStatement The FunctionStatement of the called function * @param transientWrites map of transient writes */ - public static void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement, Map transientWrites) { - for (int i = 0; i < sbHop.getOutputVariableNames().length; ++i) { + public static void mapFunctionOutputs(FunctionOp sbHop, FunctionStatement funcStatement, + Map transientWrites) { + for(int i = 0; i < sbHop.getOutputVariableNames().length; ++i) { Hop outputWrite = transientWrites.get(funcStatement.getOutputParams().get(i).getName()); transientWrites.put(sbHop.getOutputVariableNames()[i], outputWrite); } } + + public static Map createFunctionFedVarTable(Map paramMap, FunctionOp sbHop) { + Map funcParamMap = FederatedPlannerUtils.getParamMap(sbHop); + if(paramMap != null) + funcParamMap.putAll(paramMap); + paramMap = funcParamMap; + return paramMap; + } }