From 8164ea9e6d35dfbbe62b0f15448daf05268b9e81 Mon Sep 17 00:00:00 2001 From: Reza Yazdani <44502768+RezaYazdaniAminabadi@users.noreply.github.com> Date: Tue, 24 May 2022 13:27:50 -0700 Subject: [PATCH] Fixing several bugs in the inference-api and the kernels (#1951) Co-authored-by: Jeff Rasley --- .github/workflows/amd.yml | 11 + .github/workflows/nv-torch12-p40.yml | 9 + .github/workflows/nv-torch18-v100.yml | 11 + .../inference/csrc/apply_rotary_pos_emb.cu | 8 +- csrc/transformer/inference/csrc/gelu.cu | 45 +- .../transformer/inference/csrc/pt_binding.cpp | 300 ++++++++- csrc/transformer/inference/csrc/transform.cu | 587 ++++++++++++++++++ .../inference/includes/custom_cuda_layers.h | 31 + deepspeed/module_inject/replace_module.py | 10 +- deepspeed/module_inject/replace_policy.py | 27 +- .../inference/transformer_inference.py | 146 +++-- op_builder/transformer_inference.py | 1 + requirements/requirements-dev.txt | 1 + tests/unit/test_inference.py | 34 + 14 files changed, 1087 insertions(+), 134 deletions(-) create mode 100644 csrc/transformer/inference/csrc/transform.cu create mode 100644 tests/unit/test_inference.py diff --git a/.github/workflows/amd.yml b/.github/workflows/amd.yml index 7925ee545acb..91e8825b83e1 100644 --- a/.github/workflows/amd.yml +++ b/.github/workflows/amd.yml @@ -37,12 +37,23 @@ jobs: python -c "import torch; print('CUDA available:', torch.cuda.is_available())" sudo apt-get update sudo apt-get install -y libaio-dev + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip install . + # Runs a set of commands using the runners shell - name: Install deepspeed run: | sudo /opt/conda/bin/pip install .[dev,1bit,autotuning] #python -c "from deepspeed.env_report import cli_main; cli_main()" ds_report + # Runs a set of commands using the runners shell - name: Unit tests run: | diff --git a/.github/workflows/nv-torch12-p40.yml b/.github/workflows/nv-torch12-p40.yml index b9a5c4112194..080543df6980 100644 --- a/.github/workflows/nv-torch12-p40.yml +++ b/.github/workflows/nv-torch12-p40.yml @@ -32,6 +32,15 @@ jobs: python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip install . + - name: Install deepspeed run: | pip install .[dev,autotuning] diff --git a/.github/workflows/nv-torch18-v100.yml b/.github/workflows/nv-torch18-v100.yml index da8caa54774e..0afac798119a 100644 --- a/.github/workflows/nv-torch18-v100.yml +++ b/.github/workflows/nv-torch18-v100.yml @@ -32,10 +32,21 @@ jobs: pip install torch==1.8.2+cu111 torchvision==0.9.2+cu111 -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html python -c "import torch; print('torch:', torch.__version__, torch)" python -c "import torch; print('CUDA available:', torch.cuda.is_available())" + + - name: Install transformers + run: | + git clone https://github.com/huggingface/transformers + cd transformers + # if needed switch to the last known good SHA until transformers@master is fixed + # git checkout 1cc453d33 + git rev-parse --short HEAD + pip install . + - name: Install deepspeed run: | pip install .[dev,1bit,autotuning,sparse_attn] ds_report + - name: Unit tests run: | unset TORCH_CUDA_ARCH_LIST # only jit compile for current arch diff --git a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu index 175854b8860b..8a34bb2017f1 100644 --- a/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu +++ b/csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu @@ -4,6 +4,7 @@ #include #endif +namespace cg = cooperative_groups; namespace cg = cooperative_groups; __global__ void apply_rotary_pos_emb(float* mixed_query, @@ -153,7 +154,9 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, int lane = id & 0x1f; unsigned head_id = blockIdx.x * MAX_WARP_NUM + gid; + unsigned seq_index = head_id % seq_len; unsigned offset = head_id * head_size; + unsigned k_offset = (seq_index + (head_id / seq_len) * MAX_OUT_TOKES) * head_size; constexpr unsigned mask[32] = { 0x1 | 0x1000, 0x2 | 0x2000, 0x4 | 0x4000, 0x8 | 0x8000, 0x10 | 0x10000, @@ -171,7 +174,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, float inv_freq = (float)((lane % half_dim) * 2) / (float)rotary_dim; inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; float q = (float)mixed_query[offset + lane]; - float k = (float)key_layer[offset + lane]; + float k = (float)key_layer[k_offset + lane]; float rotary_sign = (lane > (half_dim - 1) ? -1.0 : 1.0); float q_rot = (q * rotary_sign); float k_rot = (k * rotary_sign); @@ -183,7 +186,7 @@ __global__ void apply_rotary_pos_emb1(__half* mixed_query, k = k * cosf(inv_freq) + k_rot_tmp * sinf(inv_freq); mixed_query[offset + lane] = (__half)q; - key_layer[offset + lane] = (__half)k; + key_layer[k_offset + lane] = (__half)k; lane += WARP_SIZE; } @@ -237,6 +240,7 @@ template void launch_apply_rotary_pos_emb<__half>(__half*, bool, bool, cudaStream_t); + /* __global__ void apply_rotary_pos_emb(float* mixed_query, float* key_layer, diff --git a/csrc/transformer/inference/csrc/gelu.cu b/csrc/transformer/inference/csrc/gelu.cu index 70bbf42cf9ed..00c70eea22b5 100644 --- a/csrc/transformer/inference/csrc/gelu.cu +++ b/csrc/transformer/inference/csrc/gelu.cu @@ -317,12 +317,18 @@ __global__ void gptj_residual_add(float* input, float4 out = output_cast[offset]; float4 res_vec = attn_cast[offset]; float4 bias_data = bias_cast[offset % intermediate_size]; - float4 attn_bias = attnbias_cast[offset % intermediate_size]; - data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x + attn_bias.x); - data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y + attn_bias.y); - data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z + attn_bias.z); - data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w + attn_bias.w); + if (attnbias) { + float4 attn_bias = attnbias_cast[offset % intermediate_size]; + data.x += attn_bias.x; + data.y += attn_bias.y; + data.z += attn_bias.z; + data.w += attn_bias.w; + } + data.x = data.x * mp_size + (out.x + res_vec.x + bias_data.x); + data.y = data.y * mp_size + (out.y + res_vec.y + bias_data.y); + data.z = data.z * mp_size + (out.z + res_vec.z + bias_data.z); + data.w = data.w * mp_size + (out.w + res_vec.w + bias_data.w); output_cast[offset] = data; } @@ -354,13 +360,11 @@ __global__ void gptj_residual_add(__half* input, float2 res_vec = attn_cast[offset]; float2 bias_vec = bias_cast[offset % intermediate_size]; - float2 attn_bias_vec = attnbias_cast[offset % intermediate_size]; __half2* vals_half = reinterpret_cast<__half2*>(&vals_vec); __half2* out_half = reinterpret_cast<__half2*>(&out_vec); __half2* res_half = reinterpret_cast<__half2*>(&res_vec); __half2* bias_half = reinterpret_cast<__half2*>(&bias_vec); - __half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec); float2 low_data = __half22float2(vals_half[0]); float2 high_data = __half22float2(vals_half[1]); @@ -373,18 +377,21 @@ __global__ void gptj_residual_add(__half* input, float2 low_bias = __half22float2(bias_half[0]); float2 high_bias = __half22float2(bias_half[1]); - - float2 attn_low_bias = __half22float2(attnbias_half[0]); - float2 attn_high_bias = __half22float2(attnbias_half[1]); - - low_data.x = - low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x + attn_low_bias.x)); - low_data.y = - low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y + attn_low_bias.y)); - high_data.x = - high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x + attn_high_bias.x)); - high_data.y = - high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y + attn_high_bias.y)); + if (attn_bias) { + float2 attn_bias_vec = attnbias_cast[offset % intermediate_size]; + __half2* attnbias_half = reinterpret_cast<__half2*>(&attn_bias_vec); + float2 attn_low_bias = __half22float2(attnbias_half[0]); + float2 attn_high_bias = __half22float2(attnbias_half[1]); + low_data.x += attn_low_bias.x; + low_data.y += attn_low_bias.y; + high_data.x += attn_high_bias.x; + high_data.y += attn_high_bias.y; + } + + low_data.x = low_data.x * mp_size + (low_out.x + low_res.x + (low_bias.x)); + low_data.y = low_data.y * mp_size + (low_out.y + low_res.y + (low_bias.y)); + high_data.x = high_data.x * mp_size + (high_out.x + high_res.x + (high_bias.x)); + high_data.y = high_data.y * mp_size + (high_out.y + high_res.y + (high_bias.y)); vals_half[0] = __float22half2_rn(low_data); vals_half[1] = __float22half2_rn(high_data); diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index 320e6491b1cd..f6999e3c92a1 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -8,8 +8,6 @@ std::array gemm_algos = std::array({99, 99, 99}); -#define MAX_OUT_TOKES 10 - template at::Tensor ds_softmax(at::Tensor& attn_scores, at::Tensor& attn_mask, @@ -52,9 +50,11 @@ template void allocate_workspace(size_t hidden_dim, size_t max_seq_len, size_t batch_size, + unsigned num_layers, size_t head_size = 128) { - size_t _workSpaceSize = (hidden_dim * batch_size * max_seq_len); + size_t _workSpaceSize = 16 * (hidden_dim * batch_size * max_seq_len) + + (num_layers * batch_size * max_seq_len * hidden_dim * 2); // KV-cache Context::Instance().GenWorkSpace(_workSpaceSize * sizeof(T)); } @@ -71,7 +71,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) float gemm_beta = 0.0; if (!workspace) { - allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0)); + allocate_workspace(W.size(1), MAX_OUT_TOKES, Q.size(0), 1); workspace = (T*)Context::Instance().GetWorkSpace(); } @@ -170,19 +170,19 @@ void attention_unfused(at::Tensor& prev_key_cont, } template -std::vector ds_softmax_context(at::Tensor& query, - at::Tensor& prev_key, - at::Tensor& new_key, - at::Tensor& attn_mask, - at::Tensor& prev_value, - at::Tensor& new_value, - int heads, - float norm_factor, - bool merging, - bool triangular, - bool local_attention, - int window_size, - bool no_masking) +std::vector ds_softmax_context1(at::Tensor& query, + at::Tensor& prev_key, + at::Tensor& new_key, + at::Tensor& attn_mask, + at::Tensor& prev_value, + at::Tensor& new_value, + int heads, + float norm_factor, + bool merging, + bool triangular, + bool local_attention, + int window_size, + bool no_masking) { auto query_cont = query.contiguous(); auto prev_key_cont = prev_key.contiguous(); @@ -222,6 +222,211 @@ std::vector ds_softmax_context(at::Tensor& query, return {output, prev_key, prev_value}; } +template +void ds_softmax_internal(T* attn_scores, + at::Tensor& attn_mask, + bool triangular, + bool recompute, + bool local_attention, + int window_size, + int bsz, + int seq_len, + int soft_len, + int heads) +{ + launch_attn_softmax_v2((T*)attn_scores, + (attn_mask.sizes().size() > 1 ? (T*)attn_mask.data_ptr() : nullptr), + triangular, + recompute, + local_attention, + window_size, + bsz, + heads, + seq_len, + soft_len, + 1.0, + at::cuda::getCurrentCUDAStream()); +} + +template +void attention_unfused(T* prev_key_cont, + T* query_cont, + at::Tensor& attn_mask, + T* prev_value_cont, + T* output, + unsigned& bsz, + int& k, + unsigned& seq_len, + unsigned& soft_len, + int& heads, + float& norm_factor, + bool triangular, + bool recompute, + bool local_attention, + int window_size) +{ + float alpha = norm_factor * norm_factor; + float gemm_beta = 0.0; + T* workspace = (T*)output + bsz * seq_len * heads * k; + + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + soft_len, + seq_len, + k, + &alpha, + &gemm_beta, + (T*)prev_key_cont, + (T*)query_cont, + workspace, + CUBLAS_OP_T, + CUBLAS_OP_N, + MAX_OUT_TOKES * k, + seq_len * k, + seq_len * soft_len, + bsz * heads, +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif + ds_softmax_internal(workspace, + attn_mask, + triangular, + recompute, + local_attention, + window_size, + bsz, + seq_len, + soft_len, + heads); + alpha = 1.0; + cublas_strided_batched_gemm(Context::Instance().GetCublasHandle(), + k, + seq_len, + soft_len, + &alpha, + &gemm_beta, + (T*)prev_value_cont, + workspace, + (T*)output, + CUBLAS_OP_N, + CUBLAS_OP_N, + MAX_OUT_TOKES * k, + seq_len * soft_len, + seq_len * k, + bsz * heads, +#ifdef __HIP_PLATFORM_HCC__ + rocblas_gemm_algo_standard); +#else + CUBLAS_GEMM_DEFAULT_TENSOR_OP); +#endif +} + +template +std::vector ds_softmax_context(at::Tensor& query_key_value, + at::Tensor& attn_mask, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int heads, + float norm_factor, + bool triangular, + bool local_attention, + int window_size, + bool no_masking, + unsigned layer_id, + unsigned num_layers) +{ + unsigned bsz = query_key_value.size(0); + unsigned seq_len = query_key_value.size(1); + unsigned hidden_dim = query_key_value.size(2) / 3; + + bool is_prompt = (seq_len > 1); + + if (is_prompt) Context::Instance().reset_tokens(seq_len); + unsigned soft_len = Context::Instance().current_tokens(); + + int k = hidden_dim / heads; + auto options = at::TensorOptions() + .dtype(query_key_value.options().dtype()) + .layout(at::kStrided) + .device(at::kCUDA) + .requires_grad(false); + + T* workspace = (T*)Context::Instance().GetWorkSpace(); + size_t buf_size = bsz * seq_len * hidden_dim; + auto output = torch::from_blob(workspace + 4 * buf_size, {bsz, seq_len, hidden_dim}, options); + + auto query_cont = workspace + 8 * buf_size; + size_t offset = + 16 * (hidden_dim * bsz * MAX_OUT_TOKES) + layer_id * 2 * bsz * MAX_OUT_TOKES * hidden_dim; + + unsigned all_tokens = soft_len; + auto kv_cache = workspace + offset + (hidden_dim / heads) * (is_prompt ? 0 : soft_len - 1); + size_t value_offset = bsz * MAX_OUT_TOKES * hidden_dim; + + T* temp_buf = (T*)output.data_ptr() + at::numel(output); + launch_bias_add_transform_0213((T*)query_cont, + kv_cache, + kv_cache + value_offset, + (T*)query_key_value.data_ptr(), + nullptr, + bsz, + seq_len, + (is_prompt ? 0 : soft_len - 1), + soft_len, + hidden_dim, + heads, + rotary_dim, + rotate_half, + rotate_every_two, + Context::Instance().GetCurrentStream(), + 3); + if (rotary_dim > 0 && rotate_half) + launch_apply_rotary_pos_emb(query_cont, + kv_cache, + k, + seq_len, + rotary_dim, + (is_prompt ? 0 : soft_len - 1), + heads, + bsz, + rotate_half, + rotate_every_two, + Context::Instance().GetCurrentStream()); + + attention_unfused(workspace + offset, + (T*)query_cont, + attn_mask, + workspace + offset + value_offset, + temp_buf, + bsz, + k, + seq_len, + all_tokens, + heads, + norm_factor, + (triangular && is_prompt), + is_prompt, + local_attention, + window_size); + launch_transform4d_0213((T*)output.data_ptr(), + temp_buf, + bsz, + heads, + seq_len, + output.size(2), + Context::Instance().GetCurrentStream(false), + 1); + + if (layer_id == num_layers - 1) Context::Instance().advance_tokens(); + auto prev_key = torch::from_blob(workspace + offset, {bsz, all_tokens, hidden_dim}, options); + auto prev_value = + torch::from_blob(workspace + offset + value_offset, {bsz, all_tokens, hidden_dim}, options); + return {output, prev_key, prev_value}; +} + template at::Tensor ds_bias_gelu(at::Tensor& input, at::Tensor& bias) { @@ -271,6 +476,24 @@ at::Tensor ds_layernorm(at::Tensor& input_cont, at::Tensor& gamma, at::Tensor& b return inp_norm; } +template +void ds_layernorm_internal(T* workspace, + at::Tensor& input, + at::Tensor& gamma, + at::Tensor& betta, + float epsilon) +{ + int bsz = input.size(0) * input.size(1); + launch_layer_norm(workspace, + (T*)input.data_ptr(), + (T*)gamma.data_ptr(), + (T*)betta.data_ptr(), + epsilon, + bsz, + input.size(2), + Context::Instance().GetCurrentStream()); +} + template at::Tensor qkv_unfused_cublas(at::Tensor& output, at::Tensor& input, @@ -281,13 +504,15 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, const float epsilon, bool add_bias) { - auto inp_norm = ds_layernorm(input, gamma, beta, epsilon); - + int bsz = input.size(0) * input.size(1); + T* workspace = (T*)Context::Instance().GetWorkSpace(); + workspace += (3 * input.size(0) * MAX_OUT_TOKES * input.size(2)); + ds_layernorm_internal(workspace, input, gamma, beta, epsilon); // cudaEventRecord(Context::Instance().GetCompEvent(1), Context::Instance().GetCurrentStream()); float alpha = (T)1.0; float gemm_beta = (T)0.0; - int bsz = input.size(0) * input.size(1); + cublasSetStream(Context::Instance().GetCublasHandle(), Context::Instance().GetCurrentStream()); cublas_gemm_ex(Context::Instance().GetCublasHandle(), CUBLAS_OP_N, @@ -298,7 +523,7 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, &alpha, &gemm_beta, (T*)weight.data_ptr(), - (T*)inp_norm.data_ptr(), + workspace, (T*)output.data_ptr(), #ifdef __HIP_PLATFORM_HCC__ rocblas_gemm_algo_standard); @@ -311,7 +536,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, weight.size(1), bsz, Context::Instance().GetCurrentStream()); - return inp_norm; + + return torch::from_blob(workspace, input.sizes(), input.options()); } template @@ -321,19 +547,27 @@ std::vector ds_qkv_gemm(at::Tensor& input, at::Tensor& gamma, at::Tensor& beta, const float epsilon, - bool add_bias) + bool add_bias, + unsigned num_layers) { - auto input_cont = input.contiguous(); + int bsz = input.size(0) * input.size(1); + int out_size = weight.size(1); + T* workspace = (T*)Context::Instance().GetWorkSpace(); + if (!workspace) { + cublasSetStream(Context::Instance().GetCublasHandle(), + Context::Instance().GetCurrentStream()); + allocate_workspace(input.size(2), MAX_OUT_TOKES, input.size(0), num_layers); + workspace = (T*)Context::Instance().GetWorkSpace(); + } auto options = at::TensorOptions() - .dtype(input_cont.options().dtype()) + .dtype(input.options().dtype()) .layout(at::kStrided) .device(at::kCUDA) .requires_grad(false); - auto output = at::empty({input_cont.size(0), input_cont.size(1), weight.size(1)}, options); - int bsz = input_cont.size(0) * input_cont.size(1); + auto output = at::from_blob(workspace, {input.size(0), input.size(1), weight.size(1)}, options); auto inp_norm = - qkv_unfused_cublas(output, input_cont, weight, bias, gamma, beta, epsilon, add_bias); + qkv_unfused_cublas(output, input, weight, bias, gamma, beta, epsilon, add_bias); return {output, inp_norm}; } @@ -757,7 +991,8 @@ void residual_add_bias(at::Tensor& output, at::Tensor& output_b, at::Tensor& attention_b, int mp_size, - bool mlp_after_attn) + bool mlp_after_attn, + bool add_bias) { int bsz = input.size(0) * input.size(1); int hidden_size = input.size(2); @@ -779,7 +1014,7 @@ void residual_add_bias(at::Tensor& output, (float*)output.data_ptr(), (float*)attention_output.data_ptr(), (float*)output_b.data_ptr(), - (float*)attention_b.data_ptr(), + (float*)(add_bias ? attention_b.data_ptr() : nullptr), hidden_size, bsz, mp_size, @@ -799,7 +1034,7 @@ void residual_add_bias(at::Tensor& output, (__half*)output.data_ptr(), (__half*)attention_output.data_ptr(), (__half*)output_b.data_ptr(), - (__half*)attention_b.data_ptr(), + (__half*)(add_bias ? attention_b.data_ptr() : nullptr), hidden_size, bsz, mp_size, @@ -910,6 +1145,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("softmax_context_fp16", &ds_softmax_context<__half>, "DeepSpeed attention with fp32 (CUDA)"); + m.def("softmax_context_int8", + &ds_softmax_context1<__half>, + "DeepSpeed attention with fp32 (CUDA)"); m.def("bias_gelu_fp32", &ds_bias_gelu, "DeepSpeed Gelu with fp32 (CUDA)"); m.def("bias_gelu_fp16", &ds_bias_gelu<__half>, "DeepSpeed Gelu with fp32 (CUDA)"); m.def("bias_residual_fp32", diff --git a/csrc/transformer/inference/csrc/transform.cu b/csrc/transformer/inference/csrc/transform.cu new file mode 100644 index 000000000000..dd7adb7a0508 --- /dev/null +++ b/csrc/transformer/inference/csrc/transform.cu @@ -0,0 +1,587 @@ +#ifndef __HIP_PLATFORM_HCC__ +#include +#endif +#include "custom_cuda_layers.h" +namespace cg = cooperative_groups; + +// Bias add + +__global__ void bias_add_transform_0213(float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + unsigned seq_offset, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + // int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = + reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + unsigned seq_id = d1 + seq_offset; + float4 inputs = vals_vec[d3]; + int lane = d3 & 0x1f; + if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { + float4 q = vals_vec[d3]; + float2* q_f = reinterpret_cast(&q); + if (rotate_every_two) { +#pragma unroll + for (int o = 0; o < 2; o++) { + float inv_freq = (float)(((d3 << 1) + o) * 2) / (float)(rotary_dim << 2); + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + q_f[o].x = (-1.0 * q_f[o].y * sinf(inv_freq) + q_f[o].x * cosf(inv_freq)); + q_f[o].y = (q_f[o].x * sinf(inv_freq) + q_f[o].y * cosf(inv_freq)); + } + } + output_vec[d3] = q; + } else + output_vec[d3] = inputs; +} + +#define ATTN_H 3 +#define MAX_SEQ_LINE 10 + +__global__ void bias_add_transform_0213(__half* output, // q + __half* k_cache, + __half* v_cache, + const __half* vals, // qkv + const __half* bias, + int hidden_dim, + int seq_length, + unsigned seq_offset, + int all_tokens, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + unsigned half_dim = (rotary_dim << 3) >> 1; + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + int d2_out_stride = d2_stride * (cnt == 0 ? seq_length : MAX_OUT_TOKES); + float4 vals_arr; + float4 output_arr; + + __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); + __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + float4* output_vec = + reinterpret_cast(cnt == 0 ? output : (cnt == 1 ? k_cache : v_cache)); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + unsigned seq_id = d1 + seq_offset; + + int lane = d3 & 0x1f; + if (cnt < 2 && rotary_dim > 0 && d3 < rotary_dim) { + float4 q = vals_vec[d3]; + __half2* q_h = reinterpret_cast<__half2*>(&q); + if (rotate_every_two) { +#pragma unroll + for (int o = 0; o < 4; o++) { + float inv_freq = (float)(((d3 << 2) + o) * 2) / (float)(rotary_dim << 3); + inv_freq = 1.0 / powf(10000.0, inv_freq) * (float)seq_id; + float q_data[2]; + q_data[0] = (float)q_h[o].x; + q_data[1] = (float)q_h[o].y; + q_h[o].x = (__half)(-1.0 * q_data[1] * sinf(inv_freq) + q_data[0] * cosf(inv_freq)); + q_h[o].y = (__half)(q_data[0] * sinf(inv_freq) + q_data[1] * cosf(inv_freq)); + } + } + output_vec[d3] = q; + } else + output_vec[d3] = vals_vec[d3]; + +#endif +} + +// [B S C*H] - > C * [B A S N] +template <> +void launch_bias_add_transform_0213(float* output, + float* k_cache, + float* v_cache, + const float* vals, + const float* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(batch_size, seq_length, (trans_count * head_ext)); + + bias_add_transform_0213<<>>(output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + heads, + rotary_dim >> 2, + rotate_half, + rotate_every_two, + head_ext); +} +template +void launch_bias_add_transform_0213(T* outputs, + T* vals, + T* vals1, + const T* vals2, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int seq_length1, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count); +template <> +void launch_bias_add_transform_0213<__half>(__half* output, + __half* k_cache, + __half* v_cache, + const __half* vals, + const __half* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int all_tokens, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + int head_ext = 1; // (hidden_dim - 1) / MAX_THREADS + 1; + dim3 block_dim(hidden_dim / heads, (heads / head_ext)); + dim3 grid_dim(1, seq_length, (trans_count * head_ext)); + bias_add_transform_0213<<>>(output, + k_cache, + v_cache, + vals, + bias, + hidden_dim, + seq_length, + seq_offset, + all_tokens, + heads, + rotary_dim >> 3, + rotate_half, + rotate_every_two, + head_ext); +} + +// Bias add +template +__global__ void bias_add_transform_0213(T* output, + const T* vals, + const T* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext); + +template <> +__global__ void bias_add_transform_0213(float* output, + const float* vals, + const float* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + float4 inputs = vals_vec[d0 * d0_stride * (gridDim.z / head_ext) + cnt * d1_stride + + d1 * d1_stride * (gridDim.z / head_ext) + d2 * d2_stride + d3]; + float4 biases = bias_vec[cnt * d1_stride + d2 * d2_stride + d3]; + + float4 outputs; + outputs.x = inputs.x + biases.x; + outputs.y = inputs.y + biases.y; + outputs.z = inputs.z + biases.z; + outputs.w = inputs.w + biases.w; + + output_vec[cnt * d0_out_stride * gridDim.x + d0 * d0_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride + d3] = outputs; +} + +template <> +__global__ void bias_add_transform_0213<__half>(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads, + int head_ext) +{ +#ifdef HALF_PRECISION_AVAILABLE + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = blockIdx.z / head_ext; // Hidden count + int d2 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr; + float4 bias_arr; + float4 output_arr; + __half2* vals_half = reinterpret_cast<__half2*>(&vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(&bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(&output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + vals_vec += (d0 * d0_stride * (gridDim.z / head_ext)); + vals_vec += (d1 * d1_stride * (gridDim.z / head_ext)); + vals_vec += (cnt * d1_stride); + vals_vec += (d2 * d2_stride); + + bias_vec += (cnt * d1_stride); + bias_vec += (d2 * d2_stride); + + output_vec += (cnt * d0_stride * gridDim.x); + output_vec += (d1 * d2_stride); + output_vec += (d0 * d0_stride); + output_vec += (d2 * d2_out_stride); + + bias_arr = bias_vec[d3]; + vals_arr = vals_vec[d3]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + output_vec[d3] = output_arr; + +#endif +} + +__global__ void bias_add_transform_0213_v2(__half* output, + const __half* vals, + const __half* bias, + int hidden_dim, + int seq_length, + int heads) +{ +#ifdef HALF_PRECISION_AVAILABLE + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + int iteration_stride = d1_stride * blockDim.z; // Hidden * 3 / 8 + int batch_stride = d0_stride * blockDim.z; // Hidden * S * 3 / 8 + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = d2_stride * seq_length; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y; // Sequence ID (0-127) + int cnt = threadIdx.z; // blockIdx.z; // Hidden count + int d2 = threadIdx.y; // Head (0-11) + int d3 = threadIdx.x; // Values (groups of 4) + + float4 vals_arr[1]; + float4 bias_arr[1]; + float4 output_arr[1]; + __half2* vals_half = reinterpret_cast<__half2*>(vals_arr); + __half2* bias_half = reinterpret_cast<__half2*>(bias_arr); + __half2* output_half = reinterpret_cast<__half2*>(output_arr); + + const float4* vals_vec = reinterpret_cast(vals); + const float4* bias_vec = reinterpret_cast(bias); + float4* output_vec = reinterpret_cast(output); + + int iter_index = cnt * d1_stride + d2 * d2_stride + d3; + int input_offset = d0 * batch_stride + d1 * (iteration_stride << 1); + bias_arr[0] = bias_vec[iter_index]; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + vals_arr[0] = vals_vec[input_offset + iter_id]; + + output_half[0] = vals_half[0] + bias_half[0]; + output_half[1] = vals_half[1] + bias_half[1]; + output_half[2] = vals_half[2] + bias_half[2]; + output_half[3] = vals_half[3] + bias_half[3]; + + in_data[iter_id] = output_arr[0]; + } + __syncthreads(); + + iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_out_stride * gridDim.x); + int head_count = (d2 >> 1) + cnt * (blockDim.y >> 1); + + int out_index = d0 * d0_out_stride + d1 * (d1_out_stride << 1) + d3 + (d2 % 2) * d2_stride; + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = (iter * iteration_stride) + head_count; + int iter_offset = + (iter_row % blockDim.y) * d2_out_stride + (iter_row / blockDim.y) * matrix_stride; + output_vec[out_index + iter_offset] = + in_data[iter_row * d2_stride + d3 + (d2 % 2) * (d1_stride * blockDim.z)]; + } +#endif +} + +template +__global__ void transform4d_0213(T* out, + const T* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext); + +template <> +__global__ void transform4d_0213(float* out, + const float* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ + int d0_stride = hidden_dim * seq_length; + int d1_stride = d0_stride / heads; + int d2_stride = hidden_dim / heads; + + int d0_out_stride = d0_stride; + int d1_out_stride = d2_stride; + int d2_out_stride = hidden_dim; + + int d0 = blockIdx.x; // Batch + int d1 = blockIdx.y / ((seq_length - 1) / blockDim.y + 1); // Head + int d2 = (threadIdx.y + blockDim.y * blockIdx.y) % seq_length; + int cnt = blockIdx.z; + int d3 = threadIdx.x; // Values (groups of 8) + + if (d2 < seq_length) { + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + float4 vals_vec = in_vec[cnt * d0_stride * gridDim.x + d0 * d0_stride + d1 * d1_stride + + d2 * d2_stride + d3]; + out_vec[d0 * d0_out_stride * gridDim.z + cnt * d2_out_stride + d1 * d1_out_stride + + d2 * d2_out_stride * gridDim.z + d3] = vals_vec; + } +} + +template <> +__global__ void transform4d_0213<__half>(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim, + int head_ext) +{ +#if __CUDA_ARCH__ >= 700 + + int d0_stride = hidden_dim * (seq_length / head_ext); + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y + (blockIdx.z % head_ext) * (heads / head_ext); // Head + int d2 = blockIdx.z / head_ext; // Sequence + int cnt = blockIdx.y; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + in_vec += (cnt * d0_stride * gridDim.x); + in_vec += (d0 * d0_stride); + in_vec += (d2 * d2_stride); + in_vec += (d1 * d2_stride * seq_length); + + out_vec += (cnt * d1_stride); + out_vec += (d1 * d2_stride); + out_vec += (d0 * d0_stride * gridDim.y); + out_vec += (d2 * d1_stride * gridDim.y); + + out_vec[d3] = in_vec[d3]; + +#endif +} + +__global__ void transform4d_0213_v2(__half* out, + const __half* in, + int heads, + int seq_length, + int hidden_dim) +{ +#if __CUDA_ARCH__ >= 700 + __shared__ float4 in_data[3072]; + + int d0_stride = hidden_dim * seq_length; + int d1_stride = hidden_dim; + int d2_stride = hidden_dim / heads; + + int d0 = blockIdx.x; // Batch + int d1 = threadIdx.y; // Head + int d2 = blockIdx.y; // Sequence + int cnt = threadIdx.z; // Hidden count + int d3 = threadIdx.x; // Values (groups of 8) + + const float4* in_vec = reinterpret_cast(in); + float4* out_vec = reinterpret_cast(out); + + int input_offset = d0 * d0_stride + d2 * (d2_stride << 1) + d3 + (d1 % 2) * d2_stride; + int head_count = (d1 >> 1) + cnt * (blockDim.y >> 1); + int iteration_stride = blockDim.z * (blockDim.y >> 1); + int matrix_stride = (d0_stride * gridDim.x); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_row = iter * iteration_stride + head_count; + int iter_offset = (iter_row % blockDim.y) * d2_stride; + + in_data[d3 + iter_offset + (iter_row / blockDim.y + (d1 % 2) * blockDim.z) * d1_stride] = + in_vec[input_offset + iter_offset * seq_length + + (iter_row / blockDim.y) * matrix_stride]; + } + __syncthreads(); + + iteration_stride = d1_stride * blockDim.z; + int iter_index = cnt * d1_stride + d1 * d2_stride + d3; + int output_offset = d0 * d0_stride * blockDim.z + d2 * (iteration_stride << 1); + +#pragma unroll + for (int iter = 0; iter < 2; iter++) { + int iter_id = iter * iteration_stride + iter_index; + out_vec[output_offset + iter_id] = in_data[iter_id]; + } +#endif +} + +// 3 * [B A S N] - > [B S C*H] +template <> +void launch_transform4d_0213(float* out, + const float* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 2; + dim3 grid_dims(batch_size, heads * ((seq_length - 1) / 8 + 1), trans_count); + dim3 block_dims(hidden_dim / heads, 8); + transform4d_0213 + <<>>(out, in, heads, seq_length, hidden_dim, 1); +} + +template <> +void launch_transform4d_0213<__half>(__half* out, + const __half* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count) +{ + hidden_dim >>= 3; + int head_ext = (hidden_dim - 1) / MAX_THREADS + 1; + dim3 grid_dims(batch_size, trans_count, (seq_length * head_ext)); + dim3 block_dims(hidden_dim / heads, (heads / head_ext)); + transform4d_0213<__half> + <<>>(out, in, heads, seq_length, hidden_dim, head_ext); +} diff --git a/csrc/transformer/inference/includes/custom_cuda_layers.h b/csrc/transformer/inference/includes/custom_cuda_layers.h index 06b4340061c9..c8a0b79a111b 100644 --- a/csrc/transformer/inference/includes/custom_cuda_layers.h +++ b/csrc/transformer/inference/includes/custom_cuda_layers.h @@ -17,8 +17,11 @@ #include #include +#define MAX_OUT_TOKES 1024 #define MAX_WARP_NUM 32 #define WARP_SIZE 32 + +#define MAX_THREADS 1024 #define SMs 80 #define MAX_REGISTERS 256 @@ -122,3 +125,31 @@ void launch_moe_res_matmul(T* residual, int seq_len, int hidden_dim, cudaStream_t stream); + +// 4D transform [0, 1, 2, 3] -> [0, 2, 1, 3] +template +void launch_transform4d_0213(T* out, + const T* in, + int batch_size, + int heads, + int seq_length, + int hidden_dim, + cudaStream_t stream, + int trans_count); +template +void launch_bias_add_transform_0213(T* outputs, + T* vals, + T* vals1, + const T* vals2, + const T* bias, + int batch_size, + int seq_length, + unsigned seq_offset, + int seq_length1, + int hidden_dim, + int heads, + int rotary_dim, + bool rotate_half, + bool rotate_every_two, + cudaStream_t stream, + int trans_count); diff --git a/deepspeed/module_inject/replace_module.py b/deepspeed/module_inject/replace_module.py index 62d5cd75e9a1..79c89ffd6bd8 100755 --- a/deepspeed/module_inject/replace_module.py +++ b/deepspeed/module_inject/replace_module.py @@ -53,7 +53,7 @@ def merge_assert(self, dim1, dim2): def qkv_copy(self, dst, src): if src is None: - return torch.nn.Parameter(src) + return src src_shape = src.shape dst_shape = dst.shape @@ -90,7 +90,7 @@ def qkv_copy(self, dst, src): def copy(self, dst, src): if src is None: - return torch.nn.Parameter(src) + return src src_shape = src.shape dst_shape = dst.shape @@ -351,7 +351,7 @@ def replace_with_policy(child, # linear layer is created with [input, output] shape # transpose it here to reduce inference cost! def transpose(data): - data.view(-1).copy_(data.transpose(-1, -2).contiguous().view(-1)) + data.reshape(-1).copy_(data.transpose(-1, -2).contiguous().reshape(-1)) data = data.reshape(data.shape[-1], data.shape[-2]) return data @@ -391,8 +391,8 @@ def _transpose(x): qkvw = torch.nn.Parameter(_transpose(qkvw).contiguous()) qkvb = torch.nn.Parameter(_transpose(qkvb).contiguous()) - dense_b = dense_b * (transformer_config.training_mp_size / - transformer_config.mp_size) + dense_b = dense_b if dense_b is None else dense_b * ( + transformer_config.training_mp_size / transformer_config.mp_size) _4hh_b = _4hh_b * (transformer_config.training_mp_size / transformer_config.mp_size) diff --git a/deepspeed/module_inject/replace_policy.py b/deepspeed/module_inject/replace_policy.py index c8d14e431d08..f4054c633810 100755 --- a/deepspeed/module_inject/replace_policy.py +++ b/deepspeed/module_inject/replace_policy.py @@ -2,6 +2,7 @@ import torch from torch.nn.parameter import Parameter +from packaging import version as pkg_version class DSPolicy(ABC): @@ -210,12 +211,15 @@ def __init__(self, client_module, inference=True): # we use megatron version to differentiate between the old and new # megatron-lm source code if MegatronLayerPolicy._orig_layer_class is None: - try: - import megatron - from megatron.model.transformer import ParallelTransformerLayer - MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer - except ImportError: + if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): MegatronLayerPolicy._orig_layer_class = None + else: + try: + import megatron + from megatron.model.transformer import ParallelTransformerLayer + MegatronLayerPolicy._orig_layer_class = ParallelTransformerLayer + except ImportError: + MegatronLayerPolicy._orig_layer_class = None def get_hidden_heads(self): return self.client_module.attention.query_key_value.weight.shape[1], \ @@ -325,12 +329,15 @@ def __init__(self, client_module, inference=True, megatron_v2=True): super().__init__(inference, megatron_v2=megatron_v2) self.client_module = client_module if GPTNEOXLayerPolicy._orig_layer_class is None: - try: - import megatron - from megatron.model.transformer import ParallelTransformerLayerPipe - GPTNEOXLayerPolicy._orig_layer_class = ParallelTransformerLayerPipe - except ImportError: + if pkg_version.parse(torch.__version__) <= pkg_version.parse("1.2"): GPTNEOXLayerPolicy._orig_layer_class = None + else: + try: + import megatron + from megatron.model.transformer import ParallelTransformerLayerPipe + GPTNEOXLayerPolicy._orig_layer_class = ParallelTransformerLayerPipe + except ImportError: + GPTNEOXLayerPolicy._orig_layer_class = None def get_hidden_heads(self): if GPTNEOXLayerPolicy.version == 0: diff --git a/deepspeed/ops/transformer/inference/transformer_inference.py b/deepspeed/ops/transformer/inference/transformer_inference.py index aed03148e919..6b91841801f5 100755 --- a/deepspeed/ops/transformer/inference/transformer_inference.py +++ b/deepspeed/ops/transformer/inference/transformer_inference.py @@ -135,7 +135,8 @@ def forward(ctx, q_scales, q_groups, merge_count, - qkv_merging): + qkv_merging, + score_context_func): def _transpose_for_scores(x, key=False, reshape=False): attention_head_size = x.shape[-1] // num_attention_heads_per_partition new_x_shape = x.size()[:-1] + (num_attention_heads_per_partition, @@ -156,9 +157,11 @@ def _transpose_for_context(x): return x.view(*new_x_layer_shape).contiguous() def compute_attention(qkv_out, input_mask): - score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ - inference_cuda_module.softmax_context_fp16 + no_masking = input_mask is None + head_size = (qkv_out.shape[-1] // 3 // num_attention_heads_per_partition) + if no_masking: + input_mask = torch.empty(1) if merge_count > 0 and config.q_int8: split_dim = (qkv_out.dim() - 1) qkv_split = torch.split(qkv_out, @@ -175,86 +178,89 @@ def compute_attention(qkv_out, input_mask): torch.cat([s[i] for s in qkv_split], axis=-1) for i in range(len(qkv_split[0])) ] - else: - (mixed_query, - key_layer, - value_layer) = torch.split(qkv_out, - (qkv_out.shape[-1] // 3), - dim=(qkv_out.dim() - 1)) - no_masking = input_mask is None - if no_masking: - input_mask = torch.empty(1) - head_size = (mixed_query.shape[-1] // num_attention_heads_per_partition) - unfused_mode = not config.specialized_mode or \ - mixed_query.shape[1] >= 32 or head_size > 128 - - if config.rotary_dim > 0: - mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( - mixed_query, - key_layer, - config.rotary_dim, - 0 if layer_past is None else layer_past[0].shape[-2], - num_attention_heads_per_partition, - config.rotate_half, - config.rotate_every_two) - if layer_past is not None: - past_key, past_value = layer_past - if unfused_mode: + if config.rotary_dim > 0: + mixed_query, key_layer = inference_cuda_module.apply_rotary_pos_emb( + mixed_query, + key_layer, + config.rotary_dim, + 0 if layer_past is None else layer_past[0].shape[-2], + num_attention_heads_per_partition, + config.rotate_half, + config.rotate_every_two) + if layer_past is not None: + past_key, past_value = layer_past key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-2) value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) - presents = (key_layer, value_layer) - if unfused_mode: + presents = (key_layer, value_layer) mixed_query = _transpose_for_scores(mixed_query, False, True) key_layer = _transpose_for_scores( key_layer, True, True) / (norm_factor if config.scale_attention else 1.0) value_layer = _transpose_for_scores(value_layer, False, True) - #print(f'[{torch.distributed.get_rank()}] {config.layer_id}: {mixed_query.norm()}') - if layer_past is None: - attn_key_value = score_context_func( - mixed_query, - key_layer, - torch.empty(1), - input_mask, - value_layer, - torch.empty(1), - num_attention_heads_per_partition, - (1 / norm_factor if config.scale_attention else 1.0), - (not unfused_mode), - config.triangular_masking, - config.local_attention, - config.window_size, - no_masking) + if layer_past is None: + attn_key_value = score_context_func( + mixed_query, + key_layer, + torch.empty(1), + input_mask, + value_layer, + torch.empty(1), + num_attention_heads_per_partition, + (1 / norm_factor if config.scale_attention else 1.0), + (not unfused_mode), + config.triangular_masking, + config.local_attention, + config.window_size, + no_masking) + else: + attn_key_value = score_context_func( + mixed_query, + (key_layer if unfused_mode else past_key.type_as(key_layer)), + key_layer, + input_mask, + (value_layer + if unfused_mode else past_value.type_as(value_layer)), + value_layer, + num_attention_heads_per_partition, + (1 / norm_factor if config.scale_attention else 1.0), + (not unfused_mode), + config.triangular_masking, + config.local_attention, + config.window_size, + no_masking) + if unfused_mode: + context_layer, _, _ = attn_key_value + else: + context_layer, key_layer, value_layer = attn_key_value + + # Transpose Context + context_layer = _transpose_for_context(context_layer) + + return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer else: attn_key_value = score_context_func( - mixed_query, - (key_layer if unfused_mode else past_key.type_as(key_layer)), - key_layer, + qkv_out, input_mask, - (value_layer if unfused_mode else past_value.type_as(value_layer)), - value_layer, + config.rotary_dim, + config.rotate_half, + config.rotate_every_two, num_attention_heads_per_partition, (1 / norm_factor if config.scale_attention else 1.0), - (not unfused_mode), config.triangular_masking, config.local_attention, config.window_size, - no_masking) - if unfused_mode: - context_layer, _, _ = attn_key_value - else: - context_layer, key_layer, value_layer = attn_key_value - - # Transpose Context - context_layer = _transpose_for_context(context_layer) + no_masking, + config.layer_id, + DeepSpeedTransformerInference.layer_id) - return context_layer, presents[0], presents[1] # atten_output, key_layer, value_layer + context_layer, key_layer, value_layer = attn_key_value + return context_layer, key_layer, value_layer def selfAttention_fp(): vector_matmul_func = inference_cuda_module.vector_matmul_fp16 if config.fp16 else \ @@ -267,19 +273,20 @@ def selfAttention_fp(): else: qkv_func = inference_cuda_module.qkv_gemm_fp16 if config.fp16 else \ inference_cuda_module.qkv_gemm_fp32 + qkv_out = qkv_func(input, attn_qkvw, (attn_qkvb if attn_qkvb is not None else norm_b), norm_w, norm_b, config.epsilon, - (attn_qkvb is not None)) + (attn_qkvb is not None), + DeepSpeedTransformerInference.layer_id) context_layer, key_layer, value_layer = compute_attention(qkv_out[0] if isinstance(qkv_out, list) else qkv_out, input_mask) output = vector_matmul_func(context_layer, attn_ow, False) - #print(f'[{torch.distributed.get_rank()}] {config.layer_id}: oooooo -> {output.norm()}') - return output, key_layer, value_layer, context_layer, qkv_out[-1] # attn_out, present_key, present_value, context_output, inp_norm + return output, key_layer, value_layer, context_layer, qkv_out[-1] def selfAttention_int8(): if not config.pre_layer_norm: @@ -366,6 +373,9 @@ def __init__(self, math.sqrt(self.config.hidden_size // self.config.heads)) self.qkv_merging = qkv_merging + self.score_context_func = inference_cuda_module.softmax_context_fp32 if (not config.fp16) else \ + inference_cuda_module.softmax_context_fp16 + def forward(self, input, input_mask, @@ -400,7 +410,8 @@ def forward(self, self.q_scales, self.q_groups, self.merge_count, - self.qkv_merging) + self.qkv_merging, + self.score_context_func) return output @@ -472,9 +483,10 @@ def forward(ctx, residual, input, output_b, - bias, + bias if bias is not None else output_b, config.mp_size, - config.mlp_after_attn) + config.mlp_after_attn, + bias is not None) if mp_group is not None and dist.get_world_size(group=mp_group) > 1: dist.all_reduce(output, group=mp_group) return output diff --git a/op_builder/transformer_inference.py b/op_builder/transformer_inference.py index 23eab4886e80..2f05230dbada 100755 --- a/op_builder/transformer_inference.py +++ b/op_builder/transformer_inference.py @@ -20,6 +20,7 @@ def sources(self): 'csrc/transformer/inference/csrc/softmax.cu', 'csrc/transformer/inference/csrc/dequantize.cu', 'csrc/transformer/inference/csrc/apply_rotary_pos_emb.cu', + 'csrc/transformer/inference/csrc/transform.cu', ] def extra_ldflags(self): diff --git a/requirements/requirements-dev.txt b/requirements/requirements-dev.txt index 313379c4ecc2..d0258ff660f4 100644 --- a/requirements/requirements-dev.txt +++ b/requirements/requirements-dev.txt @@ -11,3 +11,4 @@ recommonmark sphinx sphinx-rtd-theme torchvision +transformers diff --git a/tests/unit/test_inference.py b/tests/unit/test_inference.py new file mode 100644 index 000000000000..0a0e5d4a36b7 --- /dev/null +++ b/tests/unit/test_inference.py @@ -0,0 +1,34 @@ +import os +import torch +import pytest +import deepspeed +from transformers import pipeline +from .common import distributed_test +from packaging import version as pkg_version + + +@pytest.mark.parametrize("dtype", [(torch.float), (torch.half)]) +def test_gpt2_inject(dtype): + if pkg_version.parse(torch.__version__) <= pkg_version.parse('1.2'): + pytest.skip("DS inference injection doesn't work well on older torch versions") + + @distributed_test(world_size=[1]) + def _go(): + local_rank = int(os.getenv("LOCAL_RANK", "0")) + world_size = int(os.getenv("WORLD_SIZE", "1")) + generator = pipeline("text-generation", model="gpt2", device=local_rank) + + generator.model = deepspeed.init_inference( + generator.model, + mp_size=world_size, + dtype=dtype, + replace_method="auto", + replace_with_kernel_inject=True, + ) + + prompt = "DeepSpeed is" + string_1 = generator(prompt, do_sample=False, max_length=128) + string_2 = generator(prompt, do_sample=False, max_length=128) + assert string_1 == string_2 + + _go()