diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 5d5933f6d2f..600b6989eb8 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -166,7 +166,7 @@ protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int ro out.quickSetValue(i, outputCol, getCode(in, i)); } }*/ - + protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ // Apply loop tiling to exploit CPU caches double[] codes = getCodeCol(in, rowStart, blk); @@ -343,7 +343,7 @@ public Callable getBuildTask(CacheBlock in) { throw new DMLRuntimeException("Trying to get the Build task of an Encoder which does not require building"); } - public Callable getPartialBuildTask(CacheBlock in, int startRow, + public Callable getPartialBuildTask(CacheBlock in, int startRow, int blockSize, HashMap ret) { throw new DMLRuntimeException( "Trying to get the PartialBuild task of an Encoder which does not support partial building"); @@ -409,7 +409,7 @@ protected void setApplyRowBlocksPerColumn(int nPart) { } public enum EncoderType { - Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite + Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, MVImpute, Composite, Udf } /* diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java index 114eb5f8efe..01c21e19169 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java @@ -6,9 +6,9 @@ * to you under the Apache License, Version 2.0 (the * "License"); you may not use this file except in compliance * with the License. You may obtain a copy of the License at - * + * * http://www.apache.org/licenses/LICENSE-2.0 - * + * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.transform.encode; +import java.io.IOException; +import java.io.ObjectInput; +import java.io.ObjectOutput; import java.util.List; import org.apache.sysds.api.DMLScript; @@ -45,8 +48,8 @@ public class ColumnEncoderUDF extends ColumnEncoder { //TODO pass execution context through encoder factory for arbitrary functions not just builtin //TODO integration into IPA to ensure existence of unoptimized functions - - private final String _fName; + + private String _fName; public int _domainSize = 1; protected ColumnEncoderUDF(int ptCols, String name) { @@ -72,7 +75,7 @@ public void build(CacheBlock in) { public List> getBuildTasks(CacheBlock in) { return null; } - + @Override public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; @@ -82,7 +85,7 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta MatrixBlock col = out.slice(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, new MatrixBlock()); ec.setVariable("I", new ListObject(new Data[] {ParamservUtils.newMatrixObject(col, true)})); ec.setVariable("O", ParamservUtils.newMatrixObject(col, true)); - + //call UDF function via eval machinery var fun = new EvalNaryCPInstruction(null, "eval", "", new CPOperand("O", ValueType.FP64, DataType.MATRIX), @@ -93,9 +96,6 @@ public void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowSta //obtain result and in-place write back MatrixBlock ret = ((MatrixObject)ec.getCacheableData("O")).acquireReadAndRelease(); - //out.leftIndexingOperations(ret, 0, in.getNumRows()-1, _colID-1, _colID-1, ret, UpdateType.INPLACE); - //out.leftIndexingOperations(ret, 0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, UpdateType.INPLACE); - //out.copy(0, in.getNumRows()-1, _colID-1, _colID-1, ret, true); out.copy(0, in.getNumRows()-1, outputCol, outputCol+_domainSize-1, ret, true); if (DMLScript.STATISTICS) @@ -124,14 +124,14 @@ else if(columnEncoder instanceof ColumnEncoderFeatureHash){ } } } - + @Override protected ColumnApplyTask getSparseTask(CacheBlock in, MatrixBlock out, int outputCol, int startRow, int blk) { throw new DMLRuntimeException("UDF encoders do not support sparse tasks."); } - + @Override public void mergeAt(ColumnEncoder other) { if(other instanceof ColumnEncoderUDF) @@ -165,4 +165,21 @@ protected double getCode(CacheBlock in, int row) { protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize) { throw new DMLRuntimeException("UDF encoders only support full column access."); } + + @Override + public void writeExternal(ObjectOutput out) throws IOException { + LOG.debug("Writing ColumnEncoderUTF to create"); + super.writeExternal(out); + out.writeInt(_domainSize); + out.writeUTF(_fName); + } + + @Override + public void readExternal(ObjectInput in) throws IOException { + LOG.debug("reading ColumnEncoderUTF"); + super.readExternal(in); + _domainSize = in.readInt(); + _fName = in.readUTF(); + LOG.debug("set _fName: " + _fName); + } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index f76a83e3623..9105ff79150 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -93,7 +93,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V List mvIDs = Arrays.asList(ArrayUtils.toObject( TfMetaUtils.parseJsonObjectIDList(jSpec, colnames, TfMethod.IMPUTE.toString(), minCol, maxCol))); List udfIDs = TfMetaUtils.parseUDFColIDs(jSpec, colnames, minCol, maxCol); - + // create individual encoders if(!rcIDs.isEmpty()) for(Integer id : rcIDs) @@ -104,7 +104,7 @@ public static MultiColumnEncoder createEncoder(String spec, String[] colnames, V if(!ptIDs.isEmpty()) for(Integer id : ptIDs) addEncoderToMap(new ColumnEncoderPassThrough(id), colEncoders); - + if(!binIDs.isEmpty()) for(Object o : (JSONArray) jSpec.get(TfMethod.BIN.toString())) { JSONObject colspec = (JSONObject) o; @@ -131,7 +131,7 @@ else if ("EQUI-HEIGHT".equals(method)) for(Integer id : udfIDs) addEncoderToMap(new ColumnEncoderUDF(id, name), colEncoders); } - + // create composite decoder of all created encoders for(Entry> listEntry : colEncoders.entrySet()) { if(DMLScript.STATISTICS) @@ -189,8 +189,8 @@ private static void addEncoderToMap(ColumnEncoder encoder, HashMap