From 3e636c0728b93649646bf48379f39f67ae53d6ce Mon Sep 17 00:00:00 2001 From: Raul Casallas Date: Thu, 26 Aug 2021 06:33:10 -0500 Subject: [PATCH 1/2] feat: adding explainability --- src/main/java/com/modzy/sdk/JobClient.java | 8 +++--- src/main/java/com/modzy/sdk/ModzyClient.java | 26 +++++++++---------- src/main/java/com/modzy/sdk/model/Job.java | 7 +++++ .../modzy/sdk/samples/JobAwsInputSample.java | 2 +- .../sdk/samples/JobEmbeddedInputSample.java | 2 +- .../modzy/sdk/samples/JobFileInputSample.java | 2 +- .../modzy/sdk/samples/JobTextInputSample.java | 2 +- .../java/com/modzy/sdk/TestJobClient.java | 6 ++--- .../java/com/modzy/sdk/TestResultClient.java | 2 +- 9 files changed, 33 insertions(+), 24 deletions(-) diff --git a/src/main/java/com/modzy/sdk/JobClient.java b/src/main/java/com/modzy/sdk/JobClient.java index c477386..d7241f9 100644 --- a/src/main/java/com/modzy/sdk/JobClient.java +++ b/src/main/java/com/modzy/sdk/JobClient.java @@ -205,10 +205,11 @@ private Job closeJob(Job job) throws ApiException{ * @param model The model instance in which the model will run * @param modelVersion The specific version of the model * @param jobInput The inputs of the model to pass to Modzy + * @param explain If the model supports explainability, flag this job to return an explanation of the predictions * @return the updated instance of the Job returned by Modzy API * @throws ApiException if there is something wrong with the service or the call */ - public Job submitJob(Model model, ModelVersion modelVersion, JobInput jobInput) throws ApiException{ + public Job submitJob(Model model, ModelVersion modelVersion, JobInput jobInput, Boolean explain) throws ApiException{ return this.submitJob( new Job(model, modelVersion, jobInput) ); } @@ -219,15 +220,16 @@ public Job submitJob(Model model, ModelVersion modelVersion, JobInput jobInpu * @param modelId identifier of the model * @param modelVersionId identifier of the model version * @param jobInput the inputs of the model to pass to Modzy + * @param explain If the model supports explainability, flag this job to return an explanation of the predictions * @return the updated instance of the Job returned by Modzy API * @throws ApiException if there is something wrong with the service or the call */ - public Job submitJob(String modelId, String modelVersionId, JobInput jobInput) throws ApiException{ + public Job submitJob(String modelId, String modelVersionId, JobInput jobInput, Boolean explain) throws ApiException{ Model model = new Model(); model.setIdentifier(modelId); ModelVersion modelVersion = new ModelVersion(); modelVersion.setVersion(modelVersionId); - return this.submitJob( new Job(model, modelVersion, jobInput) ); + return this.submitJob( new Job(model, modelVersion, jobInput, explain) ); } /** diff --git a/src/main/java/com/modzy/sdk/ModzyClient.java b/src/main/java/com/modzy/sdk/ModzyClient.java index 3385c54..394c382 100644 --- a/src/main/java/com/modzy/sdk/ModzyClient.java +++ b/src/main/java/com/modzy/sdk/ModzyClient.java @@ -178,10 +178,10 @@ public TagWrapper getTagsAndModels(String ...tagsId) throws ApiException{ } /** - * @see JobClient#submitJob(Model, ModelVersion, JobInput) + * @see JobClient#submitJob(Model, ModelVersion, JobInput, explain) */ - public Job submitJob(Model model, ModelVersion modelVersion, JobInput jobInput) throws ApiException{ - return this.jobClient.submitJob(model, modelVersion, jobInput); + public Job submitJob(Model model, ModelVersion modelVersion, JobInput jobInput, Boolean explain) throws ApiException{ + return this.jobClient.submitJob(model, modelVersion, jobInput, explain); } @@ -202,7 +202,7 @@ public Job submitJobText(String modelId, List textSource) throws ApiExce ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); JobInput jobInput = new JobInputText(modelVersion); jobInput.addSource(textSource); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -222,7 +222,7 @@ public Job submitJobText(String modelId, String versionId, List textSour ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, versionId); JobInput jobInput = new JobInputText(modelVersion); jobInput.addSource(textSource); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -241,7 +241,7 @@ public Job submitJobEmbedded(String modelId, List embeddedSource) ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); JobInput jobInput = new JobInputEmbedded(modelVersion); jobInput.addSource(embeddedSource); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -261,7 +261,7 @@ public Job submitJobEmbedded(String modelId, String versionId, List jobInput = new JobInputEmbedded(modelVersion); jobInput.addSource(embeddedSource); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -283,7 +283,7 @@ public Job submitJobAWSS3(String modelId, String accessKeyID, String secretAcces ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); JobInput jobInput = new JobInputS3(modelVersion, accessKeyID, secretAccessKey, region); jobInput.addSource( s3FileRefSource ); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -306,7 +306,7 @@ public Job submitJobAWSS3(String modelId, String versionId, String accessKeyID, ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, versionId); JobInput jobInput = new JobInputS3(modelVersion, accessKeyID, secretAccessKey, region); jobInput.addSource( s3FileRefSource ); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -330,7 +330,7 @@ public Job submitJobJDBC(String modelId, String url, String username, String pas Model model = this.modelClient.getModel(modelId); ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); JobInput jobInput = new JobInputJDBC(url, username, password, driver, query); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -355,7 +355,7 @@ public Job submitJobJDBC(String modelId, String versionId, String url, String us Model model = this.modelClient.getModel(modelId); ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, versionId); JobInput jobInput = new JobInputJDBC(url, username, password, driver, query); - return this.submitJob(model, modelVersion, jobInput); + return this.submitJob(model, modelVersion, jobInput, false); } /** @@ -440,7 +440,7 @@ public > T getResult(Job job, Class outputClass) throw * @throws ApiException if there is something wrong with the services or the call */ public JobOutput submitJobBlockUntilComplete(String modelId, String modelVersionId, JobInput jobInput ) throws ApiException{ - Job job = this.jobClient.submitJob(modelId, modelVersionId, jobInput); + Job job = this.jobClient.submitJob(modelId, modelVersionId, jobInput, false); job = this.blockUntilNotInJobStatus(job, 20000, JobStatus.SUBMITTED); job = this.blockUntilNotInJobStatus(job, 30000, JobStatus.IN_PROGRESS); if( !job.getStatus().equals(JobStatus.COMPLETED) ) { @@ -462,7 +462,7 @@ public JobOutput submitJobBlockUntilComplete(String modelId, String mo * @throws ApiException if there is something wrong with the services or the call */ public JobOutput submitJobBlockUntilComplete(Model model, ModelVersion modelVersion, JobInput jobInput ) throws ApiException{ - Job job = this.jobClient.submitJob(model, modelVersion, jobInput); + Job job = this.jobClient.submitJob(model, modelVersion, jobInput, false); this.logger.info("["+job.getJobIdentifier()+"] "+model.getName()+" :: "+modelVersion.getVersion()+" :: waiting "); job = this.blockUntilNotInJobStatus(job, modelVersion.getTimeout().getStatus(), JobStatus.SUBMITTED); job = this.blockUntilNotInJobStatus(job, modelVersion.getTimeout().getRun(), JobStatus.IN_PROGRESS); diff --git a/src/main/java/com/modzy/sdk/model/Job.java b/src/main/java/com/modzy/sdk/model/Job.java index 28cb890..27a25eb 100644 --- a/src/main/java/com/modzy/sdk/model/Job.java +++ b/src/main/java/com/modzy/sdk/model/Job.java @@ -22,6 +22,8 @@ public class Job { @ToString.Include private Model model; + private Boolean explain; + @ToString.Include private JobStatus status; @@ -63,5 +65,10 @@ public Job(Model model, ModelVersion modelVersion, JobInput input) { this(model, modelVersion); this.input = input; } + + public Job(Model model, ModelVersion modelVersion, JobInput input, Boolean explain) { + this(model, modelVersion, input); + this.explain = explain; + } } diff --git a/src/main/java/com/modzy/sdk/samples/JobAwsInputSample.java b/src/main/java/com/modzy/sdk/samples/JobAwsInputSample.java index d10b525..b618fd8 100644 --- a/src/main/java/com/modzy/sdk/samples/JobAwsInputSample.java +++ b/src/main/java/com/modzy/sdk/samples/JobAwsInputSample.java @@ -109,7 +109,7 @@ public static void main(String[] args) throws ApiException { jobInput.addSource("wrong-value", mapSource); // When you have all your inputs ready, you can use our helper method to submit the job as follows: - Job job = modzyClient.submitJob(model, modelVersion, jobInput); + Job job = modzyClient.submitJob(model, modelVersion, jobInput, false); // Modzy creates the job and queue for processing. The job object contains all the info that you need to keep track // of the process, the most important being the job_identifier and the job status. System.out.println(String.format("job: %s", job)); diff --git a/src/main/java/com/modzy/sdk/samples/JobEmbeddedInputSample.java b/src/main/java/com/modzy/sdk/samples/JobEmbeddedInputSample.java index 49677b0..8d1927d 100644 --- a/src/main/java/com/modzy/sdk/samples/JobEmbeddedInputSample.java +++ b/src/main/java/com/modzy/sdk/samples/JobEmbeddedInputSample.java @@ -113,7 +113,7 @@ public static void main(String[] args) throws ApiException, IOException { jobInput.addSource("wrong-values", mapSource); // When you have all your inputs ready, you can use our helper method to submit the job as follows: - Job job = modzyClient.submitJob(model, modelVersion, jobInput); + Job job = modzyClient.submitJob(model, modelVersion, jobInput, false); // Modzy creates the job and queue for processing. The job object contains all the info that you need to keep track // of the process, the most important being the job_identifier and the job status. System.out.println(String.format("job: %s", job)); diff --git a/src/main/java/com/modzy/sdk/samples/JobFileInputSample.java b/src/main/java/com/modzy/sdk/samples/JobFileInputSample.java index df719b8..4efb950 100644 --- a/src/main/java/com/modzy/sdk/samples/JobFileInputSample.java +++ b/src/main/java/com/modzy/sdk/samples/JobFileInputSample.java @@ -123,7 +123,7 @@ public static void main(String[] args) throws ApiException, IOException { jobInput.addSource("wrong-values", mapSource); // When you have all your inputs ready, you can use our helper method to submit the job as follows: - Job job = modzyClient.submitJob(model, modelVersion, jobInput); + Job job = modzyClient.submitJob(model, modelVersion, jobInput, false); // Modzy creates the job and queue for processing. The job object contains all the info that you need to keep track // of the process, the most important being the job_identifier and the job status. System.out.println(String.format("job: %s", job)); diff --git a/src/main/java/com/modzy/sdk/samples/JobTextInputSample.java b/src/main/java/com/modzy/sdk/samples/JobTextInputSample.java index 072308b..6bacf5a 100644 --- a/src/main/java/com/modzy/sdk/samples/JobTextInputSample.java +++ b/src/main/java/com/modzy/sdk/samples/JobTextInputSample.java @@ -90,7 +90,7 @@ public static void main(String[] args) throws ApiException { mapSource.put("a.wrong.key", "This input is wrong!"); jobInput.addSource("wrong-key", mapSource); // When you have all your inputs ready, you can use our helper method to submit the job as follows: - Job job = modzyClient.submitJob(model, modelVersion, jobInput); + Job job = modzyClient.submitJob(model, modelVersion, jobInput, false); // Modzy creates the job and queue for processing. The job object contains all the info that you need to keep track // of the process, the most important being the job_identifier and the job status. System.out.println(String.format("job: %s", job)); diff --git a/src/test/java/com/modzy/sdk/TestJobClient.java b/src/test/java/com/modzy/sdk/TestJobClient.java index cfa1424..92a3048 100644 --- a/src/test/java/com/modzy/sdk/TestJobClient.java +++ b/src/test/java/com/modzy/sdk/TestJobClient.java @@ -81,7 +81,7 @@ public void testSubmitJob(){ jobInput.addSource(sourceMap); Job job = null; try { - job = this.jobClient.submitJob(model, modelVersion, jobInput); + job = this.jobClient.submitJob(model, modelVersion, jobInput, false); this.logger.info( job.toString() ); } catch (ApiException e) { fail(e.getMessage()); @@ -118,7 +118,7 @@ public void testGetJob() { jobInput.addSource(sourceMap); Job job = null; try { - job = this.jobClient.submitJob(model, modelVersion, jobInput); + job = this.jobClient.submitJob(model, modelVersion, jobInput, false); this.logger.info( job.toString() ); } catch (ApiException e) { fail(e.getMessage()); @@ -171,7 +171,7 @@ public void testCancelJob() { jobInput.addSource(sourceMap); Job job = null; try { - job = this.jobClient.submitJob(model, modelVersion, jobInput); + job = this.jobClient.submitJob(model, modelVersion, jobInput, false); this.logger.info( job.toString() ); } catch (ApiException e) { fail(e.getMessage()); diff --git a/src/test/java/com/modzy/sdk/TestResultClient.java b/src/test/java/com/modzy/sdk/TestResultClient.java index dcb745d..3d29b0e 100644 --- a/src/test/java/com/modzy/sdk/TestResultClient.java +++ b/src/test/java/com/modzy/sdk/TestResultClient.java @@ -74,7 +74,7 @@ public void testGetResult(){ jobInput.addSource(sourceMap); Job job = null; try { - job = this.jobClient.submitJob(model, modelVersion, jobInput); + job = this.jobClient.submitJob(model, modelVersion, jobInput, false); this.logger.info( job.toString() ); } catch (ApiException e) { fail(e.getMessage()); From 51bcf2a7627326dfcca5436158a2886c50952db4 Mon Sep 17 00:00:00 2001 From: Raul Casallas Date: Thu, 26 Aug 2021 08:25:40 -0500 Subject: [PATCH 2/2] small fix --- src/main/java/com/modzy/sdk/ModzyClient.java | 43 +++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/src/main/java/com/modzy/sdk/ModzyClient.java b/src/main/java/com/modzy/sdk/ModzyClient.java index 394c382..e192f76 100644 --- a/src/main/java/com/modzy/sdk/ModzyClient.java +++ b/src/main/java/com/modzy/sdk/ModzyClient.java @@ -1,5 +1,6 @@ package com.modzy.sdk; +import java.io.InputStream; import java.util.Arrays; import java.util.List; import java.util.Map; @@ -19,6 +20,7 @@ import com.modzy.sdk.model.JobInputEmbedded; import com.modzy.sdk.model.JobInputJDBC; import com.modzy.sdk.model.JobInputS3; +import com.modzy.sdk.model.JobInputStream; import com.modzy.sdk.model.JobInputText; import com.modzy.sdk.model.JobOutput; import com.modzy.sdk.model.JobStatus; @@ -263,6 +265,45 @@ public Job submitJobEmbedded(String modelId, String versionId, List streamSource) throws ApiException{ + Model model = this.modelClient.getModel(modelId); + ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); + JobInput jobInput = new JobInputStream(modelVersion); + jobInput.addSource(streamSource); + return this.submitJob(model, modelVersion, jobInput, false); + } + + /** + * + * Create a new job for the model at the specific version with the input streams provided, + * this method try to match the streamSource values with the inputs of the specific version + * of the model. + * + * @param modelId the model id string + * @param versionId version id string + * @param streamSource the source(s) of the model + * @return the updated instance of the Job returned by Modzy API + * @throws ApiException if there is something wrong with the service or the call + */ + public Job submitJobFile(String modelId, String versionId, List streamSource) throws ApiException{ + Model model = this.modelClient.getModel(modelId); + ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, versionId); + JobInput jobInput = new JobInputStream(modelVersion); + jobInput.addSource(streamSource); + return this.submitJob(model, modelVersion, jobInput, false); + } /** * @@ -326,7 +367,7 @@ public Job submitJobAWSS3(String modelId, String versionId, String accessKeyID, * @return the updated instance of the Job returned by Modzy API * @throws ApiException if there is something wrong with the service or the call */ - public Job submitJobJDBC(String modelId, String url, String username, String password, String driver, String query ) throws ApiException{ + public Job submitJobJDBC(String modelId, String url, String username, String password, String driver, String query) throws ApiException{ Model model = this.modelClient.getModel(modelId); ModelVersion modelVersion = this.modelClient.getModelVersion(modelId, model.getLatestVersion()); JobInput jobInput = new JobInputJDBC(url, username, password, driver, query);