Skip to content
Open
6 changes: 6 additions & 0 deletions backends/metax_gpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ include(cutlass)
include(dgc)
include(warpctc)
include(warprnnt)
if(WITH_MAGMA)
include(external/magma)
endif()

set(PLUGIN_VERSION ${PADDLE_VERSION})

Expand Down Expand Up @@ -124,6 +127,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/dense_tensor_iterator.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/math/*.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/eigen/*.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/funcs/magma/magma_function.cc
# kernels/gpu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_send_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/partial_recv_kernel.cu
Expand Down Expand Up @@ -668,6 +672,7 @@ file(
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/binomial_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/bernoulli_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/cufft.cc
${PADDLE_SOURCE_DIR}/paddle/phi/backends/dynload/magma.cc
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/box_coder_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/broadcast_tensors_kernel.cu
${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/broadcast_tensors_grad_kernel.cu
Expand Down Expand Up @@ -782,6 +787,7 @@ target_link_libraries(
protobuf
external_error_proto
dgc
magma
${WARPCTC_LIBRARIES}
${WARPRNNT_LIBRARIES}
${PADDLE_CORE_LIB})
Expand Down
123 changes: 123 additions & 0 deletions backends/metax_gpu/kernels/metax_kernel/eig_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/eig.h"
#include "paddle/phi/kernels/eig_kernel.h"

namespace phi {

template <typename T, typename Context>
void EigKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out_w,
DenseTensor* out_v) {
dev_ctx.template Alloc<phi::dtype::Complex<T>>(out_w);
dev_ctx.template Alloc<phi::dtype::Complex<T>>(out_v);

if (x.numel() == 0) {
return;
}

auto cpu_place = phi::CPUPlace();
phi::DeviceContextPool& pool = phi::DeviceContextPool::Instance();
auto* cpu_ctx = static_cast<phi::CPUContext*>(pool.Get(cpu_place));

// prepare cpu Tensor here, since magma requires output on cpu
DenseTensor out_w_cpu, out_v_cpu;
out_w_cpu.Resize(out_w->dims());
(*cpu_ctx).template Alloc<phi::dtype::Complex<T>>(&out_w_cpu);
out_v_cpu.Resize(x.dims());
(*cpu_ctx).template Alloc<phi::dtype::Complex<T>>(&out_v_cpu);

if (!IsComplexType(x.dtype())) {
// output still be complex though input is real
int batch_count = BatchCount(x);
int order = static_cast<int>(x.dims(-1));

DenseTensor real_w_cpu, real_v_cpu;

std::vector<int64_t> real_w_dim = common::vectorize<int64_t>(out_w->dims());
real_w_dim.back() *= 2;
real_w_cpu.Resize(common::make_ddim(real_w_dim));
(*cpu_ctx).template Alloc<phi::dtype::Real<T>>(&real_w_cpu);
real_v_cpu.Resize(x.dims());
(*cpu_ctx).template Alloc<phi::dtype::Real<T>>(&real_v_cpu);

phi::ApplyEigKernelMagma<phi::dtype::Real<T>, Context>(
dev_ctx, x, &real_w_cpu, &real_v_cpu);

// 1. extract real part & imag part from real_w_cpu
DenseTensor real_part_cpu = phi::funcs::Slice<phi::dtype::Real<T>>(
(*cpu_ctx), real_w_cpu, {-1}, {0}, {order});
DenseTensor imag_part_cpu = phi::funcs::Slice<phi::dtype::Real<T>>(
(*cpu_ctx), real_w_cpu, {-1}, {order}, {order * 2});

// 2. construct complex values
auto* real_part_data = real_part_cpu.data<phi::dtype::Real<T>>();
auto* imag_part_data = imag_part_cpu.data<phi::dtype::Real<T>>();
int64_t out_w_numel = static_cast<int64_t>(out_w->numel());

phi::funcs::ForRange<phi::CPUContext> for_range((*cpu_ctx), out_w_numel);
phi::funcs::RealImagToComplexFunctor<phi::dtype::Complex<T>> functor(
real_part_data,
imag_part_data,
out_w_cpu.data<phi::dtype::Complex<T>>(),
out_w_numel);
for_range(functor);

// 3. construct complex vectors
DenseTensor real_v_trans_cpu =
phi::TransposeLast2Dim<phi::dtype::Real<T>, phi::CPUContext>(
(*cpu_ctx), real_v_cpu);
DenseTensor out_v_trans_cpu;
out_v_trans_cpu.Resize(x.dims());
(*cpu_ctx).template Alloc<phi::dtype::Complex<T>>(&out_v_trans_cpu);

phi::ConstructComplexVectors<phi::dtype::Real<T>,
phi::dtype::Complex<T>,
phi::CPUContext>(&out_v_trans_cpu,
out_w_cpu,
real_v_trans_cpu,
(*cpu_ctx),
batch_count,
order);

TransposeTwoAxis<phi::dtype::Complex<T>, phi::CPUContext>(
out_v_trans_cpu,
&out_v_cpu,
x.dims().size() - 1,
x.dims().size() - 2,
(*cpu_ctx));

} else {
phi::ApplyEigKernelMagma<T, Context>(dev_ctx, x, &out_w_cpu, &out_v_cpu);
}

// copy result from cpu to gpu tensor
phi::Copy(dev_ctx, out_w_cpu, dev_ctx.GetPlace(), false, out_w);
phi::Copy(dev_ctx, out_v_cpu, dev_ctx.GetPlace(), false, out_v);
}

} // namespace phi

PD_REGISTER_PLUGIN_KERNEL(
eig, metax_gpu, ALL_LAYOUT, phi::EigKernel, float, phi::complex64) {
if (kernel_key.dtype() == phi::DataType::FLOAT32) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
kernel->OutputAt(1).SetDataType(phi::dtype::ToComplex(kernel_key.dtype()));
}
}
Loading