diff --git a/examples/c3d_train_ucf101/conv3d_ucf101_deploy.prototxt b/examples/c3d_train_ucf101/conv3d_ucf101_deploy.prototxt new file mode 100644 index 0000000000..829d331645 --- /dev/null +++ b/examples/c3d_train_ucf101/conv3d_ucf101_deploy.prototxt @@ -0,0 +1,342 @@ +name: "deep_c3d_ucf101" +input: "data" +input_dim: 30 +input_dim: 3 +input_dim: 16 +input_dim: 112 +input_dim: 112 +# ----------- 1st layer group --------------- +layers { + name: "conv1a" + type: CONVOLUTION3D + bottom: "data" + top: "conv1a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + stride: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1a" + type: RELU + bottom: "conv1a" + top: "conv1a" +} +layers { + name: "pool1" + type: POOLING3D + bottom: "conv1a" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 1 + stride: 2 + temporal_stride: 1 + } +} +# ------------- 2nd layer group -------------- +layers { + name: "conv2a" + type: CONVOLUTION3D + bottom: "pool1" + top: "conv2a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2a" + type: RELU + bottom: "conv2a" + top: "conv2a" +} +layers { + name: "pool2" + type: POOLING3D + bottom: "conv2a" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} +# ----------------- 3rd layer group -------------- +layers { + name: "conv3a" + type: CONVOLUTION3D + bottom: "pool2" + top: "conv3a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu3a" + type: RELU + bottom: "conv3a" + top: "conv3a" +} +layers { + name: "pool3" + type: POOLING3D + bottom: "conv3a" + top: "pool3" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} + +# --------- 4th layer group +layers { + name: "conv4a" + type: CONVOLUTION3D + bottom: "pool3" + top: "conv4a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4a" + type: RELU + bottom: "conv4a" + top: "conv4a" +} +layers { + name: "pool4" + type: POOLING3D + bottom: "conv4a" + top: "pool4" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} + +# --------------- 5th layer group -------- +layers { + name: "conv5a" + type: CONVOLUTION3D + bottom: "pool4" + top: "conv5a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5a" + type: RELU + bottom: "conv5a" + top: "conv5a" +} +layers { + name: "pool5" + type: POOLING3D + bottom: "conv5a" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} +# ---------------- fc layers ------------- +layers { + name: "fc6" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 2048 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6" + top: "fc6" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6" + top: "fc6" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7" + type: INNER_PRODUCT + bottom: "fc6" + top: "fc7" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 2048 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7" + top: "fc7" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7" + top: "fc7" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8" + type: INNER_PRODUCT + bottom: "fc7" + top: "fc8" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 101 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "prob" + type: SOFTMAX + bottom: "fc8" + top: "prob" +} +#layers { +# top: "accuracy" +# name: "accuracy" +# type: ACCURACY +# bottom: "prob" +# bottom: "label" +#} diff --git a/examples/c3d_train_ucf101/run_c3d_classification.py b/examples/c3d_train_ucf101/run_c3d_classification.py new file mode 100755 index 0000000000..c0f3c0f026 --- /dev/null +++ b/examples/c3d_train_ucf101/run_c3d_classification.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python + +''' +A sample script to run c3d classifications on multiple videos +''' + +import os +import numpy as np +import math +import json +import sys +sys.path.append("/home/chuck/projects/C3D/python") +import caffe +from c3d_classify import c3d_classify +import csv + +# UCF101 categories +def get_ucf_categories(): + category = [ + 'ApplyEyeMakeup', + 'ApplyLipstick', + 'Archery', + 'BabyCrawling', + 'BalanceBeam', + 'BandMarching', + 'BaseballPitch', + 'Basketball', + 'BasketballDunk', + 'BenchPress', + 'Biking', + 'Billiards', + 'BlowDryHair', + 'BlowingCandles', + 'BodyWeightSquats', + 'Bowling', + 'BoxingPunchingBag', + 'BoxingSpeedBag', + 'BreastStroke', + 'BrushingTeeth', + 'CleanAndJerk', + 'CliffDiving', + 'CricketBowling', + 'CricketShot', + 'CuttingInKitchen', + 'Diving', + 'Drumming', + 'Fencing', + 'FieldHockeyPenalty', + 'FloorGymnastics', + 'FrisbeeCatch', + 'FrontCrawl', + 'GolfSwing', + 'Haircut', + 'Hammering', + 'HammerThrow', + 'HandstandPushups', + 'HandstandWalking', + 'HeadMassage', + 'HighJump', + 'HorseRace', + 'HorseRiding', + 'HulaHoop', + 'IceDancing', + 'JavelinThrow', + 'JugglingBalls', + 'JumpingJack', + 'JumpRope', + 'Kayaking', + 'Knitting', + 'LongJump', + 'Lunges', + 'MilitaryParade', + 'Mixing', + 'MoppingFloor', + 'Nunchucks', + 'ParallelBars', + 'PizzaTossing', + 'PlayingCello', + 'PlayingDaf', + 'PlayingDhol', + 'PlayingFlute', + 'PlayingGuitar', + 'PlayingPiano', + 'PlayingSitar', + 'PlayingTabla', + 'PlayingViolin', + 'PoleVault', + 'PommelHorse', + 'PullUps', + 'Punch', + 'PushUps', + 'Rafting', + 'RockClimbingIndoor', + 'RopeClimbing', + 'Rowing', + 'SalsaSpin', + 'ShavingBeard', + 'Shotput', + 'SkateBoarding', + 'Skiing', + 'Skijet', + 'SkyDiving', + 'SoccerJuggling', + 'SoccerPenalty', + 'StillRings', + 'SumoWrestling', + 'Surfing', + 'Swing', + 'TableTennisShot', + 'TaiChi', + 'TennisSwing', + 'ThrowDiscus', + 'TrampolineJumping', + 'Typing', + 'UnevenBars', + 'VolleyballSpiking', + 'WalkingWithDog', + 'WallPushups', + 'WritingOnBoard', + 'YoYo' + ] + + return category + +def main(): + + # get UCF101 categories + #ucf_categories = get_ucf_categories() + + # save class probs + force_save_result = False + cwd = os.path.dirname(os.path.realpath(__file__)) + result_path = 'ucf101_c3d_intermediate_results' + + # save prob ranks for allvideos + output_file = 'ucf101_c3d_performance.csv' + bufsize = 0 + out = open(output_file, "w", bufsize) + + # model + model_def_file = './conv3d_ucf101_deploy.prototxt' + model_file = './conv3d_ucf101_iter_50000' + mean_file = './ucf101_train_mean.binaryproto' + net = caffe.Net(model_def_file, model_file) + + # caffe init + gpu_id = 0 + net.set_device(gpu_id) + net.set_mode_gpu() + net.set_phase_test() + + # read test video list + test_video_list = '../c3d_finetuning/test_01.lst' + reader = csv.reader(open(test_video_list), delimiter=" ") + + # show top_N + top_N = 3 + + # network param + prob_layer = 'prob' + + for count, video_and_category in enumerate(reader): + ''' + e.g. video_name: /datasets/ucf101/ApplyEyeMakeup/v_ApplyEyeMakeup_g01_c03/ + ''' + (video_name, start_frame, category_id) = video_and_category + video_name = video_name.rstrip('/') + start_frame = int(start_frame) + category_id = int(category_id) + if not os.path.isdir(video_name): + print "[Error] video_name path={} does not exist. Skipping...".format(video_name) + continue + video_id = video_name.split('/')[-1][2:] # e.g. ApplyEyeMakeup_g01_c03 + category = video_name.split('/')[-2] # e.g. ApplyEyeMakeup + + print "-"*79 + print "video_name={} ({}-th), video_id={}, start_frame={}, category={}, category_id={}".format(video_name, count+1, video_id, start_frame, category, category_id) + + # save class probs + result = os.path.join(cwd, result_path, '{0}_frame_{1:05d}_c3d.txt'.format(video_id, start_frame)) + if os.path.isfile(result) and not force_save_result: + print "[Info] intermediate output file={} has been already saved. Skipping...".format(result) + avg_pred = np.loadtxt(result) + else: + blob = caffe.proto.caffe_pb2.BlobProto() + data = open(mean_file,'rb').read() + blob.ParseFromString(data) + image_mean = np.array(caffe.io.blobproto_to_array(blob)) + prediction = c3d_classify( + vid_name=video_name, + image_mean=image_mean, + net=net, + start_frame=start_frame, + prob_layer=prob_layer, + multi_crop=False + ) + if prediction.ndim == 2: + avg_pred = np.mean(prediction, axis=1) + else: + avg_pred = prediction + np.savetxt(result, avg_pred, delimiter=",") + sorted_indices = sorted(range(len(avg_pred)), key=lambda k: -avg_pred[k]) + print "-"*5 + for x in range(top_N): + index = sorted_indices[x] + prob = round(avg_pred[index]*100,10) + if category_id == index: + hit_or_miss = 'hit!' + else: + hit_or_miss = 'miss' + print "[Info] GT:{}, c3d detected:{} (p={}%): {}".format(category, ucf_categories[index], prob, hit_or_miss) + c3d_rank = sorted_indices.index(category_id) + 1 + + # save class ranking + out.write("{0}_{1:05d}, {2}\n".format(video_id, start_frame, c3d_rank)) + out.close() + +if __name__ == "__main__": + main() diff --git a/include/caffe/vision_layers.hpp b/include/caffe/vision_layers.hpp index fbb73c0abd..f67d9de9e5 100644 --- a/include/caffe/vision_layers.hpp +++ b/include/caffe/vision_layers.hpp @@ -267,6 +267,7 @@ class MemoryDataLayer : public Layer { // will be given to Blob, which is mutable void Reset(Dtype* data, Dtype* label, int n); int datum_channels() { return datum_channels_; } + int datum_length() { return datum_length_; } int datum_height() { return datum_height_; } int datum_width() { return datum_width_; } int batch_size() { return batch_size_; } @@ -282,6 +283,7 @@ class MemoryDataLayer : public Layer { Dtype* data_; Dtype* labels_; int datum_channels_; + int datum_length_; int datum_height_; int datum_width_; int datum_size_; diff --git a/python/c3d_classify.py b/python/c3d_classify.py new file mode 100755 index 0000000000..4de32f96b3 --- /dev/null +++ b/python/c3d_classify.py @@ -0,0 +1,93 @@ +''' +A sample function to run classification using c3d model. +''' + +import os +import numpy as np +import math +import cv2 + +def c3d_classify( + vid_name, + image_mean, + net, + start_frame, + prob_layer='prob', + multi_crop=False + ): + ''' + vid_name: a directory that contains extracted images (image_%05d.jpg) + image_mean: (3,c3d_depth=16,height,width)-dim image mean + net: a caffe network object + start_frame: frame number to run classification (start_frame:start_frame+16) + note: this is 0-based whereas the first image file is + image_0001.jpg + multi_crop: use mirroring / 4-corner + 1-center cropping + ''' + + # infer net params + batch_size = net.blobs['data'].data.shape[0] + c3d_depth = net.blobs['data'].data.shape[2] + num_categories = net.blobs['prob'].data.shape[1] + + # init + dims = (128,171,3,c3d_depth) + rgb = np.zeros(shape=dims, dtype=np.float32) + rgb_flip = np.zeros(shape=dims, dtype=np.float32) + + for i in range(c3d_depth): + img_file = os.path.join(vid_name, 'image_{0:04d}.jpg'.format(start_frame+i)) + img = cv2.imread(img_file, cv2.IMREAD_UNCHANGED) + img = cv2.resize(img, dims[1::-1]) + rgb[:,:,:,i] = img + rgb_flip[:,:,:,i] = img[:,::-1,:] + + # substract mean + image_mean = np.transpose(np.squeeze(image_mean), (2,3,0,1)) + rgb -= image_mean + rgb_flip -= image_mean[:,::-1,:,:] + + if multi_crop: + # crop (112-by-112) in upperleft, upperright, lowerleft, lowerright + # corners and the center, for both original and flipped images + rgb_1 = rgb[:112, :112, :,:] + rgb_2 = rgb[:112, -112:, :,:] + rgb_3 = rgb[8:120, 30:142, :,:] + rgb_4 = rgb[-112:, :112, :,:] + rgb_5 = rgb[-112:, -112:, :,:] + rgb_f_1 = rgb_flip[:112, :112, :,:] + rgb_f_2 = rgb_flip[:112, -112:, :,:] + rgb_f_3 = rgb_flip[8:120, 30:142, :,:] + rgb_f_4 = rgb_flip[-112:, :112, :,:] + rgb_f_5 = rgb_flip[-112:, -112:, :,:] + rgb = np.concatenate((rgb_1[...,np.newaxis], + rgb_2[...,np.newaxis], + rgb_3[...,np.newaxis], + rgb_4[...,np.newaxis], + rgb_5[...,np.newaxis], + rgb_f_1[...,np.newaxis], + rgb_f_2[...,np.newaxis], + rgb_f_3[...,np.newaxis], + rgb_f_4[...,np.newaxis], + rgb_f_5[...,np.newaxis]), axis=4) + else: + # crop (112-by-112) for both original/flipped images + rgb_3 = rgb[8:120, 30:142, :,:] + rgb_f_3 = rgb_flip[8:120, 30:142, :,:] + rgb = np.concatenate((rgb_3[...,np.newaxis], + rgb_f_3[...,np.newaxis]), axis=4) + + # run classifications on batches + prediction = np.zeros((num_categories,rgb.shape[4])) + if rgb.shape[4] < batch_size: + net.blobs['data'].data[:rgb.shape[4],:,:,:,:] = np.transpose(rgb, (4,2,3,0,1)) + output = net.forward() + prediction = np.transpose(np.squeeze(output[prob_layer][:rgb.shape[4],:,:,:,:], axis=(2,3,4))) + else: + num_batches = int(math.ceil(float(rgb.shape[4])/batch_size)) + for bb in range(num_batches): + span = range(batch_size*bb, min(rgb.shape[4],batch_size*(bb+1))) + net.blobs['data'].data[...] = np.transpose(rgb[:,:,:,:,span], (4,2,3,0,1)) + output = net.forward() + prediction[:, span] = np.transpose(np.squeeze(output[prob_layer], axis=(2,3,4))) + return prediction diff --git a/python/c3d_classify.pyc b/python/c3d_classify.pyc new file mode 100644 index 0000000000..72c8157987 Binary files /dev/null and b/python/c3d_classify.pyc differ diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index e9fe5cd3b0..93500d9546 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -55,6 +55,7 @@ class CaffeBlob { string name() const { return name_; } int num() const { return blob_->num(); } int channels() const { return blob_->channels(); } + int length() const { return blob_->length(); } int height() const { return blob_->height(); } int width() const { return blob_->width(); } int count() const { return blob_->count(); } @@ -79,9 +80,9 @@ class CaffeBlobWrap : public CaffeBlob { : CaffeBlob(blob), self_(p) {} object get_data() { - npy_intp dims[] = {num(), channels(), height(), width()}; + npy_intp dims[] = {num(), channels(), length(), height(), width()}; - PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + PyObject *obj = PyArray_SimpleNewFromData(5, dims, NPY_FLOAT32, blob_->mutable_cpu_data()); PyArray_SetBaseObject(reinterpret_cast(obj), self_); Py_INCREF(self_); @@ -91,9 +92,9 @@ class CaffeBlobWrap : public CaffeBlob { } object get_diff() { - npy_intp dims[] = {num(), channels(), height(), width()}; + npy_intp dims[] = {num(), channels(), length(), height(), width()}; - PyObject *obj = PyArray_SimpleNewFromData(4, dims, NPY_FLOAT32, + PyObject *obj = PyArray_SimpleNewFromData(5, dims, NPY_FLOAT32, blob_->mutable_cpu_diff()); PyArray_SetBaseObject(reinterpret_cast(obj), self_); Py_INCREF(self_); @@ -160,12 +161,12 @@ struct CaffeNet { // Generate Python exceptions for badly shaped or discontiguous arrays. inline void check_contiguous_array(PyArrayObject* arr, string name, - int channels, int height, int width) { + int channels, int length, int height, int width) { if (!(PyArray_FLAGS(arr) & NPY_ARRAY_C_CONTIGUOUS)) { throw std::runtime_error(name + " must be C contiguous"); } - if (PyArray_NDIM(arr) != 4) { - throw std::runtime_error(name + " must be 4-d"); + if (PyArray_NDIM(arr) != 5) { + throw std::runtime_error(name + " must be 5-d"); } if (PyArray_TYPE(arr) != NPY_FLOAT32) { throw std::runtime_error(name + " must be float32"); @@ -173,10 +174,13 @@ struct CaffeNet { if (PyArray_DIMS(arr)[1] != channels) { throw std::runtime_error(name + " has wrong number of channels"); } - if (PyArray_DIMS(arr)[2] != height) { + if (PyArray_DIMS(arr)[2] != length) { + throw std::runtime_error(name + " has wrong length"); + } + if (PyArray_DIMS(arr)[3] != height) { throw std::runtime_error(name + " has wrong height"); } - if (PyArray_DIMS(arr)[3] != width) { + if (PyArray_DIMS(arr)[4] != width) { throw std::runtime_error(name + " has wrong width"); } } @@ -204,8 +208,9 @@ struct CaffeNet { PyArrayObject* labels_arr = reinterpret_cast(labels_obj.ptr()); check_contiguous_array(data_arr, "data array", md_layer->datum_channels(), - md_layer->datum_height(), md_layer->datum_width()); - check_contiguous_array(labels_arr, "labels array", 1, 1, 1); + md_layer->datum_length(), md_layer->datum_height(), + md_layer->datum_width()); + check_contiguous_array(labels_arr, "labels array", 1, 1, 1, 1); if (PyArray_DIMS(data_arr)[0] != PyArray_DIMS(labels_arr)[0]) { throw std::runtime_error("data and labels must have the same first" " dimension"); @@ -330,6 +335,7 @@ BOOST_PYTHON_MODULE(_caffe) { .add_property("name", &CaffeBlob::name) .add_property("num", &CaffeBlob::num) .add_property("channels", &CaffeBlob::channels) + .add_property("length", &CaffeBlob::length) .add_property("height", &CaffeBlob::height) .add_property("width", &CaffeBlob::width) .add_property("count", &CaffeBlob::count) diff --git a/python/caffe/io.py b/python/caffe/io.py index 0bd2f812be..c1539ea7f8 100644 --- a/python/caffe/io.py +++ b/python/caffe/io.py @@ -28,14 +28,43 @@ def resize_image(im, new_dims, interp_order=1): Resize an image array with interpolation. Take - im: (H x W x K) ndarray + im: (H x W x K) or (H x W x K x L) ndarray new_dims: (height, width) tuple of new dimensions. interp_order: interpolation order, default is linear. Give im: resized ndarray with shape (new_dims[0], new_dims[1], K) """ - return skimage.transform.resize(im, new_dims, order=interp_order) + + im_min, im_max = im.min(), im.max() + if im_max > im_min: + # skimage is fast but only understands {1,3} channel images + # in [0, 1]. + im_std = (im - im_min) / (im_max - im_min) + else: + # the image is a constant -- avoid divide by 0 + # TODO(chuck): cover for 4-dim im case + ret = np.empty((new_dims[0], new_dims[1], im.shape[-1]), + dtype=np.float32) + ret.fill(im_min) + return ret + + if im.ndim == 3: + resized = skimage.transform.resize(im_std, new_dims, order=interp_order) + resized = resized * (im_max - im_min) + im_min + elif im.ndim == 4: + resized = np.empty(new_dims + im.shape[-2:]) + for l in range(im.shape[3]): + resized[:,:,:,l] = skimage.transform.resize( + im_std[:,:,:,l], + new_dims, + order=interp_order + ) + resized[:,:,:,l] = resized[:,:,:,l] * (im_max - im_min) + im_min + else: + raise ValueError('Incorrect input array shape.') + + return resized def oversample(images, crop_dims): @@ -87,10 +116,10 @@ def blobproto_to_array(blob, return_diff=False): """ if return_diff: return np.array(blob.diff).reshape( - blob.num, blob.channels, blob.height, blob.width) + blob.num, blob.channels, blob.length, blob.height, blob.width) else: return np.array(blob.data).reshape( - blob.num, blob.channels, blob.height, blob.width) + blob.num, blob.channels, blob.length, blob.height, blob.width) def array_to_blobproto(arr, diff=None): @@ -98,10 +127,10 @@ def array_to_blobproto(arr, diff=None): convert the diff. You need to make sure that arr and diff have the same shape, and this function does not do sanity check. """ - if arr.ndim != 4: + if arr.ndim != 5: raise ValueError('Incorrect array shape.') blob = caffe_pb2.BlobProto() - blob.num, blob.channels, blob.height, blob.width = arr.shape; + blob.num, blob.channels, blob.length, blob.height, blob.width = arr.shape; blob.data.extend(arr.astype(float).flat) if diff is not None: blob.diff.extend(diff.astype(float).flat) @@ -130,10 +159,10 @@ def array_to_datum(arr, label=0): the output data will be encoded as a string. Otherwise, the output data will be stored in float format. """ - if arr.ndim != 3: + if arr.ndim != 4: raise ValueError('Incorrect array shape.') datum = caffe_pb2.Datum() - datum.channels, datum.height, datum.width = arr.shape + datum.channels, datum.length, datum.height, datum.width = arr.shape if arr.dtype == np.uint8: datum.data = arr.tostring() else: @@ -148,7 +177,7 @@ def datum_to_array(datum): """ if len(datum.data): return np.fromstring(datum.data, dtype = np.uint8).reshape( - datum.channels, datum.height, datum.width) + datum.channels, datum.length, datum.height, datum.width) else: return np.array(datum.float_data).astype(float).reshape( - datum.channels, datum.height, datum.width) + datum.channels, datum.length, datum.height, datum.width) diff --git a/python/caffe/pycaffe.py b/python/caffe/pycaffe.py index 5c1512cd8b..46587fbbea 100644 --- a/python/caffe/pycaffe.py +++ b/python/caffe/pycaffe.py @@ -59,8 +59,8 @@ def _Net_forward(self, blobs=None, **kwargs): for in_, blob in kwargs.iteritems(): if blob.shape[0] != self.blobs[in_].num: raise Exception('Input is not batch sized') - if blob.ndim != 4: - raise Exception('{} blob is not 4-d'.format(in_)) + if blob.ndim != 5: + raise Exception('{} blob is not 5-d'.format(in_)) self.blobs[in_].data[...] = blob self._forward() @@ -93,8 +93,8 @@ def _Net_backward(self, diffs=None, **kwargs): for top, diff in kwargs.iteritems(): if diff.shape[0] != self.blobs[top].num: raise Exception('Diff is not batch sized') - if diff.ndim != 4: - raise Exception('{} diff is not 4-d'.format(top)) + if diff.ndim != 5: + raise Exception('{} diff is not 5-d'.format(top)) self.blobs[top].diff[...] = diff self._backward() @@ -176,7 +176,7 @@ def _Net_forward_backward_all(self, blobs=None, diffs=None, **kwargs): return all_outs, all_diffs -def _Net_set_mean(self, input_, mean_f, mode='elementwise'): +def _Net_set_mean(self, input_, mean_f): """ Set the mean to subtract for data centering. @@ -184,7 +184,6 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'): input_: which input to assign this mean. mean_f: path to mean .npy with ndarray (input dimensional or broadcastable) mode: elementwise = use the whole mean (and check dimensions) - channel = channel constant (e.g. mean pixel instead of mean image) """ if not hasattr(self, 'mean'): self.mean = {} @@ -192,19 +191,17 @@ def _Net_set_mean(self, input_, mean_f, mode='elementwise'): raise Exception('Input not in {}'.format(self.inputs)) in_shape = self.blobs[input_].data.shape mean = np.load(mean_f) - if mode == 'elementwise': - if mean.shape != in_shape[1:]: - # Resize mean (which requires H x W x K input in range [0,1]). - m_min, m_max = mean.min(), mean.max() - normal_mean = (mean - m_min) / (m_max - m_min) - mean = caffe.io.resize_image(normal_mean.transpose((1,2,0)), - in_shape[2:]).transpose((2,0,1)) * (m_max - m_min) + m_min - self.mean[input_] = mean - elif mode == 'channel': - self.mean[input_] = mean.mean(1).mean(1).reshape((in_shape[1], 1, 1)) - else: - raise Exception('Mode not in {}'.format(['elementwise', 'channel'])) - + if mean.ndim == 5: + mean = np.squeeze(mean, 0) + if mean.shape != in_shape[1:]: + # Resize mean (which requires H x W x K input in range [0,1]). + m_min, m_max = mean.min(), mean.max() + normal_mean = (mean - m_min) / (m_max - m_min) + ''' [info] normal_mean.shape=(16, 3, 128, 171),in_shape=(1, 3, 16, 112, 112) ''' + mean = caffe.io.resize_image( + normal_mean.transpose((2,3,0,1)), + in_shape[3:]).transpose((2,3,0,1)) * (m_max - m_min) + m_min + self.mean[input_] = mean def _Net_set_input_scale(self, input_, scale): @@ -247,27 +244,27 @@ def _Net_preprocess(self, input_name, input_): - scale feature - reorder channels (for instance color to BGR) - subtract mean - - transpose dimensions to K x H x W + - transpose dimensions to K x L X H x W (L: c3d_depth) Take input_name: name of input blob to preprocess for - input_: (H' x W' x K) ndarray + input_: (H' x W' x K X L) ndarray Give - caffe_inputs: (K x H x W) ndarray + caffe_inputs: (K x L X H x W) ndarray """ caffe_in = input_.astype(np.float32) input_scale = self.input_scale.get(input_name) channel_order = self.channel_swap.get(input_name) mean = self.mean.get(input_name) - in_size = self.blobs[input_name].data.shape[2:] + in_size = self.blobs[input_name].data.shape[3:] if caffe_in.shape[:2] != in_size: caffe_in = caffe.io.resize_image(caffe_in, in_size) if input_scale: caffe_in *= input_scale if channel_order: - caffe_in = caffe_in[:, :, channel_order] - caffe_in = caffe_in.transpose((2, 0, 1)) + caffe_in = caffe_in[:, :, channel_order, :] + caffe_in = caffe_in.transpose((2, 3, 0, 1)) if mean is not None: caffe_in -= mean return caffe_in @@ -283,11 +280,11 @@ def _Net_deprocess(self, input_name, input_): mean = self.mean.get(input_name) if mean is not None: decaf_in += mean - decaf_in = decaf_in.transpose((1,2,0)) + decaf_in = decaf_in.transpose((2,3,0,1)) if channel_order: channel_order_inverse = [channel_order.index(i) for i in range(decaf_in.shape[2])] - decaf_in = decaf_in[:, :, channel_order_inverse] + decaf_in = decaf_in[:, :, channel_order_inverse, :] if input_scale: decaf_in /= input_scale return decaf_in diff --git a/src/caffe/layers/memory_data_layer.cpp b/src/caffe/layers/memory_data_layer.cpp index e71c8b784a..6d11d7ffbb 100644 --- a/src/caffe/layers/memory_data_layer.cpp +++ b/src/caffe/layers/memory_data_layer.cpp @@ -14,12 +14,13 @@ void MemoryDataLayer::SetUp(const vector*>& bottom, CHECK_EQ(top->size(), 2) << "Memory Data Layer takes two blobs as output."; batch_size_ = this->layer_param_.memory_data_param().batch_size(); datum_channels_ = this->layer_param_.memory_data_param().channels(); + datum_length_ = this->layer_param_.memory_data_param().length(); datum_height_ = this->layer_param_.memory_data_param().height(); datum_width_ = this->layer_param_.memory_data_param().width(); - datum_size_ = datum_channels_ * datum_height_ * datum_width_; - CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, height," + datum_size_ = datum_channels_ * datum_length_ * datum_height_ * datum_width_; + CHECK_GT(batch_size_ * datum_size_, 0) << "batch_size, channels, length, height," " and width must be specified and positive in memory_data_param"; - (*top)[0]->Reshape(batch_size_, datum_channels_, 1, datum_height_, datum_width_); + (*top)[0]->Reshape(batch_size_, datum_channels_, datum_length_, datum_height_, datum_width_); (*top)[1]->Reshape(batch_size_, 1, 1, 1, 1); data_ = NULL; labels_ = NULL; diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index b21561fa44..4e1ff4afcc 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -341,6 +341,7 @@ message LRNParameter { message MemoryDataParameter { optional uint32 batch_size = 1; optional uint32 channels = 2; + optional uint32 length = 5; optional uint32 height = 3; optional uint32 width = 4; } diff --git a/src/caffe/test/test_memory_data_layer.cpp b/src/caffe/test/test_memory_data_layer.cpp index 15f01bd41e..2edb09f530 100644 --- a/src/caffe/test/test_memory_data_layer.cpp +++ b/src/caffe/test/test_memory_data_layer.cpp @@ -19,6 +19,7 @@ class MemoryDataLayerTest : public ::testing::Test { batch_size_ = 8; batches_ = 12; channels_ = 4; + length_ = 5; height_ = 7; width_ = 11; blob_top_vec_.push_back(data_blob_); @@ -26,8 +27,8 @@ class MemoryDataLayerTest : public ::testing::Test { // pick random input data FillerParameter filler_param; GaussianFiller filler(filler_param); - data_->Reshape(batches_ * batch_size_, channels_, height_, width_); - labels_->Reshape(batches_ * batch_size_, 1, 1, 1); + data_->Reshape(batches_ * batch_size_, channels_, length_, height_, width_); + labels_->Reshape(batches_ * batch_size_, 1, 1, 1, 1); filler.Fill(this->data_); filler.Fill(this->labels_); } @@ -41,6 +42,7 @@ class MemoryDataLayerTest : public ::testing::Test { int batch_size_; int batches_; int channels_; + int length_; int height_; int width_; // we don't really need blobs for the input data, but it makes it @@ -62,6 +64,7 @@ TYPED_TEST(MemoryDataLayerTest, TestSetup) { MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); md_param->set_batch_size(this->batch_size_); md_param->set_channels(this->channels_); + md_param->set_length(this->length_); md_param->set_height(this->height_); md_param->set_width(this->width_); shared_ptr > layer( @@ -69,10 +72,12 @@ TYPED_TEST(MemoryDataLayerTest, TestSetup) { layer->SetUp(this->blob_bottom_vec_, &(this->blob_top_vec_)); EXPECT_EQ(this->data_blob_->num(), this->batch_size_); EXPECT_EQ(this->data_blob_->channels(), this->channels_); + EXPECT_EQ(this->data_blob_->length(), this->length_); EXPECT_EQ(this->data_blob_->height(), this->height_); EXPECT_EQ(this->data_blob_->width(), this->width_); EXPECT_EQ(this->label_blob_->num(), this->batch_size_); EXPECT_EQ(this->label_blob_->channels(), 1); + EXPECT_EQ(this->label_blob_->length(), 1); EXPECT_EQ(this->label_blob_->height(), 1); EXPECT_EQ(this->label_blob_->width(), 1); } @@ -83,6 +88,7 @@ TYPED_TEST(MemoryDataLayerTest, TestForward) { MemoryDataParameter* md_param = layer_param.mutable_memory_data_param(); md_param->set_batch_size(this->batch_size_); md_param->set_channels(this->channels_); + md_param->set_length(this->length_); md_param->set_height(this->height_); md_param->set_width(this->width_); shared_ptr > layer(