Skip to content
This repository has been archived by the owner on Sep 26, 2020. It is now read-only.

Commit

Permalink
Add the test view (#161)
Browse files Browse the repository at this point in the history
* Start implementing a test view

* Put the trained model path on the Job. Add a load test data plugin to load example datasets. Refine the load test data plugin input format.

* Simplify ResultFragment

* Fix tracking the trained model file

* Fix the trained model filename and test plugins

* Show existing test results

* Add show file support

* Fix a bug not setting the internal training method

* Shorten test result filenames. Write test runner errors to a log file the user can read.

* Clean up test view ui

* Fix LocalTestRunnerIntegTest

* Remove ModelSource.FromJob

* Load CSV files into a linechart in ResultFragment.kt

* Update the CLI

* Add image support in ResultFragment.kt

Co-authored-by: Austin Shalit <[email protected]>
  • Loading branch information
Octogonapus and AustinShalit authored Mar 27, 2020
1 parent 789390e commit 7877834
Show file tree
Hide file tree
Showing 30 changed files with 756 additions and 210 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class EC2TrainingScriptRunner(
|apt-cache policy docker-ce
|apt install -y docker-ce
|systemctl status docker
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.16/axon-0.1.16-py2.py3-none-any.whl
|pip3 install https://github.com/wpilibsuite/axon-cli/releases/download/v0.1.17/axon-0.1.17-py2.py3-none-any.whl
|axon create-heartbeat ${config.id}
|axon update-training-progress ${config.id} "initializing"
|axon download-model "${config.oldModelName.path}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import edu.wpi.axon.tfdata.loss.Loss
import edu.wpi.axon.tfdata.optimizer.Optimizer
import edu.wpi.axon.training.ModelDeploymentTarget
import edu.wpi.axon.util.FilePath
import edu.wpi.axon.util.getOutputModelName
import kotlin.random.Random
import org.apache.commons.lang3.RandomStringUtils

Expand Down Expand Up @@ -89,7 +90,6 @@ fun Random.nextModelSource(): ModelSource =
when (nextInt(ModelSource::class.sealedSubclasses.count())) {
0 -> ModelSource.FromExample(nextExampleModel())
1 -> ModelSource.FromFile(nextFilePath())
2 -> ModelSource.FromJob(nextInt(1, Int.MAX_VALUE))
else -> error("Missing a ModelSource case.")
}

Expand Down Expand Up @@ -127,6 +127,7 @@ fun Random.nextJob(
Layer.AveragePooling2D(RandomStringUtils.randomAlphanumeric(10), null).untrainable()
)
),
userNewModelPath: String = getOutputModelName(userOldModelPath.filename),
generateDebugComments: Boolean = nextBoolean(),
trainingMethod: InternalJobTrainingMethod = nextTrainingMethod(),
target: ModelDeploymentTarget = nextTarget(),
Expand All @@ -141,6 +142,7 @@ fun Random.nextJob(
userMetrics,
userEpochs,
userNewModel,
userNewModelPath,
generateDebugComments,
trainingMethod,
target,
Expand Down
35 changes: 14 additions & 21 deletions db/src/main/kotlin/edu/wpi/axon/db/JobDb.kt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import edu.wpi.axon.tfdata.Model
import edu.wpi.axon.tfdata.loss.Loss
import edu.wpi.axon.tfdata.optimizer.Optimizer
import edu.wpi.axon.training.ModelDeploymentTarget
import edu.wpi.axon.util.FilePath
import org.jetbrains.exposed.dao.IntIdTable
import org.jetbrains.exposed.sql.Database
import org.jetbrains.exposed.sql.ResultRow
Expand All @@ -36,6 +37,7 @@ internal object Jobs : IntIdTable() {
val userMetricsCol = varchar("userMetrics", 255)
val userEpochsCol = integer("userEpochs")
val userModelCol = text("userModel")
val userNewModelFilenameCol = text("userNewModelFilename")
val generateDebugCommentsCol = bool("generateDebugComments")
val internalTrainingMethodCol = varchar("internalTrainingMethod", 255)
val targetCol = varchar("target", 255)
Expand All @@ -52,6 +54,7 @@ internal object Jobs : IntIdTable() {
userEpochs = row[userEpochsCol],
generateDebugComments = row[generateDebugCommentsCol],
userNewModel = Model.deserialize(row[userModelCol]),
userNewModelFilename = row[userNewModelFilenameCol],
internalTrainingMethod = InternalJobTrainingMethod.deserialize(
row[internalTrainingMethodCol]
),
Expand Down Expand Up @@ -96,6 +99,7 @@ class JobDb(private val database: Database) {
userMetrics: Set<String>,
userEpochs: Int,
userNewModel: Model,
userNewModelFilename: String,
generateDebugComments: Boolean,
internalTrainingMethod: InternalJobTrainingMethod,
target: ModelDeploymentTarget,
Expand All @@ -112,6 +116,7 @@ class JobDb(private val database: Database) {
row[userMetricsCol] = klaxon.toJsonString(userMetrics)
row[userEpochsCol] = userEpochs
row[userModelCol] = userNewModel.serialize()
row[userNewModelFilenameCol] = userNewModelFilename
row[generateDebugCommentsCol] = generateDebugComments
row[internalTrainingMethodCol] = internalTrainingMethod.serialize()
row[targetCol] = target.serialize()
Expand All @@ -129,6 +134,7 @@ class JobDb(private val database: Database) {
userMetrics = userMetrics,
userEpochs = userEpochs,
userNewModel = userNewModel,
userNewModelFilename = userNewModelFilename,
generateDebugComments = generateDebugComments,
internalTrainingMethod = internalTrainingMethod,
target = target,
Expand All @@ -141,23 +147,6 @@ class JobDb(private val database: Database) {
return job
}

fun update(job: Job) = update(
job.id,
job.name,
job.status,
job.userOldModelPath,
job.userDataset,
job.userOptimizer,
job.userLoss,
job.userMetrics,
job.userEpochs,
job.userNewModel,
job.generateDebugComments,
job.internalTrainingMethod,
job.target,
job.datasetPlugin
)

fun update(
id: Int,
name: String? = null,
Expand All @@ -169,6 +158,7 @@ class JobDb(private val database: Database) {
userMetrics: Set<String>? = null,
userEpochs: Int? = null,
userNewModel: Model? = null,
userNewModelFilename: FilePath.Local? = null,
generateDebugComments: Boolean? = null,
internalJobTrainingMethod: InternalJobTrainingMethod? = null,
target: ModelDeploymentTarget? = null,
Expand All @@ -185,6 +175,9 @@ class JobDb(private val database: Database) {
userMetrics?.let { row[userMetricsCol] = klaxon.toJsonString(userMetrics) }
userEpochs?.let { row[userEpochsCol] = userEpochs }
userNewModel?.let { row[userModelCol] = userNewModel.serialize() }
userNewModelFilename?.let {
row[userNewModelFilenameCol] = userNewModelFilename.path
}
generateDebugComments?.let { row[generateDebugCommentsCol] = generateDebugComments }
internalJobTrainingMethod?.let {
row[internalTrainingMethodCol] = internalJobTrainingMethod.serialize()
Expand Down Expand Up @@ -228,10 +221,10 @@ class JobDb(private val database: Database) {
fun fetchRunningJobs(): List<Job> = transaction(database) {
// TODO: Split the error log off of the status so we don't need to do this filter
Jobs.select {
(Jobs.statusCol like "%Creating%") or
(Jobs.statusCol like "%Initializing%") or
(Jobs.statusCol like "%InProgress%")
}.map { Jobs.toDomain(it) }
(Jobs.statusCol like "%Creating%") or
(Jobs.statusCol like "%Initializing%") or
(Jobs.statusCol like "%InProgress%")
}.map { Jobs.toDomain(it) }
.filter { it.status !is TrainingScriptProgress.Error }
}

Expand Down
2 changes: 2 additions & 0 deletions db/src/main/kotlin/edu/wpi/axon/db/data/Job.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import edu.wpi.axon.training.ModelDeploymentTarget
* @param userEpochs The number of epochs.
* @param userNewModel The new model configuration (the old model after it was configured by the
* user).
* @param userNewModelFilename The filename of the new model (after training).
* @param generateDebugComments Whether to put debug comments in the output.
* @param internalTrainingMethod Do not set this directly, this should always start as
* [InternalJobTrainingMethod.Untrained]. If you want to control where the Job is trained, set the
Expand All @@ -35,6 +36,7 @@ data class Job(
val userMetrics: Set<String>,
val userEpochs: Int,
val userNewModel: Model,
val userNewModelFilename: String,
val generateDebugComments: Boolean,
val internalTrainingMethod: InternalJobTrainingMethod,
val target: ModelDeploymentTarget,
Expand Down
17 changes: 9 additions & 8 deletions db/src/main/kotlin/edu/wpi/axon/db/data/ModelSource.kt
Original file line number Diff line number Diff line change
Expand Up @@ -12,29 +12,30 @@ import kotlinx.serialization.json.JsonConfiguration
@Serializable
sealed class ModelSource {

abstract val filename: String

/**
* From an example model.
*/
@Serializable
data class FromExample(val exampleModel: ExampleModel) : ModelSource()
data class FromExample(val exampleModel: ExampleModel) : ModelSource() {
override val filename: String = exampleModel.fileName
}

/**
* From a FilePath.
*/
@Serializable
data class FromFile(val filePath: FilePath) : ModelSource()

/**
* From the trained output of a Job.
*/
@Serializable
data class FromJob(val jobId: Int) : ModelSource()
data class FromFile(val filePath: FilePath) : ModelSource() {
override val filename: String = filePath.filename
}

fun serialize(): String = Json(
JsonConfiguration.Stable
).stringify(serializer(), this)

companion object {

fun deserialize(data: String) = Json(
JsonConfiguration.Stable
).parse(serializer(), data)
Expand Down
3 changes: 2 additions & 1 deletion dsl/src/main/kotlin/edu/wpi/axon/dsl/task/LoadStringTask.kt
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ class LoadStringTask(name: String) : BaseTask(name) {

override val dependencies: MutableSet<Code<*>> = mutableSetOf()

override fun code() = """${output.name} = "$data""""
override fun code() =
"""${output.name} = "${data.replace("\\", "\\\\").replace("\"", "\\\"")}""""
}
11 changes: 11 additions & 0 deletions dsl/src/test/kotlin/edu/wpi/axon/dsl/task/LoadStringTaskTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,15 @@ internal class LoadStringTaskTest : KoinTestFixture() {
output = configuredCorrectly("output")
}.code().shouldBe("output = \"data\"")
}

@Test
fun `test escaping the string`() {
startKoin { }
LoadStringTask("").apply {
data = """"quoted" \backslashes\"""
output = configuredCorrectly("output")
}.code().shouldBe("""
output = "\"quoted\" \\backslashes\\"
""".trimIndent())
}
}
41 changes: 41 additions & 0 deletions plugin/src/main/kotlin/edu/wpi/axon/plugin/LoadTestDataPlugins.kt
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package edu.wpi.axon.plugin

object LoadTestDataPlugins {

// TODO: Don't hardcode which element to return
val loadExampleDatasetPlugin = Plugin.Official(
"Load an Example Dataset",
"""
|def load_test_data(input):
| import json
| loaded_json = json.loads(input)
|
| try:
| type = loaded_json["example_dataset"]
| if type == "boston_housing":
| dataset = tf.keras.datasets.boston_housing
| elif type == "cifar10":
| dataset = tf.keras.datasets.cifar10
| elif type == "cifar100":
| dataset = tf.keras.datasets.cifar100
| elif type == "fashion_mnist":
| dataset = tf.keras.datasets.fashion_mnist
| elif type == "imdb":
| dataset = tf.keras.datasets.imdb
| elif type == "mnist":
| dataset = tf.keras.datasets.mnist
| elif type == "reuters":
| dataset = tf.keras.datasets.reuters
| else:
| raise RuntimeError("Cannot load the dataset.")
| except KeyError:
| raise RuntimeError("Cannot load the dataset.")
|
| (x_train, y_train), (x_test, y_test) = dataset.load_data()
| x_test = x_test[:1]
| x_test = tf.cast(x_test / 255, tf.float32)
| x_test = x_test[..., tf.newaxis]
| return (x_test, y_test[:1], 1)
""".trimMargin()
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package edu.wpi.axon.plugin

object ProcessTestOutputPlugins {

val serializeModelOutputPlugin = Plugin.Official(
"Serialize Model Output",
"""
|def process_model_output(model_input, expected_output, model_output):
| # import json
| import numpy as np
| with open("output/expected_output.txt", "w+") as f:
| # json.dump(expected_output, f)
| np.savetxt(f, expected_output)
| with open("output/model_output.txt", "w+") as f:
| # json.dump(model_output, f)
| np.savetxt(f, model_output)
""".trimMargin()
)
}
Loading

0 comments on commit 7877834

Please sign in to comment.