diff --git a/src/caffe/data_transformer.cpp b/src/caffe/data_transformer.cpp index 1ef4bc670..3165ee057 100644 --- a/src/caffe/data_transformer.cpp +++ b/src/caffe/data_transformer.cpp @@ -370,7 +370,9 @@ void DataTransformer::Transform(const cv::Mat& cv_img, CHECK_LE(width, img_width); CHECK_GE(num, 1); - CHECK(cv_img.depth() == CV_8U) << "Image data type must be unsigned byte"; + CHECK(cv_img.depth() == CV_8U || cv_img.depth() == CV_32F || + cv_img.depth() == CV_32S) << + "Image data type must be unsigned byte, float or integer"; const Dtype scale = param_.scale(); @@ -422,7 +424,9 @@ void DataTransformer::Transform(const cv::Mat& cv_img, Dtype* transformed_data = transformed_blob->mutable_cpu_data(); int top_index; for (int h = 0; h < height; ++h) { - const uchar* ptr = cv_cropped_img.ptr(h); + const uchar* ptrUchar = cv_cropped_img.ptr(h); + const float* ptrFloat = cv_cropped_img.ptr(h); + const int32_t* ptrInt = cv_cropped_img.ptr(h); int img_index = 0; for (int w = 0; w < width; ++w) { for (int c = 0; c < img_channels; ++c) { @@ -432,7 +436,14 @@ void DataTransformer::Transform(const cv::Mat& cv_img, top_index = (c * height + h) * width + w; } // int top_index = (c * height + h) * width + w; - Dtype pixel = static_cast(ptr[img_index++]); + Dtype pixel = 0; + + if (cv_img.depth() == CV_8U) + pixel = static_cast(ptrUchar[img_index++]); + if (cv_img.depth() == CV_32F) + pixel = static_cast(ptrFloat[img_index++]); + if (cv_img.depth() == CV_32S) + pixel = static_cast(ptrInt[img_index++]); if (has_mean_file) { int mean_index = (c * img_height + h_off + h) * img_width + w_off + w; transformed_data[top_index] =