diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 62b53f9a76c..a85229b2b86 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -343,8 +343,7 @@ vTensorStorage::vTensorStorage( storage_type_, dtype, allocate_memory)), - last_access_{}, - has_copies_{false} {} + last_access_{} {} vTensorStorage::vTensorStorage( Context* const context, @@ -361,21 +360,6 @@ vTensorStorage::vTensorStorage( buffer_(vkapi::VulkanBuffer()), last_access_{} {} -vTensorStorage::vTensorStorage( - vTensorStorage& other, - const int64_t buffer_offset) - : context_(other.context_), - storage_type_{other.storage_type_}, - image_extents_(other.image_extents_), - buffer_length_{other.buffer_length_}, - buffer_offset_{buffer_offset}, - image_(other.image_), - buffer_(other.buffer_, buffer_offset), - last_access_{other.last_access_}, - has_copies_{false} { - other.has_copies_ = true; -} - vTensorStorage::~vTensorStorage() { flush(); } @@ -397,21 +381,6 @@ void vTensorStorage::transition( vkapi::PipelineStageFlags prev_stage = last_access_.stage; vkapi::MemoryAccessFlags prev_access = last_access_.access; - // If the underlying resource is a copy of another tensor's resource the - // last_access may not be accurate, since the original storage may have been - // written to as part of the original tensor. Likewise, if the underlying - // resource has copies, then the resource may have been updated as part of the - // view tensors. - // - // If the resource is a copy, or has copies of it, then cowardly assume that - // it has previously been written to as part of a compute shader before the - // current access event so that the appropriate memory barriers may be - // inserted. - if (is_copy() || has_copies_) { - prev_stage = vkapi::PipelineStage::COMPUTE; - prev_access = vkapi::kWrite; - } - const bool prev_written = (prev_access & vkapi::MemoryAccessType::WRITE) != 0; VkImageLayout cur_layout = VK_IMAGE_LAYOUT_UNDEFINED; @@ -458,20 +427,6 @@ void vTensorStorage::transition( last_access_.access = cur_access; } -bool vTensorStorage::is_copy() const { - if (storage_type_ == utils::kBuffer) { - return buffer_.is_copy(); - } - return image_.is_copy(); -} - -bool vTensorStorage::is_copy_of(const vTensorStorage& other) const { - if (storage_type_ == utils::kBuffer) { - return buffer_.is_copy_of(other.buffer_); - } - return image_.is_copy_of(other.image_); -} - // // vTensor // @@ -503,14 +458,14 @@ vTensor::vTensor( numel_uniform_offset_(kUniformOffsetUnset), logical_limits_uniform_offset_(kUniformOffsetUnset), // Construct Tensor storage - storage_( + storage_(std::make_shared( context, storage_type, axis_map_, packed_dim_, padded_sizes_, dtype_, - allocate_memory) { + allocate_memory)) { uniform_data_ = std::make_shared(UniformData{ sizes_, unsqueezed_strides_, @@ -519,7 +474,7 @@ vTensor::vTensor( VK_CHECK_COND( dim_order_is_valid(dim_order_), "computed dim order is invalid"); - set_logical_limits(storage_.image_extents_); + set_logical_limits(storage_->image_extents_); } // NOLINTNEXTLINE @@ -546,13 +501,13 @@ vTensor::vTensor( numel_uniform_offset_(kUniformOffsetUnset), logical_limits_uniform_offset_(kUniformOffsetUnset), // Construct Tensor storage - storage_(context, image) { + storage_(std::make_shared(context, image)) { uniform_data_ = std::make_shared(UniformData{ sizes_, {0, 0, 0, 0}, {{0, 0, 0}}, static_cast(utils::multiply_integers(sizes_))}); - set_logical_limits(storage_.image_extents_); + set_logical_limits(storage_->image_extents_); } vTensor::vTensor(vTensor& other) @@ -583,8 +538,7 @@ vTensor::vTensor(vTensor& other) vTensor::vTensor( vTensor& other, const std::vector& sizes, - const std::vector& dim_order, - const int64_t offset_numel) + const std::vector& dim_order) : dtype_(other.dtype_), // Copy tensor size metadata sizes_(sizes.begin(), sizes.end()), @@ -604,7 +558,7 @@ vTensor::vTensor( numel_uniform_offset_(kUniformOffsetUnset), logical_limits_uniform_offset_(kUniformOffsetUnset), // Copy Tensor storage - storage_(other.storage_, vkapi::element_size(dtype_) * offset_numel) { + storage_(other.storage_) { uniform_data_ = std::make_shared(UniformData{ sizes_, unsqueezed_strides_, @@ -613,10 +567,6 @@ vTensor::vTensor( VK_CHECK_COND( dim_order_is_valid(dim_order_), "new dim order provided is invalid"); - VK_CHECK_COND( - offset_numel + numel() <= other.numel(), - "Tensor alias cannot access more elements than available in the original" - "tensor"); } uint32_t vTensor::UniformData::write_attribute( @@ -647,31 +597,31 @@ uint32_t vTensor::UniformData::write_attribute( vkapi::VulkanImage& vTensor::image( vkapi::PipelineBarrier& pipeline_barrier, const vkapi::PipelineStageFlags stage) & { - storage_.transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ); - return storage_.image_; + storage_->transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ); + return storage_->image_; } vkapi::VulkanImage& vTensor::image( vkapi::PipelineBarrier& pipeline_barrier, const vkapi::PipelineStageFlags stage, const vkapi::MemoryAccessFlags access) & { - storage_.transition(pipeline_barrier, stage, access); - return storage_.image_; + storage_->transition(pipeline_barrier, stage, access); + return storage_->image_; } vkapi::VulkanBuffer& vTensor::buffer( vkapi::PipelineBarrier& pipeline_barrier, const vkapi::PipelineStageFlags stage) & { - storage_.transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ); - return storage_.buffer_; + storage_->transition(pipeline_barrier, stage, vkapi::MemoryAccessType::READ); + return storage_->buffer_; } vkapi::VulkanBuffer& vTensor::buffer( vkapi::PipelineBarrier& pipeline_barrier, const vkapi::PipelineStageFlags stage, const vkapi::MemoryAccessFlags access) & { - storage_.transition(pipeline_barrier, stage, access); - return storage_.buffer_; + storage_->transition(pipeline_barrier, stage, access); + return storage_->buffer_; } void vTensor::set_logical_limits(const utils::uvec3& image_extents) { @@ -695,10 +645,10 @@ utils::GPUMemoryLayout vTensor::estimate_memory_layout() const { const vkapi::BufferBindInfo vTensor::sizes_ubo() { const size_t size_per_ubo = - storage_.context_->adapter_ptr()->min_ubo_alignment(); + storage_->context_->adapter_ptr()->min_ubo_alignment(); const size_t max_ubo_size = kMaxMetadataFieldCount * size_per_ubo; if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, max_ubo_size, true); + uniforms_ = ParamsBuffer(storage_->context_, max_ubo_size, true); } if (sizes_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -714,10 +664,10 @@ const vkapi::BufferBindInfo vTensor::sizes_ubo() { const vkapi::BufferBindInfo vTensor::strides_ubo() { const size_t size_per_ubo = - storage_.context_->adapter_ptr()->min_ubo_alignment(); + storage_->context_->adapter_ptr()->min_ubo_alignment(); const size_t max_ubo_size = kMaxMetadataFieldCount * size_per_ubo; if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, max_ubo_size, true); + uniforms_ = ParamsBuffer(storage_->context_, max_ubo_size, true); } if (unsqueezed_strides_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -735,10 +685,10 @@ const vkapi::BufferBindInfo vTensor::strides_ubo() { const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { const size_t size_per_ubo = - storage_.context_->adapter_ptr()->min_ubo_alignment(); + storage_->context_->adapter_ptr()->min_ubo_alignment(); const size_t max_ubo_size = kMaxMetadataFieldCount * size_per_ubo; if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, max_ubo_size, true); + uniforms_ = ParamsBuffer(storage_->context_, max_ubo_size, true); } if (logical_limits_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -754,10 +704,10 @@ const vkapi::BufferBindInfo vTensor::logical_limits_ubo() { const vkapi::BufferBindInfo vTensor::numel_ubo() { const size_t size_per_ubo = - storage_.context_->adapter_ptr()->min_ubo_alignment(); + storage_->context_->adapter_ptr()->min_ubo_alignment(); const size_t max_ubo_size = kMaxMetadataFieldCount * size_per_ubo; if (!uniforms_.buffer()) { - uniforms_ = ParamsBuffer(storage_.context_, max_ubo_size, true); + uniforms_ = ParamsBuffer(storage_->context_, max_ubo_size, true); } if (numel_uniform_offset_ == kUniformOffsetUnset) { VK_CHECK_COND( @@ -774,7 +724,7 @@ const vkapi::BufferBindInfo vTensor::numel_ubo() { size_t vTensor::staging_buffer_numel() const { const bool is_int8 = dtype_ == vkapi::kChar; const bool int8_supported = - storage_.context_->adapter_ptr()->has_full_int8_buffers_support(); + storage_->context_->adapter_ptr()->has_full_int8_buffers_support(); if (is_int8 && !int8_supported) { return utils::align_up_4(numel()); } @@ -787,10 +737,10 @@ size_t vTensor::staging_buffer_numel() const { VkMemoryRequirements vTensor::get_memory_requirements() const { switch (storage_type()) { case utils::kBuffer: - return storage_.buffer_.get_memory_requirements(); + return storage_->buffer_.get_memory_requirements(); case utils::kTexture2D: case utils::kTexture3D: - return storage_.image_.get_memory_requirements(); + return storage_->image_.get_memory_requirements(); } return {}; } @@ -798,11 +748,11 @@ VkMemoryRequirements vTensor::get_memory_requirements() const { void vTensor::bind_allocation(const vkapi::Allocation& allocation) { switch (storage_type()) { case utils::kBuffer: - storage_.buffer_.bind_allocation(allocation); + storage_->buffer_.bind_allocation(allocation); break; case utils::kTexture2D: case utils::kTexture3D: - storage_.image_.bind_allocation(allocation); + storage_->image_.bind_allocation(allocation); break; } } @@ -845,11 +795,11 @@ void vTensor::check_sizes(const std::vector& sizes) const { utils::uvec3 virtual_extents = calculate_image_extents(padded_sizes_, axis_map_, packed_dim_); - bool valid_resize = virtual_extents[0] <= storage_.image_extents_[0]; + bool valid_resize = virtual_extents[0] <= storage_->image_extents_[0]; valid_resize = - valid_resize && virtual_extents[1] <= storage_.image_extents_[1]; + valid_resize && virtual_extents[1] <= storage_->image_extents_[1]; valid_resize = - valid_resize && virtual_extents[2] <= storage_.image_extents_[2]; + valid_resize && virtual_extents[2] <= storage_->image_extents_[2]; VK_CHECK_COND( valid_resize, @@ -859,7 +809,7 @@ void vTensor::check_sizes(const std::vector& sizes) const { // new sizes of the tensor. int64_t numel = utils::multiply_integers(sizes); bool valid_resize = - numel + storage_.buffer_offset_ <= storage_.buffer_length_; + numel + storage_->buffer_offset_ <= storage_->buffer_length_; VK_CHECK_COND( valid_resize, "tensor sizes requires a larger buffer than the current one."); diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index d9cbadb46b9..850dc2d7fab 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -97,19 +97,8 @@ class vTensorStorage final { vTensorStorage(Context* const context, const vkapi::VulkanImage& image); - protected: - /* - * This allows for creation of tensors that use the same underlying storage - * as another tensor. Note that this functionality is currently enabled for - * tensors that have buffer storage only. The created tensor will not have - * ownership of the underlying VkBuffer. This constructor is marked protected - * because this behaviour is unsafe, since the original tensor may be - * destroyed before the copy is destroyed. - */ - vTensorStorage(vTensorStorage& other, const int64_t buffer_offset = 0); - public: - // To discourage creating copies, the assignment operator is still deleted. + vTensorStorage(vTensorStorage& other) = delete; vTensorStorage& operator=(const vTensorStorage& other) = delete; vTensorStorage(vTensorStorage&& other) = default; @@ -136,8 +125,6 @@ class vTensorStorage final { // Last Access - used to insert memory barriers LastAccess last_access_; - // Indicates whether copies of this vTensorStorage have been made - bool has_copies_; private: // Registers underlying memory for cleanup @@ -156,16 +143,6 @@ class vTensorStorage final { inline VkFormat texture_format() { return image_.format(); } - - /* - * Check if the underlying resource is a copy of another resource - */ - bool is_copy() const; - - /* - * Used for checking if this vTensorStorage is a copy of another instance - */ - bool is_copy_of(const vTensorStorage& other) const; }; class vTensor final { @@ -222,8 +199,7 @@ class vTensor final { vTensor( vTensor& other, const std::vector& sizes, - const std::vector& dim_order, - const int64_t offset_numel = 0); + const std::vector& dim_order); // To discourage making copies, the copy assignment operator is still deleted vTensor& operator=(const vTensor& other) = delete; @@ -358,7 +334,7 @@ class vTensor final { // impossible for a ubo to have an offset of 1. constexpr static uint32_t kUniformOffsetUnset = 1; - vTensorStorage storage_; + std::shared_ptr storage_; std::shared_ptr uniform_data_; @@ -368,7 +344,7 @@ class vTensor final { */ inline vkapi::VulkanImage& image() const& { - return storage_.image_; + return storage_->image_; } vkapi::VulkanImage& image( @@ -381,7 +357,7 @@ class vTensor final { const vkapi::MemoryAccessFlags) &; inline vkapi::VulkanBuffer& buffer() const& { - return storage_.buffer_; + return storage_->buffer_; } vkapi::VulkanBuffer& buffer( @@ -398,11 +374,11 @@ class vTensor final { */ inline utils::StorageType storage_type() const { - return storage_.storage_type_; + return storage_->storage_type_; } inline bool has_buffer_storage() const { - return storage_.storage_type_ == utils::kBuffer; + return storage_->storage_type_ == utils::kBuffer; } private: @@ -623,7 +599,7 @@ class vTensor final { * Check if this vTensor instance is a view of another vTensor instance */ inline bool is_view_of(const vTensor& other) const { - return storage_.is_copy_of(other.storage_); + return storage_.get() == other.storage_.get(); } const std::shared_ptr& get_uniform_data() const { diff --git a/backends/vulkan/runtime/graph/ComputeGraph.cpp b/backends/vulkan/runtime/graph/ComputeGraph.cpp index 8e498d5f2d1..5dc26286682 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.cpp +++ b/backends/vulkan/runtime/graph/ComputeGraph.cpp @@ -378,27 +378,16 @@ ValueRef ComputeGraph::add_tensor_view(const ValueRef vref) { const vTensorPtr t = get_tensor(vref); ValueRef idx(static_cast(values_.size())); values_.emplace_back(api::vTensor(*t)); - for (SharedObject& sobj : shared_objects_) { - if (sobj.has_user(vref)) { - sobj.add_user(this, idx); - } - } return idx; } ValueRef ComputeGraph::add_tensor_view( const ValueRef vref, const std::vector& sizes, - const std::vector& strides, - const size_t offset_numel) { + const std::vector& strides) { const vTensorPtr t = get_tensor(vref); ValueRef idx(static_cast(values_.size())); - values_.emplace_back(api::vTensor(*t, sizes, strides, offset_numel)); - for (SharedObject& sobj : shared_objects_) { - if (sobj.has_user(vref)) { - sobj.add_user(this, idx); - } - } + values_.emplace_back(api::vTensor(*t, sizes, strides)); return idx; } diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index 90f89ea18d6..31514989dfc 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -573,8 +573,7 @@ class ComputeGraph final { ValueRef add_tensor_view( const ValueRef vref, const std::vector& sizes, - const std::vector& dim_order, - const size_t offset_numel = 0); + const std::vector& dim_order); /* * Add a `TensorRef` value to the graph with the specific properties. A diff --git a/backends/vulkan/test/utils/test_utils.cpp b/backends/vulkan/test/utils/test_utils.cpp index e842500e6be..3497aeb5705 100644 --- a/backends/vulkan/test/utils/test_utils.cpp +++ b/backends/vulkan/test/utils/test_utils.cpp @@ -137,45 +137,6 @@ void record_bitw8_image_to_nchw_nobitw8buffer_op( v_src.numel_ubo()); } -void record_conv2d_prepack_weights_op( - api::Context* const context, - vkapi::VulkanBuffer& src_buffer, - api::vTensor& v_dst, - const std::vector& original_sizes, - const bool transposed) { - vkapi::PipelineBarrier pipeline_barrier{}; - - std::string kernel_name; - if (transposed) { - kernel_name = "conv_transpose2d"; - } else { - kernel_name = "conv2d"; - } - kernel_name += "_prepack_weights"; - add_dtype_suffix(kernel_name, v_dst); - vkapi::ShaderInfo shader = VK_KERNEL_FROM_STR(kernel_name); - - api::ParamsBuffer original_sizes_ubo( - context, utils::make_ivec4(original_sizes, /*reverse = */ true)); - - vkapi::SpecVarList specialization_constants = {}; - context->submit_compute_job( - shader, - pipeline_barrier, - v_dst.logical_limits(), - adaptive_work_group_size(v_dst.logical_limits()), - specialization_constants, - VK_NULL_HANDLE, - 0, - v_dst.image( - pipeline_barrier, - vkapi::PipelineStage::COMPUTE, - vkapi::MemoryAccessType::WRITE), - src_buffer, - v_dst.sizes_ubo(), - original_sizes_ubo.buffer()); -} - void record_binary_op( api::Context* const context, const std::string& op_name, diff --git a/backends/vulkan/test/vulkan_compute_api_test.cpp b/backends/vulkan/test/vulkan_compute_api_test.cpp index 60b5ccb1a80..f89d4dca705 100644 --- a/backends/vulkan/test/vulkan_compute_api_test.cpp +++ b/backends/vulkan/test/vulkan_compute_api_test.cpp @@ -821,58 +821,6 @@ TEST_F(VulkanComputeAPITest, tensor_no_copy_transpose_test) { } } -TEST_F(VulkanComputeAPITest, tensor_no_copy_slice_test) { - constexpr int L = 31; - - // S{N} refers to slice {N} - constexpr int L_S1 = 17; - constexpr int O_S1 = 5; - - constexpr int L_S2 = 7; - constexpr int O_S2 = 3; - - std::vector dim_order = {0}; - - std::vector t_sizes = {L}; - std::vector s1_sizes = {L_S1}; - std::vector s2_sizes = {L_S2}; - - vTensor orig = CREATE_FLOAT_BUFFER(t_sizes, /*allocate_memory=*/true); - - fill_vtensor(orig, 0); - - vTensor s1 = vTensor(orig, s1_sizes, dim_order, O_S1); - vTensor s2 = vTensor(s1, s2_sizes, dim_order, O_S2); - - record_scalar_add_buffer(api::context(), s1, 4.5f); - record_scalar_add_buffer(api::context(), s2, 7.5f); - - std::vector orig_data(orig.staging_buffer_numel()); - extract_vtensor(orig, orig_data); - - int id = 0; - while (id < O_S1) { - EXPECT_TRUE(orig_data[id] == 0); - ++id; - } - while (id < O_S1 + O_S2) { - EXPECT_TRUE(orig_data[id] == 4.5); - ++id; - } - while (id < O_S1 + O_S2 + L_S2) { - EXPECT_TRUE(orig_data[id] == 12); - ++id; - } - while (id < O_S1 + L_S1) { - EXPECT_TRUE(orig_data[id] == 4.5); - ++id; - } - while (id < L) { - EXPECT_TRUE(orig_data[id] == 0); - ++id; - } -} - TEST_F(VulkanComputeAPITest, texture_deferred_allocation_test) { // This test is the same as texture_add_sanity_check, except that the tensor // memory is allocated in a deferred fashion @@ -1303,62 +1251,6 @@ TEST(VulkanComputeGraphTest, test_simple_graph_with_buffer) { } } -TEST(VulkanComputeGraphTest, test_simple_graph_with_view) { - constexpr int W = 7; - constexpr int H = 7; - // slice height - constexpr int S_H = 2; - // slice offset - constexpr int S_O = 3; - - GraphConfig config; - config.set_storage_type_override(utils::kBuffer); - ComputeGraph graph(config); - - std::vector dim_order = {0, 1}; - - std::vector orig_sizes = {H, W}; - std::vector slice_sizes = {S_H, W}; - const int offset = S_O * W; - - // Build graph - - IOValueRef orig = graph.add_input_tensor(orig_sizes, vkapi::kFloat); - ValueRef slice = - graph.add_tensor_view(orig.value, slice_sizes, dim_order, offset); - - EXPECT_TRUE(graph.val_is_view_of(slice, orig.value)); - - IOValueRef out = {}; - - out.value = graph.add_tensor(slice_sizes, vkapi::kFloat); - - auto opFn = VK_GET_OP_FN("aten.abs.default"); - opFn(graph, {slice, out.value, kDummyValueRef, kDummyValueRef}); - - out.staging = graph.set_output_tensor(out.value); - - graph.prepare(); - graph.encode_execute(); - - // Run graph - - for (float i = 5.0f; i < 30.0f; i += 10.0f) { - float start_val = -130 + i; - - fill_vtensor(graph, orig, start_val, true); - - graph.execute(); - - EXTRACT_TENSOR(out); - - for (size_t i = 0; i < graph.get_tensor(out.value)->numel(); ++i) { - const float expected_val = std::abs(start_val) - float(offset) - i; - CHECK_VALUE(data_out, i, expected_val); - } - } -} - TEST(VulkanComputeGraphTest, test_graph_view_of_view) { GraphConfig config; config.set_storage_type_override(utils::kTexture3D); @@ -2044,7 +1936,7 @@ TEST(VulkanComputeGraphTest, test_etvk_copy_offset_node) { } } -TEST(VulkanComputeGraphTest, test_etvk_copy_channel_offset_node) { +TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_node) { GraphConfig config; ComputeGraph graph(config); @@ -2103,7 +1995,7 @@ TEST(VulkanComputeGraphTest, test_etvk_copy_channel_offset_node) { TEST( VulkanComputeGraphTest, - test_etvk_copy_channel_offset_node_clean_boundary) { + DISABLED_test_etvk_copy_channel_offset_node_clean_boundary) { // Tricky part for channel copy is handling the boundary across multiple copy. // For example, when we concat two [3, 1, 1] nchw-tensors along the channel // dimension, due to channel packing, elements from different source texel @@ -2312,7 +2204,7 @@ TEST(VulkanComputeGraphTest, test_etvk_copy_offset_int_node) { } } -TEST(VulkanComputeGraphTest, test_etvk_copy_channel_offset_int_node) { +TEST(VulkanComputeGraphTest, DISABLED_test_etvk_copy_channel_offset_int_node) { GraphConfig config; ComputeGraph graph(config); @@ -2966,71 +2858,6 @@ TEST(VulkanComputeGraphOpsTest, max_pool2d_smoke_test) { kernel); } -void test_conv2d( - const std::vector& original_sizes, - const std::vector& padded_sizes, - const std::vector& gpu_sizes, - const bool transposed, - const std::vector& data_out_expected) { - vTensor vten = vTensor( - context(), - gpu_sizes, - vkapi::kFloat, - utils::StorageType::TEXTURE_2D, - utils::GPUMemoryLayout::TENSOR_CHANNELS_PACKED); - - // Create and fill input staging buffer - const int64_t in_numel = utils::multiply_integers(original_sizes); - StagingBuffer staging_buffer_in(context(), vkapi::kFloat, in_numel); - - std::vector data_in(in_numel); - for (int i = 0; i < in_numel; i++) { - data_in[i] = i + 1; - } - staging_buffer_in.copy_from(data_in.data(), sizeof(float) * in_numel); - - // Output staging buffer - const int64_t out_numel = - padded_sizes[0] * padded_sizes[1] * original_sizes[2] * original_sizes[3]; - StagingBuffer staging_buffer_out(context(), vkapi::kFloat, out_numel); - - // Copy data in and out of the tensor - record_conv2d_prepack_weights_op( - context(), staging_buffer_in.buffer(), vten, original_sizes, transposed); - record_image_to_nchw_op(context(), vten, staging_buffer_out.buffer()); - - // Execute command buffer - submit_to_gpu(); - - // Extract data from output staging buffer - std::vector data_out(out_numel); - staging_buffer_out.copy_to(data_out.data(), sizeof(float) * out_numel); - - // Check data matches results copied from ATen-VK - for (int i = 0; i < vten.numel(); i++) { - CHECK_VALUE(data_out, i, data_out_expected[i]); - } -} - -TEST(VulkanComputeGraphOpsTest, conv2d_prepack_test) { - test_conv2d( - /*original_sizes = */ {2, 3, 1, 2}, - /*padded_sizes = */ {4, 4}, - /*gpu_sizes = */ {4, 1, 8}, - /*transposed = */ false, - /*data_out_expected = */ {1, 3, 5, 0, 2, 4, 6, 0, 7, 9, 11, - 0, 8, 10, 12, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); - test_conv2d( - /*original_sizes = */ {2, 3, 1, 2}, - /*padded_sizes = */ {4, 4}, - /*gpu_sizes = */ {4, 1, 8}, - /*transposed = */ true, - /*data_out_expected = */ {2, 8, 0, 0, 1, 7, 0, 0, 4, 10, 0, - 0, 3, 9, 0, 0, 6, 12, 0, 0, 5, 11, - 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); -} - void test_grid_priors( std::vector input_sizes, std::vector output_sizes, @@ -3242,8 +3069,10 @@ void test_to_copy() { EXPECT_EQ(data_in.size(), output_data.size()); +#ifdef VULKAN_DEBUG float mse_ex = 0.0f; float mse_vk = 0.0f; +#endif // check results for (size_t i = 0; i < output_data.size(); ++i) { @@ -3254,6 +3083,7 @@ void test_to_copy() { torch::executor::Half output = output_data[i]; uint16_t* output_bits = reinterpret_cast(&output); +#ifdef VULKAN_DEBUG std::string msg; msg.reserve(64); msg = "input = " + std::to_string(input) + "(0b" + @@ -3265,6 +3095,10 @@ void test_to_copy() { std::cout << msg << std::endl; + mse_ex += std::pow(expected_output - input, 2); + mse_vk += std::pow(output - input, 2); +#endif + // Note: Torch executor half "rounds up" when converting to fp16 whereas // most driver implementations of Vulkan's opFConvert() just truncates the // extra bits for performance (rounding introduces conditional). @@ -3284,15 +3118,16 @@ void test_to_copy() { EXPECT_TRUE( (*output_bits == *expected_bits) || /*rounding error*/ ((*output_bits + 1u) == *expected_bits)); - mse_ex += std::pow(expected_output - input, 2); - mse_vk += std::pow(output - input, 2); } +#ifdef VULKAN_DEBUG mse_ex /= output_data.size(); mse_vk /= output_data.size(); + std::cout << "=========================================================" << std::endl; std::cout << "mse_ex = " << mse_ex << ", mse_vk = " << mse_vk << std::endl; +#endif } TEST(VulkanComputeGraphOpsTest, test_to_copy) {