Skip to content

Commit

Permalink
[SYSTEMDS-3421] Adds missing read/writeExternal to ColumnEncoderUDF.
Browse files Browse the repository at this point in the history
The read/writeExternal functions for ColumnEncoderUDF where missing.
This also adds the necessary switch case in the EncoderFactory
and a test that currently does nothing.

Closes #1681
  • Loading branch information
paginabianca authored and Baunsgaard committed Nov 11, 2022
1 parent 9681737 commit cb75e3f
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int ro
out.quickSetValue(i, outputCol, getCode(in, i));
}
}*/

protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){
// Apply loop tiling to exploit CPU caches
double[] codes = getCodeCol(in, rowStart, blk);
Expand Down Expand Up @@ -343,7 +343,7 @@ public Callable<Object> getBuildTask(CacheBlock in) {
throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building");
}

public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow,
public Callable<Object> getPartialBuildTask(CacheBlock in, int startRow,
int blockSize, HashMap<Integer, Object> ret) {
throw new DMLRuntimeException(
"Trying to get the PartialBuild task of an Encoder which does not support partial building");
Expand Down Expand Up @@ -409,7 +409,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) {
}

public enum EncoderType {
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite
Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, Udf
}

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -19,6 +19,9 @@

package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.List;

import org.apache.sysds.api.DMLScript;
Expand All @@ -45,8 +48,8 @@ public class ColumnEncoderUDF extends ColumnEncoder {

//TODO pass execution context through encoder factory for arbitrary functions not just builtin
//TODO integration into IPA to ensure existence of unoptimized functions
private final String _fName;

private String _fName;
public int _domainSize = 1;

protected ColumnEncoderUDF(int ptCols, String name) {
Expand All @@ -72,7 +75,7 @@ public void build(CacheBlock in) {
public List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
return null;
}

@Override
public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
Expand All @@ -82,7 +85,7 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta
MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock());
ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)}));
ec.setVariable("O", ParamservUtils.newMatrixObject(col, true));

//call UDF function via eval machinery
var fun = new EvalNaryCPInstruction(null, "eval", "",
new CPOperand("O", ValueType.FP64, DataType.MATRIX),
Expand All @@ -93,9 +96,6 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta

//obtain result and in-place write back
MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease();
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE);
//out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE);
//out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true);
out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true);

if (DMLScript.STATISTICS)
Expand Down Expand Up @@ -124,14 +124,14 @@ else if(columnEncoder instanceof ColumnEncoderFeatureHash){
}
}
}

@Override
protected ColumnApplyTask<ColumnEncoderUDF> getSparseTask(CacheBlock in,
MatrixBlock out, int outputCol, int startRow, int blk)
{
throw new DMLRuntimeException("UDF encoders do not support sparse tasks.");
}

@Override
public void mergeAt(ColumnEncoder other) {
if(other instanceof ColumnEncoderUDF)
Expand Down Expand Up @@ -165,4 +165,21 @@ protected double getCode(CacheBlock in, int row) {
protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) {
throw new DMLRuntimeException("UDF encoders only support full column access.");
}

@Override
public void writeExternal(ObjectOutput out) throws IOException {
LOG.debug("Writing ColumnEncoderUTF to create");
super.writeExternal(out);
out.writeInt(_domainSize);
out.writeUTF(_fName);
}

@Override
public void readExternal(ObjectInput in) throws IOException {
LOG.debug("reading ColumnEncoderUTF");
super.readExternal(in);
_domainSize = in.readInt();
_fName = in.readUTF();
LOG.debug("set _fName: " + _fName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol)));
List<Integer> udfIDs = TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol);

// create individual encoders
if(!rcIDs.isEmpty())
for(Integer id : rcIDs)
Expand All @@ -104,7 +104,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V
if(!ptIDs.isEmpty())
for(Integer id : ptIDs)
addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders);

if(!binIDs.isEmpty())
for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) {
JSONObject colspec = (JSONObject) o;
Expand All @@ -131,7 +131,7 @@ else if ("EQUI-HEIGHT".equals(method))
for(Integer id : udfIDs)
addEncoderToMap(new ColumnEncoderUDF(id, name), colEncoders);
}

// create composite decoder of all created encoders
for(Entry<Integer, List<ColumnEncoder>> listEntry : colEncoders.entrySet()) {
if(DMLScript.STATISTICS)
Expand Down Expand Up @@ -189,8 +189,8 @@ private static void addEncoderToMap(ColumnEncoder encoder, HashMap<Integer, List
}

public static int getEncoderType(ColumnEncoder columnEncoder) {
//TODO replace with columnEncoder.getType().ordinal
//(which requires a cleanup of all type handling)
// TODO replace with columnEncoder.getType().ordinal
// (which requires a cleanup of all type handling)
if(columnEncoder instanceof ColumnEncoderBin)
return EncoderType.Bin.ordinal();
else if(columnEncoder instanceof ColumnEncoderDummycode)
Expand All @@ -201,6 +201,8 @@ else if(columnEncoder instanceof ColumnEncoderPassThrough)
return EncoderType.PassThrough.ordinal();
else if(columnEncoder instanceof ColumnEncoderRecode)
return EncoderType.Recode.ordinal();
else if(columnEncoder instanceof ColumnEncoderUDF)
return EncoderType.Udf.ordinal();
throw new DMLRuntimeException("Unsupported encoder type: " + columnEncoder.getClass().getCanonicalName());
}

Expand All @@ -217,6 +219,8 @@ public static ColumnEncoder createInstance(int type) {
return new ColumnEncoderPassThrough();
case Recode:
return new ColumnEncoderRecode();
case Udf:
return new ColumnEncoderUDF();
default:
throw new DMLRuntimeException("Unsupported encoder type: " + etype);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*
* http://www.apache.org/licenses/LICENSE-2.0
*
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
Expand All @@ -29,28 +29,29 @@
import org.apache.sysds.test.TestUtils;
import org.apache.sysds.utils.Statistics;

public class TransformEncodeUDFTest extends AutomatedTestBase
{
private final static String TEST_NAME1 = "TransformEncodeUDF1"; //min-max
private final static String TEST_NAME2 = "TransformEncodeUDF2"; //scale w/ defaults
public class TransformEncodeUDFTest extends AutomatedTestBase {
private final static String TEST_NAME1 = "TransformEncodeUDF1"; // min-max
private final static String TEST_NAME2 = "TransformEncodeUDF2"; // scale w/ defaults
private final static String TEST_NAME3 = "TransformEncodeUDF3"; // simple custom UDF
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR + TransformEncodeUDFTest.class.getSimpleName() + "/";
//dataset and transform tasks without missing values

// dataset and transform tasks without missing values
private final static String DATASET = "homes3/homes.csv";

@Override
public void setUp() {
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}) );
addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}));
addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"R"}));
addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"R"}));
}

@Test
public void testUDF1Singlenode() {
runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME1);
}

@Test
public void testUDF1Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME1);
Expand All @@ -63,30 +64,37 @@ public void testUDF2Singlenode() {

@Test
public void testUDF2Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME2);
runTransformTest(ExecMode.HYBRID, TEST_NAME2);
}

private void runTransformTest(ExecMode rt, String testname)
{
//set runtime platform

@Test
public void testUDF3Singlenode() {
runTransformTest(ExecMode.SINGLE_NODE, TEST_NAME3);
}

@Test
public void testUDF3Hybrid() {
runTransformTest(ExecMode.HYBRID, TEST_NAME3);
}

private void runTransformTest(ExecMode rt, String testname) {
// set runtime platform
ExecMode rtold = setExecMode(rt);

try
{

try {
getAndLoadTestConfiguration(testname);

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + testname + ".dml";
programArgs = new String[]{"-explain",
"-nvargs", "DATA=" + DATASET_DIR + DATASET, "R="+output("R")};
programArgs = new String[] {"-explain", "-nvargs", "DATA=" + DATASET_DIR + DATASET, "R=" + output("R")};

// compare transformencode+scale vs transformencode w/ UDF
runTest(true, false, null, -1);

//compare transformencode+scale vs transformencode w/ UDF
runTest(true, false, null, -1);

double ret = HDFSTool.readDoubleFromHDFSFile(output("R"));
Assert.assertEquals(Double.valueOf(148*9), Double.valueOf(ret));
if( rt == ExecMode.HYBRID ) {
Assert.assertEquals(Double.valueOf(148 * 9), Double.valueOf(ret));

if(rt == ExecMode.HYBRID) {
Long num = Long.valueOf(Statistics.getNoOfExecutedSPInst());
Assert.assertEquals("Wrong number of executed Spark instructions: " + num, Long.valueOf(0), num);
}
Expand Down
26 changes: 26 additions & 0 deletions src/test/scripts/functions/transform/TransformEncodeUDF3.dml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#-------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
#-------------------------------------------------------------

F1 = read($DATA, data_type="frame", format="csv");




0 comments on commit cb75e3f

Please sign in to comment.