From c00388a2ef933f243b28d89bafb1b329d72557ad Mon Sep 17 00:00:00 2001 From: Connor Holmes Date: Wed, 20 Dec 2023 16:05:26 -0800 Subject: [PATCH] Mixtral FastGen Support (#4828) Adds support for Mixtral with FastGen. Key features implemented: 1. Top-2 MoE support 2. Better support for RoPE thetas 3. The mistral model implementation --------- Co-authored-by: Michael Wyatt --- .../v2/checkpoint/huggingface_engine.py | 2 +- deepspeed/inference/v2/engine_factory.py | 7 + .../v2/kernels/ragged_ops/__init__.py | 2 +- .../kernels/ragged_ops/includes/top_k_utils.h | 15 + .../blocked_kv_rotary.cpp | 4 + .../blocked_kv_rotary.cu | 26 +- .../blocked_kv_rotary.cuh | 1 + .../blocked_kv_rotary.h | 1 + .../blocked_kv_rotary.py | 12 +- .../ragged_ops/moe_gather/moe_gather.cpp | 10 +- .../ragged_ops/moe_gather/moe_gather.cu | 107 +++++-- .../ragged_ops/moe_gather/moe_gather.cuh | 2 + .../ragged_ops/moe_gather/moe_gather.h | 3 +- .../ragged_ops/moe_gather/moe_gather.py | 9 +- .../ragged_ops/moe_scatter/moe_scatter.cpp | 7 +- .../ragged_ops/moe_scatter/moe_scatter.cu | 188 ++++++------ .../ragged_ops/moe_scatter/moe_scatter.cuh | 1 + .../ragged_ops/moe_scatter/moe_scatter.py | 8 +- .../v2/kernels/ragged_ops/ragged_ops.cpp | 6 +- .../__init__.py | 2 +- .../top_k_gating.cpp} | 26 +- .../top_k_gating.cu} | 69 +++-- .../top_k_gating.cuh} | 3 +- .../top_k_gating.h} | 4 +- .../top_k_gating.py} | 14 +- .../v2/model_implementations/__init__.py | 1 + .../common_parameters/moe_parameters.py | 23 +- .../model_implementations/falcon/__init__.py | 2 +- .../{falcon_containers.py => container.py} | 4 +- .../falcon/{falcon_model.py => model.py} | 4 +- .../falcon/{falcon_policy.py => policy.py} | 8 +- .../inference_transformer_base.py | 20 +- .../llama_v2/__init__.py | 2 +- .../{llama_v2_containers.py => container.py} | 4 +- .../llama_v2/{llama_v2_model.py => model.py} | 29 +- .../{llama_v2_policy.py => policy.py} | 6 +- .../v2/model_implementations/mistral/model.py | 25 +- .../model_implementations/mistral/policy.py | 8 +- .../model_implementations/mixtral/__init__.py | 6 + .../mixtral/container.py | 46 +++ .../v2/model_implementations/mixtral/model.py | 274 ++++++++++++++++++ .../model_implementations/mixtral/policy.py | 31 ++ .../v2/model_implementations/opt/container.py | 4 +- .../v2/model_implementations/opt/model.py | 3 +- .../v2/model_implementations/opt/policy.py | 6 +- .../inference/v2/modules/configs/__init__.py | 7 +- .../v2/modules/configs/attention_configs.py | 24 +- .../v2/modules/configs/moe_config.py | 6 + .../attention/dense_blocked_attention.py | 20 +- .../implementations/moe/cutlass_multi_gemm.py | 88 ++++-- op_builder/ragged_ops.py | 7 +- .../v2/kernels/ragged_ops/test_moe_gather.py | 67 ++++- .../v2/kernels/ragged_ops/test_moe_scatter.py | 69 +++-- ...t_top_1_gating.py => test_top_k_gating.py} | 83 +++++- .../parameters/test_parameter_list.py | 2 +- .../inference/v2/modules/test_blocked_attn.py | 11 +- .../inference/v2/modules/test_cutlass_moe.py | 114 ++++++++ 57 files changed, 1193 insertions(+), 340 deletions(-) create mode 100644 deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating => top_k_gating}/__init__.py (69%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cpp => top_k_gating/top_k_gating.cpp} (67%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cu => top_k_gating/top_k_gating.cu} (59%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.cuh => top_k_gating/top_k_gating.cuh} (87%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.h => top_k_gating/top_k_gating.h} (86%) rename deepspeed/inference/v2/kernels/ragged_ops/{top_1_gating/top_1_gating.py => top_k_gating/top_k_gating.py} (87%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_containers.py => container.py} (97%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_model.py => model.py} (98%) rename deepspeed/inference/v2/model_implementations/falcon/{falcon_policy.py => policy.py} (74%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_containers.py => container.py} (95%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_model.py => model.py} (83%) rename deepspeed/inference/v2/model_implementations/llama_v2/{llama_v2_policy.py => policy.py} (76%) create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/__init__.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/container.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/model.py create mode 100644 deepspeed/inference/v2/model_implementations/mixtral/policy.py rename tests/unit/inference/v2/kernels/ragged_ops/{test_top_1_gating.py => test_top_k_gating.py} (51%) diff --git a/deepspeed/inference/v2/checkpoint/huggingface_engine.py b/deepspeed/inference/v2/checkpoint/huggingface_engine.py index 6b64ed3185a2..ca9fb113b15a 100644 --- a/deepspeed/inference/v2/checkpoint/huggingface_engine.py +++ b/deepspeed/inference/v2/checkpoint/huggingface_engine.py @@ -61,7 +61,7 @@ def model_has_safetensors(model_name_or_path: str) -> bool: # We need to download the checkpoint files from HF if model_has_safetensors(self.model_name_or_path): # Prioritize downloading safetensors if they are available - allow_patterns = ["*.safetensors", "*.json", "*.pt"] + allow_patterns = ["*.safetensors", "*.json"] else: # Fallback to bin files when safetensors are not present allow_patterns = ["*.bin", "*.json", "*.pt"] diff --git a/deepspeed/inference/v2/engine_factory.py b/deepspeed/inference/v2/engine_factory.py index 9558125ff934..a0dc050bbbf9 100644 --- a/deepspeed/inference/v2/engine_factory.py +++ b/deepspeed/inference/v2/engine_factory.py @@ -17,6 +17,7 @@ OPTPolicy, Llama2Policy, MistralPolicy, + MixtralPolicy, FalconPolicy, ) from .model_implementations.inference_policy_base import POLICIES, InferenceV2Policy @@ -105,6 +106,12 @@ def build_hf_engine(path: str, assert version.parse(transformers.__version__) >= version.parse("4.34.0"), \ f"Mistral requires transformers >= 4.34.0, you have version {transformers.__version__}" policy = MistralPolicy(model_config, checkpoint_engine=checkpoint_engine) + elif model_config.model_type == "mixtral": + # Ensure we're using the correct version of transformers for mistral + import transformers + assert version.parse(transformers.__version__) >= version.parse("4.36.1"), \ + f"Mistral requires transformers >= 4.36.1, you have version {transformers.__version__}" + policy = MixtralPolicy(model_config, checkpoint_engine=checkpoint_engine) elif model_config.model_type == "falcon": policy = FalconPolicy(model_config, checkpoint_engine=checkpoint_engine) else: diff --git a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py index 988152b2e7c0..38a4ebd6fba3 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/__init__.py @@ -10,4 +10,4 @@ from .logits_gather import * from .moe_gather import * from .moe_scatter import * -from .top_1_gating import * +from .top_k_gating import * diff --git a/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h new file mode 100644 index 000000000000..abb9e15f8f6f --- /dev/null +++ b/deepspeed/inference/v2/kernels/ragged_ops/includes/top_k_utils.h @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. +// SPDX-License-Identifier: Apache-2.0 + +// DeepSpeed Team + +#define TOP_K_SWITCH(N_TOP_K, ...) \ + [&] { \ + if (1 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 1; \ + __VA_ARGS__(); \ + } else if (2 == N_TOP_K) { \ + constexpr int CONST_TOP_K = 2; \ + __VA_ARGS__(); \ + } \ + }() diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp index 8493bbf4b9af..a640c2b30164 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cpp @@ -13,6 +13,7 @@ (C_TYPE*)k.data_ptr(), \ (C_TYPE*)v.data_ptr(), \ (C_TYPE*)inv_freq_ptr, \ + theta_base, \ batch_wrapper, \ qkv_stride, \ kv_cache_stride, \ @@ -51,6 +52,8 @@ void kv_trained_rotary_embeddings(torch::Tensor& kv_cache, TORCH_CHECK(n_tokens == k.size(0)); TORCH_CHECK(n_tokens == v.size(0)); + const float theta_base = 0.f; + // Dimensions const int32_t block_size = kv_cache.size(1); const int32_t n_kv_heads = kv_cache.size(3); @@ -91,6 +94,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu index 980334f02b0b..5dd79f0c636a 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cu @@ -27,6 +27,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, T* k, T* v, const T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -114,7 +115,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, // Conversion to T and back means that both branches of this if statement // will produce the same results if using the same algo for producing the // freqs. - T trunc_freq = conversion::to(1.0 / powf(10000.0, inv_freq_flt)); + T trunc_freq = conversion::to(1.0 / powf(theta_base, inv_freq_flt)); inv_freq_flt = conversion::to(trunc_freq) * (float)global_token_idx; } @@ -158,7 +159,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, } else { inv_freq_flt = (float)((head_neuron_idx % half_head_size) * 2) / (float)headSize; - inv_freq_flt = 1.0 / powf(10000.0, inv_freq_flt) * (float)global_token_idx; + inv_freq_flt = 1.0 / powf(theta_base, inv_freq_flt) * (float)global_token_idx; } float rotary_sign = (head_neuron_idx >= half_head_size) ? -1.0f : 1.0f; @@ -186,6 +187,7 @@ __global__ void kv_rotary_pos_kernel(T* kv_cache, k, \ v, \ inv_freq, \ + theta_base, \ batch_desc, \ qkv_stride, \ kv_cache_stride, \ @@ -198,6 +200,7 @@ void launch_kv_rotary_kernel(T* kv_cache, T* k, T* v, T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, @@ -245,6 +248,7 @@ void launch_kv_rotary_kernel(T* kv_cache, TYPE * k, \ TYPE * v, \ TYPE * inv_freq, \ + const float theta_base, \ const BatchWrapperCPP batch_desc, \ const int qkv_stride, \ const int kv_cache_stride, \ @@ -262,10 +266,20 @@ INSTANTIATE_KV_ROTARY_KERNEL(__half) INSTANTIATE_KV_ROTARY_KERNEL(__nv_bfloat16) #endif -#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ - if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ - kv_rotary_pos_kernel<<>>( \ - kv_cache, q, k, v, nullptr, batch_desc, qkv_stride, kv_cache_stride, v_offset, 0); +#define DISPATCH_KV_COPY_IMPL(Q_RATIO, HEAD_SIZE) \ + if (q_ratio == Q_RATIO && head_size == HEAD_SIZE) \ + kv_rotary_pos_kernel \ + <<>>(kv_cache, \ + q, \ + k, \ + v, \ + nullptr, \ + 0.f, \ + batch_desc, \ + qkv_stride, \ + kv_cache_stride, \ + v_offset, \ + 0); template void launch_kv_copy_kernel(T* kv_cache, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh index be38ff30c46c..41a69d3b397b 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.cuh @@ -18,6 +18,7 @@ void launch_kv_rotary_kernel(T* kv_cache, T* k, T* v, T* inv_freq, + const float theta_base, const BatchWrapperCPP batch_desc, const int qkv_stride, const int kv_cache_stride, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h index 0615825c0a21..e56ce644dbbc 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.h @@ -45,6 +45,7 @@ void kv_rotary_embeddings(torch::Tensor& kv_cache, torch::Tensor& q, torch::Tensor& k, torch::Tensor& v, + const float theta_base, torch::Tensor& batch_metadata, torch::Tensor& seq_metadata, torch::Tensor& tokens_to_seq, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py index 50d9aca061f3..f206a4f5d28c 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary/blocked_kv_rotary.py @@ -21,7 +21,12 @@ class BlockedRotaryEmbeddings(DSKernelBase): supported_head_sizes = [64, 128] supported_q_ratios = [1, 2, 4, 5, 8, 16, 29, 35, 36, 71] - def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch.dtype) -> None: + def __init__(self, + head_size: int, + n_q_heads: int, + n_kv_heads: int, + dtype: torch.dtype, + theta_base: float = 10000.0) -> None: """ Args: head_size: The size of the attention head. @@ -51,6 +56,7 @@ def __init__(self, head_size: int, n_q_heads: int, n_kv_heads: int, dtype: torch self.head_size = head_size self.n_q_heads = n_q_heads self.n_kv_heads = n_kv_heads + self.theta_base = theta_base def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: RaggedBatchWrapper) -> None: """ @@ -66,5 +72,5 @@ def __call__(self, kv_cache: torch.Tensor, qkv: torch.Tensor, ragged_batch: Ragg k = qkv[:, self.head_size * self.n_q_heads:self.head_size * (self.n_q_heads + self.n_kv_heads)] v = qkv[:, self.head_size * (self.n_q_heads + self.n_kv_heads):] - self.kernel(kv_cache, q, k, v, ragged_batch.batch_metadata_buffer(), ragged_batch.inflight_seq_descriptors(), - ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) + self.kernel(kv_cache, q, k, v, self.theta_base, ragged_batch.batch_metadata_buffer(), + ragged_batch.inflight_seq_descriptors(), ragged_batch.tokens_to_seq(), ragged_batch.kv_ptrs()) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp index e55e1f48c125..506629406f0d 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp @@ -16,6 +16,8 @@ n_channels, \ n_experts, \ n_tokens, \ + n_top_k, \ + normalize_scales, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -27,17 +29,21 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_count) + const torch::Tensor& expert_count, + const bool normalize_scales) { const int32_t n_channels = layer_output.size(1); const int32_t n_experts = expert_count.size(0); const int32_t n_tokens = layer_output.size(0); + const int32_t n_top_k = mapped_slots.size(1); - TORCH_CHECK(moe_output.size(0) == n_tokens); + TORCH_CHECK(moe_output.size(0) == n_tokens * n_top_k); TORCH_CHECK(moe_output.size(1) == n_channels); TORCH_CHECK(scores.size(0) == n_tokens); TORCH_CHECK(mapped_slots.size(0) == n_tokens); + TORCH_CHECK(scores.size(1) == n_top_k); + TORCH_CHECK(layer_output.scalar_type() == moe_output.scalar_type()); TORCH_CHECK(scores.scalar_type() == torch::kFloat32); TORCH_CHECK(mapped_slots.scalar_type() == torch::kInt32); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu index c2fae24f5080..4153a2a3636f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu @@ -7,7 +7,8 @@ #include "ds_kernel_utils.h" #include "moe_gather.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" namespace gather { @@ -16,65 +17,105 @@ constexpr int threads = 256; } // namespace gather -template +template __global__ void moe_gather_kernel(T* layer_output, const T* moe_output, const float* scores, const int32_t* mapped_slots, int32_t* expert_counts, const int32_t n_channels, - const int32_t n_experts) + const int32_t n_experts, + const bool normalize_scales) { constexpr int32_t vector_size = gather::access_granularity / sizeof(T); constexpr int32_t stride = vector_size * gather::threads; const int32_t token_idx = blockIdx.x; - const int32_t mapped_slot = mapped_slots[token_idx]; + int32_t token_mapped_slots[N_TOP_K]; + + bool all_slots_invalid = true; + for (int i = 0; i < N_TOP_K; i++) { + token_mapped_slots[i] = mapped_slots[token_idx * N_TOP_K + i]; + all_slots_invalid &= (token_mapped_slots[i] == gating::unassigned); + } if (token_idx == 0) { // Reset expert counts for its next use. if (threadIdx.x < n_experts) { expert_counts[threadIdx.x] = 0; } } - if (mapped_slot == gating::unassigned) { - // This token was not assigned. + if (all_slots_invalid) { + // This token was not assigned to anything. // TODO(cmikeh2): It's possible we want different behavior here moving forward. return; } - const float score = scores[token_idx]; + float token_scores[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] = scores[token_idx * N_TOP_K + i]; } + + if (normalize_scales) { + // Normalize the scores so that they sum to 1. + float sum = 0.0f; + for (int i = 0; i < N_TOP_K; i++) { sum += token_scores[i]; } + + if (sum > 0.0f) { + for (int i = 0; i < N_TOP_K; i++) { token_scores[i] /= sum; } + } + } + const int32_t channel_offset = threadIdx.x * vector_size; - const T* moe_output_base = moe_output + mapped_slot * n_channels + channel_offset; + const T* moe_output_bases[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + moe_output_bases[i] = moe_output + token_mapped_slots[i] * n_channels + channel_offset; + } + T* layer_output_base = layer_output + token_idx * n_channels + channel_offset; #pragma unroll for (int i = 0; i < copyUnroll; i++) { - T reg_buffer[vector_size]; - if (i * stride + channel_offset < n_channels) { - mem_access::load_global(reg_buffer, - moe_output_base + i * stride); + float accum_buffer[vector_size]; + for (int j = 0; j < vector_size; j++) { + accum_buffer[j] = reduce::init(); + } + +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + T reg_buffer[vector_size]; + mem_access::load_global( + reg_buffer, moe_output_bases[j] + i * stride); +#pragma unroll + for (int k = 0; k < vector_size; k++) { + float up_cast = conversion::to(reg_buffer[k]); + accum_buffer[k] += up_cast * token_scores[j]; + } + } + + T store_buffer[vector_size]; #pragma unroll for (int j = 0; j < vector_size; j++) { - // There are accuracy implications of downcasting the score to a 16-bit - // data type, so we up-convert the input to 32-bit, multiply, and then - // down-convert back to 16-bit. - float up_cast = conversion::to(reg_buffer[j]); - reg_buffer[j] = conversion::to(up_cast * score); + store_buffer[j] = conversion::to(accum_buffer[j]); } mem_access::store_global(layer_output_base + i * stride, - reg_buffer); + store_buffer); } } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_gather_kernel<<>>( \ - layer_output, moe_output, scores, mapped_slots, expert_counts, n_channels, n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_gather_kernel<<>>(layer_output, \ + moe_output, \ + scores, \ + mapped_slots, \ + expert_counts, \ + n_channels, \ + n_experts, \ + normalize_scales); \ break; template @@ -86,6 +127,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream) { constexpr int vals_per_unroll = gather::threads * gather::access_granularity / sizeof(T); @@ -94,14 +137,16 @@ void launch_moe_gather(T* layer_output, const dim3 block(gather::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1) - LAUNCH_FOR_UNROLL(2) - LAUNCH_FOR_UNROLL(3) - LAUNCH_FOR_UNROLL(4) - LAUNCH_FOR_UNROLL(5) - LAUNCH_FOR_UNROLL(6) - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1) + LAUNCH_FOR_UNROLL(2) + LAUNCH_FOR_UNROLL(3) + LAUNCH_FOR_UNROLL(4) + LAUNCH_FOR_UNROLL(5) + LAUNCH_FOR_UNROLL(6) + } + }); } #define INSTANTIATE_GATHER_FOR_TYPE(TYPE) \ @@ -113,6 +158,8 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, \ const int32_t n_experts, \ const int32_t n_tokens, \ + const int32_t n_top_k, \ + const bool normalize_scales, \ cudaStream_t stream); INSTANTIATE_GATHER_FOR_TYPE(__half) diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh index f98a727ead58..b348d0cfb330 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cuh @@ -17,4 +17,6 @@ void launch_moe_gather(T* layer_output, const int32_t n_channels, const int32_t n_experts, const int32_t n_tokens, + const int32_t n_top_k, + const bool normalize_scales, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h index 7ffe9f8b4dc6..ec9e03057eb8 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.h @@ -16,4 +16,5 @@ void moe_gather(torch::Tensor& layer_output, const torch::Tensor& moe_output, const torch::Tensor& scores, const torch::Tensor& mapped_slots, - const torch::Tensor& expert_counts); + const torch::Tensor& expert_counts, + const bool normalize_scales); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py index c37683d03fbe..f03938171ba4 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_gather/moe_gather.py @@ -18,7 +18,7 @@ class MoEGather(DSKernelBase): supported_dtypes = [DtypeEnum.fp16, DtypeEnum.bf16] - def __init__(self, dtype: DtypeEnum, channels: int) -> None: + def __init__(self, dtype: DtypeEnum, channels: int, normalize_scores: bool = False) -> None: if not isinstance(dtype, DtypeEnum): dtype = DtypeEnum(dtype) @@ -31,6 +31,7 @@ def __init__(self, dtype: DtypeEnum, channels: int) -> None: inf_module = RaggedOpsBuilder().load() self.kernel = inf_module.moe_gather + self.normalize_scores = normalize_scores def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: torch.Tensor, mapped_slots: torch.Tensor, expert_counts: torch.Tensor) -> torch.Tensor: @@ -40,13 +41,13 @@ def __call__(self, layer_output: torch.Tensor, moe_output: torch.Tensor, scores: Arguments: layer_output (torch.Tensor): The output of the layer of shape [n_tokens, hidden_size]. This has been scaled appropriately. - moe_output (torch.Tensor): The output of the MoE of shape [n_tokens, hidden_size]. + moe_output (torch.Tensor): The output of the MoE of shape [n_tokens * n_top_k, hidden_size]. scores (torch.Tensor): The gating scores of shape [n_tokens]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. The index of token ``i`` in layer_output is ``mapped_slots[i]``. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. The indices of token ``i`` in layer_output is ``mapped_slots[i]``. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. This is passed to fuse the clearing of this data structure into the gather. Returns: layer_output """ - self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts) + self.kernel(layer_output, moe_output, scores, mapped_slots, expert_counts, self.normalize_scores) return layer_output diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp index 902f1cc0ea15..8f7ecbd1a287 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cpp @@ -18,6 +18,7 @@ n_channels, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } @@ -36,13 +37,17 @@ void moe_scatter(torch::Tensor& moe_input, { const int32_t n_tokens = activations.size(0); const int32_t n_channels = activations.size(1); + const int32_t n_top_k = assignments.size(1); // Should have a lot of matching buffer sizes here. - TORCH_CHECK(n_tokens == moe_input.size(0)); TORCH_CHECK(n_tokens == assignments.size(0)); TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_channels == moe_input.size(1)); + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k * n_tokens == moe_input.size(0)); + TORCH_CHECK(n_top_k == mapped_slots.size(1)); + const int32_t n_experts = expert_count_cumsums.size(0); TORCH_CHECK(moe_input.scalar_type() == activations.scalar_type()); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu index 0746cd7be645..d3eb4f649e79 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cu @@ -4,9 +4,9 @@ // DeepSpeed Team #include "ds_kernel_utils.h" -#include "moe_scatter.cuh" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; @@ -15,10 +15,11 @@ namespace scatter { constexpr int access_granularity = 16; constexpr int threads = 256; constexpr int warps = threads / hw_warp_size; +constexpr int max_experts = 1024; } // namespace scatter -template +template __global__ void moe_scatter_kernel(T* moe_input, int64_t* expert_count_cumsums, int32_t* mapped_slots, @@ -38,88 +39,78 @@ __global__ void moe_scatter_kernel(T* moe_input, // Bank aligned and sufficient __shared__ int32_t red_buffer[32]; - __shared__ int32_t token_0_row; + __shared__ int32_t expert_offsets[scatter::max_experts]; // CG helpers cg::thread_block tb = cg::this_thread_block(); cg::thread_block_tile warp = cg::tiled_partition(tb); - const int assigned_expert = assignments[token_idx]; - - // For the different codepaths, we'll converge on this variable for doing - // the token copy. - int32_t token_base_row; + // Fetch the assigned experts for this token. + int assigned_experts[N_TOP_K]; + for (int i = 0; i < N_TOP_K; i++) { + assigned_experts[i] = assignments[token_idx * N_TOP_K + i]; + } - if (token_idx == 0) { - // Token 0 will perform a cumsum on the data - int32_t expert_vals; - if (tidx < n_experts) { - expert_vals = expert_counts[tidx]; + bool all_unassigned = true; + for (int i = 0; i < N_TOP_K; i++) { + if (assigned_experts[i] != gating::unassigned) { + all_unassigned = false; } else { - expert_vals = 0; + mapped_slots[token_idx * N_TOP_K + i] = gating::unassigned; } + } + if (all_unassigned && token_idx != 0) return; + + // Do a prefix scan on the expert counts to get the base offsets. Here we use the + // single up-sweep variant. + int32_t expert_vals; + if (tidx < n_experts) { + expert_vals = expert_counts[tidx]; + } else { + expert_vals = 0; + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(expert_vals, i); - expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; - } + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(expert_vals, i); + expert_vals = (warp.thread_rank() < i) ? expert_vals : expert_vals + maybe_add; + } - if (warp.thread_rank() == hw_warp_size - 1) { - mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); - } + if (warp.thread_rank() == hw_warp_size - 1) { + mem_access::store_shared<4>(red_buffer + warp_rank, &expert_vals); + } - tb.sync(); + tb.sync(); - int32_t phase_2_val = 0; - if (warp.thread_rank() < scatter::warps) { - mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); - } + int32_t phase_2_val = 0; + if (warp.thread_rank() < scatter::warps) { + mem_access::load_shared<4>(&phase_2_val, red_buffer + warp.thread_rank()); + } #pragma unroll - for (int i = 1; i < hw_warp_size; i *= 2) { - int32_t maybe_add = warp.shfl_up(phase_2_val, i); - phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; - } - - int warp_offset = 0; - if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } - const int32_t expert_cumsum = warp_offset + expert_vals; - - if (tidx < n_experts) { - int64_t expert_cumsum_64 = (int64_t)expert_cumsum; - expert_count_cumsums[tidx] = expert_cumsum_64; - } - - if (assigned_expert == gating::unassigned) return; - if (assigned_expert - 1 == tidx) token_0_row = expert_cumsum; + for (int i = 1; i < hw_warp_size; i *= 2) { + int32_t maybe_add = warp.shfl_up(phase_2_val, i); + phase_2_val = (warp.thread_rank() < i) ? phase_2_val : phase_2_val + maybe_add; + } - tb.sync(); + int warp_offset = 0; + if (warp_rank > 0) { warp_offset = warp.shfl(phase_2_val, warp_rank - 1); } + const int32_t expert_cumsum = warp_offset + expert_vals; - if (assigned_expert != 0) { - token_base_row = token_0_row; - } else { - token_base_row = 0; - } + // Token 0 will write the + if (token_idx == 0 && tidx < n_experts) { + int64_t expert_cumsum_64 = (int64_t)expert_cumsum; + expert_count_cumsums[tidx] = expert_cumsum_64; + } - } else if (assigned_expert == gating::unassigned) { - // For whatever reason, don't need to perform the copy, so we'll early return - // and signal this wasn't mapped with a negative 1. - if (tidx == 0) mapped_slots[token_idx] = gating::unassigned; - return; - } else { - // For all other valid tokens, we can just do a block-scoped sum. - if (tidx < assigned_expert) { - token_base_row = expert_counts[tidx]; - } else { - token_base_row = 0; - } + // Since token 0 has now written the expert cumsum to global memory, + // if it has no valid experts, we can early return. + if (token_idx == 0 && all_unassigned) return; - warp.sync(); + if (tidx < n_experts) { expert_offsets[tidx] = expert_cumsum; } - // TODO(cmikeh2): Shouldn't use the internal api. - reduce::_block(tb, warp, &token_base_row); - } + // Ensure all the expert offsets are written in shared memory. + tb.sync(); // Data copy to appropriate location const int32_t thread_offset = tidx * vector_size; @@ -127,9 +118,16 @@ __global__ void moe_scatter_kernel(T* moe_input, const int32_t base_load_offset = token_idx * n_channels + thread_offset; const T* load_base_ptr = activations + base_load_offset; - const int32_t store_row = token_base_row + offsets[token_idx]; - const int32_t base_store_offset = store_row * n_channels + thread_offset; - T* store_base_ptr = moe_input + base_store_offset; + int32_t store_rows[N_TOP_K]; + T* store_base_ptrs[N_TOP_K]; +#pragma unroll + for (int i = 0; i < N_TOP_K; i++) { + const int32_t cur_expert_offset = + (assigned_experts[i] > 0) ? expert_offsets[assigned_experts[i] - 1] : 0; + store_rows[i] = cur_expert_offset + offsets[token_idx * N_TOP_K + i]; + const int32_t base_store_offset = store_rows[i] * n_channels + thread_offset; + store_base_ptrs[i] = moe_input + base_store_offset; + } #pragma unroll for (int i = 0; i < copyUnroll; i++) { @@ -138,25 +136,31 @@ __global__ void moe_scatter_kernel(T* moe_input, if (i * load_stride + thread_offset < n_channels) { mem_access::load_global(tmp_buf, load_base_ptr + i * load_stride); - mem_access::store_global(store_base_ptr + i * load_stride, - tmp_buf); +#pragma unroll + for (int j = 0; j < N_TOP_K; j++) { + mem_access::store_global( + store_base_ptrs[j] + i * load_stride, tmp_buf); + } } } - if (threadIdx.x == 0) { mapped_slots[token_idx] = store_row; } + if (threadIdx.x == 0) { + for (int i = 0; i < N_TOP_K; i++) { mapped_slots[token_idx * N_TOP_K + i] = store_rows[i]; } + } } -#define LAUNCH_FOR_UNROLL(COUNT) \ - case COUNT: \ - moe_scatter_kernel<<>>(moe_input, \ - expert_count_cumsums, \ - mapped_slots, \ - activations, \ - assignments, \ - expert_counts, \ - offsets, \ - n_channels, \ - n_experts); \ +#define LAUNCH_FOR_UNROLL(COUNT) \ + case COUNT: \ + moe_scatter_kernel \ + <<>>(moe_input, \ + expert_count_cumsums, \ + mapped_slots, \ + activations, \ + assignments, \ + expert_counts, \ + offsets, \ + n_channels, \ + n_experts); \ break; template @@ -170,6 +174,7 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { constexpr int vals_per_unroll = scatter::threads * scatter::access_granularity / sizeof(T); @@ -178,14 +183,16 @@ void launch_moe_scatter(T* moe_input, const dim3 block(scatter::threads); const dim3 grid(n_tokens); - switch (copy_unroll) { - LAUNCH_FOR_UNROLL(1); - LAUNCH_FOR_UNROLL(2); - LAUNCH_FOR_UNROLL(3); - LAUNCH_FOR_UNROLL(4); - LAUNCH_FOR_UNROLL(5); - LAUNCH_FOR_UNROLL(6); - } + TOP_K_SWITCH(n_top_k, [&] { + switch (copy_unroll) { + LAUNCH_FOR_UNROLL(1); + LAUNCH_FOR_UNROLL(2); + LAUNCH_FOR_UNROLL(3); + LAUNCH_FOR_UNROLL(4); + LAUNCH_FOR_UNROLL(5); + LAUNCH_FOR_UNROLL(6); + } + }); } #define INSTANTIATE_SCATTER_FOR_TYPE(TYPE) \ @@ -199,6 +206,7 @@ void launch_moe_scatter(T* moe_input, const int32_t, \ const int32_t, \ const int32_t, \ + const int32_t, \ cudaStream_t); INSTANTIATE_SCATTER_FOR_TYPE(__half); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh index 5c94cb0ef734..d9756c80f05a 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.cuh @@ -19,4 +19,5 @@ void launch_moe_scatter(T* moe_input, const int32_t n_channels, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py index 5cd6ae5f0fe2..7efcedb4e880 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/moe_scatter/moe_scatter.py @@ -40,13 +40,13 @@ def __call__(self, moe_input: torch.Tensor, expert_cumsum: torch.Tensor, mapped_ Scatters the hidden states such that the token stride for each expert's input is contiguous. Arguments: - moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens, hidden_size]. + moe_input (torch.Tensor): The direct input for the MoE GEMM of shape [n_tokens * n_top_k, hidden_size]. expert_cumsum (torch.Tensor): The cumulative sum of the expert counts of shape [n_experts]. - mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens]. + mapped_slots (torch.Tensor): The index of the token in the expert's input of shape [n_tokens, n_top_k]. hidden_states (torch.Tensor): The hidden states of shape [n_tokens, hidden_size]. expert_counts (torch.Tensor): The number of tokens assigned to each expert of shape [n_experts]. - assignments (torch.Tensor): The expert assignments of shape [n_tokens]. - offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens]. + assignments (torch.Tensor): The expert assignments of shape [n_tokens, n_top_k]. + offsets (torch.Tensor): The offsets into the expert for a given token of shape [n_tokens, n_top_K]. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: The MoE input (with scattered values), the cumsum of the offsets (for the MoE kernels themselves), and the assignments Tensor modified in place to show which row that token was mapped to in the input. diff --git a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp index 1c09fc52bbb1..f320f46e2620 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/ragged_ops.cpp @@ -12,7 +12,7 @@ #include "logits_gather.h" #include "moe_gather.h" #include "moe_scatter.h" -#include "top_1_gating.h" +#include "top_k_gating.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -43,6 +43,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) // moe_scatter.h m.def("moe_scatter", &moe_scatter, "MoE scatter for top-1-gating."); - // top_1_gating.h - m.def("top_1_gating", &top_1_gating, "Top-1 gating for MoE with ragged batch awareness."); + // top_k_gating.h + m.def("top_k_gating", &top_k_gating, "Top-1 gating for MoE with ragged batch awareness."); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py similarity index 69% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py index b50a0838d9f8..487735b015b0 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/__init__.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .top_1_gating import RaggedTop1Gating +from .top_k_gating import RaggedTopKGating diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp similarity index 67% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp index 55c68454b228..5eec7e2b955f 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp @@ -3,12 +3,12 @@ // DeepSpeed Team -#include "top_1_gating.h" +#include "top_k_gating.h" #include -#define DISPATCH_TOP_1_GATING(T_TYPE, C_TYPE) \ +#define DISPATCH_TOP_K_GATING(T_TYPE, C_TYPE) \ if (logits.options().dtype() == torch::T_TYPE) { \ - launch_top_1_gating((int32_t*)expert_counts.data_ptr(), \ + launch_top_k_gating((int32_t*)expert_counts.data_ptr(), \ (float*)scores.data_ptr(), \ (int32_t*)assignments.data_ptr(), \ (int32_t*)offsets.data_ptr(), \ @@ -16,14 +16,15 @@ batch_metadata_ptr, \ n_tokens, \ n_experts, \ + n_top_k, \ at::cuda::getCurrentCUDAStream()); \ return; \ } /* -Perform softmax plus atomics in order to do first pass of top_1_gating. +Perform softmax plus atomics in order to do first pass of top_k_gating. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, @@ -31,10 +32,15 @@ void top_1_gating(torch::Tensor& expert_counts, torch::Tensor& batch_metadata) { const int32_t n_tokens = scores.size(0); + const int32_t n_top_k = scores.size(1); - // Should have the same buffer size for scores and offsets + // Should have the same buffer size for scores, offsets, and assignments TORCH_CHECK(n_tokens == offsets.size(0)); TORCH_CHECK(n_tokens == logits.size(0)); + TORCH_CHECK(n_tokens == assignments.size(0)); + + TORCH_CHECK(n_top_k == offsets.size(1)); + TORCH_CHECK(n_top_k == assignments.size(1)); TORCH_CHECK(expert_counts.scalar_type() == torch::kInt32); TORCH_CHECK(scores.scalar_type() == torch::kFloat); @@ -45,11 +51,11 @@ void top_1_gating(torch::Tensor& expert_counts, const RaggedBatchDescriptor* batch_metadata_ptr = reinterpret_cast(batch_metadata.data_ptr()); - DISPATCH_TOP_1_GATING(kFloat, float) - DISPATCH_TOP_1_GATING(kHalf, __half) + DISPATCH_TOP_K_GATING(kFloat, float) + DISPATCH_TOP_K_GATING(kHalf, __half) #ifdef BF16_AVAILABLE - DISPATCH_TOP_1_GATING(kBFloat16, __nv_bfloat16) + DISPATCH_TOP_K_GATING(kBFloat16, __nv_bfloat16) #endif - TORCH_CHECK(false, "Unsupported dtype for logits in top_1_gating"); + TORCH_CHECK(false, "Unsupported dtype for logits in top_k_gating"); } diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu similarity index 59% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu index 02daee9f692e..58f95c045593 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu @@ -6,12 +6,13 @@ #include "conversion_utils.h" #include "memory_access_utils.h" #include "reduction_utils.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" +#include "top_k_utils.h" using ROp = reduce::ROpType; -template -__global__ void top_1_gating_kernel(int32_t* expert_counts, +template +__global__ void top_k_gating_kernel(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -30,8 +31,11 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, // Padding tokens do not require if (token_idx >= batch_metadata->n_tokens) { if (threadIdx.x == 0) { - offsets[token_idx] = gating::unassigned; - assignments[token_idx] = gating::unassigned; +#pragma unroll + for (int i = 0; i < TOP_K; i++) { + assignments[token_idx * TOP_K + i] = gating::unassigned; + offsets[token_idx * TOP_K + i] = gating::unassigned; + } } return; } @@ -44,34 +48,46 @@ __global__ void top_1_gating_kernel(int32_t* expert_counts, } else { reduce::init(&logit_val); } + float reduce_val = logit_val; + + int32_t local_assigned_experts[TOP_K]; + float local_assigned_logits[TOP_K]; // Training code tends to use ``torch.argmax`` to select the expert, which // which has ties broken by the lower index. Since our fused comparison algorithm // breaks ties by the higher index (since it's the lower 32-bits of the 64-bit // comparison), we invert the expert index to break ties by the lower index. int32_t inverted_expert = n_experts - expert_idx - 1; - // Perform softmax - const reduce::IdxReduceResult res = - reduce::idx_reduce(tb, warp, logit_val, inverted_expert); - // Recover the original expert index - const int32_t assigned_expert = n_experts - res.idx - 1; - const float max_logit = res.val; + // Find the top k logits + for (int i = 0; i < TOP_K; ++i) { + const reduce::IdxReduceResult res = + reduce::idx_reduce(tb, warp, reduce_val, inverted_expert); + local_assigned_experts[i] = n_experts - res.idx - 1; + local_assigned_logits[i] = res.val; + + // Set the max logit to -inf so that it is not selected again + if (threadIdx.x == n_experts - res.idx - 1) { reduce::init(&reduce_val); } + } + + const float max_logit = local_assigned_logits[0]; float softmax_sum = __expf(logit_val - max_logit); reduce::block(tb, warp, softmax_sum); - // Compute the score - const float score = __expf(max_logit - max_logit) / softmax_sum; + for (int i = 0; i < TOP_K; ++i) { + const float softmax = __expf(local_assigned_logits[i] - max_logit) / softmax_sum; - if (threadIdx.x == 0) { - scores[token_idx] = score; - assignments[token_idx] = assigned_expert; - offsets[token_idx] = atomicAdd(expert_counts + assigned_expert, 1); + if (threadIdx.x == 0) { + scores[token_idx * TOP_K + i] = softmax; + assignments[token_idx * TOP_K + i] = local_assigned_experts[i]; + offsets[token_idx * TOP_K + i] = + atomicAdd(expert_counts + local_assigned_experts[i], 1); + } } } template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -79,17 +95,20 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream) { const dim3 grid(n_tokens); const dim3 block(((n_experts + hw_warp_size - 1) / hw_warp_size) * hw_warp_size); - top_1_gating_kernel<<>>( - expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + TOP_K_SWITCH(n_top_k, [&] { + top_k_gating_kernel<<>>( + expert_counts, scores, assignments, offsets, logits, batch_metadata, n_experts); + }); } -#define INSTANTIATE_TOP_1_KERNEL(T) \ - template void launch_top_1_gating(int32_t * expert_counts, \ +#define INSTANTIATE_top_k_KERNEL(T) \ + template void launch_top_k_gating(int32_t * expert_counts, \ float* scores, \ int32_t* assignments, \ int32_t* offsets, \ @@ -97,10 +116,10 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, \ const int32_t n_tokens, \ const int32_t n_experts, \ + const int32_t n_top_k, \ cudaStream_t stream); -INSTANTIATE_TOP_1_KERNEL(float) -INSTANTIATE_TOP_1_KERNEL(__half) +INSTANTIATE_top_k_KERNEL(float) INSTANTIATE_top_k_KERNEL(__half) #ifdef BF16_AVAILABLE -INSTANTIATE_TOP_1_KERNEL(__nv_bfloat16) + INSTANTIATE_top_k_KERNEL(__nv_bfloat16) #endif diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh index c83ad56ff2f1..c525cc5f524e 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cuh +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cuh @@ -13,7 +13,7 @@ constexpr int unassigned = -1; } // namespace gating template -void launch_top_1_gating(int32_t* expert_counts, +void launch_top_k_gating(int32_t* expert_counts, float* scores, int32_t* assignments, int32_t* offsets, @@ -21,4 +21,5 @@ void launch_top_1_gating(int32_t* expert_counts, const RaggedBatchDescriptor* batch_metadata, const int32_t n_tokens, const int32_t n_experts, + const int32_t n_top_k, cudaStream_t stream); diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h similarity index 86% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h index b431f4cad30c..00840c3c93b5 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.h +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.h @@ -8,12 +8,12 @@ #include #include #include "ragged_dtypes.h" -#include "top_1_gating.cuh" +#include "top_k_gating.cuh" /* Perform softmax plus atomics to get token mapping. */ -void top_1_gating(torch::Tensor& expert_counts, +void top_k_gating(torch::Tensor& expert_counts, torch::Tensor& scores, torch::Tensor& assignments, torch::Tensor& offsets, diff --git a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py similarity index 87% rename from deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py rename to deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py index 1df97c2e9f8d..72ba2b6019bb 100644 --- a/deepspeed/inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.py +++ b/deepspeed/inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.py @@ -13,7 +13,7 @@ from deepspeed.ops.op_builder import RaggedOpsBuilder -class RaggedTop1Gating(DSKernelBase): +class RaggedTopKGating(DSKernelBase): """ CUDA implementation of top-1 gating. This will perform a softmax on the logits, and return the scale as well as its idx within that expert's allocation. @@ -26,28 +26,28 @@ def __init__(self, logit_dtype: DtypeEnum) -> None: if not isinstance(logit_dtype, DtypeEnum): logit_dtype = DtypeEnum(logit_dtype) - if logit_dtype not in RaggedTop1Gating.supported_logit_dtypes: + if logit_dtype not in RaggedTopKGating.supported_logit_dtypes: raise RuntimeError(f"Unsupported logit dtype {logit_dtype}") inf_module = RaggedOpsBuilder().load() - self.kernel = inf_module.top_1_gating + self.kernel = inf_module.top_k_gating def __call__(self, expert_counts: torch.Tensor, scores: torch.Tensor, assignments: torch.Tensor, offsets: torch.Tensor, logits: torch.Tensor, batch: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ - Perform the ragged top_1_gating. + Perform the ragged top_k_gating. Arguments: expert_counts (torch.Tensor): Tensor of 0s of shape [n_experts] to be filled with number of tokens assigned to each expert. This must be filled with 0s else the copy kernel will buffer overflow. In order to minimize the zero-fill cost, it is recommended to write to 0 during the MoE output remapping. - scores (torch.Tensor): Preallocated output of shape [n_tokens] to place expert scaling + scores (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place expert scaling value. - expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens] to place + expert_assignment (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which expert a token has been assigned to. - expert_offset (torch.Tensor): Preallocated output of shape [n_tokens] to place which + expert_offset (torch.Tensor): Preallocated output of shape [n_tokens, n_top_k] to place which offset within an experts group a token is. logits (torch.Tensor): Raw logits of gating function. batch (RaggedBatchWrapper): Batch information for ragged tensor. diff --git a/deepspeed/inference/v2/model_implementations/__init__.py b/deepspeed/inference/v2/model_implementations/__init__.py index 481be2e5940e..ab1f984fba7e 100644 --- a/deepspeed/inference/v2/model_implementations/__init__.py +++ b/deepspeed/inference/v2/model_implementations/__init__.py @@ -12,4 +12,5 @@ from .llama_v2 import * from .opt import * from .mistral import * +from .mixtral import * from .falcon import * diff --git a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py index df5f1427a5cf..8ababf567ba9 100644 --- a/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py +++ b/deepspeed/inference/v2/model_implementations/common_parameters/moe_parameters.py @@ -33,7 +33,7 @@ class UnfusedMoEMLP1Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -46,7 +46,7 @@ class UnfusedMoEMLP2Parameter(ParameterBase): and need to be joined into a single group. """ - experts: ParamList("num_experts") # noqa: F821 + experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: stacked_experts = torch.stack([p for p in self.experts], dim=0) @@ -57,13 +57,22 @@ class UnfusedMoEGatedMLPParameter(ParameterBase): """ MoE Parameter for a gated activation function in which the gating matrix is not fused in the same parameter as the non-gating matrix. + + This is a stacked version of the ``GatedMLPParameter``. Please see that class for more + documentation on the layout of the parameters. """ - gating_experts: ParamList("num_experts") # noqa: F821 + gating_experts: ParamList("n_experts") # noqa: F821 - up_experts: ParamList("num_experts") # noqa: F821 + up_experts: ParamList("n_experts") # noqa: F821 def finalize(self) -> torch.Tensor: - fused_params = [torch.cat([gate, weight], dim=0) for gate, weight in zip(self.gating_experts, self.up_experts)] - stacked_params = torch.stack(fused_params, dim=0) - return self.inference_model.transform_moe_mlp_2_param(stacked_params) + transposed_experts = [] + for gate, up in zip(self.gating_experts, self.up_experts): + assert gate.shape[0] == up.shape[0], "Gated MLP parameters must have the same number of neurons." + total_neurons = gate.shape[0] + up.shape[0] + fused_expert = torch.cat([gate, up], dim=-1).reshape(total_neurons, -1) + transposed_experts.append(fused_expert) + + stacked_experts = torch.stack(transposed_experts, dim=0) + return self.inference_model.transform_moe_mlp_1_param(stacked_experts) diff --git a/deepspeed/inference/v2/model_implementations/falcon/__init__.py b/deepspeed/inference/v2/model_implementations/falcon/__init__.py index ff66879b44be..20f37538274c 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/__init__.py +++ b/deepspeed/inference/v2/model_implementations/falcon/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .falcon_policy import FalconPolicy +from .policy import FalconPolicy diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py b/deepspeed/inference/v2/model_implementations/falcon/container.py similarity index 97% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py rename to deepspeed/inference/v2/model_implementations/falcon/container.py index f3cbe6609cdd..caccfe1ecb00 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_containers.py +++ b/deepspeed/inference/v2/model_implementations/falcon/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Falcon 7b model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py b/deepspeed/inference/v2/model_implementations/falcon/model.py similarity index 98% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_model.py rename to deepspeed/inference/v2/model_implementations/falcon/model.py index a00f754744a4..d1ccc38280a0 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_model.py +++ b/deepspeed/inference/v2/model_implementations/falcon/model.py @@ -11,12 +11,12 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * from ...ragged import RaggedBatchWrapper -from .falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNonTransformerContainer, FalconTransformerContainer class FalconInferenceModel(DSTransformerModelBase): diff --git a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py b/deepspeed/inference/v2/model_implementations/falcon/policy.py similarity index 74% rename from deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py rename to deepspeed/inference/v2/model_implementations/falcon/policy.py index 5672d45a8d13..c6612090a0df 100644 --- a/deepspeed/inference/v2/model_implementations/falcon/falcon_policy.py +++ b/deepspeed/inference/v2/model_implementations/falcon/policy.py @@ -6,10 +6,10 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.falcon.falcon_containers import FalconNonTransformerContainer, FalconTransformerContainer -from ...model_implementations.falcon.falcon_containers import FalconNewArchTransformerContainer -from ...model_implementations.falcon.falcon_model import FalconInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import FalconNonTransformerContainer, FalconTransformerContainer +from .container import FalconNewArchTransformerContainer +from .model import FalconInferenceModel class FalconPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py index 8f6a0b7fa688..e78a161b4cd0 100644 --- a/deepspeed/inference/v2/model_implementations/inference_transformer_base.py +++ b/deepspeed/inference/v2/model_implementations/inference_transformer_base.py @@ -521,12 +521,26 @@ def transform_norm_param(self, param: torch.Tensor) -> InferenceParameter: class DSMoETransformerModelBase(DSTransformerModelBase): @property - def num_experts(self) -> int: + def n_experts(self) -> int: """ Return the number of experts in the model. """ raise NotImplementedError("Attempted to access an unimplemented number of experts") + @property + def n_top_k(self) -> int: + """ + Number of experts per token. + """ + raise NotImplementedError("Attempted to access an unimplemented number of experts per token") + + @property + def normalize_expert_scores(self) -> bool: + """ + Whether to normalize expert scores. If true, sum(expert_scores) = 1. + """ + raise NotImplementedError("Attempted to access an unimplemented normalization flag") + def make_moe_layer(self) -> None: """ Instantiates the MoE layer for the model. This sets the `self.moe` attribute. @@ -538,9 +552,11 @@ def make_moe_layer(self) -> None: model_dim=self.model_dim, intermediate_features=sharded_dim, activation=self.mlp_activation_fn, - n_experts=self.num_experts, + n_experts=self.n_experts, + top_k=self.n_top_k, input_dtype=self.activation_dtype, output_dtype=self.activation_dtype, + normalize_scores=self.normalize_expert_scores, ) self.moe = heuristics.instantiate_moe(moe_config, self._engine_config) diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py index 5d2b5ae562ee..79605a76a4c2 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/__init__.py @@ -3,4 +3,4 @@ # DeepSpeed Team -from .llama_v2_policy import Llama2Policy +from .policy import Llama2Policy diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py b/deepspeed/inference/v2/model_implementations/llama_v2/container.py similarity index 95% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py rename to deepspeed/inference/v2/model_implementations/llama_v2/container.py index e9c473ce512b..9de9bdb34574 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_containers.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF Llama model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py b/deepspeed/inference/v2/model_implementations/llama_v2/model.py similarity index 83% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py rename to deepspeed/inference/v2/model_implementations/llama_v2/model.py index 9b628f77de01..b91e3258caa0 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_model.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/model.py @@ -11,12 +11,13 @@ from ...allocator import empty_from from ...inference_utils import ActivationType, DtypeEnum -from ...model_implementations import * +from .. import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper -from .llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer class Llama2InferenceModel(DSTransformerModelBase): @@ -105,6 +106,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -145,8 +167,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py similarity index 76% rename from deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py rename to deepspeed/inference/v2/model_implementations/llama_v2/policy.py index c8253be79fad..bb13ab6d5bf4 100644 --- a/deepspeed/inference/v2/model_implementations/llama_v2/llama_v2_policy.py +++ b/deepspeed/inference/v2/model_implementations/llama_v2/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.llama_v2.llama_v2_containers import Llama2NonTransformerContainer, Llama2TransformerContainer -from ...model_implementations.llama_v2.llama_v2_model import Llama2InferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import Llama2NonTransformerContainer, Llama2TransformerContainer +from .model import Llama2InferenceModel class Llama2Policy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mistral/model.py b/deepspeed/inference/v2/model_implementations/mistral/model.py index d9b06b91e308..08a9dae78e43 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/model.py +++ b/deepspeed/inference/v2/model_implementations/mistral/model.py @@ -14,6 +14,7 @@ from ...model_implementations import * from ...modules.configs import * from ...modules.interfaces import * +from ...modules import heuristics from ...ragged import RaggedBatchWrapper from .container import MistralNonTransformerContainer, MistralTransformerContainer @@ -104,6 +105,27 @@ def norm_type(self) -> NormTypeEnum: def positional_embedding_type(self) -> PositionalEmbeddingType: return PositionalEmbeddingType.rotate_half + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + """ Forward implementations """ @@ -144,8 +166,7 @@ def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_st kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=None) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=None) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/mistral/policy.py b/deepspeed/inference/v2/model_implementations/mistral/policy.py index f6d0a0fe5987..b67ec311c952 100644 --- a/deepspeed/inference/v2/model_implementations/mistral/policy.py +++ b/deepspeed/inference/v2/model_implementations/mistral/policy.py @@ -5,10 +5,10 @@ from typing import Any -from deepspeed.inference.v2.config_v2 import RaggedInferenceEngineConfig -from deepspeed.inference.v2.model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from deepspeed.inference.v2.model_implementations.mistral.container import MistralNonTransformerContainer, MistralTransformerContainer -from deepspeed.inference.v2.model_implementations.mistral.model import MistralInferenceModel +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MistralNonTransformerContainer, MistralTransformerContainer +from .model import MistralInferenceModel class MistralPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/model_implementations/mixtral/__init__.py b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py new file mode 100644 index 000000000000..2cb1aa889291 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from .policy import MixtralPolicy diff --git a/deepspeed/inference/v2/model_implementations/mixtral/container.py b/deepspeed/inference/v2/model_implementations/mixtral/container.py new file mode 100644 index 000000000000..6ec4a0552b8f --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/container.py @@ -0,0 +1,46 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +# Create a container object to save model-specific tensors using the policy file above. + +from deepspeed.inference.v2.model_implementations.common_parameters import * +from deepspeed.inference.v2.model_implementations.layer_container_base import LayerContainer + + +class MixtralTransformerContainer(LayerContainer): + + qkv_w: UnfusedQKVParameter + attn_out_w: AttentionOutputParameter + moe_gate: MoEGatingWeightParameter + moe_mlp_1: UnfusedMoEGatedMLPParameter + moe_mlp_2: UnfusedMoEMLP2Parameter + attn_norm_gamma: NormParameter + mlp_norm_gamma: NormParameter + + PARAM_MAPPING = { + "input_layernorm.weight": "attn_norm_gamma.params", + "post_attention_layernorm.weight": "mlp_norm_gamma.params", + "self_attn.q_proj.weight": "qkv_w.q_params", + "self_attn.k_proj.weight": "qkv_w.k_params", + "self_attn.v_proj.weight": "qkv_w.v_params", + "self_attn.o_proj.weight": "attn_out_w.params", + "block_sparse_moe.gate.weight": "moe_gate.params", + "block_sparse_moe.experts.*.w1.weight": "moe_mlp_1.gating_experts", + "block_sparse_moe.experts.*.w3.weight": "moe_mlp_1.up_experts", + "block_sparse_moe.experts.*.w2.weight": "moe_mlp_2.experts", + } + + +class MixtralNonTransformerContainer(LayerContainer): + + word_emb: EmbeddingParameter + word_unembed: UnembedParameter + final_norm: NormParameter + + PARAM_MAPPING = { + "model.embed_tokens.weight": "word_emb.params", + "lm_head.weight": "word_unembed.params", + "model.norm.weight": "final_norm.params", + } diff --git a/deepspeed/inference/v2/model_implementations/mixtral/model.py b/deepspeed/inference/v2/model_implementations/mixtral/model.py new file mode 100644 index 000000000000..731a907716f4 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/model.py @@ -0,0 +1,274 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Iterable, Optional, Tuple + +import torch + +import deepspeed.comm as dist + +from ...allocator import empty_from +from ...config_v2 import RaggedInferenceEngineConfig +from ...inference_utils import ActivationType, DtypeEnum +from ...model_implementations import * +from ...modules.configs import * +from ...modules.interfaces import * +from ...modules import heuristics +from ...ragged import RaggedBatchWrapper +from ..inference_model_base import ( + DSModelImplementationConfig, + MPType, +) + +from .container import MixtralNonTransformerContainer, MixtralTransformerContainer + + +class MixtralInferenceModel(DSMoETransformerModelBase): + """ + Inference model implementation for Mixtral models. + """ + + _non_transformer: Optional[MixtralNonTransformerContainer] + """ + Embed + unembed container. Specializing the type annotation. + """ + + _transformer: Optional[Iterable[MixtralTransformerContainer]] + """ + Per-layer transformer container. Specializing the type annotation. + """ + """ + Properties ineherited from `DSInferenceModelBase` + """ + + @property + def max_sequence_length(self) -> int: + return self._config.max_position_embeddings + + """ + Properties ineherited from `DSTransformerModelBase` + """ + + @property + def num_layers(self) -> int: + return self._config.num_hidden_layers + + @property + def model_dim(self) -> int: + return self._config.hidden_size + + @property + def vocab_size(self) -> int: + return self._config.vocab_size + + @property + def head_size(self) -> int: + return self.model_dim // self.n_heads + + @property + def n_heads(self) -> int: + return self._config.num_attention_heads + + @property + def intermediate_dim(self) -> int: + return self._config.intermediate_size + + @property + def n_heads_kv(self) -> int: + return self._config.num_key_value_heads + + @property + def activation_dtype(self) -> DtypeEnum: + if self._config.torch_dtype == torch.float16: + return DtypeEnum.fp16 + elif self._config.torch_dtype == torch.bfloat16: + return DtypeEnum.bf16 + else: + raise NotImplementedError("Only fp16 and bf16 are supported") + + @property + def mlp_activation_fn(self) -> ActivationType: + activation = self._config.hidden_act.lower() + if activation == "gelu": + return ActivationType.GEGLU + elif activation == "relu": + return ActivationType.ReGLU + elif activation == "gegelu": + return ActivationType.GEGLU + elif activation == "silu": + return ActivationType.SiGLU + else: + raise NotImplementedError(f"Activation {activation} not supported") + + @property + def norm_type(self) -> NormTypeEnum: + return NormTypeEnum.RMSNorm + + @property + def positional_embedding_type(self) -> PositionalEmbeddingType: + return PositionalEmbeddingType.rotate_half + + """ + Inherited from `DSMoETransformerModelBase` + """ + + @property + def n_experts(self) -> int: + return self._config.num_local_experts + + @property + def n_top_k(self) -> int: + return self._config.num_experts_per_tok + + @property + def normalize_expert_scores(self) -> bool: + return True + + """ + Model implementation + """ + + def __init__(self, config: DSModelImplementationConfig, engine_config: RaggedInferenceEngineConfig, + base_mp_group: MPType) -> None: + """ + Base implementation for initialization. By default, this will initialize + the traditional components of a transformer model: + - Embedding + - QKV projection + - Self attention + - Attention output projection + - Feed forward network + - Normalization + - Unembedding + + Arguments: + config (DSModelImplementationConfig): Model-specific configuration. No assumptions + should be made about this config that are not closely tied to the specific + model implementation. + engine_config (RaggedInferenceEngineConfig): Engine configuration. + base_mp_group (MPType): Base communication group for Tensor-parallel inference. + """ + super().__init__(config, engine_config, base_mp_group) + + self.make_norm_layer() + self.make_qkv_layer() + self.make_attn_layer() + self.make_attn_out_layer() + self.make_moe_layer() + self.make_embedding_layer() + self.make_unembedding_layer() + self._kv_cache_config = None + + def make_attn_layer(self) -> None: + """ + Builds the attention layer for the model. This sets the `self.attn` attribute. + """ + softmax_scale = 1.0 / (self.head_size**0.5) + + rotary_config = RotateHalfConfig(theta_base=self._config.rope_theta) + + attn_config = DSSelfAttentionConfig(max_tokens=self._engine_config.state_manager.max_ragged_batch_size, + n_heads_q=self.n_heads_q_local, + n_heads_kv=self.n_heads_kv_local, + head_size=self.head_size, + max_sequences=self._engine_config.state_manager.max_ragged_sequence_count, + scale_factor=softmax_scale, + input_dtype=self.activation_dtype, + output_dtype=self.activation_dtype, + positional_embedding_type=self.positional_embedding_type, + positional_embedding_config=rotary_config) + + self.attn = heuristics.instantiate_attention(attn_config, self._engine_config) + + def _forward_embed(self, ragged_batch: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs the embedding lookup prior to running the transformer of the model. + + Arguments: + ragged_batch (RaggedBatchWrapper): The batch to embed. + + Returns: + torch.Tensor: The embedded batch. + """ + embed = self.embed(ragged_batch, self._non_transformer.word_emb) + + if embed.shape[-1] != self.model_dim: + raise ValueError(f"Embedding output shape {embed.shape} does not match model_dim {self.model_dim}") + + return embed + + def _forward_transformer(self, layer_idx: int, residual: torch.Tensor, hidden_states: torch.Tensor, + ragged_batch_info: RaggedBatchWrapper) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Executes one (slightly offset) layer of the transformer. This implementation does a peak-ahead + optimization to fuse the layer norm of the next layer into the current layer. + + Arguments: + layer_idx (int): The index of the layer to execute. + residual (torch.Tensor): The residual tensor from the previous layer. + hidden_states (torch.Tensor): The hidden states from the previous layer. This is the + hidden states after pre normalization. + ragged_batch_info (RaggedBatchWrapper): The batch metadata. + """ + # TODO(cmikeh2): Distribute ragged_batch_info to all modules + + cur_params = self._transformer[layer_idx] + kv_cache = self.state_manager.get_cache(layer_idx) + + hidden_states = self.qkv(hidden_states, cur_params.qkv_w) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) + hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + residual, hidden_states = self.norm(residual, hidden_states, cur_params.mlp_norm_gamma) + + hidden_states = self.moe(hidden_states, ragged_batch_info, cur_params.moe_gate, cur_params.moe_mlp_1, + cur_params.moe_mlp_2) + + if self.tp_size > 1: + dist.all_reduce(hidden_states, group=self._base_mp_group) + + if layer_idx != self.num_layers - 1: + next_params = self._transformer[layer_idx + 1] + residual, hidden_states = self.norm(residual, hidden_states, next_params.attn_norm_gamma) + else: + # On last layer, we just need to perform the residual add. Adding into the residual + # here is safe. + residual.add_(hidden_states) + + return residual, hidden_states + + def _forward_unembed(self, hidden_states: torch.Tensor, ragged_batch_info: RaggedBatchWrapper) -> torch.Tensor: + """ + Performs unembedding of the hidden states to logits. This will only sample the final + token of each sequence. + """ + logits = self.unembed(hidden_states, self._non_transformer.word_unembed, ragged_batch_info, + self._non_transformer.final_norm) + + if self.tp_size > 1: + comm_buffer = empty_from(self._comm_logits, (self.tp_size, logits.shape[0], logits.shape[1])) + full_logits = empty_from(self._return_logits, (logits.shape[0], self.vocab_size)) + + dist.all_gather_into_tensor(comm_buffer, logits, group=self._base_mp_group) + + full_logits.copy_(comm_buffer.permute(1, 0, 2).reshape(logits.shape[0], self.vocab_size)) + + return full_logits + else: + return logits + + def forward(self, wrapped_batch: RaggedBatchWrapper) -> torch.Tensor: + + residual = self._forward_embed(wrapped_batch) + + residual, hidden_states = self.norm(residual, None, self._transformer[0].attn_norm_gamma, beta=None) + + for layer_idx in range(self.num_layers): + residual, hidden_states = self._forward_transformer(layer_idx, residual, hidden_states, wrapped_batch) + + return self._forward_unembed(residual, wrapped_batch) diff --git a/deepspeed/inference/v2/model_implementations/mixtral/policy.py b/deepspeed/inference/v2/model_implementations/mixtral/policy.py new file mode 100644 index 000000000000..2f0087919720 --- /dev/null +++ b/deepspeed/inference/v2/model_implementations/mixtral/policy.py @@ -0,0 +1,31 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +from typing import Any + +from ...config_v2 import RaggedInferenceEngineConfig +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import MixtralTransformerContainer, MixtralNonTransformerContainer +from .model import MixtralInferenceModel + + +class MixtralPolicy(InferenceV2Policy): + + def instantiate_model(self, engine_config: RaggedInferenceEngineConfig, mp_group: Any) -> MixtralInferenceModel: + return MixtralInferenceModel(config=self._model_config, engine_config=engine_config, base_mp_group=mp_group) + + def build_container_map(self) -> ContainerMap: + + map = ContainerMap() + + transformer_containers = [MixtralTransformerContainer(self.model) for _ in range(self.model.num_layers)] + + map.set_transformer_params(['model.layers'], transformer_containers) + + map.set_non_transformer_params(MixtralNonTransformerContainer(self.model)) + + map.set_unmapped_params([]) + + return map diff --git a/deepspeed/inference/v2/model_implementations/opt/container.py b/deepspeed/inference/v2/model_implementations/opt/container.py index 5ddbbde3f141..e97599ef8e50 100644 --- a/deepspeed/inference/v2/model_implementations/opt/container.py +++ b/deepspeed/inference/v2/model_implementations/opt/container.py @@ -5,8 +5,8 @@ # Create a container object to save model-specific tensors using the policy file above. -from ...model_implementations.common_parameters import * -from ...model_implementations.layer_container_base import LayerContainer +from ..common_parameters import * +from ..layer_container_base import LayerContainer ''' # HF OPT model looks like this: diff --git a/deepspeed/inference/v2/model_implementations/opt/model.py b/deepspeed/inference/v2/model_implementations/opt/model.py index fa221e15a0b7..8bd26ba044e5 100644 --- a/deepspeed/inference/v2/model_implementations/opt/model.py +++ b/deepspeed/inference/v2/model_implementations/opt/model.py @@ -131,8 +131,7 @@ def _forward_transformer_layer(self, layer_idx: int, residual: torch.Tensor, hid kv_cache = self.state_manager.get_cache(layer_idx) hidden_states = self.qkv(hidden_states, cur_params.qkv_w, b=cur_params.qkv_b) - hidden_states = self.attn(hidden_states, kv_cache, - ragged_batch_info) #, inv_freqs=None) #cur_params.rotary_emb) + hidden_states = self.attn(hidden_states, kv_cache, ragged_batch_info) hidden_states = self.attn_out(hidden_states, cur_params.attn_out_w, b=cur_params.attn_out_b) if self.tp_size > 1: diff --git a/deepspeed/inference/v2/model_implementations/opt/policy.py b/deepspeed/inference/v2/model_implementations/opt/policy.py index af5750260ead..d57d5beb48d5 100644 --- a/deepspeed/inference/v2/model_implementations/opt/policy.py +++ b/deepspeed/inference/v2/model_implementations/opt/policy.py @@ -6,9 +6,9 @@ from typing import Any from ...config_v2 import RaggedInferenceEngineConfig -from ...model_implementations.inference_policy_base import ContainerMap, InferenceV2Policy -from ...model_implementations.opt.container import OPTNonTransformerContainer, OPTTransformerContainer -from ...model_implementations.opt.model import OPTInferenceModel +from ..inference_policy_base import ContainerMap, InferenceV2Policy +from .container import OPTNonTransformerContainer, OPTTransformerContainer +from .model import OPTInferenceModel class OPTPolicy(InferenceV2Policy): diff --git a/deepspeed/inference/v2/modules/configs/__init__.py b/deepspeed/inference/v2/modules/configs/__init__.py index 19b9fb99ddea..3429e69b47de 100644 --- a/deepspeed/inference/v2/modules/configs/__init__.py +++ b/deepspeed/inference/v2/modules/configs/__init__.py @@ -3,7 +3,12 @@ # DeepSpeed Team -from .attention_configs import (DSSelfAttentionConfig, PositionalEmbeddingType, MaskingType) +from .attention_configs import ( + DSSelfAttentionConfig, + PositionalEmbeddingType, + MaskingType, + RotateHalfConfig, +) from .embedding_config import DSEmbeddingsConfig from .linear_config import DSLinearConfig from .moe_config import DSMoEConfig diff --git a/deepspeed/inference/v2/modules/configs/attention_configs.py b/deepspeed/inference/v2/modules/configs/attention_configs.py index bcdc3d2613d5..823104b13fc2 100644 --- a/deepspeed/inference/v2/modules/configs/attention_configs.py +++ b/deepspeed/inference/v2/modules/configs/attention_configs.py @@ -4,10 +4,11 @@ # DeepSpeed Team from enum import Enum -from typing import Dict +from typing import Dict, Optional from ...inference_utils import DtypeEnum from ...modules.ds_module import DSModuleConfig +from deepspeed.runtime.config_utils import DeepSpeedConfigModel class PositionalEmbeddingType(Enum): @@ -25,6 +26,20 @@ class PositionalEmbeddingType(Enum): alibi = "alibi" +class RotateHalfConfig(DeepSpeedConfigModel): + + use_trained_freqs: bool = False + """ + Whether to use a passed `trained_freqs` tensor for the attention implementation + or to use default synthesized frequencies. + """ + + theta_base: float = 10_000.0 + """ + Base for theta. This will only be used if `use_trained_freqs` is False. + """ + + class MaskingType(Enum): # No masking @@ -79,4 +94,9 @@ class DSSelfAttentionConfig(DSModuleConfig): positional_embedding_type: PositionalEmbeddingType = PositionalEmbeddingType.none # Positional embedding args - positional_embedding_args: Dict = {} + positional_embedding_config: Optional[RotateHalfConfig] = None + """ + To extend this for the other positional embedding types, we would need to add + new configs for each type (as necessary) and annotate this with the + Union[RotateHalfConfig, OtherConfig, ...] type. + """ diff --git a/deepspeed/inference/v2/modules/configs/moe_config.py b/deepspeed/inference/v2/modules/configs/moe_config.py index 1a88d54af19f..7bc944f55e17 100644 --- a/deepspeed/inference/v2/modules/configs/moe_config.py +++ b/deepspeed/inference/v2/modules/configs/moe_config.py @@ -48,3 +48,9 @@ class DSMoEConfig(DSModuleConfig): """ Activation function of the first MLP1 """ + + normalize_scores: bool = False + """ + Whether normalization is applied to the selected scores. If true, the module + should rescale the scores such that their sum is 1.0. + """ diff --git a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py index bb482f0c58d6..b2727ffca620 100644 --- a/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py +++ b/deepspeed/inference/v2/modules/implementations/attention/dense_blocked_attention.py @@ -68,9 +68,16 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st Args: config (DSSelfAttentionConfig): The self attention config for all attention DSModules. - implementation_config (Dict[str, Any]): The implementation config for this DSModule may - contain a `trained_freqs` key. If passed, the implementation will expect a `trained_freqs` - tensor in the `forward` method and will not synthesize the frequencies internally. + implementation_config (Dict[str, Any]): + There are two (dependent) potential components in the implementtion config. + + 1. `trained_freqs` - If the embedding weights for RoPE are trained, the implementation + config should contain {'trained_freqs': True}. This will mean the implementation will + expect a `trained_freqs` tensor in the `forward` method and will not synthesize the + values internally. + + 2. `theta_base` - The base value for synthesized frequencies in the rotary embeddings. + This will only be used if `trained_freqs` is False or not present in the `implementation_config`. If this is not included, the default value of 10000.0 will be used. """ super().__init__(config, implementation_config) @@ -79,14 +86,13 @@ def __init__(self, config: DSSelfAttentionConfig, implementation_config: Dict[st self._kv_copy = LinearBlockedKVCopy(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) elif embed_type == PositionalEmbeddingType.rotate_half: - use_trained_freqs = "trained_freqs" in self._config.positional_embedding_args and self._config.positional_embedding_args[ - "trained_freqs"] - if use_trained_freqs: + if config.positional_embedding_config.use_trained_freqs: self._kv_copy = BlockedTrainedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, self._config.n_heads_kv, self._config.input_dtype) else: + theta_base = config.positional_embedding_config.theta_base self._kv_copy = BlockedRotaryEmbeddings(self._config.head_size, self._config.n_heads_q, - self._config.n_heads_kv, self._config.input_dtype) + self._config.n_heads_kv, self._config.input_dtype, theta_base) self._softmax_scale = self._config.scale_factor diff --git a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py index e43a737515ed..38c0000d7f78 100644 --- a/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py +++ b/deepspeed/inference/v2/modules/implementations/moe/cutlass_multi_gemm.py @@ -9,12 +9,12 @@ from deepspeed.accelerator import get_accelerator from ....allocator import empty_from -from ....inference_utils import ActivationType -from ....kernels.core_ops import BlasLibLinear +from ....inference_utils import ActivationType, is_gated +from ....kernels.core_ops import BlasLibLinear, CUDAGatedActivation from ....kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from ....ragged import RaggedBatchWrapper @@ -42,11 +42,7 @@ def supports_config(config: DSMoEConfig) -> bool: if config.input_dtype != torch.float16 and config.input_dtype != torch.bfloat16: return False - if config.top_k != 1: - return False - - if config.activation in [ActivationType.GEGLU, ActivationType.ReGLU, ActivationType.SiGLU]: - # Currently not supporting gated activations in MoE + if config.top_k != 1 and config.top_k != 2: return False return True @@ -57,15 +53,24 @@ def __init__(self, config: DSMoEConfig, implementation_config: Dict[str, Any]) - # Convenience variables for frequently accessed items. self.max_tokens = self._config.max_tokens self.n_experts = self._config.n_experts + self.n_top_k = self._config.top_k self.intermediate_dim = self._config.intermediate_features - self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=config.activation) + moe_op_act_fn = ActivationType.IDENTITY if is_gated(self._config.activation) else self._config.activation + + self._mlp_1 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=moe_op_act_fn) self._mlp_2 = MoEGEMM(fp_dtype=implementation_config['weight_dtype'], act_fn=ActivationType.IDENTITY) + if is_gated(self._config.activation): + self._activation = CUDAGatedActivation(self._config.model_dim, self._config.input_dtype, + self._config.activation) + else: + self._activation = None + self._gate_proj = BlasLibLinear(self._config.input_dtype) - self._top_1_gate = RaggedTop1Gating(config.input_dtype) + self._top_1_gate = RaggedTopKGating(config.input_dtype) self._moe_scatter = MoEScatter(config.input_dtype, config.model_dim) - self._moe_gather = MoEGather(config.input_dtype, config.model_dim) + self._moe_gather = MoEGather(config.input_dtype, config.model_dim, config.normalize_scores) self._create_buffers() @@ -78,32 +83,38 @@ def _create_buffers(self): self._expert_counts = torch.empty((self.n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - self._scores = torch.empty((self._config.max_tokens, ), + self._scores = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) - self._assignments = torch.empty((self._config.max_tokens, ), + self._assignments = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - self._offsets = torch.empty((self._config.max_tokens, ), + self._offsets = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # Scatter buffers - self._moe_input = torch.empty((self._config.max_tokens, self._config.model_dim), + self._moe_input = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.input_dtype, device=get_accelerator().current_device()) self._expert_cumsum = torch.empty((self._config.n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - self._mapped_slots = torch.empty((self._config.max_tokens, ), + self._mapped_slots = torch.empty((self._config.max_tokens, self.n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) # GEMM Buffers - self._intermediate = torch.empty((self._config.max_tokens, self._config.intermediate_features), + self._intermediate = torch.empty((self._config.max_tokens * self.n_top_k, self._config.intermediate_features), dtype=self._config.output_dtype, device=get_accelerator().current_device()) - self._output_unordered = torch.empty((self._config.max_tokens, self._config.model_dim), + if self._activation is not None: + self._gated_intermediate = torch.empty( + (self._config.max_tokens * self.n_top_k, self._config.intermediate_features * 2), + dtype=self._config.output_dtype, + device=get_accelerator().current_device()) + + self._output_unordered = torch.empty((self._config.max_tokens * self.n_top_k, self._config.model_dim), dtype=self._config.output_dtype, device=get_accelerator().current_device()) @@ -167,11 +178,11 @@ def _gate(self, hidden_states: torch.Tensor, batch_metadata: RaggedBatchWrapper, # Get views on the buffers for gating logits = empty_from(self._logits, (hidden_states.shape[0], self._logits.shape[-1])) - scores = empty_from(self._scores, (hidden_states.shape[0], )) - assignments = empty_from(self._assignments, (hidden_states.shape[0], )) - offsets = empty_from(self._offsets, (hidden_states.shape[0], )) - mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], )) - moe_input = empty_from(self._moe_input, (hidden_states.shape[0], self._moe_input.shape[-1])) + scores = empty_from(self._scores, (hidden_states.shape[0], self.n_top_k)) + assignments = empty_from(self._assignments, (hidden_states.shape[0], self.n_top_k)) + offsets = empty_from(self._offsets, (hidden_states.shape[0], self.n_top_k)) + mapped_slots = empty_from(self._mapped_slots, (hidden_states.shape[0], self.n_top_k)) + moe_input = empty_from(self._moe_input, (hidden_states.shape[0] * self.n_top_k, self._moe_input.shape[-1])) self._gate_proj(logits, hidden_states, gate_w) self._expert_counts.zero_() @@ -200,18 +211,31 @@ def forward(self, moe_input, expert_cumsum, scores, mapped_slots = self._gate(hidden_states, batch_metadata, gate_w) # Get views on the buffers for GEMM - intermediate = empty_from(self._intermediate, (hidden_states.shape[0], self._intermediate.shape[-1])) + intermediate = empty_from(self._intermediate, + (hidden_states.shape[0] * self.n_top_k, self._intermediate.shape[-1])) output_unordered = empty_from(self._output_unordered, - (hidden_states.shape[0], self._output_unordered.shape[-1])) + (hidden_states.shape[0] * self.n_top_k, self._output_unordered.shape[-1])) output = empty_from(self._output, (hidden_states.shape[0], self._output.shape[-1])) - self._mlp_1( - intermediate, - moe_input, - mlp_1_w, - expert_cumsum, - mlp_1_b, - ) + if self._activation is not None: + gated_intermediate = empty_from( + self._gated_intermediate, (hidden_states.shape[0] * self.n_top_k, self._gated_intermediate.shape[-1])) + self._mlp_1( + gated_intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) + self._activation(intermediate, gated_intermediate) + else: + self._mlp_1( + intermediate, + moe_input, + mlp_1_w, + expert_cumsum, + mlp_1_b, + ) self._mlp_2( output_unordered, diff --git a/op_builder/ragged_ops.py b/op_builder/ragged_ops.py index 13d71b476b5a..8cb372e96c37 100644 --- a/op_builder/ragged_ops.py +++ b/op_builder/ragged_ops.py @@ -73,8 +73,8 @@ def sources(self): "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cpp", "inference/v2/kernels/ragged_ops/moe_gather/moe_gather.cu", "inference/v2/kernels/ragged_ops/ragged_helpers/ragged_kernel_helpers.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cpp", - "inference/v2/kernels/ragged_ops/top_1_gating/top_1_gating.cu", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cpp", + "inference/v2/kernels/ragged_ops/top_k_gating/top_k_gating.cu", ] prefix = self.get_prefix() @@ -101,12 +101,13 @@ def include_paths(self): 'inference/v2/kernels/ragged_ops/atom_builder', 'inference/v2/kernels/ragged_ops/blocked_flash', 'inference/v2/kernels/ragged_ops/embed', + 'inference/v2/kernels/ragged_ops/includes', 'inference/v2/kernels/ragged_ops/linear_blocked_kv_rotary', 'inference/v2/kernels/ragged_ops/logits_gather', 'inference/v2/kernels/ragged_ops/moe_gather', 'inference/v2/kernels/ragged_ops/moe_scatter', 'inference/v2/kernels/ragged_ops/ragged_helpers', - 'inference/v2/kernels/ragged_ops/top_1_gating', + 'inference/v2/kernels/ragged_ops/top_k_gating', ] prefix = self.get_prefix() diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py index 5fa375b49c19..3907fc3e3a4b 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_gather.py @@ -11,18 +11,28 @@ from deepspeed.inference.v2.kernels.ragged_ops import ( MoEGather, MoEScatter, - RaggedTop1Gating, + RaggedTopKGating, ) from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` and +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` and ``MoEScatter`` to produce correct inputs. If either of these kernels is broken these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CASES = [ + # (n_tokens, n_experts, n_top_k) + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] -def build_inputs(n_tokens, n_experts, do_padding): + +def build_inputs(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): assert n_tokens <= 2048, "This test will break if n_tokens > 2048" @@ -39,22 +49,28 @@ def build_inputs(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) @@ -63,11 +79,12 @@ def build_inputs(n_tokens, n_experts, do_padding): @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_gather(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CASES) +@pytest.mark.parametrize("do_padding", [False]) +def test_moe_gather(n_tokens: int, n_experts: int, n_top_k: int, do_padding: bool): + get_accelerator().manual_seed(0xC0FFEE) - batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, do_padding) + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) @@ -75,9 +92,31 @@ def test_moe_gather(n_tokens, n_experts, do_padding): gather(output, moe_input, scores, mapped_slots, expert_counts) for token_idx in range(n_tokens): + effective_score = scores[token_idx].sum().item() assert torch.equal( output[token_idx], torch.full((4096, ), - token_idx * scores[token_idx], + token_idx * effective_score, dtype=torch.float16, device=get_accelerator().current_device())) + + +@pytest.mark.inference_v2_ops +def test_moe_gather_normalize_scales(): + get_accelerator().manual_seed(0xC0FFEE) + + n_tokens = 72 + n_experts = 8 + n_top_k = 2 + do_padding = False + + batch, moe_input, scores, mapped_slots, expert_counts = build_inputs(n_tokens, n_experts, n_top_k, do_padding) + output = torch.randn((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + + gather = MoEGather(DtypeEnum.fp16, 4096, normalize_scores=True) + gather(output, moe_input, scores, mapped_slots, expert_counts) + + for token_idx in range(n_tokens): + assert torch.equal( + output[token_idx], + torch.full((4096, ), token_idx, dtype=torch.float16, device=get_accelerator().current_device())) diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py index 4ca051410c1c..aae459f06a6f 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_moe_scatter.py @@ -8,19 +8,28 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import MoEScatter, RaggedTopKGating from .ragged_testing_utils import build_simple_batch """ -For simplicity's sake, these tests do rely on ``RaggedTop1Gating`` to produce correct -inputs. If ``RaggedTop1Gating`` is broken, these tests will fail, so double check +For simplicity's sake, these tests do rely on ``RaggedTopKGating`` to produce correct +inputs. If ``RaggedTopKGating`` is broken, these tests will fail, so double check the unit test results there before debugging here. """ +TEST_CONFIGS = [ + (13, 64, 1), + (278, 64, 1), + (1977, 64, 1), + (13, 8, 2), + (278, 8, 2), + (1977, 8, 2), +] + @pytest.mark.inference_v2_ops -@pytest.mark.parametrize("n_tokens, n_experts", [(13, 64), (278, 64), (1977, 64)]) -@pytest.mark.parametrize("do_padding", [True, False]) -def test_moe_scatter(n_tokens, n_experts, do_padding): +@pytest.mark.parametrize("n_tokens, n_experts, n_top_k", TEST_CONFIGS) +@pytest.mark.parametrize("do_padding", [False, True]) +def test_moe_scatter(n_tokens, n_experts, n_top_k, do_padding): # Sequence composition shouldn't matter here batch = build_simple_batch([n_tokens], padding=do_padding) @@ -35,40 +44,52 @@ def test_moe_scatter(n_tokens, n_experts, do_padding): device=get_accelerator().current_device()).repeat_interleave(4096, dim=0).reshape( batch.tensor_toks, 4096).contiguous() - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) # Gating outputs expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((batch.tensor_toks, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((batch.tensor_toks, ), + scores = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((batch.tensor_toks, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) # Scatter outputs - moe_input = torch.empty((batch.tensor_toks, 4096), dtype=torch.float16, device=get_accelerator().current_device()) + moe_input = torch.empty((batch.tensor_toks * n_top_k, 4096), + dtype=torch.float16, + device=get_accelerator().current_device()) expert_cumsum = torch.empty((n_experts, ), dtype=torch.int64, device=get_accelerator().current_device()) - mapped_slots = torch.empty((batch.tensor_toks, ), dtype=torch.int32, device=get_accelerator().current_device()) + mapped_slots = torch.empty((batch.tensor_toks, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) scatter = MoEScatter(DtypeEnum.fp16, 4096) scatter(moe_input, expert_cumsum, mapped_slots, hidden_states, expert_counts, expert_assignment, expert_offset) + get_accelerator().synchronize() assert torch.equal(expert_cumsum, torch.cumsum(expert_counts, dim=0).to(torch.int64)) + if not do_padding: + assert torch.unique(mapped_slots).size(0) == n_top_k * n_tokens + for token_idx in range(batch.tensor_toks): if token_idx < n_tokens: - expert_idx = expert_assignment[token_idx].item() - if expert_idx == 0: - expert_cumsum_val = 0 - else: - expert_cumsum_val = expert_cumsum[expert_idx - 1] - offset = expert_offset[token_idx] - total_offset = offset + expert_cumsum_val - - assert total_offset == mapped_slots[token_idx].item() - assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) + for k in range(n_top_k): + expert_idx = expert_assignment[token_idx][k].item() + if expert_idx == 0: + expert_cumsum_val = 0 + else: + expert_cumsum_val = expert_cumsum[expert_idx - 1] + offset = expert_offset[token_idx][k] + total_offset = offset + expert_cumsum_val + + assert total_offset == mapped_slots[token_idx][k].item() + assert torch.equal(moe_input[total_offset], hidden_states[token_idx]) else: - assert mapped_slots[token_idx].item() == -1 + for k in range(n_top_k): + assert mapped_slots[token_idx][k].item() == -1 - assert expert_cumsum[-1] == n_tokens + assert expert_cumsum[-1] == n_tokens * n_top_k diff --git a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py similarity index 51% rename from tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py rename to tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py index 6ff2508bf320..5fa0c8a079f0 100644 --- a/tests/unit/inference/v2/kernels/ragged_ops/test_top_1_gating.py +++ b/tests/unit/inference/v2/kernels/ragged_ops/test_top_k_gating.py @@ -9,9 +9,52 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.inference_utils import DtypeEnum -from deepspeed.inference.v2.kernels.ragged_ops import RaggedTop1Gating +from deepspeed.inference.v2.kernels.ragged_ops import RaggedTopKGating from .ragged_testing_utils import build_simple_batch -from ....v2.inference_test_utils import allclose +from ...inference_test_utils import allclose + + +def _top_k_gating_testing_helper(n_tokens: int, n_experts: int, n_top_k: int, seed: int = 0xC0FFEE) -> None: + + torch.manual_seed(seed) + logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) + batch = build_simple_batch([n_tokens], padding=False) + gate = RaggedTopKGating(DtypeEnum.fp16) + + expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + + gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + + ref_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + ref_scores, ref_indices = torch.topk(ref_weights, n_top_k, dim=-1) + + assert allclose(scores, ref_scores), f"expected {ref_scores}, got {scores}" + assert torch.equal(expert_assignment, + ref_indices.to(torch.int32)), f"expected {ref_indices}, got {expert_assignment}" + assert expert_counts.sum( + ) == n_tokens * n_top_k, f"expected {n_tokens * n_top_k} tokens, got {expert_counts.sum()}" + + # Ensure that the expert offsets are unique + for i in range(n_experts): + expert_idxs = torch.where(expert_assignment == i, expert_offset, 0) + if expert_counts[i] > 0: + assert expert_idxs.unique().shape[0] == expert_counts[ + i], f"expected {expert_counts[i]} unique offsets, got {expert_idxs.unique().shape[0]}" + assert expert_idxs.max( + ) == expert_counts[i] - 1, f"expected max offset {expert_counts[i] - 1}, got {expert_idxs.max()}" + else: + # Should have all 0's so one unique value + assert expert_idxs.unique().shape[0] == 1 + assert expert_idxs.max() == 0 + + +@pytest.mark.inference_v2_ops +@pytest.mark.parametrize('n_tokens', [1, 17, 32, 89, 433]) +def test_top_2_e_8_gating(n_tokens: int) -> None: + _top_k_gating_testing_helper(n_tokens=n_tokens, n_experts=8, n_top_k=2) def _test_single_mapping_helper(n_tokens: int, @@ -19,6 +62,8 @@ def _test_single_mapping_helper(n_tokens: int, assigned_expert: int, logit_fill: float = 0.0, match_fill: float = 1.0) -> None: + + n_top_k = 1 logits = torch.full((n_tokens, n_experts), logit_fill, dtype=torch.float16, @@ -26,12 +71,12 @@ def _test_single_mapping_helper(n_tokens: int, logits[:, assigned_expert] = match_fill - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -39,7 +84,7 @@ def _test_single_mapping_helper(n_tokens: int, assert expert_counts[assigned_expert] == n_tokens assert torch.all(expert_assignment == assigned_expert) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, assigned_expert].reshape(-1, n_top_k)) @pytest.mark.inference_v2_ops @@ -72,6 +117,7 @@ def test_determinism(): n_tokens = 512 n_experts = 64 + n_top_k = 1 logits = torch.zeros((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) @@ -79,13 +125,15 @@ def test_determinism(): logits[:, 19] = 1.0 logits[:, 26] = 1.0 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) for _ in range(1024): expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), + dtype=torch.int32, + device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) @@ -94,7 +142,7 @@ def test_determinism(): assert expert_counts[26] == 0 assert torch.all(expert_assignment == 19) assert torch.unique(expert_offset).shape[0] == n_tokens - assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19]) + assert allclose(scores, F.softmax(logits.float(), dim=1)[:, 19].reshape(-1, 1)) @pytest.mark.inference_v2_ops @@ -105,16 +153,19 @@ def test_score_accuracy(n_tokens: int, n_experts: int) -> None: """ logits = torch.randn((n_tokens, n_experts), dtype=torch.float16, device=get_accelerator().current_device()) batch = build_simple_batch([n_tokens], padding=False) + n_top_k = 1 - gate = RaggedTop1Gating(DtypeEnum.fp16) + gate = RaggedTopKGating(DtypeEnum.fp16) expert_counts = torch.zeros((n_experts, ), dtype=torch.int32, device=get_accelerator().current_device()) - scores = torch.empty((n_tokens, ), dtype=torch.float32, device=get_accelerator().current_device()) - expert_assignment = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) - expert_offset = torch.empty((n_tokens, ), dtype=torch.int32, device=get_accelerator().current_device()) + scores = torch.empty((n_tokens, n_top_k), dtype=torch.float32, device=get_accelerator().current_device()) + expert_assignment = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) + expert_offset = torch.empty((n_tokens, n_top_k), dtype=torch.int32, device=get_accelerator().current_device()) ref_scores = F.softmax(logits.float(), dim=1).max(dim=1).values + ref_scores = ref_scores.reshape(-1, 1) gate(expert_counts, scores, expert_assignment, expert_offset, logits, batch) + assert allclose(scores, ref_scores) assert expert_counts.sum() == n_tokens diff --git a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py index 260236562ee9..06ff9047d648 100644 --- a/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py +++ b/tests/unit/inference/v2/model_implementations/parameters/test_parameter_list.py @@ -26,7 +26,7 @@ def __init__(self, experts_per_rank: int) -> None: self._num_experts = experts_per_rank @property - def num_experts(self) -> int: + def n_experts(self) -> int: return self._num_experts @on_device diff --git a/tests/unit/inference/v2/modules/test_blocked_attn.py b/tests/unit/inference/v2/modules/test_blocked_attn.py index 215ad64636b1..6556aa460a44 100644 --- a/tests/unit/inference/v2/modules/test_blocked_attn.py +++ b/tests/unit/inference/v2/modules/test_blocked_attn.py @@ -12,7 +12,7 @@ from deepspeed.accelerator import get_accelerator from deepspeed.inference.v2.modules import ConfigBundle -from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType +from deepspeed.inference.v2.modules.configs import DSSelfAttentionConfig, PositionalEmbeddingType, RotateHalfConfig from deepspeed.inference.v2.modules.interfaces import DSSelfAttentionRegistry, DSSelfAttentionBase from ..kernels.ragged_ops.ragged_testing_utils import build_batch_and_manager @@ -37,13 +37,10 @@ def _blocked_flash_testing_helper(head_size: int, """ if trained_freqs is None: embed_type = PositionalEmbeddingType.none - embed_args = {} + embed_args = None else: embed_type = PositionalEmbeddingType.rotate_half - if trained_freqs: - embed_args = {'trained_freqs': True} - else: - embed_args = {'trained_freqs': False} + embed_args = RotateHalfConfig(use_trained_freqs=trained_freqs) attn_config = DSSelfAttentionConfig(max_tokens=2048, n_heads_q=n_heads_q, @@ -51,7 +48,7 @@ def _blocked_flash_testing_helper(head_size: int, head_size=head_size, max_sequences=32, positional_embedding_type=embed_type, - positional_embedding_args=embed_args) + positional_embedding_config=embed_args) config = ConfigBundle(name='dense_blocked_attention', config=attn_config) attn_module: DSSelfAttentionBase = DSSelfAttentionRegistry.instantiate_config(config) diff --git a/tests/unit/inference/v2/modules/test_cutlass_moe.py b/tests/unit/inference/v2/modules/test_cutlass_moe.py index e21170c9ed8f..b14ba127c6be 100644 --- a/tests/unit/inference/v2/modules/test_cutlass_moe.py +++ b/tests/unit/inference/v2/modules/test_cutlass_moe.py @@ -212,3 +212,117 @@ def test_in_out_channels(in_channels: int, out_channels: int) -> None: dtype=DtypeEnum.fp16, activation_type=ActivationType.IDENTITY, use_bias=True) + + +def _mixtral_moe_baseline(hidden_states: torch.Tensor, + gate_weight: torch.Tensor, + mlp_w1: torch.Tensor, + mlp_w2: torch.Tensor, + mlp_w3: torch.Tensor, + force_float: bool = False) -> torch.Tensor: + """ + Baseline implementation for mixtral MoE module. + + Based on transformers implementation: https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py + """ + output_dtype = hidden_states.dtype + if force_float: + hidden_states = hidden_states.float() + gate_weight = gate_weight.float() + mlp_w1 = mlp_w1.float() + mlp_w2 = mlp_w2.float() + mlp_w3 = mlp_w3.float() + + router_logits = torch.nn.functional.linear(hidden_states, gate_weight) + routing_weights = torch.nn.functional.softmax(router_logits, dim=-1, dtype=torch.float) + routing_weights, selected_experts = routing_weights.topk(k=2, dim=-1) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # NOTE(cmikeh2): This is a difference implementation, ours will preserve the original scale + # as float32 and perform in-kernel fused FP16->FP32->FP16 conversion. + routing_weights = routing_weights.to(hidden_states.dtype) + + final_hidden_states = torch.zeros_like(hidden_states) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=gate_weight.shape[0]).permute(2, 1, 0) + get_accelerator().synchronize() + + for expert_idx in range(gate_weight.shape[0]): + exp_mlp_w1 = mlp_w1[expert_idx] + exp_mlp_w2 = mlp_w2[expert_idx] + exp_mlp_w3 = mlp_w3[expert_idx] + + idx, top_x = torch.where(expert_mask[expert_idx]) + + if top_x.shape[0] == 0: + continue + + top_x_list = top_x.tolist() + idx_list = idx.tolist() + + current_state = hidden_states[top_x_list] + + linear = torch.nn.functional.linear + intermediate = torch.nn.functional.silu(linear(current_state, exp_mlp_w1)) * linear(current_state, exp_mlp_w3) + output = linear(intermediate, exp_mlp_w2) * routing_weights[top_x_list, idx_list].unsqueeze(-1) + final_hidden_states.index_add_(0, top_x, output.to(final_hidden_states.dtype)) + + return final_hidden_states.to(output_dtype) + + +@pytest.mark.inference_v2_ops +def test_mixtral_moe_config(): + + experts = 8 + n_top_k = 2 + in_channels = 4096 + intermediate_dim = 2048 + dtype = DtypeEnum.bf16 + + # Parameters + gate_weight = torch.randn( + (experts, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + mlp_w1 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w3 = torch.randn( + (experts, intermediate_dim, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + mlp_w2 = torch.randn( + (experts, in_channels, intermediate_dim), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + n_tokens = 256 + hidden_states = torch.randn( + (n_tokens, in_channels), dtype=dtype.value, device=get_accelerator().current_device()) * .1 + + baseline = _mixtral_moe_baseline(hidden_states, gate_weight, mlp_w1, mlp_w2, mlp_w3) + + mlp_w13_fused = torch.cat([mlp_w1, mlp_w3], dim=-1).reshape(experts, 2 * intermediate_dim, in_channels) + + config = DSMoEConfig(max_tokens=4096, + model_dim=in_channels, + intermediate_features=intermediate_dim, + n_experts=experts, + activation=ActivationType.SiGLU, + input_dtype=dtype, + output_dtype=dtype, + top_k=n_top_k, + normalize_scores=True) + + implementation_config = {"weight_dtype": DtypeEnum(dtype)} + + bundle = ConfigBundle(name='cutlass_multi_gemm_moe', config=config, implementation_config=implementation_config) + moe_module = DSMoERegistry.instantiate_config(bundle) + + batch = build_simple_batch([n_tokens]) + + gate_ds = moe_module.transform_gate_param(gate_weight) + mlp_w1_ds = moe_module.transform_moe_mlp_1_param(mlp_w13_fused) + mlp_w2_ds = moe_module.transform_moe_mlp_2_param(mlp_w2) + + output = moe_module(hidden_states, batch, gate_ds, mlp_w1_ds, mlp_w2_ds) + + # NOTE(cmikeh2): These are higher than the other tests for reasons that aren't quite + # clear to me. My best guess is that the SiGLU activation is causing larger numerical + # divergence. The thresholds chosen here is based on the observed error between the + # float and bfloat16 reference implementations. + assert allclose(output, baseline.to(dtype.value), tolerances=(5e-2, 5e-2))