File tree Expand file tree Collapse file tree 7 files changed +141
-1
lines changed
include/kernel/collectors
include/computation/operators Expand file tree Collapse file tree 7 files changed +141
-1
lines changed Original file line number Diff line number Diff line change 1+ #ifndef KERNEL_ATTENTION_H
2+ #define KERNEL_ATTENTION_H
3+
4+ #include " ../collector.h"
5+
6+ namespace refactor ::kernel {
7+
8+ struct AttentionCollector final : public InfoCollector {
9+ dim_t maxSeqLen;
10+
11+ AttentionCollector (decltype (_target), decltype (maxSeqLen)) noexcept ;
12+
13+ std::vector<KernelBox>
14+ filter (TensorRefs inputs, TensorRefs outputs) const final ;
15+ };
16+
17+ }// namespace refactor::kernel
18+
19+ #endif // KERNEL_ATTENTION_H
Original file line number Diff line number Diff line change 1+ #include " kernel/collectors/attention.h"
2+ #include " kernel/kernel.h"
3+ #include " kernel/tensor.h"
4+ // #include "../kernels/attention/cpu_kernel.hh"
5+ // #include "../kernels/attention/cuda_kernel.hh"
6+
7+ namespace refactor ::kernel {
8+
9+ AttentionCollector::AttentionCollector (
10+ decltype (_target) target,
11+ decltype(maxSeqLen) maxSeqLen_) noexcept
12+ : InfoCollector(target),
13+ maxSeqLen(maxSeqLen_) {}
14+
15+ std::vector<KernelBox>
16+ AttentionCollector::filter (TensorRefs inputs, TensorRefs outputs) const {
17+ std::vector<KernelBox> ans;
18+ switch (_target) {
19+ case decltype (_target)::Cpu:
20+ break ;
21+ case decltype (_target)::Nvidia:
22+ break ;
23+ case decltype (_target)::Mlu:
24+ break ;
25+ default :
26+ UNREACHABLEX (void , " Unknown target" );
27+ }
28+ return ans;
29+ }
30+
31+ }// namespace refactor::kernel
Original file line number Diff line number Diff line change 1+ #ifndef COMPUTATION_ATTENTION_H
2+ #define COMPUTATION_ATTENTION_H
3+
4+ #include " ../operator.h"
5+
6+ namespace refactor ::computation {
7+
8+ struct Attention final : public Operator {
9+ dim_t maxSeqLen;
10+
11+ constexpr Attention (decltype (maxSeqLen) maxSeqLen_) noexcept
12+ : Operator(), maxSeqLen(maxSeqLen_) {}
13+
14+ static size_t typeId () noexcept ;
15+ size_t opTypeId () const noexcept final ;
16+ std::string_view name () const noexcept final ;
17+ };
18+
19+ }// namespace refactor::computation
20+
21+ #endif // COMPUTATION_ATTENTION_H
Original file line number Diff line number Diff line change 1+ #include " computation/operators/attention.h"
2+
3+ namespace refactor ::computation {
4+ using Op = Attention;
5+
6+ auto Op::typeId () noexcept -> size_t {
7+ static uint8_t ID = 1 ;
8+ return reinterpret_cast <size_t >(&ID);
9+ }
10+ auto Op::opTypeId () const noexcept -> size_t { return typeId (); }
11+ auto Op::name () const noexcept -> std::string_view { return " Attention" ; }
12+
13+ }// namespace refactor::computation
Original file line number Diff line number Diff line change 1+ #include " computation/operators/attention.h"
2+ #include " attention.hh"
3+ #include " common.h"
4+
5+ namespace refactor ::llm {
6+ using Op = Attention;
7+
8+ Op::Attention (decltype (maxSeqLen) maxSeqLen_)
9+ : Operator(), maxSeqLen(maxSeqLen_) {}
10+
11+ auto Op::build (ModelContext const &, std::string_view, Attributes attributes) -> OpBox {
12+ auto maxSeqLen = attributes.getOrInsert (" max_seq_len" , {0 }).float_ ();
13+ return OpBox (std::make_unique<Op>(maxSeqLen));
14+ }
15+ auto Op::typeId () -> size_t {
16+ static uint8_t ID = 1 ;
17+ return reinterpret_cast <size_t >(&ID);
18+ }
19+
20+ auto Op::opTypeId () const -> size_t { return typeId (); }
21+ auto Op::opTypeName () const -> std::string_view { return " llm::Attention" ; }
22+
23+ auto Op::infer (TensorRefs inputs, InferOptions const &) const -> InferResult {
24+ TODO (" " );
25+ }
26+
27+ auto Op::lower (TensorRefs) const -> computation::OpBox {
28+ TODO (" " );
29+ }
30+
31+ }// namespace refactor::llm
Original file line number Diff line number Diff line change 1+ #ifndef LLM_RMS_ATTENTION_HH
2+ #define LLM_RMS_ATTENTION_HH
3+
4+ #include " frontend/operator.h"
5+
6+ namespace refactor ::llm {
7+ using namespace frontend ;
8+
9+ struct Attention final : public Operator {
10+ dim_t maxSeqLen;
11+
12+ explicit Attention (decltype (maxSeqLen));
13+
14+ static OpBox build (ModelContext const &, std::string_view, Attributes);
15+ static size_t typeId ();
16+
17+ size_t opTypeId () const final ;
18+ std::string_view opTypeName () const final ;
19+ InferResult infer (TensorRefs, InferOptions const &) const final ;
20+ computation::OpBox lower (TensorRefs) const final ;
21+ };
22+
23+ }// namespace refactor::llm
24+
25+ #endif // LLM_RMS_ATTENTION_HH
Original file line number Diff line number Diff line change @@ -9,7 +9,7 @@ namespace refactor::llm {
99 struct RmsNormalization final : public Operator {
1010 float epsilon;
1111
12- RmsNormalization (decltype (epsilon));
12+ explicit RmsNormalization (decltype (epsilon));
1313
1414 static OpBox build (ModelContext const &, std::string_view, Attributes);
1515 static size_t typeId ();
You can’t perform that action at this time.
0 commit comments