Skip to content

Commit cfc49ea

Browse files
committed
Enable int32 support for Clamp OP
1 parent 0836071 commit cfc49ea

File tree

4 files changed

+143
-18
lines changed

4 files changed

+143
-18
lines changed

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_dnnl_ext_emitters.hpp

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -85,22 +85,6 @@ class jit_elu_emitter : public jit_dnnl_emitter {
8585
}
8686
};
8787

88-
class jit_clamp_emitter : public jit_dnnl_emitter {
89-
public:
90-
jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
91-
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
92-
const std::shared_ptr<ov::Node>& n,
93-
ov::element::Type exec_prc = ov::element::f32)
94-
: jit_dnnl_emitter(host, host_isa, n, exec_prc) {
95-
kind = dnnl_eltwise_clip;
96-
auto op = ov::as_type_ptr<ov::op::v0::Clamp>(n);
97-
alpha = static_cast<float>(op->get_min());
98-
beta = static_cast<float>(op->get_max());
99-
100-
set_injector();
101-
}
102-
};
103-
10488
class jit_swish_emitter : public jit_dnnl_emitter {
10589
public:
10690
jit_swish_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.cpp

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2923,4 +2923,108 @@ void jit_abs_emitter::register_table_entries() {
29232923
push_arg_entry_of("positive_mask", 0x7fffffff, true);
29242924
}
29252925

2926+
/// CLAMP ///
2927+
jit_clamp_emitter::jit_clamp_emitter(x64::jit_generator_t* host,
2928+
x64::cpu_isa_t host_isa,
2929+
const std::shared_ptr<ov::Node>& node)
2930+
: jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {
2931+
const auto& clamp = ov::as_type_ptr<ov::op::v0::Clamp>(node);
2932+
double alpha = clamp->get_min();
2933+
double beta = clamp->get_max();
2934+
switch (exec_prc_) {
2935+
case element::i32:
2936+
minimum = static_cast<int>(std::max<int64_t>(static_cast<int64_t>(alpha), std::numeric_limits<int32_t>::min()));
2937+
maximum = static_cast<int>(std::min<int64_t>(static_cast<int64_t>(beta), std::numeric_limits<int32_t>::max()));
2938+
break;
2939+
case element::f32:
2940+
minimum = x64::float2int(alpha);
2941+
maximum = x64::float2int(beta);
2942+
break;
2943+
default:
2944+
OV_CPU_JIT_EMITTER_THROW("Unsupported precision");
2945+
}
2946+
prepare_table();
2947+
}
2948+
2949+
jit_clamp_emitter::jit_clamp_emitter(x64::jit_generator_t* host,
2950+
x64::cpu_isa_t host_isa,
2951+
ov::element::Type exec_prc,
2952+
float alpha,
2953+
float beta)
2954+
: jit_emitter(host, host_isa, exec_prc) {
2955+
// TODO: Duplicate code and abstract a method
2956+
switch (exec_prc_) {
2957+
case element::i32:
2958+
minimum = static_cast<int>(std::max<int64_t>(static_cast<int64_t>(alpha), std::numeric_limits<int>::min()));
2959+
maximum = static_cast<int>(std::min<int64_t>(static_cast<int64_t>(beta), std::numeric_limits<int>::max()));
2960+
break;
2961+
case element::f32:
2962+
minimum = x64::float2int(alpha);
2963+
maximum = x64::float2int(beta);
2964+
break;
2965+
default:
2966+
OV_CPU_JIT_EMITTER_THROW("Unsupported precision");
2967+
}
2968+
prepare_table();
2969+
}
2970+
2971+
size_t jit_clamp_emitter::get_inputs_num() const {
2972+
return 1;
2973+
}
2974+
2975+
void jit_clamp_emitter::emit_impl(const std::vector<size_t>& in_vec_idxs,
2976+
const std::vector<size_t>& out_vec_idxs) const {
2977+
if (host_isa_ == x64::sse41) {
2978+
emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
2979+
} else if (host_isa_ == x64::avx2) {
2980+
emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
2981+
} else if (host_isa_ == x64::avx512_core) {
2982+
emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
2983+
} else {
2984+
OV_CPU_JIT_EMITTER_THROW("Unsupported ISA ", host_isa_);
2985+
}
2986+
}
2987+
2988+
template <x64::cpu_isa_t isa>
2989+
void jit_clamp_emitter::emit_isa(const std::vector<size_t>& in_vec_idxs,
2990+
const std::vector<size_t>& out_vec_idxs) const {
2991+
using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
2992+
auto vmm_src0 = Vmm(in_vec_idxs[0]);
2993+
auto vmm_dst = Vmm(out_vec_idxs[0]);
2994+
2995+
auto uni_vclamp = [this](Vmm vmm_dst, Vmm vmm_src0) {
2996+
switch (exec_prc_) {
2997+
case ov::element::f32:
2998+
h->uni_vmaxps(vmm_dst, vmm_src0, table_val("min"));
2999+
h->uni_vminps(vmm_dst, vmm_dst, table_val("max"));
3000+
break;
3001+
case ov::element::i32:
3002+
h->uni_vpmaxsd(vmm_dst, vmm_src0, table_val("min"));
3003+
h->uni_vpminsd(vmm_dst, vmm_dst, table_val("max"));
3004+
break;
3005+
default:
3006+
OV_CPU_JIT_EMITTER_THROW("Unsupported precision");
3007+
}
3008+
};
3009+
3010+
if (isa == x64::sse41) {
3011+
if (vmm_src0.getIdx() != vmm_dst.getIdx()) {
3012+
h->uni_vmovups(vmm_dst, vmm_src0);
3013+
}
3014+
uni_vclamp(vmm_dst, vmm_dst);
3015+
} else {
3016+
uni_vclamp(vmm_dst, vmm_src0);
3017+
}
3018+
}
3019+
3020+
std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions(
3021+
[[maybe_unused]] const std::shared_ptr<ov::Node>& node) {
3022+
return {{element::f32, element::f32}, {element::i32, element::i32}};
3023+
}
3024+
3025+
void jit_clamp_emitter::register_table_entries() {
3026+
push_arg_entry_of("min", minimum, true);
3027+
push_arg_entry_of("max", maximum, true);
3028+
}
3029+
29263030
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/emitters/plugin/x64/jit_eltwise_emitters.hpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -954,4 +954,30 @@ class jit_abs_emitter : public jit_emitter {
954954
void register_table_entries() override;
955955
};
956956

957+
class jit_clamp_emitter : public jit_emitter {
958+
public:
959+
jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
960+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
961+
ov::element::Type exec_prc = ov::element::f32,
962+
float alpha = 0.0f,
963+
float beta = 0.0f);
964+
jit_clamp_emitter(dnnl::impl::cpu::x64::jit_generator_t* host,
965+
dnnl::impl::cpu::x64::cpu_isa_t host_isa,
966+
const std::shared_ptr<ov::Node>& n);
967+
968+
size_t get_inputs_num() const override;
969+
static std::set<std::vector<element::Type>> get_supported_precisions(
970+
const std::shared_ptr<ov::Node>& node = nullptr);
971+
972+
private:
973+
void emit_impl(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const override;
974+
975+
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
976+
void emit_isa(const std::vector<size_t>& in_vec_idxs, const std::vector<size_t>& out_vec_idxs) const;
977+
void register_table_entries() override;
978+
979+
int minimum;
980+
int maximum;
981+
};
982+
957983
} // namespace ov::intel_cpu

src/plugins/intel_cpu/src/nodes/kernels/x64/jit_uni_eltwise_generic.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,17 @@ struct EltwiseEmitter<jit_is_inf_emitter> {
378378
ctx.opData.beta);
379379
}
380380
};
381+
382+
template <>
383+
struct EltwiseEmitter<jit_clamp_emitter> {
384+
void operator()(EltwiseEmitterContext& ctx) {
385+
ctx.emitter = std::make_shared<jit_clamp_emitter>(ctx.host,
386+
ctx.host_isa,
387+
ctx.exec_prc,
388+
ctx.opData.alpha,
389+
ctx.opData.beta);
390+
}
391+
};
381392
} // namespace
382393

383394
template <dnnl::impl::cpu::x64::cpu_isa_t isa>
@@ -398,7 +409,7 @@ std::shared_ptr<jit_emitter> jit_uni_eltwise_generic<isa>::create_eltwise_emitte
398409
OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter),
399410
OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter),
400411
OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter),
401-
OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter),
412+
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
402413
OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter),
403414
OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter),
404415
OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter),
@@ -891,7 +902,7 @@ std::set<std::vector<element::Type>> eltwise_precision_helper::get_supported_pre
891902
OV_CASE(Algorithm::EltwiseAbs, jit_abs_emitter),
892903
OV_CASE(Algorithm::EltwiseSqrt, jit_dnnl_aux_emitter),
893904
OV_CASE(Algorithm::EltwiseSoftRelu, jit_dnnl_aux_emitter),
894-
OV_CASE(Algorithm::EltwiseClamp, jit_dnnl_aux_emitter),
905+
OV_CASE(Algorithm::EltwiseClamp, jit_clamp_emitter),
895906
OV_CASE(Algorithm::EltwiseSwish, jit_dnnl_aux_emitter),
896907
OV_CASE(Algorithm::EltwiseHswish, jit_dnnl_aux_emitter),
897908
OV_CASE(Algorithm::EltwiseMish, jit_dnnl_aux_emitter),

0 commit comments

Comments
 (0)