|
| 1 | +import org.scalatest._ |
| 2 | +import org.apache.spark.sparknet.loaders.CifarLoader |
| 3 | +import org.apache.spark.sparknet.CaffeLibrary |
| 4 | +import com.sun.jna.Pointer |
| 5 | + |
| 6 | +// for this test to work, $SPARKNET_HOME/caffe should be the caffe root directory |
| 7 | +// and you need to run $SPARKNET_HOME/caffe/data/cifar10/get_cifar10.sh |
| 8 | +class CifarSpec extends FlatSpec { |
| 9 | + "A Cifar net" should "get chance digits right on randomly initialized net" in { |
| 10 | + val sparkNetHome = sys.env("SPARKNET_HOME") |
| 11 | + val loader = new CifarLoader(sparkNetHome + "/caffe/data/cifar10/") |
| 12 | + |
| 13 | + System.load(sparkNetHome + "/build/libccaffe.so") |
| 14 | + val caffeLib = CaffeLibrary.INSTANCE |
| 15 | + |
| 16 | + caffeLib.set_basepath(sparkNetHome + "/caffe/") |
| 17 | + val net = caffeLib.make_solver_from_prototxt(sparkNetHome + "/caffe/examples/cifar10/cifar10_full_java_solver.prototxt") |
| 18 | + |
| 19 | + val dtypeSize = caffeLib.get_dtype_size() |
| 20 | + val intSize = caffeLib.get_int_size() |
| 21 | + |
| 22 | + def makeImageCallback(images: Array[Array[Float]]) : CaffeLibrary.java_callback_t = { |
| 23 | + return new CaffeLibrary.java_callback_t() { |
| 24 | + var currImage = 0 |
| 25 | + def invoke(data: Pointer, batch_size: Int, num_dims: Int, shape: Pointer) { |
| 26 | + var size = 1 |
| 27 | + for(i <- 0 to num_dims-1) { |
| 28 | + val dim = shape.getInt(i * intSize) |
| 29 | + size *= dim |
| 30 | + } |
| 31 | + for(j <- 0 to batch_size-1) { |
| 32 | + assert(size == images(currImage).length) |
| 33 | + for(i <- 0 to size-1) { |
| 34 | + data.setFloat((j * size + i) * dtypeSize, images(currImage)(i)) |
| 35 | + } |
| 36 | + currImage += 1 |
| 37 | + if(currImage == images.length) { |
| 38 | + currImage = 0 |
| 39 | + } |
| 40 | + } |
| 41 | + } |
| 42 | + }; |
| 43 | + } |
| 44 | + |
| 45 | + def makeLabelCallback(labels: Array[Float]) : CaffeLibrary.java_callback_t = { |
| 46 | + return new CaffeLibrary.java_callback_t() { |
| 47 | + var currImage = 0 |
| 48 | + def invoke(data: Pointer, batch_size: Int, num_dims: Int, shape: Pointer) { |
| 49 | + for(j <- 0 to batch_size-1) { |
| 50 | + assert(shape.getInt(0) == 1) |
| 51 | + data.setFloat(j * dtypeSize, labels(currImage)) |
| 52 | + currImage += 1 |
| 53 | + if(currImage == labels.length) { |
| 54 | + currImage = 0 |
| 55 | + } |
| 56 | + } |
| 57 | + } |
| 58 | + }; |
| 59 | + } |
| 60 | + |
| 61 | + val loadTrainImageFn = makeImageCallback(loader.trainImages) |
| 62 | + val loadTrainLabelFn = makeLabelCallback(loader.trainLabels) |
| 63 | + caffeLib.set_train_data_callback(net, 0, loadTrainImageFn) |
| 64 | + caffeLib.set_train_data_callback(net, 1, loadTrainLabelFn) |
| 65 | + |
| 66 | + val loadTestImageFn = makeImageCallback(loader.testImages) |
| 67 | + val loadTestLabelFn = makeLabelCallback(loader.testLabels) |
| 68 | + caffeLib.set_test_data_callback(net, 0, loadTestImageFn) |
| 69 | + caffeLib.set_test_data_callback(net, 1, loadTestLabelFn) |
| 70 | + |
| 71 | + caffeLib.solver_test(net, 10) // TODO: shouldn't be hard coded |
| 72 | + |
| 73 | + val numTestScores = caffeLib.num_test_scores(net) |
| 74 | + |
| 75 | + val testScores = new Array[Float](numTestScores) |
| 76 | + |
| 77 | + // perform test on random net |
| 78 | + for (i <- 0 to numTestScores - 1) { |
| 79 | + testScores(i) = caffeLib.get_test_score(net, i) * 100 // TODO: this batch size shouldn't be hard coded |
| 80 | + } |
| 81 | + |
| 82 | + assert(70.0 <= testScores(0) && testScores(0) <= 130.0) |
| 83 | + } |
| 84 | +} |
0 commit comments