Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Commit

Permalink
Wizard (#165)
Browse files Browse the repository at this point in the history
* Wizard steps

* Wizardy things

* Implement onCommit for JobWizardItem

* Add IMDB example

* Added auto mpg to wizard and fix creation bug

* Allow the wizard to be run multiple times

* Possibly fixed the wizard model bug

* Bump ci

* Solve missing input selection options on first Wizard

* Add name field description

* Start work on togglegroup button

* Fix JobWizardModel

* Data grid with images

* spotless

* Center and style images

* Convert task selection to datagrid

* Add images

* Remove 'a' from imdb description

* Add target datagrid with images

* Commit photos

* Write to remaining wizard model fields in onCommit

Co-authored-by: Austin Shalit <[email protected]>
  • Loading branch information
Octogonapus and AustinShalit authored May 6, 2020
1 parent e0ea6ee commit 77e9bc8
Show file tree
Hide file tree
Showing 14 changed files with 520 additions and 73 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package edu.wpi.axon.ui.controller

import edu.wpi.axon.tfdata.Dataset
import edu.wpi.axon.tfdata.loss.Loss
import edu.wpi.axon.tfdata.optimizer.Optimizer
import edu.wpi.axon.training.ModelDeploymentTarget
import edu.wpi.axon.ui.model.DatasetType
import edu.wpi.axon.ui.model.TaskInput
import edu.wpi.axon.ui.model.WizardTarget
import edu.wpi.axon.ui.model.WizardTask
import javafx.collections.FXCollections
import javafx.collections.ObservableList
import tornadofx.Controller

class WizardTaskService : Controller() {
val tasks: ObservableList<WizardTask> = FXCollections.observableArrayList(
WizardTask("Classification",
"Separate items into categories",
resources["/classification.png"],
listOf(
TaskInput("MNIST",
"Classify handwritten digits",
resources["/MNIST.png"],
DatasetType.EXAMPLE,
Dataset.ExampleDataset.Mnist,
Optimizer.Adam(),
Loss.CategoricalCrossentropy),
TaskInput("Fashion MNIST",
"Classify photos of clothing",
resources["/Fashion_MNIST.png"],
DatasetType.EXAMPLE,
Dataset.ExampleDataset.FashionMnist,
Optimizer.Adam(),
Loss.CategoricalCrossentropy),
TaskInput("IMDB",
"Classify positive or negative movie reviews",
resources["/imdb.png"],
DatasetType.EXAMPLE,
Dataset.ExampleDataset.IMDB,
Optimizer.Adam(),
Loss.CategoricalCrossentropy)
)),
WizardTask("Regression",
"Perform a regression on a set of data",
resources["/regression.png"],
listOf(
TaskInput("Auto MPG",
"Predict the MPG of a provided vehicle configuration",
resources["/auto_mpg_car.jpg"],
DatasetType.EXAMPLE,
Dataset.ExampleDataset.AutoMPG,
Optimizer.RMSprop(),
Loss.MeanSquaredError)
))
)

val targets: ObservableList<WizardTarget> = FXCollections.observableArrayList(
WizardTarget("Desktop", "For inference on a computer", resources["/desktop.png"], ModelDeploymentTarget.Desktop::class),
WizardTarget("Coral", "For inference with a Google™ Coral", resources["/coral.jpg"], ModelDeploymentTarget.Coral::class)
)
}
135 changes: 114 additions & 21 deletions ui-javafx/src/main/kotlin/edu/wpi/axon/ui/model/Job.kt
Original file line number Diff line number Diff line change
@@ -1,81 +1,99 @@
package edu.wpi.axon.ui.model

import edu.wpi.axon.db.JobDb
import edu.wpi.axon.db.data.InternalJobTrainingMethod
import edu.wpi.axon.db.data.Job
import edu.wpi.axon.db.data.ModelSource
import edu.wpi.axon.db.data.TrainingScriptProgress
import edu.wpi.axon.examplemodel.ExampleModelManager
import edu.wpi.axon.plugin.DatasetPlugins
import edu.wpi.axon.plugin.Plugin
import edu.wpi.axon.tfdata.Dataset
import edu.wpi.axon.tfdata.Model
import edu.wpi.axon.tfdata.loss.Loss
import edu.wpi.axon.tfdata.optimizer.Optimizer
import edu.wpi.axon.training.ModelDeploymentTarget
import edu.wpi.axon.ui.ModelManager
import edu.wpi.axon.util.FilePath
import edu.wpi.axon.util.getOutputModelName
import javafx.beans.property.SimpleIntegerProperty
import javafx.beans.property.SimpleObjectProperty
import javafx.beans.property.SimpleSetProperty
import javafx.beans.property.SimpleStringProperty
import kotlin.reflect.KClass
import tornadofx.Commit
import tornadofx.ItemViewModel
import tornadofx.asObservable
import tornadofx.getValue
import tornadofx.setValue
import tornadofx.toObservable

data class JobDto(val job: Job) {
val nameProperty = SimpleStringProperty(job.name)
data class JobDto(val job: Job?) {
val nameProperty = SimpleStringProperty(job?.name)
var name by nameProperty

val statusProperty = SimpleObjectProperty(job.status)
val statusProperty = SimpleObjectProperty<TrainingScriptProgress>(job?.status)
var status by statusProperty

val userOldModelPathProperty = SimpleObjectProperty(job.userOldModelPath)
val userOldModelPathProperty = SimpleObjectProperty<ModelSource>(job?.userOldModelPath)
var userOldModelPath by userOldModelPathProperty

val oldModelTypeProperty = SimpleObjectProperty(
when (job.userOldModelPath) {
is ModelSource.FromExample -> ModelSourceType.EXAMPLE
val oldModelTypeProperty = SimpleObjectProperty<ModelSourceType>(
when (job?.userOldModelPath) {
is ModelSource.FromExample, null -> ModelSourceType.EXAMPLE
is ModelSource.FromFile -> ModelSourceType.FILE
}
)
var oldModelType by oldModelTypeProperty

val userDatasetProperty = SimpleObjectProperty(job.userDataset)
val userDatasetProperty = SimpleObjectProperty<Dataset>(job?.userDataset)
var userDataset by userDatasetProperty

val userOptimizerProperty = SimpleObjectProperty(job.userOptimizer)
val userOptimizerProperty = SimpleObjectProperty<Optimizer>(job?.userOptimizer)
var userOptimizer by userOptimizerProperty

val optimizerTypeProperty = SimpleObjectProperty(job.userOptimizer::class)
val optimizerTypeProperty =
SimpleObjectProperty<KClass<out Optimizer>>(job?.let { it.userOptimizer::class })
var optimizerType by optimizerTypeProperty

val userLossProperty = SimpleObjectProperty(job.userLoss)
val userLossProperty = SimpleObjectProperty<Loss>(job?.userLoss)
var userLoss by userLossProperty

val lossTypeProperty = SimpleObjectProperty(job.userLoss::class)
val lossTypeProperty = SimpleObjectProperty<KClass<out Loss>>(job?.let { it.userLoss::class })
var lossType by lossTypeProperty

val userMetricsProperty = SimpleSetProperty(job.userMetrics.asObservable())
val userMetricsProperty = SimpleSetProperty<String>(job?.userMetrics?.asObservable())
var userMetrics by userMetricsProperty

val userEpochsProperty = SimpleIntegerProperty(job.userEpochs)
val userEpochsProperty = SimpleIntegerProperty(job?.userEpochs ?: 1)
var userEpochs by userEpochsProperty

val userNewModelProperty = SimpleObjectProperty(job.userNewModel)
val userNewModelProperty = SimpleObjectProperty<Model>(job?.userNewModel)
var userNewModel by userNewModelProperty

val userNewModelFilenameProperty = SimpleObjectProperty(job.userNewModelFilename)
val userNewModelFilenameProperty = SimpleObjectProperty<String>(job?.userNewModelFilename)
var userNewModelFilename by userNewModelFilenameProperty

val internalTrainingMethodProperty = SimpleObjectProperty(job.internalTrainingMethod)
val internalTrainingMethodProperty =
SimpleObjectProperty<InternalJobTrainingMethod>(job?.internalTrainingMethod)
var internalTrainingMethod by internalTrainingMethodProperty

val targetProperty = SimpleObjectProperty<ModelDeploymentTarget>(job.target)
val targetProperty = SimpleObjectProperty<ModelDeploymentTarget>(job?.target)
var target by targetProperty

val targetTypeProperty = SimpleObjectProperty(job.target::class)
val targetTypeProperty =
SimpleObjectProperty<KClass<out ModelDeploymentTarget>>(job?.let { it.target::class })
var targetType by targetTypeProperty

val datasetPluginProperty = SimpleObjectProperty<Plugin>(job.datasetPlugin)
val datasetPluginProperty = SimpleObjectProperty<Plugin>(job?.datasetPlugin)
var datasetPlugin by datasetPluginProperty

val idProperty = SimpleIntegerProperty(job.id)
val idProperty = SimpleIntegerProperty(job?.id ?: -1)
var id by idProperty

override fun toString(): String {
return "JobDto(job=$job, nameProperty=$nameProperty, statusProperty=$statusProperty, userOldModelPathProperty=$userOldModelPathProperty, oldModelTypeProperty=$oldModelTypeProperty, userDatasetProperty=$userDatasetProperty, userOptimizerProperty=$userOptimizerProperty, optimizerTypeProperty=$optimizerTypeProperty, userLossProperty=$userLossProperty, lossTypeProperty=$lossTypeProperty, userMetricsProperty=$userMetricsProperty, userEpochsProperty=$userEpochsProperty, userNewModelProperty=$userNewModelProperty, userNewModelFilenameProperty=$userNewModelFilenameProperty, internalTrainingMethodProperty=$internalTrainingMethodProperty, targetProperty=$targetProperty, targetTypeProperty=$targetTypeProperty, datasetPluginProperty=$datasetPluginProperty, idProperty=$idProperty)"
}
}

class JobModel : ItemViewModel<JobDto>() {
Expand Down Expand Up @@ -121,3 +139,78 @@ class JobModel : ItemViewModel<JobDto>() {

override fun toString() = "JobModel($item)"
}

class JobWizardModel : ItemViewModel<JobDto>() {

private val modelManager by di<ModelManager>()
private val exampleModelManager by di<ExampleModelManager>()

val name = bind(JobDto::nameProperty, autocommit = true)
val status = bind(JobDto::statusProperty, autocommit = true)
val userOldModelPath = bind(JobDto::userOldModelPathProperty, autocommit = true)
val oldModelType = bind(JobDto::oldModelTypeProperty, autocommit = true)
val userDataset = bind(JobDto::userDatasetProperty, autocommit = true)
val userOptimizer = bind(JobDto::userOptimizerProperty, autocommit = true)
val optimizerType = bind(JobDto::optimizerTypeProperty, autocommit = true)
val userLoss = bind(JobDto::userLossProperty, autocommit = true)
val lossType = bind(JobDto::lossTypeProperty, autocommit = true)
val userMetrics = bind(JobDto::userMetricsProperty, autocommit = true)
val userEpochs = bind(JobDto::userEpochsProperty, autocommit = true)
val userNewModel = bind(JobDto::userNewModelProperty, autocommit = true)
val userNewModelFilename = bind(JobDto::userNewModelFilenameProperty, autocommit = true)
val internalTrainingMethod = bind(JobDto::internalTrainingMethodProperty, autocommit = true)
val target = bind(JobDto::targetProperty, autocommit = true)
val targetType = bind(JobDto::targetTypeProperty, autocommit = true)
val datasetPlugin = bind(JobDto::datasetPluginProperty, autocommit = true)

val task = bind(autocommit = true) { SimpleObjectProperty<WizardTask>() }
val taskInput = bind(autocommit = true) { SimpleObjectProperty<TaskInput>() }
val wizardTarget = bind(autocommit = true) { SimpleObjectProperty<WizardTarget>() }

override fun onCommit() {
// Logic for detecting parameters goes here
val exampleModelName = "${task.value.title} - ${taskInput.value.title}"
val exampleModel = exampleModelManager.getAllExampleModels()
.unsafeRunSync()
.firstOrNull { it.name == exampleModelName }

check(exampleModel != null) {
"No example model was found with name `$exampleModelName`"
}

val modelSource = ModelSource.FromExample(exampleModel)

status.value = TrainingScriptProgress.NotStarted
userOldModelPath.value = modelSource
oldModelType.value = ModelSourceType.EXAMPLE
userDataset.value = taskInput.value.dataset
userOptimizer.value = taskInput.value.optimizer
optimizerType.value = taskInput.value.optimizer::class
userLoss.value = taskInput.value.loss
lossType.value = taskInput.value.loss::class
userMetrics.value = setOf("accuracy").toObservable()
userNewModel.value = modelManager.loadModel(modelSource)
userNewModelFilename.value = getOutputModelName(modelSource.filename)
internalTrainingMethod.value = InternalJobTrainingMethod.Untrained
targetType.value = wizardTarget.value.targetClass
target.value = when (targetType.value) {
ModelDeploymentTarget.Desktop::class -> ModelDeploymentTarget.Desktop
ModelDeploymentTarget.Coral::class -> ModelDeploymentTarget.Coral()
else -> error("Invalid target")
}
datasetPlugin.value = extractDatasetPlugin(taskInput.value.dataset, modelSource)
}

private fun extractDatasetPlugin(dataset: Dataset, model: ModelSource.FromExample) =
when (dataset) {
Dataset.ExampleDataset.BostonHousing -> TODO()
Dataset.ExampleDataset.Cifar10 -> TODO()
Dataset.ExampleDataset.Cifar100 -> TODO()
Dataset.ExampleDataset.FashionMnist -> DatasetPlugins.processMnistTypePlugin
Dataset.ExampleDataset.IMDB -> TODO()
Dataset.ExampleDataset.Mnist -> DatasetPlugins.processMnistTypePlugin
Dataset.ExampleDataset.Reuters -> TODO()
Dataset.ExampleDataset.AutoMPG -> DatasetPlugins.datasetPassthroughPlugin
is Dataset.Custom -> TODO()
}
}
13 changes: 13 additions & 0 deletions ui-javafx/src/main/kotlin/edu/wpi/axon/ui/model/WizardTask.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package edu.wpi.axon.ui.model

import edu.wpi.axon.tfdata.Dataset
import edu.wpi.axon.tfdata.loss.Loss
import edu.wpi.axon.tfdata.optimizer.Optimizer
import edu.wpi.axon.training.ModelDeploymentTarget
import kotlin.reflect.KClass

data class TaskInput(val title: String = "", val description: String = "", val graphic: String? = null, val datasetType: DatasetType, val dataset: Dataset, val optimizer: Optimizer, val loss: Loss)

data class WizardTask(val title: String = "", val description: String = "", val graphic: String? = null, val supportedInputs: List<TaskInput> = listOf())

data class WizardTarget(val title: String = "", val description: String = "", val graphic: String? = null, val targetClass: KClass<out ModelDeploymentTarget>)
Loading

0 comments on commit 77e9bc8

Please sign in to comment.