Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3421] Adds missing read/writeExternal to ColumnEncoderUDF. #1716

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int ro
out.quickSetValue(i, outputCol, getCode(in, i));
}
}*/

protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
// Apply loop tiling to exploit CPU caches
double[] codes = getCodeCol(in, rowStart, blk);
Expand Down Expand Up @@ -343,7 +343,7 @@ public Callable<Object> getBuildTask(CacheBlock in) {
throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building");
}

public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow,
public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow,
int blockSize, HashMap<Integer, Object> ret) {
throw new DMLRuntimeException(
"Trying to get the PartialBuild task of an Encoder which does not support partial building");
Expand Down Expand Up @@ -409,7 +409,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
}

public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, Udf
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -19,6 +19,9 @@

package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;

import org.apache.sysds.api.DMLScript;
Expand All @@ -45,8 +48,8 @@ public class ColumnEncoderUDF extends ColumnEncoder {

//TODO pass execution context through encoder factory for arbitrary functions not just builtin
//TODO integration into IPA to ensure existence of unoptimized functions
private final String _fName;

private String _fName;
public int _domainSize = 1;

protected ColumnEncoderUDF(int ptCols, String name) {
Expand All @@ -72,7 +75,7 @@ public void build(CacheBlock in) {
public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
return null;
}

@Override
public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
Expand All @@ -82,7 +85,7 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta
MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock());
ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)}));
ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));

//call UDF function via eval machinery
var fun = new EvalNaryCPInstruction(null, "eval", "",
new CPOperand("O", ValueType.FP64, DataType.MATRIX),
Expand All @@ -93,9 +96,6 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta

//obtain result and in-place write back
MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE);
//out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true);
out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true);

if (DMLScript.STATISTICS)
Expand Down Expand Up @@ -124,14 +124,14 @@ else if(columnEncoder instanceof ColumnEncoderFeatureHash){
}
}
}

@Override
protected ColumnApplyTask<ColumnEncoderUDF> getSparseTask(CacheBlock in,
MatrixBlock out, int outputCol, int startRow, int blk)
{
throw new DMLRuntimeException("UDF encoders do not support sparse tasks.");
}

@Override
public void mergeAt(ColumnEncoder other) {
if(other instanceof ColumnEncoderUDF)
Expand Down Expand Up @@ -165,4 +165,21 @@ protected double getCode(CacheBlock in, int row) {
protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) {
throw new DMLRuntimeException("UDF encoders only support full column access.");
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
LOG.debug("Writing ColumnEncoderUTF to create");
super.writeExternal(out);
out.writeInt(_domainSize);
out.writeUTF(_fName);
}

@Override
public void readExternal(ObjectInput in) throws IOException {
LOG.debug("reading ColumnEncoderUTF");
super.readExternal(in);
_domainSize = in.readInt();
_fName = in.readUTF();
LOG.debug("set _fName: " + _fName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
List<Integer> udfIDs = TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);

// create individual encoders
if(!rcIDs.isEmpty())
for(Integer id : rcIDs)
Expand All @@ -104,7 +104,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
if(!ptIDs.isEmpty())
for(Integer id : ptIDs)
addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders);

if(!binIDs.isEmpty())
for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) {
JSONObject colspec = (JSONObject) o;
Expand All @@ -131,7 +131,7 @@ else if ("EQUI-HEIGHT".equals(method))
for(Integer id : udfIDs)
addEncoderToMap(new ColumnEncoderUDF(id, name), colEncoders);
}

// create composite decoder of all created encoders
for(Entry<Integer, List<ColumnEncoder>> listEntry : colEncoders.entrySet()) {
if(DMLScript.STATISTICS)
Expand Down Expand Up @@ -189,8 +189,8 @@ private static void addEncoderToMap(ColumnEncoder encoder, HashMap<Integer, List
}

public static int getEncoderType(ColumnEncoder columnEncoder) {
//TODO replace with columnEncoder.getType().ordinal
//(which requires a cleanup of all type handling)
// TODO replace with columnEncoder.getType().ordinal
// (which requires a cleanup of all type handling)
if(columnEncoder instanceof ColumnEncoderBin)
return EncoderType.Bin.ordinal();
else if(columnEncoder instanceof ColumnEncoderDummycode)
Expand All @@ -201,6 +201,8 @@ else if(columnEncoder instanceof ColumnEncoderPassThrough)
return EncoderType.PassThrough.ordinal();
else if(columnEncoder instanceof ColumnEncoderRecode)
return EncoderType.Recode.ordinal();
else if(columnEncoder instanceof ColumnEncoderUDF)
return EncoderType.Udf.ordinal();
throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
}

Expand All @@ -217,6 +219,8 @@ public static ColumnEncoder createInstance(int type) {
return new ColumnEncoderPassThrough();
case Recode:
return new ColumnEncoderRecode();
case Udf:
return new ColumnEncoderUDF();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -29,28 +29,29 @@
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;

public class TransformEncodeUDFTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "TransformEncodeUDF1"; //min-max
private final static String TEST_NAME2 = "TransformEncodeUDF2"; //scale w/ defaults
public class TransformEncodeUDFTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "TransformEncodeUDF1"; // min-max
private final static String TEST_NAME2 = "TransformEncodeUDF2"; // scale w/ defaults
private final static String TEST_NAME3 = "TransformEncodeUDF3"; // simple custom UDF
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR + TransformEncodeUDFTest.class.getSimpleName() + "/";
//dataset and transform tasks without missing values

// dataset and transform tasks without missing values
private final static String DATASET = "homes3/homes.csv";

@Override
public void setUp() {
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}));
}

@Test
public void testUDF1Singlenode() {
runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME1);
}

@Test
public void testUDF1Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME1);
Expand All @@ -63,30 +64,37 @@ public void testUDF2Singlenode() {

@Test
public void testUDF2Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME2);
runTransformTest(ExecMode.HYBRID, TEST_NAME2);
}

private void runTransformTest(ExecMode rt, String testname)
{
//set runtime platform

@Test
public void testUDF3Singlenode() {
runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME3);
}

@Test
public void testUDF3Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME3);
}

private void runTransformTest(ExecMode rt, String testname) {
// set runtime platform
ExecMode rtold = setExecMode(rt);

try
{

try {
getAndLoadTestConfiguration(testname);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-explain",
"-nvargs", "DATA=" + DATASET_DIR + DATASET, "R="+output("R")};
programArgs = new String[] {"-explain", "-nvargs", "DATA=" + DATASET_DIR + DATASET, "R=" + output("R")};

// compare transformencode+scale vs transformencode w/ UDF
runTest(true, false, null, -1);

//compare transformencode+scale vs transformencode w/ UDF
runTest(true, false, null, -1);

double ret = HDFSTool.readDoubleFromHDFSFile(output("R"));
Assert.assertEquals(Double.valueOf(148*9), Double.valueOf(ret));
if( rt == ExecMode.HYBRID ) {
Assert.assertEquals(Double.valueOf(148 * 9), Double.valueOf(ret));

if(rt == ExecMode.HYBRID) {
Long num = Long.valueOf(Statistics.getNoOfExecutedSPInst());
Assert.assertEquals("Wrong number of executed Spark instructions: " + num, Long.valueOf(0), num);
}
Expand Down
26 changes: 26 additions & 0 deletions src/test/scripts/functions/transform/TransformEncodeUDF3.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

F1 = read($DATA, data_type="frame", format="csv");