Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FUTURE][SYSTEMML-2078] Add support for global variables #754

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/main/java/org/apache/sysml/parser/DMLProgram.java
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ public class DMLProgram
private ArrayList<StatementBlock> _blocks;
private HashMap<String, FunctionStatementBlock> _functionBlocks;
private HashMap<String,DMLProgram> _namespaces;
private VariableSet _globals; // global variables
public static final String DEFAULT_NAMESPACE = ".defaultNS";
public static final String INTERNAL_NAMESPACE = "_internal"; // used for multi-return builtin functions
private static final Log LOG = LogFactory.getLog(DMLProgram.class.getName());
Expand All @@ -41,6 +42,11 @@ public DMLProgram(){
_blocks = new ArrayList<>();
_functionBlocks = new HashMap<>();
_namespaces = new HashMap<>();
_globals = new VariableSet();
}

public VariableSet getGlobalVariables() {
return _globals;
}

public HashMap<String,DMLProgram> getNamespaces(){
Expand Down
40 changes: 30 additions & 10 deletions src/main/java/org/apache/sysml/parser/DMLTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,26 @@ public void validateParseTree(DMLProgram dmlp)
boolean fWriteRead = prepareReadAfterWrite(dmlp, new HashMap<String, DataIdentifier>());

//STEP2: Actual Validate

// handle regular blocks -- "main" program
VariableSet vs = new VariableSet();
HashMap<String, ConstIdentifier> constVars = new HashMap<>();
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock sb = dmlp.getStatementBlock(i);
vs.addVariables(dmlp.getGlobalVariables());
vs = sb.validate(dmlp, vs, constVars, fWriteRead);
constVars = sb.getConstOut();
}

// handle functions in namespaces (current program has default namespace)
for (String namespaceKey : dmlp.getNamespaces().keySet()){

// for each function defined in the namespace
for (String fname : dmlp.getFunctionStatementBlocks(namespaceKey).keySet()) {
FunctionStatementBlock fblock = dmlp.getFunctionStatementBlock(namespaceKey,fname);

HashMap<String, ConstIdentifier> constVars = new HashMap<>();
VariableSet vs = new VariableSet();
constVars = new HashMap<>();
vs = new VariableSet();

// add the input variables for the function to input variable list
FunctionStatement fstmt = (FunctionStatement)fblock.getStatement(0);
Expand All @@ -130,19 +141,13 @@ public void validateParseTree(DMLProgram dmlp)
currVar.setDimensions(0, 0);
vs.addVariable(currVar.getName(), currVar);
}
vs.addVariables(dmlp.getGlobalVariables());
fblock.validate(dmlp, vs, constVars, false);
}

}

// handle regular blocks -- "main" program
VariableSet vs = new VariableSet();
HashMap<String, ConstIdentifier> constVars = new HashMap<>();
for (int i = 0; i < dmlp.getNumStatementBlocks(); i++) {
StatementBlock sb = dmlp.getStatementBlock(i);
vs = sb.validate(dmlp, vs, constVars, fWriteRead);
constVars = sb.getConstOut();
}


//STEP3: Post-processing steps after validate - e.g., prepare read-after-write meta data
if( fWriteRead )
Expand Down Expand Up @@ -177,6 +182,11 @@ public void liveVariableAnalysis(DMLProgram dmlp) {
activeIn.addVariable(id.getName(), id);
}
fsb.initializeforwardLV(activeIn);

// inject the needed global variables
HashMap<String, DataIdentifier> globals = new HashMap<>(fstmt.getBody().get(0).getGen().getVariables());
fstmt.getInputParams().forEach(p -> globals.remove(p.getName()));
fstmt.getScopeVariables().addAll(globals.values());
}
}

Expand Down Expand Up @@ -213,6 +223,16 @@ public void liveVariableAnalysis(DMLProgram dmlp) {
StatementBlock sb = dmlp.getStatementBlock(i);
activeIn = sb.initializeforwardLV(activeIn);
}

// find the global variables in regular program blocks
VariableSet globals = new VariableSet();
for (StatementBlock sb : dmlp.getStatementBlocks()) {
if (!(sb instanceof ForStatementBlock) && !(sb instanceof FunctionStatementBlock)
&& !(sb instanceof IfStatementBlock) && !(sb instanceof WhileStatementBlock)) {
globals = sb.initializeforwardLV(globals);
}
}
dmlp.getGlobalVariables().addVariables(globals);

if (dmlp.getNumStatementBlocks() > 0){
StatementBlock lastSb = dmlp.getStatementBlock(dmlp.getNumStatementBlocks() - 1);
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/org/apache/sysml/parser/FunctionStatement.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ public class FunctionStatement extends Statement
protected String _name;
protected ArrayList <DataIdentifier> _inputParams;
protected ArrayList <DataIdentifier> _outputParams;
protected ArrayList <DataIdentifier> _scopeVariables; // the copied variables of this function scope

@Override
public Statement rewriteStatement(String prefix) {
Expand All @@ -41,6 +42,7 @@ public FunctionStatement(){
_name = null;
_inputParams = new ArrayList<>();
_outputParams = new ArrayList<>();
_scopeVariables = new ArrayList<>();
}

public ArrayList<DataIdentifier> getInputParams(){
Expand All @@ -51,6 +53,10 @@ public ArrayList<DataIdentifier> getOutputParams(){
return _outputParams;
}

public ArrayList<DataIdentifier> getScopeVariables() {
return _scopeVariables;
}

public void setInputParams(ArrayList<DataIdentifier> inputParams){
_inputParams = inputParams;
}
Expand Down
46 changes: 27 additions & 19 deletions src/main/java/org/apache/sysml/parser/StatementBlock.java
Original file line number Diff line number Diff line change
Expand Up @@ -609,28 +609,16 @@ public ArrayList<Statement> rewriteFunctionCallStatements (DMLProgram dmlProg, A
fcall.getParamExprs().size() + " found, but " + fstmt.getInputParams().size()+" expected.");
}

for (int i =0; i < fstmt.getInputParams().size(); i++) {
for (int i = 0; i < fstmt.getInputParams().size(); i++) {
DataIdentifier currFormalParam = fstmt.getInputParams().get(i);
Expression exp = fcall.getParamExprs().get(i).getExpr();
bindScopeVariable(newStatements, exp, prefix, currFormalParam);
}

// copy the referenced global variables into function scope
for (DataIdentifier v : fstmt.getScopeVariables()) {
// create new assignment statement
String newFormalParameterName = prefix + currFormalParam.getName();
DataIdentifier newTarget = new DataIdentifier(currFormalParam);
newTarget.setName(newFormalParameterName);

Expression currCallParam = fcall.getParamExprs().get(i).getExpr();

//auto casting of inputs on inlining (if required)
ValueType targetVT = newTarget.getValueType();
if (newTarget.getDataType() == DataType.SCALAR && currCallParam.getOutput() != null
&& targetVT != currCallParam.getOutput().getValueType() && targetVT != ValueType.STRING) {
currCallParam = new BuiltinFunctionExpression(
BuiltinFunctionExpression.getValueTypeCastOperator(targetVT),
new Expression[] { currCallParam }, newTarget);
}

// create the assignment statement to bind the call parameter to formal parameter
AssignmentStatement binding = new AssignmentStatement(newTarget, currCallParam, newTarget);
newStatements.add(binding);
bindScopeVariable(newStatements, v, prefix, v);
}

for (Statement stmt : sblock._statements){
Expand Down Expand Up @@ -710,6 +698,26 @@ public ArrayList<Statement> rewriteFunctionCallStatements (DMLProgram dmlProg, A
return newStatements;
}

private void bindScopeVariable(ArrayList<Statement> newStatements, Expression currCallParam, String prefix, DataIdentifier currFormalParam) {
// create new assignment statement
String newFormalParameterName = prefix + currFormalParam.getName();
DataIdentifier newTarget = new DataIdentifier(currFormalParam);
newTarget.setName(newFormalParameterName);

//auto casting of inputs on inlining (if required)
ValueType targetVT = newTarget.getValueType();
if (newTarget.getDataType() == DataType.SCALAR && currCallParam.getOutput() != null
&& targetVT != currCallParam.getOutput().getValueType() && targetVT != ValueType.STRING) {
currCallParam = new BuiltinFunctionExpression(
BuiltinFunctionExpression.getValueTypeCastOperator(targetVT),
new Expression[] { currCallParam }, newTarget);
}

// create the assignment statement to bind the call parameter to formal parameter
AssignmentStatement binding = new AssignmentStatement(newTarget, currCallParam, newTarget);
newStatements.add(binding);
}

public VariableSet validate(DMLProgram dmlProg, VariableSet ids, HashMap<String, ConstIdentifier> constVars, boolean conditional)
{
_constVarsIn.putAll(constVars);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@

public class MLContextTest extends MLContextTestBase {

@Test
public void testBasicGlobalVariablesTest() {
System.out.println("MLContextTest - basic global variables test");
Script script = dmlFromFile(baseDirectory + File.separator + "global-variables-test.dml");
ml.execute(script);
Assert.assertTrue(Statistics.getNoOfExecutedSPInst() == 0);
}

@Test
public void testBuiltinConstantsTest() {
System.out.println("MLContextTest - basic builtin constants test");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

# a func reading the global value
f1 = function (matrix[double] M, double factor) return (double res) {
res = prod(M) * factor * gv
print(toString("local factor = "+factor))
}

# a func trying to modify the global value
f2 = function () return (double res) {
gv = 12345
res = gv
}

# global variables
X = matrix("1 2 3 4", rows=2, cols=2)
y = 10
gv = 5
factor = 1

print(toString("global factor = "+factor))
print(toString(f1(X, y)))

print(toString("before call of f2, gv="+gv))
print(f2())
print(toString("after call of f2, gv="+gv))