Skip to content

Commit 4c4c0fa

Browse files
authored
Visualisations of model with letsplot (#70)
* introduce proposed plot api to jvm example * fix docs * rename function and refactor common code * update plot api proposition * create separate modeule for visualization * naming fixes * use it instead of model * rename helper function xyio * add kdocs * add wav visualization * add network visualization * Add comments * exclude slf4j, mark casts * clean gradle configuration * fix dependency in tests
1 parent 0991b2f commit 4c4c0fa

File tree

15 files changed

+552
-52
lines changed

15 files changed

+552
-52
lines changed

api/src/main/kotlin/org/jetbrains/kotlinx/dl/dataset/embeddedDatasets.kt

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,6 @@ import java.nio.file.Paths
2121
import java.nio.file.StandardCopyOption
2222
import java.util.zip.ZipEntry
2323
import java.util.zip.ZipFile
24-
import java.io.File
25-
26-
import java.io.IOException
27-
28-
import java.io.FileOutputStream
29-
import java.lang.IllegalStateException
3024

3125

3226
/**
@@ -191,7 +185,7 @@ public fun freeSpokenDigits(
191185
cacheDirectory.existsOrMkdirs()
192186

193187
val path = freeSpokenDigitDatasetPath(cacheDirectory)
194-
val dataset = File("$path/free-spoken-digit-dataset-master/recordings")
188+
val dataset = File(path)
195189
.listFiles()?.flatMap(::extractWavFileSamples)
196190
?: throw IllegalStateException("Cannot find Free Spoken Digits Dataset files in $path")
197191
val maxDataSize = dataset.map { it.first.size }.maxOrNull()
@@ -275,7 +269,8 @@ public fun dogsCatsDatasetPath(cacheDirectory: File = File("cache")): String =
275269
unzipDatasetPath(
276270
cacheDirectory,
277271
loadFile(cacheDirectory, DOGS_CATS_IMAGES_ARCHIVE),
278-
"/datasets/dogs-vs-cats")
272+
"/datasets/dogs-vs-cats"
273+
)
279274

280275
/** Path to the subset of Dogs-vs-Cats dataset. */
281276
private const val DOGS_CATS_SMALL_IMAGES_ARCHIVE: String = "datasets/small_catdogs/data.zip"
@@ -285,20 +280,25 @@ public fun dogsCatsSmallDatasetPath(cacheDirectory: File = File("cache")): Strin
285280
unzipDatasetPath(
286281
cacheDirectory,
287282
loadFile(cacheDirectory, DOGS_CATS_SMALL_IMAGES_ARCHIVE),
288-
"/datasets/small-dogs-vs-cats")
283+
"/datasets/small-dogs-vs-cats"
284+
)
289285

290286
/** Path to the Free Spoken Digits Dataset. */
291287
private const val FSDD_SOUNDS_ARCHIVE: String = "datasets/fsdd.zip"
292288

293289
/** Path to download the Free Spoken Digits Dataset. */
294-
private const val FSS_SOUNDS_SOURCE: String = "https://codeload.github.com/Jakobovski/free-spoken-digit-dataset/zip/refs/heads/master"
290+
private const val FSS_SOUNDS_SOURCE: String =
291+
"https://codeload.github.com/Jakobovski/free-spoken-digit-dataset/zip/refs/heads/master"
295292

296-
/** Returns path to images of the subset of the Dogs-vs-Cats dataset. */
293+
/** Returns path to sound data files from Free Spoken Digits Dataset. */
297294
public fun freeSpokenDigitDatasetPath(cacheDirectory: File = File("cache")): String =
298295
unzipDatasetPath(
299296
cacheDirectory,
300297
loadFile(cacheDirectory, FSDD_SOUNDS_ARCHIVE, downloadURLFromRelativePath = { FSS_SOUNDS_SOURCE }),
301-
"/datasets/free-spoken-digit")
298+
"/datasets/free-spoken-digit"
299+
).run {
300+
"$this/free-spoken-digit-dataset-master/recordings"
301+
}
302302

303303
/**
304304
* Download the compressed dataset from external source, decompress the file and remove the downloaded file

api/src/main/kotlin/org/jetbrains/kotlinx/dl/dataset/sound/wav/WavFile.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,11 @@ private const val RIFF_TYPE_ID: Long = 0x45564157
2828
*
2929
* Based on code written by [Andrew Greensted](http://www.labbookpages.co.uk/)
3030
* but modified to more Kotlin idiomatic way with only read option for simplicity.
31+
*
32+
* @property bufferSize is a size of a buffer to read from given file when reading next frames.
33+
* @constructor creates [WavFile]
34+
*
35+
* @param file to read the WAV file data from
3136
*/
3237
public class WavFile(
3338
file: File,

build.gradle

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ repositories {
2828
subprojects {
2929
repositories {
3030
mavenCentral()
31+
maven {
32+
url "https://dl.bintray.com/kotlin/kotlinx.html"
33+
}
3134
}
3235
}
33-
34-

examples/build.gradle

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
dependencies {
22
implementation "org.jetbrains.kotlin:kotlin-stdlib-jdk8"
3-
api(project(":api"))
3+
api project(":api")
4+
implementation project(":visualization")
45
}
56

67
dependencies {
78
compile 'org.apache.logging.log4j:log4j-api:2.14.0'
89
compile 'org.apache.logging.log4j:log4j-core:2.14.0'
910
compile 'org.apache.logging.log4j:log4j-slf4j-impl:2.14.0'
11+
1012
testCompile 'org.junit.jupiter:junit-jupiter-api:5.5.2'
1113
testCompile 'org.junit.jupiter:junit-jupiter-engine:5.5.2'
1214
testCompile 'org.junit.jupiter:junit-jupiter-params:5.5.2'
@@ -19,6 +21,7 @@ dependencies {
1921
compileKotlin {
2022
kotlinOptions.jvmTarget = "1.8"
2123
}
24+
2225
compileTestKotlin {
2326
kotlinOptions.jvmTarget = "1.8"
2427
}
@@ -29,4 +32,3 @@ test {
2932
minHeapSize = "1024m"
3033
maxHeapSize = "8g"
3134
}
32-

examples/src/main/kotlin/examples/cnn/fsdd/SoundNet.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ private const val SEED = 12L
3737
* @param poolStride stride for poolSize and stride in maxpooling layer
3838
* @return array of layers to be registered in [Sequential] as vararg
3939
*/
40-
private fun soundBlock(filters: Long, kernelSize: Long, poolStride: Long): Array<Layer> =
40+
internal fun soundBlock(filters: Long, kernelSize: Long, poolStride: Long): Array<Layer> =
4141
arrayOf(
4242
Conv1D(
4343
filters = filters,

examples/src/main/kotlin/examples/visualisation/LeNetFashionMnistVisualisation.kt renamed to examples/src/main/kotlin/examples/visualization/LeNetFashionMnistVisualization.kt

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

6-
package examples.visualisation
6+
package examples.visualization
77

88
import examples.inference.lenet5
9+
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D
910
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
1011
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
1112
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
1213
import org.jetbrains.kotlinx.dl.dataset.fashionMnist
14+
import org.jetbrains.kotlinx.dl.visualization.letsplot.*
15+
import org.jetbrains.kotlinx.dl.visualization.swing.*
1316

1417
private const val EPOCHS = 1
1518
private const val TRAINING_BATCH_SIZE = 500
@@ -38,7 +41,12 @@ fun main() {
3841

3942
val (newTrain, validation) = train.split(0.95)
4043

44+
val sampleIndex = 42
45+
val x = test.getX(sampleIndex)
46+
val y = test.getY(sampleIndex).toInt()
47+
4148
lenet5().use {
49+
4250
it.compile(
4351
optimizer = Adam(),
4452
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
@@ -53,28 +61,39 @@ fun main() {
5361
validationBatchSize = TEST_BATCH_SIZE
5462
)
5563

56-
val imageId = 0
57-
58-
val weights = it.layers[1].weights // first conv2d layer
59-
60-
drawFilters(weights.values.toTypedArray()[0], colorCoefficient = 10.0)
61-
62-
val weights2 = it.layers[3].weights // first conv2d layer
63-
64-
drawFilters(weights2.values.toTypedArray()[0], colorCoefficient = 12.0)
64+
val fashionPlots = List(3) { imageIndex ->
65+
flattenImagePlot(imageIndex, test,
66+
predict = it::predict,
67+
labelEncoding = fashionMnistLabelEncoding::get,
68+
plotFeature = PlotFeature.GRAY
69+
)
70+
}
71+
columnPlot(fashionPlots, 3, 256).show()
6572

6673
val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
67-
6874
println("Accuracy $accuracy")
6975

70-
val (prediction, activations) = it.predictAndGetActivations(train.getX(imageId))
76+
val fstConv2D = it.layers[1] as Conv2D
77+
val sndConv2D = it.layers[3] as Conv2D
7178

72-
println("Prediction: ${fashionMnistLabelEncoding[prediction]}")
79+
// lets-plot approach
80+
filtersPlot(fstConv2D, columns = 16).show()
81+
filtersPlot(sndConv2D, columns = 16).show()
7382

74-
drawActivations(activations)
83+
// swing approach
84+
drawFilters(fstConv2D.weights.values.toTypedArray()[0], colorCoefficient = 10.0)
85+
drawFilters(sndConv2D.weights.values.toTypedArray()[0], colorCoefficient = 10.0)
7586

76-
val trainImageLabel = train.getY(imageId)
87+
val layersActivations = modelActivationOnLayersPlot(it, x)
88+
val (prediction, activations) = it.predictAndGetActivations(x)
89+
println("Prediction: ${fashionMnistLabelEncoding[prediction]}")
90+
println("Ground Truth: ${fashionMnistLabelEncoding[y]}")
91+
92+
// lets-plot approach
93+
layersActivations[0].show()
94+
layersActivations[1].show()
7795

78-
println("Ground Truth: ${fashionMnistLabelEncoding[trainImageLabel.toInt()]}")
96+
// swing approach
97+
drawActivations(activations)
7998
}
8099
}

examples/src/main/kotlin/examples/visualisation/LeNetMnistVisualisation.kt renamed to examples/src/main/kotlin/examples/visualization/LeNetMnistVisualization.kt

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
* Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE.txt file.
44
*/
55

6-
package examples.visualisation
6+
package examples.visualization
77

88
import examples.inference.lenet5
9+
import org.jetbrains.kotlinx.dl.api.core.layer.convolutional.Conv2D
910
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
1011
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
1112
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
1213
import org.jetbrains.kotlinx.dl.dataset.mnist
14+
import org.jetbrains.kotlinx.dl.visualization.letsplot.*
15+
import org.jetbrains.kotlinx.dl.visualization.swing.*
1316

1417
private const val EPOCHS = 1
1518
private const val TRAINING_BATCH_SIZE = 500
@@ -21,11 +24,15 @@ private const val TEST_BATCH_SIZE = 1000
2124
* Model is trained on Mnist dataset.
2225
*/
2326
fun main() {
27+
2428
val (train, test) = mnist()
2529

26-
val imageId = 1
30+
val sampleIndex = 42
31+
val x = test.getX(sampleIndex)
32+
val y = test.getY(sampleIndex).toInt()
2733

2834
lenet5().use {
35+
2936
it.compile(
3037
optimizer = Adam(),
3138
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
@@ -42,26 +49,35 @@ fun main() {
4249
validationBatchSize = TEST_BATCH_SIZE
4350
)
4451

45-
val weights = it.layers[1].weights // first conv2d layer
46-
47-
drawFilters(weights.values.toTypedArray()[0], colorCoefficient = 10.0)
48-
49-
val weights2 = it.layers[3].weights // first conv2d layer
50-
51-
drawFilters(weights2.values.toTypedArray()[0], colorCoefficient = 10.0)
52+
val numbersPlots = List(3) { imageIndex ->
53+
flattenImagePlot(imageIndex, test, it::predict)
54+
}
55+
columnPlot(numbersPlots, 3, 256).show()
5256

5357
val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
54-
5558
println("Accuracy $accuracy")
5659

57-
val (prediction, activations) = it.predictAndGetActivations(train.getX(imageId))
60+
val fstConv2D = it.layers[1] as Conv2D
61+
val sndConv2D = it.layers[3] as Conv2D
5862

59-
println("Prediction: $prediction")
63+
// lets-plot approach
64+
filtersPlot(fstConv2D, columns = 16).show()
65+
filtersPlot(sndConv2D, columns = 16).show()
6066

61-
drawActivations(activations)
67+
// swing approach
68+
drawFilters(fstConv2D.weights.values.toTypedArray()[0], colorCoefficient = 10.0)
69+
drawFilters(sndConv2D.weights.values.toTypedArray()[0], colorCoefficient = 10.0)
6270

63-
val trainImageLabel = train.getY(imageId)
71+
val layersActivations = modelActivationOnLayersPlot(it, x)
72+
val (prediction, activations) = it.predictAndGetActivations(x)
73+
println("Prediction: $prediction")
74+
println("Ground Truth: $y")
6475

65-
println("Ground Truth: $trainImageLabel")
76+
// lets-plot approach
77+
layersActivations[0].show()
78+
layersActivations[1].show()
79+
80+
// swing approach
81+
drawActivations(activations)
6682
}
6783
}

0 commit comments

Comments
 (0)