@@ -2923,4 +2923,108 @@ void jit_abs_emitter::register_table_entries() {
2923
2923
push_arg_entry_of (" positive_mask" , 0x7fffffff , true );
2924
2924
}
2925
2925
2926
+ // / CLAMP ///
2927
+ jit_clamp_emitter::jit_clamp_emitter (x64::jit_generator_t * host,
2928
+ x64::cpu_isa_t host_isa,
2929
+ const std::shared_ptr<ov::Node>& node)
2930
+ : jit_emitter(host, host_isa, get_arithmetic_binary_exec_precision(node)) {
2931
+ const auto & clamp = ov::as_type_ptr<ov::op::v0::Clamp>(node);
2932
+ double alpha = clamp->get_min ();
2933
+ double beta = clamp->get_max ();
2934
+ switch (exec_prc_) {
2935
+ case element::i32 :
2936
+ minimum = static_cast <int >(std::max<int64_t >(static_cast <int64_t >(alpha), std::numeric_limits<int32_t >::min ()));
2937
+ maximum = static_cast <int >(std::min<int64_t >(static_cast <int64_t >(beta), std::numeric_limits<int32_t >::max ()));
2938
+ break ;
2939
+ case element::f32 :
2940
+ minimum = x64::float2int (alpha);
2941
+ maximum = x64::float2int (beta);
2942
+ break ;
2943
+ default :
2944
+ OV_CPU_JIT_EMITTER_THROW (" Unsupported precision" );
2945
+ }
2946
+ prepare_table ();
2947
+ }
2948
+
2949
+ jit_clamp_emitter::jit_clamp_emitter (x64::jit_generator_t * host,
2950
+ x64::cpu_isa_t host_isa,
2951
+ ov::element::Type exec_prc,
2952
+ float alpha,
2953
+ float beta)
2954
+ : jit_emitter(host, host_isa, exec_prc) {
2955
+ // TODO: Duplicate code and abstract a method
2956
+ switch (exec_prc_) {
2957
+ case element::i32 :
2958
+ minimum = static_cast <int >(std::max<int64_t >(static_cast <int64_t >(alpha), std::numeric_limits<int >::min ()));
2959
+ maximum = static_cast <int >(std::min<int64_t >(static_cast <int64_t >(beta), std::numeric_limits<int >::max ()));
2960
+ break ;
2961
+ case element::f32 :
2962
+ minimum = x64::float2int (alpha);
2963
+ maximum = x64::float2int (beta);
2964
+ break ;
2965
+ default :
2966
+ OV_CPU_JIT_EMITTER_THROW (" Unsupported precision" );
2967
+ }
2968
+ prepare_table ();
2969
+ }
2970
+
2971
+ size_t jit_clamp_emitter::get_inputs_num () const {
2972
+ return 1 ;
2973
+ }
2974
+
2975
+ void jit_clamp_emitter::emit_impl (const std::vector<size_t >& in_vec_idxs,
2976
+ const std::vector<size_t >& out_vec_idxs) const {
2977
+ if (host_isa_ == x64::sse41) {
2978
+ emit_isa<x64::sse41>(in_vec_idxs, out_vec_idxs);
2979
+ } else if (host_isa_ == x64::avx2) {
2980
+ emit_isa<x64::avx2>(in_vec_idxs, out_vec_idxs);
2981
+ } else if (host_isa_ == x64::avx512_core) {
2982
+ emit_isa<x64::avx512_core>(in_vec_idxs, out_vec_idxs);
2983
+ } else {
2984
+ OV_CPU_JIT_EMITTER_THROW (" Unsupported ISA " , host_isa_);
2985
+ }
2986
+ }
2987
+
2988
+ template <x64::cpu_isa_t isa>
2989
+ void jit_clamp_emitter::emit_isa (const std::vector<size_t >& in_vec_idxs,
2990
+ const std::vector<size_t >& out_vec_idxs) const {
2991
+ using Vmm = typename conditional3<isa == x64::sse41, Xmm, isa == x64::avx2, Ymm, Zmm>::type;
2992
+ auto vmm_src0 = Vmm (in_vec_idxs[0 ]);
2993
+ auto vmm_dst = Vmm (out_vec_idxs[0 ]);
2994
+
2995
+ auto uni_vclamp = [this ](Vmm vmm_dst, Vmm vmm_src0) {
2996
+ switch (exec_prc_) {
2997
+ case ov::element::f32 :
2998
+ h->uni_vmaxps (vmm_dst, vmm_src0, table_val (" min" ));
2999
+ h->uni_vminps (vmm_dst, vmm_dst, table_val (" max" ));
3000
+ break ;
3001
+ case ov::element::i32 :
3002
+ h->uni_vpmaxsd (vmm_dst, vmm_src0, table_val (" min" ));
3003
+ h->uni_vpminsd (vmm_dst, vmm_dst, table_val (" max" ));
3004
+ break ;
3005
+ default :
3006
+ OV_CPU_JIT_EMITTER_THROW (" Unsupported precision" );
3007
+ }
3008
+ };
3009
+
3010
+ if (isa == x64::sse41) {
3011
+ if (vmm_src0.getIdx () != vmm_dst.getIdx ()) {
3012
+ h->uni_vmovups (vmm_dst, vmm_src0);
3013
+ }
3014
+ uni_vclamp (vmm_dst, vmm_dst);
3015
+ } else {
3016
+ uni_vclamp (vmm_dst, vmm_src0);
3017
+ }
3018
+ }
3019
+
3020
+ std::set<std::vector<element::Type>> jit_clamp_emitter::get_supported_precisions (
3021
+ [[maybe_unused]] const std::shared_ptr<ov::Node>& node) {
3022
+ return {{element::f32 , element::f32 }, {element::i32 , element::i32 }};
3023
+ }
3024
+
3025
+ void jit_clamp_emitter::register_table_entries () {
3026
+ push_arg_entry_of (" min" , minimum, true );
3027
+ push_arg_entry_of (" max" , maximum, true );
3028
+ }
3029
+
2926
3030
} // namespace ov::intel_cpu
0 commit comments