diff --git a/.github/workflows/nightly-release.yml b/.github/workflows/nightly-release.yml index 0ecf037cd5..f5bd59791a 100644 --- a/.github/workflows/nightly-release.yml +++ b/.github/workflows/nightly-release.yml @@ -98,7 +98,7 @@ jobs: run: | python -m pip install --upgrade pip pip install build twine wheel - pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15 + pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2" - name: Build flashinfer-cubin wheel env: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 2a7c48ca23..14af22d986 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -136,7 +136,7 @@ jobs: run: | python -m pip install --upgrade pip pip install build twine wheel - pip install setuptools>=61.0 requests filelock torch tqdm numpy apache-tvm-ffi==0.1.0b15 + pip install setuptools>=61.0 requests filelock torch tqdm numpy "apache-tvm-ffi>=0.1,<0.2" - name: Build flashinfer-cubin wheel run: | diff --git a/csrc/batch_attention.cu b/csrc/batch_attention.cu index 21c78e8a39..a3d36b7981 100644 --- a/csrc/batch_attention.cu +++ b/csrc/batch_attention.cu @@ -42,21 +42,21 @@ Array BatchPagedAttentionPlan(TensorView float_workspace_buffer, TensorView kv_len, int64_t batch_size, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_o, bool causal) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); HolisticPlanInfo<2> plan_info; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = TwoStageHolisticPlan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), - static_cast(kv_len->data), batch_size, num_qo_heads, num_kv_heads, head_dim_o, - causal, stream); + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), static_cast(kv_len.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, head_dim_o, causal, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan persistent paged attention, error: " << cudaGetErrorString(status); @@ -76,34 +76,34 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo HolisticPlanInfo<2> plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); const MaskMode mask_mode = static_cast(mask_mode_code); // NOTE (Yilong): assume both q and o are NHD - unsigned int q_stride_n = q->strides[0]; - unsigned int q_stride_h = q->strides[1]; + unsigned int q_stride_n = q.stride(0); + unsigned int q_stride_h = q.stride(1); // layout only constraint paged KV const QKVLayout kv_layout = static_cast(layout_code); - unsigned int k_stride_page = k_cache->strides[0]; - unsigned int v_stride_page = v_cache->strides[0]; + unsigned int k_stride_page = k_cache.stride(0); + unsigned int v_stride_page = v_cache.stride(0); unsigned int k_stride_n, k_stride_h, v_stride_n, v_stride_h; if (kv_layout == QKVLayout::kNHD) { - k_stride_h = k_cache->strides[2]; - k_stride_n = k_cache->strides[1]; - v_stride_h = v_cache->strides[2]; - v_stride_n = v_cache->strides[1]; + k_stride_h = k_cache.stride(2); + k_stride_n = k_cache.stride(1); + v_stride_h = v_cache.stride(2); + v_stride_n = v_cache.stride(1); } else { - k_stride_h = k_cache->strides[1]; - k_stride_n = k_cache->strides[2]; - v_stride_h = v_cache->strides[1]; - v_stride_n = v_cache->strides[2]; + k_stride_h = k_cache.stride(1); + k_stride_n = k_cache.stride(2); + v_stride_h = v_cache.stride(1); + v_stride_n = v_cache.stride(2); } - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, @@ -112,9 +112,9 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo IdType* len_kv_chunk = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.len_kv_chunk_offset); for (int i = 0; i < 2; i++) { - params[i].q = static_cast(q->data); - params[i].k = static_cast(k_cache->data); - params[i].v = static_cast(v_cache->data); + params[i].q = static_cast(q.data_ptr()); + params[i].k = static_cast(k_cache.data_ptr()); + params[i].v = static_cast(v_cache.data_ptr()); params[i].q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].q_indptr_offset); @@ -122,7 +122,7 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].kv_indptr_offset); params[i].partial_indptr = GetPtrFromBaseOffset( int_buffer_ptr, plan_info.tasks[i].partial_indptr_offset); - params[i].kv_indices = static_cast(kv_indices->data); + params[i].kv_indices = static_cast(kv_indices.data_ptr()); params[i].q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].q_len_offset); params[i].kv_len = @@ -139,9 +139,9 @@ void BatchPagedAttentionRun(TensorView float_workspace_buffer, TensorView int_wo GetPtrFromBaseOffset(int_buffer_ptr, plan_info.tasks[i].work_indptr_offset); params[i].len_kv_chunk = len_kv_chunk + i; - params[i].final_o = static_cast(o->data); + params[i].final_o = static_cast(o.data_ptr()); params[i].final_lse = - maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params[i].partial_o = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params[i].partial_lse = diff --git a/csrc/batch_decode.cu b/csrc/batch_decode.cu index afb105f442..c3ce1e2ecf 100644 --- a/csrc/batch_decode.cu +++ b/csrc/batch_decode.cu @@ -43,9 +43,9 @@ Array BatchDecodeWithPagedKVCachePlan( int64_t window_left, double logits_soft_cap, int64_t head_dim_qk, int64_t head_dim_vo, TensorView empty_q_data, TensorView empty_kv_data) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); DecodePlanInfo plan_info; @@ -53,8 +53,8 @@ Array BatchDecodeWithPagedKVCachePlan( << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " "cores template for different head dim"; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { @@ -62,10 +62,10 @@ Array BatchDecodeWithPagedKVCachePlan( auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatched< GROUP_SIZE, HEAD_DIM_QK, POS_ENCODING_MODE, AttentionVariant, Params>; cudaError_t status = DecodePlan( - static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer->data), - static_cast(page_locked_int_workspace_buffer->data), - int_workspace_size_in_bytes, plan_info, static_cast(indptr->data), + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + static_cast(page_locked_int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, work_estimation_func); @@ -89,19 +89,19 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout = static_cast(kv_layout_code); - int64_t batch_size = q->shape[0]; - int64_t num_qo_heads = q->shape[1]; + int64_t batch_size = q.size(0); + int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->shape[1]; - page_size = paged_k_cache->shape[2]; + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache->shape[1]; - num_kv_heads = paged_k_cache->shape[2]; + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } - uint32_t head_dim_qk = q->shape[2]; - uint32_t head_dim_vo = paged_v_cache->shape[3]; + uint32_t head_dim_qk = q.size(2); + uint32_t head_dim_vo = paged_v_cache.size(3); TVM_FFI_ICHECK_EQ(head_dim_qk, head_dim_vo) << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " @@ -109,16 +109,16 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, if (maybe_lse.has_value()) { const auto& lse = maybe_lse.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(lse.size(0), batch_size); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); } - void* float_buffer = static_cast(float_workspace_buffer->data); - void* int_buffer = static_cast(int_workspace_buffer->data); + void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); // get q_stride_n and q_stride_h - const auto q_stride_n = q->strides[0]; - const auto q_stride_h = q->strides[1]; + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); // get kv_cache_strides const int64_t* kv_cache_strides = nullptr; @@ -130,24 +130,26 @@ void BatchDecodeWithPagedKVCacheRun(TensorView float_workspace_buffer, } kv_cache_strides = k_strides.data(); - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM_QK, batch_size, kv_layout, - static_cast(paged_k_cache->data), static_cast(paged_v_cache->data), - kv_cache_strides, static_cast(paged_kv_indices->data), - static_cast(paged_kv_indptr->data), - static_cast(paged_kv_last_page_len->data)); + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); Params params; - params.q = static_cast(q->data); + params.q = static_cast(q.data_ptr()); params.paged_kv = paged_kv; - params.o = static_cast(o->data); - params.lse = maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.padded_batch_size = 0; params.num_qo_heads = num_qo_heads; params.q_stride_n = q_stride_n; diff --git a/csrc/batch_decode_mla_cute_sm80.cu b/csrc/batch_decode_mla_cute_sm80.cu index d96190e539..5679076438 100644 --- a/csrc/batch_decode_mla_cute_sm80.cu +++ b/csrc/batch_decode_mla_cute_sm80.cu @@ -18,23 +18,24 @@ Array BatchDecodeWithPagedKVCachePlanMLA(ffi::TensorView float_workspac int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); DecodePlanInfo plan_info; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); auto work_estimation_func = BatchDecodeWithPagedKVCacheWorkEstimationDispatchedMlaCuteSM80< HEAD_DIM_CKV, HEAD_DIM_KPE, QO_TILE_LEN, AttentionVariant, Params>; cudaError_t status = DecodePlan( - static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer->data), - static_cast(page_locked_int_workspace_buffer->data), int_workspace_size_in_bytes, - plan_info, static_cast(indptr->data), batch_size, num_qo_heads, page_size, - enable_cuda_graph, /*stream=*/stream, work_estimation_func); + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + static_cast(page_locked_int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), + batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, + work_estimation_func); TVM_FFI_ICHECK(status == cudaSuccess) << "BatchDecodeWithPagedKVCachePlanMLA failed with error " << cudaGetErrorString(status); @@ -55,31 +56,32 @@ void BatchDecodeWithPagedKVCacheRunMLA( DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - int64_t batch_size = q_nope->shape[0]; - int64_t num_qo_heads = q_nope->shape[1]; - int64_t page_size = paged_ckv_cache->shape[1]; + int64_t batch_size = q_nope.size(0); + int64_t num_qo_heads = q_nope.size(1); + int64_t page_size = paged_ckv_cache.size(1); if (maybe_lse.has_value()) { const auto& lse = maybe_lse.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(lse.size(0), batch_size); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); } TVM_FFI_ICHECK_GE(logits_soft_cap, 0.f) << "logits_soft_cap must be non-negative"; - void* float_buffer = static_cast(float_workspace_buffer->data); - void* int_buffer = static_cast(int_workspace_buffer->data); + void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); paged_kv_mla_t paged_kv( page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size, - static_cast(paged_ckv_cache->data), paged_ckv_cache.strides().data(), - static_cast(paged_kpe_cache->data), paged_kpe_cache.strides().data(), - static_cast(paged_kv_indices->data), static_cast(paged_kv_indptr->data), - static_cast(paged_kv_last_page_len->data)); + static_cast(paged_ckv_cache.data_ptr()), paged_ckv_cache.strides().data(), + static_cast(paged_kpe_cache.data_ptr()), paged_kpe_cache.strides().data(), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); Params params( - static_cast(q_nope->data), static_cast(q_pe->data), - /*q_offset=*/nullptr, paged_kv, static_cast(o->data), - /*lse=*/(maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr), + static_cast(q_nope.data_ptr()), static_cast(q_pe.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr), num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); DTypeO* tmp_v = nullptr; @@ -101,8 +103,8 @@ void BatchDecodeWithPagedKVCacheRunMLA( } params.padded_batch_size = plan_info.padded_batch_size; - cudaSetDevice(paged_ckv_cache->device.device_id); - const cudaStream_t stream = get_stream(paged_ckv_cache->device); + cudaSetDevice(paged_ckv_cache.device().device_id); + const cudaStream_t stream = get_stream(paged_ckv_cache.device()); cudaError_t status = BatchDecodeWithPagedKVCacheDispatchedMlaCuteSM80( params, tmp_v, tmp_s, /*stream=*/stream); diff --git a/csrc/batch_decode_mla_plan.cu b/csrc/batch_decode_mla_plan.cu index d7c41c90ca..7925a14f27 100644 --- a/csrc/batch_decode_mla_plan.cu +++ b/csrc/batch_decode_mla_plan.cu @@ -15,13 +15,13 @@ Array BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buf TensorView indptr, int64_t batch_size, int64_t num_qo_heads, int64_t page_size, bool enable_cuda_graph) { - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); DecodePlanInfo plan_info; @@ -30,11 +30,12 @@ Array BatchDecodeWithPagedKVCachePlanMLA(TensorView float_workspace_buf AttentionVariant, Params>; cudaError_t status = DecodePlan( - static_cast(float_workspace_buffer->data), float_workspace_size_in_bytes, - static_cast(int_workspace_buffer->data), - static_cast(page_locked_int_workspace_buffer->data), int_workspace_size_in_bytes, - plan_info, static_cast(indptr->data), batch_size, num_qo_heads, page_size, - enable_cuda_graph, /*stream=*/stream, work_estimation_func); + static_cast(float_workspace_buffer.data_ptr()), float_workspace_size_in_bytes, + static_cast(int_workspace_buffer.data_ptr()), + static_cast(page_locked_int_workspace_buffer.data_ptr()), + int_workspace_size_in_bytes, plan_info, static_cast(indptr.data_ptr()), + batch_size, num_qo_heads, page_size, enable_cuda_graph, /*stream=*/stream, + work_estimation_func); TVM_FFI_ICHECK(status == cudaSuccess) << "BatchDecodeWithPagedKVCachePlanMLA failed with error: " << cudaGetErrorString(status); diff --git a/csrc/batch_decode_mla_run.cu b/csrc/batch_decode_mla_run.cu index f4cef6dc4a..35d533b536 100644 --- a/csrc/batch_decode_mla_run.cu +++ b/csrc/batch_decode_mla_run.cu @@ -20,33 +20,34 @@ void BatchDecodeWithPagedKVCacheRunMLA( DecodePlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - int64_t batch_size = q_nope->shape[0]; - int64_t num_qo_heads = q_nope->shape[1]; - int64_t page_size = paged_ckv_cache->shape[1]; + int64_t batch_size = q_nope.size(0); + int64_t num_qo_heads = q_nope.size(1); + int64_t page_size = paged_ckv_cache.size(1); if (maybe_lse.has_value()) { const auto& lse = maybe_lse.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(lse.size(0), batch_size); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); } TVM_FFI_ICHECK_GE(logits_soft_cap, 0.f) << "logits_soft_cap must be non-negative"; - void* float_buffer = static_cast(float_workspace_buffer->data); - void* int_buffer = static_cast(int_workspace_buffer->data); + void* float_buffer = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer = static_cast(int_workspace_buffer.data_ptr()); - cudaSetDevice(q_nope->device.device_id); - const cudaStream_t stream = get_stream(q_nope->device); + cudaSetDevice(q_nope.device().device_id); + const cudaStream_t stream = get_stream(q_nope.device()); paged_kv_mla_t paged_kv( page_size, HEAD_DIM_CKV, HEAD_DIM_KPE, batch_size, - static_cast(paged_ckv_cache->data), paged_ckv_cache.strides().data(), - static_cast(paged_kpe_cache->data), paged_kpe_cache.strides().data(), - static_cast(paged_kv_indices->data), static_cast(paged_kv_indptr->data), - static_cast(paged_kv_last_page_len->data)); - Params params(static_cast(q_nope->data), static_cast(q_pe->data), - /*q_offset=*/nullptr, paged_kv, static_cast(o->data), - /*lse=*/(maybe_lse ? static_cast(maybe_lse.value()->data) : nullptr), + static_cast(paged_ckv_cache.data_ptr()), paged_ckv_cache.strides().data(), + static_cast(paged_kpe_cache.data_ptr()), paged_kpe_cache.strides().data(), + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); + Params params(static_cast(q_nope.data_ptr()), static_cast(q_pe.data_ptr()), + /*q_offset=*/nullptr, paged_kv, static_cast(o.data_ptr()), + /*lse=*/(maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr), num_qo_heads, window_left, logits_soft_cap, sm_scale, rope_scale, rope_theta); DTypeO* tmp_v = nullptr; diff --git a/csrc/batch_mla_config.jinja b/csrc/batch_mla_config.jinja index 3a8d2af048..2525d34273 100644 --- a/csrc/batch_mla_config.jinja +++ b/csrc/batch_mla_config.jinja @@ -13,7 +13,7 @@ using namespace flashinfer; #ifdef FLASHINFER_ENABLE_PROFILER #define ADDITIONAL_FUNC_PARAMS , Tensor profiler_buffer #define ADDITIONAL_PARAMS_SETTER \ - params.profiler_buffer = static_cast(profiler_buffer->data); + params.profiler_buffer = static_cast(profiler_buffer.data_ptr()); #else #define ADDITIONAL_FUNC_PARAMS #define ADDITIONAL_PARAMS_SETTER diff --git a/csrc/batch_mla_plan.cu b/csrc/batch_mla_plan.cu index 6715b5db2f..1f7176e452 100644 --- a/csrc/batch_mla_plan.cu +++ b/csrc/batch_mla_plan.cu @@ -30,22 +30,23 @@ Array BatchMLAPagedAttentionPlan(TensorView float_workspace_buffer, TensorView kv_len, int64_t num_heads, int64_t head_dim_o, bool causal) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); MLAPlanInfo plan_info; - int batch_size = kv_len->shape[0]; + int batch_size = kv_len.size(0); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); - cudaError_t status = MLAPlan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), - static_cast(kv_len->data), batch_size, num_heads, head_dim_o, causal, stream); + cudaError_t status = + MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), static_cast(kv_len.data_ptr()), + batch_size, num_heads, head_dim_o, causal, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan MLA, error: " << cudaGetErrorString(status); diff --git a/csrc/batch_mla_run.cu b/csrc/batch_mla_run.cu index de7acedb04..dfa2442f1b 100644 --- a/csrc/batch_mla_run.cu +++ b/csrc/batch_mla_run.cu @@ -39,39 +39,39 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int MLAPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); const MaskMode mask_mode = static_cast(mask_mode_code); - unsigned int q_nope_stride_n = q_nope->strides[0]; - unsigned int q_nope_stride_h = q_nope->strides[1]; - unsigned int q_pe_stride_n = q_pe->strides[0]; - unsigned int q_pe_stride_h = q_pe->strides[1]; - unsigned int ckv_stride_page = ckv_cache->strides[0]; - unsigned int ckv_stride_n = ckv_cache->strides[1]; - unsigned int kpe_stride_page = kpe_cache->strides[0]; - unsigned int kpe_stride_n = kpe_cache->strides[1]; - unsigned int o_stride_n = o->strides[0]; - unsigned int o_stride_h = o->strides[1]; + unsigned int q_nope_stride_n = q_nope.stride(0); + unsigned int q_nope_stride_h = q_nope.stride(1); + unsigned int q_pe_stride_n = q_pe.stride(0); + unsigned int q_pe_stride_h = q_pe.stride(1); + unsigned int ckv_stride_page = ckv_cache.stride(0); + unsigned int ckv_stride_n = ckv_cache.stride(1); + unsigned int kpe_stride_page = kpe_cache.stride(0); + unsigned int kpe_stride_n = kpe_cache.stride(1); + unsigned int o_stride_n = o.stride(0); + unsigned int o_stride_h = o.stride(1); - cudaSetDevice(q_nope->device.device_id); - const cudaStream_t stream = get_stream(q_nope->device); + cudaSetDevice(q_nope.device().device_id); + const cudaStream_t stream = get_stream(q_nope.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { Params params; - params.q_nope = static_cast(q_nope->data); - params.q_pe = static_cast(q_pe->data); - params.ckv = static_cast(ckv_cache->data); - params.kpe = static_cast(kpe_cache->data); + params.q_nope = static_cast(q_nope.data_ptr()); + params.q_pe = static_cast(q_pe.data_ptr()); + params.ckv = static_cast(ckv_cache.data_ptr()); + params.kpe = static_cast(kpe_cache.data_ptr()); params.q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_indptr_offset); params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); params.partial_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.partial_indptr_offset); - params.kv_indices = static_cast(kv_indices->data); + params.kv_indices = static_cast(kv_indices.data_ptr()); params.q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_offset); params.kv_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); params.q_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_offset); @@ -89,9 +89,9 @@ void BatchMLAPagedAttentionRun(TensorView float_workspace_buffer, TensorView int int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset); params.merge_partial_stride = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_partial_stride_offset); - params.final_o = static_cast(o->data); + params.final_o = static_cast(o.data_ptr()); params.final_lse = - maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.partial_o = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = diff --git a/csrc/batch_mla_sm90_plan.cu b/csrc/batch_mla_sm90_plan.cu index 78427ffe57..d297ebab90 100644 --- a/csrc/batch_mla_sm90_plan.cu +++ b/csrc/batch_mla_sm90_plan.cu @@ -30,22 +30,23 @@ Array BatchMLAPagedAttentionSM90Plan(TensorView float_workspace_buffer, TensorView kv_len, int64_t num_heads, int64_t head_dim_o, bool causal) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); MLAPlanInfo plan_info; - int batch_size = kv_len->shape[0]; + int batch_size = kv_len.size(0); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); - cudaError_t status = MLAPlan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), - static_cast(kv_len->data), batch_size, num_heads, head_dim_o, causal, stream); + cudaError_t status = + MLAPlan(float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), static_cast(kv_len.data_ptr()), + batch_size, num_heads, head_dim_o, causal, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to plan MLA, error: " << cudaGetErrorString(status); diff --git a/csrc/batch_mla_sm90_run.cu b/csrc/batch_mla_sm90_run.cu index efb744e9f2..8d6d80c223 100644 --- a/csrc/batch_mla_sm90_run.cu +++ b/csrc/batch_mla_sm90_run.cu @@ -40,39 +40,39 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, MLAPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); const MaskMode mask_mode = static_cast(mask_mode_code); - unsigned int q_nope_stride_n = q_nope->strides[0]; - unsigned int q_nope_stride_h = q_nope->strides[1]; - unsigned int q_pe_stride_n = q_pe->strides[0]; - unsigned int q_pe_stride_h = q_pe->strides[1]; - unsigned int ckv_stride_page = ckv_cache->strides[0]; - unsigned int ckv_stride_n = ckv_cache->strides[1]; - unsigned int kpe_stride_page = kpe_cache->strides[0]; - unsigned int kpe_stride_n = kpe_cache->strides[1]; - unsigned int o_stride_n = o->strides[0]; - unsigned int o_stride_h = o->strides[1]; + unsigned int q_nope_stride_n = q_nope.stride(0); + unsigned int q_nope_stride_h = q_nope.stride(1); + unsigned int q_pe_stride_n = q_pe.stride(0); + unsigned int q_pe_stride_h = q_pe.stride(1); + unsigned int ckv_stride_page = ckv_cache.stride(0); + unsigned int ckv_stride_n = ckv_cache.stride(1); + unsigned int kpe_stride_page = kpe_cache.stride(0); + unsigned int kpe_stride_n = kpe_cache.stride(1); + unsigned int o_stride_n = o.stride(0); + unsigned int o_stride_h = o.stride(1); - cudaSetDevice(q_nope->device.device_id); - const cudaStream_t stream = get_stream(q_nope->device); + cudaSetDevice(q_nope.device().device_id); + const cudaStream_t stream = get_stream(q_nope.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_CKV, HEAD_DIM_KPE, Params, [&] { Params params; - params.q_nope = static_cast(q_nope->data); - params.q_pe = static_cast(q_pe->data); - params.ckv = static_cast(ckv_cache->data); - params.kpe = static_cast(kpe_cache->data); + params.q_nope = static_cast(q_nope.data_ptr()); + params.q_pe = static_cast(q_pe.data_ptr()); + params.ckv = static_cast(ckv_cache.data_ptr()); + params.kpe = static_cast(kpe_cache.data_ptr()); params.q_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_indptr_offset); params.kv_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_indptr_offset); params.partial_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.partial_indptr_offset); - params.kv_indices = static_cast(kv_indices->data); + params.kv_indices = static_cast(kv_indices.data_ptr()); params.q_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_len_offset); params.kv_len = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.kv_len_offset); params.q_start = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.q_start_offset); @@ -90,9 +90,9 @@ void BatchMLAPagedAttentionSM90Run(TensorView float_workspace_buffer, int_buffer_ptr, plan_info.merge_partial_packed_offset_end_offset); params.merge_partial_stride = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.merge_partial_stride_offset); - params.final_o = static_cast(o->data); + params.final_o = static_cast(o.data_ptr()); params.final_lse = - maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.partial_o = GetPtrFromBaseOffset(float_buffer_ptr, plan_info.partial_o_offset); params.partial_lse = diff --git a/csrc/batch_prefill.cu b/csrc/batch_prefill.cu index 796295a2ea..5d7182bdc5 100644 --- a/csrc/batch_prefill.cu +++ b/csrc/batch_prefill.cu @@ -52,20 +52,20 @@ Array BatchPrefillWithKVCachePlan( int64_t head_dim_vo, bool causal, int64_t window_left, int64_t fixed_split_size, bool disable_split_kv) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); PrefillPlanInfo plan_info; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillPlan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), total_num_rows, - batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, - enable_cuda_graph, + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), total_num_rows, batch_size, num_qo_heads, + num_kv_heads, head_dim_qk, head_dim_vo, page_size, enable_cuda_graph, /*sizeof_dtype_o=*/2, window_left, fixed_split_size, disable_split_kv, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -85,36 +85,36 @@ void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout = static_cast(layout); - int64_t num_qo_heads = q->shape[1]; - int64_t head_dim_qk = q->shape[2]; - int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k->shape[1] : k->shape[0]; - uint32_t q_stride_n = q->strides[0], q_stride_h = q->strides[1], k_stride_n, k_stride_h, - v_stride_n, v_stride_h; + int64_t num_qo_heads = q.size(1); + int64_t head_dim_qk = q.size(2); + int64_t num_kv_heads = (kv_layout == QKVLayout::kNHD) ? k.size(1) : k.size(0); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), k_stride_n, k_stride_h, v_stride_n, + v_stride_h; if (kv_layout == QKVLayout::kNHD) { - k_stride_n = k->strides[0]; - k_stride_h = k->strides[1]; - v_stride_n = v->strides[0]; - v_stride_h = v->strides[1]; + k_stride_n = k.stride(0); + k_stride_h = k.stride(1); + v_stride_n = v.stride(0); + v_stride_h = v.stride(1); } else { - k_stride_h = k->strides[0]; - k_stride_n = k->strides[1]; - v_stride_h = v->strides[0]; - v_stride_n = v->strides[1]; + k_stride_h = k.stride(0); + k_stride_n = k.stride(1); + v_stride_h = v.stride(0); + v_stride_n = v.stride(1); } if (maybe_lse.has_value()) { const auto& lse = *maybe_lse; - TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); } - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); const MaskMode mask_mode = static_cast(mask_mode_code); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, @@ -122,13 +122,14 @@ void BatchPrefillWithRaggedKVCacheRun(TensorView float_workspace_buffer, RaggedParams, PagedParams, [&] { RaggedParams params; - params.q = static_cast(q->data); - params.k = static_cast(k->data); - params.v = static_cast(v->data); - params.o = static_cast(o->data); - params.lse = maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; - params.q_indptr = static_cast(qo_indptr->data); - params.kv_indptr = static_cast(kv_indptr->data); + params.q = static_cast(q.data_ptr()); + params.k = static_cast(k.data_ptr()); + params.v = static_cast(v.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.kv_indptr = static_cast(kv_indptr.data_ptr()); params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); @@ -210,43 +211,43 @@ void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer, PrefillPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout = static_cast(layout); - int64_t batch_size = paged_kv_indptr->shape[0] - 1; - int64_t num_qo_heads = q->shape[1]; + int64_t batch_size = paged_kv_indptr.size(0) - 1; + int64_t num_qo_heads = q.size(1); int64_t num_kv_heads, page_size; - uint32_t head_dim_qk = q->shape[2]; + uint32_t head_dim_qk = q.size(2); if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->shape[1]; - page_size = paged_k_cache->shape[2]; + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache->shape[1]; - num_kv_heads = paged_k_cache->shape[2]; + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } if (maybe_lse) { const auto& lse = *maybe_lse; - TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); } - void* float_buffer_ptr = static_cast(float_workspace_buffer->data); - void* int_buffer_ptr = static_cast(int_workspace_buffer->data); + void* float_buffer_ptr = static_cast(float_workspace_buffer.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer.data_ptr()); const MaskMode mask_mode = static_cast(mask_mode_code); // get q_stride_n and q_stride_h - const auto q_stride_n = q->strides[0]; - const auto q_stride_h = q->strides[1]; + const auto q_stride_n = q.stride(0); + const auto q_stride_h = q.stride(1); // get kv_cache_strides const int64_t* kv_cache_strides = paged_k_cache.strides().data(); - TVM_FFI_ICHECK_EQ(paged_k_cache->ndim, paged_v_cache->ndim); - for (int i = 0; i < paged_k_cache->ndim; ++i) { - TVM_FFI_ICHECK_EQ(paged_k_cache->strides[i], paged_v_cache->strides[i]) + TVM_FFI_ICHECK_EQ(paged_k_cache.ndim(), paged_v_cache.ndim()); + for (int i = 0; i < paged_k_cache.ndim(); ++i) { + TVM_FFI_ICHECK_EQ(paged_k_cache.stride(i), paged_v_cache.stride(i)) << "k/v strides differs at " << i; } - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, @@ -254,18 +255,19 @@ void BatchPrefillWithPagedKVCacheRun(TensorView float_workspace_buffer, RaggedParams, PagedParams, [&] { PagedParams params; - params.q = static_cast(q->data); + params.q = static_cast(q.data_ptr()); paged_kv_t paged_kv( num_kv_heads, page_size, HEAD_DIM_VO, batch_size, kv_layout, - static_cast(paged_k_cache->data), static_cast(paged_v_cache->data), - kv_cache_strides, static_cast(paged_kv_indices->data), - static_cast(paged_kv_indptr->data), - static_cast(paged_kv_last_page_len->data)); + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), kv_cache_strides, + static_cast(paged_kv_indices.data_ptr()), + static_cast(paged_kv_indptr.data_ptr()), + static_cast(paged_kv_last_page_len.data_ptr())); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr->data); - params.o = static_cast(o->data); + params.q_indptr = static_cast(qo_indptr.data_ptr()); + params.o = static_cast(o.data_ptr()); - params.lse = maybe_lse ? static_cast(maybe_lse.value()->data) : nullptr; + params.lse = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n; diff --git a/csrc/batch_prefill_fp8_sm90.cu b/csrc/batch_prefill_fp8_sm90.cu index 5e221fa2ff..7c8680dc0b 100644 --- a/csrc/batch_prefill_fp8_sm90.cu +++ b/csrc/batch_prefill_fp8_sm90.cu @@ -44,21 +44,22 @@ Array BatchPrefillWithKVCacheSM90Plan( bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); flashinfer::PrefillPlanSM90Info plan_info; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillSM90Plan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), - static_cast(kv_len_arr->data), total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_qk, head_dim_vo, page_size, causal, enable_cuda_graph, + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), static_cast(kv_len_arr.data_ptr()), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, + causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -92,26 +93,26 @@ void BatchPrefillWithPagedKVCacheSM90Run( if (maybe_lse.has_value()) { const auto& lse = maybe_lse.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); } QKVLayout kv_layout = static_cast(layout); int64_t num_kv_heads, page_size; - int64_t head_dim_qk = q->shape[2]; - int64_t head_dim_vo = paged_v_cache->shape[3]; + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = paged_v_cache.size(3); if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->shape[1]; - page_size = paged_k_cache->shape[2]; + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache->shape[1]; - num_kv_heads = paged_k_cache->shape[2]; + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -120,30 +121,30 @@ void BatchPrefillWithPagedKVCacheSM90Run( USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { PagedParams params; - params.q_ptr = static_cast(q->data); - params.k_ptr = static_cast(paged_k_cache->data); - params.v_ptr = static_cast(paged_v_cache->data); - params.o_ptr = static_cast(o->data); - params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value()->data) : nullptr; - params.q_stride_n = q->strides[0]; - params.q_stride_h = q->strides[1]; - params.o_stride_n = o->strides[0]; - params.o_stride_h = o->strides[1]; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); if (kv_layout == QKVLayout::kNHD) { // (num_pages, page_size, num_heads, head_dim) - params.k_stride_n = paged_k_cache->strides[1]; - params.k_stride_h = paged_k_cache->strides[2]; - params.v_stride_n = paged_v_cache->strides[1]; - params.v_stride_h = paged_v_cache->strides[2]; + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); } else { // (num_pages, num_heads, page_size, head_dim) - params.k_stride_h = paged_k_cache->strides[1]; - params.k_stride_n = paged_k_cache->strides[2]; - params.v_stride_h = paged_v_cache->strides[1]; - params.v_stride_n = paged_v_cache->strides[2]; + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); } - params.nnz_qo = q->shape[0]; - params.num_qo_heads = q->shape[1]; + params.nnz_qo = q.size(0); + params.num_qo_heads = q.size(1); params.num_kv_heads = num_kv_heads; params.group_size = params.num_qo_heads / num_kv_heads; params.page_size = page_size; @@ -161,7 +162,7 @@ void BatchPrefillWithPagedKVCacheSM90Run( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.head_indices_offset); params.work_indptr = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); - params.kv_indices = static_cast(paged_kv_indices->data); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); ADDITIONAL_PARAMS_SETTER diff --git a/csrc/batch_prefill_sm90.cu b/csrc/batch_prefill_sm90.cu index a43060eba3..1cf78bab59 100644 --- a/csrc/batch_prefill_sm90.cu +++ b/csrc/batch_prefill_sm90.cu @@ -50,21 +50,22 @@ Array BatchPrefillWithKVCacheSM90Plan( bool enable_cuda_graph, int64_t head_dim_qk, int64_t head_dim_vo, bool causal, int64_t window_left) { size_t float_workspace_size_in_bytes = - float_workspace_buffer->shape[0] * get_element_size(float_workspace_buffer); + float_workspace_buffer.size(0) * get_element_size(float_workspace_buffer); size_t int_workspace_size_in_bytes = - int_workspace_buffer->shape[0] * get_element_size(int_workspace_buffer); + int_workspace_buffer.size(0) * get_element_size(int_workspace_buffer); flashinfer::PrefillPlanSM90Info plan_info; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); cudaError_t status = PrefillSM90Plan( - float_workspace_buffer->data, float_workspace_size_in_bytes, int_workspace_buffer->data, - page_locked_int_workspace_buffer->data, int_workspace_size_in_bytes, plan_info, - static_cast(qo_indptr->data), static_cast(kv_indptr->data), - static_cast(kv_len_arr->data), total_num_rows, batch_size, num_qo_heads, - num_kv_heads, head_dim_qk, head_dim_vo, page_size, causal, enable_cuda_graph, + float_workspace_buffer.data_ptr(), float_workspace_size_in_bytes, + int_workspace_buffer.data_ptr(), page_locked_int_workspace_buffer.data_ptr(), + int_workspace_size_in_bytes, plan_info, static_cast(qo_indptr.data_ptr()), + static_cast(kv_indptr.data_ptr()), static_cast(kv_len_arr.data_ptr()), + total_num_rows, batch_size, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, page_size, + causal, enable_cuda_graph, /*sizeof_dtype_o=*/2, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -84,20 +85,20 @@ void BatchPrefillWithRaggedKVCacheSM90Run( if (maybe_lse) { const auto& lse = *maybe_lse; - TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); } - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - int64_t head_dim_qk = q->shape[2]; - int64_t head_dim_vo = v->shape[2]; + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = v.size(2); QKVLayout kv_layout = static_cast(layout); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -106,30 +107,30 @@ void BatchPrefillWithRaggedKVCacheSM90Run( USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { RaggedParams params; - params.q_ptr = static_cast(q->data); - params.k_ptr = static_cast(k->data); - params.v_ptr = static_cast(v->data); - params.o_ptr = static_cast(o->data); - params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value()->data) : nullptr; - params.q_stride_n = q->strides[0]; - params.q_stride_h = q->strides[1]; - params.o_stride_n = o->strides[0]; - params.o_stride_h = o->strides[1]; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = k->strides[0]; - params.k_stride_h = k->strides[1]; - params.v_stride_n = v->strides[0]; - params.v_stride_h = v->strides[1]; + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); } else { - params.k_stride_h = k->strides[0]; - params.k_stride_n = k->strides[1]; - params.v_stride_h = v->strides[0]; - params.v_stride_n = v->strides[1]; + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); } - params.nnz_qo = q->shape[0]; - params.nnz_kv = k->shape[0]; - params.num_qo_heads = q->shape[1]; - params.num_kv_heads = k->shape[1]; + params.nnz_qo = q.size(0); + params.nnz_kv = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); params.group_size = params.num_qo_heads / params.num_kv_heads; params.window_left = window_left; params.causal = mask_mode_code == 1; @@ -174,26 +175,26 @@ void BatchPrefillWithPagedKVCacheSM90Run( if (maybe_lse) { const auto& lse = *maybe_lse; - TVM_FFI_ICHECK_EQ(lse->shape[0], q->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q.size(1)); } QKVLayout kv_layout = static_cast(layout); int64_t num_kv_heads, page_size; - int64_t head_dim_qk = q->shape[2]; - int64_t head_dim_vo = paged_v_cache->shape[3]; + int64_t head_dim_qk = q.size(2); + int64_t head_dim_vo = paged_v_cache.size(3); if (kv_layout == QKVLayout::kHND) { - num_kv_heads = paged_k_cache->shape[1]; - page_size = paged_k_cache->shape[2]; + num_kv_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache->shape[1]; - num_kv_heads = paged_k_cache->shape[2]; + page_size = paged_k_cache.size(1); + num_kv_heads = paged_k_cache.size(2); } - void* float_buffer_ptr = float_workspace_buffer->data; - void* int_buffer_ptr = int_workspace_buffer->data; + void* float_buffer_ptr = float_workspace_buffer.data_ptr(); + void* int_buffer_ptr = int_workspace_buffer.data_ptr(); - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); const MaskMode mask_mode = static_cast(mask_mode_code); bool use_swa = window_left != -1; @@ -202,30 +203,30 @@ void BatchPrefillWithPagedKVCacheSM90Run( USE_LOGITS_SOFT_CAP, AttentionVariant, RaggedParams, PagedParams, [&] { PagedParams params; - params.q_ptr = static_cast(q->data); - params.k_ptr = static_cast(paged_k_cache->data); - params.v_ptr = static_cast(paged_v_cache->data); - params.o_ptr = static_cast(o->data); - params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value()->data) : nullptr; - params.q_stride_n = q->strides[0]; - params.q_stride_h = q->strides[1]; - params.o_stride_n = o->strides[0]; - params.o_stride_h = o->strides[1]; + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(paged_k_cache.data_ptr()); + params.v_ptr = static_cast(paged_v_cache.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); + params.lse_ptr = maybe_lse ? static_cast(maybe_lse.value().data_ptr()) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); if (kv_layout == QKVLayout::kNHD) { // (num_pages, page_size, num_heads, head_dim) - params.k_stride_n = paged_k_cache->strides[1]; - params.k_stride_h = paged_k_cache->strides[2]; - params.v_stride_n = paged_v_cache->strides[1]; - params.v_stride_h = paged_v_cache->strides[2]; + params.k_stride_n = paged_k_cache.stride(1); + params.k_stride_h = paged_k_cache.stride(2); + params.v_stride_n = paged_v_cache.stride(1); + params.v_stride_h = paged_v_cache.stride(2); } else { // (num_pages, num_heads, page_size, head_dim) - params.k_stride_h = paged_k_cache->strides[1]; - params.k_stride_n = paged_k_cache->strides[2]; - params.v_stride_h = paged_v_cache->strides[1]; - params.v_stride_n = paged_v_cache->strides[2]; + params.k_stride_h = paged_k_cache.stride(1); + params.k_stride_n = paged_k_cache.stride(2); + params.v_stride_h = paged_v_cache.stride(1); + params.v_stride_n = paged_v_cache.stride(2); } - params.nnz_qo = q->shape[0]; - params.num_qo_heads = q->shape[1]; + params.nnz_qo = q.size(0); + params.num_qo_heads = q.size(1); params.num_kv_heads = num_kv_heads; params.group_size = params.num_qo_heads / num_kv_heads; params.page_size = page_size; @@ -243,7 +244,7 @@ void BatchPrefillWithPagedKVCacheSM90Run( GetPtrFromBaseOffset(int_buffer_ptr, plan_info.work_indptr_offset); params.batch_indices = GetPtrFromBaseOffset(int_buffer_ptr, plan_info.batch_indices_offset); - params.kv_indices = static_cast(paged_kv_indices->data); + params.kv_indices = static_cast(paged_kv_indices.data_ptr()); ADDITIONAL_PARAMS_SETTER diff --git a/csrc/blackwell_fmha_plan.cu b/csrc/blackwell_fmha_plan.cu index f976ce5b0b..ef9b1475ea 100644 --- a/csrc/blackwell_fmha_plan.cu +++ b/csrc/blackwell_fmha_plan.cu @@ -21,16 +21,17 @@ void blackwell_fmha_plan(TensorView qo_segment_offsets, TensorView kv_segment_of TensorView work_indptr, TensorView qo_tile_indices, TensorView head_indices, TensorView batch_indices, int64_t qo_tile_size, int64_t num_heads, int64_t num_buckets, bool causal) { - cudaSetDevice(qo_segment_offsets->device.device_id); - const cudaStream_t stream = get_stream(qo_tile_indices->device); - int batch_size = qo_segment_offsets->shape[0] - 1; + cudaSetDevice(qo_segment_offsets.device().device_id); + const cudaStream_t stream = get_stream(qo_tile_indices.device()); + int batch_size = qo_segment_offsets.size(0) - 1; auto status = flashinfer::plan_kernel_wrapper( - static_cast(qo_segment_offsets->data), static_cast(kv_segment_offsets->data), + static_cast(qo_segment_offsets.data_ptr()), + static_cast(kv_segment_offsets.data_ptr()), /*qo_lens=*/nullptr, - /*kv_lens=*/nullptr, static_cast(work_indptr->data), - static_cast(qo_tile_indices->data), static_cast(head_indices->data), - static_cast(batch_indices->data), qo_tile_size, batch_size, num_heads, num_buckets, + /*kv_lens=*/nullptr, static_cast(work_indptr.data_ptr()), + static_cast(qo_tile_indices.data_ptr()), static_cast(head_indices.data_ptr()), + static_cast(batch_indices.data_ptr()), qo_tile_size, batch_size, num_heads, num_buckets, causal, /*enable_pdl=*/true, stream); TVM_FFI_ICHECK_EQ(status, cudaSuccess) << "Failed to plan blackwell fmha" << cudaGetErrorString(status); diff --git a/csrc/bmm_fp8.cu b/csrc/bmm_fp8.cu index 2709191316..ea8417b617 100644 --- a/csrc/bmm_fp8.cu +++ b/csrc/bmm_fp8.cu @@ -28,32 +28,32 @@ void bmm_fp8(TensorView A, TensorView B, TensorView D, TensorView A_scale, Tenso CHECK_DIM(3, A); CHECK_DIM(3, B); CHECK_DIM(3, D); - TVM_FFI_ICHECK(A->shape[0] == B->shape[0] && A->shape[0] == D->shape[0]) - << "Batch sizes must match"; - TVM_FFI_ICHECK(A->shape[2] == B->shape[1]) << "Incompatible matrix sizes"; - TVM_FFI_ICHECK(A->shape[1] == D->shape[1] && B->shape[2] == D->shape[2]) + TVM_FFI_ICHECK(A.size(0) == B.size(0) && A.size(0) == D.size(0)) << "Batch sizes must match"; + TVM_FFI_ICHECK(A.size(2) == B.size(1)) << "Incompatible matrix sizes"; + TVM_FFI_ICHECK(A.size(1) == D.size(1) && B.size(2) == D.size(2)) << "Result tensor has incorrect shape"; // PyTorch is row major by default. cuBLASLt is column major by default. // We need row major D as expected. // A ^ T * B = D, so D ^ T = B ^ T * A - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B->dtype, b_type, [&] { - return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A->dtype, a_type, [&] { - return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D->dtype, d_type, [&] { - auto batch_size = A->shape[0]; - auto m = A->shape[1]; - auto k = A->shape[2]; - auto n = B->shape[2]; + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(B.dtype(), b_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(A.dtype(), a_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(D.dtype(), d_type, [&] { + auto batch_size = A.size(0); + auto m = A.size(1); + auto k = A.size(2); + auto n = B.size(2); auto lt_handle = reinterpret_cast(cublas_handle); - cudaSetDevice(A->device.device_id); - auto stream = get_stream(A->device); + cudaSetDevice(A.device().device_id); + auto stream = get_stream(A.device()); auto status = flashinfer::bmm_fp8::bmm_fp8_internal_cublaslt( - workspace_buffer->data, workspace_buffer.numel(), static_cast(B->data), - static_cast(A->data), static_cast(D->data), batch_size, n, m, k, - static_cast(B_scale->data), static_cast(A_scale->data), lt_handle, - stream); + workspace_buffer.data_ptr(), workspace_buffer.numel(), + static_cast(B.data_ptr()), static_cast(A.data_ptr()), + static_cast(D.data_ptr()), batch_size, n, m, k, + static_cast(B_scale.data_ptr()), static_cast(A_scale.data_ptr()), + lt_handle, stream); TVM_FFI_ICHECK(status == CUBLAS_STATUS_SUCCESS) << "bmm_fp8_internal_cublaslt failed: " << cublasGetStatusString(status); return true; diff --git a/csrc/cascade.cu b/csrc/cascade.cu index 186a5c113b..98e4a590dc 100644 --- a/csrc/cascade.cu +++ b/csrc/cascade.cu @@ -35,21 +35,21 @@ void merge_state(TensorView v_a, TensorView s_a, TensorView v_b, TensorView s_b, CHECK_DIM(2, s_b); CHECK_SHAPE(v_a, v_b); CHECK_SHAPE(s_a, s_b); - TVM_FFI_ICHECK_EQ(v_a->shape[0], s_a->shape[0]); - TVM_FFI_ICHECK_EQ(v_a->shape[1], s_b->shape[1]); - unsigned int seq_len = v_a->shape[0]; - unsigned int num_heads = v_a->shape[1]; - unsigned int head_dim = v_a->shape[2]; + TVM_FFI_ICHECK_EQ(v_a.size(0), s_a.size(0)); + TVM_FFI_ICHECK_EQ(v_a.size(1), s_b.size(1)); + unsigned int seq_len = v_a.size(0); + unsigned int num_heads = v_a.size(1); + unsigned int head_dim = v_a.size(2); - cudaSetDevice(v_a->device.device_id); - auto stream = get_stream(v_a->device); + cudaSetDevice(v_a.device().device_id); + auto stream = get_stream(v_a.device()); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v_a->dtype, c_type, [&] { + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v_a.dtype(), c_type, [&] { cudaError_t status = - MergeState(static_cast(v_a->data), static_cast(s_a->data), - static_cast(v_b->data), static_cast(s_b->data), - static_cast(v_merged->data), static_cast(s_merged->data), - seq_len, num_heads, head_dim, stream); + MergeState(static_cast(v_a.data_ptr()), static_cast(s_a.data_ptr()), + static_cast(v_b.data_ptr()), static_cast(s_b.data_ptr()), + static_cast(v_merged.data_ptr()), + static_cast(s_merged.data_ptr()), seq_len, num_heads, head_dim, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MergeState kernel launch failed: " << cudaGetErrorString(status); return true; @@ -72,26 +72,26 @@ void merge_state_in_place(TensorView v, TensorView s, TensorView v_other, Tensor CHECK_DIM(2, s_other); CHECK_SHAPE(v, v_other); CHECK_SHAPE(s, s_other); - TVM_FFI_ICHECK_EQ(v->shape[0], s->shape[0]); - TVM_FFI_ICHECK_EQ(v->shape[1], s->shape[1]); + TVM_FFI_ICHECK_EQ(v.size(0), s.size(0)); + TVM_FFI_ICHECK_EQ(v.size(1), s.size(1)); uint8_t* mask_ptr = nullptr; if (mask.has_value()) { CHECK_DIM(1, mask.value()); - TVM_FFI_ICHECK_EQ(v->shape[0], mask.value()->shape[0]); + TVM_FFI_ICHECK_EQ(v.size(0), mask.value().size(0)); CHECK_DEVICE(mask.value(), v); - mask_ptr = static_cast(mask.value()->data); + mask_ptr = static_cast(mask.value().data_ptr()); } - unsigned int seq_len = v->shape[0]; - unsigned int num_heads = v->shape[1]; - unsigned int head_dim = v->shape[2]; + unsigned int seq_len = v.size(0); + unsigned int num_heads = v.size(1); + unsigned int head_dim = v.size(2); - cudaSetDevice(v->device.device_id); - auto stream = get_stream(v->device); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v->dtype, c_type, [&] { - cudaError_t status = - MergeStateInPlace(static_cast(v->data), static_cast(s->data), - static_cast(v_other->data), static_cast(s_other->data), - seq_len, num_heads, head_dim, mask_ptr, stream); + cudaSetDevice(v.device().device_id); + auto stream = get_stream(v.device()); + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v.dtype(), c_type, [&] { + cudaError_t status = MergeStateInPlace( + static_cast(v.data_ptr()), static_cast(s.data_ptr()), + static_cast(v_other.data_ptr()), static_cast(s_other.data_ptr()), seq_len, + num_heads, head_dim, mask_ptr, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MergeStateInPlace kernel launch failed: " << cudaGetErrorString(status); return true; @@ -106,21 +106,21 @@ void merge_states(TensorView v, TensorView s, TensorView v_merged, TensorView s_ CHECK_DEVICE(s, v); CHECK_DIM(4, v); CHECK_DIM(3, s); - TVM_FFI_ICHECK_EQ(v->shape[0], s->shape[0]); - TVM_FFI_ICHECK_EQ(v->shape[1], s->shape[1]); - TVM_FFI_ICHECK_EQ(v->shape[2], s->shape[2]); - unsigned int seq_len = v->shape[0]; - unsigned int num_index_sets = v->shape[1]; - unsigned int num_heads = v->shape[2]; - unsigned int head_dim = v->shape[3]; + TVM_FFI_ICHECK_EQ(v.size(0), s.size(0)); + TVM_FFI_ICHECK_EQ(v.size(1), s.size(1)); + TVM_FFI_ICHECK_EQ(v.size(2), s.size(2)); + unsigned int seq_len = v.size(0); + unsigned int num_index_sets = v.size(1); + unsigned int num_heads = v.size(2); + unsigned int head_dim = v.size(3); - cudaSetDevice(v->device.device_id); - auto stream = get_stream(v->device); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v->dtype, c_type, [&] { - cudaError_t status = - MergeStates(static_cast(v->data), static_cast(s->data), - static_cast(v_merged->data), static_cast(s_merged->data), - num_index_sets, seq_len, num_heads, head_dim, stream); + cudaSetDevice(v.device().device_id); + auto stream = get_stream(v.device()); + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(v.dtype(), c_type, [&] { + cudaError_t status = MergeStates( + static_cast(v.data_ptr()), static_cast(s.data_ptr()), + static_cast(v_merged.data_ptr()), static_cast(s_merged.data_ptr()), + num_index_sets, seq_len, num_heads, head_dim, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MergeStates kernel launch failed: " << cudaGetErrorString(status); return true; diff --git a/csrc/cudnn_sdpa_kernel_launcher.cu b/csrc/cudnn_sdpa_kernel_launcher.cu index c1ebdafd66..12f7ac1dad 100644 --- a/csrc/cudnn_sdpa_kernel_launcher.cu +++ b/csrc/cudnn_sdpa_kernel_launcher.cu @@ -341,8 +341,8 @@ static void create_packed_tma_desc_kv_prefill(int b, int32_t* actual_seq_lens_kv std::array packed_tensor_stride_v = {h_kv * d_vo * BYTES_PER_ELEMENT, d_vo * BYTES_PER_ELEMENT, 0}; - uint16_t* k_ptr = reinterpret_cast(k->data + batch_offset_k); - uint16_t* v_ptr = reinterpret_cast(v->data + batch_offset_v); + uint16_t* k_ptr = reinterpret_cast(k.data_ptr() + batch_offset_k); + uint16_t* v_ptr = reinterpret_cast(v.data_ptr() + batch_offset_v); tma::cudaSetTmaTileDescriptor( &packed_tma_desc_k[i], (void*)k_ptr, DIMS_QKV, packed_tensor_size_k.data(), @@ -383,8 +383,8 @@ static void create_packed_tma_desc_qo_prefill(int b, int32_t* actual_seq_lens_q_ std::array packed_tensor_stride_o = {h_qo * d_vo * BYTES_PER_ELEMENT, d_vo * BYTES_PER_ELEMENT, 0}; - uint16_t* q_ptr = reinterpret_cast(q->data + batch_offset_q); - uint16_t* out_ptr = reinterpret_cast(out->data + batch_offset_o); + uint16_t* q_ptr = reinterpret_cast(q.data_ptr() + batch_offset_q); + uint16_t* out_ptr = reinterpret_cast(out.data_ptr() + batch_offset_o); tma::cudaSetTmaTileDescriptor( &packed_tma_desc_q[i], (void*)q_ptr, DIMS_QKV, packed_tensor_size_q.data(), @@ -528,7 +528,7 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView constexpr int32_t NUM_THREADS = 512; - const CUstream stream = get_stream(q->device); + const CUstream stream = get_stream(q.device()); int64_t* batch_offset_q_array_data = nullptr; int64_t* batch_offset_o_array_data = nullptr; @@ -537,16 +537,16 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView int64_t* batch_offset_array_data = nullptr; if (batch_offset_q_array.has_value()) { batch_offset_array_data = static_cast( - batch_offset_q_array.value()->data); // Fix this to make it operational later + batch_offset_q_array.value().data_ptr()); // Fix this to make it operational later } // Step 1: Setup the kernel pointer static CUfunction prefill_func[KERNEL_NUM_PREFILL_TYPES] = {nullptr, nullptr, nullptr, nullptr}; - int64_t d_qk = q->shape[2]; + int64_t d_qk = q.size(2); - int64_t d_vo = v_cache->ndim == 3 ? v_cache->shape[2] : v_cache->shape[3]; + int64_t d_vo = v_cache.ndim() == 3 ? v_cache.size(2) : v_cache.size(3); if (prefill_func[0] == nullptr) { setup_prefill(prefill_func); @@ -565,17 +565,17 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView // Step 2: Extract attention descriptor - int64_t h_qo = q->shape[1]; + int64_t h_qo = q.size(1); - int64_t h_kv = k_cache->shape[1]; + int64_t h_kv = k_cache.size(1); - int64_t page_size = k_cache->ndim == 4 ? k_cache->shape[2] : 1; + int64_t page_size = k_cache.ndim() == 4 ? k_cache.size(2) : 1; int64_t s_kv = max_s_kv; int64_t num_pages_per_seq = static_cast(std::ceil(1.0 * s_kv / page_size)); - int64_t total_num_pages = k_cache->ndim == 4 ? k_cache->shape[0] : 1; + int64_t total_num_pages = k_cache.ndim() == 4 ? k_cache.size(0) : 1; bool kv_cache_enabled = d_qk == 192 ? false : true; @@ -598,8 +598,8 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView if (is_cuda_graph_compatible == false) { CHECK_CPU(actual_seq_lens_q); CHECK_CPU(actual_seq_lens_kv); - auto actual_seq_lens_q_data = static_cast(actual_seq_lens_q->data); - auto actual_seq_lens_kv_data = static_cast(actual_seq_lens_kv->data); + auto actual_seq_lens_q_data = static_cast(actual_seq_lens_q.data_ptr()); + auto actual_seq_lens_kv_data = static_cast(actual_seq_lens_kv.data_ptr()); uint32_t actual_num_tiles_per_head = std::transform_reduce( actual_seq_lens_q_data, actual_seq_lens_q_data + b, 0U, std::plus<>(), [](int32_t seq_len) { @@ -623,7 +623,7 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView auto k_strides = k_cache.strides(); auto v_strides = v_cache.strides(); - bool is_kv_ragged = k_cache->ndim == 3; + bool is_kv_ragged = k_cache.ndim() == 3; std::array tensor_traversal_stride_qkv = {1, 1, 1, 1}; std::array tensor_size_k = {d_qk, page_size, h_kv, total_num_pages}; @@ -642,7 +642,7 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView 64, kv_cache_enabled ? std::min(TILE_N_1, page_size) : TILE_N_1, 1, 1}; uint64_t batch_offset_qo = 0; - int8_t* workspace_start = static_cast(workspace_buffer->data); + int8_t* workspace_start = static_cast(workspace_buffer.data_ptr()); // These tensors are allocated in the workspace buffer // Using 2 * b for q and o @@ -664,23 +664,23 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView if (is_cuda_graph_compatible == false) { if (is_kv_ragged) { - auto actual_seq_lens_kv_data = static_cast(actual_seq_lens_kv->data); + auto actual_seq_lens_kv_data = static_cast(actual_seq_lens_kv.data_ptr()); create_packed_tma_desc_kv_prefill( b, actual_seq_lens_kv_data, d_qk, d_vo, h_kv, tensor_traversal_stride_qkv.data(), tensor_box_size_k.data(), tma_desc_k_host, tma_desc_v_host, k_cache, v_cache); } else { // tma descriptors for k and v - tma::cudaSetTmaTileDescriptor(tma_desc_k_host, k_cache->data, DIMS_QKV, tensor_size_k.data(), - tensor_stride_k.data(), tensor_traversal_stride_qkv.data(), - tensor_box_size_k.data(), tma::cudaTmaDescFormat::BF16_RN, - tma::cudaTmaDescSwizzle::SWIZZLE_128B); - - tma::cudaSetTmaTileDescriptor(tma_desc_v_host, v_cache->data, DIMS_QKV, tensor_size_v.data(), - tensor_stride_v.data(), tensor_traversal_stride_qkv.data(), - tensor_box_size_v.data(), tma::cudaTmaDescFormat::BF16_RN, - tma::cudaTmaDescSwizzle::SWIZZLE_128B); + tma::cudaSetTmaTileDescriptor( + tma_desc_k_host, k_cache.data_ptr(), DIMS_QKV, tensor_size_k.data(), + tensor_stride_k.data(), tensor_traversal_stride_qkv.data(), tensor_box_size_k.data(), + tma::cudaTmaDescFormat::BF16_RN, tma::cudaTmaDescSwizzle::SWIZZLE_128B); + + tma::cudaSetTmaTileDescriptor( + tma_desc_v_host, v_cache.data_ptr(), DIMS_QKV, tensor_size_v.data(), + tensor_stride_v.data(), tensor_traversal_stride_qkv.data(), tensor_box_size_v.data(), + tma::cudaTmaDescFormat::BF16_RN, tma::cudaTmaDescSwizzle::SWIZZLE_128B); } - auto actual_seq_lens_q_data = static_cast(actual_seq_lens_q->data); + auto actual_seq_lens_q_data = static_cast(actual_seq_lens_q.data_ptr()); create_packed_tma_desc_qo_prefill(b, actual_seq_lens_q_data, d_qk, d_vo, h_qo, tensor_traversal_stride_qkv.data(), tensor_box_size_q.data(), packed_tma_desc_q, packed_tma_desc_o, q, out, @@ -692,7 +692,7 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView dim3 grid(1, 1, 1); dim3 block(128, 1, 1); - cudaStream_t raw_stream = get_stream(q->device); + cudaStream_t raw_stream = get_stream(q.device()); cudaError_t err = cudaStreamQuery(raw_stream); if (!(err == cudaSuccess || err == cudaErrorNotReady)) { @@ -700,11 +700,12 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView } qkv_tma_setup_prefill<<>>( - b, h_qo, h_kv, d_qk, d_vo, is_kv_ragged, page_size, total_num_pages, k_cache->strides[2], - k_cache->strides[1], k_cache->strides[0], v_cache->strides[2], v_cache->strides[1], - v_cache->strides[0], static_cast(actual_seq_lens_q_gpu->data), - static_cast(actual_seq_lens_kv_gpu->data), q->data, k_cache->data, v_cache->data, - out->data, packed_tma_desc_q_dev, tma_desc_k, tma_desc_v, packed_tma_desc_o_dev); + b, h_qo, h_kv, d_qk, d_vo, is_kv_ragged, page_size, total_num_pages, k_cache.stride(2), + k_cache.stride(1), k_cache.stride(0), v_cache.stride(2), v_cache.stride(1), + v_cache.stride(0), static_cast(actual_seq_lens_q_gpu.data_ptr()), + static_cast(actual_seq_lens_kv_gpu.data_ptr()), q.data_ptr(), k_cache.data_ptr(), + v_cache.data_ptr(), out.data_ptr(), packed_tma_desc_q_dev, tma_desc_k, tma_desc_v, + packed_tma_desc_o_dev); } cudnn_sdpa::AttentionDescriptor_t attn_desc{ @@ -722,11 +723,11 @@ void prefill(int64_t b, int64_t s_qo, int64_t max_s_kv, TensorView q, TensorView uint32_t page_size32 = static_cast(page_size); uint32_t num_pages_per_seq32 = static_cast(num_pages_per_seq); - void* lse_tensor_pointer = return_lse ? lse->data : NULL; + void* lse_tensor_pointer = return_lse ? lse.data_ptr() : NULL; - void* actual_seq_lens_q_gpu_pointer = static_cast(actual_seq_lens_q_gpu->data); - void* actual_seq_lens_kv_gpu_pointer = static_cast(actual_seq_lens_kv_gpu->data); - void* block_tables_pointer = d_qk == 192 ? NULL : static_cast(block_tables->data); + void* actual_seq_lens_q_gpu_pointer = static_cast(actual_seq_lens_q_gpu.data_ptr()); + void* actual_seq_lens_kv_gpu_pointer = static_cast(actual_seq_lens_kv_gpu.data_ptr()); + void* block_tables_pointer = d_qk == 192 ? NULL : static_cast(block_tables.data_ptr()); auto print_cudaTmaDescTiled = [](tma::cudaTmaDescTiled* desc) { printf("addr %p", desc->tensor_common0); @@ -884,8 +885,8 @@ void setup_tma_desc_decode(int64_t b, int64_t s_kv, int64_t h_qo, int64_t h_kv, std::array tensor_size_partial_o = {d, split_factor, h_qo, b}; std::array tensor_stride_partial_o = { h_qo * d * b * sizeof(float), d * b * sizeof(float), d * h_qo * sizeof(float)}; - uint16_t* q_ptr = reinterpret_cast(q->data); - uint16_t* out_ptr = reinterpret_cast(out->data); + uint16_t* q_ptr = reinterpret_cast(q.data_ptr()); + uint16_t* out_ptr = reinterpret_cast(out.data_ptr()); float* partial_o_ptr = reinterpret_cast(partial_o_dev); int64_t batch_offset_qo = 0; @@ -907,12 +908,12 @@ void setup_tma_desc_decode(int64_t b, int64_t s_kv, int64_t h_qo, int64_t h_kv, batch_offset_qo += h_qo * d; } - tma::cudaSetTmaTileDescriptor(tma_desc_k, k_cache->data, DIMS_QKV, tensor_size_kv.data(), + tma::cudaSetTmaTileDescriptor(tma_desc_k, k_cache.data_ptr(), DIMS_QKV, tensor_size_kv.data(), tensor_stride_kv.data(), tensor_traversal_stride_qkv.data(), tensor_box_size_kv.data(), tma::cudaTmaDescFormat::BF16_RN, tma::cudaTmaDescSwizzle::SWIZZLE_128B); - tma::cudaSetTmaTileDescriptor(tma_desc_v, v_cache->data, DIMS_QKV, tensor_size_kv.data(), + tma::cudaSetTmaTileDescriptor(tma_desc_v, v_cache.data_ptr(), DIMS_QKV, tensor_size_kv.data(), tensor_stride_kv.data(), tensor_traversal_stride_qkv.data(), tensor_box_size_kv.data(), tma::cudaTmaDescFormat::BF16_RN, tma::cudaTmaDescSwizzle::SWIZZLE_128B); @@ -931,10 +932,10 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac int64_t* batch_offset_q_array_data = nullptr; if (batch_offset_q_array.has_value()) { - batch_offset_q_array_data = static_cast(batch_offset_q_array.value()->data); + batch_offset_q_array_data = static_cast(batch_offset_q_array.value().data_ptr()); } - const CUstream stream = get_stream(q->device); + const CUstream stream = get_stream(q.device()); constexpr int NUM_DECODE_KERNELS = 5; static CUfunction hfunc_decode[NUM_DECODE_KERNELS] = {nullptr, nullptr, nullptr, nullptr, @@ -973,15 +974,15 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, device_id); } - int64_t b = q->shape[0]; - int64_t h_qo = q->shape[1]; - int64_t d = q->shape[2]; + int64_t b = q.size(0); + int64_t h_qo = q.size(1); + int64_t d = q.size(2); - int64_t h_kv = k_cache->shape[1]; + int64_t h_kv = k_cache.size(1); - int64_t page_size = k_cache->ndim == 4 ? k_cache->shape[2] : 1; + int64_t page_size = k_cache.ndim() == 4 ? k_cache.size(2) : 1; - int64_t total_num_pages = k_cache->ndim == 4 ? k_cache->shape[0] : 1; + int64_t total_num_pages = k_cache.ndim() == 4 ? k_cache.size(0) : 1; int64_t s_kv = max_s_kv; @@ -1019,7 +1020,7 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac config.hStream = stream; config.numAttrs = 1; - int8_t* workspace_start = static_cast(workspace_buffer->data); + int8_t* workspace_start = static_cast(workspace_buffer.data_ptr()); int8_t* partial_o_dev = workspace_start; int8_t* tma_descriptor_start = partial_o_dev + (b * s_qo * h_qo * d * sizeof(float) * split_factor); @@ -1064,9 +1065,9 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac qkv_tma_setup_decode<<>>( b, h_qo, h_kv, d, total_num_pages, page_size, split_factor, TILE_M_1, TILE_N_1, - kv_strides[2], kv_strides[1], kv_strides[0], q->data, k_cache->data, v_cache->data, - out->data, partial_o_dev, packed_tma_desc_q_dev, tma_desc_k_dev, tma_desc_v_dev, - packed_tma_desc_o_dev, packed_tma_desc_partial_o_dev, + kv_strides[2], kv_strides[1], kv_strides[0], q.data_ptr(), k_cache.data_ptr(), + v_cache.data_ptr(), out.data_ptr(), partial_o_dev, packed_tma_desc_q_dev, tma_desc_k_dev, + tma_desc_v_dev, packed_tma_desc_o_dev, packed_tma_desc_partial_o_dev, reinterpret_cast(batch_strides_dev)); } else { std::unique_ptr tma_desc_host(new tma::cudaTmaDesc[5 * b]); @@ -1105,8 +1106,8 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac float attn_scale = scale; void* actual_seq_lens_q_gpu_pointer = nullptr; - void* actual_seq_lens_kv_gpu_pointer = static_cast(actual_seq_lens_kv_gpu->data); - void* block_tables_pointer = static_cast(block_tables->data); + void* actual_seq_lens_kv_gpu_pointer = static_cast(actual_seq_lens_kv_gpu.data_ptr()); + void* block_tables_pointer = static_cast(block_tables.data_ptr()); cudnn_sdpa::strides_t lse_strides = {h_qo, 1, h_qo, 1}; cudnn_sdpa::strides_t partial_lse_strides = {h_qo, 1, h_qo * b, 1}; @@ -1142,7 +1143,7 @@ void decode(int64_t max_s_kv, TensorView q, TensorView k_cache, TensorView v_cac if (split_factor > 1) { // TODO: Add support for split_factor > 1 void* args_lean_attn_reduction[11]; - void* o_dev = out->data; + void* o_dev = out.data_ptr(); void* lse_final_dev = nullptr; diff --git a/csrc/cutlass_mla.cu b/csrc/cutlass_mla.cu index ecc528de40..f68df30bea 100644 --- a/csrc/cutlass_mla.cu +++ b/csrc/cutlass_mla.cu @@ -23,21 +23,21 @@ using namespace flashinfer::attention; void CutlassMLAPagedAttention(ffi::TensorView workspace, ffi::TensorView out, ffi::TensorView lse, ffi::TensorView q_nope_pe, ffi::TensorView ckv_kpe_cache, ffi::TensorView kv_lens, ffi::TensorView page_table) { - cudaSetDevice(q_nope_pe->device.device_id); - const cudaStream_t stream = get_stream(q_nope_pe->device); + cudaSetDevice(q_nope_pe.device().device_id); + const cudaStream_t stream = get_stream(q_nope_pe.device()); - int device_index = q_nope_pe->device.device_id; - int batches = q_nope_pe->shape[0]; - int page_count_per_seq = page_table->shape[1]; - int page_count_total = ckv_kpe_cache->shape[0]; - int page_size = ckv_kpe_cache->shape[1]; + int device_index = q_nope_pe.device().device_id; + int batches = q_nope_pe.size(0); + int page_count_per_seq = page_table.size(1); + int page_count_total = ckv_kpe_cache.size(0); + int page_size = ckv_kpe_cache.size(1); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_nope_pe->dtype, c_type, [&] { + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_nope_pe.dtype(), c_type, [&] { using cutlass_t = cutlass_dtype_t; - auto status = - runMla(workspace->data, out->data, lse->data, q_nope_pe->data, - ckv_kpe_cache->data, kv_lens->data, page_table->data, batches, - page_count_per_seq, page_count_total, page_size, device_index, stream); + auto status = runMla( + workspace.data_ptr(), out.data_ptr(), lse.data_ptr(), q_nope_pe.data_ptr(), + ckv_kpe_cache.data_ptr(), kv_lens.data_ptr(), page_table.data_ptr(), batches, + page_count_per_seq, page_count_total, page_size, device_index, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to run CutlassMLAPagedAttention: " << cudaGetErrorString(status); diff --git a/csrc/fmha_cutlass_sm100.cu b/csrc/fmha_cutlass_sm100.cu index ac62f0fed6..c50116fa7f 100644 --- a/csrc/fmha_cutlass_sm100.cu +++ b/csrc/fmha_cutlass_sm100.cu @@ -82,22 +82,22 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff Optional maybe_lse, int64_t mask_mode_code, double sm_scale, int64_t num_qo_heads, int64_t num_kv_heads, int64_t head_dim_qk, int64_t head_dim_vo, int64_t max_qo_len) { - TVM_FFI_ICHECK_EQ(q->dtype, k->dtype); - auto scalar_type_in = q->dtype; - auto scalar_type_out = o->dtype; + TVM_FFI_ICHECK_EQ(q.dtype(), k.dtype()); + auto scalar_type_in = q.dtype(); + auto scalar_type_out = o.dtype(); MaskMode mask_mode = static_cast(mask_mode_code); - int total_qo_len = q->shape[0]; - int total_kv_len = k->shape[0]; - int batch_size = qo_segment_offsets->shape[0] - 1; - int q_stride_n = q->strides[0]; - int q_stride_h = q->strides[1]; - int k_stride_n = k->strides[0]; - int k_stride_h = k->strides[1]; - int v_stride_n = v->strides[0]; - int v_stride_h = v->strides[1]; + int total_qo_len = q.size(0); + int total_kv_len = k.size(0); + int batch_size = qo_segment_offsets.size(0) - 1; + int q_stride_n = q.stride(0); + int q_stride_h = q.stride(1); + int k_stride_n = k.stride(0); + int k_stride_h = k.stride(1); + int v_stride_n = v.stride(0); + int v_stride_h = v.stride(1); - cudaSetDevice(qo_segment_offsets->device.device_id); - const cudaStream_t stream = get_stream(o->device); + cudaSetDevice(qo_segment_offsets.device().device_id); + const cudaStream_t stream = get_stream(o.device()); DISPATCH_context(DTypeIn, DTypeOut, HEAD_DIM_QK, HEAD_DIM_VO, MASK_MODE, [&] { using cutlass_type_in = cutlass_dtype_t; @@ -112,13 +112,14 @@ void FMHACutlassSM100Run(ffi::TensorView workspace_buffer, ffi::TensorView q, ff typename std::conditional::type; auto status = run_fmha_fwd( - workspace_buffer->data, static_cast(q->data), - static_cast(k->data), static_cast(v->data), - static_cast(qo_segment_offsets->data), static_cast(kv_segment_offsets->data), - static_cast(work_indptr->data), static_cast(qo_tile_indices->data), - static_cast(qo_head_indices->data), static_cast(batch_indices->data), - static_cast(o->data), - maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr, + workspace_buffer.data_ptr(), static_cast(q.data_ptr()), + static_cast(k.data_ptr()), static_cast(v.data_ptr()), + static_cast(qo_segment_offsets.data_ptr()), + static_cast(kv_segment_offsets.data_ptr()), static_cast(work_indptr.data_ptr()), + static_cast(qo_tile_indices.data_ptr()), + static_cast(qo_head_indices.data_ptr()), static_cast(batch_indices.data_ptr()), + static_cast(o.data_ptr()), + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr, mask_mode_code, sm_scale, num_qo_heads, num_kv_heads, head_dim_qk, head_dim_vo, q_stride_n, q_stride_h, k_stride_n, k_stride_h, v_stride_n, v_stride_h, batch_size, total_qo_len, total_kv_len, max_qo_len, stream); diff --git a/csrc/fp4_gemm_cutlass.cu b/csrc/fp4_gemm_cutlass.cu index e55f6e6ed4..ae9b0aa658 100644 --- a/csrc/fp4_gemm_cutlass.cu +++ b/csrc/fp4_gemm_cutlass.cu @@ -69,18 +69,18 @@ void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Sc workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { - gemmRunner.gemm(out->data, mat1->data, mat2->data, mat1Scale->data, mat2Scale->data, - static_cast(globalScale->data), m, n, k, batch_count, gemmConfig, - reinterpret_cast(workspace), required_workspace_size, - get_stream(mat1->device)); + gemmRunner.gemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), mat1Scale.data_ptr(), + mat2Scale.data_ptr(), static_cast(globalScale.data_ptr()), m, n, k, + batch_count, gemmConfig, reinterpret_cast(workspace), + required_workspace_size, get_stream(mat1.device())); }; if (provided_workspace_size < required_workspace_size) { Tensor new_workspace = - alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1->device); - runKernel(new_workspace->data); + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); } else { - runKernel(workspace_buffer->data); + runKernel(workspace_buffer.data_ptr()); } } @@ -108,27 +108,26 @@ void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tensor CHECK_INPUT_AND_TYPE(globalScale, dl_float32); int64_t m, n, k, b; - if (mat1->ndim == 2) { - TVM_FFI_ICHECK_EQ(mat2->ndim, 2) << "mat2 must be a matrix"; - TVM_FFI_ICHECK_EQ(mat1->shape[1], mat2->shape[1] * mat2_k_scale) - << "mat1 and mat2 shapes cannot be multiplied (" << mat1->shape[0] << "x" << mat1->shape[1] - << " and " << mat2->shape[0] << "x" << mat2->shape[1] << ")"; - m = mat1->shape[0]; - n = mat2->shape[0]; - k = mat2->shape[1] * 2; + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); + k = mat2.size(1) * 2; b = 1; - } else if (mat1->ndim == 3) { - TVM_FFI_ICHECK_EQ(mat2->ndim, 3) << "mat2 must be a batch of matrices"; - TVM_FFI_ICHECK_EQ(mat1->shape[0], mat2->shape[0]) - << "mat1 and mat2 must have the same batch size (" << mat1->shape[0] << " and " - << mat2->shape[0] << ")"; - TVM_FFI_ICHECK_EQ(mat1->shape[2], mat2->shape[2] * mat2_k_scale) - << "mat1 and mat2 shapes cannot be multiplied (" << mat1->shape[1] << "x" << mat1->shape[2] - << " and " << mat2->shape[1] << "x" << mat2->shape[2] << ")"; - m = mat1->shape[1]; - n = mat2->shape[1]; - k = mat2->shape[2] * 2; - b = mat1->shape[0]; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); + k = mat2.size(2) * 2; + b = mat1.size(0); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; } @@ -141,24 +140,24 @@ void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tensor constexpr int alignment = 32; TVM_FFI_ICHECK_EQ(k % alignment, 0) - << "Expected k to be divisible by " << alignment << ", but got mat1 shape: (" - << mat1->shape[0] << "x" << mat1->shape[1] << "), k: " << k << "."; + << "Expected k to be divisible by " << alignment << ", but got mat1 shape: (" << mat1.size(0) + << "x" << mat1.size(1) << "), k: " << k << "."; TVM_FFI_ICHECK_EQ(n % alignment, 0) - << "Expected n to be divisible by " << alignment << ", but got mat2 shape: (" - << mat2->shape[0] << "x" << mat2->shape[1] << ")."; + << "Expected n to be divisible by " << alignment << ", but got mat2 shape: (" << mat2.size(0) + << "x" << mat2.size(1) << ")."; // Validate out dimensions std::vector out_shape = - mat1->ndim == 2 ? std::vector{m, n} : std::vector{b, m, n}; - TVM_FFI_ICHECK_EQ(out->ndim, out_shape.size()) - << "out must have " << out_shape.size() << " dimensions, but got " << out->ndim; + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size()) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); for (int i = 0; i < out_shape.size(); ++i) { - TVM_FFI_ICHECK_EQ(out->shape[i], out_shape[i]) + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " - << out->shape[i]; + << out.size(i); } - switch (encode_dlpack_dtype(out->dtype)) { + switch (encode_dlpack_dtype(out.dtype())) { case float16_code: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config, workspace_buffer); diff --git a/csrc/fp4_gemm_cutlass_sm120.cu b/csrc/fp4_gemm_cutlass_sm120.cu index 3848d55f85..30080f0fce 100644 --- a/csrc/fp4_gemm_cutlass_sm120.cu +++ b/csrc/fp4_gemm_cutlass_sm120.cu @@ -61,18 +61,18 @@ void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView mat1Sc workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { - gemmRunner.gemm(out->data, mat1->data, mat2->data, mat1Scale->data, mat2Scale->data, - static_cast(globalScale->data), m, n, k, batch_count, gemmConfig, - reinterpret_cast(workspace), required_workspace_size, - get_stream(mat1->device)); + gemmRunner.gemm(out.data_ptr(), mat1.data_ptr(), mat2.data_ptr(), mat1Scale.data_ptr(), + mat2Scale.data_ptr(), static_cast(globalScale.data_ptr()), m, n, k, + batch_count, gemmConfig, reinterpret_cast(workspace), + required_workspace_size, get_stream(mat1.device())); }; if (provided_workspace_size < required_workspace_size) { Tensor new_workspace = - alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1->device); - runKernel(new_workspace->data); + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); + runKernel(new_workspace.data_ptr()); } else { - runKernel(workspace_buffer->data); + runKernel(workspace_buffer.data_ptr()); } } @@ -83,19 +83,19 @@ void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tensor TensorView globalScale, TensorView out, TensorView workspace_buffer, int64_t tactic) { // Validate inputs - TVM_FFI_ICHECK_EQ(mat1->dtype, FLOAT4_E2M1X2) << "mat1 must be FLOAT4_E2M1X2 (uint8)"; - TVM_FFI_ICHECK_EQ(mat2->dtype, FLOAT4_E2M1X2) << "mat2 must be FLOAT4_E2M1X2 (uint8)"; - TVM_FFI_ICHECK_EQ(mat1Scale->dtype, SF_DTYPE) << "mat1Scale must be SF_DTYPE (uint8)"; - TVM_FFI_ICHECK_EQ(mat2Scale->dtype, SF_DTYPE) << "mat2Scale must be SF_DTYPE (uint8)"; - TVM_FFI_ICHECK_EQ(globalScale->dtype, dl_float32) << "globalScale must be float"; - TVM_FFI_ICHECK_EQ(mat1->device.device_type, kDLCUDA) << "mat1 must be on CUDA device"; - TVM_FFI_ICHECK_EQ(mat2->device.device_type, kDLCUDA) << "mat2 must be on CUDA device"; - TVM_FFI_ICHECK_EQ(mat1Scale->device.device_type, kDLCUDA) << "mat1Scale must be on CUDA device"; - TVM_FFI_ICHECK_EQ(mat2Scale->device.device_type, kDLCUDA) << "mat2Scale must be on CUDA device"; - TVM_FFI_ICHECK_EQ(globalScale->device.device_type, kDLCUDA) + TVM_FFI_ICHECK_EQ(mat1.dtype(), FLOAT4_E2M1X2) << "mat1 must be FLOAT4_E2M1X2 (uint8)"; + TVM_FFI_ICHECK_EQ(mat2.dtype(), FLOAT4_E2M1X2) << "mat2 must be FLOAT4_E2M1X2 (uint8)"; + TVM_FFI_ICHECK_EQ(mat1Scale.dtype(), SF_DTYPE) << "mat1Scale must be SF_DTYPE (uint8)"; + TVM_FFI_ICHECK_EQ(mat2Scale.dtype(), SF_DTYPE) << "mat2Scale must be SF_DTYPE (uint8)"; + TVM_FFI_ICHECK_EQ(globalScale.dtype(), dl_float32) << "globalScale must be float"; + TVM_FFI_ICHECK_EQ(mat1.device().device_type, kDLCUDA) << "mat1 must be on CUDA device"; + TVM_FFI_ICHECK_EQ(mat2.device().device_type, kDLCUDA) << "mat2 must be on CUDA device"; + TVM_FFI_ICHECK_EQ(mat1Scale.device().device_type, kDLCUDA) << "mat1Scale must be on CUDA device"; + TVM_FFI_ICHECK_EQ(mat2Scale.device().device_type, kDLCUDA) << "mat2Scale must be on CUDA device"; + TVM_FFI_ICHECK_EQ(globalScale.device().device_type, kDLCUDA) << "globalScale must be on CUDA device"; - TVM_FFI_ICHECK_EQ(out->device.device_type, kDLCUDA) << "out must be on CUDA device"; - TVM_FFI_ICHECK_EQ(workspace_buffer->device.device_type, kDLCUDA) + TVM_FFI_ICHECK_EQ(out.device().device_type, kDLCUDA) << "out must be on CUDA device"; + TVM_FFI_ICHECK_EQ(workspace_buffer.device().device_type, kDLCUDA) << "workspace_buffer must be on CUDA device"; // Check device consistency @@ -110,24 +110,24 @@ void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tensor int64_t b = 1; int64_t m, k_packed, n; - if (mat1->ndim == 2) { - m = mat1->shape[0]; - k_packed = mat1->shape[1]; - } else if (mat1->ndim == 3) { - b = mat1->shape[0]; - m = mat1->shape[1]; - k_packed = mat1->shape[2]; + if (mat1.ndim() == 2) { + m = mat1.size(0); + k_packed = mat1.size(1); + } else if (mat1.ndim() == 3) { + b = mat1.size(0); + m = mat1.size(1); + k_packed = mat1.size(2); } else { TVM_FFI_ICHECK(false) << "mat1 must be 2D or 3D tensor"; } - if (mat2->ndim == 2) { - n = mat2->shape[0]; - TVM_FFI_ICHECK_EQ(mat2->shape[1], k_packed) << "mat2->shape[1] must match mat1.size(-1)"; - } else if (mat2->ndim == 3) { - TVM_FFI_ICHECK_EQ(mat2->shape[0], b) << "Batch dimensions must match"; - n = mat2->shape[1]; - TVM_FFI_ICHECK_EQ(mat2->shape[2], k_packed) << "mat2->shape[2] must match mat1.size(-1)"; + if (mat2.ndim() == 2) { + n = mat2.size(0); + TVM_FFI_ICHECK_EQ(mat2.size(1), k_packed) << "mat2.size(1) must match mat1.size(-1)"; + } else if (mat2.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.size(0), b) << "Batch dimensions must match"; + n = mat2.size(1); + TVM_FFI_ICHECK_EQ(mat2.size(2), k_packed) << "mat2.size(2) must match mat1.size(-1)"; } else { TVM_FFI_ICHECK(false) << "mat2 must be 2D or 3D tensor"; } @@ -147,14 +147,14 @@ void fp4_bmm_impl(TensorView mat1, TensorView mat2, TensorView mat1Scale, Tensor // Validate output dimensions std::vector out_shape = (b > 1) ? std::vector{b, m, n} : std::vector{m, n}; - TVM_FFI_ICHECK_EQ(out->ndim, out_shape.size()) + TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size()) << "out must have " << out_shape.size() << " dimensions"; for (size_t i = 0; i < out_shape.size(); ++i) { - TVM_FFI_ICHECK_EQ(out->shape[i], out_shape[i]) - << "out.size(" << i << "): expected " << out_shape[i] << ", got " << out->shape[i]; + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) + << "out.size(" << i << "): expected " << out_shape[i] << ", got " << out.size(i); } - switch (encode_dlpack_dtype(out->dtype)) { + switch (encode_dlpack_dtype(out.dtype())) { case float16_code: runGemm(out, mat1, mat2, mat1Scale, mat2Scale, globalScale, m, n, k, b, config, workspace_buffer); diff --git a/csrc/fp8_gemm_cutlass.cu b/csrc/fp8_gemm_cutlass.cu index f8cc28ab38..d2e6a63d82 100644 --- a/csrc/fp8_gemm_cutlass.cu +++ b/csrc/fp8_gemm_cutlass.cu @@ -67,20 +67,20 @@ void runGemm(TensorView out, TensorView mat1, TensorView mat2, TensorView scale_ workspace_buffer.numel() * get_element_size(workspace_buffer); auto runKernel = [&](void* workspace) { - gemmRunner.gemm(static_cast<__nv_fp8_e4m3*>(mat1->data), - static_cast<__nv_fp8_e4m3*>(mat2->data), static_cast(scale_a->data), - static_cast(scale_b->data), out->data, m, n, k, b, gemmConfig, - static_cast(workspace), required_workspace_size, - get_stream(mat1->device)); + gemmRunner.gemm( + static_cast<__nv_fp8_e4m3*>(mat1.data_ptr()), static_cast<__nv_fp8_e4m3*>(mat2.data_ptr()), + static_cast(scale_a.data_ptr()), static_cast(scale_b.data_ptr()), + out.data_ptr(), m, n, k, b, gemmConfig, static_cast(workspace), + required_workspace_size, get_stream(mat1.device())); }; if (provided_workspace_size < required_workspace_size) { Tensor new_workspace = - alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1->device); + alloc_tensor({required_workspace_size}, DLDataType{kDLInt, 8, 1}, mat1.device()); - runKernel(new_workspace->data); + runKernel(new_workspace.data_ptr()); } else { - runKernel(workspace_buffer->data); + runKernel(workspace_buffer.data_ptr()); } } @@ -94,27 +94,26 @@ void fp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView scale_a, TensorVi int mat2_k_scale = 1; int64_t m, n, k, b; - if (mat1->ndim == 2) { - TVM_FFI_ICHECK_EQ(mat2->ndim, 2) << "mat2 must be a matrix"; - TVM_FFI_ICHECK_EQ(mat1->shape[1], mat2->shape[1] * mat2_k_scale) - << "mat1 and mat2 shapes cannot be multiplied (" << mat1->shape[0] << "x" << mat1->shape[1] - << " and " << mat2->shape[0] << "x" << mat2->shape[1] << ")"; - m = mat1->shape[0]; - n = mat2->shape[0]; - k = mat2->shape[1]; + if (mat1.ndim() == 2) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 must be a matrix"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(1) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(0) << "x" << mat1.size(1) + << " and " << mat2.size(0) << "x" << mat2.size(1) << ")"; + m = mat1.size(0); + n = mat2.size(0); + k = mat2.size(1); b = 1; - } else if (mat1->ndim == 3) { - TVM_FFI_ICHECK_EQ(mat2->ndim, 3) << "mat2 must be a batch of matrices"; - TVM_FFI_ICHECK_EQ(mat1->shape[0], mat2->shape[0]) - << "mat1 and mat2 must have the same batch size (" << mat1->shape[0] << " and " - << mat2->shape[0] << ")"; - TVM_FFI_ICHECK_EQ(mat1->shape[2], mat2->shape[2] * mat2_k_scale) - << "mat1 and mat2 shapes cannot be multiplied (" << mat1->shape[1] << "x" << mat1->shape[2] - << " and " << mat2->shape[1] << "x" << mat2->shape[2] << ")"; - m = mat1->shape[1]; - n = mat2->shape[1]; - k = mat2->shape[2]; - b = mat1->shape[0]; + } else if (mat1.ndim() == 3) { + TVM_FFI_ICHECK_EQ(mat2.ndim(), 3) << "mat2 must be a batch of matrices"; + TVM_FFI_ICHECK_EQ(mat1.size(0), mat2.size(0)) << "mat1 and mat2 must have the same batch size (" + << mat1.size(0) << " and " << mat2.size(0) << ")"; + TVM_FFI_ICHECK_EQ(mat1.size(2), mat2.size(2) * mat2_k_scale) + << "mat1 and mat2 shapes cannot be multiplied (" << mat1.size(1) << "x" << mat1.size(2) + << " and " << mat2.size(1) << "x" << mat2.size(2) << ")"; + m = mat1.size(1); + n = mat2.size(1); + k = mat2.size(2); + b = mat1.size(0); } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "mat1 must be a matrix or a batch of matrices"; } @@ -127,16 +126,16 @@ void fp8_bmm_impl(TensorView mat1, TensorView mat2, TensorView scale_a, TensorVi // Validate out dimensions std::vector out_shape = - mat1->ndim == 2 ? std::vector{m, n} : std::vector{b, m, n}; - TVM_FFI_ICHECK_EQ(out->ndim, out_shape.size()) - << "out must have " << out_shape.size() << " dimensions, but got " << out->ndim; + mat1.ndim() == 2 ? std::vector{m, n} : std::vector{b, m, n}; + TVM_FFI_ICHECK_EQ(out.ndim(), out_shape.size()) + << "out must have " << out_shape.size() << " dimensions, but got " << out.ndim(); for (int i = 0; i < out_shape.size(); ++i) { - TVM_FFI_ICHECK_EQ(out->shape[i], out_shape[i]) + TVM_FFI_ICHECK_EQ(out.size(i), out_shape[i]) << "out shape mismatch at dimension " << i << ": expected " << out_shape[i] << ", got " - << out->shape[i]; + << out.size(i); } - switch (encode_dlpack_dtype(out->dtype)) { + switch (encode_dlpack_dtype(out.dtype())) { case float16_code: runGemm(out, mat1, mat2, scale_a, scale_b, m, n, k, b, config, workspace_buffer); break; diff --git a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu index 298242acf4..051a9b0f96 100644 --- a/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu +++ b/csrc/fused_moe/cutlass_backend/flashinfer_cutlass_fused_moe_sm100_binding.cu @@ -267,36 +267,36 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { CHECK_DIM(2, fc1_expert_biases.value()); CHECK_DIM(2, fc2_expert_biases.value()); - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[0], fc1_expert_biases.value()->shape[0]) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc1_expert_biases.value().size(0)) << "fc1_expert_weights and fc1_expert_biases must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc2_expert_weights->shape[0], fc2_expert_biases.value()->shape[0]) + TVM_FFI_ICHECK_EQ(fc2_expert_weights.size(0), fc2_expert_biases.value().size(0)) << "fc2_expert_weights and fc2_expert_biases must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc1_expert_biases.value()->shape[1], fc1_expert_weights->shape[1]) + TVM_FFI_ICHECK_EQ(fc1_expert_biases.value().size(1), fc1_expert_weights.size(1)) << "fc1_expert_biases should match fc1_expert_weights output shape."; - TVM_FFI_ICHECK_EQ(fc2_expert_biases.value()->shape[1], fc2_expert_weights->shape[1]) + TVM_FFI_ICHECK_EQ(fc2_expert_biases.value().size(1), fc2_expert_weights.size(1)) << "fc2_expert_biases should match fc2_expert_weights output shape."; } - TVM_FFI_ICHECK_EQ(input->shape[0], token_selected_experts->shape[0]) + TVM_FFI_ICHECK_EQ(input.size(0), token_selected_experts.size(0)) << "input and token_selected_experts must have the same num tokens."; if (token_final_scales.has_value()) { CHECK_DIM(2, token_final_scales.value()); - TVM_FFI_ICHECK_EQ(input->shape[0], token_final_scales.value()->shape[0]) + TVM_FFI_ICHECK_EQ(input.size(0), token_final_scales.value().size(0)) << "input and token_selected_experts_probs must have the same num tokens."; - TVM_FFI_ICHECK_EQ(token_selected_experts->shape[1], token_final_scales.value()->shape[1]) + TVM_FFI_ICHECK_EQ(token_selected_experts.size(1), token_final_scales.value().size(1)) << "token_selected_experts and token_final_scales must have the same number of " "experts per token."; } - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[0], fc2_expert_weights->shape[0]) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc2_expert_weights.size(0)) << "fc1_expert_weights and fc2_expert_weights must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1], - fc2_expert_weights->shape[2] * mInnerDimMultiplier * 2) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1), + fc2_expert_weights.size(2) * mInnerDimMultiplier * 2) << "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."; - int experts_per_token = token_selected_experts->shape[1]; - int64_t num_rows = input->shape[0]; - int64_t hidden_size = fc2_expert_weights->shape[1]; - int64_t inter_size = fc2_expert_weights->shape[2] * mInnerDimMultiplier; + int experts_per_token = token_selected_experts.size(1); + int64_t num_rows = input.size(0); + int64_t hidden_size = fc2_expert_weights.size(1); + int64_t inter_size = fc2_expert_weights.size(2) * mInnerDimMultiplier; if (isWMxfp4AMxfp8Quant() || isWMxfp4AFp8Quant()) { // MXFP4 weights are required to bealigned to 128 bytes @@ -316,40 +316,40 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { << " for weights"; } - int const num_experts_on_rank = fc2_expert_weights->shape[0]; + int const num_experts_on_rank = fc2_expert_weights.size(0); auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); ActivationType base_activation_type = ActivationType::Swiglu; if (swiglu_alpha.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_alpha.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_alpha.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_alpha.value().size(0), num_experts_on_rank) << "swiglu_alpha must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } if (swiglu_beta.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_beta.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank) << "swiglu_beta must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } if (swiglu_limit.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_limit.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_limit.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_limit.value().size(0), num_experts_on_rank) << "swiglu_limit must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } auto activation_params = ActivationParams( base_activation_type, - reinterpret_cast(swiglu_alpha.has_value() ? swiglu_alpha.value()->data + reinterpret_cast(swiglu_alpha.has_value() ? swiglu_alpha.value().data_ptr() : nullptr), - reinterpret_cast(swiglu_beta.has_value() ? swiglu_beta.value()->data + reinterpret_cast(swiglu_beta.has_value() ? swiglu_beta.value().data_ptr() : nullptr), - reinterpret_cast(swiglu_limit.has_value() ? swiglu_limit.value()->data + reinterpret_cast(swiglu_limit.has_value() ? swiglu_limit.value().data_ptr() : nullptr)); setRunnerProfiles(profile_ids); - auto stream = get_stream(input->device); + auto stream = get_stream(input.device()); WorkspaceInfo workspace_info = getWorkspaceInfo( num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), @@ -362,36 +362,38 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), - token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) - : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), - static_cast(workspace_info.workspace->data), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, - enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, - min_latency_mode, min_latency_params, enable_pdl, stream); + mKernelRunner->runMoe( + input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, + reinterpret_cast(token_selected_experts.data_ptr()), + token_final_scales.has_value() + ? reinterpret_cast(token_final_scales.value().data_ptr()) + : nullptr, + fc1_expert_weights.data_ptr(), + fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, + activation_params, fc2_expert_weights.data_ptr(), + fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, + quant_params, num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), + static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), + static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, + false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #else mKernelRunner->runMoe( - input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), + input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, + reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) + ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, - num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), - static_cast(workspace_info.workspace), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + fc1_expert_weights.data_ptr(), + fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, + activation_params, fc2_expert_weights.data_ptr(), + fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, + quant_params, num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace), + output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, + false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -434,93 +436,95 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { CHECK_INPUT_TYPE(fc2_expert_biases.value(), mOutputDtype); CHECK_DIM(2, fc1_expert_biases.value()); CHECK_DIM(2, fc2_expert_biases.value()); - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[0], fc1_expert_biases.value()->shape[0]) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc1_expert_biases.value().size(0)) << "fc1_expert_weights and fc1_expert_biases must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc2_expert_weights->shape[0], fc2_expert_biases.value()->shape[0]) + TVM_FFI_ICHECK_EQ(fc2_expert_weights.size(0), fc2_expert_biases.value().size(0)) << "fc2_expert_weights and fc2_expert_biases must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc1_expert_biases.value()->shape[1], fc1_expert_weights->shape[1]) + TVM_FFI_ICHECK_EQ(fc1_expert_biases.value().size(1), fc1_expert_weights.size(1)) << "fc1_expert_biases should match fc1_expert_weights output shape."; - TVM_FFI_ICHECK_EQ(fc2_expert_biases.value()->shape[1], fc2_expert_weights->shape[1]) + TVM_FFI_ICHECK_EQ(fc2_expert_biases.value().size(1), fc2_expert_weights.size(1)) << "fc2_expert_biases should match fc2_expert_weights output shape."; } - TVM_FFI_ICHECK_EQ(input->shape[0], token_selected_experts->shape[0]) + TVM_FFI_ICHECK_EQ(input.size(0), token_selected_experts.size(0)) << "input and token_selected_experts must have the same num tokens."; if (token_final_scales) { CHECK_DIM(2, token_final_scales.value()); - TVM_FFI_ICHECK_EQ(input->shape[0], token_final_scales.value()->shape[0]) + TVM_FFI_ICHECK_EQ(input.size(0), token_final_scales.value().size(0)) << "input and token_selected_experts_probs must have the same num tokens."; - TVM_FFI_ICHECK_EQ(token_selected_experts->shape[1], token_final_scales.value()->shape[1]) + TVM_FFI_ICHECK_EQ(token_selected_experts.size(1), token_final_scales.value().size(1)) << "token_selected_experts and token_final_scales must have the same number of " "experts per token."; } - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[0], fc2_expert_weights->shape[0]) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(0), fc2_expert_weights.size(0)) << "fc1_expert_weights and fc2_expert_weights must have the same number of experts."; - TVM_FFI_ICHECK_EQ(fc1_expert_weights->shape[1], - fc2_expert_weights->shape[2] * mInnerDimMultiplier * 2) + TVM_FFI_ICHECK_EQ(fc1_expert_weights.size(1), + fc2_expert_weights.size(2) * mInnerDimMultiplier * 2) << "fc1_expert_weights inter size must be 2 times fc2_expert_weights inter size."; TVM_FFI_ICHECK(!input_sf.has_value() || isWMxfp4AMxfp8Quant() || isNvfp4Quant()) << "Block-scaling factors provided for non block-scaling quantization"; - int experts_per_token = token_selected_experts->shape[1]; - int64_t num_rows = input->shape[0]; - int64_t hidden_size = fc2_expert_weights->shape[1]; - int64_t inter_size = fc2_expert_weights->shape[2] * mInnerDimMultiplier; + int experts_per_token = token_selected_experts.size(1); + int64_t num_rows = input.size(0); + int64_t hidden_size = fc2_expert_weights.size(1); + int64_t inter_size = fc2_expert_weights.size(2) * mInnerDimMultiplier; - int const num_experts_on_rank = fc2_expert_weights->shape[0]; + int const num_experts_on_rank = fc2_expert_weights.size(0); auto const num_experts_total = static_cast(num_experts_on_rank * ep_size); auto parallelism_config = kernels::MOEParallelismConfig(tp_size, tp_rank, ep_size, ep_rank); ActivationType base_activation_type = ActivationType::Swiglu; if (swiglu_alpha.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_alpha.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_alpha.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_alpha.value().size(0), num_experts_on_rank) << "swiglu_alpha must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } if (swiglu_beta.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_beta.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_beta.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_beta.value().size(0), num_experts_on_rank) "swiglu_beta must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } if (swiglu_limit.has_value()) { CHECK_INPUT_AND_TYPE(swiglu_limit.value(), dl_float32); - TVM_FFI_ICHECK_EQ(swiglu_limit.value()->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(swiglu_limit.value().size(0), num_experts_on_rank) << "swiglu_limit must have num_experts_on_rank elements."; base_activation_type = ActivationType::SwigluBias; } auto activation_params = ActivationParams( base_activation_type, - reinterpret_cast(swiglu_alpha.has_value() ? swiglu_alpha.value()->data + reinterpret_cast(swiglu_alpha.has_value() ? swiglu_alpha.value().data_ptr() : nullptr), - reinterpret_cast(swiglu_beta.has_value() ? swiglu_beta.value()->data + reinterpret_cast(swiglu_beta.has_value() ? swiglu_beta.value().data_ptr() : nullptr), - reinterpret_cast(swiglu_limit.has_value() ? swiglu_limit.value()->data + reinterpret_cast(swiglu_limit.has_value() ? swiglu_limit.value().data_ptr() : nullptr)); setRunnerProfiles(profile_ids); - auto stream = get_stream(input->device); + auto stream = get_stream(input.device()); CHECK_DIM(1, num_active_experts_per_node); CHECK_INPUT_TYPE(num_active_experts_per_node, dl_int32); - TVM_FFI_ICHECK_EQ(num_active_experts_per_node->shape[0], 1); + TVM_FFI_ICHECK_EQ(num_active_experts_per_node.size(0), 1); CHECK_DIM(2, experts_to_token_score); CHECK_INPUT_TYPE(experts_to_token_score, dl_float32); - TVM_FFI_ICHECK_EQ(experts_to_token_score->shape[0], num_experts_on_rank); - TVM_FFI_ICHECK_EQ(experts_to_token_score->shape[1], num_rows); + TVM_FFI_ICHECK_EQ(experts_to_token_score.size(0), num_experts_on_rank); + TVM_FFI_ICHECK_EQ(experts_to_token_score.size(1), num_rows); CHECK_DIM(1, active_expert_global_ids); CHECK_INPUT_TYPE(active_expert_global_ids, dl_int32); - TVM_FFI_ICHECK_EQ(active_expert_global_ids->shape[0], num_experts_on_rank); + TVM_FFI_ICHECK_EQ(active_expert_global_ids.size(0), num_experts_on_rank); kernels::MoeMinLatencyParams min_latency_params{}; min_latency_params.num_active_experts_per_node = - static_cast(num_active_experts_per_node->data); - min_latency_params.experts_to_token_score = static_cast(experts_to_token_score->data); - min_latency_params.active_expert_global_ids = static_cast(active_expert_global_ids->data); + static_cast(num_active_experts_per_node.data_ptr()); + min_latency_params.experts_to_token_score = + static_cast(experts_to_token_score.data_ptr()); + min_latency_params.active_expert_global_ids = + static_cast(active_expert_global_ids.data_ptr()); WorkspaceInfo workspace_info = getWorkspaceInfo( num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), @@ -532,36 +536,38 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // TODO: support lora in the future ::tensorrt_llm::kernels::LoraParams lora_params{}; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mKernelRunner->runMoe(input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), - token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) - : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, - quant_params, num_rows, hidden_size, inter_size, num_experts_total, - static_cast(experts_per_token), - static_cast(workspace_info.workspace->data), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, - enable_alltoall, false, lora_params, mUseDeepSeekFP8BlockScaling, - min_latency_mode, min_latency_params, enable_pdl, stream); + mKernelRunner->runMoe( + input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, + reinterpret_cast(token_selected_experts.data_ptr()), + token_final_scales.has_value() + ? reinterpret_cast(token_final_scales.value().data_ptr()) + : nullptr, + fc1_expert_weights.data_ptr(), + fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, + activation_params, fc2_expert_weights.data_ptr(), + fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, + quant_params, num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), + static_cast(workspace_info.workspace.data_ptr()), output.data_ptr(), + static_cast(workspace_info.src_to_dest_map), parallelism_config, enable_alltoall, + false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #else mKernelRunner->runMoe( - input->data, input_sf.has_value() ? input_sf.value()->data : nullptr, - reinterpret_cast(token_selected_experts->data), + input.data_ptr(), input_sf.has_value() ? input_sf.value().data_ptr() : nullptr, + reinterpret_cast(token_selected_experts.data_ptr()), token_final_scales.has_value() - ? reinterpret_cast(token_final_scales.value()->data) + ? reinterpret_cast(token_final_scales.value().data_ptr()) : nullptr, - fc1_expert_weights->data, - fc1_expert_biases.has_value() ? fc1_expert_biases.value()->data : nullptr, - activation_params, fc2_expert_weights->data, - fc2_expert_biases.has_value() ? fc2_expert_biases.value()->data : nullptr, quant_params, - num_rows, hidden_size, inter_size, num_experts_total, static_cast(experts_per_token), - static_cast(workspace_info.workspace), output->data, - static_cast(workspace_info.src_to_dest_map), parallelism_config, false, lora_params, - mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, enable_pdl, stream); + fc1_expert_weights.data_ptr(), + fc1_expert_biases.has_value() ? fc1_expert_biases.value().data_ptr() : nullptr, + activation_params, fc2_expert_weights.data_ptr(), + fc2_expert_biases.has_value() ? fc2_expert_biases.value().data_ptr() : nullptr, + quant_params, num_rows, hidden_size, inter_size, num_experts_total, + static_cast(experts_per_token), static_cast(workspace_info.workspace), + output.data_ptr(), static_cast(workspace_info.src_to_dest_map), parallelism_config, + false, lora_params, mUseDeepSeekFP8BlockScaling, min_latency_mode, min_latency_params, + enable_pdl, stream); #endif } @@ -583,9 +589,9 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { return; } - int64_t num_rows = input->shape[0]; - int64_t hidden_size = fc2_expert_weights->shape[1]; - int64_t inter_size = fc2_expert_weights->shape[2] * mInnerDimMultiplier; + int64_t num_rows = input.size(0); + int64_t hidden_size = fc2_expert_weights.size(1); + int64_t inter_size = fc2_expert_weights.size(2) * mInnerDimMultiplier; int64_t group_size_ = isInt4Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size : -1; @@ -593,17 +599,17 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { isWFP4A16Quant() ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size : group_size_; - int const num_experts = static_cast(fc2_expert_weights->shape[0] * ep_size); + int const num_experts = static_cast(fc2_expert_weights.size(0) * ep_size); // Get specific profile configs according to the profile_id. // Fallback tactic is set to be 0 // TODO: use the best tactic id found offline for a better default inference perf auto profile = profile_id == -1 ? mAllProfiles.front() : mAllProfiles[profile_id]; - auto stream = get_stream(input->device); + auto stream = get_stream(input.device()); auto const* expert_weights_ptr = - (gemm_idx == 1) ? fc1_expert_weights->data : fc2_expert_weights->data; + (gemm_idx == 1) ? fc1_expert_weights.data_ptr() : fc2_expert_weights.data_ptr(); // Preparation phase, only enabled during autotuning warmup phase. if (do_preparation) { @@ -644,12 +650,12 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { mProfileWorkspace = alloc_tensor({static_cast(profile_workspace_size)}, dl_int8, DLDevice{kDLCUDA, device_id}); - mProfiler->prepare(num_rows, static_cast(mProfileWorkspace->data), expert_weights_ptr, - enable_pdl, stream); + mProfiler->prepare(num_rows, static_cast(mProfileWorkspace.data_ptr()), + expert_weights_ptr, enable_pdl, stream); } // Profile specific tactic. Assuming at least one preparation phase has been executed already. - mProfiler->runProfiler(num_rows, profile, static_cast(mProfileWorkspace->data), + mProfiler->runProfiler(num_rows, profile, static_cast(mProfileWorkspace.data_ptr()), expert_weights_ptr, enable_pdl, stream); } @@ -781,8 +787,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { cudaGetDevice(&device_id); info.workspace = alloc_tensor({static_cast(total_workspace_size)}, dl_int8, DLDevice{kDLCUDA, device_id}); - info.src_to_dest_map = - common::nextWorkspacePtr(static_cast(info.workspace->data), moe_workspace_size); + info.src_to_dest_map = common::nextWorkspacePtr(static_cast(info.workspace.data_ptr()), + moe_workspace_size); return info; } @@ -807,22 +813,23 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { CHECK_INPUT_TYPE(fc1_input_dequant, dl_float32); // Check ranks CHECK_DIM(1, fc1_dequant); - TVM_FFI_ICHECK_LE(fc2_quant->ndim, 1) << "fc2 quant must be a scalar or 1-D tensor"; + TVM_FFI_ICHECK_LE(fc2_quant.ndim(), 1) << "fc2 quant must be a scalar or 1-D tensor"; CHECK_DIM(1, fc2_dequant); CHECK_DIM(0, fc1_input_dequant); // Check shapes - TVM_FFI_ICHECK_EQ(fc1_dequant->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc1_dequant.size(0), num_experts_on_rank) << "fc1 dequant size must be (num_experts_on_rank,)"; - TVM_FFI_ICHECK(fc2_quant->ndim == 0 || fc2_quant->shape[0] == num_experts_on_rank) + TVM_FFI_ICHECK(fc2_quant.ndim() == 0 || fc2_quant.size(0) == num_experts_on_rank) << "fc2 quant must be scalar or (num_experts_on_rank,)"; - TVM_FFI_ICHECK_EQ(fc2_dequant->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc2_dequant.size(0), num_experts_on_rank) << "fc2 dequant size must be (num_experts_on_rank,)"; - return kernels::QuantParams::FP8( - static_cast(fc1_dequant->data), static_cast(fc2_quant->data), - static_cast(fc2_dequant->data), - /* fp8 output quant scale */ nullptr, static_cast(fc1_input_dequant->data), - fc2_quant->ndim == 1); + return kernels::QuantParams::FP8(static_cast(fc1_dequant.data_ptr()), + static_cast(fc2_quant.data_ptr()), + static_cast(fc2_dequant.data_ptr()), + /* fp8 output quant scale */ nullptr, + static_cast(fc1_input_dequant.data_ptr()), + fc2_quant.ndim() == 1); } else if (isWMxfp4AFp8Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for W4A8_MXFP4_MXF8 quantization"; @@ -846,47 +853,48 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { // Check ranks CHECK_DIM(3, fc1_weight_block); CHECK_DIM(1, fc1_global); - TVM_FFI_ICHECK_LE(fc2_act_global->ndim, 1) << "fc2 act global must be a scalar or 1-D tensor"; + TVM_FFI_ICHECK_LE(fc2_act_global.ndim(), 1) + << "fc2 act global must be a scalar or 1-D tensor"; CHECK_DIM(3, fc2_weight_block); CHECK_DIM(1, fc2_global); // Check shapes TVM_FFI_ICHECK( - fc1_weight_block->shape[0] == num_experts_on_rank && - fc1_weight_block->shape[1] == + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * 2 && - fc1_weight_block->shape[2] * FP8_PER_INT32 * + fc1_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " "// block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc1_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) << "fc1 global size must be (num_experts_on_rank,)"; - TVM_FFI_ICHECK(fc2_act_global->ndim == 0 || fc2_act_global->shape[0] == num_experts_on_rank) + TVM_FFI_ICHECK(fc2_act_global.ndim() == 0 || fc2_act_global.size(0) == num_experts_on_rank) << "fc2 act global must be scalar or (num_experts_on_rank,)"; TVM_FFI_ICHECK( - fc2_weight_block->shape[0] == num_experts_on_rank && - fc2_weight_block->shape[1] == + fc2_weight_block.size(0) == num_experts_on_rank && + fc2_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && - fc2_weight_block->shape[2] * FP8_PER_INT32 * + fc2_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " "block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc2_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) << "fc2 global size must be (num_experts_on_rank,)"; return kernels::QuantParams::FP8MXFP4( nullptr, - static_cast(fc1_weight_block->data), - static_cast(fc1_global->data), - static_cast(fc2_act_global->data), - static_cast(fc2_weight_block->data), - static_cast(fc2_global->data), false, fc2_act_global->ndim == 1); + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_act_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr()), false, fc2_act_global.ndim() == 1); } else if (isWMxfp4AMxfp8Quant()) { #ifdef USING_OSS_CUTLASS_MOE_GEMM TVM_FFI_ICHECK(quant_scales.has_value()) @@ -910,38 +918,38 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { CHECK_DIM(3, fc2_weight_block); CHECK_DIM(1, fc2_global); TVM_FFI_ICHECK( - fc1_weight_block->shape[0] == num_experts_on_rank && - fc1_weight_block->shape[1] == + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) * 2 && - fc1_weight_block->shape[2] * FP8_PER_INT32 * + fc1_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " "// block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc1_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) << "fc1 global size must be (num_experts_on_rank,)"; TVM_FFI_ICHECK( - fc2_weight_block->shape[0] == num_experts_on_rank && - fc2_weight_block->shape[1] == + fc2_weight_block.size(0) == num_experts_on_rank && + fc2_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentMXFPX) && - fc2_weight_block->shape[2] * FP8_PER_INT32 * + fc2_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentMXFPX)) << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " "block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc2_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) << "fc2 global size must be (num_experts_on_rank,)"; return kernels::QuantParams::MXFP8MXFP4( - static_cast(fc1_weight_block->data), - static_cast(fc1_global->data), - static_cast(fc2_weight_block->data), - static_cast(fc2_global->data)); + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr())); #else TVM_FFI_ICHECK(false) << "MXFP8 x MXFP4 quantization is not supported in OSS Cutlass Moe Gemm"; @@ -970,58 +978,61 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { CHECK_INPUT_TYPE(fc2_weight_block, dl_int32); CHECK_INPUT_TYPE(fc2_global, dl_float32); // Check ranks - TVM_FFI_ICHECK_LE(fc1_act_global->ndim, 1) << "fc1 act global must be a scalar or 1-D tensor"; + TVM_FFI_ICHECK_LE(fc1_act_global.ndim(), 1) + << "fc1 act global must be a scalar or 1-D tensor"; CHECK_DIM(3, fc1_weight_block); CHECK_DIM(1, fc1_global); - TVM_FFI_ICHECK_LE(fc2_act_global->ndim, 1) << "fc2 act global must be a scalar or 1-D tensor"; + TVM_FFI_ICHECK_LE(fc2_act_global.ndim(), 1) + << "fc2 act global must be a scalar or 1-D tensor"; CHECK_DIM(3, fc2_weight_block); CHECK_DIM(1, fc2_global); // Check shapes - TVM_FFI_ICHECK(fc1_act_global->ndim == 0 || fc1_act_global->shape[0] == num_experts_on_rank) + TVM_FFI_ICHECK(fc1_act_global.ndim() == 0 || fc1_act_global.size(0) == num_experts_on_rank) << "fc1 act global must be scalar or (num_experts_on_rank,)"; TVM_FFI_ICHECK( - fc1_weight_block->shape[0] == num_experts_on_rank && - fc1_weight_block->shape[1] == + fc1_weight_block.size(0) == num_experts_on_rank && + fc1_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4) * 2 && - fc1_weight_block->shape[2] * FP8_PER_INT32 * + fc1_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)) << "fc1 weight block size must be (num_experts_on_rank, inter_size * 2, hidden_size // 4 " "// block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc1_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc1_global.size(0), num_experts_on_rank) << "fc1 global size must be (num_experts_on_rank,)"; - TVM_FFI_ICHECK(fc2_act_global->ndim == 0 || fc2_act_global->shape[0] == num_experts_on_rank) + TVM_FFI_ICHECK(fc2_act_global.ndim() == 0 || fc2_act_global.size(0) == num_experts_on_rank) << "fc2 act global must be scalar or (num_experts_on_rank,)"; TVM_FFI_ICHECK( - fc2_weight_block->shape[0] == num_experts_on_rank && - fc2_weight_block->shape[1] == + fc2_weight_block.size(0) == num_experts_on_rank && + fc2_weight_block.size(1) == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( hidden_size, TmaWarpSpecializedGroupedGemmInput::MinNDimAlignmentNVFP4) && - fc2_weight_block->shape[2] * FP8_PER_INT32 * + fc2_weight_block.size(2) * FP8_PER_INT32 * TmaWarpSpecializedGroupedGemmInput::NVFP4BlockScaleVectorSize == TmaWarpSpecializedGroupedGemmInput::alignToSfDim( inter_size, TmaWarpSpecializedGroupedGemmInput::MinKDimAlignmentNVFP4)) << "fc2 weight block size must be (num_experts_on_rank, hidden_size, inter_size // 4 // " "block_scale_vector_size)"; - TVM_FFI_ICHECK_EQ(fc2_global->shape[0], num_experts_on_rank) + TVM_FFI_ICHECK_EQ(fc2_global.size(0), num_experts_on_rank) << "fc2 global size must be (num_experts_on_rank,)"; return kernels::QuantParams::FP4( - static_cast(fc1_act_global->data), - static_cast(fc1_weight_block->data), - static_cast(fc1_global->data), - static_cast(fc2_act_global->data), - static_cast(fc2_weight_block->data), - static_cast(fc2_global->data), fc1_act_global->ndim == 1, - fc2_act_global->ndim == 1); + static_cast(fc1_act_global.data_ptr()), + static_cast(fc1_weight_block.data_ptr()), + static_cast(fc1_global.data_ptr()), + static_cast(fc2_act_global.data_ptr()), + static_cast(fc2_weight_block.data_ptr()), + static_cast(fc2_global.data_ptr()), fc1_act_global.ndim() == 1, + fc2_act_global.ndim() == 1); } else if (mUseDeepSeekFP8BlockScaling) { TensorView fc1_scales = quant_scales.value()[0]; TensorView fc2_scales = quant_scales.value()[1]; - return kernels::QuantParams::FP8BlockScaling(static_cast(fc1_scales->data), - static_cast(fc2_scales->data)); + return kernels::QuantParams::FP8BlockScaling( + static_cast(fc1_scales.data_ptr()), + static_cast(fc2_scales.data_ptr())); } else if (isWFP4A16Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for W4 quantization"; TVM_FFI_ICHECK_EQ(quant_scales.value().size(), 2) @@ -1031,8 +1042,8 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TensorView fc2_weight_scales = quant_scales.value()[1]; int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size; return kernels::QuantParams::GroupWise(group_size, - static_cast(fc1_weight_scales->data), - static_cast(fc2_weight_scales->data), + static_cast(fc1_weight_scales.data_ptr()), + static_cast(fc2_weight_scales.data_ptr()), nullptr, nullptr, nullptr, nullptr, nullptr, nullptr); } else if (isInt4Quant()) { TVM_FFI_ICHECK(quant_scales.has_value()) << "Expecting quant scales for INT4 quantization"; @@ -1048,14 +1059,18 @@ class FusedMoeRunner : public tvm::ffi::ModuleObj { TensorView fc2_alpha = quant_scales.value()[7]; int group_size = TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size; return kernels::QuantParams::GroupWise( - group_size, static_cast(fc1_weight_scales->data), - static_cast(fc2_weight_scales->data), - static_cast(fc1_act_scales.numel() > 0 ? fc1_act_scales->data : nullptr), - static_cast(fc2_act_scales.numel() > 0 ? fc2_act_scales->data : nullptr), - static_cast(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros->data : nullptr), - static_cast(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros->data : nullptr), - static_cast(fc1_alpha.numel() > 0 ? fc1_alpha->data : nullptr), - static_cast(fc2_alpha.numel() > 0 ? fc2_alpha->data : nullptr)); + group_size, static_cast(fc1_weight_scales.data_ptr()), + static_cast(fc2_weight_scales.data_ptr()), + static_cast(fc1_act_scales.numel() > 0 ? fc1_act_scales.data_ptr() + : nullptr), + static_cast(fc2_act_scales.numel() > 0 ? fc2_act_scales.data_ptr() + : nullptr), + static_cast(fc1_weight_zeros.numel() > 0 ? fc1_weight_zeros.data_ptr() + : nullptr), + static_cast(fc2_weight_zeros.numel() > 0 ? fc2_weight_zeros.data_ptr() + : nullptr), + static_cast(fc1_alpha.numel() > 0 ? fc1_alpha.data_ptr() : nullptr), + static_cast(fc2_alpha.numel() > 0 ? fc2_alpha.data_ptr() : nullptr)); } else { return kernels::QuantParams{}; } diff --git a/csrc/gemm_groupwise_sm100.cu b/csrc/gemm_groupwise_sm100.cu index 56a89a23e0..30e7329db7 100644 --- a/csrc/gemm_groupwise_sm100.cu +++ b/csrc/gemm_groupwise_sm100.cu @@ -91,11 +91,11 @@ void CutlassGemmGroupwiseScaledSM100(TensorView float_workspace_buffer, TensorVi int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode, int64_t mma_sm) { - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(C->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(C.device()); DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { return DISPATCH_MMA_SM(mma_sm, MMA_SM, [&] { - return DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A->dtype, C->dtype, c_type_in, c_type_out, [&] { + return DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A.dtype(), C.dtype(), c_type_in, c_type_out, [&] { return DISPATCH_SCALE_GRANULARITY( scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] { @@ -103,13 +103,13 @@ void CutlassGemmGroupwiseScaledSM100(TensorView float_workspace_buffer, TensorVi using cutlass_t_out = cutlass_dtype_t; auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM100< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K, - MMA_SM>( - static_cast(float_workspace_buffer->data), - get_element_size(float_workspace_buffer) * float_workspace_buffer->shape[0], - static_cast(A->data), static_cast(B->data), - static_cast(SFA->data), static_cast(SFB->data), - static_cast(C->data), A->shape[0], B->shape[0], A->shape[1], 1, - stream); + MMA_SM>(static_cast(float_workspace_buffer.data_ptr()), + get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), + static_cast(A.data_ptr()), + static_cast(B.data_ptr()), + static_cast(SFA.data_ptr()), static_cast(SFB.data_ptr()), + static_cast(C.data_ptr()), A.size(0), B.size(0), + A.size(1), 1, stream); TVM_FFI_ICHECK_EQ(status, cudaSuccess) << "Failed to run cutlass gemm groupwise scaled sm100" << cudaGetErrorString(status); diff --git a/csrc/gemm_groupwise_sm120.cu b/csrc/gemm_groupwise_sm120.cu index a434325a1a..28cbf58d33 100644 --- a/csrc/gemm_groupwise_sm120.cu +++ b/csrc/gemm_groupwise_sm120.cu @@ -86,8 +86,8 @@ void CutlassGemmGroupwiseScaledSM120(TensorView float_workspace_buffer, TensorVi TensorView SFA, TensorView SFB, TensorView C, int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode) { - cudaSetDevice(float_workspace_buffer->device.device_id); - auto stream = get_stream(C->device); + cudaSetDevice(float_workspace_buffer.device().device_id); + auto stream = get_stream(C.device()); // Ensure scales are contiguous // Note: We keep the original shape and let the kernel's layout handle interpretation @@ -95,7 +95,7 @@ void CutlassGemmGroupwiseScaledSM120(TensorView float_workspace_buffer, TensorVi CHECK_CONTIGUOUS(SFB); DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { - return DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A->dtype, C->dtype, c_type_in, c_type_out, [&] { + return DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A.dtype(), C.dtype(), c_type_in, c_type_out, [&] { return DISPATCH_SCALE_GRANULARITY( scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, [&] { @@ -104,29 +104,29 @@ void CutlassGemmGroupwiseScaledSM120(TensorView float_workspace_buffer, TensorVi // Handle both 2D and 3D tensors (BMM) int m, n, k, l; - if (A->ndim == 2) { + if (A.ndim() == 2) { // 2D case: simple matrix multiplication - m = A->shape[0]; - k = A->shape[1]; - n = B->shape[0]; + m = A.size(0); + k = A.size(1); + n = B.size(0); l = 1; // no batch dimension - } else if (A->ndim == 3) { + } else if (A.ndim() == 3) { // 3D case: batch matrix multiplication - l = A->shape[0]; // batch size - m = A->shape[1]; // per-batch m dimension - k = A->shape[2]; // per-batch k dimension - n = B->shape[2]; // per-batch n dimension (B is [batch, k, n] column-major) + l = A.size(0); // batch size + m = A.size(1); // per-batch m dimension + k = A.size(2); // per-batch k dimension + n = B.size(2); // per-batch n dimension (B is [batch, k, n] column-major) } else { return false; // Unsupported tensor dimension } auto status = flashinfer::gemm::CutlassGroupwiseScaledGEMMSM120< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K>( - static_cast(float_workspace_buffer->data), + static_cast(float_workspace_buffer.data_ptr()), get_element_size(float_workspace_buffer) * float_workspace_buffer.numel(), - static_cast(A->data), static_cast(B->data), - static_cast(SFA->data), static_cast(SFB->data), - static_cast(C->data), m, n, k, l, + static_cast(A.data_ptr()), static_cast(B.data_ptr()), + static_cast(SFA.data_ptr()), static_cast(SFB.data_ptr()), + static_cast(C.data_ptr()), m, n, k, l, stream); // C is the output (D) return status == cudaSuccess; }); diff --git a/csrc/group_gemm.cu b/csrc/group_gemm.cu index 73684a3fcf..100c7183f3 100644 --- a/csrc/group_gemm.cu +++ b/csrc/group_gemm.cu @@ -23,16 +23,16 @@ using namespace flashinfer::group_gemm; void CutlassSegmentGEMM(TensorView workspace_buffer, TensorView all_problems, TensorView x_ptr, TensorView w_ptr, TensorView y_ptr, TensorView x_ld, TensorView w_ld, TensorView y_ld, TensorView empty_x_data, bool weight_column_major) { - unsigned int batch_size = x_ptr->shape[0]; + unsigned int batch_size = x_ptr.size(0); - cudaSetDevice(workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(workspace_buffer->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(empty_x_data->dtype, c_type, [&] { + cudaSetDevice(workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(workspace_buffer.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(empty_x_data.dtype(), c_type, [&] { using cutlass_t = cutlass_dtype_t; auto status = CutlassSegmentGEMMRun( - workspace_buffer->data, get_element_size(workspace_buffer) * workspace_buffer->shape[0], - all_problems->data, batch_size, x_ptr->data, w_ptr->data, y_ptr->data, x_ld->data, - w_ld->data, y_ld->data, weight_column_major, stream); + workspace_buffer.data_ptr(), get_element_size(workspace_buffer) * workspace_buffer.size(0), + all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), y_ptr.data_ptr(), + x_ld.data_ptr(), w_ld.data_ptr(), y_ld.data_ptr(), weight_column_major, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to run CutlassSegmentGEMM: " << cudaGetErrorString(status); return true; diff --git a/csrc/group_gemm_fp8_groupwise_sm100.cu b/csrc/group_gemm_fp8_groupwise_sm100.cu index a3abef08a4..006314de71 100644 --- a/csrc/group_gemm_fp8_groupwise_sm100.cu +++ b/csrc/group_gemm_fp8_groupwise_sm100.cu @@ -91,11 +91,11 @@ void CutlassGroupGemmFP8GroupwiseScaledSM100( TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode, int64_t mma_sm) { - cudaSetDevice(float_workspace_buffer->device.device_id); - auto stream = get_stream(D->device); - int num_groups = m_indptr->shape[0] - 1; - int max_m = SFA->shape[1]; - DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A->dtype, D->dtype, c_type_in, c_type_out, [&] { + cudaSetDevice(float_workspace_buffer.device().device_id); + auto stream = get_stream(D.device()); + int num_groups = m_indptr.size(0) - 1; + int max_m = SFA.size(1); + DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A.dtype(), D.dtype(), c_type_in, c_type_out, [&] { return DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { return DISPATCH_MMA_SM(mma_sm, MMA_SM, [&] { return DISPATCH_SCALE_GRANULARITY( @@ -105,15 +105,15 @@ void CutlassGroupGemmFP8GroupwiseScaledSM100( using cutlass_t_out = cutlass_dtype_t; auto status = flashinfer::group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM100< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K, - MMA_SM>( - static_cast(int_workspace_buffer->data), - get_element_size(int_workspace_buffer) * int_workspace_buffer->shape[0], - static_cast(float_workspace_buffer->data), - get_element_size(float_workspace_buffer) * float_workspace_buffer->shape[0], - static_cast(A->data), static_cast(B->data), - static_cast(SFA->data), static_cast(SFB->data), - static_cast(D->data), static_cast(m_indptr->data), max_m, n, - k, num_groups, stream); + MMA_SM>(static_cast(int_workspace_buffer.data_ptr()), + get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), + static_cast(float_workspace_buffer.data_ptr()), + get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), + static_cast(A.data_ptr()), + static_cast(B.data_ptr()), + static_cast(SFA.data_ptr()), static_cast(SFB.data_ptr()), + static_cast(D.data_ptr()), + static_cast(m_indptr.data_ptr()), max_m, n, k, num_groups, stream); return status == cudaSuccess; }); }); diff --git a/csrc/group_gemm_fp8_groupwise_sm120.cu b/csrc/group_gemm_fp8_groupwise_sm120.cu index f19aecd15e..c0bbaa31b2 100644 --- a/csrc/group_gemm_fp8_groupwise_sm120.cu +++ b/csrc/group_gemm_fp8_groupwise_sm120.cu @@ -85,9 +85,9 @@ void CutlassGroupGemmFP8GroupwiseScaledSM120( TensorView SFA, TensorView SFB, TensorView D, TensorView m_indptr, int64_t n, int64_t k, int64_t scale_granularity_m, int64_t scale_granularity_n, int64_t scale_granularity_k, std::string scale_major_mode) { - cudaSetDevice(float_workspace_buffer->device.device_id); - auto stream = get_stream(D->device); - int num_groups = m_indptr->shape[0] - 1; + cudaSetDevice(float_workspace_buffer.device().device_id); + auto stream = get_stream(D.device()); + int num_groups = m_indptr.size(0) - 1; // Ensure scales are contiguous // Note: We keep the original shape and let the kernel's layout handle interpretation @@ -95,9 +95,9 @@ void CutlassGroupGemmFP8GroupwiseScaledSM120( CHECK_CONTIGUOUS(SFB); // Get max_m from SFA shape - int max_m = SFA->shape[SFA->ndim > 1 ? 1 : 0]; + int max_m = SFA.size(SFA.ndim() > 1 ? 1 : 0); - DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A->dtype, D->dtype, c_type_in, c_type_out, [&] { + DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE(A.dtype(), D.dtype(), c_type_in, c_type_out, [&] { return DISPATCH_SCALE_MAJOR_K(scale_major_mode, SCALE_MAJOR_K, [&] { return DISPATCH_SCALE_GRANULARITY( scale_granularity_m, scale_granularity_n, scale_granularity_k, SCALE_GRANULARITY_M, @@ -107,14 +107,14 @@ void CutlassGroupGemmFP8GroupwiseScaledSM120( auto status = flashinfer::group_gemm::CutlassFP8GroupwiseScaledGroupGEMMSM120< SCALE_GRANULARITY_M, SCALE_GRANULARITY_N, SCALE_GRANULARITY_K, SCALE_MAJOR_K, cutlass_t_in, cutlass_t_out>( - static_cast(int_workspace_buffer->data), - get_element_size(int_workspace_buffer) * int_workspace_buffer->shape[0], - static_cast(float_workspace_buffer->data), - get_element_size(float_workspace_buffer) * float_workspace_buffer->shape[0], - static_cast(A->data), static_cast(B->data), - static_cast(SFA->data), static_cast(SFB->data), - static_cast(D->data), static_cast(m_indptr->data), max_m, n, - k, num_groups, stream); + static_cast(int_workspace_buffer.data_ptr()), + get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), + static_cast(float_workspace_buffer.data_ptr()), + get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), + static_cast(A.data_ptr()), static_cast(B.data_ptr()), + static_cast(SFA.data_ptr()), static_cast(SFB.data_ptr()), + static_cast(D.data_ptr()), static_cast(m_indptr.data_ptr()), + max_m, n, k, num_groups, stream); return status == cudaSuccess; }); }); diff --git a/csrc/group_gemm_mxfp4_groupwise_sm100.cu b/csrc/group_gemm_mxfp4_groupwise_sm100.cu index 1403420602..c4dc79e5d7 100644 --- a/csrc/group_gemm_mxfp4_groupwise_sm100.cu +++ b/csrc/group_gemm_mxfp4_groupwise_sm100.cu @@ -133,12 +133,12 @@ void CutlassGroupGemmMXFP4GroupwiseScaledSM100(TensorView int_workspace_buffer, TensorView D, TensorView m_indptr, int64_t n, int64_t k, int64_t mma_sm, int64_t tile_m, int64_t tile_n, int64_t tile_k, bool swap_ab) { - cudaSetDevice(float_workspace_buffer->device.device_id); - auto stream = get_stream(A->device); - int num_groups = m_indptr->shape[0] - 1; + cudaSetDevice(float_workspace_buffer.device().device_id); + auto stream = get_stream(A.device()); + int num_groups = m_indptr.size(0) - 1; DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE( - A->dtype, B->dtype, SFA->dtype, SFB->dtype, D->dtype, c_type_in_a, c_type_in_b, c_type_sf_a, - c_type_sf_b, c_type_out, [&] { + A.dtype(), B.dtype(), SFA.dtype(), SFB.dtype(), D.dtype(), c_type_in_a, c_type_in_b, + c_type_sf_a, c_type_sf_b, c_type_out, [&] { return DISPATCH_MMA_SM(mma_sm, MMA_SM, [&] { return DISPATCH_TILE_M(tile_m, TILE_M, [&] { return DISPATCH_TILE_N(tile_n, TILE_N, [&] { @@ -153,16 +153,16 @@ void CutlassGroupGemmMXFP4GroupwiseScaledSM100(TensorView int_workspace_buffer, using cutlass_t_out = cutlass_dtype_t; auto status = flashinfer::group_gemm::CutlassMXFP4GroupwiseScaledGroupGEMMSM100< TILE_M, TILE_N, TILE_K, MMA_SM, SWAP_AB>( - static_cast(int_workspace_buffer->data), - get_element_size(int_workspace_buffer) * int_workspace_buffer->shape[0], - static_cast(float_workspace_buffer->data), - get_element_size(float_workspace_buffer) * float_workspace_buffer->shape[0], - static_cast(A->data), - static_cast(B->data), - static_cast(SFA->data), - static_cast(SFB->data), - static_cast(D->data), static_cast(m_indptr->data), n, - k, num_groups, stream); + static_cast(int_workspace_buffer.data_ptr()), + get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), + static_cast(float_workspace_buffer.data_ptr()), + get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), + static_cast(A.data_ptr()), + static_cast(B.data_ptr()), + static_cast(SFA.data_ptr()), + static_cast(SFB.data_ptr()), + static_cast(D.data_ptr()), + static_cast(m_indptr.data_ptr()), n, k, num_groups, stream); return status == cudaSuccess; } else { TVM_FFI_ICHECK(false) << "Unsupported input data type"; diff --git a/csrc/group_gemm_sm90.cu b/csrc/group_gemm_sm90.cu index ab6ddd5e6c..9e2ee793e4 100644 --- a/csrc/group_gemm_sm90.cu +++ b/csrc/group_gemm_sm90.cu @@ -52,21 +52,22 @@ void CutlassSegmentGEMMSM90(TensorView float_workspace_buffer, TensorView int_wo TensorView y_ptr, TensorView x_stride, TensorView weight_stride, TensorView y_stride, TensorView empty_x_data, TensorView empty_y_data, bool weight_column_major) { - unsigned int batch_size = x_ptr->shape[0]; - cudaSetDevice(float_workspace_buffer->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer->device); + unsigned int batch_size = x_ptr.size(0); + cudaSetDevice(float_workspace_buffer.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer.device()); DISPATCH_DLPACK_INPUT_OUTPUT_DTYPE( - empty_x_data->dtype, empty_y_data->dtype, c_type_in, c_type_out, [&] { + empty_x_data.dtype(), empty_y_data.dtype(), c_type_in, c_type_out, [&] { using cutlass_t_in = cutlass_dtype_t; using cutlass_t_out = cutlass_dtype_t; auto status = flashinfer::group_gemm::CutlassSegmentGEMMSM90Run( - float_workspace_buffer->data, - get_element_size(float_workspace_buffer) * float_workspace_buffer->shape[0], - int_workspace_buffer->data, - get_element_size(int_workspace_buffer) * int_workspace_buffer->shape[0], - all_problems->data, batch_size, x_ptr->data, w_ptr->data, y_ptr->data, - x_stride->data, weight_stride->data, y_stride->data, weight_column_major, stream); + float_workspace_buffer.data_ptr(), + get_element_size(float_workspace_buffer) * float_workspace_buffer.size(0), + int_workspace_buffer.data_ptr(), + get_element_size(int_workspace_buffer) * int_workspace_buffer.size(0), + all_problems.data_ptr(), batch_size, x_ptr.data_ptr(), w_ptr.data_ptr(), + y_ptr.data_ptr(), x_stride.data_ptr(), weight_stride.data_ptr(), + y_stride.data_ptr(), weight_column_major, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "Failed to run CutlassSegmentGEMM: " << cudaGetErrorString(status); return true; diff --git a/csrc/norm.cu b/csrc/norm.cu index 6ce0c38c33..c5ff67eb10 100644 --- a/csrc/norm.cu +++ b/csrc/norm.cu @@ -26,24 +26,24 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps, CHECK_DEVICE(input, weight); CHECK_DIM(1, weight); // weight: (hidden_size) - auto input_ndim = input->ndim; + auto input_ndim = input.ndim(); if (input_ndim == 2) { // Normal RMSNorm: [batch_size, hidden_size] // Use CTA parallelization for better parallelism CHECK_DIM(2, output); - TVM_FFI_ICHECK_EQ(input->shape[1], weight->shape[0]); - unsigned int batch_size = input->shape[0]; - unsigned int hidden_size = input->shape[1]; - TVM_FFI_ICHECK_EQ(output->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); + TVM_FFI_ICHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + TVM_FFI_ICHECK_EQ(output.size(0), batch_size); + TVM_FFI_ICHECK_EQ(output.size(1), hidden_size); + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { - cudaError_t status = - norm::RMSNorm(static_cast(input->data), static_cast(weight->data), - static_cast(output->data), batch_size, hidden_size, - input->strides[0], output->strides[0], eps, enable_pdl, stream); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { + cudaError_t status = norm::RMSNorm( + static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, hidden_size, input.stride(0), + output.stride(0), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "RMSNorm failed with error code " << cudaGetErrorString(status); return true; @@ -52,21 +52,22 @@ void rmsnorm(TensorView output, TensorView input, TensorView weight, double eps, // QK RMSNorm: [batch_size, num_heads, head_dim] // Use warp-level parallization CHECK_DIM(3, output); // output: (batch_size, num_heads, hidden_size) - TVM_FFI_ICHECK_EQ(input->shape[2], weight->shape[0]); - unsigned int batch_size = input->shape[0]; - unsigned int num_heads = input->shape[1]; - unsigned int hidden_size = input->shape[2]; - TVM_FFI_ICHECK_EQ(output->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(output->shape[1], num_heads); - TVM_FFI_ICHECK_EQ(output->shape[2], hidden_size); + TVM_FFI_ICHECK_EQ(input.size(2), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int num_heads = input.size(1); + unsigned int hidden_size = input.size(2); + TVM_FFI_ICHECK_EQ(output.size(0), batch_size); + TVM_FFI_ICHECK_EQ(output.size(1), num_heads); + TVM_FFI_ICHECK_EQ(output.size(2), hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { cudaError_t status = norm::QKRMSNorm( - static_cast(input->data), static_cast(weight->data), - static_cast(output->data), batch_size, num_heads, hidden_size, input->strides[0], - input->strides[1], output->strides[0], output->strides[1], eps, enable_pdl, stream); + static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, num_heads, hidden_size, + input.stride(0), input.stride(1), output.stride(0), output.stride(1), eps, enable_pdl, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "QKRMSNorm failed with error code " << cudaGetErrorString(status); return true; @@ -86,19 +87,19 @@ void fused_add_rmsnorm(TensorView input, TensorView residual, TensorView weight, CHECK_DIM(2, input); // input: (batch_size, hidden_size) CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) CHECK_DIM(1, weight); // weight: (hidden_size) - unsigned int batch_size = input->shape[0]; - unsigned int hidden_size = input->shape[1]; - TVM_FFI_ICHECK_EQ(residual->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(residual->shape[1], hidden_size); - TVM_FFI_ICHECK_EQ(weight->shape[0], hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + TVM_FFI_ICHECK_EQ(residual.size(0), batch_size); + TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size); + TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size); + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { cudaError_t status = norm::FusedAddRMSNorm( - static_cast(input->data), static_cast(residual->data), - static_cast(weight->data), batch_size, hidden_size, input->strides[0], - residual->strides[0], eps, enable_pdl, stream); + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, input.stride(0), + residual.stride(0), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "FusedAddRMSNorm failed with error code " << cudaGetErrorString(status); @@ -113,19 +114,19 @@ void gemma_rmsnorm(TensorView output, TensorView input, TensorView weight, doubl CHECK_DEVICE(input, weight); CHECK_DIM(2, input); // input: (batch_size, hidden_size) CHECK_DIM(1, weight); // weight: (hidden_size) - TVM_FFI_ICHECK_EQ(input->shape[1], weight->shape[0]); - unsigned int batch_size = input->shape[0]; - unsigned int hidden_size = input->shape[1]; - TVM_FFI_ICHECK_EQ(output->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); + TVM_FFI_ICHECK_EQ(input.size(1), weight.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + TVM_FFI_ICHECK_EQ(output.size(0), batch_size); + TVM_FFI_ICHECK_EQ(output.size(1), hidden_size); + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { - cudaError_t status = - norm::GemmaRMSNorm(static_cast(input->data), static_cast(weight->data), - static_cast(output->data), batch_size, hidden_size, - input->strides[0], output->strides[0], eps, enable_pdl, stream); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { + cudaError_t status = norm::GemmaRMSNorm( + static_cast(input.data_ptr()), static_cast(weight.data_ptr()), + static_cast(output.data_ptr()), batch_size, hidden_size, input.stride(0), + output.stride(0), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "GemmaRMSNorm failed with error code " << cudaGetErrorString(status); return true; @@ -142,19 +143,19 @@ void gemma_fused_add_rmsnorm(TensorView input, TensorView residual, TensorView w CHECK_DIM(2, input); // input: (batch_size, hidden_size) CHECK_DIM(2, residual); // residual: (batch_size, hidden_size) CHECK_DIM(1, weight); // weight: (hidden_size) - unsigned int batch_size = input->shape[0]; - unsigned int hidden_size = input->shape[1]; - TVM_FFI_ICHECK_EQ(residual->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(residual->shape[1], hidden_size); - TVM_FFI_ICHECK_EQ(weight->shape[0], hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + TVM_FFI_ICHECK_EQ(residual.size(0), batch_size); + TVM_FFI_ICHECK_EQ(residual.size(1), hidden_size); + TVM_FFI_ICHECK_EQ(weight.size(0), hidden_size); + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { cudaError_t status = norm::GemmaFusedAddRMSNorm( - static_cast(input->data), static_cast(residual->data), - static_cast(weight->data), batch_size, hidden_size, input->strides[0], - residual->strides[0], eps, enable_pdl, stream); + static_cast(input.data_ptr()), static_cast(residual.data_ptr()), + static_cast(weight.data_ptr()), batch_size, hidden_size, input.stride(0), + residual.stride(0), eps, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "GemmaFusedAddRMSNorm failed with error code " << cudaGetErrorString(status); return true; @@ -170,24 +171,24 @@ void layernorm(Tensor output, Tensor input, Tensor gamma, Tensor beta, double ep CHECK_DIM(2, input); // input: (batch_size, hidden_size) CHECK_DIM(1, gamma); // gamma: (hidden_size) CHECK_DIM(1, beta); // beta: (hidden_size) - TVM_FFI_ICHECK_EQ(input->shape[1], gamma->shape[0]); - TVM_FFI_ICHECK_EQ(input->shape[1], beta->shape[0]); - unsigned int batch_size = input->shape[0]; - unsigned int hidden_size = input->shape[1]; - TVM_FFI_ICHECK_EQ(output->shape[0], batch_size); - TVM_FFI_ICHECK_EQ(output->shape[1], hidden_size); - cudaSetDevice(input->device.device_id); - const cudaStream_t stream = get_stream(input->device); + TVM_FFI_ICHECK_EQ(input.size(1), gamma.size(0)); + TVM_FFI_ICHECK_EQ(input.size(1), beta.size(0)); + unsigned int batch_size = input.size(0); + unsigned int hidden_size = input.size(1); + TVM_FFI_ICHECK_EQ(output.size(0), batch_size); + TVM_FFI_ICHECK_EQ(output.size(1), hidden_size); + cudaSetDevice(input.device().device_id); + const cudaStream_t stream = get_stream(input.device()); // TODO(kaixih): This is currently our only use case; Add more if needed. - TVM_FFI_ICHECK_EQ(input->dtype, dl_bfloat16) << "input must be bfloat16"; - TVM_FFI_ICHECK_EQ(gamma->dtype, dl_float32) << "gamma must be float32"; - TVM_FFI_ICHECK_EQ(beta->dtype, dl_float32) << "beta must be float32"; + TVM_FFI_ICHECK_EQ(input.dtype(), dl_bfloat16) << "input must be bfloat16"; + TVM_FFI_ICHECK_EQ(gamma.dtype(), dl_float32) << "gamma must be float32"; + TVM_FFI_ICHECK_EQ(beta.dtype(), dl_float32) << "beta must be float32"; - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { - cudaError_t status = - norm::LayerNorm(static_cast(input->data), static_cast(gamma->data), - static_cast(beta->data), static_cast(output->data), - batch_size, hidden_size, eps, stream); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { + cudaError_t status = norm::LayerNorm( + static_cast(input.data_ptr()), static_cast(gamma.data_ptr()), + static_cast(beta.data_ptr()), static_cast(output.data_ptr()), batch_size, + hidden_size, eps, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "LayerNorm failed with error code " << cudaGetErrorString(status); return true; diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp index c14d77a606..673bc27edd 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Op.cpp @@ -141,7 +141,7 @@ int computeSFIndex(int rowIdx, int colIdx, int totalRow, int totalColumn, // blockScale: [num_experts, rows, cols] or [rows, cols] // Return: num_experts * pad_up(rows, 128) * pad_up(cols, 4) void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScale) { - bool is_cuda = (blockScale->device.device_type == kDLCUDA); + bool is_cuda = (blockScale.device().device_type == kDLCUDA); if (is_cuda) { CHECK_CUDA(blockScale); } else { @@ -149,7 +149,7 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal } CHECK_CONTIGUOUS(blockScale); CHECK_INPUT_TYPE(blockScale, dl_uint8); - auto blockScaleShape = blockScale.shape(); + auto blockScaleShape = blockScale.sizes(); TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3) << "Block Scale should be 2D or 3D tensor."; auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1; @@ -164,18 +164,19 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal if (is_cuda) { const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); - const cudaStream_t stream = get_stream(blockScale->device); + const cudaStream_t stream = get_stream(blockScale.device()); tensorrt_llm::kernels::invokeBlockScaleInterleave( - num_experts, rows, rows_padded, cols, cols_padded, static_cast(blockScale->data), - static_cast(interleavedBlockScale->data), smCount, stream); + num_experts, rows, rows_padded, cols, cols_padded, + static_cast(blockScale.data_ptr()), + static_cast(interleavedBlockScale.data_ptr()), smCount, stream); } else { for (int eIdx = 0; eIdx < static_cast(num_experts); eIdx++) { uint8_t* interleavedBlockScalePtr = - static_cast(interleavedBlockScale->data) + eIdx * expert_out_size; + static_cast(interleavedBlockScale.data_ptr()) + eIdx * expert_out_size; for (int rIdx = 0; rIdx < static_cast(rows_padded); ++rIdx) { auto globalRowIdx = eIdx * rows + rIdx; - uint8_t* blockScalePtr = static_cast(blockScale->data) + globalRowIdx * cols; + uint8_t* blockScalePtr = static_cast(blockScale.data_ptr()) + globalRowIdx * cols; for (int cIdx = 0; cIdx < static_cast(cols_padded); ++cIdx) { uint8_t sf_ori = 0; if (rIdx < static_cast(rows) && cIdx < static_cast(cols)) { @@ -195,7 +196,7 @@ void BlockScaleInterleave(TensorView blockScale, TensorView interleavedBlockScal // Note: rows and cols are the dimensions of the original unswizzled SFMatrix, so reshape input // before passing into this function! Return: The same shape as blockScale void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView reversedBlockScale) { - bool is_cuda = (blockScale->device.device_type == kDLCUDA); + bool is_cuda = (blockScale.device().device_type == kDLCUDA); if (is_cuda) { CHECK_CUDA(blockScale); } else { @@ -203,7 +204,7 @@ void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView revers } CHECK_CONTIGUOUS(blockScale); CHECK_INPUT_TYPE(blockScale, dl_uint8); - auto blockScaleShape = blockScale.shape(); + auto blockScaleShape = blockScale.sizes(); TVM_FFI_ICHECK(blockScaleShape.size() == 2 || blockScaleShape.size() == 3) << "Block Scale should be 2D or 3D tensor."; auto num_experts = blockScaleShape.size() == 3 ? blockScaleShape[0] : 1; @@ -215,10 +216,10 @@ void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView revers if (is_cuda) { const thread_local int smCount = tensorrt_llm::common::getMultiProcessorCount(); - const cudaStream_t stream = get_stream(blockScale->device); + const cudaStream_t stream = get_stream(blockScale.device()); tensorrt_llm::kernels::invokeBlockScaleInterleaveReverse( - num_experts, rows, cols, static_cast(blockScale->data), - static_cast(reversedBlockScale->data), smCount, stream); + num_experts, rows, cols, static_cast(blockScale.data_ptr()), + static_cast(reversedBlockScale.data_ptr()), smCount, stream); } else { // index in the swizzled SFMatrix -> (eIdx, rIdx, cIdx) in the unswizzled SFMatrix Map> identity; @@ -231,11 +232,11 @@ void BlockScaleInterleaveReverse(TensorView const& blockScale, TensorView revers } } } - uint8_t* blockScalePtr = static_cast(blockScale->data); + uint8_t* blockScalePtr = static_cast(blockScale.data_ptr()); for (int i = 0; i < expert_out_size * num_experts; i++) { auto loc_2d = identity[i]; if (loc_2d[1] < rows && loc_2d[2] < cols) { - uint8_t* reversedBlockScalePtr = static_cast(reversedBlockScale->data) + + uint8_t* reversedBlockScalePtr = static_cast(reversedBlockScale.data_ptr()) + (loc_2d[0] * rows + loc_2d[1]) * cols + loc_2d[2]; *reversedBlockScalePtr = blockScalePtr[i]; } @@ -250,15 +251,15 @@ void E2M1AndUFP8SFScaleToFloatV2(TensorView valueE2M1, TensorView scaleFP8SF, bool isSfSwizzledLayout = true) { CHECK_CPU_INPUT(valueE2M1, dl_uint8); CHECK_CPU_INPUT(scaleFP8SF, dl_uint8); - auto packedShape = valueE2M1.shape(); - auto scaleShape = scaleFP8SF.shape(); + auto packedShape = valueE2M1.sizes(); + auto scaleShape = scaleFP8SF.sizes(); TVM_FFI_ICHECK_EQ(packedShape.size(), 2) << "valueE2M1 should be 2D tensor."; TVM_FFI_ICHECK_EQ(scaleShape.size(), 1) << "scaleFP8SF should be 1D tensor."; float globalScaleVal{1.0f}; if (sfType == 1) { TVM_FFI_ICHECK(globalScale.has_value()) << "globalScale is required when sfType is 1."; - globalScaleVal = static_cast(globalScale.value()->data)[0]; + globalScaleVal = static_cast(globalScale.value().data_ptr())[0]; } int hiddenDim = packedShape[1] * 2; @@ -271,10 +272,10 @@ void E2M1AndUFP8SFScaleToFloatV2(TensorView valueE2M1, TensorView scaleFP8SF, for (size_t vIdx = 0; vIdx < static_cast(packedShape[0]); ++vIdx) { for (int group = 0; group < groupsPerHiddenDim; ++group) { float* floatPtr = - static_cast(floatTensorView->data) + vIdx * hiddenDim + group * sfVecSize; - uint8_t* packedFp4Ptr = static_cast(valueE2M1->data) + vIdx * packedFp4HiddenDim + - group * sfVecSize / 2; - uint8_t* scaleFP8SFPtr = static_cast(scaleFP8SF->data); + static_cast(floatTensorView.data_ptr()) + vIdx * hiddenDim + group * sfVecSize; + uint8_t* packedFp4Ptr = static_cast(valueE2M1.data_ptr()) + + vIdx * packedFp4HiddenDim + group * sfVecSize / 2; + uint8_t* scaleFP8SFPtr = static_cast(scaleFP8SF.data_ptr()); uint8_t fp8Scale = scaleFP8SFPtr[computeSFIndex(vIdx, group, packedShape[0], groupsPerHiddenDim, layout)]; float scaleFloat; @@ -307,25 +308,26 @@ void mxfp4_dequantize_host(TensorView weight, TensorView scale, TensorView dequa CHECK_CPU_INPUT(scale, dl_uint8); CHECK_CONTIGUOUS(weight); CHECK_CONTIGUOUS(scale); - TVM_FFI_ICHECK_NE(weight.shape().Product(), 0) << "weight should not be empty tensor"; + TVM_FFI_ICHECK_NE(weight.numel(), 0) << "weight should not be empty tensor"; CHECK_INPUT_TYPE(weight, dl_uint8); CHECK_INPUT_TYPE(scale, dl_uint8); - int const n = weight->shape[0]; - int const k = weight->shape[1] * 2; + int const n = weight.size(0); + int const k = weight.size(1) * 2; - TVM_FFI_ICHECK_EQ(n, scale->shape[0]) << "weight and scale must have the same number of rows"; - TVM_FFI_ICHECK_EQ(k, scale->shape[1] * group_size) + TVM_FFI_ICHECK_EQ(n, scale.size(0)) << "weight and scale must have the same number of rows"; + TVM_FFI_ICHECK_EQ(k, scale.size(1) * group_size) << "weight and scale must have the same number of columns"; #if CUDA_VERSION >= 12080 - uint8_t* weight_packed_ptr = static_cast(weight->data); - __nv_fp8_e8m0* scale_ptr = reinterpret_cast<__nv_fp8_e8m0*>(static_cast(scale->data)); - float* dequant_weight_ptr = static_cast(dequant_weight->data); + uint8_t* weight_packed_ptr = static_cast(weight.data_ptr()); + __nv_fp8_e8m0* scale_ptr = + reinterpret_cast<__nv_fp8_e8m0*>(static_cast(scale.data_ptr())); + float* dequant_weight_ptr = static_cast(dequant_weight.data_ptr()); float fp4_lut[] = {0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, 0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}; - const auto num_packed_elements = weight.shape().Product(); + const auto num_packed_elements = weight.numel(); for (int packed_idx = 0; packed_idx < num_packed_elements; ++packed_idx) { int8_t weight_packed_data = weight_packed_ptr[packed_idx]; @@ -338,7 +340,7 @@ void mxfp4_dequantize_host(TensorView weight, TensorView scale, TensorView dequa int scale_n_idx = packed_idx / (k / 2); int scale_k_idx = ((packed_idx * 2) % k) / group_size; - float scale_ = static_cast(scale_ptr[scale_n_idx * scale->shape[1] + scale_k_idx]); + float scale_ = static_cast(scale_ptr[scale_n_idx * scale.size(1) + scale_k_idx]); dequant_weight_ptr[2 * packed_idx] = weight_low * scale_; dequant_weight_ptr[2 * packed_idx + 1] = weight_high * scale_; diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp index 00ab6b1c15..55e86df406 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp4Quantize.cpp @@ -47,10 +47,10 @@ void fp4_quantize(TensorView self, Optional const& globalScale, Tens float* globalScalePtr{nullptr}; if (globalScale.has_value()) { - globalScalePtr = static_cast(globalScale.value()->data); + globalScalePtr = static_cast(globalScale.value().data_ptr()); } - auto const& inputShape = self.shape(); + auto const& inputShape = self.sizes(); auto const& rank = inputShape.size(); TVM_FFI_ICHECK_GE(rank, 2) << "Input should be >=2D tensor."; @@ -68,24 +68,24 @@ void fp4_quantize(TensorView self, Optional const& globalScale, Tens : tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4) : tensorrt_llm::QuantizationSFLayout::LINEAR; -#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ - tensorrt_llm::kernels::invokeFP4Quantization( \ - 1, m, k, reinterpret_cast(self->data), globalScalePtr, \ - reinterpret_cast(valueE2M1->data), reinterpret_cast(scaleFP8SF->data), \ - sfUseUE8M0, layout, mMultiProcessorCount, /*mask=*/nullptr, enable_pdl, \ - get_stream(self->device)); +#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ + tensorrt_llm::kernels::invokeFP4Quantization( \ + 1, m, k, reinterpret_cast(self.data_ptr()), globalScalePtr, \ + reinterpret_cast(valueE2M1.data_ptr()), \ + reinterpret_cast(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \ + /*mask=*/nullptr, enable_pdl, get_stream(self.device())); if (sfUseUE8M0) { - if (self->dtype == dl_float16) { + if (self.dtype() == dl_float16) { LAUNCH_FP4_QUANTIZE_KERNEL(half, 32) - } else if (self->dtype == dl_bfloat16) { + } else if (self.dtype() == dl_bfloat16) { #ifdef ENABLE_BF16 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 32) #else TVM_FFI_LOG_AND_THROW(NotImplementedError) << "BFloat16 must be enabled to quantize an bf16 tensor to fp4."; #endif - } else if (self->dtype == dl_float8_e4m3fn) { + } else if (self.dtype() == dl_float8_e4m3fn) { #ifdef ENABLE_FP8 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 32) #else @@ -97,16 +97,16 @@ void fp4_quantize(TensorView self, Optional const& globalScale, Tens << "fp4_quantize only supports input tensor with dtypes fp16/bf16/e4m3."; } } else { - if (self->dtype == dl_float16) { + if (self.dtype() == dl_float16) { LAUNCH_FP4_QUANTIZE_KERNEL(half, 16) - } else if (self->dtype == dl_bfloat16) { + } else if (self.dtype() == dl_bfloat16) { #ifdef ENABLE_BF16 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 16) #else TVM_FFI_LOG_AND_THROW(NotImplementedError) << "BFloat16 must be enabled to quantize an bf16 tensor to fp4."; #endif - } else if (self->dtype == dl_float8_e4m3fn) { + } else if (self.dtype() == dl_float8_e4m3fn) { #ifdef ENABLE_FP8 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 16) #else @@ -140,7 +140,7 @@ void fp4_batched_quantize(TensorView self, Optional const& mask, Ten CHECK_INPUT_TYPE(globalScale, fp32_dtype); TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16"; - auto const& inputShape = self.shape(); + auto const& inputShape = self.sizes(); auto const& rank = inputShape.size(); TVM_FFI_ICHECK_EQ(rank, 3) << "Input should be 3D tensor."; @@ -152,9 +152,9 @@ void fp4_batched_quantize(TensorView self, Optional const& mask, Ten TVM_FFI_ICHECK_EQ(k % sfVecSize, 0); bool use_mask = mask.has_value(); if (use_mask) { - auto const& mask_rank = mask.value().shape().size(); + auto const& mask_rank = mask.value().ndim(); TVM_FFI_ICHECK_EQ(mask_rank, 1) << "Mask should be 1D tensor."; - TVM_FFI_ICHECK_EQ(mask.value().shape()[0], b); + TVM_FFI_ICHECK_EQ(mask.value().size(0), b); } std::vector outputShape(inputShape.begin(), inputShape.end()); @@ -163,24 +163,24 @@ void fp4_batched_quantize(TensorView self, Optional const& mask, Ten const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4; -#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ - tensorrt_llm::kernels::invokeFP4Quantization( \ - b, m, k, reinterpret_cast(self->data), static_cast(globalScale->data), \ - reinterpret_cast(valueE2M1->data), reinterpret_cast(scaleFP8SF->data), \ - sfUseUE8M0, layout, mMultiProcessorCount, \ - use_mask ? static_cast(mask.value()->data) : nullptr, /*enable_pdl=*/false, \ - get_stream(self->device)); +#define LAUNCH_FP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ + tensorrt_llm::kernels::invokeFP4Quantization( \ + b, m, k, reinterpret_cast(self.data_ptr()), static_cast(globalScale.data_ptr()), \ + reinterpret_cast(valueE2M1.data_ptr()), \ + reinterpret_cast(scaleFP8SF.data_ptr()), sfUseUE8M0, layout, mMultiProcessorCount, \ + use_mask ? static_cast(mask.value().data_ptr()) : nullptr, /*enable_pdl=*/false, \ + get_stream(self.device())); - if (self->dtype == dl_float16) { + if (self.dtype() == dl_float16) { LAUNCH_FP4_QUANTIZE_KERNEL(half, 16) - } else if (self->dtype == dl_bfloat16) { + } else if (self.dtype() == dl_bfloat16) { #ifdef ENABLE_BF16 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_bfloat16, 16) #else TVM_FFI_LOG_AND_THROW(NotImplementedError) << "BFloat16 must be enabled to quantize an bf16 tensor to fp4."; #endif - } else if (self->dtype == dl_float8_e4m3fn) { + } else if (self.dtype() == dl_float8_e4m3fn) { #ifdef ENABLE_FP8 LAUNCH_FP4_QUANTIZE_KERNEL(__nv_fp8_e4m3, 16) #else @@ -205,9 +205,9 @@ void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView cons CHECK_INPUT_TYPE(globalScale, fp32_dtype); TVM_FFI_ICHECK_EQ(sfVecSize, 16) << "sfVecSize can only be 16"; - auto const& inputShape = self.shape(); + auto const& inputShape = self.sizes(); auto const& rank = inputShape.size(); - auto const& mask_rank = mask.shape().size(); + auto const& mask_rank = mask.ndim(); TVM_FFI_ICHECK_EQ(rank, 3) << "Input should be 3D tensor."; TVM_FFI_ICHECK_EQ(mask_rank, 1) << "Mask should be 1D tensor."; @@ -218,7 +218,7 @@ void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView cons int64_t k = k_by_2 / 2; TVM_FFI_ICHECK_EQ(k % sfVecSize, 0); - TVM_FFI_ICHECK_EQ(mask.shape()[0], b); + TVM_FFI_ICHECK_EQ(mask.size(0), b); std::vector outputShape(inputShape.begin(), inputShape.end()); outputShape[rank - 1] = k / 2; @@ -226,16 +226,17 @@ void silu_and_mul_nvfp4_batched_quantize(TensorView const& self, TensorView cons const thread_local int mMultiProcessorCount = tensorrt_llm::common::getMultiProcessorCount(); auto layout = tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4; -#define LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ - tensorrt_llm::kernels::invokeSiluAndMulFP4Quantization( \ - b, m, k_by_2, reinterpret_cast(self->data), static_cast(globalScale->data), \ - static_cast(mask->data), reinterpret_cast(valueE2M1->data), \ - reinterpret_cast(scaleFP8SF->data), layout, mMultiProcessorCount, \ - get_stream(self->device)); +#define LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(T, SF_VEC_SIZE) \ + tensorrt_llm::kernels::invokeSiluAndMulFP4Quantization( \ + b, m, k_by_2, reinterpret_cast(self.data_ptr()), \ + static_cast(globalScale.data_ptr()), static_cast(mask.data_ptr()), \ + reinterpret_cast(valueE2M1.data_ptr()), \ + reinterpret_cast(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, \ + get_stream(self.device())); - if (self->dtype == dl_float16) { + if (self.dtype() == dl_float16) { LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(half, 16) - } else if (self->dtype == dl_bfloat16) { + } else if (self.dtype() == dl_bfloat16) { #ifdef ENABLE_BF16 LAUNCH_SILU_AND_MUL_NVFP4_QUANTIZE_KERNEL(__nv_bfloat16, 16) #else diff --git a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp index f943dbcfb4..0af0b1f030 100644 --- a/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp +++ b/csrc/nv_internal/tensorrt_llm/thop/fp8Quantize.cpp @@ -36,7 +36,7 @@ void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF TVM_FFI_ICHECK_EQ(alignment % SF_VEC_SIZE, 0) << "alignment must be divisible by SF_VEC_SIZE = 32"; - auto const& inputShape = input.shape(); + auto const& inputShape = input.sizes(); auto const& rank = inputShape.size(); TVM_FFI_ICHECK_GE(rank, 2) << "Input should be >=2D tensor."; @@ -53,15 +53,16 @@ void mxfp8_quantize(TensorView input, TensorView valMxFP8, TensorView scaleFP8SF auto const layout = isSfSwizzledLayout ? tensorrt_llm::QuantizationSFLayout::SWIZZLED_128x4 : tensorrt_llm::QuantizationSFLayout::LINEAR; -#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ - tensorrt_llm::kernels::invokeMxFP8Quantization( \ - 1, m, k, padded_k, reinterpret_cast(input->data), \ - reinterpret_cast(valMxFP8->data), reinterpret_cast(scaleFP8SF->data), \ - layout, mMultiProcessorCount, enable_pdl, get_stream(input->device)); +#define LAUNCH_MXFP8_QUANTIZE_KERNEL(T) \ + tensorrt_llm::kernels::invokeMxFP8Quantization( \ + 1, m, k, padded_k, reinterpret_cast(input.data_ptr()), \ + reinterpret_cast(valMxFP8.data_ptr()), \ + reinterpret_cast(scaleFP8SF.data_ptr()), layout, mMultiProcessorCount, enable_pdl, \ + get_stream(input.device())); - if (input->dtype == dl_float16) { + if (input.dtype() == dl_float16) { LAUNCH_MXFP8_QUANTIZE_KERNEL(half) - } else if (input->dtype == dl_bfloat16) { + } else if (input.dtype() == dl_bfloat16) { #ifdef ENABLE_BF16 LAUNCH_MXFP8_QUANTIZE_KERNEL(__nv_bfloat16) #else @@ -97,7 +98,7 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc int32_t const sf_vec_size = 32; auto fp32_dtype = DLDataType{kDLFloat, 32, 1}; CHECK_INPUT_TYPE(x_fp32, fp32_dtype); - auto data_shape = x_fp32.shape(); + auto data_shape = x_fp32.sizes(); TVM_FFI_ICHECK_EQ(data_shape.size(), 2) << "x_fp32 should be 2D tensor."; int num_tokens = data_shape[0]; int hidden_dim = data_shape[1]; @@ -109,11 +110,12 @@ void mxfp8_quantize_host(TensorView x_fp32, TensorView fp8_tensor, TensorView sc for (size_t ti = 0; ti < static_cast(data_shape[0]); ++ti) { for (int group = 0; group < groups_per_hidden_dim; ++group) { - float* fp32_ptr = static_cast(x_fp32->data) + ti * hidden_dim + group * sf_vec_size; + float* fp32_ptr = + static_cast(x_fp32.data_ptr()) + ti * hidden_dim + group * sf_vec_size; uint8_t* fp8_ptr = - static_cast(fp8_tensor->data) + ti * hidden_dim + group * sf_vec_size; + static_cast(fp8_tensor.data_ptr()) + ti * hidden_dim + group * sf_vec_size; - uint8_t* scale_ue8m08sf_ptr = static_cast(scale_tensor->data); + uint8_t* scale_ue8m08sf_ptr = static_cast(scale_tensor.data_ptr()); float local_amax = 0.0f; for (int ki = 0; ki < sf_vec_size; ++ki) { @@ -143,8 +145,8 @@ void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, int32_t const sf_vec_size = 32; CHECK_INPUT_TYPE(value_e4m3, dl_uint8); CHECK_INPUT_TYPE(scale_ue8m08sf, dl_uint8); - auto data_shape = value_e4m3.shape(); - auto scale_shape = scale_ue8m08sf.shape(); + auto data_shape = value_e4m3.sizes(); + auto scale_shape = scale_ue8m08sf.sizes(); TVM_FFI_ICHECK_EQ(data_shape.size(), 2) << "value_e4m3 should be 2D tensor."; TVM_FFI_ICHECK_EQ(scale_shape.size(), 1) << "scale_ue8m08sf should be 1D tensor."; @@ -157,10 +159,10 @@ void mxfp8_dequantize_host(TensorView value_e4m3, TensorView scale_ue8m08sf, for (size_t ti = 0; ti < static_cast(data_shape[0]); ++ti) { for (int group = 0; group < groups_per_hidden_dim; ++group) { float* float_ptr = - static_cast(float_tensor->data) + ti * hidden_dim + group * sf_vec_size; + static_cast(float_tensor.data_ptr()) + ti * hidden_dim + group * sf_vec_size; uint8_t* fp8_ptr = - static_cast(value_e4m3->data) + ti * hidden_dim + group * sf_vec_size; - uint8_t* scale_ue8m08sf_ptr = static_cast(scale_ue8m08sf->data); + static_cast(value_e4m3.data_ptr()) + ti * hidden_dim + group * sf_vec_size; + uint8_t* scale_ue8m08sf_ptr = static_cast(scale_ue8m08sf.data_ptr()); uint8_t fp8_scale = scale_ue8m08sf_ptr[computeSFIndex(ti, group, data_shape[0], groups_per_hidden_dim, layout)]; diff --git a/csrc/nvshmem_binding.cu b/csrc/nvshmem_binding.cu index 6d7ecacacb..16526725d1 100644 --- a/csrc/nvshmem_binding.cu +++ b/csrc/nvshmem_binding.cu @@ -31,8 +31,8 @@ using tvm::ffi::Shape; void get_unique_id(TensorView uid) { CHECK_CONTIGUOUS(uid); TVM_FFI_ICHECK_EQ(uid.numel() * get_element_size(uid), nvshmemx_uniqueid_t_size); - TVM_FFI_ICHECK_EQ(uid->device.device_type, kDLCPU); - nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid->data); + TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); + nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid.data_ptr()); *uid_ptr = NVSHMEMX_UNIQUEID_INITIALIZER; nvshmemx_get_uniqueid(uid_ptr); } @@ -42,8 +42,8 @@ int64_t unique_id_size() { return nvshmemx_uniqueid_t_size; } int64_t init(TensorView uid, int64_t rank, int64_t world_size) { CHECK_CONTIGUOUS(uid); TVM_FFI_ICHECK_EQ(uid.numel() * get_element_size(uid), nvshmemx_uniqueid_t_size); - TVM_FFI_ICHECK_EQ(uid->device.device_type, kDLCPU); - nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid->data); + TVM_FFI_ICHECK_EQ(uid.device().device_type, kDLCPU); + nvshmemx_uniqueid_t* uid_ptr = reinterpret_cast(uid.data_ptr()); nvshmemx_init_attr_t attr = NVSHMEMX_INIT_ATTR_INITIALIZER; nvshmemx_set_attr_uniqueid_args(rank, world_size, uid_ptr, &attr); return nvshmemx_init_attr(NVSHMEMX_INIT_WITH_UNIQUEID, &attr); @@ -79,12 +79,13 @@ void barrier_all_on_current_stream() { void alltoall(TensorView dest, TensorView source) { CHECK_CONTIGUOUS(dest); CHECK_CONTIGUOUS(source); - TVM_FFI_ICHECK_EQ(dest->dtype, source->dtype) << "dest and source must have the same dtype"; + TVM_FFI_ICHECK_EQ(dest.dtype(), source.dtype()) << "dest and source must have the same dtype"; - size_t nbytes = dest.numel() * get_element_size(dest) / dest->shape[0]; - cudaStream_t stream = get_stream(dest->device); - NVSHMEMCHECK(nvshmemx_alltoallmem_on_stream(NVSHMEM_TEAM_WORLD, static_cast(dest->data), - static_cast(source->data), nbytes, stream)); + size_t nbytes = dest.numel() * get_element_size(dest) / dest.size(0); + cudaStream_t stream = get_stream(dest.device()); + NVSHMEMCHECK( + nvshmemx_alltoallmem_on_stream(NVSHMEM_TEAM_WORLD, static_cast(dest.data_ptr()), + static_cast(source.data_ptr()), nbytes, stream)); } void fake_alltoall(TensorView dest, TensorView source) {} @@ -92,35 +93,35 @@ void fake_alltoall(TensorView dest, TensorView source) {} void sum_reduce(TensorView dest, TensorView source, int64_t nelems) { CHECK_CONTIGUOUS(dest); CHECK_CONTIGUOUS(source); - TVM_FFI_ICHECK_EQ(dest->dtype, source->dtype) << "dest and source must have the same dtype"; + TVM_FFI_ICHECK_EQ(dest.dtype(), source.dtype()) << "dest and source must have the same dtype"; // Add validation and conversion TVM_FFI_ICHECK_GE(nelems, 0) << "nelems must be non-negative, got " << nelems; TVM_FFI_ICHECK_LE(nelems, SIZE_MAX) << "nelems too large: " << nelems << " > " << SIZE_MAX; size_t nelems_size_t = static_cast(nelems); - cudaStream_t stream = get_stream(dest->device); + cudaStream_t stream = get_stream(dest.device()); - switch (encode_dlpack_dtype(dest->dtype)) { + switch (encode_dlpack_dtype(dest.dtype())) { case float16_code: // float16 NVSHMEMCHECK(nvshmemx_half_sum_reduce_on_stream( - NVSHMEM_TEAM_WORLD, static_cast(dest->data), - static_cast(source->data), nelems_size_t, stream)); + NVSHMEM_TEAM_WORLD, static_cast(dest.data_ptr()), + static_cast(source.data_ptr()), nelems_size_t, stream)); break; case float32_code: // float32 NVSHMEMCHECK(nvshmemx_float_sum_reduce_on_stream( - NVSHMEM_TEAM_WORLD, static_cast(dest->data), static_cast(source->data), - nelems_size_t, stream)); + NVSHMEM_TEAM_WORLD, static_cast(dest.data_ptr()), + static_cast(source.data_ptr()), nelems_size_t, stream)); break; case bfloat16_code: // bfloat16 NVSHMEMCHECK(nvshmemx_bfloat16_sum_reduce_on_stream( - NVSHMEM_TEAM_WORLD, static_cast(dest->data), - static_cast(source->data), nelems_size_t, stream)); + NVSHMEM_TEAM_WORLD, static_cast(dest.data_ptr()), + static_cast(source.data_ptr()), nelems_size_t, stream)); break; default: TVM_FFI_LOG_AND_THROW(NotImplementedError) - << "Unsupported dtype for nvshmem_sum_reduce: " << dest->dtype; + << "Unsupported dtype for nvshmem_sum_reduce: " << dest.dtype(); } } @@ -132,21 +133,21 @@ void allreduce_on_stream_with_copy(TensorView dest_symm, TensorView source_symm, CHECK_CONTIGUOUS(source_symm); CHECK_CONTIGUOUS(dest_local); CHECK_CONTIGUOUS(source_local); - TVM_FFI_ICHECK_EQ(dest_symm->dtype, source_symm->dtype) + TVM_FFI_ICHECK_EQ(dest_symm.dtype(), source_symm.dtype()) << "dest_symm and source_symm must have the same dtype"; - TVM_FFI_ICHECK_EQ(dest_symm->dtype, source_local->dtype) + TVM_FFI_ICHECK_EQ(dest_symm.dtype(), source_local.dtype()) << "dest_symm and source_local must have the same dtype"; - TVM_FFI_ICHECK_EQ(dest_local->dtype, source_local->dtype) + TVM_FFI_ICHECK_EQ(dest_local.dtype(), source_local.dtype()) << "dest_local and source_local must have the same dtype"; - cudaStream_t stream = get_stream(source_symm->device); + cudaStream_t stream = get_stream(source_symm.device()); - cudaMemcpyAsync(source_symm->data, source_local->data, nelems * get_element_size(source_local), - cudaMemcpyDefault, stream); + cudaMemcpyAsync(source_symm.data_ptr(), source_local.data_ptr(), + nelems * get_element_size(source_local), cudaMemcpyDefault, stream); nvshmemx_barrier_on_stream(NVSHMEM_TEAM_WORLD, stream); sum_reduce(dest_symm, source_symm, nelems); - cudaMemcpyAsync(dest_local->data, dest_symm->data, nelems * get_element_size(dest_local), - cudaMemcpyDefault, stream); + cudaMemcpyAsync(dest_local.data_ptr(), dest_symm.data_ptr(), + nelems * get_element_size(dest_local), cudaMemcpyDefault, stream); cudaStreamSynchronize(stream); } diff --git a/csrc/page.cu b/csrc/page.cu index e6397f6150..614fc96640 100644 --- a/csrc/page.cu +++ b/csrc/page.cu @@ -44,11 +44,11 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); CHECK_DIM(1, kv_last_page_len); - unsigned int nnz = append_key->shape[0]; - unsigned int batch_size = kv_last_page_len->shape[0]; - TVM_FFI_ICHECK_EQ(kv_indptr->shape[0], batch_size + 1); - TVM_FFI_ICHECK_EQ(batch_indices->shape[0], nnz); - TVM_FFI_ICHECK_EQ(positions->shape[0], nnz); + unsigned int nnz = append_key.size(0); + unsigned int batch_size = kv_last_page_len.size(0); + TVM_FFI_ICHECK_EQ(kv_indptr.size(0), batch_size + 1); + TVM_FFI_ICHECK_EQ(batch_indices.size(0), nnz); + TVM_FFI_ICHECK_EQ(positions.size(0), nnz); CHECK_DEVICE(append_key, append_key); CHECK_DEVICE(append_value, append_key); CHECK_DEVICE(paged_k_cache, append_key); @@ -60,53 +60,56 @@ void append_paged_kv_cache(TensorView append_key, TensorView append_value, Tenso QKVLayout kv_layout = QKVLayout(layout); unsigned int num_heads, page_size, head_dim; - head_dim = paged_k_cache->shape[3]; + head_dim = paged_k_cache.size(3); if (kv_layout == QKVLayout::kHND) { - num_heads = paged_k_cache->shape[1]; - page_size = paged_k_cache->shape[2]; + num_heads = paged_k_cache.size(1); + page_size = paged_k_cache.size(2); } else { - page_size = paged_k_cache->shape[1]; - num_heads = paged_k_cache->shape[2]; + page_size = paged_k_cache.size(1); + num_heads = paged_k_cache.size(2); } // get kv_cache_strides - auto k_strides = paged_k_cache->strides; - auto v_strides = paged_v_cache->strides; - auto k_dim = paged_k_cache->ndim; - TVM_FFI_ICHECK(std::equal(k_strides, k_strides + k_dim, v_strides)) + auto k_strides = paged_k_cache.strides(); + auto v_strides = paged_v_cache.strides(); + auto k_dim = paged_k_cache.ndim(); + TVM_FFI_ICHECK(std::equal(k_strides.begin(), k_strides.begin() + k_dim, v_strides.begin())) << "k/v strides must be identical"; - auto append_k_strides = append_key->strides; + auto append_k_strides = append_key.strides(); auto append_k_stride_n = append_k_strides[0]; auto append_k_stride_h = append_k_strides[1]; - auto append_v_strides = append_value->strides; + auto append_v_strides = append_value.strides(); auto append_v_stride_n = append_v_strides[0]; auto append_v_stride_h = append_v_strides[1]; - TVM_FFI_ICHECK_EQ(append_key->shape[1], num_heads); - TVM_FFI_ICHECK_EQ(append_key->shape[2], head_dim); - TVM_FFI_ICHECK_EQ(append_value->shape[1], num_heads); - TVM_FFI_ICHECK_EQ(append_value->shape[2], head_dim); + TVM_FFI_ICHECK_EQ(append_key.size(1), num_heads); + TVM_FFI_ICHECK_EQ(append_key.size(2), head_dim); + TVM_FFI_ICHECK_EQ(append_value.size(1), num_heads); + TVM_FFI_ICHECK_EQ(append_value.size(2), head_dim); - cudaSetDevice(append_key->device.device_id); - const cudaStream_t stream = get_stream(append_key->device); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE(paged_k_cache->dtype, c_type, [&] { + cudaSetDevice(append_key.device().device_id); + const cudaStream_t stream = get_stream(append_key.device()); + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE(paged_k_cache.dtype(), c_type, [&] { paged_kv_t paged_kv( num_heads, page_size, head_dim, batch_size, kv_layout, - static_cast(paged_k_cache->data), static_cast(paged_v_cache->data), - k_strides, static_cast(kv_indices->data), static_cast(kv_indptr->data), - static_cast(kv_last_page_len->data)); - cudaError_t status = AppendPagedKVCache( - paged_kv, static_cast(append_key->data), static_cast(append_value->data), - static_cast(batch_indices->data), static_cast(positions->data), nnz, - append_k_stride_n, append_k_stride_h, append_v_stride_n, append_v_stride_h, stream); + static_cast(paged_k_cache.data_ptr()), + static_cast(paged_v_cache.data_ptr()), k_strides.data(), + static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), + static_cast(kv_last_page_len.data_ptr())); + cudaError_t status = + AppendPagedKVCache(paged_kv, static_cast(append_key.data_ptr()), + static_cast(append_value.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), nnz, append_k_stride_n, + append_k_stride_h, append_v_stride_n, append_v_stride_h, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "AppendPagedKVCache failed with error: " << cudaGetErrorString(status); return true; }); TVM_FFI_ICHECK(success) << "AppendPagedKVCache failed to dispatch with dtype " - << paged_k_cache->dtype; + << paged_k_cache.dtype(); } void block_sparse_indices_to_vector_sparse_offsets( @@ -119,15 +122,16 @@ void block_sparse_indices_to_vector_sparse_offsets( CHECK_INPUT(vector_sparse_indptr); CHECK_INPUT(kv_len_arr); - cudaSetDevice(block_sparse_indices->device.device_id); - const cudaStream_t stream = get_stream(block_sparse_indices->device); + cudaSetDevice(block_sparse_indices.device().device_id); + const cudaStream_t stream = get_stream(block_sparse_indices.device()); cudaError_t status = BlockSparseIndicesToVectorSparseOffset( - static_cast(block_sparse_indices->data), - static_cast(block_sparse_indptr->data), - static_cast(vector_sparse_offsets->data), - static_cast(vector_sparse_indptr->data), static_cast(kv_len_arr->data), - stride_block, stride_n, batch_size, block_size, stream); + static_cast(block_sparse_indices.data_ptr()), + static_cast(block_sparse_indptr.data_ptr()), + static_cast(vector_sparse_offsets.data_ptr()), + static_cast(vector_sparse_indptr.data_ptr()), + static_cast(kv_len_arr.data_ptr()), stride_block, stride_n, batch_size, block_size, + stream); TVM_FFI_ICHECK(status == cudaSuccess) << "BlockSparseIndicesToVectorSparseOffset failed with error: " << cudaGetErrorString(status); @@ -156,11 +160,11 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, CHECK_DIM(1, kv_indices); CHECK_DIM(1, kv_indptr); CHECK_DIM(1, kv_last_page_len); - unsigned int nnz = append_ckv->shape[0]; - unsigned int batch_size = kv_last_page_len->shape[0]; - TVM_FFI_ICHECK_EQ(kv_indptr->shape[0], batch_size + 1); - TVM_FFI_ICHECK_EQ(batch_indices->shape[0], nnz); - TVM_FFI_ICHECK_EQ(positions->shape[0], nnz); + unsigned int nnz = append_ckv.size(0); + unsigned int batch_size = kv_last_page_len.size(0); + TVM_FFI_ICHECK_EQ(kv_indptr.size(0), batch_size + 1); + TVM_FFI_ICHECK_EQ(batch_indices.size(0), nnz); + TVM_FFI_ICHECK_EQ(positions.size(0), nnz); CHECK_DEVICE(append_ckv, append_ckv); CHECK_DEVICE(append_kpe, append_ckv); CHECK_DEVICE(ckv_cache, append_ckv); @@ -170,39 +174,41 @@ void append_paged_mla_kv_cache(TensorView append_ckv, TensorView append_kpe, CHECK_DEVICE(kv_last_page_len, append_ckv); unsigned int page_size, ckv_dim, kpe_dim; - page_size = ckv_cache->shape[1]; - ckv_dim = ckv_cache->shape[2]; - kpe_dim = kpe_cache->shape[2]; + page_size = ckv_cache.size(1); + ckv_dim = ckv_cache.size(2); + kpe_dim = kpe_cache.size(2); // get kv_cache_strides - auto ckv_strides = ckv_cache->strides; - auto kpe_strides = kpe_cache->strides; + auto ckv_strides = ckv_cache.strides(); + auto kpe_strides = kpe_cache.strides(); - auto append_ckv_strides = append_ckv->strides; + auto append_ckv_strides = append_ckv.strides(); auto append_ckv_stride_n = append_ckv_strides[0]; - auto append_kpe_strides = append_kpe->strides; + auto append_kpe_strides = append_kpe.strides(); auto append_kpe_stride_n = append_kpe_strides[0]; - TVM_FFI_ICHECK_EQ(append_ckv->shape[1], ckv_dim); - TVM_FFI_ICHECK_EQ(append_kpe->shape[1], kpe_dim); + TVM_FFI_ICHECK_EQ(append_ckv.size(1), ckv_dim); + TVM_FFI_ICHECK_EQ(append_kpe.size(1), kpe_dim); - cudaSetDevice(append_ckv->device.device_id); - const cudaStream_t stream = get_stream(append_ckv->device); - bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE(ckv_cache->dtype, c_type, [&] { + cudaSetDevice(append_ckv.device().device_id); + const cudaStream_t stream = get_stream(append_ckv.device()); + bool success = DISPATCH_DLPACK_DTYPE_TO_CTYPE(ckv_cache.dtype(), c_type, [&] { paged_kv_mla_t paged_mla_kv( - page_size, ckv_dim, kpe_dim, batch_size, static_cast(ckv_cache->data), ckv_strides, - static_cast(kpe_cache->data), kpe_strides, static_cast(kv_indices->data), - static_cast(kv_indptr->data), static_cast(kv_last_page_len->data)); - cudaError_t status = AppendPagedKVMlaCache(paged_mla_kv, static_cast(append_ckv->data), - static_cast(append_kpe->data), - static_cast(batch_indices->data), - static_cast(positions->data), nnz, - append_ckv_stride_n, append_kpe_stride_n, stream); + page_size, ckv_dim, kpe_dim, batch_size, static_cast(ckv_cache.data_ptr()), + ckv_strides.data(), static_cast(kpe_cache.data_ptr()), kpe_strides.data(), + static_cast(kv_indices.data_ptr()), static_cast(kv_indptr.data_ptr()), + static_cast(kv_last_page_len.data_ptr())); + cudaError_t status = + AppendPagedKVMlaCache(paged_mla_kv, static_cast(append_ckv.data_ptr()), + static_cast(append_kpe.data_ptr()), + static_cast(batch_indices.data_ptr()), + static_cast(positions.data_ptr()), nnz, append_ckv_stride_n, + append_kpe_stride_n, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "AppendPagedKVMlaCache failed with error: " << cudaGetErrorString(status); return true; }); TVM_FFI_ICHECK(success) << "AppendPagedKVMlaCache failed to dispatch with dtype " - << ckv_cache->dtype; + << ckv_cache.dtype(); } diff --git a/csrc/pod.cu b/csrc/pod.cu index b9036c7532..e38796e0e3 100644 --- a/csrc/pod.cu +++ b/csrc/pod.cu @@ -55,32 +55,32 @@ void pod_with_kv_cache_tensor( double logits_soft_cap_d, double sm_scale_d, double rope_rcp_scale_d, double rope_rcp_theta_d, bool enable_pdl) { // Prefill setup - unsigned int head_dim_qk = q_p->shape[2]; + unsigned int head_dim_qk = q_p.size(2); unsigned int kv_len_p, qo_len_p, num_kv_heads, num_qo_heads; QKVLayout kv_layout_p = static_cast(layout_p); - qo_len_p = q_p->shape[0]; - num_qo_heads = q_p->shape[1]; - uint32_t q_stride_n_p = q_p->strides[0], q_stride_h_p = q_p->strides[1], k_stride_n_p, - k_stride_h_p, v_stride_n_p, v_stride_h_p; + qo_len_p = q_p.size(0); + num_qo_heads = q_p.size(1); + uint32_t q_stride_n_p = q_p.stride(0), q_stride_h_p = q_p.stride(1), k_stride_n_p, k_stride_h_p, + v_stride_n_p, v_stride_h_p; if (kv_layout_p == QKVLayout::kNHD) { - kv_len_p = k_p->shape[0]; - num_kv_heads = k_p->shape[1]; - k_stride_n_p = k_p->strides[0]; - k_stride_h_p = k_p->strides[1]; - v_stride_n_p = v_p->strides[0]; - v_stride_h_p = v_p->strides[1]; + kv_len_p = k_p.size(0); + num_kv_heads = k_p.size(1); + k_stride_n_p = k_p.stride(0); + k_stride_h_p = k_p.stride(1); + v_stride_n_p = v_p.stride(0); + v_stride_h_p = v_p.stride(1); } else { - kv_len_p = k_p->shape[1]; - num_kv_heads = k_p->shape[0]; - k_stride_h_p = k_p->strides[0]; - k_stride_n_p = k_p->strides[1]; - v_stride_h_p = v_p->strides[0]; - v_stride_n_p = v_p->strides[1]; + kv_len_p = k_p.size(1); + num_kv_heads = k_p.size(0); + k_stride_h_p = k_p.stride(0); + k_stride_n_p = k_p.stride(1); + v_stride_h_p = v_p.stride(0); + v_stride_n_p = v_p.stride(1); } if (maybe_lse_p.has_value()) { const auto& lse = maybe_lse_p.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], qo_len_p); - TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(lse.size(0), qo_len_p); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); } const MaskMode mask_mode_p = static_cast(mask_mode_code_p); @@ -89,20 +89,20 @@ void pod_with_kv_cache_tensor( PrefillPlanInfo plan_info; plan_info.FromVector(std::vector(plan_info_vec.begin(), plan_info_vec.end())); QKVLayout kv_layout_d = static_cast(layout_d); - int64_t batch_size = paged_kv_indptr_d->shape[0] - 1; - int64_t num_qo_heads_d = q_d->shape[1]; + int64_t batch_size = paged_kv_indptr_d.size(0) - 1; + int64_t num_qo_heads_d = q_d.size(1); TVM_FFI_ICHECK_EQ(num_qo_heads, num_qo_heads_d) << "POD currently requires same # Query heads for prefill and decode"; int64_t num_kv_heads_d, page_size_d; - uint32_t head_dim_qk_d = q_d->shape[2]; + uint32_t head_dim_qk_d = q_d.size(2); if (kv_layout_d == QKVLayout::kHND) { - num_kv_heads_d = paged_k_cache_d->shape[1]; - page_size_d = paged_k_cache_d->shape[2]; + num_kv_heads_d = paged_k_cache_d.size(1); + page_size_d = paged_k_cache_d.size(2); } else { - page_size_d = paged_k_cache_d->shape[1]; - num_kv_heads_d = paged_k_cache_d->shape[2]; + page_size_d = paged_k_cache_d.size(1); + num_kv_heads_d = paged_k_cache_d.size(2); } TVM_FFI_ICHECK_EQ(num_kv_heads, num_kv_heads_d) << "POD currently requires same # KV heads for prefill and decode; Prefill: " << num_kv_heads @@ -110,18 +110,18 @@ void pod_with_kv_cache_tensor( if (maybe_lse_d.has_value()) { const auto& lse = maybe_lse_d.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], q_d->shape[0]); - TVM_FFI_ICHECK_EQ(lse->shape[1], q_d->shape[1]); + TVM_FFI_ICHECK_EQ(lse.size(0), q_d.size(0)); + TVM_FFI_ICHECK_EQ(lse.size(1), q_d.size(1)); } - void* float_buffer_ptr = static_cast(float_workspace_buffer_d->data); - void* int_buffer_ptr = static_cast(int_workspace_buffer_d->data); + void* float_buffer_ptr = static_cast(float_workspace_buffer_d.data_ptr()); + void* int_buffer_ptr = static_cast(int_workspace_buffer_d.data_ptr()); const MaskMode mask_mode_d = static_cast(mask_mode_code_d); // get q_stride_n and q_stride_h - const auto q_stride_n_d = q_d->strides[0]; - const auto q_stride_h_d = q_d->strides[1]; + const auto q_stride_n_d = q_d.stride(0); + const auto q_stride_h_d = q_d.stride(1); // get kv_cache_strides const int64_t* kv_cache_strides_d = nullptr; @@ -133,8 +133,8 @@ void pod_with_kv_cache_tensor( } kv_cache_strides_d = k_strides_d.data(); - cudaSetDevice(float_workspace_buffer_d->device.device_id); - const cudaStream_t stream = get_stream(float_workspace_buffer_d->device); + cudaSetDevice(float_workspace_buffer_d.device().device_id); + const cudaStream_t stream = get_stream(float_workspace_buffer_d.device()); DISPATCH_context( MASK_MODE_P, MASK_MODE_D, DTypeQ, DTypeKV, HEAD_DIM_QK, USE_SLIDING_WINDOW_P, @@ -143,12 +143,12 @@ void pod_with_kv_cache_tensor( { // Make params a reference to prefill_params to set values PrefillParams& params = prefill_params; - params.q = static_cast(q_p->data); - params.k = static_cast(k_p->data); - params.v = static_cast(v_p->data); - params.o = static_cast(o_p->data); - params.lse = - maybe_lse_p.has_value() ? static_cast(maybe_lse_p.value()->data) : nullptr; + params.q = static_cast(q_p.data_ptr()); + params.k = static_cast(k_p.data_ptr()); + params.v = static_cast(v_p.data_ptr()); + params.o = static_cast(o_p.data_ptr()); + params.lse = maybe_lse_p.has_value() ? static_cast(maybe_lse_p.value().data_ptr()) + : nullptr; params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); @@ -164,12 +164,14 @@ void pod_with_kv_cache_tensor( params.window_left = window_left_p; params.partition_kv = false; - params.maybe_custom_mask = maybe_custom_mask_p.has_value() - ? static_cast(maybe_custom_mask_p.value()->data) - : nullptr; - params.maybe_alibi_slopes = maybe_alibi_slopes_p.has_value() - ? static_cast(maybe_alibi_slopes_p.value()->data) - : nullptr; + params.maybe_custom_mask = + maybe_custom_mask_p.has_value() + ? static_cast(maybe_custom_mask_p.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_p.has_value() + ? static_cast(maybe_alibi_slopes_p.value().data_ptr()) + : nullptr; params.logits_soft_cap = logits_soft_cap_p; params.sm_scale = sm_scale_p; params.rope_rcp_scale = rope_rcp_scale_p; @@ -181,20 +183,20 @@ void pod_with_kv_cache_tensor( float* tmp_s = nullptr; { DecodeParams& params = decode_params; - params.q = static_cast(q_d->data); + params.q = static_cast(q_d.data_ptr()); paged_kv_t paged_kv( num_kv_heads, page_size_d, HEAD_DIM_VO, batch_size, kv_layout_d, - static_cast(paged_k_cache_d->data), - static_cast(paged_v_cache_d->data), kv_cache_strides_d, - static_cast(paged_kv_indices_d->data), - static_cast(paged_kv_indptr_d->data), - static_cast(paged_kv_last_page_len_d->data)); + static_cast(paged_k_cache_d.data_ptr()), + static_cast(paged_v_cache_d.data_ptr()), kv_cache_strides_d, + static_cast(paged_kv_indices_d.data_ptr()), + static_cast(paged_kv_indptr_d.data_ptr()), + static_cast(paged_kv_last_page_len_d.data_ptr())); params.paged_kv = paged_kv; - params.q_indptr = static_cast(qo_indptr_d->data); - params.o = static_cast(o_d->data); + params.q_indptr = static_cast(qo_indptr_d.data_ptr()); + params.o = static_cast(o_d.data_ptr()); - params.lse = - maybe_lse_d.has_value() ? static_cast(maybe_lse_d.value()->data) : nullptr; + params.lse = maybe_lse_d.has_value() ? static_cast(maybe_lse_d.value().data_ptr()) + : nullptr; params.num_qo_heads = num_qo_heads; params.group_size = uint_fastdiv(num_qo_heads / paged_kv.num_heads); params.q_stride_n = q_stride_n_d; @@ -213,12 +215,14 @@ void pod_with_kv_cache_tensor( params.padded_batch_size = 0; params.partition_kv = false; - params.maybe_mask_indptr = maybe_mask_indptr_d.has_value() - ? static_cast(maybe_mask_indptr_d.value()->data) - : nullptr; - params.maybe_alibi_slopes = maybe_alibi_slopes_d.has_value() - ? static_cast(maybe_alibi_slopes_d.value()->data) - : nullptr; + params.maybe_mask_indptr = + maybe_mask_indptr_d.has_value() + ? static_cast(maybe_mask_indptr_d.value().data_ptr()) + : nullptr; + params.maybe_alibi_slopes = + maybe_alibi_slopes_d.has_value() + ? static_cast(maybe_alibi_slopes_d.value().data_ptr()) + : nullptr; params.logits_soft_cap = logits_soft_cap_d; params.sm_scale = sm_scale_d; params.rope_rcp_scale = rope_rcp_scale_d; @@ -264,7 +268,7 @@ void pod_with_kv_cache_tensor( cudaError_t status = flashinfer::PODWithKVCacheTensorDispatched< HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, USE_FP16_QK_REDUCTION, MASK_MODE_P, CTA_TILE_Q, MASK_MODE_D, PrefillAttentionVariant, DecodeAttentionVariant>( - prefill_params, static_cast(tmp_p->data), decode_params, tmp_v, tmp_s, + prefill_params, static_cast(tmp_p.data_ptr()), decode_params, tmp_v, tmp_s, enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "PODWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); diff --git a/csrc/quantization.cu b/csrc/quantization.cu index 557d79af79..ddea5c100e 100644 --- a/csrc/quantization.cu +++ b/csrc/quantization.cu @@ -21,13 +21,13 @@ using namespace flashinfer; void packbits(TensorView x, const std::string& bitorder, TensorView y) { CHECK_INPUT(x); - auto device = x->device; + auto device = x.device(); TVM_FFI_ICHECK(bitorder == "big" || bitorder == "little") << "bitorder must be 'big' or 'little'"; int64_t num_elements = x.numel(); - auto stream = get_stream(x->device); + auto stream = get_stream(x.device()); cudaError_t status = quantization::PackBits( - static_cast(x->data), static_cast(y->data), num_elements, + static_cast(x.data_ptr()), static_cast(y.data_ptr()), num_elements, bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -42,14 +42,14 @@ void segment_packbits(TensorView x, TensorView input_indptr, TensorView output_i CHECK_DEVICE(input_indptr, x); CHECK_DEVICE(output_indptr, x); TVM_FFI_ICHECK(bitorder == "big" || bitorder == "little") << "bitorder must be 'big' or 'little'"; - unsigned int batch_size = input_indptr->shape[0] - 1; - TVM_FFI_ICHECK_EQ(output_indptr->shape[0], batch_size + 1) + unsigned int batch_size = input_indptr.size(0) - 1; + TVM_FFI_ICHECK_EQ(output_indptr.size(0), batch_size + 1) << "output_indptr must be on the same device as x"; - auto stream = get_stream(x->device); + auto stream = get_stream(x.device()); cudaError_t status = quantization::SegmentPackBits( - static_cast(x->data), static_cast(y->data), - static_cast(input_indptr->data), static_cast(output_indptr->data), - batch_size, + static_cast(x.data_ptr()), static_cast(y.data_ptr()), + static_cast(input_indptr.data_ptr()), + static_cast(output_indptr.data_ptr()), batch_size, bitorder == "big" ? quantization::BitOrder::kBig : quantization::BitOrder::kLittle, stream); } diff --git a/csrc/renorm.cu b/csrc/renorm.cu index d186b40eae..1e2aa45769 100644 --- a/csrc/renorm.cu +++ b/csrc/renorm.cu @@ -25,15 +25,15 @@ void top_p_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_p_arr, double top_p_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); bool has_top_p_arr = maybe_top_p_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::TopPRenormProb( - static_cast(probs->data), static_cast(renorm_probs->data), - has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, batch_size, + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, batch_size, top_p_val, vocab_size, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPRenormProb failed with error code " << cudaGetErrorString(status); @@ -43,15 +43,15 @@ void top_k_renorm_probs(TensorView probs, TensorView renorm_probs, Optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = probs->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = probs.size(0); + unsigned int vocab_size = probs.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::TopKRenormProb( - static_cast(probs->data), static_cast(renorm_probs->data), - has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, + static_cast(probs.data_ptr()), static_cast(renorm_probs.data_ptr()), + has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -62,15 +62,15 @@ void top_k_mask_logits(TensorView logits, TensorView mask_logits, Optional maybe_top_k_arr, int64_t top_k_val) { CHECK_INPUT(logits); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) - unsigned int batch_size = logits->shape[0]; - unsigned int vocab_size = logits->shape[1]; + unsigned int batch_size = logits.size(0); + unsigned int vocab_size = logits.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); - cudaSetDevice(logits->device.device_id); - auto stream = get_stream(logits->device); + cudaSetDevice(logits.device().device_id); + auto stream = get_stream(logits.device()); cudaError_t status = sampling::TopKMaskLogits( - static_cast(logits->data), static_cast(mask_logits->data), - has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, + static_cast(logits.data_ptr()), static_cast(mask_logits.data_ptr()), + has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, stream); TVM_FFI_ICHECK(status == cudaSuccess) diff --git a/csrc/rope.cu b/csrc/rope.cu index 8d17a439a3..3e85711918 100644 --- a/csrc/rope.cu +++ b/csrc/rope.cu @@ -34,34 +34,34 @@ void apply_rope(TensorView q, TensorView k, TensorView q_rope, TensorView k_rope CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_DIM(1, indptr); // indptr: (B + 1) CHECK_DIM(1, offsets); // offsets: (B) - TVM_FFI_ICHECK_EQ(q->shape[0], k->shape[0]); - TVM_FFI_ICHECK_EQ(q->shape[2], k->shape[2]); - unsigned int num_qo_heads = q->shape[1]; - unsigned int num_kv_heads = k->shape[1]; - unsigned int head_dim = q->shape[2]; - unsigned int batch_size = offsets->shape[0]; - TVM_FFI_ICHECK_EQ(indptr->shape[0], batch_size + 1); - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - size_t q_rope_stride_n = q_rope->strides[0]; - size_t q_rope_stride_h = q_rope->strides[1]; - size_t k_rope_stride_n = k_rope->strides[0]; - size_t k_rope_stride_h = k_rope->strides[1]; - TVM_FFI_ICHECK_EQ(indptr->dtype, offsets->dtype); - - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q->dtype, c_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(indptr->dtype, c_idtype, [&] { + TVM_FFI_ICHECK_EQ(q.size(0), k.size(0)); + TVM_FFI_ICHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + TVM_FFI_ICHECK_EQ(indptr.size(0), batch_size + 1); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + TVM_FFI_ICHECK_EQ(indptr.dtype(), offsets.dtype()); + + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(indptr.dtype(), c_idtype, [&] { cudaError_t status = BatchQKApplyRotary( - static_cast(q->data), static_cast(k->data), - static_cast(q_rope->data), static_cast(k_rope->data), - static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, - interleave, rope_scale, rope_theta, stream); + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, + k_rope_stride_h, interleave, rope_scale, rope_theta, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "BatchQKApplyRotary failed with error code " << cudaGetErrorString(status); return true; @@ -79,29 +79,29 @@ void apply_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, TensorVie CHECK_DEVICE(q, k); CHECK_DIM(3, q); // q: (nnz, H_Q, D) CHECK_DIM(3, k); // k: (nnz, H_K, D) - TVM_FFI_ICHECK_EQ(q->shape[0], k->shape[0]); - TVM_FFI_ICHECK_EQ(q->shape[2], k->shape[2]); - unsigned int num_qo_heads = q->shape[1]; - unsigned int num_kv_heads = k->shape[1]; - unsigned int head_dim = q->shape[2]; - unsigned int nnz = q->shape[0]; - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - size_t q_rope_stride_n = q_rope->strides[0]; - size_t q_rope_stride_h = q_rope->strides[1]; - size_t k_rope_stride_n = k_rope->strides[0]; - size_t k_rope_stride_h = k_rope->strides[1]; - - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q->dtype, c_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids->dtype, c_idtype, [&] { + TVM_FFI_ICHECK_EQ(q.size(0), k.size(0)); + TVM_FFI_ICHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int nnz = q.size(0); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { cudaError_t status = BatchQKApplyRotaryPosIds( - static_cast(q->data), static_cast(k->data), - static_cast(q_rope->data), static_cast(k_rope->data), - static_cast(pos_ids->data), nnz, num_qo_heads, num_kv_heads, rotary_dim, + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, stream); @@ -128,31 +128,31 @@ void apply_rope_pos_ids_cos_sin_cache(TensorView q, TensorView k, TensorView q_r // cos_sin_cache: (max_seq_len, R) // First half of R is cos, second half is sin CHECK_DIM(2, cos_sin_cache); - TVM_FFI_ICHECK_EQ(q->shape[0], k->shape[0]); - TVM_FFI_ICHECK_EQ(q->shape[2], k->shape[2]); - unsigned int rotary_dim = cos_sin_cache->shape[1]; - unsigned int num_qo_heads = q->shape[1]; - unsigned int num_kv_heads = k->shape[1]; - unsigned int head_dim = q->shape[2]; - unsigned int nnz = q->shape[0]; - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - size_t q_rope_stride_n = q_rope->strides[0]; - size_t q_rope_stride_h = q_rope->strides[1]; - size_t k_rope_stride_n = k_rope->strides[0]; - size_t k_rope_stride_h = k_rope->strides[1]; - - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q->dtype, c_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids->dtype, c_idtype, [&] { + TVM_FFI_ICHECK_EQ(q.size(0), k.size(0)); + TVM_FFI_ICHECK_EQ(q.size(2), k.size(2)); + unsigned int rotary_dim = cos_sin_cache.size(1); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int nnz = q.size(0); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( - static_cast(q->data), static_cast(k->data), - static_cast(q_rope->data), static_cast(k_rope->data), - static_cast(cos_sin_cache->data), static_cast(pos_ids->data), nnz, - num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), static_cast(pos_ids.data_ptr()), + nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, stream); @@ -178,36 +178,36 @@ void apply_llama31_rope(TensorView q, TensorView k, TensorView q_rope, TensorVie CHECK_DIM(3, k); // k: (nnz, H_K, D) CHECK_DIM(1, indptr); // indptr: (B + 1) CHECK_DIM(1, offsets); // offsets: (B) - TVM_FFI_ICHECK_EQ(q->shape[0], k->shape[0]); - TVM_FFI_ICHECK_EQ(q->shape[2], k->shape[2]); - unsigned int num_qo_heads = q->shape[1]; - unsigned int num_kv_heads = k->shape[1]; - unsigned int head_dim = q->shape[2]; - unsigned int batch_size = offsets->shape[0]; - TVM_FFI_ICHECK_EQ(indptr->shape[0], batch_size + 1); - TVM_FFI_ICHECK_EQ(indptr->dtype, offsets->dtype); - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - size_t q_rope_stride_n = q_rope->strides[0]; - size_t q_rope_stride_h = q_rope->strides[1]; - size_t k_rope_stride_n = k_rope->strides[0]; - size_t k_rope_stride_h = k_rope->strides[1]; - TVM_FFI_ICHECK_EQ(indptr->dtype, offsets->dtype); - - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q->dtype, c_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(indptr->dtype, c_idtype, [&] { + TVM_FFI_ICHECK_EQ(q.size(0), k.size(0)); + TVM_FFI_ICHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int batch_size = offsets.size(0); + TVM_FFI_ICHECK_EQ(indptr.size(0), batch_size + 1); + TVM_FFI_ICHECK_EQ(indptr.dtype(), offsets.dtype()); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + TVM_FFI_ICHECK_EQ(indptr.dtype(), offsets.dtype()); + + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(indptr.dtype(), c_idtype, [&] { cudaError_t status = BatchQKApplyLlama31Rotary( - static_cast(q->data), static_cast(k->data), - static_cast(q_rope->data), static_cast(k_rope->data), - static_cast(indptr->data), static_cast(offsets->data), batch_size, - num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, - k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, - interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, - stream); + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(indptr.data_ptr()), static_cast(offsets.data_ptr()), + batch_size, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, + k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, + k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, + old_context_length, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "BatchQKApplyLlama31Rotary failed with error code " << cudaGetErrorString(status); @@ -227,29 +227,29 @@ void apply_llama31_rope_pos_ids(TensorView q, TensorView k, TensorView q_rope, T CHECK_DEVICE(q, k); CHECK_DIM(3, q); // q: (nnz, H_Q, D) CHECK_DIM(3, k); // k: (nnz, H_K, D) - TVM_FFI_ICHECK_EQ(q->shape[0], k->shape[0]); - TVM_FFI_ICHECK_EQ(q->shape[2], k->shape[2]); - unsigned int num_qo_heads = q->shape[1]; - unsigned int num_kv_heads = k->shape[1]; - unsigned int head_dim = q->shape[2]; - unsigned int nnz = q->shape[0]; - size_t q_stride_n = q->strides[0]; - size_t q_stride_h = q->strides[1]; - size_t k_stride_n = k->strides[0]; - size_t k_stride_h = k->strides[1]; - size_t q_rope_stride_n = q_rope->strides[0]; - size_t q_rope_stride_h = q_rope->strides[1]; - size_t k_rope_stride_n = k_rope->strides[0]; - size_t k_rope_stride_h = k_rope->strides[1]; - - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q->dtype, c_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids->dtype, c_idtype, [&] { + TVM_FFI_ICHECK_EQ(q.size(0), k.size(0)); + TVM_FFI_ICHECK_EQ(q.size(2), k.size(2)); + unsigned int num_qo_heads = q.size(1); + unsigned int num_kv_heads = k.size(1); + unsigned int head_dim = q.size(2); + unsigned int nnz = q.size(0); + size_t q_stride_n = q.stride(0); + size_t q_stride_h = q.stride(1); + size_t k_stride_n = k.stride(0); + size_t k_stride_h = k.stride(1); + size_t q_rope_stride_n = q_rope.stride(0); + size_t q_rope_stride_h = q_rope.stride(1); + size_t k_rope_stride_n = k_rope.stride(0); + size_t k_rope_stride_h = k_rope.stride(1); + + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q.dtype(), c_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { cudaError_t status = BatchQKApplyLlama31RotaryPosIds( - static_cast(q->data), static_cast(k->data), - static_cast(q_rope->data), static_cast(k_rope->data), - static_cast(pos_ids->data), nnz, num_qo_heads, num_kv_heads, rotary_dim, + static_cast(q.data_ptr()), static_cast(k.data_ptr()), + static_cast(q_rope.data_ptr()), static_cast(k_rope.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rotary_dim, head_dim, q_stride_n, q_stride_h, k_stride_n, k_stride_h, q_rope_stride_n, q_rope_stride_h, k_rope_stride_n, k_rope_stride_h, interleave, rope_scale, rope_theta, low_freq_factor, high_freq_factor, old_context_length, stream); @@ -285,29 +285,29 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope CHECK_INPUT(pos_ids); // Extract dimensions from tensor shapes (flexible) - uint32_t rope_dim = q_rope_in->shape[q_rope_in->ndim - 1]; - uint32_t no_rope_dim = q_nope_in->shape[q_nope_in->ndim - 1]; + uint32_t rope_dim = q_rope_in.size(-1); + uint32_t no_rope_dim = q_nope_in.size(-1); // Validate rope and no_rope dimensions are consistent - TVM_FFI_ICHECK_EQ(k_rope_in->shape[k_rope_in->ndim - 1], rope_dim); - TVM_FFI_ICHECK_EQ(k_nope_in->shape[k_nope_in->ndim - 1], no_rope_dim); - TVM_FFI_ICHECK_EQ(q_rope_out->shape[q_rope_out->ndim - 1], rope_dim); - TVM_FFI_ICHECK_EQ(k_rope_out->shape[k_rope_out->ndim - 1], rope_dim); - TVM_FFI_ICHECK_EQ(q_nope_out->shape[q_nope_out->ndim - 1], no_rope_dim); - TVM_FFI_ICHECK_EQ(k_nope_out->shape[k_nope_out->ndim - 1], no_rope_dim); - TVM_FFI_ICHECK_EQ(q_rope_in->dtype, k_rope_in->dtype); - TVM_FFI_ICHECK_EQ(q_rope_in->dtype, q_nope_in->dtype); - TVM_FFI_ICHECK_EQ(q_rope_in->dtype, k_nope_in->dtype); - TVM_FFI_ICHECK_EQ(q_rope_out->dtype, k_rope_out->dtype); - TVM_FFI_ICHECK_EQ(q_rope_out->dtype, q_nope_out->dtype); - TVM_FFI_ICHECK_EQ(q_rope_out->dtype, k_nope_out->dtype); + TVM_FFI_ICHECK_EQ(k_rope_in.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_in.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_out.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(k_rope_out.size(-1), rope_dim); + TVM_FFI_ICHECK_EQ(q_nope_out.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(k_nope_out.size(-1), no_rope_dim); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_rope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), q_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_in.dtype(), k_nope_in.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), k_rope_out.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), q_nope_out.dtype()); + TVM_FFI_ICHECK_EQ(q_rope_out.dtype(), k_nope_out.dtype()); // Validate supported input data types (float16 or bfloat16) - TVM_FFI_ICHECK(q_rope_in->dtype == dl_float16 || q_rope_in->dtype == dl_bfloat16) + TVM_FFI_ICHECK(q_rope_in.dtype() == dl_float16 || q_rope_in.dtype() == dl_bfloat16) << "Input dtype must be float16 or bfloat16"; // Validate supported output quantization data types (float8_e4m3fn or float8_e5m2) - TVM_FFI_ICHECK(q_rope_out->dtype == dl_float8_e4m3fn || q_rope_out->dtype == dl_float8_e5m2) + TVM_FFI_ICHECK(q_rope_out.dtype() == dl_float8_e4m3fn || q_rope_out.dtype() == dl_float8_e5m2) << "Output dtype must be float8_e4m3fn or float8_e5m2"; // Q tensors are always 3D: (nnz, num_qo_heads, rope_dim/no_rope_dim) @@ -318,7 +318,7 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope // K tensors can be 2D (MLA) or 3D (GQA/MHA) uint32_t num_kv_heads; - if (k_rope_in->ndim == 2) { + if (k_rope_in.ndim() == 2) { // MLA case: k_rope_in: (nnz, rope_dim), k_nope_in: (nnz, no_rope_dim) CHECK_DIM(2, k_rope_in); CHECK_DIM(2, k_nope_in); @@ -331,51 +331,51 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope CHECK_DIM(3, k_nope_in); CHECK_DIM(3, k_rope_out); CHECK_DIM(3, k_nope_out); - num_kv_heads = k_rope_in->shape[1]; + num_kv_heads = k_rope_in.size(1); } - uint32_t nnz = q_rope_in->shape[0]; - uint32_t num_qo_heads = q_rope_in->shape[1]; + uint32_t nnz = q_rope_in.size(0); + uint32_t num_qo_heads = q_rope_in.size(1); // Validate consistent dimensions across all tensors - TVM_FFI_ICHECK_EQ(q_nope_in->shape[0], nnz); - TVM_FFI_ICHECK_EQ(k_rope_in->shape[0], nnz); - TVM_FFI_ICHECK_EQ(k_nope_in->shape[0], nnz); - TVM_FFI_ICHECK_EQ(q_rope_out->shape[0], nnz); - TVM_FFI_ICHECK_EQ(k_rope_out->shape[0], nnz); - TVM_FFI_ICHECK_EQ(q_nope_out->shape[0], nnz); - TVM_FFI_ICHECK_EQ(k_nope_out->shape[0], nnz); + TVM_FFI_ICHECK_EQ(q_nope_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(k_rope_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(k_nope_in.size(0), nnz); + TVM_FFI_ICHECK_EQ(q_rope_out.size(0), nnz); + TVM_FFI_ICHECK_EQ(k_rope_out.size(0), nnz); + TVM_FFI_ICHECK_EQ(q_nope_out.size(0), nnz); + TVM_FFI_ICHECK_EQ(k_nope_out.size(0), nnz); // Validate Q tensor head dimensions are consistent - TVM_FFI_ICHECK_EQ(q_nope_in->shape[1], num_qo_heads); - TVM_FFI_ICHECK_EQ(q_rope_out->shape[1], num_qo_heads); - TVM_FFI_ICHECK_EQ(q_nope_out->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(q_nope_in.size(1), num_qo_heads); + TVM_FFI_ICHECK_EQ(q_rope_out.size(1), num_qo_heads); + TVM_FFI_ICHECK_EQ(q_nope_out.size(1), num_qo_heads); // Validate K tensor head dimensions (if 3D) - if (k_rope_in->ndim == 3) { - TVM_FFI_ICHECK_EQ(k_nope_in->shape[1], num_kv_heads); - TVM_FFI_ICHECK_EQ(k_rope_out->shape[1], num_kv_heads); - TVM_FFI_ICHECK_EQ(k_nope_out->shape[1], num_kv_heads); + if (k_rope_in.ndim() == 3) { + TVM_FFI_ICHECK_EQ(k_nope_in.size(1), num_kv_heads); + TVM_FFI_ICHECK_EQ(k_rope_out.size(1), num_kv_heads); + TVM_FFI_ICHECK_EQ(k_nope_out.size(1), num_kv_heads); } - const uint32_t q_rope_in_stride_n = q_rope_in->strides[0]; - const uint32_t q_rope_in_stride_h = q_rope_in->strides[1]; - const uint32_t q_nope_in_stride_n = q_nope_in->strides[0]; - const uint32_t q_nope_in_stride_h = q_nope_in->strides[1]; - const uint32_t q_rope_out_stride_n = q_rope_out->strides[0]; - const uint32_t q_rope_out_stride_h = q_rope_out->strides[1]; - const uint32_t q_nope_out_stride_n = q_nope_out->strides[0]; - const uint32_t q_nope_out_stride_h = q_nope_out->strides[1]; + const uint32_t q_rope_in_stride_n = q_rope_in.stride(0); + const uint32_t q_rope_in_stride_h = q_rope_in.stride(1); + const uint32_t q_nope_in_stride_n = q_nope_in.stride(0); + const uint32_t q_nope_in_stride_h = q_nope_in.stride(1); + const uint32_t q_rope_out_stride_n = q_rope_out.stride(0); + const uint32_t q_rope_out_stride_h = q_rope_out.stride(1); + const uint32_t q_nope_out_stride_n = q_nope_out.stride(0); + const uint32_t q_nope_out_stride_h = q_nope_out.stride(1); // K tensor strides depend on dimensionality uint32_t k_rope_in_stride, k_nope_in_stride, k_rope_out_stride, k_nope_out_stride; uint32_t k_rope_in_stride_h, k_nope_in_stride_h, k_rope_out_stride_h, k_nope_out_stride_h; - if (k_rope_in->ndim == 2) { + if (k_rope_in.ndim() == 2) { // 2D K tensors (MLA): only have batch stride - k_rope_in_stride = k_rope_in->strides[0]; - k_nope_in_stride = k_nope_in->strides[0]; - k_rope_out_stride = k_rope_out->strides[0]; - k_nope_out_stride = k_nope_out->strides[0]; + k_rope_in_stride = k_rope_in.stride(0); + k_nope_in_stride = k_nope_in.stride(0); + k_rope_out_stride = k_rope_out.stride(0); + k_nope_out_stride = k_nope_out.stride(0); // For 2D tensors, head stride is the same as batch stride (shared K/V) k_rope_in_stride_h = k_rope_in_stride; k_nope_in_stride_h = k_nope_in_stride; @@ -383,29 +383,30 @@ void rope_quantize(TensorView q_rope_in, TensorView k_rope_in, TensorView q_nope k_nope_out_stride_h = k_nope_out_stride; } else { // 3D K tensors (GQA/MHA): have both batch and head strides - k_rope_in_stride = k_rope_in->strides[0]; - k_rope_in_stride_h = k_rope_in->strides[1]; - k_nope_in_stride = k_nope_in->strides[0]; - k_nope_in_stride_h = k_nope_in->strides[1]; - k_rope_out_stride = k_rope_out->strides[0]; - k_rope_out_stride_h = k_rope_out->strides[1]; - k_nope_out_stride = k_nope_out->strides[0]; - k_nope_out_stride_h = k_nope_out->strides[1]; + k_rope_in_stride = k_rope_in.stride(0); + k_rope_in_stride_h = k_rope_in.stride(1); + k_nope_in_stride = k_nope_in.stride(0); + k_nope_in_stride_h = k_nope_in.stride(1); + k_rope_out_stride = k_rope_out.stride(0); + k_rope_out_stride_h = k_rope_out.stride(1); + k_nope_out_stride = k_nope_out.stride(0); + k_nope_out_stride_h = k_nope_out.stride(1); } - cudaSetDevice(q_rope_in->device.device_id); - const cudaStream_t stream = get_stream(q_rope_in->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in->dtype, c_type, [&] { - return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out->dtype, c_quant_type, [&] { - return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids->dtype, c_idtype, [&] { + cudaSetDevice(q_rope_in.device().device_id); + const cudaStream_t stream = get_stream(q_rope_in.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(q_rope_in.dtype(), c_type, [&] { + return DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP8(q_rope_out.dtype(), c_quant_type, [&] { + return DISPATCH_DLPACK_IDTYPE_TO_CTYPE(pos_ids.dtype(), c_idtype, [&] { cudaError_t status = RopeQuantize( - static_cast(q_rope_in->data), static_cast(k_rope_in->data), - static_cast(q_nope_in->data), static_cast(k_nope_in->data), - static_cast(q_rope_out->data), - static_cast(k_rope_out->data), - static_cast(q_nope_out->data), - static_cast(k_nope_out->data), static_cast(cos_sin_cache->data), - static_cast(pos_ids->data), nnz, num_qo_heads, num_kv_heads, rope_dim, + static_cast(q_rope_in.data_ptr()), static_cast(k_rope_in.data_ptr()), + static_cast(q_nope_in.data_ptr()), static_cast(k_nope_in.data_ptr()), + static_cast(q_rope_out.data_ptr()), + static_cast(k_rope_out.data_ptr()), + static_cast(q_nope_out.data_ptr()), + static_cast(k_nope_out.data_ptr()), + static_cast(cos_sin_cache.data_ptr()), + static_cast(pos_ids.data_ptr()), nnz, num_qo_heads, num_kv_heads, rope_dim, no_rope_dim, q_rope_in_stride_n, q_rope_in_stride_h, q_nope_in_stride_n, q_nope_in_stride_h, q_rope_out_stride_n, q_rope_out_stride_h, q_nope_out_stride_n, q_nope_out_stride_h, k_rope_in_stride, k_rope_in_stride_h, k_nope_in_stride, diff --git a/csrc/sampling.cu b/csrc/sampling.cu index d17295d091..7210ecb440 100644 --- a/csrc/sampling.cu +++ b/csrc/sampling.cu @@ -27,18 +27,19 @@ void softmax(TensorView workspace_buffer, TensorView logits, TensorView output, CHECK_INPUT(logits); CHECK_INPUT(output); CHECK_DIM(2, logits); // logits: (batch_size, vocab_size) - unsigned int batch_size = logits->shape[0]; - unsigned int vocab_size = logits->shape[1]; + unsigned int batch_size = logits.size(0); + unsigned int vocab_size = logits.size(1); bool has_temperature_arr = maybe_temperature_arr.has_value(); - cudaSetDevice(logits->device.device_id); - auto stream = get_stream(logits->device); + cudaSetDevice(logits.device().device_id); + auto stream = get_stream(logits.device()); cudaError_t status = sampling::OnlineSoftmax( - static_cast(logits->data), static_cast(output->data), batch_size, vocab_size, - has_temperature_arr ? static_cast(maybe_temperature_arr.value()->data) : nullptr, - temperature_val, workspace_buffer->data, - get_element_size(workspace_buffer) * workspace_buffer->shape[0], enable_pdl, stream); + static_cast(logits.data_ptr()), static_cast(output.data_ptr()), batch_size, + vocab_size, + has_temperature_arr ? static_cast(maybe_temperature_arr.value().data_ptr()) : nullptr, + temperature_val, workspace_buffer.data_ptr(), + get_element_size(workspace_buffer) * workspace_buffer.size(0), enable_pdl, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "OnlineSoftmax failed with error code " << cudaGetErrorString(status); } @@ -47,14 +48,14 @@ void sampling_from_logits(TensorView logits, TensorView output, Optionalshape[0]; - unsigned int vocab_size = logits->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = logits.size(1); - cudaSetDevice(logits->device.device_id); - auto stream = get_stream(logits->device); + cudaSetDevice(logits.device().device_id); + auto stream = get_stream(logits.device()); cudaError_t status = sampling::SamplingFromLogits( - static_cast(logits->data), static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, + static_cast(logits.data_ptr()), static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromLogits failed with error code " << cudaGetErrorString(status); @@ -64,14 +65,14 @@ void sampling_from_probs(TensorView probs, TensorView output, Optionalshape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::SamplingFromProb( - static_cast(probs->data), static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, + static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SamplingFromProbs failed with error code " << cudaGetErrorString(status); @@ -83,16 +84,16 @@ void top_p_sampling_from_probs(TensorView probs, TensorView output, bool deterministic, uint64_t philox_seed, uint64_t philox_offset) { CHECK_INPUT(probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) - unsigned int batch_size = output->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); bool has_top_p_arr = maybe_top_p_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::TopPSamplingFromProb( - static_cast(probs->data), static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, batch_size, + static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, + has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, batch_size, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopPSamplingFromProbs failed with error code " << cudaGetErrorString(status); @@ -107,16 +108,16 @@ void top_k_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) - unsigned int batch_size = output->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::TopKSamplingFromProb( - static_cast(probs->data), static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, - has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, batch_size, + static_cast(probs.data_ptr()), static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, + has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, batch_size, top_k_val, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "TopKSamplingFromProbs failed with error code " << cudaGetErrorString(status); @@ -131,17 +132,17 @@ void min_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) - unsigned int batch_size = output->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); bool has_min_p_arr = maybe_min_p_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::MinPSamplingFromProb( - static_cast(probs->data), - has_min_p_arr ? static_cast(maybe_min_p_arr.value()->data) : nullptr, - static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, + static_cast(probs.data_ptr()), + has_min_p_arr ? static_cast(maybe_min_p_arr.value().data_ptr()) : nullptr, + static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, min_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "MinPSamplingFromProb failed with error code " << cudaGetErrorString(status); @@ -158,19 +159,19 @@ void top_k_top_p_sampling_from_probs(TensorView probs, TensorView output, CHECK_DEVICE(output, probs); CHECK_DIM(2, probs); // probs: (batch_size, vocab_size) CHECK_DIM(1, output); // output: (batch_size) - unsigned int batch_size = output->shape[0]; - unsigned int vocab_size = probs->shape[1]; + unsigned int batch_size = output.size(0); + unsigned int vocab_size = probs.size(1); bool has_top_k_arr = maybe_top_k_arr.has_value(); bool has_top_p_arr = maybe_top_p_arr.has_value(); - cudaSetDevice(probs->device.device_id); - auto stream = get_stream(probs->device); + cudaSetDevice(probs.device().device_id); + auto stream = get_stream(probs.device()); cudaError_t status = sampling::TopKTopPSamplingFromProb( - static_cast(probs->data), - has_top_k_arr ? static_cast(maybe_top_k_arr.value()->data) : nullptr, - has_top_p_arr ? static_cast(maybe_top_p_arr.value()->data) : nullptr, - static_cast(output->data), - maybe_indices.has_value() ? static_cast(maybe_indices.value()->data) : nullptr, + static_cast(probs.data_ptr()), + has_top_k_arr ? static_cast(maybe_top_k_arr.value().data_ptr()) : nullptr, + has_top_p_arr ? static_cast(maybe_top_p_arr.value().data_ptr()) : nullptr, + static_cast(output.data_ptr()), + maybe_indices.has_value() ? static_cast(maybe_indices.value().data_ptr()) : nullptr, batch_size, top_k_val, top_p_val, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) @@ -190,24 +191,24 @@ void chain_speculative_sampling(TensorView draft_probs, TensorView draft_token_i CHECK_DIM(3, draft_probs); // draft_probs: (batch_size, num_speculate_tokens, vocab_size) CHECK_DIM(2, draft_token_ids); // draft_token_ids: (batch_size, num_speculate_tokens) CHECK_DIM(3, target_probs); // target_probs: (batch_size, num_speculate_tokens + 1, vocab_size) - unsigned int batch_size = draft_probs->shape[0]; - unsigned int num_speculate_tokens = draft_probs->shape[1]; - unsigned int vocab_size = draft_probs->shape[2]; - TVM_FFI_ICHECK_EQ(batch_size, draft_token_ids->shape[0]); - TVM_FFI_ICHECK_EQ(batch_size, target_probs->shape[0]); - TVM_FFI_ICHECK_EQ(num_speculate_tokens + 1, target_probs->shape[1]); - TVM_FFI_ICHECK_EQ(vocab_size, target_probs->shape[2]); - TVM_FFI_ICHECK_EQ(batch_size, output_accepted_token_num->shape[0]); - TVM_FFI_ICHECK_EQ(batch_size, output_emitted_draft_token_num->shape[0]); - - cudaSetDevice(draft_probs->device.device_id); - auto stream = get_stream(draft_probs->device); + unsigned int batch_size = draft_probs.size(0); + unsigned int num_speculate_tokens = draft_probs.size(1); + unsigned int vocab_size = draft_probs.size(2); + TVM_FFI_ICHECK_EQ(batch_size, draft_token_ids.size(0)); + TVM_FFI_ICHECK_EQ(batch_size, target_probs.size(0)); + TVM_FFI_ICHECK_EQ(num_speculate_tokens + 1, target_probs.size(1)); + TVM_FFI_ICHECK_EQ(vocab_size, target_probs.size(2)); + TVM_FFI_ICHECK_EQ(batch_size, output_accepted_token_num.size(0)); + TVM_FFI_ICHECK_EQ(batch_size, output_emitted_draft_token_num.size(0)); + + cudaSetDevice(draft_probs.device().device_id); + auto stream = get_stream(draft_probs.device()); cudaError_t status = sampling::ChainSpeculativeSampling( - static_cast(draft_probs->data), static_cast(draft_token_ids->data), - static_cast(target_probs->data), static_cast(output_token_ids->data), - static_cast(output_accepted_token_num->data), - static_cast(output_emitted_draft_token_num->data), batch_size, num_speculate_tokens, - vocab_size, deterministic, philox_seed, philox_offset, stream); + static_cast(draft_probs.data_ptr()), static_cast(draft_token_ids.data_ptr()), + static_cast(target_probs.data_ptr()), static_cast(output_token_ids.data_ptr()), + static_cast(output_accepted_token_num.data_ptr()), + static_cast(output_emitted_draft_token_num.data_ptr()), batch_size, + num_speculate_tokens, vocab_size, deterministic, philox_seed, philox_offset, stream); TVM_FFI_ICHECK(status == cudaSuccess) << "ChainSpeculativeSampling failed with error code " << cudaGetErrorString(status); diff --git a/csrc/single_decode.cu b/csrc/single_decode.cu index dfb9107191..84f36d8388 100644 --- a/csrc/single_decode.cu +++ b/csrc/single_decode.cu @@ -44,26 +44,26 @@ void single_decode_with_kv_cache(TensorView q, TensorView k, TensorView v, Tenso CHECK_DIM(3, k); CHECK_DIM(3, v); CHECK_SHAPE(k, v); - TVM_FFI_ICHECK_EQ(q->shape[1], k->shape[2]); - TVM_FFI_ICHECK_EQ(v->dtype, k->dtype); - unsigned int num_qo_heads = q->shape[0]; - unsigned int head_dim_qk = q->shape[1]; - unsigned int head_dim_vo = v->shape[2]; + TVM_FFI_ICHECK_EQ(q.size(1), k.size(2)); + TVM_FFI_ICHECK_EQ(v.dtype(), k.dtype()); + unsigned int num_qo_heads = q.size(0); + unsigned int head_dim_qk = q.size(1); + unsigned int head_dim_vo = v.size(2); unsigned int kv_len, num_kv_heads; QKVLayout kv_layout = static_cast(layout); if (kv_layout == QKVLayout::kNHD) { - kv_len = k->shape[0]; - num_kv_heads = k->shape[1]; + kv_len = k.size(0); + num_kv_heads = k.size(1); } else { - num_kv_heads = k->shape[0]; - kv_len = k->shape[1]; + num_kv_heads = k.size(0); + kv_len = k.size(1); } TVM_FFI_ICHECK_EQ(num_qo_heads % num_kv_heads, 0) << "num_qo_heads(" << num_qo_heads << ") must be divisible by num_kv_heads(" << num_kv_heads << ")"; - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); TVM_FFI_ICHECK_EQ(head_dim_qk, head_dim_vo) << "CUDA cores template only supports equal head dim for QK and VO, please use tensor " @@ -74,11 +74,12 @@ void single_decode_with_kv_cache(TensorView q, TensorView k, TensorView v, Tenso USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { Params params; - params.q = static_cast(q->data); - params.k = static_cast(k->data); - params.v = static_cast(v->data); - params.o = static_cast(o->data); - params.lse = maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + params.q = static_cast(q.data_ptr()); + params.k = static_cast(k.data_ptr()); + params.v = static_cast(v.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.kv_len = kv_len; params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; @@ -95,7 +96,7 @@ void single_decode_with_kv_cache(TensorView q, TensorView k, TensorView v, Tenso cudaError_t status = flashinfer::SingleDecodeWithKVCacheDispatched( - params, static_cast(tmp->data), stream); + params, static_cast(tmp.data_ptr()), stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SingleDecodeWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); diff --git a/csrc/single_prefill.cu b/csrc/single_prefill.cu index 46d24a71ea..1d66db57ba 100644 --- a/csrc/single_prefill.cu +++ b/csrc/single_prefill.cu @@ -38,38 +38,38 @@ void single_prefill_with_kv_cache(ffi::TensorView q, ffi::TensorView k, ffi::Ten ffi::TensorView tmp, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { - unsigned int head_dim_qk = q->shape[2]; + unsigned int head_dim_qk = q.size(2); unsigned int kv_len, qo_len, num_kv_heads, num_qo_heads; QKVLayout kv_layout = static_cast(layout); - qo_len = q->shape[0]; - num_qo_heads = q->shape[1]; - uint32_t q_stride_n = q->strides[0], q_stride_h = q->strides[1], k_stride_n, k_stride_h, - v_stride_n, v_stride_h; + qo_len = q.size(0); + num_qo_heads = q.size(1); + uint32_t q_stride_n = q.stride(0), q_stride_h = q.stride(1), k_stride_n, k_stride_h, v_stride_n, + v_stride_h; if (kv_layout == QKVLayout::kNHD) { - kv_len = k->shape[0]; - num_kv_heads = k->shape[1]; - k_stride_n = k->strides[0]; - k_stride_h = k->strides[1]; - v_stride_n = v->strides[0]; - v_stride_h = v->strides[1]; + kv_len = k.size(0); + num_kv_heads = k.size(1); + k_stride_n = k.stride(0); + k_stride_h = k.stride(1); + v_stride_n = v.stride(0); + v_stride_h = v.stride(1); } else { - kv_len = k->shape[1]; - num_kv_heads = k->shape[0]; - k_stride_h = k->strides[0]; - k_stride_n = k->strides[1]; - v_stride_h = v->strides[0]; - v_stride_n = v->strides[1]; + kv_len = k.size(1); + num_kv_heads = k.size(0); + k_stride_h = k.stride(0); + k_stride_n = k.stride(1); + v_stride_h = v.stride(0); + v_stride_n = v.stride(1); } if (maybe_lse.has_value()) { const auto& lse = maybe_lse.value(); - TVM_FFI_ICHECK_EQ(lse->shape[0], qo_len); - TVM_FFI_ICHECK_EQ(lse->shape[1], num_qo_heads); + TVM_FFI_ICHECK_EQ(lse.size(0), qo_len); + TVM_FFI_ICHECK_EQ(lse.size(1), num_qo_heads); } const MaskMode mask_mode = static_cast(mask_mode_code); - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, @@ -77,11 +77,12 @@ void single_prefill_with_kv_cache(ffi::TensorView q, ffi::TensorView k, ffi::Ten [&] { Params params; - params.q = static_cast(q->data); - params.k = static_cast(k->data); - params.v = static_cast(v->data); - params.o = static_cast(o->data); - params.lse = maybe_lse.has_value() ? static_cast(maybe_lse.value()->data) : nullptr; + params.q = static_cast(q.data_ptr()); + params.k = static_cast(k.data_ptr()); + params.v = static_cast(v.data_ptr()); + params.o = static_cast(o.data_ptr()); + params.lse = + maybe_lse.has_value() ? static_cast(maybe_lse.value().data_ptr()) : nullptr; params.num_qo_heads = num_qo_heads; params.num_kv_heads = num_kv_heads; params.group_size = uint_fastdiv(num_qo_heads / num_kv_heads); @@ -102,7 +103,7 @@ void single_prefill_with_kv_cache(ffi::TensorView q, ffi::TensorView k, ffi::Ten cudaError_t status = flashinfer::SinglePrefillWithKVCacheDispatched< HEAD_DIM_QK, HEAD_DIM_VO, POS_ENCODING_MODE, /*use_fp16_qk_reduction=*/USE_FP16_QK_REDUCTION, MASK_MODE, AttentionVariant>( - params, static_cast(tmp->data), stream); + params, static_cast(tmp.data_ptr()), stream); TVM_FFI_ICHECK(status == cudaSuccess) << "SinglePrefillWithKVCache kernel launch failed, error: " << cudaGetErrorString(status); diff --git a/csrc/single_prefill_fp8_sm90.cu b/csrc/single_prefill_fp8_sm90.cu index 6051d2c737..d6830e2863 100644 --- a/csrc/single_prefill_fp8_sm90.cu +++ b/csrc/single_prefill_fp8_sm90.cu @@ -36,45 +36,45 @@ void single_prefill_with_kv_cache_sm90(ffi::TensorView q, ffi::TensorView k, ffi ffi::TensorView tmp, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { - unsigned int head_dim_qk = q->shape[2]; - unsigned int head_dim_vo = v->shape[2]; - unsigned int num_qo_heads = q->shape[1]; - unsigned int qo_len = q->shape[0]; + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = v.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); QKVLayout kv_layout = static_cast(layout); - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); const MaskMode mask_mode = static_cast(mask_mode_code); DISPATCH_context( DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { Params params; - params.q_ptr = static_cast(q->data); - params.k_ptr = static_cast(k->data); - params.v_ptr = static_cast(v->data); - params.o_ptr = static_cast(o->data); + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); params.lse_ptr = - maybe_lse.has_value() ? (static_cast(maybe_lse.value()->data)) : nullptr; - params.q_stride_n = q->strides[0]; - params.q_stride_h = q->strides[1]; - params.o_stride_n = o->strides[0]; - params.o_stride_h = o->strides[1]; + maybe_lse.has_value() ? (static_cast(maybe_lse.value().data_ptr())) : nullptr; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = k->strides[0]; - params.k_stride_h = k->strides[1]; - params.v_stride_n = v->strides[0]; - params.v_stride_h = v->strides[1]; + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); } else { - params.k_stride_h = k->strides[0]; - params.k_stride_n = k->strides[1]; - params.v_stride_h = v->strides[0]; - params.v_stride_n = v->strides[1]; + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); } - params.qo_len = q->shape[0]; - params.kv_len = k->shape[0]; - params.num_qo_heads = q->shape[1]; - params.num_kv_heads = k->shape[1]; + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); params.causal = mask_mode == MaskMode::kCausal; params.group_size = params.num_qo_heads / params.num_kv_heads; diff --git a/csrc/single_prefill_sm90.cu b/csrc/single_prefill_sm90.cu index 8cfe8e3713..b1f757d462 100644 --- a/csrc/single_prefill_sm90.cu +++ b/csrc/single_prefill_sm90.cu @@ -36,45 +36,45 @@ void single_prefill_with_kv_cache_sm90(ffi::TensorView q, ffi::TensorView k, ffi ffi::TensorView tmp, ffi::TensorView o, Optional maybe_lse, int64_t mask_mode_code, int64_t layout, int64_t window_left ADDITIONAL_FUNC_PARAMS) { - unsigned int head_dim_qk = q->shape[2]; - unsigned int head_dim_vo = v->shape[2]; - unsigned int num_qo_heads = q->shape[1]; - unsigned int qo_len = q->shape[0]; + unsigned int head_dim_qk = q.size(2); + unsigned int head_dim_vo = v.size(2); + unsigned int num_qo_heads = q.size(1); + unsigned int qo_len = q.size(0); QKVLayout kv_layout = static_cast(layout); - cudaSetDevice(q->device.device_id); - const cudaStream_t stream = get_stream(q->device); + cudaSetDevice(q.device().device_id); + const cudaStream_t stream = get_stream(q.device()); const MaskMode mask_mode = static_cast(mask_mode_code); DISPATCH_context(DTypeQ, DTypeKV, DTypeO, IdType, MASK_MODE, HEAD_DIM_QK, HEAD_DIM_VO, USE_SLIDING_WINDOW, USE_LOGITS_SOFT_CAP, AttentionVariant, Params, [&] { Params params; - params.q_ptr = static_cast(q->data); - params.k_ptr = static_cast(k->data); - params.v_ptr = static_cast(v->data); - params.o_ptr = static_cast(o->data); + params.q_ptr = static_cast(q.data_ptr()); + params.k_ptr = static_cast(k.data_ptr()); + params.v_ptr = static_cast(v.data_ptr()); + params.o_ptr = static_cast(o.data_ptr()); params.lse_ptr = maybe_lse.has_value() - ? (static_cast(maybe_lse.value()->data)) + ? (static_cast(maybe_lse.value().data_ptr())) : nullptr; - params.q_stride_n = q->strides[0]; - params.q_stride_h = q->strides[1]; - params.o_stride_n = o->strides[0]; - params.o_stride_h = o->strides[1]; + params.q_stride_n = q.stride(0); + params.q_stride_h = q.stride(1); + params.o_stride_n = o.stride(0); + params.o_stride_h = o.stride(1); if (kv_layout == QKVLayout::kNHD) { - params.k_stride_n = k->strides[0]; - params.k_stride_h = k->strides[1]; - params.v_stride_n = v->strides[0]; - params.v_stride_h = v->strides[1]; + params.k_stride_n = k.stride(0); + params.k_stride_h = k.stride(1); + params.v_stride_n = v.stride(0); + params.v_stride_h = v.stride(1); } else { - params.k_stride_h = k->strides[0]; - params.k_stride_n = k->strides[1]; - params.v_stride_h = v->strides[0]; - params.v_stride_n = v->strides[1]; + params.k_stride_h = k.stride(0); + params.k_stride_n = k.stride(1); + params.v_stride_h = v.stride(0); + params.v_stride_n = v.stride(1); } - params.qo_len = q->shape[0]; - params.kv_len = k->shape[0]; - params.num_qo_heads = q->shape[1]; - params.num_kv_heads = k->shape[1]; + params.qo_len = q.size(0); + params.kv_len = k.size(0); + params.num_qo_heads = q.size(1); + params.num_kv_heads = k.size(1); params.causal = mask_mode == MaskMode::kCausal; params.group_size = params.num_qo_heads / params.num_kv_heads; params.window_left = window_left; diff --git a/csrc/tgv_gemm.cu b/csrc/tgv_gemm.cu index 73025e22f5..75019221c6 100644 --- a/csrc/tgv_gemm.cu +++ b/csrc/tgv_gemm.cu @@ -119,12 +119,12 @@ void tgv_gemm_impl(input_type* mat1_ptr, input_type* mat2_ptr, output_type* outp void tgv_gemm(TensorView mat1, TensorView mat2, Optional bias, int64_t tactic, TensorView out, bool pdl) { // Input validation - TVM_FFI_ICHECK_EQ(mat1->device.device_type, kDLCUDA) << "mat1 tensor must be on CUDA"; - TVM_FFI_ICHECK_EQ(mat2->device.device_type, kDLCUDA) << "mat2 tensor must be on CUDA"; - TVM_FFI_ICHECK_EQ(mat1->ndim, 2) << "mat1 tensor must be 2D (M, K)"; - TVM_FFI_ICHECK_EQ(mat2->ndim, 2) << "mat2 tensor must be 2D (K, N)"; - TVM_FFI_ICHECK_EQ(mat1->shape[1], mat2->shape[0]) << "mat1.K must match mat2.K"; - TVM_FFI_ICHECK_EQ(mat1->dtype, mat2->dtype) << "mat1 and mat2 must have the same dtype"; + TVM_FFI_ICHECK_EQ(mat1.device().device_type, kDLCUDA) << "mat1 tensor must be on CUDA"; + TVM_FFI_ICHECK_EQ(mat2.device().device_type, kDLCUDA) << "mat2 tensor must be on CUDA"; + TVM_FFI_ICHECK_EQ(mat1.ndim(), 2) << "mat1 tensor must be 2D (M, K)"; + TVM_FFI_ICHECK_EQ(mat2.ndim(), 2) << "mat2 tensor must be 2D (K, N)"; + TVM_FFI_ICHECK_EQ(mat1.size(1), mat2.size(0)) << "mat1.K must match mat2.K"; + TVM_FFI_ICHECK_EQ(mat1.dtype(), mat2.dtype()) << "mat1 and mat2 must have the same dtype"; // No heuristic for now, we use 64x8 with 8 DMA stages as the default tactic. if (tactic == -1) { @@ -144,9 +144,9 @@ void tgv_gemm(TensorView mat1, TensorView mat2, Optional bias, int64 TVM_FFI_ICHECK(cta_m == 64 || cta_m == 128) << "cta_m must be one of: 64, 128"; // Get dimensions - int M = mat1->shape[0]; - int K = mat1->shape[1]; - int N = mat2->shape[1]; + int M = mat1.size(0); + int K = mat1.size(1); + int N = mat2.size(1); int64_t element_size = get_element_size(mat1); TVM_FFI_ICHECK(int64_t(M) * N * element_size < std::numeric_limits::max()) @@ -158,46 +158,46 @@ void tgv_gemm(TensorView mat1, TensorView mat2, Optional bias, int64 // validity check for bias if (bias.has_value()) { - TVM_FFI_ICHECK_EQ(bias.value()->device.device_type, kDLCUDA) << "Bias tensor must be on CUDA"; - TVM_FFI_ICHECK_EQ(bias.value()->ndim, 1) << "Bias tensor must be 1D (M,)"; - TVM_FFI_ICHECK_EQ(bias.value()->shape[0], M) << "Bias tensor must have M elements"; - TVM_FFI_ICHECK_EQ(bias.value()->dtype, mat1->dtype) + TVM_FFI_ICHECK_EQ(bias.value().device().device_type, kDLCUDA) << "Bias tensor must be on CUDA"; + TVM_FFI_ICHECK_EQ(bias.value().ndim(), 1) << "Bias tensor must be 1D (M,)"; + TVM_FFI_ICHECK_EQ(bias.value().size(0), M) << "Bias tensor must have M elements"; + TVM_FFI_ICHECK_EQ(bias.value().dtype(), mat1.dtype()) << "Bias tensor must have the same dtype as input matrices"; - TVM_FFI_ICHECK_EQ(bias.value()->strides[0], 1) << "Bias tensor must be M contiguous"; + TVM_FFI_ICHECK_EQ(bias.value().stride(0), 1) << "Bias tensor must be M contiguous"; } // Create output tensor [N, M] row major - TVM_FFI_ICHECK_EQ(out->shape[0], N); - TVM_FFI_ICHECK_EQ(out->shape[1], M); - TVM_FFI_ICHECK_EQ(out->dtype, mat1->dtype); + TVM_FFI_ICHECK_EQ(out.size(0), N); + TVM_FFI_ICHECK_EQ(out.size(1), M); + TVM_FFI_ICHECK_EQ(out.dtype(), mat1.dtype()); // manually calculate the L stride // A [M, K] row major - int stride_A_M = mat1->strides[0]; - int stride_A_K = mat1->strides[1]; + int stride_A_M = mat1.stride(0); + int stride_A_K = mat1.stride(1); int stride_A_L = M * K; // B [K, N] column major - int stride_B_N = mat2->strides[1]; - int stride_B_K = mat2->strides[0]; + int stride_B_N = mat2.stride(1); + int stride_B_K = mat2.stride(0); int stride_B_L = N * K; // original C [N, M] row major - int stride_C_M = out->strides[1]; - int stride_C_N = out->strides[0]; + int stride_C_M = out.stride(1); + int stride_C_N = out.stride(0); int stride_C_L = M * N; // Get CUDA stream - cudaStream_t stream = get_stream(out->device); + cudaStream_t stream = get_stream(out.device()); // Dispatch based on dtype - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(mat1->dtype, c_type, [&] { + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(mat1.dtype(), c_type, [&] { using cutlass_input_type = flashinfer::cutlass_dtype_t; using cutlass_output_type = flashinfer::cutlass_dtype_t; - cutlass_input_type* mat1_ptr = static_cast(mat1->data); - cutlass_input_type* mat2_ptr = static_cast(mat2->data); - cutlass_output_type* output_ptr = static_cast(out->data); + cutlass_input_type* mat1_ptr = static_cast(mat1.data_ptr()); + cutlass_input_type* mat2_ptr = static_cast(mat2.data_ptr()); + cutlass_output_type* output_ptr = static_cast(out.data_ptr()); cutlass_output_type* bias_ptr = - bias.has_value() ? static_cast(bias.value()->data) : nullptr; + bias.has_value() ? static_cast(bias.value().data_ptr()) : nullptr; tgv_gemm_impl( mat1_ptr, mat2_ptr, output_ptr, bias_ptr, M, N, K, stride_A_M, stride_A_K, stride_A_L, @@ -205,20 +205,14 @@ void tgv_gemm(TensorView mat1, TensorView mat2, Optional bias, int64 dma_stage, pdl, stream); return true; }); - - // original C is [N, M] row major - // after transpose, it's [M, N] column major - // the storage is unchanged, only the logical coordinates are changed - std::swap(out->shape[0], out->shape[1]); - std::swap(out->strides[0], out->strides[1]); } // Keep backward compatibility functions void bf16_gemm(TensorView mat1, TensorView mat2, std::optional bias, int64_t tactic, TensorView out, bool pdl) { // Check that inputs are bfloat16 for backward compatibility - TVM_FFI_ICHECK_EQ(mat1->dtype, dl_bfloat16) << "mat1 tensor must be bfloat16"; - TVM_FFI_ICHECK_EQ(mat2->dtype, dl_bfloat16) << "mat2 tensor must be bfloat16"; + TVM_FFI_ICHECK_EQ(mat1.dtype(), dl_bfloat16) << "mat1 tensor must be bfloat16"; + TVM_FFI_ICHECK_EQ(mat2.dtype(), dl_bfloat16) << "mat2 tensor must be bfloat16"; tgv_gemm(mat1, mat2, bias, tactic, out, pdl); } diff --git a/csrc/trtllm_allreduce.cu b/csrc/trtllm_allreduce.cu index 7546658282..e985f50403 100644 --- a/csrc/trtllm_allreduce.cu +++ b/csrc/trtllm_allreduce.cu @@ -96,11 +96,11 @@ void trtllm_custom_all_reduce(TensorView in, TensorView out, int64_t tp_size, in Optional lamport_peer_comm_buffer_ptrs_1, Optional lamport_peer_comm_buffer_ptrs_2) { AllReduceFusionOp fusion_op = static_cast(fusion_op_code); - cudaSetDevice(in->device.device_id); - auto stream = get_stream(in->device); + cudaSetDevice(in.device().device_id); + auto stream = get_stream(in.device()); // TODO(zihao): review dispatch type - support fp16, bf16 only - DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(in->dtype, c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(in.dtype(), c_type, [&] { // TODO(yingyi): remove type template here (used to check if lamport is supported) int64_t message_size = in.numel(); int64_t hidden_size = in.numel() / token_num; @@ -109,29 +109,31 @@ void trtllm_custom_all_reduce(TensorView in, TensorView out, int64_t tp_size, in params.elts_total = message_size; params.local_rank = tp_rank; params.ranks_per_node = tp_size; - params.local_input_buffer_ptr = in->data; - params.local_output_buffer_ptr = out->data; + params.local_input_buffer_ptr = in.data_ptr(); + params.local_output_buffer_ptr = out.data_ptr(); params.barrier_flag = flag_value; // add fusion params - params.fusion_params.bias_buffer = bias.has_value() ? bias.value()->data : nullptr; - params.fusion_params.residual_buffer = residual.has_value() ? residual.value()->data : nullptr; + params.fusion_params.bias_buffer = bias.has_value() ? bias.value().data_ptr() : nullptr; + params.fusion_params.residual_buffer = + residual.has_value() ? residual.value().data_ptr() : nullptr; params.fusion_params.hidden_size = hidden_size; - params.fusion_params.weight_buffer = weight.has_value() ? weight.value()->data : nullptr; + params.fusion_params.weight_buffer = weight.has_value() ? weight.value().data_ptr() : nullptr; params.fusion_params.weight_buffer_pre_residual_norm = - weight_pre_residual_norm.has_value() ? weight_pre_residual_norm.value()->data : nullptr; + weight_pre_residual_norm.has_value() ? weight_pre_residual_norm.value().data_ptr() + : nullptr; params.fusion_params.eps = eps.has_value() ? eps.value() : 1e-5f; params.fusion_params.intermediate_buffer = - intermediate_buffer.has_value() ? intermediate_buffer.value()->data : nullptr; + intermediate_buffer.has_value() ? intermediate_buffer.value().data_ptr() : nullptr; // add ipc buffer pointers for (int i = 0; i < tp_size; ++i) { params.peer_comm_buffer_ptrs[i] = - reinterpret_cast(static_cast(peer_comm_buffer_ptrs->data)[i]); + reinterpret_cast(static_cast(peer_comm_buffer_ptrs.data_ptr())[i]); params.peer_barrier_ptrs_in[i] = - reinterpret_cast(static_cast(peer_barrier_ptrs_in->data)[i]); + reinterpret_cast(static_cast(peer_barrier_ptrs_in.data_ptr())[i]); params.peer_barrier_ptrs_out[i] = - reinterpret_cast(static_cast(peer_barrier_ptrs_out->data)[i]); + reinterpret_cast(static_cast(peer_barrier_ptrs_out.data_ptr())[i]); } if (lamport_peer_comm_buffer_ptrs_0.has_value()) { @@ -143,12 +145,12 @@ void trtllm_custom_all_reduce(TensorView in, TensorView out, int64_t tp_size, in "is provided"; for (int i = 0; i < tp_size; ++i) { params.fusion_params.lamport_peer_comm_buffer_ptrs[i] = reinterpret_cast( - static_cast(lamport_peer_comm_buffer_ptrs_0.value()->data)[i]); + static_cast(lamport_peer_comm_buffer_ptrs_0.value().data_ptr())[i]); params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tp_size] = reinterpret_cast( - static_cast(lamport_peer_comm_buffer_ptrs_1.value()->data)[i]); + static_cast(lamport_peer_comm_buffer_ptrs_1.value().data_ptr())[i]); params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tp_size * 2] = reinterpret_cast( - static_cast(lamport_peer_comm_buffer_ptrs_2.value()->data)[i]); + static_cast(lamport_peer_comm_buffer_ptrs_2.value().data_ptr())[i]); } } diff --git a/csrc/trtllm_allreduce_fusion.cu b/csrc/trtllm_allreduce_fusion.cu index 4b465c7ffc..c0f74194e4 100644 --- a/csrc/trtllm_allreduce_fusion.cu +++ b/csrc/trtllm_allreduce_fusion.cu @@ -37,43 +37,46 @@ void trtllm_allreduce_fusion(TensorView allreduce_in, int64_t world_size, int64_ Optional quant_out, Optional scale_out, Optional rms_gamma, Optional rms_eps, Optional scale_factor, Optional layout_code) { - cudaSetDevice(allreduce_in->device.device_id); + cudaSetDevice(allreduce_in.device().device_id); // todo(Yingyi): add dispatch for float and bfloat16 - DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(allreduce_in->dtype, c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(allreduce_in.dtype(), c_type, [&] { AllReduceFusionParams params; params.nranks = world_size; params.rank = world_rank; params.size = token_num * hidden_size; params.hidden_dim = hidden_size; - params.workspace = reinterpret_cast(workspace_ptrs->data); + params.workspace = reinterpret_cast(workspace_ptrs.data_ptr()); // todo(Yingyi): update optional params // todo(Yingyi): add params check with pattern - params.allreduce_in = reinterpret_cast(allreduce_in->data); - params.allreduce_out = - allreduce_out.has_value() ? reinterpret_cast(allreduce_out.value()->data) : nullptr; + params.allreduce_in = reinterpret_cast(allreduce_in.data_ptr()); + params.allreduce_out = allreduce_out.has_value() + ? reinterpret_cast(allreduce_out.value().data_ptr()) + : nullptr; params.residual_in = - residual_in.has_value() ? reinterpret_cast(residual_in.value()->data) : nullptr; - params.residual_out = - residual_out.has_value() ? reinterpret_cast(residual_out.value()->data) : nullptr; + residual_in.has_value() ? reinterpret_cast(residual_in.value().data_ptr()) : nullptr; + params.residual_out = residual_out.has_value() + ? reinterpret_cast(residual_out.value().data_ptr()) + : nullptr; params.norm_out = - norm_out.has_value() ? reinterpret_cast(norm_out.value()->data) : nullptr; + norm_out.has_value() ? reinterpret_cast(norm_out.value().data_ptr()) : nullptr; params.quant_out = - quant_out.has_value() ? reinterpret_cast(quant_out.value()->data) : nullptr; + quant_out.has_value() ? reinterpret_cast(quant_out.value().data_ptr()) : nullptr; params.scale_out = - scale_out.has_value() ? reinterpret_cast(scale_out.value()->data) : nullptr; + scale_out.has_value() ? reinterpret_cast(scale_out.value().data_ptr()) : nullptr; params.rms_gamma = - rms_gamma.has_value() ? reinterpret_cast(rms_gamma.value()->data) : nullptr; + rms_gamma.has_value() ? reinterpret_cast(rms_gamma.value().data_ptr()) : nullptr; params.rms_eps = rms_eps.has_value() ? static_cast(rms_eps.value()) : 0.0f; - params.scale_factor = - scale_factor.has_value() ? reinterpret_cast(scale_factor.value()->data) : nullptr; + params.scale_factor = scale_factor.has_value() + ? reinterpret_cast(scale_factor.value().data_ptr()) + : nullptr; params.use_oneshot = use_oneshot; params.layout = layout_code.has_value() ? static_cast(layout_code.value()) : QuantizationSFLayout::SWIZZLED_128x4; params.pattern = static_cast(pattern_code); params.trigger_completion_at_end = trigger_completion_at_end; - params.stream = get_stream(allreduce_in->device); + params.stream = get_stream(allreduce_in.device()); auto status = allreduce_fusion_op(params, launch_with_pdl, fp32_acc); TVM_FFI_ICHECK(status == cudaSuccess) diff --git a/csrc/trtllm_alltoall.cu b/csrc/trtllm_alltoall.cu index 51e0b2c2d1..9ad4fff110 100644 --- a/csrc/trtllm_alltoall.cu +++ b/csrc/trtllm_alltoall.cu @@ -34,20 +34,20 @@ void moeCommPrepareIndicesOp(TensorView gatheredTargetRankIds, TensorView backwardRecvRankLocalIndices, int64_t maxTokenCountPerRank, int64_t expertCount, int64_t topK, int64_t epRank, int64_t epSize) { CHECK_INPUT_TYPE(gatheredTargetRankIds, dl_int32); - TVM_FFI_ICHECK_EQ(gatheredTargetRankIds->ndim, 2) << "gatheredTargetRankIds must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(gatheredTargetRankIds->shape[1], topK) + TVM_FFI_ICHECK_EQ(gatheredTargetRankIds.ndim(), 2) << "gatheredTargetRankIds must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(gatheredTargetRankIds.size(1), topK) << "gatheredTargetRankIds must have topK columns"; int const* realRankTokenCountCumSumPtr = nullptr; if (realRankTokenCountCumSum.has_value()) { - TVM_FFI_ICHECK_EQ(realRankTokenCountCumSum.value()->ndim, 1) + TVM_FFI_ICHECK_EQ(realRankTokenCountCumSum.value().ndim(), 1) << "realRankTokenCountCumSum must be a 1D tensor"; CHECK_INPUT_TYPE(realRankTokenCountCumSum.value(), dl_int32) - TVM_FFI_ICHECK_EQ(realRankTokenCountCumSum.value()->shape[0], epSize) + TVM_FFI_ICHECK_EQ(realRankTokenCountCumSum.value().size(0), epSize) << "realRankTokenCountCumSum must have epSize elements"; - realRankTokenCountCumSumPtr = static_cast(realRankTokenCountCumSum.value()->data); + realRankTokenCountCumSumPtr = static_cast(realRankTokenCountCumSum.value().data_ptr()); } else { - TVM_FFI_ICHECK_EQ(gatheredTargetRankIds->shape[0], epSize * maxTokenCountPerRank) + TVM_FFI_ICHECK_EQ(gatheredTargetRankIds.size(0), epSize * maxTokenCountPerRank) << "gatheredTargetRankIds should have shape (epSize * maxTokenCountPerRank, topK)"; } TVM_FFI_ICHECK_GT(maxTokenCountPerRank, 0) << "maxTokenCountPerRank must be greater than 0"; @@ -61,28 +61,28 @@ void moeCommPrepareIndicesOp(TensorView gatheredTargetRankIds, int maxSendRanksPerToken = std::max(static_cast(epSize), static_cast(topK)); CHECK_INPUT_TYPE(localGatherIndices, dl_int32); - TVM_FFI_ICHECK_EQ(localGatherIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(localGatherIndices->shape[0], maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(localGatherIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(localGatherIndices.size(0), maxTokenCountPerRank * epSize); CHECK_INPUT_TYPE(sendRankCountCumSum, dl_int32); - TVM_FFI_ICHECK_EQ(sendRankCountCumSum->ndim, 1); - TVM_FFI_ICHECK_EQ(sendRankCountCumSum->shape[0], epSize); + TVM_FFI_ICHECK_EQ(sendRankCountCumSum.ndim(), 1); + TVM_FFI_ICHECK_EQ(sendRankCountCumSum.size(0), epSize); CHECK_INPUT_TYPE(sendRankLocalIndices, dl_int32); - TVM_FFI_ICHECK_EQ(sendRankLocalIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(sendRankLocalIndices->shape[0], maxTokenCountPerRank * maxSendRanksPerToken); + TVM_FFI_ICHECK_EQ(sendRankLocalIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(sendRankLocalIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); CHECK_INPUT_TYPE(recvRankCountCumSum, dl_int32); - TVM_FFI_ICHECK_EQ(recvRankCountCumSum->ndim, 1); - TVM_FFI_ICHECK_EQ(recvRankCountCumSum->shape[0], epSize); + TVM_FFI_ICHECK_EQ(recvRankCountCumSum.ndim(), 1); + TVM_FFI_ICHECK_EQ(recvRankCountCumSum.size(0), epSize); CHECK_INPUT_TYPE(recvRankLocalIndices, dl_int32); - TVM_FFI_ICHECK_EQ(recvRankLocalIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(recvRankLocalIndices->shape[0], maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(recvRankLocalIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(recvRankLocalIndices.size(0), maxTokenCountPerRank * epSize); CHECK_INPUT_TYPE(backwardRecvRankLocalIndices, dl_int32); - TVM_FFI_ICHECK_EQ(backwardRecvRankLocalIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(backwardRecvRankLocalIndices->shape[0], + TVM_FFI_ICHECK_EQ(backwardRecvRankLocalIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(backwardRecvRankLocalIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); flashinfer::trtllm_alltoall::MoeExpertParallelInfo expertParallelInfo; @@ -94,11 +94,13 @@ void moeCommPrepareIndicesOp(TensorView gatheredTargetRankIds, auto cudaResult = flashinfer::trtllm_alltoall::moeAllToAllPrepareIndices( worldInfo, expertParallelInfo, static_cast(maxTokenCountPerRank), - static_cast(gatheredTargetRankIds->data), realRankTokenCountCumSumPtr, - static_cast(localGatherIndices->data), static_cast(sendRankCountCumSum->data), - static_cast(sendRankLocalIndices->data), static_cast(recvRankCountCumSum->data), - static_cast(recvRankLocalIndices->data), - static_cast(backwardRecvRankLocalIndices->data), stream); + static_cast(gatheredTargetRankIds.data_ptr()), realRankTokenCountCumSumPtr, + static_cast(localGatherIndices.data_ptr()), + static_cast(sendRankCountCumSum.data_ptr()), + static_cast(sendRankLocalIndices.data_ptr()), + static_cast(recvRankCountCumSum.data_ptr()), + static_cast(recvRankLocalIndices.data_ptr()), + static_cast(backwardRecvRankLocalIndices.data_ptr()), stream); TVM_FFI_ICHECK(cudaResult == cudaSuccess) << "CUDA error in moeAllToAllPrepareIndices: " << cudaGetErrorString(cudaResult); } @@ -121,23 +123,22 @@ void moeLocalGatherOp(TensorView recvRankCumSum, TensorView localGatherIndices, TVM_FFI_ICHECK_LE(topK, expertCount) << "topK must be less than or equal to expertCount"; TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize) << "epRank must be in the range [0, epSize)"; - TVM_FFI_ICHECK_EQ(recvRankCumSum->ndim, 1) << "recvRankCumSum must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(recvRankCumSum->shape[0], epSize) << "recvRankCumSum must have epSize elements"; - TVM_FFI_ICHECK_EQ(localGatherIndices->ndim, 1) << "localGatherIndices must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(gatheredExpertIds->ndim, 2) << "gatheredExpertIds must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(gatheredScales->ndim, 2) << "gatheredScales must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(localExpertIds->ndim, 2) << "localExpertIds must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(localScales->ndim, 2) << "localScales must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(gatheredExpertIds->shape[1], topK) - << "gatheredExpertIds must have topK columns"; - TVM_FFI_ICHECK_EQ(gatheredScales->shape[1], topK) << "gatheredScales must have topK columns"; - TVM_FFI_ICHECK_EQ(localExpertIds->shape[1], topK) << "localExpertIds must have topK columns"; - TVM_FFI_ICHECK_EQ(localScales->shape[1], topK) << "localScales must have topK columns"; - - int localMaxTokenCount = static_cast(localGatherIndices->shape[0]); - TVM_FFI_ICHECK_EQ(localExpertIds->shape[0], localMaxTokenCount) + TVM_FFI_ICHECK_EQ(recvRankCumSum.ndim(), 1) << "recvRankCumSum must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(recvRankCumSum.size(0), epSize) << "recvRankCumSum must have epSize elements"; + TVM_FFI_ICHECK_EQ(localGatherIndices.ndim(), 1) << "localGatherIndices must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(gatheredExpertIds.ndim(), 2) << "gatheredExpertIds must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(gatheredScales.ndim(), 2) << "gatheredScales must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(localExpertIds.ndim(), 2) << "localExpertIds must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(localScales.ndim(), 2) << "localScales must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(gatheredExpertIds.size(1), topK) << "gatheredExpertIds must have topK columns"; + TVM_FFI_ICHECK_EQ(gatheredScales.size(1), topK) << "gatheredScales must have topK columns"; + TVM_FFI_ICHECK_EQ(localExpertIds.size(1), topK) << "localExpertIds must have topK columns"; + TVM_FFI_ICHECK_EQ(localScales.size(1), topK) << "localScales must have topK columns"; + + int localMaxTokenCount = static_cast(localGatherIndices.size(0)); + TVM_FFI_ICHECK_EQ(localExpertIds.size(0), localMaxTokenCount) << "localExpertIds must have localMaxTokenCount rows"; - TVM_FFI_ICHECK_EQ(localScales->shape[0], localMaxTokenCount) + TVM_FFI_ICHECK_EQ(localScales.size(0), localMaxTokenCount) << "localScales must have localMaxTokenCount rows"; auto stream = get_current_stream(); @@ -150,9 +151,11 @@ void moeLocalGatherOp(TensorView recvRankCumSum, TensorView localGatherIndices, static_cast(epRank)}; flashinfer::trtllm_alltoall::moeLocalGather( worldInfo, expertParallelInfo, static_cast(maxTokenCountPerRank), localMaxTokenCount, - static_cast(recvRankCumSum->data), static_cast(localGatherIndices->data), - static_cast(gatheredExpertIds->data), static_cast(gatheredScales->data), - static_cast(localExpertIds->data), static_cast(localScales->data), stream); + static_cast(recvRankCumSum.data_ptr()), + static_cast(localGatherIndices.data_ptr()), + static_cast(gatheredExpertIds.data_ptr()), + static_cast(gatheredScales.data_ptr()), static_cast(localExpertIds.data_ptr()), + static_cast(localScales.data_ptr()), stream); } void moeCommOp(TensorView input, TensorView sendRankCumSum, TensorView sendIndices, @@ -165,19 +168,19 @@ void moeCommOp(TensorView input, TensorView sendRankCumSum, TensorView sendIndic // allWorkspaces is a uint64 tensor, but may not be contiguous CHECK_INPUT_TYPE(allWorkspaces, dl_uint64); - TVM_FFI_ICHECK_EQ(input->ndim, 2) << "input must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(output->ndim, 2) << "output must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(sendRankCumSum->ndim, 1) << "sendRankCumSum must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(sendIndices->ndim, 1) << "sendIndices must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(recvRankCumSum->ndim, 1) << "recvRankCumSum must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(recvIndices->ndim, 1) << "recvIndices must be a 1D tensor"; - TVM_FFI_ICHECK_EQ(allWorkspaces->ndim, 2) << "allWorkspaces must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(input.ndim(), 2) << "input must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(output.ndim(), 2) << "output must be a 2D tensor"; + TVM_FFI_ICHECK_EQ(sendRankCumSum.ndim(), 1) << "sendRankCumSum must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(sendIndices.ndim(), 1) << "sendIndices must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(recvRankCumSum.ndim(), 1) << "recvRankCumSum must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(recvIndices.ndim(), 1) << "recvIndices must be a 1D tensor"; + TVM_FFI_ICHECK_EQ(allWorkspaces.ndim(), 2) << "allWorkspaces must be a 2D tensor"; - TVM_FFI_ICHECK_EQ(input->shape[1], output->shape[1]) + TVM_FFI_ICHECK_EQ(input.size(1), output.size(1)) << "input and output must have the same second dimension"; - TVM_FFI_ICHECK_EQ(sendRankCumSum->shape[0], epSize) << "sendRankCumSum must have epSize elements"; - TVM_FFI_ICHECK_EQ(recvRankCumSum->shape[0], epSize) << "recvRankCumSum must have epSize elements"; - TVM_FFI_ICHECK_EQ(allWorkspaces->shape[0], epSize) << "allWorkspaces must have epSize elements"; + TVM_FFI_ICHECK_EQ(sendRankCumSum.size(0), epSize) << "sendRankCumSum must have epSize elements"; + TVM_FFI_ICHECK_EQ(recvRankCumSum.size(0), epSize) << "recvRankCumSum must have epSize elements"; + TVM_FFI_ICHECK_EQ(allWorkspaces.size(0), epSize) << "allWorkspaces must have epSize elements"; TVM_FFI_ICHECK(epRank >= 0 && epRank < epSize) << "epRank must be in the range [0, epSize)"; @@ -187,25 +190,25 @@ void moeCommOp(TensorView input, TensorView sendRankCumSum, TensorView sendIndic size_t eltSize = get_element_size(input); size_t eltCountPerU64 = sizeof(uint64_t) / eltSize; - TVM_FFI_ICHECK_EQ(input->shape[1] % (eltCountPerU64 * 2), 0) - << "input->shape[1] must be aligned to 16 bytes"; - sendRecvDataInfo.vectorSizeInU64 = input->shape[1] / eltCountPerU64; + TVM_FFI_ICHECK_EQ(input.size(1) % (eltCountPerU64 * 2), 0) + << "input.size(1) must be aligned to 16 bytes"; + sendRecvDataInfo.vectorSizeInU64 = input.size(1) / eltCountPerU64; sendRecvDataInfo.DoPreCompute(); flashinfer::trtllm_alltoall::SendRecvDispls sendDispls, recvDispls; - sendDispls.dataPtr = static_cast(input->data); - sendDispls.rankCountCumSum = static_cast(sendRankCumSum->data); - sendDispls.rankLocalIndices = static_cast(sendIndices->data); - sendDispls.vectorStrideInU64 = input->strides[0] / eltCountPerU64; + sendDispls.dataPtr = static_cast(input.data_ptr()); + sendDispls.rankCountCumSum = static_cast(sendRankCumSum.data_ptr()); + sendDispls.rankLocalIndices = static_cast(sendIndices.data_ptr()); + sendDispls.vectorStrideInU64 = input.stride(0) / eltCountPerU64; - recvDispls.dataPtr = static_cast(output->data); - recvDispls.rankCountCumSum = static_cast(recvRankCumSum->data); - recvDispls.rankLocalIndices = static_cast(recvIndices->data); - recvDispls.vectorStrideInU64 = output->strides[0] / eltCountPerU64; + recvDispls.dataPtr = static_cast(output.data_ptr()); + recvDispls.rankCountCumSum = static_cast(recvRankCumSum.data_ptr()); + recvDispls.rankLocalIndices = static_cast(recvIndices.data_ptr()); + recvDispls.vectorStrideInU64 = output.stride(0) / eltCountPerU64; flashinfer::trtllm_alltoall::MoeCommWorkspace workspace; - workspace.workspacePtr = static_cast(allWorkspaces->data); - workspace.rankStrideInU64 = allWorkspaces->strides[0]; + workspace.workspacePtr = static_cast(allWorkspaces.data_ptr()); + workspace.rankStrideInU64 = allWorkspaces.stride(0); auto stream = get_current_stream(); @@ -238,110 +241,118 @@ void moePrepareOp(TensorView expertsIds, Optional scales, TVM_FFI_ICHECK_EQ(slotCount % 4, 0) << "slotCount must be divisible by 4"; int64_t maxSendRanksPerToken = std::max(epSize, topK); - int64_t tokenCount = expertsIds->shape[0]; + int64_t tokenCount = expertsIds.size(0); CHECK_DEVICE(preparedLocalExpertIds, expertsIds); CHECK_INPUT_TYPE(preparedLocalExpertIds, dl_int32); - TVM_FFI_ICHECK_EQ(preparedLocalExpertIds->ndim, 2); - TVM_FFI_ICHECK_EQ(preparedLocalExpertIds->shape[0], maxTokenCountPerRank * epSize); - TVM_FFI_ICHECK_EQ(preparedLocalExpertIds->shape[1], topK); + TVM_FFI_ICHECK_EQ(preparedLocalExpertIds.ndim(), 2); + TVM_FFI_ICHECK_EQ(preparedLocalExpertIds.size(0), maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(preparedLocalExpertIds.size(1), topK); CHECK_DEVICE(sendRankCountCumSum, expertsIds); CHECK_INPUT_TYPE(sendRankCountCumSum, dl_int32); - TVM_FFI_ICHECK_EQ(sendRankCountCumSum->ndim, 1); - TVM_FFI_ICHECK_EQ(sendRankCountCumSum->shape[0], epSize); + TVM_FFI_ICHECK_EQ(sendRankCountCumSum.ndim(), 1); + TVM_FFI_ICHECK_EQ(sendRankCountCumSum.size(0), epSize); CHECK_DEVICE(recvRankCountCumSum, expertsIds); CHECK_INPUT_TYPE(recvRankCountCumSum, dl_int32); - TVM_FFI_ICHECK_EQ(recvRankCountCumSum->ndim, 1); - TVM_FFI_ICHECK_EQ(recvRankCountCumSum->shape[0], epSize); + TVM_FFI_ICHECK_EQ(recvRankCountCumSum.ndim(), 1); + TVM_FFI_ICHECK_EQ(recvRankCountCumSum.size(0), epSize); CHECK_DEVICE(gatherRecvRankIndices, expertsIds); CHECK_INPUT_TYPE(gatherRecvRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(gatherRecvRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(gatherRecvRankIndices->shape[0], maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(gatherRecvRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(gatherRecvRankIndices.size(0), maxTokenCountPerRank * epSize); CHECK_DEVICE(recvRankIndices, expertsIds); CHECK_INPUT_TYPE(recvRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(recvRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(recvRankIndices->shape[0], maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(recvRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(recvRankIndices.size(0), maxTokenCountPerRank * epSize); CHECK_DEVICE(gatherBackwardRecvRankIndices, expertsIds); CHECK_INPUT_TYPE(gatherBackwardRecvRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(gatherBackwardRecvRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(gatherBackwardRecvRankIndices->shape[0], + TVM_FFI_ICHECK_EQ(gatherBackwardRecvRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(gatherBackwardRecvRankIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); CHECK_DEVICE(backwardRecvRankIndices, expertsIds); CHECK_INPUT_TYPE(backwardRecvRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(backwardRecvRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(backwardRecvRankIndices->shape[0], maxTokenCountPerRank * maxSendRanksPerToken); + TVM_FFI_ICHECK_EQ(backwardRecvRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(backwardRecvRankIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); CHECK_DEVICE(gatherSendRankIndices, expertsIds); CHECK_INPUT_TYPE(gatherSendRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(gatherSendRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(gatherSendRankIndices->shape[0], maxTokenCountPerRank * maxSendRanksPerToken); + TVM_FFI_ICHECK_EQ(gatherSendRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(gatherSendRankIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); CHECK_DEVICE(sendRankIndices, expertsIds); CHECK_INPUT_TYPE(sendRankIndices, dl_int32); - TVM_FFI_ICHECK_EQ(sendRankIndices->ndim, 1); - TVM_FFI_ICHECK_EQ(sendRankIndices->shape[0], maxTokenCountPerRank * maxSendRanksPerToken); + TVM_FFI_ICHECK_EQ(sendRankIndices.ndim(), 1); + TVM_FFI_ICHECK_EQ(sendRankIndices.size(0), maxTokenCountPerRank * maxSendRanksPerToken); float* scalesPtr = nullptr; float* preparedLocalScalesPtr = nullptr; if (scales.has_value()) { CHECK_INPUT_TYPE(scales.value(), dl_float32); - scalesPtr = static_cast(scales.value()->data); + scalesPtr = static_cast(scales.value().data_ptr()); CHECK_DEVICE(preparedLocalScales.value(), expertsIds); CHECK_INPUT_TYPE(preparedLocalScales.value(), dl_int32); - TVM_FFI_ICHECK_EQ(preparedLocalScales.value()->ndim, 2); - TVM_FFI_ICHECK_EQ(preparedLocalScales.value()->shape[0], maxTokenCountPerRank * epSize); - TVM_FFI_ICHECK_EQ(preparedLocalScales.value()->shape[1], topK); - preparedLocalScalesPtr = static_cast(preparedLocalScales.value()->data); + TVM_FFI_ICHECK_EQ(preparedLocalScales.value().ndim(), 2); + TVM_FFI_ICHECK_EQ(preparedLocalScales.value().size(0), maxTokenCountPerRank * epSize); + TVM_FFI_ICHECK_EQ(preparedLocalScales.value().size(1), topK); + preparedLocalScalesPtr = static_cast(preparedLocalScales.value().data_ptr()); } int* localExpertStaticsPtr = nullptr; int* gatheredExpertStaticsPtr = nullptr; if (expertsStatics.has_value()) { - localExpertStaticsPtr = static_cast(expertsStatics.value()->data); + localExpertStaticsPtr = static_cast(expertsStatics.value().data_ptr()); CHECK_DEVICE(gatheredExpertStatics.value(), expertsIds); CHECK_INPUT_TYPE(gatheredExpertStatics.value(), dl_int32); - TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value()->ndim, 2); - TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value()->shape[0], epSize); - TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value()->shape[1], expertCount); - gatheredExpertStaticsPtr = static_cast(gatheredExpertStatics.value()->data); + TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value().ndim(), 2); + TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value().size(0), epSize); + TVM_FFI_ICHECK_EQ(gatheredExpertStatics.value().size(1), expertCount); + gatheredExpertStaticsPtr = static_cast(gatheredExpertStatics.value().data_ptr()); } flashinfer::trtllm_alltoall::moe_prepare::MoeCommWorkspace workspace; - workspace.workspacePtr = static_cast(allWorkspaces->data); - workspace.rankStrideInU64 = allWorkspaces->strides[0]; + workspace.workspacePtr = static_cast(allWorkspaces.data_ptr()); + workspace.rankStrideInU64 = allWorkspaces.stride(0); auto stream = get_current_stream(); flashinfer::trtllm_alltoall::moe_prepare::computeCountAndIndice( - static_cast(expertsIds->data), static_cast(sendRankCountCumSum->data), - static_cast(recvRankCountCumSum->data), static_cast(sendRankIndices->data), - static_cast(backwardRecvRankIndices->data), static_cast(recvRankIndices->data), - workspace, tokenCount, maxTokenCountPerRank, topK, slotCount, epRank, epSize, stream); + static_cast(expertsIds.data_ptr()), static_cast(sendRankCountCumSum.data_ptr()), + static_cast(recvRankCountCumSum.data_ptr()), + static_cast(sendRankIndices.data_ptr()), + static_cast(backwardRecvRankIndices.data_ptr()), + static_cast(recvRankIndices.data_ptr()), workspace, tokenCount, maxTokenCountPerRank, + topK, slotCount, epRank, epSize, stream); flashinfer::trtllm_alltoall::moe_prepare::computeCumsum( - static_cast(sendRankCountCumSum->data), static_cast(recvRankCountCumSum->data), - epRank, epSize, stream); + static_cast(sendRankCountCumSum.data_ptr()), + static_cast(recvRankCountCumSum.data_ptr()), epRank, epSize, stream); flashinfer::trtllm_alltoall::moe_prepare::moveIndice( - static_cast(sendRankCountCumSum->data), static_cast(recvRankCountCumSum->data), - static_cast(sendRankIndices->data), static_cast(gatherSendRankIndices->data), - static_cast(backwardRecvRankIndices->data), - static_cast(gatherBackwardRecvRankIndices->data), - static_cast(recvRankIndices->data), static_cast(gatherRecvRankIndices->data), - epRank, epSize, maxTokenCountPerRank, stream); + static_cast(sendRankCountCumSum.data_ptr()), + static_cast(recvRankCountCumSum.data_ptr()), + static_cast(sendRankIndices.data_ptr()), + static_cast(gatherSendRankIndices.data_ptr()), + static_cast(backwardRecvRankIndices.data_ptr()), + static_cast(gatherBackwardRecvRankIndices.data_ptr()), + static_cast(recvRankIndices.data_ptr()), + static_cast(gatherRecvRankIndices.data_ptr()), epRank, epSize, maxTokenCountPerRank, + stream); flashinfer::trtllm_alltoall::moe_prepare::allToAllMetadata( - static_cast(expertsIds->data), static_cast(preparedLocalExpertIds->data), - scalesPtr, preparedLocalScalesPtr, localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, - static_cast(sendRankCountCumSum->data), static_cast(sendRankIndices->data), - static_cast(recvRankCountCumSum->data), static_cast(recvRankIndices->data), - tokenCount, maxTokenCountPerRank, topK, expertCount, slotCount, epRank, epSize, stream); + static_cast(expertsIds.data_ptr()), + static_cast(preparedLocalExpertIds.data_ptr()), scalesPtr, preparedLocalScalesPtr, + localExpertStaticsPtr, gatheredExpertStaticsPtr, workspace, + static_cast(sendRankCountCumSum.data_ptr()), + static_cast(sendRankIndices.data_ptr()), + static_cast(recvRankCountCumSum.data_ptr()), + static_cast(recvRankIndices.data_ptr()), tokenCount, maxTokenCountPerRank, topK, + expertCount, slotCount, epRank, epSize, stream); } TVM_FFI_DLL_EXPORT_TYPED_FUNC(moe_comm_prepare_indices, moeCommPrepareIndicesOp); diff --git a/csrc/trtllm_fmha_kernel_launcher.cu b/csrc/trtllm_fmha_kernel_launcher.cu index 05e92a1721..c40e773e64 100644 --- a/csrc/trtllm_fmha_kernel_launcher.cu +++ b/csrc/trtllm_fmha_kernel_launcher.cu @@ -202,59 +202,57 @@ void trtllm_paged_attention_decode(TensorView out, Optional out_scal int64_t o_sf_start_index, int64_t window_left, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { - auto q_data_type = dl_dtype_to_tllm_data_type(query->dtype); - auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache->dtype); - TVM_FFI_ICHECK_EQ(key_cache->ndim, value_cache->ndim); - for (int i = 0; i < key_cache->ndim; i++) { - TVM_FFI_ICHECK_EQ(key_cache->shape[i], value_cache->shape[i]); + auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); + auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); + TVM_FFI_ICHECK_EQ(key_cache.ndim(), value_cache.ndim()); + for (int i = 0; i < key_cache.ndim(); i++) { + TVM_FFI_ICHECK_EQ(key_cache.size(i), value_cache.size(i)); } - auto o_data_type = dl_dtype_to_tllm_data_type(out->dtype); + auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); // NOTE(Zihao): query is [B, Q, H, D] // where Q is the number of query tokens per request, used in MTP // based on profiled results, always use decode mode for MTP (q_len is small) // example: when kv_len = 10000, q < 200, decode mode is faster - int batch_size = query->shape[0]; - int q_len_per_request = query->shape[1]; + int batch_size = query.size(0); + int q_len_per_request = query.size(1); int sum_seq_q = batch_size * q_len_per_request; - int num_qo_heads = query->shape[2]; + int num_qo_heads = query.size(2); // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. - int head_dim_k = is_4bit(kv_data_type) ? key_cache->shape[key_cache->ndim - 1] * 2 - : key_cache->shape[key_cache->ndim - 1]; - int head_dim_q = - is_4bit(q_data_type) ? query->shape[query->ndim - 1] * 2 : query->shape[query->ndim - 1]; - int head_dim_v = is_4bit(kv_data_type) ? value_cache->shape[value_cache->ndim - 1] * 2 - : value_cache->shape[value_cache->ndim - 1]; - int head_dim_o = is_4bit(o_data_type) ? out->shape[out->ndim - 1] * 2 : out->shape[out->ndim - 1]; + int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); + int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); + int head_dim_v = is_4bit(kv_data_type) ? value_cache.size(-1) * 2 : value_cache.size(-1); + int head_dim_o = is_4bit(o_data_type) ? out.size(-1) * 2 : out.size(-1); TVM_FFI_ICHECK_EQ(head_dim_k, head_dim_q) << "head_dim_k and head_dim_q must be the same, got " << std::to_string(head_dim_k) << " and " << std::to_string(head_dim_q); TVM_FFI_ICHECK((head_dim_v == 576 && head_dim_o == 512) || head_dim_v == head_dim_o) << "head_dim_v and head_dim_o must be the same for non-MLA attention, got " << std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o); - int page_size = key_cache->shape[key_cache->ndim - 2]; - int num_kv_heads = key_cache->shape[key_cache->ndim - 3]; - int max_num_blocks_per_seq = block_tables->shape[block_tables->ndim - 1]; - bool is_shared_kv = key_cache->data == value_cache->data; - int num_pages_in_mem_pool = is_shared_kv ? key_cache->shape[0] : key_cache->shape[0] * 2; + int page_size = key_cache.size(-2); + int num_kv_heads = key_cache.size(-3); + int max_num_blocks_per_seq = block_tables.size(-1); + bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr(); + int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2; - int kv_stride_keys_values = key_cache->strides[key_cache->ndim - 2]; // key/values - int kv_stride_heads = key_cache->strides[key_cache->ndim - 3]; // head - int kv_stride_batch = key_cache->strides[0]; // batch + int kv_stride_keys_values = key_cache.stride(-2); // key/values + int kv_stride_heads = key_cache.stride(-3); // head + int kv_stride_batch = key_cache.stride(0); // batch - const auto stream = get_stream(query->device); - void* output_sf_ptr = out_scale_factor.has_value() ? out_scale_factor.value()->data : nullptr; + const auto stream = get_stream(query.device()); + void* output_sf_ptr = + out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr; float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { - TVM_FFI_ICHECK_EQ(attention_sinks.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) << "attention_sinks must be a float tensor"; - attention_sinks_ptr = static_cast(attention_sinks.value()->data); + attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } trtllm_paged_attention_launcher( - out->data, output_sf_ptr, query->data, key_cache->data, value_cache->data, - workspace_buffer->data, static_cast(block_tables->data), - static_cast(seq_lens->data), + out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), + workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), + static_cast(seq_lens.data_ptr()), /*cum_seq_lens_q=*/nullptr, /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::ForGen, batch_size, /*max_q_len=*/q_len_per_request, max_kv_len, @@ -274,51 +272,49 @@ void trtllm_paged_attention_context(TensorView out, Optional out_sca TensorView cum_seq_lens_q, TensorView cum_seq_lens_kv, int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional attention_sinks) { - auto q_data_type = dl_dtype_to_tllm_data_type(query->dtype); - auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache->dtype); - auto o_data_type = dl_dtype_to_tllm_data_type(out->dtype); - int num_qo_heads = query->shape[1]; - int sum_seq_q = query->shape[0]; + auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); + auto kv_data_type = dl_dtype_to_tllm_data_type(key_cache.dtype()); + auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); + int num_qo_heads = query.size(1); + int sum_seq_q = query.size(0); // Multiply by two for FP4 tensor as it is stored as UINT8 dtype. Assume the dim is even. - int head_dim_k = is_4bit(kv_data_type) ? key_cache->shape[key_cache->ndim - 1] * 2 - : key_cache->shape[key_cache->ndim - 1]; - int head_dim_q = - is_4bit(q_data_type) ? query->shape[query->ndim - 1] * 2 : query->shape[query->ndim - 1]; - int head_dim_v = is_4bit(kv_data_type) ? value_cache->shape[value_cache->ndim - 1] * 2 - : value_cache->shape[value_cache->ndim - 1]; - int head_dim_o = is_4bit(o_data_type) ? out->shape[out->ndim - 1] * 2 : out->shape[out->ndim - 1]; + int head_dim_k = is_4bit(kv_data_type) ? key_cache.size(-1) * 2 : key_cache.size(-1); + int head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1); + int head_dim_v = is_4bit(kv_data_type) ? value_cache.size(-1) * 2 : value_cache.size(-1); + int head_dim_o = is_4bit(o_data_type) ? out.size(-1) * 2 : out.size(-1); TVM_FFI_ICHECK_EQ(head_dim_k, head_dim_q) << "head_dim_k and head_dim_q must be the same, got " << std::to_string(head_dim_k) << " and " << std::to_string(head_dim_q); TVM_FFI_ICHECK_EQ(head_dim_v, head_dim_o) << "head_dim_v and head_dim_o must be the same, got " << std::to_string(head_dim_v) << " and " << std::to_string(head_dim_o); - int max_num_blocks_per_seq = block_tables->shape[block_tables->ndim - 1]; - bool is_shared_kv = key_cache->data == value_cache->data; - int num_pages_in_mem_pool = is_shared_kv ? key_cache->shape[0] : key_cache->shape[0] * 2; - int page_size = key_cache->shape[key_cache->ndim - 2]; - int num_kv_heads = key_cache->shape[key_cache->ndim - 3]; + int max_num_blocks_per_seq = block_tables.size(-1); + bool is_shared_kv = key_cache.data_ptr() == value_cache.data_ptr(); + int num_pages_in_mem_pool = is_shared_kv ? key_cache.size(0) : key_cache.size(0) * 2; + int page_size = key_cache.size(-2); + int num_kv_heads = key_cache.size(-3); - int kv_stride_keys_values = key_cache->strides[key_cache->ndim - 2]; // key/values - int kv_stride_heads = key_cache->strides[key_cache->ndim - 3]; // head - int kv_stride_batch = key_cache->strides[0]; // batch + int kv_stride_keys_values = key_cache.stride(-2); // key/values + int kv_stride_heads = key_cache.stride(-3); // head + int kv_stride_batch = key_cache.stride(0); // batch - const auto stream = get_stream(query->device); - void* output_sf_ptr = out_scale_factor.has_value() ? out_scale_factor.value()->data : nullptr; + const auto stream = get_stream(query.device()); + void* output_sf_ptr = + out_scale_factor.has_value() ? out_scale_factor.value().data_ptr() : nullptr; float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { - TVM_FFI_ICHECK_EQ(attention_sinks.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) << "attention_sinks must be a float tensor"; - attention_sinks_ptr = static_cast(attention_sinks.value()->data); + attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } trtllm_paged_attention_launcher( - out->data, output_sf_ptr, query->data, key_cache->data, value_cache->data, - workspace_buffer->data, static_cast(block_tables->data), - static_cast(seq_lens->data), - /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q->data), - /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv->data), attention_sinks_ptr, + out.data_ptr(), output_sf_ptr, query.data_ptr(), key_cache.data_ptr(), value_cache.data_ptr(), + workspace_buffer.data_ptr(), static_cast(block_tables.data_ptr()), + static_cast(seq_lens.data_ptr()), + /*cum_seq_lens_q=*/static_cast(cum_seq_lens_q.data_ptr()), + /*cum_seq_lens_kv=*/static_cast(cum_seq_lens_kv.data_ptr()), attention_sinks_ptr, q_data_type, kv_data_type, o_data_type, TllmPagedAttentionMode::Context, batch_size, max_q_len, max_kv_len, num_pages_in_mem_pool, num_qo_heads, num_kv_heads, head_dim_q, head_dim_o, page_size, kv_stride_keys_values, kv_stride_heads, kv_stride_batch, @@ -423,45 +419,46 @@ void trtllm_ragged_attention(TensorView out, TensorView query, TensorView key, T Optional lse) { float* attention_sinks_ptr = nullptr; if (attention_sinks.has_value()) { - TVM_FFI_ICHECK_EQ(attention_sinks.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(attention_sinks.value().dtype(), dl_float32) << "attention_sinks must be a float tensor"; - attention_sinks_ptr = static_cast(attention_sinks.value()->data); + attention_sinks_ptr = static_cast(attention_sinks.value().data_ptr()); } float* lse_ptr = nullptr; if (lse.has_value()) { - TVM_FFI_ICHECK_EQ(lse.value()->dtype, dl_float32) << "lse must be a float tensor"; - lse_ptr = static_cast(lse.value()->data); + TVM_FFI_ICHECK_EQ(lse.value().dtype(), dl_float32) << "lse must be a float tensor"; + lse_ptr = static_cast(lse.value().data_ptr()); } - TVM_FFI_ICHECK_EQ(out->ndim, 3) << "out must be a 3D tensor"; - TVM_FFI_ICHECK_EQ(query->ndim, 3) << "query must be a 3D tensor"; - TVM_FFI_ICHECK_EQ(key->ndim, 3) << "key must be a 3D tensor"; - TVM_FFI_ICHECK_EQ(value->ndim, 3) << "value must be a 3D tensor"; - - auto q_data_type = dl_dtype_to_tllm_data_type(query->dtype); - auto kv_data_type = dl_dtype_to_tllm_data_type(key->dtype); - auto o_data_type = dl_dtype_to_tllm_data_type(out->dtype); - const auto stream = get_stream(query->device); - int num_qo_heads = query->shape[1]; - int num_kv_heads = key->shape[1]; - int sum_seq_q = query->shape[0]; - int sum_seq_kv = key->shape[0]; - int head_dim_qk = query->shape[2]; - int head_dim_v = value->shape[2]; - int k_stride_keys_values = key->strides[0]; - int k_stride_heads = key->strides[1]; + TVM_FFI_ICHECK_EQ(out.ndim(), 3) << "out must be a 3D tensor"; + TVM_FFI_ICHECK_EQ(query.ndim(), 3) << "query must be a 3D tensor"; + TVM_FFI_ICHECK_EQ(key.ndim(), 3) << "key must be a 3D tensor"; + TVM_FFI_ICHECK_EQ(value.ndim(), 3) << "value must be a 3D tensor"; + + auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype()); + auto kv_data_type = dl_dtype_to_tllm_data_type(key.dtype()); + auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype()); + const auto stream = get_stream(query.device()); + int num_qo_heads = query.size(1); + int num_kv_heads = key.size(1); + int sum_seq_q = query.size(0); + int sum_seq_kv = key.size(0); + int head_dim_qk = query.size(2); + int head_dim_v = value.size(2); + int k_stride_keys_values = key.stride(0); + int k_stride_heads = key.stride(1); int k_stride_batch = key.numel(); - int v_stride_keys_values = value->strides[0]; - int v_stride_heads = value->strides[1]; + int v_stride_keys_values = value.stride(0); + int v_stride_heads = value.stride(1); int v_stride_batch = value.numel(); trtllm_ragged_attention_launcher( - out->data, query->data, key->data, value->data, workspace_buffer->data, - static_cast(seq_lens->data), static_cast(cum_seq_lens_q->data), - static_cast(cum_seq_lens_kv->data), attention_sinks_ptr, lse_ptr, q_data_type, - kv_data_type, o_data_type, max_q_len, max_kv_len, num_qo_heads, num_kv_heads, head_dim_qk, - head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, bmm2_scale, o_sf_scale, batch_size, - window_left, sm_count, enable_pdl, is_causal, k_stride_keys_values, k_stride_heads, - k_stride_batch, v_stride_keys_values, v_stride_heads, v_stride_batch, workspace_size, stream); + out.data_ptr(), query.data_ptr(), key.data_ptr(), value.data_ptr(), + workspace_buffer.data_ptr(), static_cast(seq_lens.data_ptr()), + static_cast(cum_seq_lens_q.data_ptr()), static_cast(cum_seq_lens_kv.data_ptr()), + attention_sinks_ptr, lse_ptr, q_data_type, kv_data_type, o_data_type, max_q_len, max_kv_len, + num_qo_heads, num_kv_heads, head_dim_qk, head_dim_v, sum_seq_q, sum_seq_kv, bmm1_scale, + bmm2_scale, o_sf_scale, batch_size, window_left, sm_count, enable_pdl, is_causal, + k_stride_keys_values, k_stride_heads, k_stride_batch, v_stride_keys_values, v_stride_heads, + v_stride_batch, workspace_size, stream); } namespace trtllm_cubin_loader { diff --git a/csrc/trtllm_fused_moe_kernel_launcher.cu b/csrc/trtllm_fused_moe_kernel_launcher.cu index edcfa5e176..3681b47fd2 100644 --- a/csrc/trtllm_fused_moe_kernel_launcher.cu +++ b/csrc/trtllm_fused_moe_kernel_launcher.cu @@ -50,9 +50,9 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states->device.device_id); + hidden_states.device().device_id); cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states->device.device_id); + hidden_states.device().device_id); return std::make_tuple(major, minor); }(); @@ -61,18 +61,18 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( << std::get<0>(device_props) << std::get<1>(device_props); if (use_routing_scales_on_input) { - TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_bfloat16) << "routing_logits must be bfloat16."; + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_bfloat16) << "routing_logits must be bfloat16."; } else { - TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float."; + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; } - TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) << "routing_logits has incorrect shape."; + TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits has incorrect shape."; if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 || - routing_bias.value()->dtype == dl_float32) + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts) + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) << "routing_bias has incorrect shape."; } @@ -98,7 +98,7 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; // Convert PyTorch dtype to TensorRT-LLM dtype - auto dtype = hidden_states->dtype; + auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { args.mDtypeElt = btg::Dtype::Fp16; } else if (dtype == dl_bfloat16) { @@ -109,23 +109,23 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported input dtype for MoE."; } - args.routing_logits = routing_logits->data; + args.routing_logits = routing_logits.data_ptr(); auto const routing_bias_dtype = - routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; auto btg_routing_bias_dtype = btg::Dtype::Fp32; if (routing_bias_dtype == dl_bfloat16) { btg_routing_bias_dtype = btg::Dtype::Bfloat16; } - args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; - args.hidden_states = hidden_states->data; - args.gemm1_weights = gemm1_weights->data; - args.output1_scales_scalar = static_cast(output1_scales_scalar->data); - args.output1_scales_gate_scalar = static_cast(output1_scales_gate_scalar->data); - args.gemm2_weights = gemm2_weights->data; - args.output2_scales_scalar = static_cast(output2_scales_scalar->data); - args.num_tokens = hidden_states->shape[0]; + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; + args.hidden_states = hidden_states.data_ptr(); + args.gemm1_weights = gemm1_weights.data_ptr(); + args.output1_scales_scalar = static_cast(output1_scales_scalar.data_ptr()); + args.output1_scales_gate_scalar = static_cast(output1_scales_gate_scalar.data_ptr()); + args.gemm2_weights = gemm2_weights.data_ptr(); + args.output2_scales_scalar = static_cast(output2_scales_scalar.data_ptr()); + args.num_tokens = hidden_states.size(0); args.num_experts = num_experts; - args.hidden_size = hidden_states->shape[1]; + args.hidden_size = hidden_states.size(1); args.hidden_size_output = args.hidden_size; args.top_k = top_k; args.n_group = n_group; @@ -137,132 +137,132 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( args.mUseRoutingScalesOnInput = use_routing_scales_on_input; // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits->device); + Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device); + Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device); + alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); + alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits.device()); Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits.device()); Tensor expert_indexes = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits.device()); Tensor expert_count_histogram = alloc_tensor( {2 * 256}, dl_int32, // 256 is the max number of threads per block and max number of experts - routing_logits->device); + routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states->device); + // dl_float8_e4m3fn, hidden_states.device()); // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states->device); - Tensor gemm1_output = - alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, hidden_states->device); + // dl_float8_e4m3fn, hidden_states.device()); + Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, dl_uint8, + hidden_states.device()); Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states->device); + dl_float32, hidden_states.device()); Tensor activation_output = - alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states->device); + alloc_tensor({max_num_padded_tokens, intermediate_size}, dl_uint8, hidden_states.device()); Tensor activation_output_scale = alloc_tensor({intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states->device); + dl_float32, hidden_states.device()); Tensor gemm2_output = - alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states->device); + alloc_tensor({max_num_padded_tokens, args.hidden_size}, dl_bfloat16, hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits->device); + Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); + Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); + Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits.device()); tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(routing_logits->device); + cudaStream_t stream = get_stream(routing_logits.device()); routing_runner.run( - routing_logits->data, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, + routing_logits.data_ptr(), args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, static_cast(expert_indexes->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, + args.routed_scaling_factor, static_cast(expert_indexes.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*static_cast(permuted_idx_to_expanded_idx.data_ptr())*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, btg_routing_bias_dtype, use_routing_scales_on_input, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); // MoE kernel except routing - TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights->dtype, dl_float8_e4m3fn) << "gemm1_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights->ndim, 3) << "gemm1_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights->shape[1] % 2, 0) + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm1_weights.ndim(), 3) << "gemm1_weights must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights.size(1) % 2, 0) << "the second dimension of weights must be even."; - TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights->shape[1] / 2) + TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights.size(1) / 2) << "intermediate_size has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm1_weights->shape[2], hidden_states->shape[1]) + TVM_FFI_ICHECK_EQ(gemm1_weights.size(2), hidden_states.size(1)) << "the third dimension of weights must be equal to hidden_size."; TVM_FFI_ICHECK_EQ(intermediate_size % 128, 0) << "the second dimension of weights must be a multiple of 128."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output1_scales_scalar.dtype(), dl_float32) << "output1_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar->ndim, 1) << "output1_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output1_scales_scalar.ndim(), 1) << "output1_scales_scalar must be 1D."; + TVM_FFI_ICHECK_EQ(output1_scales_scalar.size(0), local_num_experts) << "output1_scales_scalar has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.dtype(), dl_float32) << "output1_scales_gate_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar->ndim, 1) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.ndim(), 1) << "output1_scales_gate_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.size(0), local_num_experts) << "output1_scales_gate_scalar has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(gemm2_weights->dtype, dl_float8_e4m3fn) << "gemm2_weights must be fp8."; - TVM_FFI_ICHECK_EQ(gemm2_weights->ndim, 3) << "gemm2_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights->shape[2], intermediate_size) + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm2_weights.ndim(), 3) << "gemm2_weights must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights.size(2), intermediate_size) << "the third dimension of weights must be equal to intermediate_size."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output2_scales_scalar.dtype(), dl_float32) << "output2_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar->ndim, 1) << "output2_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output2_scales_scalar.ndim(), 1) << "output2_scales_scalar must be 1D."; + TVM_FFI_ICHECK_EQ(output2_scales_scalar.size(0), local_num_experts) << "output2_scales_scalar has incorrect dim 0."; // allocate output - TVM_FFI_ICHECK_EQ(output->shape[0], args.num_tokens); - TVM_FFI_ICHECK_EQ(output->shape[1], args.hidden_size); + TVM_FFI_ICHECK_EQ(output.size(0), args.num_tokens); + TVM_FFI_ICHECK_EQ(output.size(1), args.hidden_size); CHECK_INPUT_TYPE(output, dl_bfloat16); CHECK_DEVICE(output, hidden_states); // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = max_num_padded_tokens; workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indexes->data); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); + workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); workspace.expanded_idx_to_permuted_idx = static_cast( - expanded_idx_to_permuted_idx->data); // Needed by activation/finalize kernels + expanded_idx_to_permuted_idx.data_ptr()); // Needed by activation/finalize kernels workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx->data); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights->data; // Consumed by finalize kernel + static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel + workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx->data); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit->data); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas->data); + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output->data; - workspace.gemm1_output_scale = static_cast(gemm1_output_scale->data); + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); // activation intermediate ws - workspace.activation_output = activation_output->data; - workspace.activation_output_scale = static_cast(activation_output_scale->data); + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output->data; + workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; - args.output = output->data; + args.output = output.data_ptr(); args.output_scale = nullptr; tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner moe_runner( @@ -274,13 +274,13 @@ void trtllm_fp8_per_tensor_scale_moe_launcher( auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states->device); + alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states->device); - workspace.bmm1_workspace = workspace_fc1->data; - workspace.bmm2_workspace = workspace_fc2->data; - cudaStream_t moe_stream = get_stream(hidden_states->device); - moe_runner.run(args, workspace, hidden_states->device.device_id, moe_stream, moeConfigIndex, + alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, enable_pdl); } @@ -292,7 +292,7 @@ void trtllm_fp8_per_tensor_scale_moe( int64_t n_group, int64_t topk_group, int64_t intermediate_size, int64_t local_expert_offset, int64_t local_num_experts, double routed_scaling_factor, bool use_routing_scales_on_input, int64_t tile_tokens_dim, int64_t routing_method_type, bool enable_pdl) { - auto dtype = hidden_states->dtype; + auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { trtllm_fp8_per_tensor_scale_moe_launcher( routing_logits, routing_bias, hidden_states, gemm1_weights, output1_scales_scalar, @@ -318,9 +318,9 @@ void trtllm_fp8_block_scale_moe_launcher( static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states->device.device_id); + hidden_states.device().device_id); cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states->device.device_id); + hidden_states.device().device_id); return std::make_tuple(major, minor); }(); @@ -328,18 +328,18 @@ void trtllm_fp8_block_scale_moe_launcher( << "This kernel requires 10.x architecture. Current device has SM " << std::get<0>(device_props) << std::get<1>(device_props); - TVM_FFI_ICHECK_EQ(routing_logits->dtype, dl_float32) << "routing_logits must be float."; - TVM_FFI_ICHECK_EQ(routing_logits->ndim, 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits->shape[0], hidden_states->shape[0]) + TVM_FFI_ICHECK_EQ(routing_logits.dtype(), dl_float32) << "routing_logits must be float."; + TVM_FFI_ICHECK_EQ(routing_logits.ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.size(0), hidden_states.size(0)) << "routing_logits and hidden_states must have the same number of tokens."; - TVM_FFI_ICHECK_EQ(routing_logits->shape[1], num_experts) + TVM_FFI_ICHECK_EQ(routing_logits.size(1), num_experts) << "routing_logits dim1 must match num_experts."; if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 || - routing_bias.value()->dtype == dl_float32) + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts) + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) << "routing_bias has incorrect shape."; } @@ -363,7 +363,7 @@ void trtllm_fp8_block_scale_moe_launcher( tensorrt_llm::kernels::trtllmgen_moe::MoE::MoEWorkspace workspace; // Convert PyTorch dtype to TensorRT-LLM dtype - auto dtype = hidden_states->dtype; + auto dtype = hidden_states.dtype(); if (dtype == dl_float16) { args.mDtypeElt = btg::Dtype::Fp16; } else if (dtype == dl_bfloat16) { @@ -375,21 +375,21 @@ void trtllm_fp8_block_scale_moe_launcher( } auto const routing_bias_dtype = - routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; auto btg_routing_bias_dtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; - args.routing_logits = static_cast(routing_logits->data); - args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; - args.hidden_states = hidden_states->data; - args.hidden_states_scale = static_cast(hidden_states_scale->data); - args.gemm1_weights = gemm1_weights->data; - args.gemm1_weights_scale = static_cast(gemm1_weights_scale->data); - args.gemm2_weights = gemm2_weights->data; - args.gemm2_weights_scale = static_cast(gemm2_weights_scale->data); - args.num_tokens = hidden_states->shape[0]; + args.routing_logits = static_cast(routing_logits.data_ptr()); + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; + args.hidden_states = hidden_states.data_ptr(); + args.hidden_states_scale = static_cast(hidden_states_scale.data_ptr()); + args.gemm1_weights = gemm1_weights.data_ptr(); + args.gemm1_weights_scale = static_cast(gemm1_weights_scale.data_ptr()); + args.gemm2_weights = gemm2_weights.data_ptr(); + args.gemm2_weights_scale = static_cast(gemm2_weights_scale.data_ptr()); + args.num_tokens = hidden_states.size(0); args.num_experts = num_experts; - args.hidden_size = hidden_states->shape[1]; + args.hidden_size = hidden_states.size(1); args.hidden_size_output = args.hidden_size; args.top_k = top_k; args.n_group = n_group; @@ -401,7 +401,7 @@ void trtllm_fp8_block_scale_moe_launcher( args.mUseDeepSeekFp8 = true; // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits->device); + Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, routing_logits.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); @@ -411,172 +411,176 @@ void trtllm_fp8_block_scale_moe_launcher( int32_t max_num_padded_tokens_gemm2 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits->device); + Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, routing_logits.device()); Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits->device); + alloc_tensor({args.num_tokens * args.top_k}, dl_int32, routing_logits.device()); Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits->device); + alloc_tensor({max_num_padded_tokens}, dl_int32, routing_logits.device()); Tensor expert_weights = - alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_bfloat16, routing_logits.device()); // NOTE: the output type of routing kernel is currently always bfloat16 Tensor expert_indexes = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_int32, routing_logits.device()); int64_t const size_of_expert_count_histogram = std::max(num_experts * 2, int64_t(256 * 2)); Tensor expert_count_histogram = alloc_tensor( {size_of_expert_count_histogram}, dl_int32, // 256 is the max number of threads per block and max number of experts - routing_logits->device); + routing_logits.device()); // allocate workspace for activation/gemm/finalize kernels // Tensor gemm1_output = alloc_tensor({max_num_padded_tokens, 2 * intermediate_size}, - // dl_float8_e4m3fn, hidden_states->device); + // dl_float8_e4m3fn, hidden_states.device()); // Tensor activation_output = alloc_tensor({max_num_padded_tokens, intermediate_size}, - // dl_float8_e4m3fn, hidden_states->device); + // dl_float8_e4m3fn, hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, 2 * intermediate_size}, dl_uint8, - hidden_states->device); + hidden_states.device()); Tensor gemm1_output_scale = alloc_tensor({2 * intermediate_size / 128, max_num_padded_tokens}, - dl_float32, hidden_states->device); + dl_float32, hidden_states.device()); Tensor activation_output = alloc_tensor({max_num_padded_tokens_gemm1, intermediate_size}, - dl_uint8, hidden_states->device); + dl_uint8, hidden_states.device()); Tensor activation_output_scale = alloc_tensor( - {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states->device); + {intermediate_size / 128, max_num_padded_tokens_gemm1}, dl_float32, hidden_states.device()); Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, - hidden_states->device); + hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits->device); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits->device); + Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); + Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, routing_logits.device()); + Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, routing_logits.device()); tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(routing_logits->device); - routing_runner.run( - static_cast(routing_logits->data), args.routing_bias, args.num_tokens, - args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, - args.local_num_experts, args.routed_scaling_factor, static_cast(expert_indexes->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr /*static_cast(permuted_idx_to_expanded_idx->data)*/, - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, - false /* use_routing_scales_on_input */, true /* use_deep_seek_fp8 */, - static_cast(routing_method_type), stream); + cudaStream_t stream = get_stream(routing_logits.device()); + routing_runner.run(static_cast(routing_logits.data_ptr()), args.routing_bias, + args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, + args.local_expert_offset, args.local_num_experts, args.routed_scaling_factor, + static_cast(expert_indexes.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr /*static_cast(permuted_idx_to_expanded_idx.data_ptr())*/, + static_cast(permuted_idx_to_token_idx.data_ptr()), + expert_weights.data_ptr(), static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, + btg_routing_bias_dtype, false /* use_routing_scales_on_input */, + true /* use_deep_seek_fp8 */, + static_cast(routing_method_type), stream); // MoE kernel except routing - TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; - TVM_FFI_ICHECK_EQ(hidden_states_scale->dtype, dl_float32) << "hidden_states_scale must be float."; - TVM_FFI_ICHECK_EQ(hidden_states_scale->ndim, 2) << "hidden_states_scale must be 2D."; - TVM_FFI_ICHECK_EQ(hidden_states_scale->shape[0], hidden_states->shape[1] / 128) + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.dtype(), dl_float32) + << "hidden_states_scale must be float."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.ndim(), 2) << "hidden_states_scale must be 2D."; + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(0), hidden_states.size(1) / 128) << "hidden_states_scale dim0 must match hidden_states dim1 / 128."; - TVM_FFI_ICHECK_EQ(hidden_states_scale->shape[1], args.num_tokens) + TVM_FFI_ICHECK_EQ(hidden_states_scale.size(1), args.num_tokens) << "hidden_states_scale dim1 must match num_tokens."; - TVM_FFI_ICHECK_EQ(gemm1_weights->dtype, dl_float8_e4m3fn) << "gemm1_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_float8_e4m3fn) << "gemm1_weights must be fp8."; - TVM_FFI_ICHECK(gemm1_weights->ndim == 3 || gemm1_weights->ndim == 4) + TVM_FFI_ICHECK(gemm1_weights.ndim() == 3 || gemm1_weights.ndim() == 4) << "gemm1_weights must be 3D or 4D."; { int64_t Mn = 0, K = 0; - if (gemm1_weights->ndim == 3) { + if (gemm1_weights.ndim() == 3) { // MajorK [num_experts, M, K] - Mn = gemm1_weights->shape[1]; - K = gemm1_weights->shape[2]; - } else if (gemm1_weights->ndim == 4) { + Mn = gemm1_weights.size(1); + K = gemm1_weights.size(2); + } else if (gemm1_weights.ndim() == 4) { // BlockMajorK [num_experts, K/block_k, M, block_k] - Mn = gemm1_weights->shape[2]; - int64_t block_k = gemm1_weights->shape[3]; - K = gemm1_weights->shape[1] * block_k; + Mn = gemm1_weights.size(2); + int64_t block_k = gemm1_weights.size(3); + K = gemm1_weights.size(1) * block_k; } TVM_FFI_ICHECK_EQ(Mn % 2, 0) << "the second dimension of weights must be even."; TVM_FFI_ICHECK_EQ(intermediate_size, Mn / 2) << "intermediate_size has incorrect shape."; - TVM_FFI_ICHECK_EQ(K, hidden_states->shape[1]) + TVM_FFI_ICHECK_EQ(K, hidden_states.size(1)) << "the third dimension of weights must be equal to hidden_size."; } - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->dtype, dl_float32) << "gemm1_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->ndim, 3) << "gemm1_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float32) + << "gemm1_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), local_num_experts) << "gemm1_weights_scale has incorrect shape."; TVM_FFI_ICHECK_EQ(intermediate_size % 128, 0) << "the second dimension of weights must be a multiple of 128."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[1], 2 * intermediate_size / 128) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * intermediate_size / 128) << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[2], args.hidden_size / 128) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args.hidden_size / 128) << "gemm1_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights->dtype, dl_float8_e4m3fn) << "gemm2_weights must be fp8."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_float8_e4m3fn) << "gemm2_weights must be fp8."; - TVM_FFI_ICHECK(gemm2_weights->ndim == 3 || gemm2_weights->ndim == 4) + TVM_FFI_ICHECK(gemm2_weights.ndim() == 3 || gemm2_weights.ndim() == 4) << "gemm2_weights must be 3D or 4D."; { int64_t K = 0; - if (gemm2_weights->ndim == 3) { + if (gemm2_weights.ndim() == 3) { // MajorK [num_experts, M, K] - K = gemm2_weights->shape[2]; - } else if (gemm2_weights->ndim == 4) { + K = gemm2_weights.size(2); + } else if (gemm2_weights.ndim() == 4) { // BlockMajorK [num_experts, K/block_k, M, block_k] - int64_t block_k = gemm2_weights->shape[3]; - K = gemm2_weights->shape[1] * block_k; + int64_t block_k = gemm2_weights.size(3); + K = gemm2_weights.size(1) * block_k; } TVM_FFI_ICHECK_EQ(K, intermediate_size) << "the third dimension of weights must be equal to intermediate_size."; } - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->dtype, dl_float32) << "gemm2_weights_scale must be float."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->ndim, 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float32) + << "gemm2_weights_scale must be float."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), local_num_experts) << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[1], args.hidden_size / 128) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args.hidden_size / 128) << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[2], intermediate_size / 128) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), intermediate_size / 128) << "gemm2_weights_scale has incorrect shape."; - TVM_FFI_ICHECK_EQ(output->shape[0], args.num_tokens) << "output has incorrect shape."; - TVM_FFI_ICHECK_EQ(output->shape[1], args.hidden_size) << "output has incorrect shape."; - TVM_FFI_ICHECK_EQ(output->dtype, dl_bfloat16) << "output must be bf16."; + TVM_FFI_ICHECK_EQ(output.size(0), args.num_tokens) << "output has incorrect shape."; + TVM_FFI_ICHECK_EQ(output.size(1), args.hidden_size) << "output has incorrect shape."; + TVM_FFI_ICHECK_EQ(output.dtype(), dl_bfloat16) << "output must be bf16."; // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indexes->data); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); + workspace.routing_expert_indexes = static_cast(expert_indexes.data_ptr()); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); workspace.expanded_idx_to_permuted_idx = static_cast( - expanded_idx_to_permuted_idx->data); // Needed by activation/finalize kernels + expanded_idx_to_permuted_idx.data_ptr()); // Needed by activation/finalize kernels workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx->data); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights->data; // Consumed by finalize kernel + static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel + workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx->data); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit->data); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas->data); + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output->data; - workspace.gemm1_output_scale = static_cast(gemm1_output_scale->data); + workspace.gemm1_output = gemm1_output.data_ptr(); + workspace.gemm1_output_scale = static_cast(gemm1_output_scale.data_ptr()); // activation intermediate ws - workspace.activation_output = activation_output->data; - workspace.activation_output_scale = static_cast(activation_output_scale->data); + workspace.activation_output = activation_output.data_ptr(); + workspace.activation_output_scale = static_cast(activation_output_scale.data_ptr()); // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output->data; + workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; - args.output = output->data; + args.output = output.data_ptr(); args.output_scale = nullptr; auto workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states->device); + alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states->device); - workspace.bmm1_workspace = workspace_fc1->data; - workspace.bmm2_workspace = workspace_fc2->data; + alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); - cudaStream_t moe_stream = get_stream(hidden_states->device); - moe_runner.run(args, workspace, hidden_states->device.device_id, moe_stream, moeConfigIndex, + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, enable_pdl); } @@ -590,7 +594,7 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional double routed_scaling_factor, int64_t tile_tokens_dim, int64_t routing_method_type, bool use_shuffled_weight, int64_t weight_layout, bool enable_pdl) { - auto dtype = hidden_states->dtype; + auto dtype = hidden_states.dtype(); if (dtype == dl_float16 || dtype == dl_bfloat16 || dtype == dl_float8_e4m3fn) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; @@ -606,8 +610,8 @@ void trtllm_fp8_block_scale_moe(TensorView routing_logits, Optional static_cast(weight_layout)); // Always use fallback config (equivalent to moeConfigIndex == -1 case from original code) - auto const num_tokens = hidden_states->shape[0]; - auto const hidden_size = hidden_states->shape[1]; + auto const num_tokens = hidden_states.size(0); + auto const hidden_size = hidden_states.size(1); int64_t moeConfigIndex = mRunner->getDefaultValidConfigIndex( top_k, hidden_size, intermediate_size, local_num_experts, num_tokens); @@ -644,9 +648,9 @@ Array trtllm_fp4_block_scale_moe_launcher( static const std::tuple device_props = [hidden_states] { int major, minor; cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, - hidden_states->device.device_id); + hidden_states.device().device_id); cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, - hidden_states->device.device_id); + hidden_states.device().device_id); return std::make_tuple(major, minor); }(); @@ -687,20 +691,20 @@ Array trtllm_fp4_block_scale_moe_launcher( } if (routing_logits.has_value()) { - TVM_FFI_ICHECK(routing_logits.value()->dtype == dl_float32 || - routing_logits.value()->dtype == dl_bfloat16) + TVM_FFI_ICHECK(routing_logits.value().dtype() == dl_float32 || + routing_logits.value().dtype() == dl_bfloat16) << "routing_logits must be float or bfloat16."; - TVM_FFI_ICHECK_EQ(routing_logits.value()->ndim, 2) << "routing_logits must be 2D."; - TVM_FFI_ICHECK_EQ(routing_logits.value()->shape[1], num_experts) + TVM_FFI_ICHECK_EQ(routing_logits.value().ndim(), 2) << "routing_logits must be 2D."; + TVM_FFI_ICHECK_EQ(routing_logits.value().size(1), num_experts) << "routing_logits has incorrect shape."; } if (routing_bias.has_value()) { - TVM_FFI_ICHECK(routing_bias.value()->dtype == dl_bfloat16 || - routing_bias.value()->dtype == dl_float32) + TVM_FFI_ICHECK(routing_bias.value().dtype() == dl_bfloat16 || + routing_bias.value().dtype() == dl_float32) << "routing_bias must be bfloat16 or float."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->ndim, 1) << "routing_bias must be 1D."; - TVM_FFI_ICHECK_EQ(routing_bias.value()->shape[0], num_experts) + TVM_FFI_ICHECK_EQ(routing_bias.value().ndim(), 1) << "routing_bias must be 1D."; + TVM_FFI_ICHECK_EQ(routing_bias.value().size(0), num_experts) << "routing_bias has incorrect shape."; } @@ -743,38 +747,38 @@ Array trtllm_fp4_block_scale_moe_launcher( // setup args args.mDtypeElt = dtype_act; // note: the assumption is that output data type is always Bfloat16 (the default) - auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value()->dtype : dl_bfloat16; + auto routing_bias_dtype = routing_bias.has_value() ? routing_bias.value().dtype() : dl_bfloat16; auto btg_routing_bias_dtype = routing_bias_dtype == dl_bfloat16 ? btg::Dtype::Bfloat16 : btg::Dtype::Fp32; // We shouln't use args.mDtypeExpW since it indicates the output data type of routing kernel, // which is currently always bfloat16 for routing kernel while the data type of routing bias now // can be fp32 - args.routing_logits = routing_logits.has_value() ? routing_logits.value()->data : nullptr; - args.routing_bias = routing_bias.has_value() ? routing_bias.value()->data : nullptr; - args.hidden_states = hidden_states->data; + args.routing_logits = routing_logits.has_value() ? routing_logits.value().data_ptr() : nullptr; + args.routing_bias = routing_bias.has_value() ? routing_bias.value().data_ptr() : nullptr; + args.hidden_states = hidden_states.data_ptr(); args.hidden_states_scale = - hidden_states_scale.has_value() ? hidden_states_scale.value()->data : nullptr; - args.gemm1_weights = gemm1_weights->data; - args.gemm1_weights_scale = gemm1_weights_scale->data; + hidden_states_scale.has_value() ? hidden_states_scale.value().data_ptr() : nullptr; + args.gemm1_weights = gemm1_weights.data_ptr(); + args.gemm1_weights_scale = gemm1_weights_scale.data_ptr(); args.gemm1_bias = - gemm1_bias.has_value() ? static_cast(gemm1_bias.value()->data) : nullptr; + gemm1_bias.has_value() ? static_cast(gemm1_bias.value().data_ptr()) : nullptr; args.gemm1_alpha = - gemm1_alpha.has_value() ? static_cast(gemm1_alpha.value()->data) : nullptr; + gemm1_alpha.has_value() ? static_cast(gemm1_alpha.value().data_ptr()) : nullptr; args.gemm1_beta = - gemm1_beta.has_value() ? static_cast(gemm1_beta.value()->data) : nullptr; + gemm1_beta.has_value() ? static_cast(gemm1_beta.value().data_ptr()) : nullptr; args.gemm1_clamp_limit = gemm1_clamp_limit.has_value() - ? static_cast(gemm1_clamp_limit.value()->data) + ? static_cast(gemm1_clamp_limit.value().data_ptr()) : nullptr; - args.gemm2_weights = gemm2_weights->data; - args.gemm2_weights_scale = gemm2_weights_scale->data; + args.gemm2_weights = gemm2_weights.data_ptr(); + args.gemm2_weights_scale = gemm2_weights_scale.data_ptr(); args.gemm2_bias = - gemm2_bias.has_value() ? static_cast(gemm2_bias.value()->data) : nullptr; - args.num_tokens = hidden_states->shape[0]; + gemm2_bias.has_value() ? static_cast(gemm2_bias.value().data_ptr()) : nullptr; + args.num_tokens = hidden_states.size(0); args.num_experts = num_experts; // * 2 to compensate for the fact that sizeof(hidden_states.dtype) is 1 because we pack 2 e2m1 // into 1 byte. auto const hidden_states_hidden_size = - dtype_act == btg::Dtype::E2m1 ? hidden_states->shape[1] * 2 : hidden_states->shape[1]; + dtype_act == btg::Dtype::E2m1 ? hidden_states.size(1) * 2 : hidden_states.size(1); args.hidden_size = hidden_states_hidden_size; args.hidden_size_output = args.hidden_size; args.top_k = top_k; @@ -786,7 +790,7 @@ Array trtllm_fp4_block_scale_moe_launcher( args.intermediate_size = intermediate_size; // allocate workspace for routing kernel - Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, hidden_states->device); + Tensor num_tokens_per_expert = alloc_tensor({num_experts}, dl_int32, hidden_states.device()); int32_t max_num_padded_tokens = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxPermutedPaddedCount( args.num_tokens, top_k, num_experts, tile_tokens_dim); @@ -796,21 +800,21 @@ Array trtllm_fp4_block_scale_moe_launcher( int32_t max_num_padded_tokens_gemm2 = tensorrt_llm::kernels::trtllmgen_moe::Routing::maybeGetMinTokenCount( max_num_padded_tokens, args.hidden_size, btg::dtypeGetNumBits(args.mDtypeOut)); - Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states->device); + Tensor total_num_padded_tokens = alloc_tensor({1}, dl_int32, hidden_states.device()); Tensor expanded_idx_to_permuted_idx = - alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states->device); + alloc_tensor({args.num_tokens, args.top_k}, dl_int32, hidden_states.device()); Tensor permuted_idx_to_token_idx = - alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states->device); + alloc_tensor({max_num_padded_tokens}, dl_int32, hidden_states.device()); // Tensor expert_weights = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states->device); + // {args.num_tokens, args.top_k}, dl_bfloat16, hidden_states.device()); // Tensor expert_indexes = alloc_tensor( - // {args.num_tokens, args.top_k}, dl_int32, hidden_states->device); + // {args.num_tokens, args.top_k}, dl_int32, hidden_states.device()); int constexpr MAX_NUM_EXPERTS = 384; Tensor expert_count_histogram = alloc_tensor( {2 * MAX_NUM_EXPERTS}, dl_int32, // 256 is the max number of threads per block and max number of experts - hidden_states->device); + hidden_states.device()); auto const sf_vec_size = dtype_weights == btg::Dtype::MxE2m1 ? 32 : 16; @@ -819,47 +823,48 @@ Array trtllm_fp4_block_scale_moe_launcher( dtype_act == btg::Dtype::E2m1 ? intermediate_size / 2 : intermediate_size; // Tensor gemm1_output = alloc_tensor( // {max_num_padded_tokens, gemm1_output_hidden}, - // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, hidden_states->device); + // dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_float8_e4m3fn, + // hidden_states.device()); Tensor gemm1_output = alloc_tensor({max_num_padded_tokens_gemm1, gemm1_output_hidden}, dtype_act == btg::Dtype::Bfloat16 ? dl_bfloat16 : dl_uint8, - hidden_states->device); + hidden_states.device()); Optional gemm1_output_scale = std::nullopt; if (dtype_act == btg::Dtype::E2m1 || dtype_act == btg::Dtype::MxE4m3) { int64_t sf_size = tensorrt_llm::computeSwizzledLayoutSFSize(max_num_padded_tokens_gemm1, intermediate_size / sf_vec_size); - // gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states->device); - gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states->device); + // gemm1_output_scale = alloc_tensor({sf_size}, dl_float8_e4m3fn, hidden_states.device()); + gemm1_output_scale = alloc_tensor({sf_size}, dl_uint8, hidden_states.device()); } Tensor gemm2_output = alloc_tensor({max_num_padded_tokens_gemm2, args.hidden_size}, dl_bfloat16, - hidden_states->device); + hidden_states.device()); int32_t max_num_ctas = tensorrt_llm::kernels::trtllmgen_moe::Routing::getMaxNumCtasInBatchDim( args.num_tokens, args.top_k, args.num_experts, tile_tokens_dim); - Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states->device); - Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states->device); - Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states->device); + Tensor cta_idx_xy_to_batch_idx = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + Tensor cta_idx_xy_to_mn_limit = alloc_tensor({max_num_ctas}, dl_int32, hidden_states.device()); + Tensor num_non_exiting_ctas = alloc_tensor({1}, dl_int32, hidden_states.device()); // // TopK routing // tensorrt_llm::kernels::trtllmgen_moe::Routing::Runner routing_runner(tile_tokens_dim); - cudaStream_t stream = get_stream(hidden_states->device); + cudaStream_t stream = get_stream(hidden_states.device()); routing_runner.run( args.routing_logits, args.routing_bias, args.num_tokens, args.num_experts, args.top_k, args.n_group, args.topk_group, args.local_expert_offset, args.local_num_experts, - args.routed_scaling_factor, static_cast(expert_indices->data), - static_cast(expert_count_histogram->data), - static_cast(total_num_padded_tokens->data), - static_cast(expanded_idx_to_permuted_idx->data), - nullptr, /*static_cast(permuted_idx_to_expanded_idx->data),*/ - static_cast(permuted_idx_to_token_idx->data), expert_weights->data, - static_cast(num_tokens_per_expert->data), - static_cast(cta_idx_xy_to_batch_idx->data), - static_cast(cta_idx_xy_to_mn_limit->data), - static_cast(num_non_exiting_ctas->data), args.mDtypeElt, btg_routing_bias_dtype, + args.routed_scaling_factor, static_cast(expert_indices.data_ptr()), + static_cast(expert_count_histogram.data_ptr()), + static_cast(total_num_padded_tokens.data_ptr()), + static_cast(expanded_idx_to_permuted_idx.data_ptr()), + nullptr, /*static_cast(permuted_idx_to_expanded_idx.data_ptr()),*/ + static_cast(permuted_idx_to_token_idx.data_ptr()), expert_weights.data_ptr(), + static_cast(num_tokens_per_expert.data_ptr()), + static_cast(cta_idx_xy_to_batch_idx.data_ptr()), + static_cast(cta_idx_xy_to_mn_limit.data_ptr()), + static_cast(num_non_exiting_ctas.data_ptr()), args.mDtypeElt, btg_routing_bias_dtype, false /* use_routing_scales_on_input */, false /* use_deep_seek_fp8 */, static_cast(routing_method_type), stream); @@ -868,17 +873,17 @@ Array trtllm_fp4_block_scale_moe_launcher( // if (dtype_act == btg::Dtype::E2m1) { - TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_uint8) << "hidden_states must be byte."; + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_uint8) << "hidden_states must be byte."; } else if (dtype_act == btg::Dtype::E4m3 || dtype_act == btg::Dtype::MxE4m3) { - TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_float8_e4m3fn) << "hidden_states must be fp8."; + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_float8_e4m3fn) << "hidden_states must be fp8."; } else if (dtype_act == btg::Dtype::Bfloat16) { - TVM_FFI_ICHECK_EQ(hidden_states->dtype, dl_bfloat16) << "hidden_states must be bfloat16."; + TVM_FFI_ICHECK_EQ(hidden_states.dtype(), dl_bfloat16) << "hidden_states must be bfloat16."; } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported act dtype."; } if (hidden_states_scale.has_value()) { - TVM_FFI_ICHECK_EQ(hidden_states_scale.value()->dtype, dl_float8_e4m3fn) + TVM_FFI_ICHECK_EQ(hidden_states_scale.value().dtype(), dl_float8_e4m3fn) << "hidden_states_scale must be fp8."; TVM_FFI_ICHECK_EQ( @@ -887,159 +892,159 @@ Array trtllm_fp4_block_scale_moe_launcher( << "hidden_states_scale has incorrect size"; } - TVM_FFI_ICHECK_EQ(gemm1_weights->dtype, dl_uint8) << "gemm1_weights must be byte."; + TVM_FFI_ICHECK_EQ(gemm1_weights.dtype(), dl_uint8) << "gemm1_weights must be byte."; - TVM_FFI_ICHECK_EQ(gemm1_weights->ndim, 3) << "gemm1_weights must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights->shape[1] % 2, 0) + TVM_FFI_ICHECK_EQ(gemm1_weights.ndim(), 3) << "gemm1_weights must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights.size(1) % 2, 0) << "the second dimension of weights must be even."; - TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights->shape[1] / 2) + TVM_FFI_ICHECK_EQ(intermediate_size, gemm1_weights.size(1) / 2) << "intermediate_size has incorrect dim 1."; // This check passes even though the actual shape of the weights[2] and hidden_states[1] is // 2 times larger due to the fact that 2 e2m1 are packed into 1 byte. TVM_FFI_ICHECK_EQ( - gemm1_weights->shape[2], - (dtype_act == btg::Dtype::E2m1 ? hidden_states->shape[1] : hidden_states->shape[1] / 2)) + gemm1_weights.size(2), + (dtype_act == btg::Dtype::E2m1 ? hidden_states.size(1) : hidden_states.size(1) / 2)) << "the third dimension of weights must be equal to hidden_size."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->dtype, dl_float8_e4m3fn) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.dtype(), dl_float8_e4m3fn) << "gemm1_weights_scale must be fp8."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->ndim, 3) << "gemm1_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.ndim(), 3) << "gemm1_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(0), local_num_experts) << "gemm1_weights_scale has incorrect dim 0."; TVM_FFI_ICHECK_EQ(intermediate_size % sf_vec_size, 0) << "the second dimension of weights must be a multiple of ", sf_vec_size; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[1], 2 * intermediate_size) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(1), 2 * intermediate_size) << "gemm1_weights_scale has incorrect dim 1."; - TVM_FFI_ICHECK_EQ(gemm1_weights_scale->shape[2], args.hidden_size / sf_vec_size) + TVM_FFI_ICHECK_EQ(gemm1_weights_scale.size(2), args.hidden_size / sf_vec_size) << "gemm1_weights_scale has incorrect dim 2."; if (gemm1_bias.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_bias.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(gemm1_bias.value().dtype(), dl_float32) << "gemm1_bias must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_bias.value()->dtype); - TVM_FFI_ICHECK_EQ(gemm1_bias.value()->ndim, 2) << "gemm1_bias must be 2D."; - TVM_FFI_ICHECK_EQ(gemm1_bias.value()->shape[0], local_num_experts) + << tvm::ffi::DLDataTypeToString(gemm1_bias.value().dtype()); + TVM_FFI_ICHECK_EQ(gemm1_bias.value().ndim(), 2) << "gemm1_bias must be 2D."; + TVM_FFI_ICHECK_EQ(gemm1_bias.value().size(0), local_num_experts) << "gemm1_bias has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(gemm1_bias.value()->shape[1], 2 * intermediate_size) + TVM_FFI_ICHECK_EQ(gemm1_bias.value().size(1), 2 * intermediate_size) << "gemm1_bias has incorrect dim 1."; } if (gemm1_alpha.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_alpha.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(gemm1_alpha.value().dtype(), dl_float32) << "gemm1_alpha must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_alpha.value()->dtype); - TVM_FFI_ICHECK_EQ(gemm1_alpha.value()->ndim, 1) << "gemm1_alpha must be 1D."; - TVM_FFI_ICHECK_EQ(gemm1_alpha.value()->shape[0], local_num_experts) + << tvm::ffi::DLDataTypeToString(gemm1_alpha.value().dtype()); + TVM_FFI_ICHECK_EQ(gemm1_alpha.value().ndim(), 1) << "gemm1_alpha must be 1D."; + TVM_FFI_ICHECK_EQ(gemm1_alpha.value().size(0), local_num_experts) << "gemm1_alpha has incorrect dim 0."; } if (gemm1_beta.has_value()) { - TVM_FFI_ICHECK_EQ(gemm1_beta.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(gemm1_beta.value().dtype(), dl_float32) << "gemm1_beta must be float, got " - << tvm::ffi::DLDataTypeToString(gemm1_beta.value()->dtype); - TVM_FFI_ICHECK_EQ(gemm1_beta.value()->ndim, 1) << "gemm1_beta must be 1D."; - TVM_FFI_ICHECK_EQ(gemm1_beta.value()->shape[0], local_num_experts) + << tvm::ffi::DLDataTypeToString(gemm1_beta.value().dtype()); + TVM_FFI_ICHECK_EQ(gemm1_beta.value().ndim(), 1) << "gemm1_beta must be 1D."; + TVM_FFI_ICHECK_EQ(gemm1_beta.value().size(0), local_num_experts) << "gemm1_beta has incorrect dim 0."; } - TVM_FFI_ICHECK_EQ(gemm2_weights->dtype, dl_uint8) << "gemm2_weights must be byte."; + TVM_FFI_ICHECK_EQ(gemm2_weights.dtype(), dl_uint8) << "gemm2_weights must be byte."; - TVM_FFI_ICHECK_EQ(gemm2_weights->ndim, 3) << "gemm2_weights must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights.ndim(), 3) << "gemm2_weights must be 3D."; // / 2 to compensate for the fact that we pack 2 e2m1 into 1 byte. - TVM_FFI_ICHECK_EQ(gemm2_weights->shape[2], intermediate_size / 2) + TVM_FFI_ICHECK_EQ(gemm2_weights.size(2), intermediate_size / 2) << "the third dimension of weights must be equal to intermediate_size."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->dtype, dl_float8_e4m3fn) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.dtype(), dl_float8_e4m3fn) << "gemm2_weights_scale must be fp8."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->ndim, 3) << "gemm2_weights_scale must be 3D."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.ndim(), 3) << "gemm2_weights_scale must be 3D."; + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(0), local_num_experts) << "gemm2_weights_scale has incorrect dim 0."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[1], args.hidden_size) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(1), args.hidden_size) << "gemm2_weights_scale has incorrect dim 1."; - TVM_FFI_ICHECK_EQ(gemm2_weights_scale->shape[2], intermediate_size / sf_vec_size) + TVM_FFI_ICHECK_EQ(gemm2_weights_scale.size(2), intermediate_size / sf_vec_size) << "gemm2_weights_scale has incorrect dim 2."; if (output1_scales_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().dtype(), dl_float32) << "output1_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value()->ndim, 1) + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().ndim(), 1) << "output1_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_scalar.value()->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output1_scales_scalar.value().size(0), local_num_experts) << "output1_scales_scalar has incorrect dim 0."; } if (output1_scales_gate_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().dtype(), dl_float32) << "output1_scales_gate_scalar must be float."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value()->ndim, 1) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().ndim(), 1) << "output1_scales_gate_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value()->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output1_scales_gate_scalar.value().size(0), local_num_experts) << "output1_scales_gate_scalar has incorrect dim 0."; } if (output2_scales_scalar.has_value()) { - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value()->dtype, dl_float32) + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().dtype(), dl_float32) << "output2_scales_scalar must be float."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value()->ndim, 1) + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().ndim(), 1) << "output2_scales_scalar must be 1D."; - TVM_FFI_ICHECK_EQ(output2_scales_scalar.value()->shape[0], local_num_experts) + TVM_FFI_ICHECK_EQ(output2_scales_scalar.value().size(0), local_num_experts) << "output2_scales_scalar has incorrect dim 0."; } // setup workspace - workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens->data); + workspace.total_num_padded_tokens = static_cast(total_num_padded_tokens.data_ptr()); workspace.total_max_padded_tokens = std::max(max_num_padded_tokens_gemm1, max_num_padded_tokens_gemm2); workspace.ProjUpTileN = tile_tokens_dim; - workspace.routing_expert_indexes = static_cast(expert_indices->data); - workspace.permuted_idx_size = static_cast(total_num_padded_tokens->data); - workspace.expanded_idx_to_permuted_idx = - static_cast(expanded_idx_to_permuted_idx->data); // Needed by permute/finalize kernels + workspace.routing_expert_indexes = static_cast(expert_indices.data_ptr()); + workspace.permuted_idx_size = static_cast(total_num_padded_tokens.data_ptr()); + workspace.expanded_idx_to_permuted_idx = static_cast( + expanded_idx_to_permuted_idx.data_ptr()); // Needed by permute/finalize kernels workspace.permuted_idx_to_token_idx = - static_cast(permuted_idx_to_token_idx->data); // Needed by permuteGemm1 kernel - workspace.expert_weights = expert_weights->data; // Consumed by finalize kernel + static_cast(permuted_idx_to_token_idx.data_ptr()); // Needed by permuteGemm1 kernel + workspace.expert_weights = expert_weights.data_ptr(); // Consumed by finalize kernel - workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx->data); - workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit->data); - workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas->data); + workspace.cta_idx_xy_to_batch_idx = static_cast(cta_idx_xy_to_batch_idx.data_ptr()); + workspace.cta_idx_xy_to_mn_limit = static_cast(cta_idx_xy_to_mn_limit.data_ptr()); + workspace.num_non_exiting_ctas = static_cast(num_non_exiting_ctas.data_ptr()); workspace.hidden_states_scale_linear = nullptr; // gemm1 intermediate ws - workspace.gemm1_output = gemm1_output->data; + workspace.gemm1_output = gemm1_output.data_ptr(); workspace.gemm1_output_scale = gemm1_output_scale.has_value() - ? static_cast(gemm1_output_scale.value()->data) + ? static_cast(gemm1_output_scale.value().data_ptr()) : nullptr; // gemm2 intermediate ws - workspace.gemm2_output = gemm2_output->data; + workspace.gemm2_output = gemm2_output.data_ptr(); workspace.gemm2_output_scale = nullptr; - args.output = output->data; + args.output = output.data_ptr(); args.output_scale = nullptr; args.output1_scales_scalar = output1_scales_scalar.has_value() - ? static_cast(output1_scales_scalar.value()->data) + ? static_cast(output1_scales_scalar.value().data_ptr()) : nullptr; args.output1_scales_gate_scalar = output1_scales_gate_scalar.has_value() - ? static_cast(output1_scales_gate_scalar.value()->data) + ? static_cast(output1_scales_gate_scalar.value().data_ptr()) : nullptr; args.output2_scales_scalar = output2_scales_scalar.has_value() - ? static_cast(output2_scales_scalar.value()->data) + ? static_cast(output2_scales_scalar.value().data_ptr()) : nullptr; args.do_finalize = do_finalize; auto const workspace_sizes = moe_runner.getWorkspaceSizeInBytes(args, moeConfigIndex); Tensor workspace_fc1 = - alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states->device); + alloc_tensor({std::get<0>(workspace_sizes)}, dl_int8, hidden_states.device()); Tensor workspace_fc2 = - alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states->device); - workspace.bmm1_workspace = workspace_fc1->data; - workspace.bmm2_workspace = workspace_fc2->data; - cudaStream_t moe_stream = get_stream(hidden_states->device); - moe_runner.run(args, workspace, hidden_states->device.device_id, moe_stream, moeConfigIndex, + alloc_tensor({std::get<1>(workspace_sizes)}, dl_int8, hidden_states.device()); + workspace.bmm1_workspace = workspace_fc1.data_ptr(); + workspace.bmm2_workspace = workspace_fc2.data_ptr(); + cudaStream_t moe_stream = get_stream(hidden_states.device()); + moe_runner.run(args, workspace, hidden_states.device().device_id, moe_stream, moeConfigIndex, enable_pdl); if (!do_finalize) { @@ -1064,9 +1069,9 @@ Array trtllm_fp4_block_scale_moe( int64_t gated_act_type, TensorView output, int64_t config_index) { using RunnerType = tensorrt_llm::kernels::trtllmgen_moe::MoE::Runner; - int const num_tokens = hidden_states->shape[0]; - int hidden_size = hidden_states->shape[1]; - if (hidden_states->dtype == dl_uint8) hidden_size *= 2; + int const num_tokens = hidden_states.size(0); + int hidden_size = hidden_states.size(1); + if (hidden_states.dtype() == dl_uint8) hidden_size *= 2; int hidden_states_scale_vec_size = -1; if (hidden_states_scale.has_value()) { hidden_states_scale_vec_size = (num_tokens * hidden_size) / hidden_states_scale.value().numel(); @@ -1077,15 +1082,15 @@ Array trtllm_fp4_block_scale_moe( << "unsupported weight_scale_vec_size."; auto mDtypeWeights = weight_scale_vec_size == 16 ? btg::Dtype::E2m1 : btg::Dtype::MxE2m1; - TVM_FFI_ICHECK(gemm1_weights->dtype == dl_uint8 && gemm2_weights->dtype == dl_uint8) + TVM_FFI_ICHECK(gemm1_weights.dtype() == dl_uint8 && gemm2_weights.dtype() == dl_uint8) << "weights must be fp4 packed in uint8."; - TVM_FFI_ICHECK(hidden_states->dtype == dl_uint8 || hidden_states->dtype == dl_bfloat16 || - hidden_states->dtype == dl_float8_e4m3fn) + TVM_FFI_ICHECK(hidden_states.dtype() == dl_uint8 || hidden_states.dtype() == dl_bfloat16 || + hidden_states.dtype() == dl_float8_e4m3fn) << "hidden_states must be bf16, fp8 or uint8 (packed fp4)."; auto mDtypeAct = btg::Dtype::Bfloat16; - if (hidden_states->dtype == dl_uint8) { + if (hidden_states.dtype() == dl_uint8) { TVM_FFI_ICHECK(hidden_states_scale.has_value() && - hidden_states_scale.value()->dtype == dl_float8_e4m3fn) + hidden_states_scale.value().dtype() == dl_float8_e4m3fn) << "hidden_states_scale must be provided for fp4 activation."; if (hidden_states_scale_vec_size == 16) { mDtypeAct = btg::Dtype::E2m1; @@ -1094,7 +1099,7 @@ Array trtllm_fp4_block_scale_moe( } else { TVM_FFI_LOG_AND_THROW(NotImplementedError) << "Unsupported hidden state scale shape."; } - } else if (hidden_states->dtype == dl_float8_e4m3fn) { + } else if (hidden_states.dtype() == dl_float8_e4m3fn) { if (hidden_states_scale.has_value()) { if (hidden_states_scale_vec_size == 32) { mDtypeAct = btg::Dtype::MxE4m3; diff --git a/csrc/trtllm_gemm_runner.cu b/csrc/trtllm_gemm_runner.cu index 72cdb893a0..ffe0bc2cf8 100644 --- a/csrc/trtllm_gemm_runner.cu +++ b/csrc/trtllm_gemm_runner.cu @@ -266,13 +266,13 @@ void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, Tensor CHECK_INPUT(b); CHECK_INPUT(out); CHECK_INPUT(workspace_buffer); - TVM_FFI_ICHECK_EQ(workspace_buffer->ndim, 1); + TVM_FFI_ICHECK_EQ(workspace_buffer.ndim(), 1); CHECK_DIM(2, a); CHECK_DIM(2, b); - TVM_FFI_ICHECK_EQ(a->dtype, b->dtype); - TVM_FFI_ICHECK(a->dtype == dl_float8_e4m3fn || a->dtype == dl_uint8) + TVM_FFI_ICHECK_EQ(a.dtype(), b.dtype()); + TVM_FFI_ICHECK(a.dtype() == dl_float8_e4m3fn || a.dtype() == dl_uint8) << "a must be a Float8 or Byte(e2m1) tensor"; - bool is_fp8 = a->dtype == dl_float8_e4m3fn; + bool is_fp8 = a.dtype() == dl_float8_e4m3fn; if (is_fp8) { TVM_FFI_ICHECK(!globalScale.has_value()) << "globalScale must be a none tensor"; } else { @@ -283,11 +283,11 @@ void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, Tensor } } - int32_t m = a->shape[0]; - int32_t k = is_fp8 ? a->shape[1] : a->shape[1] * 2; - int32_t n = b->shape[0]; - TVM_FFI_ICHECK_EQ(b->shape[1], a->shape[1]) << "Matrix dimensions don't match for multiplication"; - TVM_FFI_ICHECK(out->shape[0] == m && out->shape[1] == n) << "Output tensor has wrong dimensions"; + int32_t m = a.size(0); + int32_t k = is_fp8 ? a.size(1) : a.size(1) * 2; + int32_t n = b.size(0); + TVM_FFI_ICHECK_EQ(b.size(1), a.size(1)) << "Matrix dimensions don't match for multiplication"; + TVM_FFI_ICHECK(out.size(0) == m && out.size(1) == n) << "Output tensor has wrong dimensions"; auto runner = flashinfer::TrtllmGenGemmRunner(flashinfer::TrtllmGenGemmRunnerOptions{ .eltType = is_fp8 ? gemm::trtllm::gen::Dtype::E4m3 : gemm::trtllm::gen::Dtype::E2m1, @@ -301,22 +301,22 @@ void trtllm_gemm(TensorView workspace_buffer, TensorView a, TensorView b, Tensor tactic = runner.selectHeuristic(m, n, k); } - auto stream = get_stream(a->device); + auto stream = get_stream(a.device()); auto runKernel = [&](void* workspace) { - runner.run(m, n, k, a->data, a_scale->data, b->data, b_scale->data, out->data, - globalScale.has_value() ? globalScale.value()->data : nullptr, nullptr, workspace, - stream, a->device.device_id, tactic); + runner.run(m, n, k, a.data_ptr(), a_scale.data_ptr(), b.data_ptr(), b_scale.data_ptr(), + out.data_ptr(), globalScale.has_value() ? globalScale.value().data_ptr() : nullptr, + nullptr, workspace, stream, a.device().device_id, tactic); }; int64_t const required_workspace_size = runner.getWorkspaceSizeInBytes(m, n, k, tactic); int64_t const provided_workspace_size = workspace_buffer.numel() * get_element_size(workspace_buffer); if (provided_workspace_size < required_workspace_size) { - Tensor new_workspace = alloc_tensor({required_workspace_size}, dl_int8, a->device); - runKernel(new_workspace->data); + Tensor new_workspace = alloc_tensor({required_workspace_size}, dl_int8, a.device()); + runKernel(new_workspace.data_ptr()); } else { - runKernel(workspace_buffer->data); + runKernel(workspace_buffer.data_ptr()); } } diff --git a/csrc/trtllm_low_latency_gemm_runner.cu b/csrc/trtllm_low_latency_gemm_runner.cu index 489ae05c73..f3ce0d43c3 100644 --- a/csrc/trtllm_low_latency_gemm_runner.cu +++ b/csrc/trtllm_low_latency_gemm_runner.cu @@ -251,18 +251,18 @@ void trtllm_low_latency_gemm(TensorView workspace_buffer, TensorView a, TensorVi CHECK_INPUT(out); CHECK_INPUT(workspace_buffer); CHECK_DIM(2, a); - TVM_FFI_ICHECK(b->ndim == 3) << "b must be a block layout matrix (3D tensor with " - "dims [N/BLOCK_SIZE, K, BLOCK_SIZE])"; - TVM_FFI_ICHECK_EQ(a->dtype, b->dtype); - TVM_FFI_ICHECK(a->dtype == dl_float8_e4m3fn) << "a must be a Float8 tensor"; - - int32_t m = a->shape[0]; - int32_t k = a->shape[1]; - int32_t n = b->shape[1]; - auto const blockSize = b->shape[2]; - auto const kFromB = b->shape[0] * blockSize; - TVM_FFI_ICHECK(kFromB == a->shape[1]) << "Matrix dimensions don't match for multiplication"; - TVM_FFI_ICHECK(out->shape[0] == m && out->shape[1] == n) << "Output tensor has wrong dimensions"; + TVM_FFI_ICHECK(b.ndim() == 3) << "b must be a block layout matrix (3D tensor with " + "dims [N/BLOCK_SIZE, K, BLOCK_SIZE])"; + TVM_FFI_ICHECK_EQ(a.dtype(), b.dtype()); + TVM_FFI_ICHECK(a.dtype() == dl_float8_e4m3fn) << "a must be a Float8 tensor"; + + int32_t m = a.size(0); + int32_t k = a.size(1); + int32_t n = b.size(1); + auto const blockSize = b.size(2); + auto const kFromB = b.size(0) * blockSize; + TVM_FFI_ICHECK(kFromB == a.size(1)) << "Matrix dimensions don't match for multiplication"; + TVM_FFI_ICHECK(out.size(0) == m && out.size(1) == n) << "Output tensor has wrong dimensions"; if (tactic == -1) { tactic = select_kernel(m, n, k, gemm::gemm::GemmInterface()); @@ -274,7 +274,7 @@ void trtllm_low_latency_gemm(TensorView workspace_buffer, TensorView a, TensorVi .outputType = gemm::trtllm::gen::Dtype::Bfloat16, }); - auto stream = get_stream(a->device); + auto stream = get_stream(a.device()); int64_t const required_workspace_size = getWorkspaceSizeInBytes(m, n, k, tactic); int64_t const provided_workspace_size = @@ -286,8 +286,8 @@ void trtllm_low_latency_gemm(TensorView workspace_buffer, TensorView a, TensorVi "workspace."; } - runner.run(m, n, k, a->data, b->data, out->data, globalScale->data, workspace_buffer->data, - stream, a->device.device_id, tactic); + runner.run(m, n, k, a.data_ptr(), b.data_ptr(), out.data_ptr(), globalScale.data_ptr(), + workspace_buffer.data_ptr(), stream, a.device().device_id, tactic); } enum class Dtype : int64_t { diff --git a/csrc/trtllm_mnnvl_allreduce.cu b/csrc/trtllm_mnnvl_allreduce.cu index 76d11d8624..6bac5372a8 100644 --- a/csrc/trtllm_mnnvl_allreduce.cu +++ b/csrc/trtllm_mnnvl_allreduce.cu @@ -30,13 +30,13 @@ void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_ int64_t buffer_M, TensorView buffer_flags_mnnvl, int64_t nranks, int64_t rank, bool wait_for_results, bool launch_with_pdl, Optional out) { - cudaSetDevice(in->device.device_id); - auto stream = get_stream(in->device); + cudaSetDevice(in.device().device_id); + auto stream = get_stream(in.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in->dtype, c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(in.dtype(), c_type, [&] { // Extract parameters from tensors - int64_t num_tokens = in->shape[0]; - int64_t token_dim = in->shape[1]; + int64_t num_tokens = in.size(0); + int64_t token_dim = in.size(1); // Validate input parameters TVM_FFI_ICHECK_EQ(token_dim % (sizeof(float2) / sizeof(c_type)), 0) @@ -57,11 +57,11 @@ void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_ params.token_dim = token_dim; params.buffer_ptrs_dev = reinterpret_cast(buffer_ptrs_dev); params.multicast_ptr = reinterpret_cast(multicast_buffer_ptr); - params.buffer_flags = buffer_flags_mnnvl->data; + params.buffer_flags = buffer_flags_mnnvl.data_ptr(); params.wait_for_results = wait_for_results; params.launch_with_pdl = launch_with_pdl; - params.input = in->data; - params.output = out.has_value() ? out.value()->data : nullptr; + params.input = in.data_ptr(); + params.output = out.has_value() ? out.value().data_ptr() : nullptr; params.stream = stream; auto status = twoshot_allreduce_dispatch_world_size(params); @@ -74,21 +74,21 @@ void trtllm_mnnvl_all_reduce(TensorView in, int64_t multicast_buffer_ptr, int64_ void trtllm_mnnvl_rmsnorm(int64_t multicast_buffer_ptr, TensorView prenorm_output, TensorView normed_output, TensorView gamma, double epsilon, TensorView residual, TensorView buffer_flags, bool launch_with_pdl) { - cudaSetDevice(prenorm_output->device.device_id); - auto stream = get_stream(prenorm_output->device); + cudaSetDevice(prenorm_output.device().device_id); + auto stream = get_stream(prenorm_output.device()); - DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output->dtype, c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_MNNVL_ALLREDUCE(prenorm_output.dtype(), c_type, [&] { // Create the parameters struct RMSNormParams params; - params.residual_output = prenorm_output->data; - params.output = normed_output->data; + params.residual_output = prenorm_output.data_ptr(); + params.output = normed_output.data_ptr(); params.input = reinterpret_cast(multicast_buffer_ptr); - params.gamma = gamma->data; + params.gamma = gamma.data_ptr(); params.epsilon = epsilon; - params.residual = residual->data; - params.buffer_flags = reinterpret_cast(buffer_flags->data); - params.batch = normed_output->shape[0]; - params.hidden_dim = normed_output->shape[1]; + params.residual = residual.data_ptr(); + params.buffer_flags = reinterpret_cast(buffer_flags.data_ptr()); + params.batch = normed_output.size(0); + params.hidden_dim = normed_output.size(1); params.stream = stream; params.launch_with_pdl = launch_with_pdl; auto status = twoshot_rmsnorm_dispatch_hidden_dim(params); diff --git a/csrc/trtllm_moe_allreduce_fusion.cu b/csrc/trtllm_moe_allreduce_fusion.cu index a7ee3fc0c4..de12dad4f8 100644 --- a/csrc/trtllm_moe_allreduce_fusion.cu +++ b/csrc/trtllm_moe_allreduce_fusion.cu @@ -32,32 +32,33 @@ void trtllm_moe_allreduce_fusion( TensorView moe_reduction_token_input, Optional layout_code, Optional moe_allreduce_out, Optional residual_out, Optional norm_out, Optional quant_out, Optional scale_out) { - cudaSetDevice(moe_reduction_active_experts_token_input->device.device_id); - auto stream = get_stream(moe_reduction_active_experts_token_input->device); + cudaSetDevice(moe_reduction_active_experts_token_input.device().device_id); + auto stream = get_stream(moe_reduction_active_experts_token_input.device()); DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE( - moe_reduction_active_experts_token_input->dtype, c_type, [&] { + moe_reduction_active_experts_token_input.dtype(), c_type, [&] { MoeReductionAllReduceFusionParams params; params.nranks = world_size; params.rank = world_rank; params.size = token_num * hidden_size; params.hidden_dim = hidden_size; - params.workspace = reinterpret_cast(workspace_ptrs->data); + params.workspace = reinterpret_cast(workspace_ptrs.data_ptr()); - params.moe_allreduce_out = moe_allreduce_out.has_value() - ? reinterpret_cast(moe_allreduce_out.value()->data) - : nullptr; - params.residual_in = reinterpret_cast(residual_in->data); + params.moe_allreduce_out = + moe_allreduce_out.has_value() + ? reinterpret_cast(moe_allreduce_out.value().data_ptr()) + : nullptr; + params.residual_in = reinterpret_cast(residual_in.data_ptr()); params.residual_out = residual_out.has_value() - ? reinterpret_cast(residual_out.value()->data) + ? reinterpret_cast(residual_out.value().data_ptr()) : nullptr; params.norm_out = - norm_out.has_value() ? reinterpret_cast(norm_out.value()->data) : nullptr; + norm_out.has_value() ? reinterpret_cast(norm_out.value().data_ptr()) : nullptr; params.quant_out = - quant_out.has_value() ? reinterpret_cast(quant_out.value()->data) : nullptr; + quant_out.has_value() ? reinterpret_cast(quant_out.value().data_ptr()) : nullptr; params.scale_out = - scale_out.has_value() ? reinterpret_cast(scale_out.value()->data) : nullptr; - params.rms_gamma = reinterpret_cast(rms_gamma->data); + scale_out.has_value() ? reinterpret_cast(scale_out.value().data_ptr()) : nullptr; + params.rms_gamma = reinterpret_cast(rms_gamma.data_ptr()); params.rms_eps = static_cast(rms_eps); params.scale_factor = static_cast(scale_factor); params.layout = layout_code.has_value() @@ -67,10 +68,11 @@ void trtllm_moe_allreduce_fusion( params.moe_reduction_device_num_experts = moe_reduction_device_num_experts; params.moe_reduction_scale_input = - reinterpret_cast(moe_reduction_scale_input->data); + reinterpret_cast(moe_reduction_scale_input.data_ptr()); params.moe_reduction_active_experts_token_input = - reinterpret_cast(moe_reduction_active_experts_token_input->data); - params.moe_reduction_token_input = reinterpret_cast(moe_reduction_token_input->data); + reinterpret_cast(moe_reduction_active_experts_token_input.data_ptr()); + params.moe_reduction_token_input = + reinterpret_cast(moe_reduction_token_input.data_ptr()); auto status = moereduction_allreduce_fusion_op(params, launch_with_pdl); TVM_FFI_ICHECK(status == cudaSuccess) @@ -85,11 +87,11 @@ void trtllm_moe_finalize_allreduce_fusion( bool launch_with_pdl, TensorView workspace, int64_t const world_rank, int64_t const world_size, double const eps, Optional shared_expert_output, Optional expert_scale_factor) { - DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in->dtype, c_type, [&] { + DISPATCH_FLOATING_TYPES_FOR_ALLREDUCE(residual_in.dtype(), c_type, [&] { MoeFinalizeAllReduceFusionParams params; - int hidden_dim = residual_in->shape[residual_in->ndim - 1]; - int top_k = expanded_idx_to_permuted_idx->shape[expanded_idx_to_permuted_idx->ndim - 1]; + int hidden_dim = residual_in.size(-1); + int top_k = expanded_idx_to_permuted_idx.size(-1); params.quant_out = nullptr; params.scale_out = nullptr; @@ -101,26 +103,27 @@ void trtllm_moe_finalize_allreduce_fusion( params.hidden_dim = hidden_dim; // workspace: AR scratch space - params.workspace = reinterpret_cast(workspace->data); - params.rms_gamma = norm_weight->data; + params.workspace = reinterpret_cast(workspace.data_ptr()); + params.rms_gamma = norm_weight.data_ptr(); params.rms_eps = static_cast(eps); - params.residual_in = residual_in->data; - params.stream = get_stream(norm_weight->device); + params.residual_in = residual_in.data_ptr(); + params.stream = get_stream(norm_weight.device()); // MOE Reduction specific params params.top_k = top_k; - params.allreduce_in = allreduce_in->data; + params.allreduce_in = allreduce_in.data_ptr(); params.expert_scale_factor = - expert_scale_factor.has_value() ? expert_scale_factor.value()->data : nullptr; - TVM_FFI_ICHECK_EQ(expanded_idx_to_permuted_idx->dtype, dl_int32) + expert_scale_factor.has_value() ? expert_scale_factor.value().data_ptr() : nullptr; + TVM_FFI_ICHECK_EQ(expanded_idx_to_permuted_idx.dtype(), dl_int32) << "expanded_idx_to_permuted_idx must be int32"; - params.expanded_idx_to_permuted_idx = static_cast(expanded_idx_to_permuted_idx->data); + params.expanded_idx_to_permuted_idx = + static_cast(expanded_idx_to_permuted_idx.data_ptr()); params.shared_expert_output = - shared_expert_output.has_value() ? shared_expert_output.value()->data : nullptr; + shared_expert_output.has_value() ? shared_expert_output.value().data_ptr() : nullptr; // output tensors - params.norm_out = norm_out->data; - params.residual_out = residual_out->data; + params.norm_out = norm_out.data_ptr(); + params.residual_out = residual_out.data_ptr(); auto status = moefinalize_allreduce_fusion_op(params, launch_with_pdl); TVM_FFI_ICHECK(status == cudaSuccess) diff --git a/csrc/tvm_ffi_utils.h b/csrc/tvm_ffi_utils.h index 1c3f7d4952..402c9933dd 100644 --- a/csrc/tvm_ffi_utils.h +++ b/csrc/tvm_ffi_utils.h @@ -221,35 +221,35 @@ constexpr DLDevice cpu = DLDevice{kDLCPU, 0}; inline void check_shape(const tvm::ffi::Tensor& a, const tvm::ffi::Tensor& b, const char* a_name, const char* b_name) { - TVM_FFI_ICHECK_EQ(a->ndim, b->ndim) << a_name << "->ndim and " << b_name << "->ndim mismatch"; - for (int i = 0; i < a->ndim; ++i) { - TVM_FFI_ICHECK_EQ(a->shape[i], b->shape[i]) - << a_name << "->shape[" << i << "] and " << b_name << "->shape[" << i << "] mismatch"; + TVM_FFI_ICHECK_EQ(a.ndim(), b.ndim()) << a_name << ".ndim() and " << b_name << ".ndim() mismatch"; + for (int i = 0; i < a.ndim(); ++i) { + TVM_FFI_ICHECK_EQ(a.size(i), b.size(i)) + << a_name << ".size(" << i << ") and " << b_name << ".size(" << i << ") mismatch"; } } inline void check_shape(const tvm::ffi::TensorView& a, const tvm::ffi::TensorView& b, const char* a_name, const char* b_name) { - TVM_FFI_ICHECK_EQ(a->ndim, b->ndim) << a_name << "->ndim and " << b_name << "->ndim mismatch"; - for (int i = 0; i < a->ndim; ++i) { - TVM_FFI_ICHECK_EQ(a->shape[i], b->shape[i]) - << a_name << "->shape[" << i << "] and " << b_name << "->shape[" << i << "] mismatch"; + TVM_FFI_ICHECK_EQ(a.ndim(), b.ndim()) << a_name << ".ndim() and " << b_name << ".ndim() mismatch"; + for (int i = 0; i < a.ndim(); ++i) { + TVM_FFI_ICHECK_EQ(a.size(i), b.size(i)) + << a_name << ".size(" << i << ") and " << b_name << ".size(" << i << ") mismatch"; } } #define CHECK_CUDA(x) \ - TVM_FFI_ICHECK_EQ(x->device.device_type, kDLCUDA) << #x " must be a CUDA tensor"; + TVM_FFI_ICHECK_EQ(x.device().device_type, kDLCUDA) << #x " must be a CUDA tensor"; #define CHECK_CPU(x) \ - TVM_FFI_ICHECK_EQ(x->device.device_type, kDLCPU) << #x " must be a host tensor"; + TVM_FFI_ICHECK_EQ(x.device().device_type, kDLCPU) << #x " must be a host tensor"; #define CHECK_CONTIGUOUS(x) TVM_FFI_ICHECK(x.IsContiguous()) << #x " must be contiguous"; -#define CHECK_LAST_DIM_CONTIGUOUS(x) \ - TVM_FFI_ICHECK_EQ(x->strides[x->ndim - 1], 1) \ +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TVM_FFI_ICHECK_EQ(x.stride(-1), 1) \ #x "must be contiguous at last dimension"; #define CHECK_INPUT(x) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x) #define CHECK_INPUT_TYPE(x, st) \ - TVM_FFI_ICHECK_EQ(x->dtype, st) << "Inconsistency of Tensor type: " #x; + TVM_FFI_ICHECK_EQ(x.dtype(), st) << "Inconsistency of Tensor type: " #x; #define CHECK_INPUT_AND_TYPE(x, st) \ CHECK_CUDA(x); \ CHECK_CONTIGUOUS(x); \ @@ -257,11 +257,11 @@ inline void check_shape(const tvm::ffi::TensorView& a, const tvm::ffi::TensorVie #define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ CHECK_CUDA(x); \ CHECK_LAST_DIM_CONTIGUOUS(x) -#define CHECK_DIM(d, x) TVM_FFI_ICHECK_EQ(x->ndim, d) << #x " must be a " #d "D tensor"; +#define CHECK_DIM(d, x) TVM_FFI_ICHECK_EQ(x.ndim(), d) << #x " must be a " #d "D tensor"; #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) -#define CHECK_DEVICE(a, b) \ - TVM_FFI_ICHECK_EQ(a->device.device_type, b->device.device_type); \ - TVM_FFI_ICHECK_EQ(a->device.device_id, b->device.device_id); +#define CHECK_DEVICE(a, b) \ + TVM_FFI_ICHECK_EQ(a.device().device_type, b.device().device_type); \ + TVM_FFI_ICHECK_EQ(a.device().device_id, b.device().device_id); inline cudaStream_t get_current_stream() { int device; @@ -273,10 +273,12 @@ inline cudaStream_t get_stream(DLDevice device) { return static_cast(TVMFFIEnvGetStream(device.device_type, device.device_id)); } -inline int64_t get_element_size(ffi::Tensor x) { return (x->dtype.bits * x->dtype.lanes) / 8; } +inline int64_t get_element_size(ffi::Tensor x) { return (x.dtype().bits * x.dtype().lanes) / 8; } -inline int64_t get_element_size(ffi::TensorView x) { return (x->dtype.bits * x->dtype.lanes) / 8; } +inline int64_t get_element_size(ffi::TensorView x) { + return (x.dtype().bits * x.dtype().lanes) / 8; +} inline ffi::Tensor alloc_tensor(tvm::ffi::Shape shape, DLDataType dtype, DLDevice device) { - return ffi::Tensor::FromDLPackAlloc(TVMFFIEnvGetTensorAllocator(), shape, dtype, device); + return ffi::Tensor::FromEnvAlloc(TVMFFIEnvTensorAlloc, shape, dtype, device); } diff --git a/csrc/vllm_custom_all_reduce.cu b/csrc/vllm_custom_all_reduce.cu index 14b1425e5c..49fbefebc0 100644 --- a/csrc/vllm_custom_all_reduce.cu +++ b/csrc/vllm_custom_all_reduce.cu @@ -28,12 +28,12 @@ fptr_t init_custom_ar(Array fake_ipc_ptrs, TensorView rank_data, int64_t for (int i = 0; i < world_size; i++) { ipc_ptrs[i] = reinterpret_cast(fake_ipc_ptrs[i]); } - return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data->data, rank_data.numel(), rank, + return (fptr_t) new vllm::CustomAllreduce(ipc_ptrs, rank_data.data_ptr(), rank_data.numel(), rank, world_size, full_nvlink); } /** - * Make sure tensor t's data lies completely within ((char)t->data) + + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() * because it allows transpose of contiguous slice (i.e. slicing the first * dimension). Currently, we require this because stride information is not @@ -52,24 +52,24 @@ bool _is_weak_contiguous(TensorView t) { auto numel = t.numel(); auto element_size = get_element_size(t); return t.IsContiguous() || - (tvm::ffi::GetDataSize(numel, t->dtype) - t->byte_offset * element_size == + (tvm::ffi::GetDataSize(numel, t.dtype()) - t.byte_offset() * element_size == numel * element_size); } /** * Performs an out-of-place allreduce and stores result in out. * - * If _reg_buffer is null, assumes inp->data is already IPC-registered. + * If _reg_buffer is null, assumes inp.data_ptr() is already IPC-registered. * Otherwise, _reg_buffer is assumed to be IPC-registered and inp is first * copied into _reg_buffer. */ void all_reduce(fptr_t _fa, TensorView inp, TensorView out, fptr_t _reg_buffer, int64_t reg_buffer_sz_bytes, int64_t num_ctas) { auto fa = reinterpret_cast(_fa); - cudaSetDevice(inp->device.device_id); - auto stream = get_stream(inp->device); + cudaSetDevice(inp.device().device_id); + auto stream = get_stream(inp.device()); - TVM_FFI_ICHECK_EQ(inp->dtype, out->dtype); + TVM_FFI_ICHECK_EQ(inp.dtype(), out.dtype()); TVM_FFI_ICHECK_EQ(inp.numel(), out.numel()); TVM_FFI_ICHECK(_is_weak_contiguous(out)); TVM_FFI_ICHECK(_is_weak_contiguous(inp)); @@ -78,26 +78,27 @@ void all_reduce(fptr_t _fa, TensorView inp, TensorView out, fptr_t _reg_buffer, if (reg_buffer) { TVM_FFI_ICHECK_LE(input_size, reg_buffer_sz_bytes); auto status = - cudaMemcpyAsync(reg_buffer, inp->data, input_size, cudaMemcpyDeviceToDevice, stream); + cudaMemcpyAsync(reg_buffer, inp.data_ptr(), input_size, cudaMemcpyDeviceToDevice, stream); TVM_FFI_ICHECK(status == cudaSuccess); } else { - reg_buffer = inp->data; + reg_buffer = inp.data_ptr(); } - switch (encode_dlpack_dtype(out->dtype)) { + switch (encode_dlpack_dtype(out.dtype())) { case float32_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), out.numel(), num_ctas); + reinterpret_cast(out.data_ptr()), out.numel(), num_ctas); break; } case float16_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), out.numel(), num_ctas); + reinterpret_cast(out.data_ptr()), out.numel(), num_ctas); break; } #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) case bfloat16_code: { fa->allreduce(stream, reinterpret_cast(reg_buffer), - reinterpret_cast(out->data), out.numel(), num_ctas); + reinterpret_cast(out.data_ptr()), out.numel(), + num_ctas); break; } #endif diff --git a/csrc/xqa/xqa_wrapper.cu b/csrc/xqa/xqa_wrapper.cu index 1a5d636e10..e4484088b5 100644 --- a/csrc/xqa/xqa_wrapper.cu +++ b/csrc/xqa/xqa_wrapper.cu @@ -31,25 +31,25 @@ void xqa_wrapper(int64_t multiProcessorCount, int64_t nbKHeads, int64_t slidingW int64_t qSeqLen, TensorView qCuSeqLens, TensorView mask, #endif TensorView semaphores, TensorView scratch) { - auto stream = get_stream(output->device); + auto stream = get_stream(output.device()); float const* attentionSinksPtr = - attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value()->data) + attentionSinks.has_value() ? reinterpret_cast(attentionSinks.value().data_ptr()) : nullptr; launchMHAFlashInfer(multiProcessorCount, nbKHeads, slidingWinSize, qScale, - reinterpret_cast(output->data), + reinterpret_cast(output.data_ptr()), #if LOW_PREC_OUTPUT - reinterpret_cast(rcpOutScale->data), + reinterpret_cast(rcpOutScale.data_ptr()), #endif - reinterpret_cast(q->data), attentionSinksPtr, - reinterpret_cast(pool->data), - reinterpret_cast(kvCachePageList->data), maxSeqLen, - reinterpret_cast(seqLen->data), batchSize, - reinterpret_cast(kvCacheScale->data), + reinterpret_cast(q.data_ptr()), attentionSinksPtr, + reinterpret_cast(pool.data_ptr()), + reinterpret_cast(kvCachePageList.data_ptr()), + maxSeqLen, reinterpret_cast(seqLen.data_ptr()), batchSize, + reinterpret_cast(kvCacheScale.data_ptr()), #if SPEC_DEC - qSeqLen, reinterpret_cast(qCuSeqLens->data), - reinterpret_cast(mask->data), + qSeqLen, reinterpret_cast(qCuSeqLens.data_ptr()), + reinterpret_cast(mask.data_ptr()), #endif - reinterpret_cast(semaphores->data), - reinterpret_cast(scratch->data), stream); + reinterpret_cast(semaphores.data_ptr()), + reinterpret_cast(scratch.data_ptr()), stream); } diff --git a/flashinfer-cubin/pyproject.toml b/flashinfer-cubin/pyproject.toml index 65d1115122..866ff08db2 100644 --- a/flashinfer-cubin/pyproject.toml +++ b/flashinfer-cubin/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm", "numpy", "apache-tvm-ffi==0.1.0b15"] +requires = ["setuptools>=61.0", "wheel", "requests", "filelock", "torch", "tqdm", "numpy", "apache-tvm-ffi>=0.1,<0.2"] build-backend = "build_backend" backend-path = ["."] diff --git a/flashinfer-jit-cache/pyproject.toml b/flashinfer-jit-cache/pyproject.toml index a37329341a..5cd679745f 100644 --- a/flashinfer-jit-cache/pyproject.toml +++ b/flashinfer-jit-cache/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["setuptools>=77", "packaging>=24", "wheel", "torch", "ninja", "requests", "numpy", "nvidia-ml-py", "nvidia-nvshmem-cu12", "apache-tvm-ffi==0.1.0b15"] +requires = ["setuptools>=77", "packaging>=24", "wheel", "torch", "ninja", "requests", "numpy", "nvidia-ml-py", "nvidia-nvshmem-cu12", "apache-tvm-ffi>=0.1,<0.2"] build-backend = "build_backend" backend-path = ["."] diff --git a/flashinfer/gemm.py b/flashinfer/gemm.py index 2ef535f97c..63a2f7e211 100644 --- a/flashinfer/gemm.py +++ b/flashinfer/gemm.py @@ -566,7 +566,7 @@ def forward( (a.shape[0], b.shape[1]), dtype=a.dtype, device=a.device ) gemm_fn(b.t(), a.t(), bias, tactic, c, pdl) - return c.t() + return c return TGVGemmRunner() diff --git a/flashinfer/jit/activation.py b/flashinfer/jit/activation.py index 4407ad146f..d29166addf 100644 --- a/flashinfer/jit/activation.py +++ b/flashinfer/jit/activation.py @@ -34,13 +34,13 @@ {{ act_func_def }} void {{ func_name }}(TensorView out, TensorView input, bool enable_pdl) { - int d = input->shape[input->ndim -1] / 2; - int64_t num_tokens = input.numel() / input->shape[input->ndim -1]; + int d = input.size(input.ndim() -1) / 2; + int64_t num_tokens = input.numel() / input.size(input.ndim() -1); dim3 grid(num_tokens); - cudaSetDevice(out->device.device_id); - const cudaStream_t stream = get_stream(out->device); - DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input->dtype, c_type, [&] { + cudaSetDevice(out.device().device_id); + const cudaStream_t stream = get_stream(out.device()); + DISPATCH_DLPACK_DTYPE_TO_CTYPE_FP16(input.dtype(), c_type, [&] { uint32_t vec_size = 16 / sizeof(c_type); cudaLaunchConfig_t config; config.gridDim = num_tokens; @@ -55,8 +55,8 @@ auto kernel = flashinfer::activation::act_and_mul_kernel; - cudaLaunchKernelEx(&config, kernel, static_cast(out->data), - static_cast(input->data), d); + cudaLaunchKernelEx(&config, kernel, static_cast(out.data_ptr()), + static_cast(input.data_ptr()), d); cudaError_t err = cudaGetLastError(); TVM_FFI_ICHECK(err == cudaSuccess) << "Failed to launch kernel: " << cudaGetErrorString(err); diff --git a/flashinfer/jit/attention/utils.py b/flashinfer/jit/attention/utils.py index ed617c71b7..ac033a65b8 100644 --- a/flashinfer/jit/attention/utils.py +++ b/flashinfer/jit/attention/utils.py @@ -55,9 +55,9 @@ def generate_additional_params( additional_params_setter = " \\\n".join( [ ( - f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}.value()->data): nullptr;" + f"params.additional_params.{var} = {var} ? static_cast<{dtype}*>({var}.value().data_ptr()): nullptr;" if var.startswith("maybe") - else f"params.additional_params.{var} = static_cast<{dtype}*>({var}->data);" + else f"params.additional_params.{var} = static_cast<{dtype}*>({var}.data_ptr());" ) for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) ] @@ -70,9 +70,9 @@ def generate_additional_params( additional_params_setter = " \\\n".join( [ ( - f"params.{var} = {var} ? static_cast<{dtype}*>({var}.value()->data): nullptr;" + f"params.{var} = {var} ? static_cast<{dtype}*>({var}.value().data_ptr()): nullptr;" if var.startswith("maybe") - else f"params.{var} = static_cast<{dtype}*>({var}->data);" + else f"params.{var} = static_cast<{dtype}*>({var}.data_ptr());" ) for dtype, var in zip(additional_tensor_dtypes, additional_tensor_names) ] diff --git a/pyproject.toml b/pyproject.toml index 38467d4b40..57a966c04d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,7 @@ license-files = ["LICENSE", "LICENSE*.txt"] flashinfer = "flashinfer.__main__:cli" [build-system] -requires = ["setuptools>=77", "packaging>=24", "apache-tvm-ffi==0.1.0b15"] +requires = ["setuptools>=77", "packaging>=24", "apache-tvm-ffi>=0.1,<0.2"] build-backend = "build_backend" backend-path = ["."] diff --git a/requirements.txt b/requirements.txt index a4e391c38d..a31b6ebdc8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -apache-tvm-ffi==0.1.0b15 +apache-tvm-ffi>=0.1,<0.2 click einops ninja