diff --git a/aws/aws.gradle.kts b/aws/aws.gradle.kts index 7c07452f..66009d99 100644 --- a/aws/aws.gradle.kts +++ b/aws/aws.gradle.kts @@ -22,11 +22,26 @@ dependencies { api(project(":tf-data")) api(project(":db-data")) - // implementation(platform("software.amazon.awssdk:bom:2.9.9")) - implementation( + api( group = "software.amazon.awssdk", name = "aws-sdk-java", - version = property("aws-sdk-java.version") as String + version = "2.10.12" + ) + implementation( + group = "com.amazonaws", + name = "aws-java-sdk", + version = "1.11.674" + ) + + implementation( + group = "com.beust", + name = "klaxon", + version = property("klaxon.version") as String + ) + + implementation(group = "mysql", + name = "mysql-connector-java", + version = property("mysql-connector-java.version") as String ) implementation(project(":logging")) diff --git a/aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt b/aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt index 55836427..42dc9abb 100644 --- a/aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt +++ b/aws/src/main/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunner.kt @@ -8,8 +8,6 @@ import java.util.concurrent.atomic.AtomicLong import mu.KotlinLogging import org.apache.commons.lang3.RandomStringUtils import org.koin.core.KoinComponent -import org.koin.core.inject -import org.koin.core.qualifier.named import software.amazon.awssdk.core.sync.RequestBody import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.ec2.Ec2Client @@ -24,14 +22,19 @@ import software.amazon.awssdk.services.s3.model.PutObjectRequest * S3. This implementation requires that the script does not try to manage models with S3 itself: * this class will handle all of that. The script should just load and save the model from/to its * current directory. + * + * @param bucketName The S3 bucket name to use for dataset and models. + * @param instanceType The type of the EC2 instance to run the training script on. + * @param region The region to connect to, or `null` to autodetect the region. */ -class EC2TrainingScriptRunner : TrainingScriptRunner, KoinComponent { - - private val region: Region by inject() - private val bucketName: String by inject(named("bucketName")) - private val instanceType: InstanceType by inject() - private val s3 by lazy { S3Client.builder().region(region).build() } - private val ec2 by lazy { Ec2Client.builder().region(region).build() } +class EC2TrainingScriptRunner( + private val bucketName: String, + private val instanceType: InstanceType, + private val region: Region? +) : TrainingScriptRunner, KoinComponent { + + private val s3 by lazy { S3Client.builder().apply { region?.let { region(it) } }.build() } + private val ec2 by lazy { Ec2Client.builder().apply { region?.let { region(it) } }.build() } private val nextScriptId = AtomicLong() private val instanceIds = mutableMapOf() diff --git a/aws/src/main/kotlin/edu/wpi/axon/aws/db/DynamoJobDB.kt b/aws/src/main/kotlin/edu/wpi/axon/aws/db/DynamoJobDB.kt deleted file mode 100644 index 25280e28..00000000 --- a/aws/src/main/kotlin/edu/wpi/axon/aws/db/DynamoJobDB.kt +++ /dev/null @@ -1,169 +0,0 @@ -package edu.wpi.axon.aws.db - -import arrow.core.Left -import arrow.core.Right -import arrow.fx.IO -import edu.wpi.axon.dbdata.Job -import edu.wpi.axon.dbdata.TrainingScriptProgress -import kotlinx.coroutines.delay -import org.koin.core.KoinComponent -import org.koin.core.inject -import software.amazon.awssdk.regions.Region -import software.amazon.awssdk.services.dynamodb.DynamoDbClient -import software.amazon.awssdk.services.dynamodb.model.AttributeDefinition -import software.amazon.awssdk.services.dynamodb.model.AttributeValue -import software.amazon.awssdk.services.dynamodb.model.KeySchemaElement -import software.amazon.awssdk.services.dynamodb.model.KeyType -import software.amazon.awssdk.services.dynamodb.model.ScalarAttributeType -import software.amazon.awssdk.services.dynamodb.model.TableStatus - -class DynamoJobDB( - private val tableName: String -) : JobDB, KoinComponent { - - private val region: Region by inject() - - private val dbClient: IO = IO { - DynamoDbClient.builder().region(region).build() - } - - override fun putJob(job: Job): IO = dbClient.flatMap { dbClient -> - ensureJobTable(dbClient).flatMap { - waitForTableStatus(dbClient, TableStatus.ACTIVE) - }.flatMap { - IO { - dbClient.putItem { - it.tableName(tableName) - .item( - mapOf( - KEY_JOB_NAME to AttributeValue.builder().s(job.name).build(), - KEY_DATA to AttributeValue.builder().s(job.serialize()).build() - ) - ) - } - - Unit - } - } - } - - override fun updateJobStatus(job: Job, newStatus: TrainingScriptProgress): IO = - dbClient.flatMap { dbClient -> - ensureJobTable(dbClient).flatMap { - waitForTableStatus(dbClient, TableStatus.ACTIVE) - }.flatMap { - IO { - val newJob = job.copy(status = newStatus) - - dbClient.updateItem { - it.tableName(tableName) - .key(mapOf(KEY_JOB_NAME to AttributeValue.builder().s(job.name).build())) - .updateExpression( - """SET $KEY_DATA = :newJobData""" - ) - .expressionAttributeValues( - mapOf(""":newJobData""" to AttributeValue.builder().s(newJob.serialize()).build()) - ) - } - - newJob - } - } - } - - override fun getJobWithName(name: String): IO = dbClient.flatMap { dbClient -> - ensureJobTable(dbClient).flatMap { - waitForTableStatus(dbClient, TableStatus.ACTIVE) - }.flatMap { - IO { - Job.deserialize( - dbClient.getItem { - it.tableName(tableName) - .key(mapOf(KEY_JOB_NAME to AttributeValue.builder().s(name).build())) - .projectionExpression(KEY_DATA) - }.item()[KEY_DATA]!!.s() - ) - } - } - } - - override fun getJobsWithStatus(status: TrainingScriptProgress): IO> = TODO() - - override fun deleteTable(): IO = dbClient.flatMap { dbClient -> - IO { - dbClient.deleteTable { - it.tableName(tableName) - } - - Unit - } - } - - /** - * Ensures the job table exists. - * - * @param client The client to use. - */ - private fun ensureJobTable(client: DynamoDbClient): IO { - return IO { - if (!client.listTables().tableNames().contains(tableName)) { - // Only create the table if it does not exist - client.createTable { - it.tableName(tableName) - .keySchema( - KeySchemaElement.builder() - .keyType(KeyType.HASH) - .attributeName(KEY_JOB_NAME) - .build() - ) - .attributeDefinitions( - AttributeDefinition.builder() - .attributeName(KEY_JOB_NAME) - .attributeType(ScalarAttributeType.S) - .build() - ) - .provisionedThroughput { - // TODO: Expose these - it.readCapacityUnits(1) - .writeCapacityUnits(1) - } - } - } - - Unit - } - } - - /** - * Waits (tail-recursively) for the [TableStatus] to equal a [desiredStatus]. - * - * @param dbClient The client to use. - * @param desiredStatus The desired status. - * @return Waits for the [desiredStatus]. - */ - private fun waitForTableStatus( - dbClient: DynamoDbClient, - desiredStatus: TableStatus - ): IO = - IO.tailRecM(dbClient) { i -> - IO { - val currentStatus = i.describeTable { - it.tableName(tableName) - }.table().tableStatus() - - if (currentStatus == desiredStatus) { - Right(i) - } else { - // Table is not at the status yet, so wait to check again - delay(500) - Left(i) - } - } - } - - companion object { - - const val KEY_JOB_NAME = "JobName" - const val KEY_DATA = "JobData" - } -} diff --git a/aws/src/main/kotlin/edu/wpi/axon/aws/db/JobDB.kt b/aws/src/main/kotlin/edu/wpi/axon/aws/db/JobDB.kt deleted file mode 100644 index e195264d..00000000 --- a/aws/src/main/kotlin/edu/wpi/axon/aws/db/JobDB.kt +++ /dev/null @@ -1,48 +0,0 @@ -package edu.wpi.axon.aws.db - -import arrow.fx.IO -import edu.wpi.axon.dbdata.Job -import edu.wpi.axon.dbdata.TrainingScriptProgress - -interface JobDB { - - /** - * Puts (create or replace) the [job] in the DB. - * - * @param job The new [Job]. - * @return An effect for continuation. - */ - fun putJob(job: Job): IO - - /** - * Updates the status of a job in the DB. - * - * @param job The [Job] to update (with the old status). - * @param newStatus The new status. - * @return The new [Job] (with the [newStatus]). - */ - fun updateJobStatus(job: Job, newStatus: TrainingScriptProgress): IO - - /** - * Gets a job with a [name]. - * - * @param name The name of the [Job]. - * @return The [Job]. - */ - fun getJobWithName(name: String): IO - - /** - * Gets all [Job]s with the [status]. - * - * @param status The status. - * @return The matching [Job]s. - */ - fun getJobsWithStatus(status: TrainingScriptProgress): IO> - - /** - * Deletes the table containing the jobs. - * - * @return An effect for continuation. - */ - fun deleteTable(): IO -} diff --git a/aws/src/main/kotlin/edu/wpi/axon/aws/db/RDSJobDBConfigurator.kt b/aws/src/main/kotlin/edu/wpi/axon/aws/db/RDSJobDBConfigurator.kt new file mode 100644 index 00000000..255b6406 --- /dev/null +++ b/aws/src/main/kotlin/edu/wpi/axon/aws/db/RDSJobDBConfigurator.kt @@ -0,0 +1,141 @@ +package edu.wpi.axon.aws.db + +import arrow.fx.IO +import arrow.fx.extensions.fx +import com.amazonaws.regions.Regions +import com.amazonaws.services.rds.AmazonRDSClient +import com.amazonaws.services.rds.model.CreateDBClusterRequest +import com.amazonaws.services.rds.model.DBCluster +import com.amazonaws.services.rds.model.DescribeDBClustersRequest +import com.amazonaws.services.rds.model.ScalingConfiguration +import com.amazonaws.services.rds.model.VpcSecurityGroupMembership +import mu.KotlinLogging +import org.koin.core.KoinComponent +import org.koin.core.inject +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.ec2.Ec2Client +import software.amazon.awssdk.services.ec2.model.IpPermission +import software.amazon.awssdk.services.ec2.model.IpRange +import software.amazon.awssdk.services.ec2.model.Ipv6Range + +/** + * Configures a Job DB in RDS. + */ +class RDSJobDBConfigurator : KoinComponent { + + private val regions: Regions by inject() + private val region: Region by inject() + private val ec2Client by lazy { Ec2Client.builder().region(region).build() } + private val rdsClient by lazy { AmazonRDSClient.builder().withRegion(regions).build() } + + private fun ensureDBCluster(): IO = IO { + val clusters = rdsClient.describeDBClusters( + DescribeDBClustersRequest() + .withDBClusterIdentifier(clusterId) + ).dbClusters + LOGGER.debug { "Got clusters: ${clusters.joinToString { it.dbClusterIdentifier }}" } + check(clusters.size == 1) + clusters.first() + }.redeemWith( + { + LOGGER.warn { "Did not find DB cluster with id $clusterId. Creating a new DB cluster." } + ensureSG().flatMap { sgId -> + IO { + rdsClient.createDBCluster( + CreateDBClusterRequest() + .withDBClusterIdentifier(clusterId) + .withEngine("aurora") + .withEngineVersion("5.6.10a") + .withEngineMode("serverless") + .withScalingConfiguration( + ScalingConfiguration().withMinCapacity(1) + .withMaxCapacity(2) + .withSecondsUntilAutoPause(300) + .withAutoPause(true) + ) + .withMasterUsername("axonusername") + .withMasterUserPassword("axonpassword") + .withEnableHttpEndpoint(true) + .withVpcSecurityGroupIds(sgId) + ) + } + } + }, + { IO.just(it) } + ) + + private fun ensureDBClusterHasCorrectSG(dbCluster: DBCluster): IO = IO.fx { + val sgId = ensureSG().bind() + val sg = VpcSecurityGroupMembership().withVpcSecurityGroupId(sgId).withStatus("active") + val vpcSecurityGroups = dbCluster.vpcSecurityGroups + if (!vpcSecurityGroups.any { it == sg }) { + // SG is not present, need to add it + dbCluster.setVpcSecurityGroups(vpcSecurityGroups + sg) + } + } + + /** + * Ensures the correct SecurityGroup exists. + * + * @return The SecurityGroup id. + */ + private fun ensureSG(): IO = IO { + ec2Client.describeSecurityGroups { + it.groupNames(rdsSecurityGroupName) + }.securityGroups() + }.flatMap { rdsGroups -> + IO { + rdsGroups.first { + it.groupName() == rdsSecurityGroupName + }.groupId() + } + }.redeemWith( + { + LOGGER.debug { "SecurityGroup $rdsSecurityGroupName not found. Creating one." } + createSG() + }, + { IO.just(it) } + ) + + private fun createSG(): IO = IO { + LOGGER.debug { "Creating new SecurityGroup $rdsSecurityGroupName." } + val sg = ec2Client.createSecurityGroup { + it.groupName(rdsSecurityGroupName) + .description("Axon autogenerated for RDS.") + }.groupId() + + ec2Client.authorizeSecurityGroupIngress { + it.groupId(sg) + .ipPermissions( + IpPermission.builder() + .fromPort(3306) + .toPort(3306) + .ipProtocol("TCP") + .ipRanges(IpRange.builder().cidrIp("0.0.0.0/0").build()) + .ipv6Ranges(Ipv6Range.builder().cidrIpv6("::/0").build()) + .build() + ) + } + + ec2Client.authorizeSecurityGroupEgress { + it.groupId(sg) + .ipPermissions( + IpPermission.builder() + .fromPort(-1) + .toPort(-1) + .ipProtocol("-1") + .ipRanges(IpRange.builder().cidrIp("0.0.0.0/0").build()) + .ipv6Ranges(Ipv6Range.builder().cidrIpv6("::/0").build()) + .build() + ) + } + + sg + } + + companion object { + private val LOGGER = KotlinLogging.logger { } + private const val clusterId = "axon-cluster" + private const val rdsSecurityGroupName = "axon-rds-autogenerated" + } +} diff --git a/aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt b/aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt index c6b26983..e594c1d7 100644 --- a/aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt +++ b/aws/src/test/kotlin/edu/wpi/axon/aws/EC2TrainingScriptRunnerTest.kt @@ -1,28 +1,21 @@ package edu.wpi.axon.aws -import edu.wpi.axon.testutil.KoinTestFixture import org.junit.jupiter.api.Disabled import org.junit.jupiter.api.Test -import org.koin.core.context.startKoin -import org.koin.core.qualifier.named -import org.koin.dsl.module import software.amazon.awssdk.regions.Region import software.amazon.awssdk.services.ec2.model.InstanceType -internal class EC2TrainingScriptRunnerTest : KoinTestFixture() { +internal class EC2TrainingScriptRunnerTest { @Test @Disabled("Needs EC2 supervision.") fun `test running mnist training script`() { - startKoin { - modules(module { - single { Region.US_EAST_1 } - single(named("bucketName")) { "axon-salmon-testbucket2" } - single { InstanceType.T3_MEDIUM } - }) - } + val runner = EC2TrainingScriptRunner( + "axon-salmon-testbucket2", + InstanceType.T3_MEDIUM, + Region.US_EAST_1 + ) - val runner = EC2TrainingScriptRunner() runner.startScript( "custom_fashion_mnist.h5", "custom_fashion_mnist-trained.h5", diff --git a/aws/src/test/kotlin/edu/wpi/axon/aws/db/DynamoJobDBTest.kt b/aws/src/test/kotlin/edu/wpi/axon/aws/db/DynamoJobDBTest.kt deleted file mode 100644 index 3f21ddf2..00000000 --- a/aws/src/test/kotlin/edu/wpi/axon/aws/db/DynamoJobDBTest.kt +++ /dev/null @@ -1,132 +0,0 @@ -package edu.wpi.axon.aws.db - -import arrow.fx.IO -import edu.wpi.axon.dbdata.Job -import edu.wpi.axon.dbdata.TrainingScriptProgress -import edu.wpi.axon.dbdata.nextDataset -import edu.wpi.axon.testutil.KoinTestFixture -import edu.wpi.axon.tfdata.loss.Loss -import edu.wpi.axon.tfdata.optimizer.Optimizer -import io.kotlintest.assertions.arrow.either.shouldBeRight -import io.kotlintest.matchers.collections.shouldContain -import io.kotlintest.shouldBe -import kotlin.random.Random -import org.apache.commons.lang3.RandomStringUtils -import org.junit.jupiter.api.Disabled -import org.junit.jupiter.api.Test -import org.koin.core.context.startKoin -import org.koin.core.get -import org.koin.dsl.module -import software.amazon.awssdk.regions.Region -import software.amazon.awssdk.services.dynamodb.DynamoDbClient -import software.amazon.awssdk.services.dynamodb.model.AttributeValue - -internal class DynamoJobDBTest : KoinTestFixture() { - - @Test - @Disabled("Needs DynamoDB supervision.") - fun `create new job`() { - startKoin { - modules(module { - single { Region.US_EAST_1 } - }) - } - - withRandomTable { db, tableName -> - IO { - val job = Random.nextJob() - - val result = db.putJob(job).attempt().unsafeRunSync() - - result.shouldBeRight { - val dbClient = DynamoDbClient.builder().region(get()).build() - dbClient.listTables().tableNames().shouldContain(tableName) - dbClient.getItem { - it.tableName(tableName) - .key( - mapOf( - DynamoJobDB.KEY_JOB_NAME to AttributeValue.builder() - .s(job.name) - .build() - ) - ) - }.item().let { jobFromDB -> - jobFromDB[DynamoJobDB.KEY_JOB_NAME]!!.s().shouldBe(job.name) - jobFromDB[DynamoJobDB.KEY_DATA]!!.s().let { jobData -> - jobData.shouldBe(job.serialize()) - Job.deserialize(jobData).shouldBe(job) - } - } - } - } - } - } - - @Test - @Disabled("Needs DynamoDB supervision.") - fun `update job status`() { - startKoin { - modules(module { - single { Region.US_EAST_1 } - }) - } - - withRandomTable { db, _ -> - IO { - val job = Random.nextJob() - - val result = db.putJob(job).attempt().unsafeRunSync() - - result.shouldBeRight { - db.getJobWithName(job.name).map { - it.shouldBe(job) - }.flatMap { - db.updateJobStatus(job, TrainingScriptProgress.NotStarted) - }.map { - it.shouldBe(job.copy(status = TrainingScriptProgress.NotStarted)) - }.unsafeRunSync() - } - } - } - } - - /** - * Runs the [testBody] with a [DynamoJobDB] and a random, newly created table. Deletes the table - * when finished. - * - * @param testBody The body of the test method. Given a new [DynamoJobDB] and a table name. - */ - private fun withRandomTable(testBody: (DynamoJobDB, String) -> IO) { - val tableName = RandomStringUtils.randomAlphanumeric(10) - IO { - DynamoJobDB(tableName) - }.bracket( - release = { - it.deleteTable() - }, - use = { testBody(it, tableName) } - ).unsafeRunSync() - } - - private fun Random.nextJob() = Job( - RandomStringUtils.randomAlphanumeric(10), - TrainingScriptProgress.Completed, - RandomStringUtils.randomAlphanumeric(10), - RandomStringUtils.randomAlphanumeric(10), - nextDataset(), - Optimizer.Adam( - nextDouble(), - nextDouble(), - nextDouble(), - nextDouble(), - nextBoolean() - ), - Loss.SparseCategoricalCrossentropy, - setOf( - RandomStringUtils.randomAlphanumeric(10), - RandomStringUtils.randomAlphanumeric(10) - ), - nextInt(), - nextBoolean() - ) -} diff --git a/axon.gradle.kts b/axon.gradle.kts index 2d2c771d..ba8ed493 100644 --- a/axon.gradle.kts +++ b/axon.gradle.kts @@ -22,6 +22,7 @@ plugins { } val awsProject = project(":aws") +val dbProject = project(":db") val dbDataProject = project(":db-data") val dbDataTestUtilProject = project(":db-data-test-util") val dslProject = project(":dsl") @@ -38,10 +39,12 @@ val tfLayerLoaderProject = project(":tf-layer-loader") val uiElectronProject = project(":ui-electron") val uiVaadinProject = project(":ui-vaadin") val trainingProject = project(":training") +val trainingTestUtilProject = project(":training-test-util") val utilProject = project(":util") val kotlinProjects = setOf( awsProject, + dbProject, dbDataProject, dbDataTestUtilProject, dslProject, @@ -58,6 +61,7 @@ val kotlinProjects = setOf( uiElectronProject, uiVaadinProject, trainingProject, + trainingTestUtilProject, utilProject ) @@ -65,6 +69,7 @@ val javaProjects = setOf() + kotlinProjects val publishedProjects = setOf( awsProject, + dbProject, dbDataProject, dslProject, dslInterfaceProject, @@ -126,6 +131,7 @@ allprojects { maven("https://oss.sonatype.org/content/repositories/staging/") maven("https://dl.bintray.com/arrow-kt/arrow-kt/") maven("https://dl.bintray.com/jamesmudd/jhdf") + maven("https://dl.bintray.com/kotlin/exposed") maven("https://dl.bintray.com/octogonapus/maven-artifacts") } diff --git a/db-data-test-util/src/main/kotlin/edu/wpi/axon/dbdata/TestUtil.kt b/db-data-test-util/src/main/kotlin/edu/wpi/axon/dbdata/TestUtil.kt index 08c70c0e..b3146e5c 100644 --- a/db-data-test-util/src/main/kotlin/edu/wpi/axon/dbdata/TestUtil.kt +++ b/db-data-test-util/src/main/kotlin/edu/wpi/axon/dbdata/TestUtil.kt @@ -10,6 +10,9 @@ fun Random.nextDataset(): Dataset { it[nextInt(it.size)].objectInstance!! } } else { - Dataset.Custom(RandomStringUtils.randomAlphanumeric(20)) + Dataset.Custom( + RandomStringUtils.randomAlphanumeric(20), + RandomStringUtils.randomAlphanumeric(20) + ) } } diff --git a/db-data/db-data.gradle.kts b/db-data/db-data.gradle.kts index dbca9280..5ee008b3 100644 --- a/db-data/db-data.gradle.kts +++ b/db-data/db-data.gradle.kts @@ -4,4 +4,6 @@ dependencies { api(project(":tf-data")) implementation(project(":util")) + + testImplementation(project(":test-util")) } diff --git a/db-data/src/main/kotlin/edu/wpi/axon/dbdata/Job.kt b/db-data/src/main/kotlin/edu/wpi/axon/dbdata/Job.kt index 4397b8b9..a7449874 100644 --- a/db-data/src/main/kotlin/edu/wpi/axon/dbdata/Job.kt +++ b/db-data/src/main/kotlin/edu/wpi/axon/dbdata/Job.kt @@ -23,7 +23,8 @@ data class Job( @Polymorphic val userLoss: Loss, val userMetrics: Set, val userEpochs: Int, - val generateDebugComments: Boolean + val generateDebugComments: Boolean, + val id: Int = -1 ) { fun serialize(): String = Json( diff --git a/db-data/src/main/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgress.kt b/db-data/src/main/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgress.kt index cd9eee85..635c11de 100644 --- a/db-data/src/main/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgress.kt +++ b/db-data/src/main/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgress.kt @@ -1,13 +1,16 @@ package edu.wpi.axon.dbdata import edu.wpi.axon.util.ObjectSerializer +import kotlinx.serialization.Polymorphic import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonConfiguration import kotlinx.serialization.modules.SerializersModule /** * The states a training script can be in. */ -sealed class TrainingScriptProgress { +sealed class TrainingScriptProgress : Comparable { /** * The script has not been started yet. @@ -26,6 +29,27 @@ sealed class TrainingScriptProgress { * The training is finished. */ object Completed : TrainingScriptProgress() + + fun serialize(): String = Json( + JsonConfiguration.Stable, + context = trainingScriptProgressModule + ).stringify(PolymorphicWrapper.serializer(), PolymorphicWrapper(this)) + + override fun compareTo(other: TrainingScriptProgress): Int { + return COMPARATOR.compare(this, other) + } + + companion object { + fun deserialize(data: String): TrainingScriptProgress = Json( + JsonConfiguration.Stable, + context = trainingScriptProgressModule + ).parse(PolymorphicWrapper.serializer(), data).wrapped + + private val COMPARATOR = Comparator.comparing { it.ordinal() } + } + + @Serializable + private data class PolymorphicWrapper(@Polymorphic val wrapped: TrainingScriptProgress) } val trainingScriptProgressModule = SerializersModule { @@ -44,3 +68,6 @@ val trainingScriptProgressModule = SerializersModule { ) } } + +inline fun T.ordinal() = + T::class.java.superclass.classes.indexOfFirst { sub -> sub == this@ordinal::class.java } diff --git a/db-data/src/test/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgressTest.kt b/db-data/src/test/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgressTest.kt new file mode 100644 index 00000000..37667f9f --- /dev/null +++ b/db-data/src/test/kotlin/edu/wpi/axon/dbdata/TrainingScriptProgressTest.kt @@ -0,0 +1,11 @@ +package edu.wpi.axon.dbdata + +import io.kotlintest.shouldBe +import org.junit.jupiter.api.Test + +internal class TrainingScriptProgressTest { + @Test + fun `test serialize`() { + TrainingScriptProgress.deserialize(TrainingScriptProgress.Completed.serialize()).shouldBe(TrainingScriptProgress.Completed) + } +} diff --git a/db/db.gradle.kts b/db/db.gradle.kts new file mode 100644 index 00000000..33b0f940 --- /dev/null +++ b/db/db.gradle.kts @@ -0,0 +1,23 @@ +description = "Classes to interact with a database" + +dependencies { + api(project(":db-data")) + api(project(":tf-data")) + + implementation(project(":util")) + + testImplementation(project(":test-util")) + testImplementation(project(":db-data-test-util")) + + implementation( + group = "org.jetbrains.exposed", + name = "exposed", + version = property("exposed.version") as String + ) + + implementation( + group = "com.beust", + name = "klaxon", + version = property("klaxon.version") as String + ) +} diff --git a/db/src/main/kotlin/edu/wpi/axon/db/JobDb.kt b/db/src/main/kotlin/edu/wpi/axon/db/JobDb.kt new file mode 100644 index 00000000..e20be7a0 --- /dev/null +++ b/db/src/main/kotlin/edu/wpi/axon/db/JobDb.kt @@ -0,0 +1,91 @@ +package edu.wpi.axon.db + +import com.beust.klaxon.Klaxon +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.dbdata.TrainingScriptProgress +import edu.wpi.axon.tfdata.Dataset +import edu.wpi.axon.tfdata.loss.Loss +import edu.wpi.axon.tfdata.optimizer.Optimizer +import org.jetbrains.exposed.dao.IntIdTable +import org.jetbrains.exposed.sql.Database +import org.jetbrains.exposed.sql.ResultRow +import org.jetbrains.exposed.sql.SchemaUtils +import org.jetbrains.exposed.sql.deleteWhere +import org.jetbrains.exposed.sql.insertAndGetId +import org.jetbrains.exposed.sql.select +import org.jetbrains.exposed.sql.selectAll +import org.jetbrains.exposed.sql.transactions.transaction + +private val klaxon = Klaxon() + +internal object Jobs : IntIdTable() { + val name = varchar("name", 255).uniqueIndex() + val status = varchar("status", 255) + val userOldModelPath = varchar("userOldModelPath", 255) + val userNewModelName = varchar("userNewModelName", 255) + val userDataset = varchar("dataset", 255) + val userOptimizer = varchar("userOptimizer", 255) + val userLoss = varchar("userLoss", 255) + val userMetrics = varchar("userMetrics", 255) + val userEpochs = integer("userEpochs") + val generateDebugComments = bool("generateDebugComments") + + fun toDomain(row: ResultRow): Job { + return Job( + name = row[name], + status = TrainingScriptProgress.deserialize(row[status]), + userOldModelPath = row[userOldModelPath], + userNewModelName = row[userNewModelName], + userDataset = Dataset.deserialize(row[userDataset]), + userOptimizer = Optimizer.deserialize(row[userOptimizer]), + userLoss = Loss.deserialize(row[userLoss]), + userMetrics = klaxon.parseArray(row[userMetrics])!!.toSet(), + userEpochs = row[userEpochs], + generateDebugComments = row[generateDebugComments], + id = row[id].value + ) + } +} + +class JobDb(private val database: Database) { + init { + transaction(database) { + SchemaUtils.create(Jobs) + } + } + + fun create(job: Job): Int? = transaction(database) { + Jobs.insertAndGetId { row -> + row[name] = job.name + row[status] = job.status.serialize() + row[userOldModelPath] = job.userOldModelPath + row[userNewModelName] = job.userNewModelName + row[userDataset] = job.userDataset.serialize() + row[userOptimizer] = job.userOptimizer.serialize() + row[userLoss] = job.userLoss.serialize() + row[userMetrics] = klaxon.toJsonString(job.userMetrics) + row[userEpochs] = job.userEpochs + row[generateDebugComments] = job.generateDebugComments + }.value + } + + fun count(): Int = transaction(database) { + Jobs.selectAll().count() + } + + fun fetch(limit: Int, offset: Int): List = transaction(database) { + Jobs.selectAll() + .limit(limit, offset) + .map { Jobs.toDomain(it) } + } + + fun findByName(name: String): Job? = transaction(database) { + Jobs.select { Jobs.name eq name } + .map { Jobs.toDomain(it) } + .firstOrNull() + } + + fun remove(id: Int): Int? = transaction(database) { + Jobs.deleteWhere { Jobs.id eq id } + } +} diff --git a/db/src/test/kotlin/edu/wpi/axon/db/JobDbTest.kt b/db/src/test/kotlin/edu/wpi/axon/db/JobDbTest.kt new file mode 100644 index 00000000..410f3c94 --- /dev/null +++ b/db/src/test/kotlin/edu/wpi/axon/db/JobDbTest.kt @@ -0,0 +1,99 @@ +package edu.wpi.axon.db + +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.dbdata.TrainingScriptProgress +import edu.wpi.axon.dbdata.nextDataset +import edu.wpi.axon.tfdata.loss.Loss +import edu.wpi.axon.tfdata.optimizer.Optimizer +import io.kotlintest.matchers.collections.shouldContainExactly +import io.kotlintest.matchers.nulls.shouldBeNull +import io.kotlintest.matchers.nulls.shouldNotBeNull +import io.kotlintest.shouldBe +import java.io.File +import java.nio.file.Paths +import kotlin.random.Random +import org.apache.commons.lang3.RandomStringUtils +import org.jetbrains.exposed.sql.Database +import org.jetbrains.exposed.sql.select +import org.jetbrains.exposed.sql.transactions.transaction +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.io.TempDir + +internal class JobDbTest { + @Test + fun `create test`(@TempDir tempDir: File) { + val db = createDb(tempDir) + val job = Random.nextJob() + + val id = db.create(job) + id.shouldNotBeNull() + + transaction { + Jobs.select { Jobs.name eq job.name } + .map { Jobs.toDomain(it) } + .shouldContainExactly(job.copy(id = id)) + } + } + + @Test + fun `find by name test`(@TempDir tempDir: File) { + val db = createDb(tempDir) + val job = Random.nextJob() + + val id = db.create(job)!! + + db.findByName(job.name).shouldBe(job.copy(id = id)) + } + + @Test + fun `count test`(@TempDir tempDir: File) { + val db = createDb(tempDir) + + db.count().shouldBe(0) + + db.create(Random.nextJob()) + + db.count().shouldBe(1) + } + + @Test + fun `remove test`(@TempDir tempDir: File) { + val db = createDb(tempDir) + val job = Random.nextJob() + + val id = db.create(job)!! + + db.remove(id).shouldBe(id) + db.findByName(job.name).shouldBeNull() + } + + private fun createDb(tempDir: File) = JobDb( + Database.connect( + url = "jdbc:h2:file:${Paths.get(tempDir.absolutePath, "test.db")}", + driver = "org.h2.Driver" + ) + ) + + private fun Random.nextJob() = Job( + RandomStringUtils.randomAlphanumeric(10), + TrainingScriptProgress.Completed, + RandomStringUtils.randomAlphanumeric(10), + RandomStringUtils.randomAlphanumeric(10), + nextDataset(), + Optimizer.Adam( + nextDouble(), + nextDouble(), + nextDouble(), + nextDouble(), + nextBoolean() + ), + Loss.SparseCategoricalCrossentropy, + setOf( + RandomStringUtils.randomAlphanumeric(10), + RandomStringUtils.randomAlphanumeric(10) + ), + nextInt(), + nextBoolean(), + -1 + ) +} diff --git a/gradle.properties b/gradle.properties index 5504e9f3..4609a5f2 100644 --- a/gradle.properties +++ b/gradle.properties @@ -5,7 +5,7 @@ kotlin.parallel.tasks.in.project=true axon.version=0.1.0 kotlinVersion=1.3.50 -gradle-wrapper.version=6.0 +gradle-wrapper.version=5.6.4 spotlessPluginVersion=3.26.0 ktlintPluginVersion=9.1.1 @@ -46,3 +46,5 @@ kt-guava.version=0.2.6 commons-lang3.version=3.9 commons-io.version=1.3.2 aws-sdk-java.version=2.10.12 +exposed.version=0.17.7 +mysql-connector-java.version=5.1.46 \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar index 5c2d1cf0..cc4fdc29 100644 Binary files a/gradle/wrapper/gradle-wrapper.jar and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 562e2c88..0ebb3108 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,5 +1,5 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-6.0-all.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-5.6.4-all.zip zipStoreBase=GRADLE_USER_HOME zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew index 83f2acfd..2fe81a7d 100755 --- a/gradlew +++ b/gradlew @@ -154,19 +154,19 @@ if [ "$cygwin" = "true" -o "$msys" = "true" ] ; then else eval `echo args$i`="\"$arg\"" fi - i=$((i+1)) + i=`expr $i + 1` done case $i in - (0) set -- ;; - (1) set -- "$args0" ;; - (2) set -- "$args0" "$args1" ;; - (3) set -- "$args0" "$args1" "$args2" ;; - (4) set -- "$args0" "$args1" "$args2" "$args3" ;; - (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; - (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; - (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; - (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; - (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + 0) set -- ;; + 1) set -- "$args0" ;; + 2) set -- "$args0" "$args1" ;; + 3) set -- "$args0" "$args1" "$args2" ;; + 4) set -- "$args0" "$args1" "$args2" "$args3" ;; + 5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + 6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + 7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + 8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + 9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; esac fi @@ -175,14 +175,9 @@ save () { for i do printf %s\\n "$i" | sed "s/'/'\\\\''/g;1s/^/'/;\$s/\$/' \\\\/" ; done echo " " } -APP_ARGS=$(save "$@") +APP_ARGS=`save "$@"` # Collect all arguments for the java command, following the shell quoting and substitution rules eval set -- $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS "\"-Dorg.gradle.appname=$APP_BASE_NAME\"" -classpath "\"$CLASSPATH\"" org.gradle.wrapper.GradleWrapperMain "$APP_ARGS" -# by default we should be in the correct project dir, but when run from Finder on Mac, the cwd is wrong -if [ "$(uname)" = "Darwin" ] && [ "$HOME" = "$PWD" ]; then - cd "$(dirname "$0")" -fi - exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat index 9618d8d9..24467a14 100644 --- a/gradlew.bat +++ b/gradlew.bat @@ -1,100 +1,100 @@ -@rem -@rem Copyright 2015 the original author or authors. -@rem -@rem Licensed under the Apache License, Version 2.0 (the "License"); -@rem you may not use this file except in compliance with the License. -@rem You may obtain a copy of the License at -@rem -@rem https://www.apache.org/licenses/LICENSE-2.0 -@rem -@rem Unless required by applicable law or agreed to in writing, software -@rem distributed under the License is distributed on an "AS IS" BASIS, -@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -@rem See the License for the specific language governing permissions and -@rem limitations under the License. -@rem - -@if "%DEBUG%" == "" @echo off -@rem ########################################################################## -@rem -@rem Gradle startup script for Windows -@rem -@rem ########################################################################## - -@rem Set local scope for the variables with windows NT shell -if "%OS%"=="Windows_NT" setlocal - -set DIRNAME=%~dp0 -if "%DIRNAME%" == "" set DIRNAME=. -set APP_BASE_NAME=%~n0 -set APP_HOME=%DIRNAME% - -@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. -set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" - -@rem Find java.exe -if defined JAVA_HOME goto findJavaFromJavaHome - -set JAVA_EXE=java.exe -%JAVA_EXE% -version >NUL 2>&1 -if "%ERRORLEVEL%" == "0" goto init - -echo. -echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:findJavaFromJavaHome -set JAVA_HOME=%JAVA_HOME:"=% -set JAVA_EXE=%JAVA_HOME%/bin/java.exe - -if exist "%JAVA_EXE%" goto init - -echo. -echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% -echo. -echo Please set the JAVA_HOME variable in your environment to match the -echo location of your Java installation. - -goto fail - -:init -@rem Get command-line arguments, handling Windows variants - -if not "%OS%" == "Windows_NT" goto win9xME_args - -:win9xME_args -@rem Slurp the command line arguments. -set CMD_LINE_ARGS= -set _SKIP=2 - -:win9xME_args_slurp -if "x%~1" == "x" goto execute - -set CMD_LINE_ARGS=%* - -:execute -@rem Setup the command line - -set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar - -@rem Execute Gradle -"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% - -:end -@rem End local scope for the variables with windows NT shell -if "%ERRORLEVEL%"=="0" goto mainEnd - -:fail -rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of -rem the _cmd.exe /c_ return code! -if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 -exit /b 1 - -:mainEnd -if "%OS%"=="Windows_NT" endlocal - -:omega +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windows variants + +if not "%OS%" == "Windows_NT" goto win9xME_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/settings.gradle.kts b/settings.gradle.kts index 8057b1db..7dcfa76b 100644 --- a/settings.gradle.kts +++ b/settings.gradle.kts @@ -31,6 +31,7 @@ pluginManagement { rootProject.name = "axon" include(":aws") +include(":db") include(":db-data") include(":db-data-test-util") include(":dsl") @@ -47,6 +48,7 @@ include(":tf-layer-loader") include(":ui-electron") include(":ui-vaadin") include(":training") +include(":training-test-util") include(":util") /** diff --git a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/Dataset.kt b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/Dataset.kt index 40ab1def..2c40514d 100644 --- a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/Dataset.kt +++ b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/Dataset.kt @@ -1,23 +1,46 @@ package edu.wpi.axon.tfdata import edu.wpi.axon.util.ObjectSerializer +import kotlinx.serialization.Polymorphic import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonConfiguration import kotlinx.serialization.modules.SerializersModule -sealed class Dataset { +sealed class Dataset : Comparable { + abstract val displayName: String - sealed class ExampleDataset(val name: String) : Dataset() { - object BostonHousing : ExampleDataset("boston_housing") - object Cifar10 : ExampleDataset("cifar10") - object Cifar100 : ExampleDataset("cifar100") - object FashionMnist : ExampleDataset("fashion_mnist") - object IMDB : ExampleDataset("imdb") - object Mnist : ExampleDataset("mnist") - object Reuters : ExampleDataset("reuters") + sealed class ExampleDataset(val name: String, override val displayName: String) : Dataset() { + object BostonHousing : ExampleDataset("boston_housing", "Boston Housing") + object Cifar10 : ExampleDataset("cifar10", "CIFAR-10") + object Cifar100 : ExampleDataset("cifar100", "CIFAR-100") + object FashionMnist : ExampleDataset("fashion_mnist", "Fashion MNIST") + object IMDB : ExampleDataset("imdb", "IMBD") + object Mnist : ExampleDataset("mnist", "MNIST") + object Reuters : ExampleDataset("reuters", "Reuters") } @Serializable - data class Custom(val pathInS3: String) : Dataset() + data class Custom(val pathInS3: String, override val displayName: String) : Dataset() + + override fun compareTo(other: Dataset) = COMPARATOR.compare(this, other) + + companion object { + private val COMPARATOR = Comparator.comparing { it.displayName } + + fun deserialize(data: String): Dataset = Json( + JsonConfiguration.Stable, + context = datasetModule + ).parse(PolymorphicWrapper.serializer(), data).wrapped + } + + fun serialize(): String = Json( + JsonConfiguration.Stable, + context = datasetModule + ).stringify(PolymorphicWrapper.serializer(), PolymorphicWrapper(this)) + + @Serializable + private data class PolymorphicWrapper(@Polymorphic val wrapped: Dataset) } val datasetModule = SerializersModule { diff --git a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/TrainState.kt b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/TrainState.kt deleted file mode 100644 index 8ecefd6a..00000000 --- a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/TrainState.kt +++ /dev/null @@ -1,27 +0,0 @@ -package edu.wpi.axon.tfdata - -import edu.wpi.axon.tfdata.loss.Loss -import edu.wpi.axon.tfdata.optimizer.Optimizer - -/** - * All the data needed to train a model. - * - * @param modelPath The path to the model file. - * @param dataset The dataset to train on. - * @param optimizer The [Optimizer] to use. - * @param loss The [Loss] function to use. - * @param metrics Any metrics. - * @param epochs The number of epochs. - * @param newModel The new model. - * @param generateDebugComments Whether to put debug comments in the output. - */ -data class TrainState( - val modelPath: String, - val dataset: Dataset, - val optimizer: Optimizer, - val loss: Loss, - val metrics: Set, - val epochs: Int, - val newModel: Model, - val generateDebugComments: Boolean = false -) diff --git a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/loss/Loss.kt b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/loss/Loss.kt index 8771a019..87531894 100644 --- a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/loss/Loss.kt +++ b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/loss/Loss.kt @@ -1,11 +1,30 @@ package edu.wpi.axon.tfdata.loss import edu.wpi.axon.util.ObjectSerializer +import kotlinx.serialization.Polymorphic +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonConfiguration import kotlinx.serialization.modules.SerializersModule sealed class Loss { object SparseCategoricalCrossentropy : Loss() + + fun serialize(): String = Json( + JsonConfiguration.Stable, + context = lossModule + ).stringify(PolymorphicWrapper.serializer(), PolymorphicWrapper(this)) + + companion object { + fun deserialize(data: String): Loss = Json( + JsonConfiguration.Stable, + context = lossModule + ).parse(PolymorphicWrapper.serializer(), data).wrapped + } + + @Serializable + private data class PolymorphicWrapper(@Polymorphic val wrapped: Loss) } val lossModule = SerializersModule { diff --git a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/optimizer/Optimizer.kt b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/optimizer/Optimizer.kt index d563534e..9530adcf 100644 --- a/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/optimizer/Optimizer.kt +++ b/tf-data/src/main/kotlin/edu/wpi/axon/tfdata/optimizer/Optimizer.kt @@ -1,6 +1,9 @@ package edu.wpi.axon.tfdata.optimizer +import kotlinx.serialization.Polymorphic import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import kotlinx.serialization.json.JsonConfiguration import kotlinx.serialization.modules.SerializersModule sealed class Optimizer { @@ -13,6 +16,21 @@ sealed class Optimizer { val epsilon: Double, val amsGrad: Boolean ) : Optimizer() + + fun serialize(): String = Json( + JsonConfiguration.Stable, + context = optimizerModule + ).stringify(PolymorphicWrapper.serializer(), PolymorphicWrapper(this)) + + companion object { + fun deserialize(data: String): Optimizer = Json( + JsonConfiguration.Stable, + context = optimizerModule + ).parse(PolymorphicWrapper.serializer(), data).wrapped + } + + @Serializable + private data class PolymorphicWrapper(@Polymorphic val wrapped: Optimizer) } val optimizerModule = SerializersModule { diff --git a/tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5.kt b/tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5.kt index 494f2525..3f492755 100644 --- a/tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5.kt +++ b/tf-layer-loader/src/main/kotlin/edu/wpi/axon/tflayerloader/LoadLayersFromHDF5.kt @@ -116,11 +116,9 @@ class LoadLayersFromHDF5( // Don't wrap a MetaLayer more than once is Layer.MetaLayer -> layer - else -> { - when (val trainable = json["trainable"] as Boolean?) { - null -> layer.untrainable() - else -> layer.trainable(trainable) - } + else -> when (val trainable = json["trainable"] as Boolean?) { + null -> layer.untrainable() + else -> layer.trainable(trainable) } } } diff --git a/training/src/test/kotlin/edu/wpi/axon/training/TrainTestUtil.kt b/training-test-util/src/main/kotlin/edu/wpi/axon/training/testutil/TrainTestUtil.kt similarity index 92% rename from training/src/test/kotlin/edu/wpi/axon/training/TrainTestUtil.kt rename to training-test-util/src/main/kotlin/edu/wpi/axon/training/testutil/TrainTestUtil.kt index 23c304af..5ea37586 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/TrainTestUtil.kt +++ b/training-test-util/src/main/kotlin/edu/wpi/axon/training/testutil/TrainTestUtil.kt @@ -1,4 +1,4 @@ -package edu.wpi.axon.training +package edu.wpi.axon.training.testutil import arrow.core.Tuple3 import arrow.fx.IO @@ -19,10 +19,10 @@ private val LOGGER = KotlinLogging.logger("training-test-util") * Loads a model with name [modelName] from the test resources. * * @param modelName The name of the model. - * @param stub Used to get the calling class. Do not use this parameter. + * @param stub Used to get the calling class. Just supply an empty lambda. * @return The model and its path. */ -internal fun loadModel(modelName: String, stub: () -> Unit = {}): Pair { +fun loadModel(modelName: String, stub: () -> Unit): Pair { val localModelPath = Paths.get(stub::class.java.getResource(modelName).toURI()).toString() val layers = LoadLayersFromHDF5(DefaultLayersToGraph()) .load(File(localModelPath)) @@ -41,7 +41,7 @@ internal fun loadModel(modelName: String, stub: () -> Unit = {}): Pair, env: Map, dir: File diff --git a/training-test-util/training-test-util.gradle.kts b/training-test-util/training-test-util.gradle.kts new file mode 100644 index 00000000..e5f25d58 --- /dev/null +++ b/training-test-util/training-test-util.gradle.kts @@ -0,0 +1,9 @@ +description = "Utilities for testing code that trains models." + +dependencies { + api(project(":test-util")) + api(project(":dsl")) + api(project(":tf-layer-loader")) + + implementation(project(":logging")) +} diff --git a/training/src/main/kotlin/edu/wpi/axon/training/TrainGeneral.kt b/training/src/main/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGenerator.kt similarity index 92% rename from training/src/main/kotlin/edu/wpi/axon/training/TrainGeneral.kt rename to training/src/main/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGenerator.kt index 50421a43..f72387e8 100644 --- a/training/src/main/kotlin/edu/wpi/axon/training/TrainGeneral.kt +++ b/training/src/main/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGenerator.kt @@ -20,9 +20,9 @@ import java.io.File * * @param trainState The train state to pull all the configuration data from. */ -class TrainGeneral( - private val trainState: TrainState -) { +class TrainGeneralModelScriptGenerator( + override val trainState: TrainState +) : TrainModelScriptGenerator { init { require(trainState.userOldModelName != trainState.userNewModelName) { @@ -34,7 +34,7 @@ class TrainGeneral( private val loadLayersFromHDF5 = LoadLayersFromHDF5(DefaultLayersToGraph()) @Suppress("UNUSED_VARIABLE") - fun generateScript(): Validated, String> = + override fun generateScript(): Validated, String> = loadLayersFromHDF5.load(File(trainState.userOldModelPath)).map { userOldModel -> require(userOldModel is Model.General) diff --git a/training/src/main/kotlin/edu/wpi/axon/training/TrainModelScriptGenerator.kt b/training/src/main/kotlin/edu/wpi/axon/training/TrainModelScriptGenerator.kt new file mode 100644 index 00000000..ab14c49d --- /dev/null +++ b/training/src/main/kotlin/edu/wpi/axon/training/TrainModelScriptGenerator.kt @@ -0,0 +1,23 @@ +package edu.wpi.axon.training + +import arrow.core.NonEmptyList +import arrow.core.Validated +import edu.wpi.axon.tfdata.Model + +/** + * Trains a [Model]. + */ +interface TrainModelScriptGenerator { + + /** + * The train state to pull all the configuration data from. + */ + val trainState: TrainState + + /** + * Generates a script that trains a [Model]. + * + * @return The script or a nel of errors. + */ + fun generateScript(): Validated, String> +} diff --git a/training/src/main/kotlin/edu/wpi/axon/training/TrainSequential.kt b/training/src/main/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGenerator.kt similarity index 92% rename from training/src/main/kotlin/edu/wpi/axon/training/TrainSequential.kt rename to training/src/main/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGenerator.kt index b0d33422..7b801d6d 100644 --- a/training/src/main/kotlin/edu/wpi/axon/training/TrainSequential.kt +++ b/training/src/main/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGenerator.kt @@ -21,9 +21,9 @@ import java.io.File * @param trainState The train state to pull all the configuration data from. */ @Suppress("UNUSED_VARIABLE") -class TrainSequential( - private val trainState: TrainState -) { +class TrainSequentialModelScriptGenerator( + override val trainState: TrainState +) : TrainModelScriptGenerator { init { require(trainState.userOldModelName != trainState.userNewModelName) { @@ -34,7 +34,7 @@ class TrainSequential( private val loadLayersFromHDF5 = LoadLayersFromHDF5(DefaultLayersToGraph()) - fun generateScript(): Validated, String> = + override fun generateScript(): Validated, String> = loadLayersFromHDF5.load(File(trainState.userOldModelPath)).map { oldModel -> require(oldModel is Model.Sequential) diff --git a/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-14-IntegrationTest.kt b/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-14-IntegrationTest.kt index 8efe0667..a64aa204 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-14-IntegrationTest.kt +++ b/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-14-IntegrationTest.kt @@ -8,17 +8,16 @@ import edu.wpi.axon.tfdata.Dataset import edu.wpi.axon.tfdata.Model import edu.wpi.axon.tfdata.loss.Loss import edu.wpi.axon.tfdata.optimizer.Optimizer +import edu.wpi.axon.training.testutil.loadModel import io.kotlintest.assertions.arrow.validation.shouldBeValid import io.kotlintest.matchers.types.shouldBeInstanceOf -import java.io.File import org.junit.jupiter.api.Test -import org.junit.jupiter.api.io.TempDir import org.koin.core.context.startKoin internal class `Mobilenet-v-1-14-IntegrationTest` : KoinTestFixture() { @Test - fun `test with mobilenet`(@TempDir tempDir: File) { + fun `test with mobilenet`() { startKoin { modules(defaultModule()) } @@ -27,9 +26,9 @@ internal class `Mobilenet-v-1-14-IntegrationTest` : KoinTestFixture() { val modelName = "mobilenetv2_1.00_224.h5" val newModelName = "mobilenetv2_1.00_224-trained.h5" - val (model, path) = loadModel(modelName) + val (model, path) = loadModel(modelName) {} model.shouldBeInstanceOf { - TrainGeneral( + TrainGeneralModelScriptGenerator( TrainState( userOldModelPath = path, userNewModelName = newModelName, diff --git a/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-15-IntegrationTest.kt b/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-15-IntegrationTest.kt index 65d5e561..8795bbab 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-15-IntegrationTest.kt +++ b/training/src/test/kotlin/edu/wpi/axon/training/Mobilenet-v-1-15-IntegrationTest.kt @@ -9,6 +9,7 @@ import edu.wpi.axon.tfdata.Model import edu.wpi.axon.tfdata.layer.Activation import edu.wpi.axon.tfdata.layer.DataFormat import edu.wpi.axon.tfdata.layer.Layer +import edu.wpi.axon.training.testutil.loadModel import io.kotlintest.matchers.collections.shouldHaveSize import io.kotlintest.matchers.types.shouldBeInstanceOf import io.kotlintest.shouldBe @@ -24,7 +25,7 @@ internal class `Mobilenet-v-1-15-IntegrationTest` : KoinTestFixture() { } val modelName = "mobilenet_tf_1_15_0.h5" - val (model, _) = loadModel(modelName) + val (model, _) = loadModel(modelName) {} model.shouldBeInstanceOf { it.layers.shouldHaveSize(3) it.layers.toList().let { diff --git a/training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralIntegrationTest.kt b/training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGeneratorIntegrationTest.kt similarity index 82% rename from training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralIntegrationTest.kt rename to training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGeneratorIntegrationTest.kt index 98e2d691..3936e4dc 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralIntegrationTest.kt +++ b/training/src/test/kotlin/edu/wpi/axon/training/TrainGeneralModelScriptGeneratorIntegrationTest.kt @@ -8,26 +8,25 @@ import edu.wpi.axon.tfdata.Dataset import edu.wpi.axon.tfdata.Model import edu.wpi.axon.tfdata.loss.Loss import edu.wpi.axon.tfdata.optimizer.Optimizer +import edu.wpi.axon.training.testutil.loadModel import io.kotlintest.assertions.arrow.validation.shouldBeValid import io.kotlintest.matchers.types.shouldBeInstanceOf -import java.io.File import org.junit.jupiter.api.Test -import org.junit.jupiter.api.io.TempDir import org.koin.core.context.startKoin -internal class TrainGeneralIntegrationTest : KoinTestFixture() { +internal class TrainGeneralModelScriptGeneratorIntegrationTest : KoinTestFixture() { @Test - fun `test with custom model with an add`(@TempDir tempDir: File) { + fun `test with custom model with an add`() { startKoin { modules(defaultModule()) } val modelName = "network_with_add.h5" val newModelName = "network_with_add-trained.h5" - val (model, path) = loadModel(modelName) + val (model, path) = loadModel(modelName) {} model.shouldBeInstanceOf { - TrainGeneral( + TrainGeneralModelScriptGenerator( TrainState( userOldModelPath = path, userNewModelName = newModelName, diff --git a/training/src/test/kotlin/edu/wpi/axon/training/TrainIntegrationTest.kt b/training/src/test/kotlin/edu/wpi/axon/training/TrainIntegrationTest.kt index fd9d6eb1..3c7ffae9 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/TrainIntegrationTest.kt +++ b/training/src/test/kotlin/edu/wpi/axon/training/TrainIntegrationTest.kt @@ -6,6 +6,7 @@ import edu.wpi.axon.tfdata.Dataset import edu.wpi.axon.tfdata.Model import edu.wpi.axon.tfdata.loss.Loss import edu.wpi.axon.tfdata.optimizer.Optimizer +import edu.wpi.axon.training.testutil.loadModel import io.kotlintest.assertions.arrow.validation.shouldBeInvalid import io.kotlintest.assertions.arrow.validation.shouldBeValid import io.kotlintest.matchers.types.shouldBeInstanceOf @@ -22,9 +23,9 @@ internal class TrainIntegrationTest : KoinTestFixture() { modules(defaultModule()) } - val (model, path) = loadModel("network_with_add.h5") + val (model, path) = loadModel("network_with_add.h5") {} model.shouldBeInstanceOf { - TrainGeneral( + TrainGeneralModelScriptGenerator( TrainState( userOldModelPath = path, userNewModelName = "network_with_add-trained.h5", @@ -45,9 +46,9 @@ internal class TrainIntegrationTest : KoinTestFixture() { modules(defaultModule()) } - val (model, path) = loadModel("custom_fashion_mnist.h5") + val (model, path) = loadModel("custom_fashion_mnist.h5") {} model.shouldBeInstanceOf { - TrainSequential( + TrainSequentialModelScriptGenerator( TrainState( userOldModelPath = path, userNewModelName = "custom_fashion_mnist-trained.h5", @@ -68,7 +69,7 @@ internal class TrainIntegrationTest : KoinTestFixture() { modules(defaultModule()) } - TrainGeneral( + TrainGeneralModelScriptGenerator( TrainState( userOldModelPath = Paths.get( this::class.java.getResource("badModel1.h5").toURI() diff --git a/training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialIntegrationTest.kt b/training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGeneratorIntegrationTest.kt similarity index 89% rename from training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialIntegrationTest.kt rename to training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGeneratorIntegrationTest.kt index 1be00aca..3c13adbf 100644 --- a/training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialIntegrationTest.kt +++ b/training/src/test/kotlin/edu/wpi/axon/training/TrainSequentialModelScriptGeneratorIntegrationTest.kt @@ -8,6 +8,8 @@ import edu.wpi.axon.tfdata.Dataset import edu.wpi.axon.tfdata.Model import edu.wpi.axon.tfdata.loss.Loss import edu.wpi.axon.tfdata.optimizer.Optimizer +import edu.wpi.axon.training.testutil.loadModel +import edu.wpi.axon.training.testutil.testTrainingScript import io.kotlintest.assertions.arrow.validation.shouldBeInvalid import io.kotlintest.assertions.arrow.validation.shouldBeValid import io.kotlintest.matchers.types.shouldBeInstanceOf @@ -18,7 +20,7 @@ import org.junit.jupiter.api.Test import org.junit.jupiter.api.io.TempDir import org.koin.core.context.startKoin -internal class TrainSequentialIntegrationTest : KoinTestFixture() { +internal class TrainSequentialModelScriptGeneratorIntegrationTest : KoinTestFixture() { @Test fun `test with bad model`() { @@ -28,7 +30,7 @@ internal class TrainSequentialIntegrationTest : KoinTestFixture() { val localModelPath = Paths.get(this::class.java.getResource("badModel1.h5").toURI()).toString() - TrainSequential( + TrainSequentialModelScriptGenerator( TrainState( userOldModelPath = localModelPath, userNewModelName = "badModel1-trained.h5", @@ -55,9 +57,9 @@ internal class TrainSequentialIntegrationTest : KoinTestFixture() { val modelName = "custom_fashion_mnist.h5" val newModelName = "custom_fashion_mnist-trained.h5" - val (model, path) = loadModel(modelName) + val (model, path) = loadModel(modelName) {} model.shouldBeInstanceOf { - TrainSequential( + TrainSequentialModelScriptGenerator( TrainState( userOldModelPath = path, userNewModelName = newModelName, diff --git a/training/training.gradle.kts b/training/training.gradle.kts index b76c1ff7..9f1ba231 100644 --- a/training/training.gradle.kts +++ b/training/training.gradle.kts @@ -5,5 +5,5 @@ dependencies { implementation(project(":tf-layer-loader")) implementation(project(":logging")) - testImplementation(project(":test-util")) + testImplementation(project(":training-test-util")) } diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/AxonLayout.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/AxonLayout.kt new file mode 100644 index 00000000..38ac46d8 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/AxonLayout.kt @@ -0,0 +1,45 @@ +package edu.wpi.axon.ui + +import com.github.mvysny.karibudsl.v10.h3 +import com.github.mvysny.karibudsl.v10.navbar +import com.github.mvysny.karibudsl.v10.routerLink +import com.github.mvysny.karibudsl.v10.tab +import com.github.mvysny.karibudsl.v10.tabs +import com.vaadin.flow.component.applayout.AppLayout +import com.vaadin.flow.component.icon.VaadinIcon +import com.vaadin.flow.component.page.BodySize +import com.vaadin.flow.component.page.Viewport +import com.vaadin.flow.component.tabs.Tabs +import edu.wpi.axon.ui.view.AboutView +import edu.wpi.axon.ui.view.DatasetView +import edu.wpi.axon.ui.view.JobsView +import edu.wpi.axon.ui.view.TrainingView + +/** + * The main layout of the application. + */ +@BodySize(height = "100vh", width = "100vw") +@Viewport("width=device-width, minimum-scale=1.0, initial-scale=1.0, user-scalable=yes") +class AxonLayout : AppLayout() { + init { + isDrawerOpened = false + navbar { + h3("Axon") + tabs { + orientation = Tabs.Orientation.HORIZONTAL + tab { + add(routerLink(VaadinIcon.CONTROLLER, "Jobs", JobsView::class)) + } + tab { + add(routerLink(VaadinIcon.CAMERA, "Dataset", DatasetView::class)) + } + tab { + add(routerLink(VaadinIcon.AUTOMATION, "Training", TrainingView::class)) + } + tab { + add(routerLink(VaadinIcon.INFO, "About", AboutView::class)) + } + } + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/JobRunner.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/JobRunner.kt new file mode 100644 index 00000000..1359ed63 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/JobRunner.kt @@ -0,0 +1,81 @@ +package edu.wpi.axon.ui + +import arrow.fx.IO +import arrow.fx.extensions.fx +import edu.wpi.axon.aws.EC2TrainingScriptRunner +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.tfdata.Model +import edu.wpi.axon.tflayerloader.DefaultLayersToGraph +import edu.wpi.axon.tflayerloader.LoadLayersFromHDF5 +import edu.wpi.axon.training.TrainGeneralModelScriptGenerator +import edu.wpi.axon.training.TrainSequentialModelScriptGenerator +import edu.wpi.axon.training.TrainState +import java.nio.file.Paths +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.ec2.model.InstanceType + +/** + * @param bucketName The S3 bucket name to use for dataset and models. + * @param instanceType The type of the EC2 instance to run the training script on. + * @param region The region to connect to, or `null` to autodetect the region. + */ +class JobRunner( + bucketName: String, + instanceType: InstanceType, + region: Region? +) { + + private val loadLayersFromHDF5 = LoadLayersFromHDF5(DefaultLayersToGraph()) + private val scriptRunner = EC2TrainingScriptRunner(bucketName, instanceType, region) + + /** + * Generates the code for a job and starts it on EC2. + * + * @param job The [Job] to run. + * @return The script id of the script that was started. + */ + fun startJob(job: Job): Long = IO.fx { + val modelFile = Paths.get(job.userOldModelPath).toFile() + val trainModelScriptGenerator = when ( + val model = loadLayersFromHDF5.load(modelFile).bind()) { + is Model.Sequential -> TrainSequentialModelScriptGenerator(toTrainState(job, model)) + is Model.General -> TrainGeneralModelScriptGenerator(toTrainState(job, model)) + } + + val script = trainModelScriptGenerator.generateScript().fold( + { + IO.raiseError( + IllegalStateException( + """ + |Got errors when generating script: + |${it.all.joinToString("\n")} + """.trimMargin() + ) + ) + }, + { IO.just(it) } + ).bind() + + scriptRunner.startScript( + oldModelName = trainModelScriptGenerator.trainState.userOldModelName, + newModelName = job.userNewModelName, + scriptContents = script + ).bind() + }.unsafeRunSync() + + private fun toTrainState( + job: Job, + model: T + ) = TrainState( + userOldModelPath = job.userOldModelPath, + userNewModelName = job.userNewModelName, + userDataset = job.userDataset, + userOptimizer = job.userOptimizer, + userLoss = job.userLoss, + userMetrics = job.userMetrics, + userEpochs = job.userEpochs, + userNewModel = model, + userAuth = null, + generateDebugComments = job.generateDebugComments + ) +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiUI.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiUI.kt deleted file mode 100644 index a6804d78..00000000 --- a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiUI.kt +++ /dev/null @@ -1,10 +0,0 @@ -package edu.wpi.axon.ui - -import com.vaadin.flow.component.UI -import com.vaadin.flow.component.dependency.HtmlImport -import com.vaadin.flow.theme.Theme -import com.vaadin.flow.theme.lumo.Lumo - -@HtmlImport("frontend://styles/ui-theme.html") -@Theme(Lumo::class) -class UiUI : UI() diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiView.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiView.kt deleted file mode 100644 index 0cffe34d..00000000 --- a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/UiView.kt +++ /dev/null @@ -1,27 +0,0 @@ -package edu.wpi.axon.ui - -import com.vaadin.flow.component.button.Button -import com.vaadin.flow.component.html.Label -import com.vaadin.flow.component.notification.Notification -import com.vaadin.flow.component.orderedlayout.VerticalLayout -import com.vaadin.flow.component.page.BodySize -import com.vaadin.flow.component.page.Viewport -import com.vaadin.flow.router.Route - -/** - * The main view of the application - */ -@Route("") -@BodySize(height = "100vh", width = "100vw") -@Viewport("width=device-width, minimum-scale=1.0, initial-scale=1.0, user-scalable=yes") -class UiView : VerticalLayout() { - init { - val hello = Label("Hello Kotlin app!") - add(hello) - - val button = Button("Click me") { - Notification.show("Clicked!") - } - add(button) - } -} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/WebAppListener.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/WebAppListener.kt new file mode 100644 index 00000000..385c7947 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/WebAppListener.kt @@ -0,0 +1,28 @@ +package edu.wpi.axon.ui + +import edu.wpi.axon.dsl.defaultModule +import javax.servlet.ServletContextEvent +import javax.servlet.ServletContextListener +import javax.servlet.annotation.WebListener +import mu.KotlinLogging +import org.koin.core.context.startKoin + +@WebListener +class WebAppListener : ServletContextListener { + + override fun contextInitialized(sce: ServletContextEvent?) { + LOGGER.info { "Starting web app." } + + startKoin { + modules(defaultModule()) + } + } + + override fun contextDestroyed(sce: ServletContextEvent?) { + LOGGER.info { "Stopping web app." } + } + + companion object { + private val LOGGER = KotlinLogging.logger { } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/service/JobService.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/service/JobService.kt new file mode 100644 index 00000000..d90131c2 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/service/JobService.kt @@ -0,0 +1,68 @@ +package edu.wpi.axon.ui.service + +import com.vaadin.flow.data.provider.DataProvider +import edu.wpi.axon.db.JobDb +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.dbdata.TrainingScriptProgress +import edu.wpi.axon.tfdata.Dataset +import edu.wpi.axon.tfdata.loss.Loss +import edu.wpi.axon.tfdata.optimizer.Optimizer +import kotlin.random.Random +import org.apache.commons.lang3.RandomStringUtils +import org.jetbrains.exposed.sql.Database + +object JobService { + + val jobs = JobDb( + Database.connect( + url = "jdbc:h2:mem:test;DB_CLOSE_DELAY=-1", + driver = "org.h2.Driver" + ) + ) + + val dataProvider = DataProvider.fromCallbacks( + { jobs.fetch(it.limit, it.offset).stream() }, + { jobs.count() } + ) + + init { + for (i in 1..20) { + jobs.create(Random.nextJob()) + } + } + + private fun Random.nextJob() = Job( + RandomStringUtils.randomAlphanumeric(10), + TrainingScriptProgress.Completed, + RandomStringUtils.randomAlphanumeric(10), + RandomStringUtils.randomAlphanumeric(10), + nextDataset(), + Optimizer.Adam( + nextDouble(), + nextDouble(), + nextDouble(), + nextDouble(), + nextBoolean() + ), + Loss.SparseCategoricalCrossentropy, + setOf( + RandomStringUtils.randomAlphanumeric(10), + RandomStringUtils.randomAlphanumeric(10) + ), + nextInt(), + nextBoolean() + ) + + fun Random.nextDataset(): Dataset { + return if (nextBoolean()) { + Dataset.ExampleDataset::class.sealedSubclasses.let { + it[nextInt(it.size)].objectInstance!! + } + } else { + Dataset.Custom( + RandomStringUtils.randomAlphanumeric(20), + RandomStringUtils.randomAlphanumeric(20) + ) + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/AboutView.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/AboutView.kt new file mode 100644 index 00000000..42b381ca --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/AboutView.kt @@ -0,0 +1,16 @@ +package edu.wpi.axon.ui.view + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.h1 +import com.github.mvysny.karibudsl.v10.verticalLayout +import com.vaadin.flow.router.Route +import edu.wpi.axon.ui.AxonLayout + +@Route(layout = AxonLayout::class) +class AboutView : KComposite() { + private val root = ui { + verticalLayout { + h1("Created by: Austin & Ryan") + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/DatasetView.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/DatasetView.kt new file mode 100644 index 00000000..0561d786 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/DatasetView.kt @@ -0,0 +1,18 @@ +package edu.wpi.axon.ui.view + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.horizontalLayout +import com.vaadin.flow.router.Route +import edu.wpi.axon.ui.AxonLayout +import edu.wpi.axon.ui.view.composite.DatasetSelector +import edu.wpi.axon.ui.view.composite.JobsList + +@Route(layout = AxonLayout::class) +class DatasetView : KComposite() { + private val root = ui { + horizontalLayout { + add(JobsList()) + add(DatasetSelector()) + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/JobsView.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/JobsView.kt new file mode 100644 index 00000000..5838b164 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/JobsView.kt @@ -0,0 +1,16 @@ +package edu.wpi.axon.ui.view + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.verticalLayout +import com.vaadin.flow.router.Route +import edu.wpi.axon.ui.AxonLayout +import edu.wpi.axon.ui.view.composite.JobsGrid + +@Route(layout = AxonLayout::class) +class JobsView : KComposite() { + private val root = ui { + verticalLayout { + add(JobsGrid()) + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/TrainingView.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/TrainingView.kt new file mode 100644 index 00000000..2df7ac83 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/TrainingView.kt @@ -0,0 +1,77 @@ +package edu.wpi.axon.ui.view + +import com.github.mvysny.karibudsl.v10.KComposite +import com.vaadin.flow.router.Route +import edu.wpi.axon.ui.AxonLayout + +@Route(layout = AxonLayout::class) +class TrainingView : KComposite() { + // private val binder = BeanValidationBinder(TrainingModel::class.java) + // + // private val root = ui { + // verticalLayout { + // val infoLabel = label { + // text = "Info" + // } + // formLayout { + // formItem { + // comboBox>("Optimizer") { + // setItems(Optimizer::class.sealedSubclasses) + // setItemLabelGenerator { + // it.simpleName + // } + // + // bind(binder).bind(TrainingModel::userOptimizer) + // } + // } + // formItem { + // comboBox>("Loss") { + // setItems(Loss::class.sealedSubclasses) + // setItemLabelGenerator { + // it.simpleName + // } + // + // bind(binder).bind(TrainingModel::userLoss) + // } + // } + // formItem { + // numberField("Epochs") { + // setHasControls(true) + // isPreventInvalidInput = true + // + // bind(binder) + // .asRequired() + // .withValidator { value, _ -> + // if (value.isWholeNumber()) { + // ValidationResult.ok() + // } else { + // ValidationResult.error("Must be an integer!") + // } + // } + // .toInt() + // .bind(TrainingModel::userEpochs) + // } + // } + // formItem { + // checkBox("Generate Debug Output") { + // bind(binder).bind(TrainingModel::generateDebugComments) + // } + // } + // button("Generate") { + // onLeftClick { + // val validate = binder.validate() + // if (validate.isOk) { + // infoLabel.text = "Saved bean values: ${binder.bean}" + // } else { + // infoLabel.text = "There are errors: ${validate.validationErrors}" + // } + // } + // } + // } + // } + // } + // + // init { + // binder.bean = VaadinSession.getCurrent().getAttribute(TrainingModel::class.java) + // } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/DatasetSelector.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/DatasetSelector.kt new file mode 100644 index 00000000..5a0670e4 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/DatasetSelector.kt @@ -0,0 +1,23 @@ +package edu.wpi.axon.ui.view.composite + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.comboBox +import com.github.mvysny.karibudsl.v10.formItem +import com.github.mvysny.karibudsl.v10.formLayout +import com.github.mvysny.karibudsl.v10.verticalLayout +import edu.wpi.axon.tfdata.Dataset + +class DatasetSelector : KComposite() { + private val root = ui { + verticalLayout { + formLayout { + formItem { + comboBox("Dataset") { + setItems(Dataset.ExampleDataset::class.sealedSubclasses.mapNotNull { it.objectInstance }) + setItemLabelGenerator { it.displayName } + } + } + } + } + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsGrid.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsGrid.kt new file mode 100644 index 00000000..0cbee4ef --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsGrid.kt @@ -0,0 +1,62 @@ +package edu.wpi.axon.ui.view.composite + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.addColumnFor +import com.github.mvysny.karibudsl.v10.grid +import com.github.mvysny.karibudsl.v10.gridContextMenu +import com.github.mvysny.karibudsl.v10.item +import com.github.mvysny.karibudsl.v10.verticalLayout +import com.vaadin.flow.component.button.Button +import com.vaadin.flow.component.grid.ColumnTextAlign +import com.vaadin.flow.component.grid.Grid +import com.vaadin.flow.component.notification.Notification +import com.vaadin.flow.data.renderer.ComponentRenderer +import com.vaadin.flow.data.renderer.TextRenderer +import com.vaadin.flow.function.SerializableFunction +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.ui.service.JobService + +class JobsGrid : KComposite() { + private lateinit var grid: Grid + + private val root = ui { + verticalLayout { + grid = grid(JobService.dataProvider) { + addColumnFor(Job::name) + addColumnFor(Job::status, TextRenderer { it.status.javaClass.simpleName }) + addColumnFor(Job::userDataset, TextRenderer { it.userDataset.displayName }) + + addColumn(ComponentRenderer(SerializableFunction { job -> + Button("Clone") { + Notification.show("Clone Button: ${job.name}") + } + })).apply { + textAlign = ColumnTextAlign.END + } + addColumn(ComponentRenderer(SerializableFunction { job -> + Button("Run") { + Notification.show("Run Button: ${job.name}") + } + })).apply { + textAlign = ColumnTextAlign.END + } + addColumn(ComponentRenderer(SerializableFunction { job -> + Button("Remove") { deleteJob(job) } + })).apply { + textAlign = ColumnTextAlign.END + } + + gridContextMenu { + item("Clone", { if (it != null) Notification.show("Clone Context: ${it.name}") }) + item("Run", { if (it != null) Notification.show("Run Context: ${it.name}") }) + item("Remove", { if (it != null) deleteJob(it) }) + } + } + } + } + + private fun deleteJob(job: Job) { + JobService.jobs.remove(job.id) + grid.dataProvider.refreshAll() + } +} diff --git a/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsList.kt b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsList.kt new file mode 100644 index 00000000..f59abd60 --- /dev/null +++ b/ui-vaadin/src/main/kotlin/edu/wpi/axon/ui/view/composite/JobsList.kt @@ -0,0 +1,43 @@ +package edu.wpi.axon.ui.view.composite + +import com.github.mvysny.karibudsl.v10.KComposite +import com.github.mvysny.karibudsl.v10.grid +import com.github.mvysny.karibudsl.v10.gridContextMenu +import com.github.mvysny.karibudsl.v10.h4 +import com.github.mvysny.karibudsl.v10.item +import com.github.mvysny.karibudsl.v10.sortProperty +import com.github.mvysny.karibudsl.v10.verticalLayout +import com.vaadin.flow.component.grid.Grid +import com.vaadin.flow.component.notification.Notification +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.ui.service.JobService + +class JobsList : KComposite() { + private lateinit var grid: Grid + + private val root = ui { + verticalLayout { + h4("Jobs") + grid = grid(dataProvider = JobService.dataProvider) { + addColumn { Job::name.get(it) }.apply { + key = Job::name.name + sortProperty = Job::name + } + + gridContextMenu { + item( + "Clone", + { if (it != null) Notification.show("Clone Context: ${it.name}") } + ) + item("Run", { if (it != null) Notification.show("Run Context: ${it.name}") }) + item("Remove", { if (it != null) deleteJob(it) }) + } + } + } + } + + private fun deleteJob(job: Job) { + JobService.jobs.remove(job.id) + grid.dataProvider.refreshAll() + } +} diff --git a/ui-vaadin/src/main/webapp/frontend/styles/ui-theme.css b/ui-vaadin/src/main/webapp/frontend/styles/ui-theme.css index 22d3dde3..20309a12 100644 --- a/ui-vaadin/src/main/webapp/frontend/styles/ui-theme.css +++ b/ui-vaadin/src/main/webapp/frontend/styles/ui-theme.css @@ -1,5 +1,4 @@ /* * This file contains the Application theme */ -label { color: green; } -label.clicked { color: red; } \ No newline at end of file +h3 { color: green; } diff --git a/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/UiViewTest.kt b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/AboutViewTest.kt similarity index 54% rename from ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/UiViewTest.kt rename to ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/AboutViewTest.kt index 4b64e911..1704c868 100644 --- a/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/UiViewTest.kt +++ b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/AboutViewTest.kt @@ -1,16 +1,15 @@ package edu.wpi.axon.ui -import com.github.mvysny.kaributesting.v10.LocatorJ._click import com.github.mvysny.kaributesting.v10.LocatorJ._get import com.github.mvysny.kaributesting.v10.MockVaadin import com.github.mvysny.kaributesting.v10.Routes -import com.github.mvysny.kaributesting.v10.expectNotifications -import com.vaadin.flow.component.button.Button +import com.vaadin.flow.component.UI +import com.vaadin.flow.component.html.H1 import org.junit.jupiter.api.BeforeAll import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test -class UiViewTest { +class AboutViewTest { companion object { private lateinit var routes: Routes @@ -24,15 +23,13 @@ class UiViewTest { @BeforeEach fun setupVaadin() { - MockVaadin.setup(routes!!) + MockVaadin.setup(routes) + + UI.getCurrent().navigate("about") } @Test - fun testGreeting() { - // simulate a button click as if clicked by the user - _click(_get(Button::class.java) { spec -> spec.withCaption("Click me") }) - - // look up the Example Template and assert on its value - expectNotifications("Clicked!") + fun testAboutText() { + _get(H1::class.java) { spec -> spec.withText("Created by: Austin & Ryan") } } } diff --git a/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/DatasetViewTest.kt b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/DatasetViewTest.kt new file mode 100644 index 00000000..04208877 --- /dev/null +++ b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/DatasetViewTest.kt @@ -0,0 +1,45 @@ +package edu.wpi.axon.ui + +import com.github.mvysny.kaributesting.v10.LocatorJ._get +import com.github.mvysny.kaributesting.v10.MockVaadin +import com.github.mvysny.kaributesting.v10.Routes +import com.github.mvysny.kaributesting.v10.getSuggestionItems +import com.github.mvysny.kaributesting.v10.setUserInput +import com.vaadin.flow.component.UI +import com.vaadin.flow.component.combobox.ComboBox +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.junit.jupiter.params.ParameterizedTest +import org.junit.jupiter.params.provider.ValueSource + +class DatasetViewTest { + + companion object { + private lateinit var routes: Routes + + @BeforeAll + @JvmStatic + fun createRoutes() { + routes = Routes().autoDiscoverViews("edu.wpi.axon.ui") + } + } + + @BeforeEach + fun setupVaadin() { + MockVaadin.setup(routes) + + UI.getCurrent().navigate("dataset") + } + + @ParameterizedTest + @ValueSource(strings = ["Boston Housing", "CIFAR-10", "Fashion MNIST", "MNIST"]) + fun testDatasetSelection(input: String) { + val combobox = _get(ComboBox::class.java) { spec -> spec.withCaption("Dataset") } + + combobox.setUserInput(input) + combobox.value = combobox.getSuggestionItems()[0] + + Assertions.assertNotNull(combobox.value) + } +} diff --git a/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/JobRunnerIntegTest.kt b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/JobRunnerIntegTest.kt new file mode 100644 index 00000000..d6a74e86 --- /dev/null +++ b/ui-vaadin/src/test/kotlin/edu/wpi/axon/ui/JobRunnerIntegTest.kt @@ -0,0 +1,49 @@ +package edu.wpi.axon.ui + +import edu.wpi.axon.dbdata.Job +import edu.wpi.axon.dbdata.TrainingScriptProgress +import edu.wpi.axon.dsl.defaultModule +import edu.wpi.axon.testutil.KoinTestFixture +import edu.wpi.axon.tfdata.Dataset +import edu.wpi.axon.tfdata.loss.Loss +import edu.wpi.axon.tfdata.optimizer.Optimizer +import edu.wpi.axon.training.testutil.loadModel +import org.junit.jupiter.api.Disabled +import org.junit.jupiter.api.Test +import org.koin.core.context.startKoin +import software.amazon.awssdk.regions.Region +import software.amazon.awssdk.services.ec2.model.InstanceType + +internal class JobRunnerIntegTest : KoinTestFixture() { + + @Test + @Disabled("Needs AWS supervision.") + fun `test starting job for mobilenet`() { + startKoin { + modules(defaultModule()) + } + + val jobRunner = JobRunner( + "axon-salmon-testbucket2", + InstanceType.T3_MEDIUM, + Region.US_EAST_1 + ) + + val newModelName = "mobilenetv2_1.00_224-trained.h5" + val (_, path) = loadModel("mobilenetv2_1.00_224.h5") {} + val job = Job( + "Job 1", + TrainingScriptProgress.NotStarted, + path, + newModelName, + Dataset.ExampleDataset.Mnist, + Optimizer.Adam(0.001, 0.9, 0.999, 1e-7, false), + Loss.SparseCategoricalCrossentropy, + setOf("accuracy"), + 1, + false + ) + + jobRunner.startJob(job) + } +} diff --git a/ui-vaadin/src/test/resources/edu/wpi/axon/ui/mobilenetv2_1.00_224.h5 b/ui-vaadin/src/test/resources/edu/wpi/axon/ui/mobilenetv2_1.00_224.h5 new file mode 100644 index 00000000..6a85fd2c Binary files /dev/null and b/ui-vaadin/src/test/resources/edu/wpi/axon/ui/mobilenetv2_1.00_224.h5 differ diff --git a/ui-vaadin/ui-vaadin.gradle.kts b/ui-vaadin/ui-vaadin.gradle.kts index 8d2a1aef..bb63cc3b 100644 --- a/ui-vaadin/ui-vaadin.gradle.kts +++ b/ui-vaadin/ui-vaadin.gradle.kts @@ -5,7 +5,7 @@ plugins { gretty { // https://akhikhl.github.io/gretty-doc/Gretty-configuration.html - host = "localhost" + host = "0.0.0.0" httpPort = 8080 contextPath = "axon" } @@ -21,13 +21,31 @@ vaadin { } dependencies { + implementation("org.jetbrains.exposed:exposed:0.17.7") // temp + + api(project(":db-data")) + + implementation(project(":aws")) + implementation(project(":db")) + implementation(project(":dsl")) + implementation(project(":tf-data")) + implementation(project(":tf-layer-loader")) + implementation(project(":training")) + implementation(project(":util")) + implementation(project(":logging")) + implementation(platform(vaadin.bom())) implementation(vaadin.core()) implementation(vaadin.lumoTheme()) + implementation("com.github.mvysny.karibudsl:karibu-dsl-v10:0.7.0") + + implementation(group = "org.hibernate", name = "hibernate-validator", version = "5.4.1.Final") + testImplementation( group = "com.github.mvysny.kaributesting", name = "karibu-testing-v10", version = property("karibu-testing-v10.version") as String ) + testImplementation(project(":training-test-util")) }