diff --git a/examples/c3d_feature_extraction/c3d_sport1m_feature_extraction_video_stream.sh b/examples/c3d_feature_extraction/c3d_sport1m_feature_extraction_video_stream.sh new file mode 100644 index 0000000000..b0ed74bad9 --- /dev/null +++ b/examples/c3d_feature_extraction/c3d_sport1m_feature_extraction_video_stream.sh @@ -0,0 +1,4 @@ +mkdir -p output/c3d/v_ApplyEyeMakeup_g01_c01 +mkdir -p output/c3d/v_BaseballPitch_g01_c01 +mkdir -p output/c3d/CAMERA +GLOG_logtosterr=1 ../../build/tools/extract_image_features.bin prototxt/c3d_sport1m_feature_extractor_video_stream.prototxt conv3d_deepnetA_sport1m_iter_1900000 0 50 -1 prototxt/output_list_video_stream_prefix.txt fc7-1 fc6-1 prob diff --git a/examples/c3d_feature_extraction/prototxt/c3d_sport1m_feature_extractor_video_stream.prototxt b/examples/c3d_feature_extraction/prototxt/c3d_sport1m_feature_extractor_video_stream.prototxt new file mode 100644 index 0000000000..7a5a31ba71 --- /dev/null +++ b/examples/c3d_feature_extraction/prototxt/c3d_sport1m_feature_extractor_video_stream.prototxt @@ -0,0 +1,450 @@ +name: "DeepConv3DNet_Sport1M_Val" +layers { + name: "data" + type: VIDEO_DATA + top: "data" + top: "label" + image_data_param { + source: "prototxt/input_list_video_stream.txt" + use_image: false + use_stream: true + mean_file: "sport1m_train16_128_mean.binaryproto" + batch_size: 50 + crop_size: 112 + mirror: false + show_data: 0 + new_height: 128 + new_width: 171 + new_length: 16 + shuffle: false + } +} +# ----------- 1st layer group --------------- +layers { + name: "conv1a" + type: CONVOLUTION3D + bottom: "data" + top: "conv1a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 64 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "relu1a" + type: RELU + bottom: "conv1a" + top: "conv1a" +} +layers { + name: "pool1" + type: POOLING3D + bottom: "conv1a" + top: "pool1" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 1 + stride: 2 + temporal_stride: 1 + } +} +# ------------- 2nd layer group -------------- +layers { + name: "conv2a" + type: CONVOLUTION3D + bottom: "pool1" + top: "conv2a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 128 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu2a" + type: RELU + bottom: "conv2a" + top: "conv2a" +} +layers { + name: "pool2" + type: POOLING3D + bottom: "conv2a" + top: "pool2" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} +# ----------------- 3rd layer group -------------- +layers { + name: "conv3a" + type: CONVOLUTION3D + bottom: "pool2" + top: "conv3a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu3a" + type: RELU + bottom: "conv3a" + top: "conv3a" +} +layers { + name: "conv3b" + type: CONVOLUTION3D + bottom: "conv3a" + top: "conv3b" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 256 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu3b" + type: RELU + bottom: "conv3b" + top: "conv3b" +} +layers { + name: "pool3" + type: POOLING3D + bottom: "conv3b" + top: "pool3" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} + +# --------- 4th layer group +layers { + name: "conv4a" + type: CONVOLUTION3D + bottom: "pool3" + top: "conv4a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 512 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4a" + type: RELU + bottom: "conv4a" + top: "conv4a" +} +layers { + name: "conv4b" + type: CONVOLUTION3D + bottom: "conv4a" + top: "conv4b" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 512 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu4b" + type: RELU + bottom: "conv4b" + top: "conv4b" +} +layers { + name: "pool4" + type: POOLING3D + bottom: "conv4b" + top: "pool4" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} + +# --------------- 5th layer group -------- +layers { + name: "conv5a" + type: CONVOLUTION3D + bottom: "pool4" + top: "conv5a" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 512 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5a" + type: RELU + bottom: "conv5a" + top: "conv5a" +} +layers { + name: "conv5b" + type: CONVOLUTION3D + bottom: "conv5a" + top: "conv5b" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + convolution_param { + num_output: 512 + kernel_size: 3 + kernel_depth: 3 + pad: 1 + temporal_pad: 1 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu5b" + type: RELU + bottom: "conv5b" + top: "conv5b" +} + +layers { + name: "pool5" + type: POOLING3D + bottom: "conv5b" + top: "pool5" + pooling_param { + pool: MAX + kernel_size: 2 + kernel_depth: 2 + stride: 2 + temporal_stride: 2 + } +} +# ---------------- fc layers ------------- +layers { + name: "fc6-1" + type: INNER_PRODUCT + bottom: "pool5" + top: "fc6-1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu6" + type: RELU + bottom: "fc6-1" + top: "fc6-1" +} +layers { + name: "drop6" + type: DROPOUT + bottom: "fc6-1" + top: "fc6-1" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc7-1" + type: INNER_PRODUCT + bottom: "fc6-1" + top: "fc7-1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 4096 + weight_filler { + type: "gaussian" + std: 0.005 + } + bias_filler { + type: "constant" + value: 1 + } + } +} +layers { + name: "relu7" + type: RELU + bottom: "fc7-1" + top: "fc7-1" +} +layers { + name: "drop7" + type: DROPOUT + bottom: "fc7-1" + top: "fc7-1" + dropout_param { + dropout_ratio: 0.5 + } +} +layers { + name: "fc8-1" + type: INNER_PRODUCT + bottom: "fc7-1" + top: "fc8-1" + blobs_lr: 1 + blobs_lr: 2 + weight_decay: 1 + weight_decay: 0 + inner_product_param { + num_output: 487 + weight_filler { + type: "gaussian" + std: 0.01 + } + bias_filler { + type: "constant" + value: 0 + } + } +} +layers { + name: "prob" + type: SOFTMAX + bottom: "fc8-1" + top: "prob" +} +layers { + name: "accuracy" + type: ACCURACY + bottom: "prob" + bottom: "label" + top: "accuracy" + #top: "prediction_truth" +} diff --git a/examples/c3d_feature_extraction/prototxt/input_list_video_stream.txt b/examples/c3d_feature_extraction/prototxt/input_list_video_stream.txt new file mode 100644 index 0000000000..6c8316fa44 --- /dev/null +++ b/examples/c3d_feature_extraction/prototxt/input_list_video_stream.txt @@ -0,0 +1,3 @@ +input/avi/v_ApplyEyeMakeup_g01_c01.avi 0 16 0 +input/avi/v_BaseballPitch_g01_c01.avi 0 16 0 +CAMERA 0 16 0 diff --git a/examples/c3d_feature_extraction/prototxt/output_list_video_stream_prefix.txt b/examples/c3d_feature_extraction/prototxt/output_list_video_stream_prefix.txt new file mode 100644 index 0000000000..69721440b7 --- /dev/null +++ b/examples/c3d_feature_extraction/prototxt/output_list_video_stream_prefix.txt @@ -0,0 +1 @@ +output/c3d/ diff --git a/include/caffe/util/image_io.hpp b/include/caffe/util/image_io.hpp index 54337bde34..782eb77cb1 100644 --- a/include/caffe/util/image_io.hpp +++ b/include/caffe/util/image_io.hpp @@ -59,6 +59,9 @@ inline bool ReadImageSequenceToVolumeDatum(const char* img_dir, const int start_ return ReadImageSequenceToVolumeDatum(img_dir, start_frm, label, length, 0, 0, sampling_rate, datum); } +bool ReadImageVectorToVolumeDatum(const std::vector& imgs, const int label, + const int length, const int height, const int width, VolumeDatum* datum); + template bool load_blob_from_binary(const string fn_blob, Blob* blob); diff --git a/include/caffe/video_data_layer.hpp b/include/caffe/video_data_layer.hpp index ab2f3e7bb1..1980cb6f99 100644 --- a/include/caffe/video_data_layer.hpp +++ b/include/caffe/video_data_layer.hpp @@ -21,6 +21,7 @@ #include #include #include +#include #include "pthread.h" #include "boost/scoped_ptr.hpp" @@ -49,6 +50,9 @@ class VideoDataLayer : public Layer { virtual void SetUp(const vector*>& bottom, vector*>* top); + vector pop_stream_names(int num); + bool is_stream_done() { return is_stream_done_; } + protected: virtual Dtype Forward_cpu(const vector*>& bottom, vector*>* top); @@ -66,9 +70,14 @@ class VideoDataLayer : public Layer { shared_ptr prefetch_rng_; vector file_list_; vector start_frm_list_; + vector num_of_frames_list_; + vector interval_list_; vector label_list_; vector shuffle_index_; int lines_id_; + int interval_jump_; + std::deque stream_names_; + bool is_stream_done_; int datum_channels_; int datum_length_; diff --git a/src/caffe/layers/video_data_layer.cpp b/src/caffe/layers/video_data_layer.cpp index 53898a480d..f45981e9bc 100644 --- a/src/caffe/layers/video_data_layer.cpp +++ b/src/caffe/layers/video_data_layer.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include @@ -43,6 +44,34 @@ using std::string; namespace caffe { +const string win_name = "Live video"; +const int waitKeyDelay = 10; +cv::VideoCapture video_cap_; + +string GetStreamName(string path, int start_frm) +{ + string name, ext; + size_t sep = path.find_last_of("\\/"); + if (sep != std::string::npos) + path = path.substr(sep + 1, path.size() - sep - 1); + size_t dot = path.find_last_of("."); + if (dot != std::string::npos) + { + name = path.substr(0, dot); + ext = path.substr(dot, path.size() - dot); + } + else + { + name = path; + ext = ""; + } + + std::ostringstream oss; + oss << std::setfill('0') << std::setw(6) << start_frm; + + return name + '/' + oss.str(); +} + template void* VideoDataLayerPrefetch(void* layer_pointer) { CHECK(layer_pointer); @@ -64,6 +93,7 @@ void* VideoDataLayerPrefetch(void* layer_pointer) { const int new_height = layer->layer_param_.image_data_param().new_height(); const int new_width = layer->layer_param_.image_data_param().new_width(); const bool use_image = layer->layer_param_.image_data_param().use_image(); + const bool use_stream = layer->layer_param_.image_data_param().use_stream(); const int sampling_rate = layer->layer_param_.image_data_param().sampling_rate(); const bool use_temporal_jitter = layer->layer_param_.image_data_param().use_temporal_jitter(); @@ -88,7 +118,57 @@ void* VideoDataLayerPrefetch(void* layer_pointer) { CHECK_GT(chunks_size, layer->lines_id_); bool read_status; int id = layer->shuffle_index_[layer->lines_id_]; - if (!use_image){ + if (use_stream){ + if (!layer->file_list_[id].compare("CAMERA")){ + int camera_index = layer->start_frm_list_[id]; + if (!video_cap_.isOpened()){ + video_cap_.open(camera_index); + if (!video_cap_.isOpened()){ + LOG(ERROR) << "Cannot open CAM " << camera_index; + return static_cast(NULL); + } + } + + vector live_imgs; + for (int i=0; i(NULL); + } + imshow(win_name, img); + sampling_count++; + } + live_imgs.push_back(img); + { + cv::Mat img_disp = img.clone(); + copyMakeBorder( img_disp, img_disp, 2, 2, 2, 2, cv::BORDER_CONSTANT, cv::Scalar(0,255,0) ); + imshow(win_name, img_disp); + } + } + int start_frm = layer->interval_jump_++ * layer->interval_list_[id]; + layer->stream_names_.push_back(GetStreamName("CAMERA", start_frm)); + read_status = ReadImageVectorToVolumeDatum(live_imgs, + layer->label_list_[id], new_length, new_height, new_width, &datum); + + } else { + int start_frm = layer->start_frm_list_[id] + layer->interval_jump_++ * layer->interval_list_[id]; + int end_frm = start_frm + new_length * sampling_rate; + if (end_frm > layer->num_of_frames_list_[id]) { + LOG(INFO) << "Done extracting features from " << layer->file_list_[id]; + read_status = false; + } else { + layer->stream_names_.push_back(GetStreamName(layer->file_list_[id].c_str(), start_frm)); + read_status = ReadVideoToVolumeDatum(layer->file_list_[id].c_str(), start_frm, + layer->label_list_[id], new_length, new_height, new_width, sampling_rate, &datum); + } + } + } + else if (!use_image){ if (!use_temporal_jitter){ read_status = ReadVideoToVolumeDatum(layer->file_list_[id].c_str(), layer->start_frm_list_[id], layer->label_list_[id], new_length, new_height, new_width, sampling_rate, &datum); @@ -119,13 +199,14 @@ void* VideoDataLayerPrefetch(void* layer_pointer) { } } - if (layer->phase_ == Caffe::TEST){ + if (layer->phase_ == Caffe::TEST && !use_stream){ CHECK(read_status) << "Testing must not miss any example"; } if (!read_status) { //LOG(ERROR) << "cannot read " << layer->file_list_[id]; layer->lines_id_++; + layer->interval_jump_ = 0; if (layer->lines_id_ >= chunks_size) { // We have reached the end. Restart from the first. DLOG(INFO) << "Restarting data prefetching from start."; @@ -236,7 +317,9 @@ void* VideoDataLayerPrefetch(void* layer_pointer) { // LOG(INFO) << "fetching label" << datum.label() << std::endl; } - layer->lines_id_++; + if (!use_stream) { + layer->lines_id_++; + } if (layer->lines_id_ >= chunks_size) { // We have reached the end. Restart from the first. DLOG(INFO) << "Restarting data prefetching from start."; @@ -280,14 +363,31 @@ void VideoDataLayer::SetUp(const vector*>& bottom, const string& source = this->layer_param_.image_data_param().source(); const bool use_temporal_jitter = this->layer_param_.image_data_param().use_temporal_jitter(); const bool use_image = this->layer_param_.image_data_param().use_image(); + const bool use_stream = this->layer_param_.image_data_param().use_stream(); LOG(INFO) << "Opening file " << source; std::ifstream infile(source.c_str()); int count = 0; string filename; - int start_frm, label; + int start_frm, num_of_frames, interval, label; - if ((!use_image) && use_temporal_jitter){ + if (use_stream){ + while (infile >> filename >> start_frm >> interval >> label) { + num_of_frames = -1; + if (filename.compare("CAMERA")){ + cv::VideoCapture cap; + CHECK(cap.open(filename)) << "Cannot open " << filename; + num_of_frames = cap.get(CV_CAP_PROP_FRAME_COUNT); + } + file_list_.push_back(filename); + start_frm_list_.push_back(start_frm); + num_of_frames_list_.push_back(num_of_frames); + interval_list_.push_back(interval); + label_list_.push_back(label); + shuffle_index_.push_back(count); + count++; + } + } else if ((!use_image) && use_temporal_jitter){ while (infile >> filename >> label) { file_list_.push_back(filename); label_list_.push_back(label); @@ -317,6 +417,7 @@ void VideoDataLayer::SetUp(const vector*>& bottom, LOG(INFO) << "A total of " << shuffle_index_.size() << " video chunks."; lines_id_ = 0; + interval_jump_ = 0; // Check if we would need to randomly skip a few data points if (this->layer_param_.image_data_param().rand_skip()) { @@ -330,8 +431,48 @@ void VideoDataLayer::SetUp(const vector*>& bottom, // Read a data point, and use it to initialize the top blob. VolumeDatum datum; int id = shuffle_index_[lines_id_]; - if (!use_image){ - if (use_temporal_jitter){ + if (use_stream){ + if (!file_list_[id].compare("CAMERA")){ + int camera_index = start_frm_list_[id]; + if (!video_cap_.isOpened()){ + video_cap_.open(camera_index); + if (!video_cap_.isOpened()){ + LOG(ERROR) << "Cannot open CAM " << camera_index; + return; + } + } + + cv::namedWindow(win_name, CV_WINDOW_AUTOSIZE); + vector live_imgs; + for (int i=0; i::SetUp(const vector*>& bottom, // Now, start the prefetch thread. Before calling prefetch, we make two // cpu_data calls so that the prefetch thread does not accidentally make // simultaneous cudaMalloc calls when the main thread is running. In some - // GPUs this seems to cause failures if we do not so. + // GPUs this seems to cause failures if we do not do so. prefetch_data_->mutable_cpu_data(); if (output_labels_) { prefetch_label_->mutable_cpu_data(); @@ -465,6 +606,18 @@ Dtype VideoDataLayer::Forward_cpu(const vector*>& bottom, return Dtype(0.); } +template +vector VideoDataLayer::pop_stream_names(int num) { + vector popped_names; + for (int i=0; i& imgs, const int label, + const int length, const int height, const int width, VolumeDatum* datum){ + cv::Mat img; + char *buffer; + int offset, channel_size, image_size, data_size; + + datum->set_channels(3); + datum->set_length(length); + datum->set_label(label); + datum->clear_data(); + datum->clear_float_data(); + + offset = 0; + for (int i=0; i 0 && width > 0) { + cv::resize(img, img, cv::Size(width, height)); + } + + if (!img.data){ + LOG(ERROR) << "Could not open image"; + return false; + } + + if (i==0){ + datum->set_height(img.rows); + datum->set_width(img.cols); + image_size = img.rows * img.cols; + channel_size = image_size * length; + data_size = channel_size * 3; + buffer = new char[data_size]; + } + for (int c=0; c<3; c++){ + ImageChannelToBuffer(&img, buffer + c * channel_size + offset, c); + } + offset += image_size; + } + CHECK(offset == channel_size) << "wrong offset size" << std::endl; + datum->set_data(buffer, data_size); + delete []buffer; + return true; +} + template <> bool load_blob_from_binary(const string fn_blob, Blob* blob){ FILE *f; diff --git a/tools/extract_image_features.cpp b/tools/extract_image_features.cpp index 20504705da..2cbc48cffb 100644 --- a/tools/extract_image_features.cpp +++ b/tools/extract_image_features.cpp @@ -26,6 +26,7 @@ #include "caffe/common.hpp" #include "caffe/net.hpp" #include "caffe/vision_layers.hpp" +#include "caffe/video_data_layer.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/util/io.hpp" #include "caffe/util/image_io.hpp" @@ -45,18 +46,18 @@ int feature_extraction_pipeline(int argc, char** argv) { char* pretrained_model = argv[2]; int device_id = atoi(argv[3]); uint batch_size = atoi(argv[4]); - uint num_mini_batches = atoi(argv[5]); + int num_mini_batches = atoi(argv[5]); char* fn_feat = argv[6]; Caffe::set_phase(Caffe::TEST); if (device_id>=0){ - Caffe::set_mode(Caffe::GPU); - Caffe::SetDevice(device_id); - LOG(ERROR) << "Using GPU #" << device_id; + Caffe::set_mode(Caffe::GPU); + Caffe::SetDevice(device_id); + LOG(ERROR) << "Using GPU #" << device_id; } else{ - Caffe::set_mode(Caffe::CPU); - LOG(ERROR) << "Using CPU"; + Caffe::set_mode(Caffe::CPU); + LOG(ERROR) << "Using CPU"; } shared_ptr > feature_extraction_net( @@ -69,6 +70,51 @@ int feature_extraction_pipeline(int argc, char** argv) { << " in the network " << string(net_proto); } + if (num_mini_batches < 0) + { + LOG(ERROR)<< "Extracting features until program is terminated"; + std::ifstream infile(fn_feat); + string feat_prefix; + std::vector list_prefix; + int c = 0; + infile >> feat_prefix; + + vector*> input_vec; + int image_index = 0; + + boost::shared_ptr > data_layer; + boost::shared_ptr< VideoDataLayer > video_data_layer; + data_layer = feature_extraction_net->layers()[0]; + video_data_layer = boost::dynamic_pointer_cast< VideoDataLayer >(data_layer); + if (!video_data_layer) { + LOG(ERROR)<< "This mode may only be used if the first layer is a VideoDataLayer"; + } + + while (1) { + feature_extraction_net->Forward(input_vec); + list_prefix = video_data_layer->pop_stream_names(batch_size); + + if (list_prefix.empty()) + break; + for (int k=7; k > feature_blob = feature_extraction_net + ->blob_by_name(string(argv[k])); + int num_features = feature_blob->num(); + + Dtype* feature_blob_data; + for (int n = 0; n < num_features; ++n) { + if (list_prefix.size()>n){ + string fn_feat = feat_prefix + list_prefix[n] + string(".") + string(argv[k]); + save_blob_to_binary(feature_blob.get(), fn_feat, n); + } + } + } + image_index += list_prefix.size(); + } + LOG(ERROR)<< "Successfully extracted " << image_index << " features!"; + return 0; + } + LOG(ERROR)<< "Extracting features for " << num_mini_batches << " batches"; std::ifstream infile(fn_feat); string feat_prefix; @@ -82,26 +128,26 @@ int feature_extraction_pipeline(int argc, char** argv) { feature_extraction_net->Forward(input_vec); list_prefix.clear(); for (int n=0; n> feat_prefix) - list_prefix.push_back(feat_prefix); - else - break; + if (infile >> feat_prefix) + list_prefix.push_back(feat_prefix); + else + break; } if (list_prefix.empty()) - break; + break; for (int k=7; k > feature_blob = feature_extraction_net + const shared_ptr > feature_blob = feature_extraction_net ->blob_by_name(string(argv[k])); - int num_features = feature_blob->num(); + int num_features = feature_blob->num(); - Dtype* feature_blob_data; - for (int n = 0; n < num_features; ++n) { - if (list_prefix.size()>n){ - string fn_feat = list_prefix[n] + string(".") + string(argv[k]); - save_blob_to_binary(feature_blob.get(), fn_feat, n); - } + Dtype* feature_blob_data; + for (int n = 0; n < num_features; ++n) { + if (list_prefix.size()>n){ + string fn_feat = list_prefix[n] + string(".") + string(argv[k]); + save_blob_to_binary(feature_blob.get(), fn_feat, n); } + } } image_index += list_prefix.size(); if (batch_index % 100 == 0) {