Skip to content

Commit ca98faf

Browse files
remove unnecessary allocations from caffenet
1 parent 843870c commit ca98faf

File tree

9 files changed

+198
-69
lines changed

9 files changed

+198
-69
lines changed

models/test/test.prototxt

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
name: "test"
2+
input: "data"
3+
input_shape {
4+
dim: 256
5+
dim: 3
6+
dim: 227
7+
dim: 227
8+
}
9+
input: "label"
10+
input_shape {
11+
dim: 256
12+
dim: 1
13+
}

src/main/scala/apps/CifarApp.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ object CifarApp {
111111
val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize
112112
var accuracy = 0F
113113
for (j <- 0 to numTestBatches - 1) {
114-
val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt)
114+
val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt, List("accuracy", "loss", "prob"))
115115
accuracy += out("accuracy").get(Array())
116116
}
117117
Array[(Float, Int)]((accuracy, numTestBatches)).iterator

src/main/scala/apps/ImageNetApp.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ object ImageNetApp {
111111
val numTestBatches = workerStore.get[Int]("testPartitionSize") / testBatchSize
112112
var accuracy = 0F
113113
for (j <- 0 to numTestBatches - 1) {
114-
val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt)
114+
val out = workerStore.get[CaffeSolver]("solver").trainNet.forward(testIt, List("accuracy"))
115115
accuracy += out("accuracy").get(Array())
116116
}
117117
Array[(Float, Int)]((accuracy, numTestBatches)).iterator

src/main/scala/libs/CaffeNet.scala

Lines changed: 19 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,19 +26,19 @@ object CaffeNet {
2626
}
2727

2828
class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preprocessor, caffeNet: FloatNet) {
29-
private val inputSize = netParam.input_size
30-
private val batchSize = netParam.input_shape(0).dim(0).toInt
31-
private val transformations = new Array[Any => NDArray](inputSize)
29+
val inputSize = netParam.input_size
30+
val batchSize = netParam.input_shape(0).dim(0).toInt
31+
private val transformations = new Array[(Any, Array[Float]) => Unit](inputSize)
3232
private val inputIndices = new Array[Int](inputSize)
3333
private val columnNames = schema.map(entry => entry.name)
3434
// private val caffeNet = new FloatNet(netParam)
3535
private val inputRef = new Array[FloatBlob](inputSize)
3636
def getNet = caffeNet // TODO: For debugging
3737

38-
private val numOutputs = caffeNet.num_outputs
39-
private val numLayers = caffeNet.layers().size.toInt
40-
private val layerNames = List.range(0, numLayers).map(i => caffeNet.layers.get(i).layer_param.name.getString)
41-
private val numLayerBlobs = List.range(0, numLayers).map(i => caffeNet.layers.get(i).blobs().size.toInt)
38+
val numOutputs = caffeNet.num_outputs
39+
val numLayers = caffeNet.layers().size.toInt
40+
val layerNames = List.range(0, numLayers).map(i => caffeNet.layers.get(i).layer_param.name.getString)
41+
val numLayerBlobs = List.range(0, numLayers).map(i => caffeNet.layers.get(i).blobs().size.toInt)
4242

4343
for (i <- 0 to inputSize - 1) {
4444
val name = netParam.input(i).getString
@@ -58,43 +58,35 @@ class CaffeNet(netParam: NetParameter, schema: StructType, preprocessor: Preproc
5858
inputRef(i) = new FloatBlob(dims)
5959
inputs.put(i, inputRef(i))
6060
}
61-
val inputBuffer = new Array[Array[Float]](inputSize)
61+
// in `inputBuffer`, the first index indexes the input argument, the second
62+
// index indexes into the batch, the third index indexes the values in the
63+
// data
64+
val inputBuffer = new Array[Array[Array[Float]]](inputSize)
6265
val inputBufferSize = new Array[Int](inputSize)
6366
for (i <- 0 to inputSize - 1) {
6467
inputBufferSize(i) = JavaCPPUtils.getInputShape(netParam, i).drop(1).product // drop 1 to ignore batchSize
65-
inputBuffer(i) = new Array[Float](inputBufferSize(i) * batchSize)
68+
inputBuffer(i) = new Array[Array[Float]](batchSize)
69+
for (batchIndex <- 0 to batchSize - 1) {
70+
inputBuffer(i)(batchIndex) = new Array[Float](inputBufferSize(i))
71+
}
6672
}
6773

68-
def transformInto(iterator: Iterator[Row], data: FloatBlobVector) = {
74+
def transformInto(iterator: Iterator[Row], inputs: FloatBlobVector) = {
6975
var batchIndex = 0
7076
while (iterator.hasNext && batchIndex != batchSize) {
7177
val row = iterator.next
7278
for (i <- 0 to inputSize - 1) {
73-
val result = transformations(i)(row(inputIndices(i)))
74-
val flatArray = result.toFlat() // TODO: Make this efficient
75-
System.arraycopy(flatArray, 0, inputBuffer(i), batchIndex * inputBufferSize(i), inputBufferSize(i))
79+
transformations(i)(row(inputIndices(i)), inputBuffer(i)(batchIndex))
7680
}
7781
batchIndex += 1
7882
}
79-
for (i <- 0 to inputSize - 1) {
80-
val blob = data.get(i)
81-
val buffer = blob.mutable_cpu_data()
82-
buffer.put(inputBuffer(i), 0, batchSize * inputBufferSize(i))
83-
}
83+
JavaCPPUtils.arraysToFloatBlobVector(inputBuffer, inputs, batchSize, inputBufferSize, inputSize)
8484
}
8585

8686
def forward(rowIt: Iterator[Row], dataBlobNames: List[String] = List[String]()): Map[String, NDArray] = {
8787
transformInto(rowIt, inputs)
88-
val tops = caffeNet.Forward(inputs)
88+
caffeNet.Forward(inputs)
8989
val outputs = Map[String, NDArray]()
90-
for (j <- 0 to numOutputs - 1) {
91-
val outputName = caffeNet.blob_names().get(caffeNet.output_blob_indices().get(j)).getString
92-
val top = tops.get(j)
93-
val shape = Array.range(0, top.num_axes).map(i => top.shape.get(i))
94-
val output = new Array[Float](shape.product)
95-
top.cpu_data().get(output, 0, shape.product)
96-
outputs += (outputName -> NDArray(output, shape))
97-
}
9890
for (name <- dataBlobNames) {
9991
val floatBlob = caffeNet.blob_by_name(name)
10092
if (floatBlob == null) {

src/main/scala/libs/JavaCPPUtils.scala

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ object JavaCPPUtils {
2525
shape
2626
}
2727

28-
def getInputShape(netParam: NetParameter, i: Int): Array[Int] = {
28+
def getInputShape(netParam: NetParameter, i: Int): Array[Int] = {
2929
val numAxes = netParam.input_shape(i).dim_size
3030
val shape = new Array[Int](numAxes)
3131
for (j <- 0 to numAxes - 1) {
@@ -34,5 +34,41 @@ object JavaCPPUtils {
3434
shape
3535
}
3636

37+
def arraysToFloatBlobVector(inputBuffer: Array[Array[Array[Float]]], inputs: FloatBlobVector, batchSize: Int, inputBufferSize: Array[Int], inputSize: Int) = {
38+
for (i <- 0 to inputSize - 1) {
39+
val blob = inputs.get(i)
40+
val buffer = blob.mutable_cpu_data()
41+
var batchIndex = 0
42+
while (batchIndex < batchSize) {
43+
var j = 0
44+
while (j < inputBufferSize(i)) {
45+
// it'd be preferable to do this with one call, but JavaCPP's FloatPointer API has confusing semantics
46+
buffer.put(inputBufferSize(i) * batchIndex + j, inputBuffer(i)(batchIndex)(j))
47+
j += 1
48+
}
49+
batchIndex += 1
50+
}
51+
}
52+
}
53+
54+
// this method is just for testing
55+
def arraysFromFloatBlobVector(inputs: FloatBlobVector, batchSize: Int, inputBufferSize: Array[Int], inputSize: Int): Array[Array[Array[Float]]] = {
56+
val result = new Array[Array[Array[Float]]](inputSize)
57+
for (i <- 0 to inputSize - 1) {
58+
result(i) = new Array[Array[Float]](batchSize)
59+
val blob = inputs.get(i)
60+
val buffer = blob.cpu_data()
61+
for (batchIndex <- 0 to batchSize - 1) {
62+
result(i)(batchIndex) = new Array[Float](inputBufferSize(i))
63+
var j = 0
64+
while (j < inputBufferSize(i)) {
65+
// it'd be preferable to do this with one call, but JavaCPP's FloatPointer API has confusing semantics
66+
result(i)(batchIndex)(j) = buffer.get(inputBufferSize(i) * batchIndex + j)
67+
j += 1
68+
}
69+
}
70+
}
71+
return result
72+
}
3773

3874
}

src/main/scala/libs/Preprocessor.scala

Lines changed: 35 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ import org.apache.spark.sql.types._
66
import org.apache.spark.sql.{DataFrame, Row}
77
import scala.collection.mutable._
88

9+
// The Preprocessor is provides a function for reading data from a dataframe row
10+
// into the net
911
trait Preprocessor {
10-
def convert(name: String, shape: Array[Int]): Any => NDArray
12+
def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit
1113
}
1214

1315
trait TensorFlowPreprocessor {
@@ -20,62 +22,68 @@ trait TensorFlowPreprocessor {
2022
// allocation. This is designed to be easier to understand, whereas the
2123
// ImageNetPreprocessor is designed to be faster.
2224
class DefaultPreprocessor(schema: StructType) extends Preprocessor {
23-
def convert(name: String, shape: Array[Int]): Any => NDArray = {
25+
def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit = {
2426
schema(name).dataType match {
25-
case FloatType => (element: Any) => {
26-
NDArray(Array[Float](element.asInstanceOf[Float]), shape)
27+
case FloatType => (element: Any, buffer: Array[Float]) => {
28+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
29+
NDArray(Array[Float](element.asInstanceOf[Float]), shape).flatCopy(buffer)
2730
}
28-
case DoubleType => (element: Any) => {
29-
NDArray(Array[Float](element.asInstanceOf[Double].toFloat), shape)
31+
case DoubleType => (element: Any, buffer: Array[Float]) => {
32+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
33+
NDArray(Array[Float](element.asInstanceOf[Double].toFloat), shape).flatCopy(buffer)
3034
}
31-
case IntegerType => (element: Any) => {
32-
NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape)
35+
case IntegerType => (element: Any, buffer: Array[Float]) => {
36+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
37+
NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape).flatCopy(buffer)
3338
}
34-
case LongType => (element: Any) => {
35-
NDArray(Array[Float](element.asInstanceOf[Long].toFloat), shape)
39+
case LongType => (element: Any, buffer: Array[Float]) => {
40+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
41+
NDArray(Array[Float](element.asInstanceOf[Long].toFloat), shape).flatCopy(buffer)
3642
}
37-
case BinaryType => (element: Any) => {
38-
NDArray(element.asInstanceOf[Array[Byte]].map(e => (e & 0xFF).toFloat), shape)
43+
case BinaryType => (element: Any, buffer: Array[Float]) => {
44+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
45+
NDArray(element.asInstanceOf[Array[Byte]].map(e => (e & 0xFF).toFloat), shape).flatCopy(buffer)
3946
}
40-
// case ArrayType(IntegerType, true) => (element: Any) => {} // TODO(rkn): implement
41-
case ArrayType(FloatType, true) => (element: Any) => {
47+
case ArrayType(FloatType, true) => (element: Any, buffer: Array[Float]) => {
48+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
4249
element match {
43-
case element: Array[Float] => NDArray(element.asInstanceOf[Array[Float]], shape)
44-
case element: WrappedArray[Float] => NDArray(element.asInstanceOf[WrappedArray[Float]].toArray, shape)
45-
case element: ArrayBuffer[Float] => NDArray(element.asInstanceOf[ArrayBuffer[Float]].toArray, shape)
50+
case element: Array[Float] => NDArray(element.asInstanceOf[Array[Float]], shape).flatCopy(buffer)
51+
case element: WrappedArray[Float] => NDArray(element.asInstanceOf[WrappedArray[Float]].toArray, shape).flatCopy(buffer)
52+
case element: ArrayBuffer[Float] => NDArray(element.asInstanceOf[ArrayBuffer[Float]].toArray, shape).flatCopy(buffer)
4653
}
4754
}
48-
// case ArrayType(DoubleType, true) => (element: Any) => {} // TODO(rkn): implement
49-
// case ArrayType(LongType, true) => (element: Any) => {} // TODO(rkn): implement
5055
}
5156
}
5257
}
5358

5459
class ImageNetPreprocessor(schema: StructType, meanImage: Array[Float], fullHeight: Int = 256, fullWidth: Int = 256, croppedHeight: Int = 227, croppedWidth: Int = 227) extends Preprocessor {
55-
def convert(name: String, shape: Array[Int]): Any => NDArray = {
60+
def convert(name: String, shape: Array[Int]): (Any, Array[Float]) => Unit = {
5661
schema(name).dataType match {
57-
case IntegerType => (element: Any) => {
58-
NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape)
62+
case IntegerType => (element: Any, buffer: Array[Float]) => {
63+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
64+
NDArray(Array[Float](element.asInstanceOf[Int].toFloat), shape).flatCopy(buffer)
5965
}
6066
case BinaryType => {
6167
if (shape(0) != 3) {
6268
throw new IllegalArgumentException("Expecting input image to have 3 channels.")
6369
}
64-
val buffer = new Array[Float](3 * fullHeight * fullWidth)
65-
(element: Any) => {
70+
val tempBuffer = new Array[Float](3 * fullHeight * fullWidth)
71+
(element: Any, buffer: Array[Float]) => {
72+
if (buffer.length != shape.product) { throw new Exception("buffer.length and shape.product don't agree, buffer has length " + buffer.length.toString + ", but shape is " + shape.deep.toString) }
6673
element match {
6774
case element: Array[Byte] => {
6875
var index = 0
6976
while (index < 3 * fullHeight * fullWidth) {
70-
buffer(index) = (element(index) & 0xFF).toFloat - meanImage(index)
77+
tempBuffer(index) = (element(index) & 0xFF).toFloat - meanImage(index)
7178
index += 1
7279
}
7380
}
7481
}
7582
val heightOffset = Random.nextInt(fullHeight - croppedHeight + 1)
7683
val widthOffset = Random.nextInt(fullWidth - croppedWidth + 1)
77-
NDArray(buffer.clone, Array[Int](shape(0), fullHeight, fullWidth)).subarray(Array[Int](0, heightOffset, widthOffset), Array[Int](shape(0), heightOffset + croppedHeight, widthOffset + croppedWidth))
78-
// TODO(rkn): probably don't want to call buffer.clone
84+
val lowerIndices = Array[Int](0, heightOffset, widthOffset)
85+
val upperIndices = Array[Int](shape(0), heightOffset + croppedHeight, widthOffset + croppedWidth)
86+
NDArray(tempBuffer, Array[Int](shape(0), fullHeight, fullWidth)).subarray(lowerIndices, upperIndices).flatCopy(buffer)
7987
}
8088
}
8189
}

src/test/scala/apps/LoadAdultDataSpec.scala

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,13 @@ class LoadAdultDataSpec extends FlatSpec {
1919

2020
val function0 = preprocessor.convert("C0", Array[Int](1))
2121
val function2 = preprocessor.convert("C2", Array[Int](1))
22-
val result0 = function0(df.take(1)(0)(0))
23-
val result2 = function2(df.take(1)(0)(2))
22+
val result0 = new Array[Float](1)
23+
val result2 = new Array[Float](1)
24+
function0(df.take(1)(0)(0), result0)
25+
function2(df.take(1)(0)(2), result2)
2426

25-
assert((result0.get(Array[Int](0)) - 39.0).abs <= 1e-4)
26-
assert((result2.get(Array[Int](0)) - 77516.0).abs <= 1e-4)
27+
assert((result0(0) - 39.0).abs <= 1e-4)
28+
assert((result2(0) - 77516.0).abs <= 1e-4)
2729

2830
sc.stop()
2931
}

0 commit comments

Comments
 (0)