Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Sve.Scatter() #104555

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26445,6 +26445,10 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const
addr = Op(2);
break;

case NI_Sve_Scatter:
addr = Op(2);
break;

#endif // TARGET_ARM64

default:
Expand Down Expand Up @@ -26486,7 +26490,11 @@ bool GenTreeHWIntrinsic::OperIsMemoryStore(GenTree** pAddr) const

if (addr != nullptr)
{
#ifdef TARGET_ARM64
assert(varTypeIsI(addr) || (varTypeIsSIMD(addr) && ((intrinsicId >= NI_Sve_Scatter))));
#else
assert(varTypeIsI(addr));
#endif
return true;
}

Expand Down
23 changes: 22 additions & 1 deletion src/coreclr/jit/hwintrinsic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1869,7 +1869,11 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
case NI_Sve_GatherVectorUInt32ZeroExtend:
case NI_Sve_GatherVectorWithByteOffsets:
assert(varTypeIsSIMD(op3->TypeGet()));
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
if (numArgs == 3)
{
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(
getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
}
break;
#endif

Expand All @@ -1885,6 +1889,23 @@ GenTree* Compiler::impHWIntrinsic(NamedIntrinsic intrinsic,
assert(!isScalar);
retNode =
gtNewSimdHWIntrinsicNode(nodeRetType, op1, op2, op3, op4, intrinsic, simdBaseJitType, simdSize);

switch (intrinsic)
{
#if defined(TARGET_ARM64)
case NI_Sve_Scatter:
assert(varTypeIsSIMD(op3->TypeGet()));
if (numArgs == 4)
{
retNode->AsHWIntrinsic()->SetAuxiliaryJitType(
getBaseJitTypeOfSIMDType(sigReader.op3ClsHnd));
}
break;
#endif

default:
break;
}
break;
}

Expand Down
33 changes: 33 additions & 0 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2052,6 +2052,39 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
GetEmitter()->emitIns_R_R(ins, emitSize, targetReg, op1Reg, opt);
break;

case NI_Sve_Scatter:
{
if (!varTypeIsSIMD(intrin.op2->gtType))
{
// Scatter(Vector<T1> mask, T1* address, Vector<T2> indicies, Vector<T> data)
assert(intrin.numOperands == 4);
SwapnilGaikwad marked this conversation as resolved.
Show resolved Hide resolved
emitAttr baseSize = emitActualTypeSize(intrin.baseType);

if (baseSize == EA_8BYTE)
{
// Index is multiplied by 8
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op2Reg, op3Reg, opt,
INS_SCALABLE_OPTS_LSL_N);
}
else
{
// Index is sign or zero extended to 64bits, then multiplied by 4
assert(baseSize == EA_4BYTE);
opt = varTypeIsUnsigned(node->GetAuxiliaryType()) ? INS_OPTS_SCALABLE_S_UXTW
: INS_OPTS_SCALABLE_S_SXTW;
GetEmitter()->emitIns_R_R_R_R(ins, emitSize, op4Reg, op1Reg, op2Reg, op3Reg, opt,
INS_SCALABLE_OPTS_MOD_N);
}
}
else
{
// Scatter(Vector<T> mask, Vector<T> addresses, Vector<T> data)
assert(intrin.numOperands == 3);
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);
}
break;
}

case NI_Sve_StoreNarrowing:
opt = emitter::optGetSveInsOpt(emitTypeSize(intrin.baseType));
GetEmitter()->emitIns_R_R_R_I(ins, emitSize, op3Reg, op1Reg, op2Reg, 0, opt);
Expand Down
1 change: 1 addition & 0 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ HARDWARE_INTRINSIC(Sve, SaturatingIncrementBy64BitElementCount,
HARDWARE_INTRINSIC(Sve, SaturatingIncrementBy8BitElementCount, 0, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_sqincb, INS_sve_uqincb, INS_sve_sqincb, INS_sve_uqincb, INS_invalid, INS_invalid}, HW_Category_Scalar, HW_Flag_Scalable|HW_Flag_HasImmediateOperand|HW_Flag_HasEnumOperand|HW_Flag_SpecialCodeGen|HW_Flag_SpecialImport|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, SaturatingIncrementByActiveElementCount, -1, 2, true, {INS_invalid, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_sve_sqincp, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_SpecialImport|HW_Flag_BaseTypeFromSecondArg|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, Scale, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_fscale, INS_sve_fscale}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, Scatter, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_st1w, INS_sve_st1w, INS_sve_st1d, INS_sve_st1d, INS_sve_st1w, INS_sve_st1d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, ShiftLeftLogical, -1, -1, false, {INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_sve_lsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, ShiftRightArithmetic, -1, -1, false, {INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_sve_asr, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics)
HARDWARE_INTRINSIC(Sve, ShiftRightArithmeticForDivide, -1, -1, false, {INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_sve_asrd, INS_invalid, INS_invalid, INS_invalid}, HW_Category_ShiftRightByImmediate, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_HasImmediateOperand)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4175,7 +4175,7 @@ internal Arm64() { }
/// svuint8_t svcls[_s8]_z(svbool_t pg, svint8_t op)
/// CLS Ztied.B, Pg/M, Zop.B
/// </summary>
public static unsafe Vector<byte> LeadingSignCount(Vector<sbyte> value){ throw new PlatformNotSupportedException(); }
public static unsafe Vector<byte> LeadingSignCount(Vector<sbyte> value) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svuint16_t svcls[_s16]_m(svuint16_t inactive, svbool_t pg, svint16_t op)
Expand Down Expand Up @@ -7144,6 +7144,120 @@ internal Arm64() { }
public static unsafe Vector<float> Scale(Vector<float> left, Vector<int> right) { throw new PlatformNotSupportedException(); }


// Non-truncating store

// <summary>
// void svst1_scatter_[s64]offset[_f64](svbool_t pg, float64_t *base, svint64_t offsets, svfloat64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<double> mask, double* address, Vector<long> indicies, Vector<double> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u64base_f64](svbool_t pg, svuint64_t bases, svfloat64_t data)
// ST1D Zdata.D, Pg, [Zbases.D, #0]
// </summary>
public static unsafe void Scatter(Vector<double> mask, Vector<ulong> addresses, Vector<double> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u64]offset[_f64](svbool_t pg, float64_t *base, svuint64_t offsets, svfloat64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<double> mask, double* address, Vector<ulong> indicies, Vector<double> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[s32]offset[_s32](svbool_t pg, int32_t *base, svint32_t offsets, svint32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
// </summary>
public static unsafe void Scatter(Vector<int> mask, int* address, Vector<int> indicies, Vector<int> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u32base_s32](svbool_t pg, svuint32_t bases, svint32_t data)
// ST1W Zdata.S, Pg, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe void Scatter(Vector<int> mask, Vector<uint> addresses, Vector<int> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u32]offset[_s32](svbool_t pg, int32_t *base, svuint32_t offsets, svint32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
// </summary>
public static unsafe void Scatter(Vector<int> mask, int* address, Vector<uint> indicies, Vector<int> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[s64]offset[_s64](svbool_t pg, int64_t *base, svint64_t offsets, svint64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<long> mask, long* address, Vector<long> indicies, Vector<long> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u64base_s64](svbool_t pg, svuint64_t bases, svint64_t data)
// ST1D Zdata.D, Pg, [Zbases.D, #0]
// </summary>
public static unsafe void Scatter(Vector<long> mask, Vector<ulong> addresses, Vector<long> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u64]offset[_s64](svbool_t pg, int64_t *base, svuint64_t offsets, svint64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<long> mask, long* address, Vector<ulong> indicies, Vector<long> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[s32]offset[_f32](svbool_t pg, float32_t *base, svint32_t offsets, svfloat32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
// </summary>
public static unsafe void Scatter(Vector<float> mask, float* address, Vector<int> indicies, Vector<float> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u32base_f32](svbool_t pg, svuint32_t bases, svfloat32_t data)
// ST1W Zdata.S, Pg, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe void Scatter(Vector<float> mask, Vector<uint> addresses, Vector<float> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u32]offset[_f32](svbool_t pg, float32_t *base, svuint32_t offsets, svfloat32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
// </summary>
public static unsafe void Scatter(Vector<float> mask, float* address, Vector<uint> indicies, Vector<float> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[s32]offset[_u32](svbool_t pg, uint32_t *base, svint32_t offsets, svuint32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, SXTW]
// </summary>
public static unsafe void Scatter(Vector<uint> mask, uint* address, Vector<int> indicies, Vector<uint> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u32base_u32](svbool_t pg, svuint32_t bases, svuint32_t data)
// ST1W Zdata.S, Pg, [Zbases.S, #0]
// </summary>
// Removed as per #103297
// public static unsafe void Scatter(Vector<uint> mask, Vector<uint> addresses, Vector<uint> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u32]offset[_u32](svbool_t pg, uint32_t *base, svuint32_t offsets, svuint32_t data)
// ST1W Zdata.S, Pg, [Xbase, Zoffsets.S, UXTW]
// </summary>
public static unsafe void Scatter(Vector<uint> mask, uint* address, Vector<uint> indicies, Vector<uint> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[s64]offset[_u64](svbool_t pg, uint64_t *base, svint64_t offsets, svuint64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<ulong> mask, ulong* address, Vector<long> indicies, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter[_u64base_u64](svbool_t pg, svuint64_t bases, svuint64_t data)
// ST1D Zdata.D, Pg, [Zbases.D, #0]
// </summary>
public static unsafe void Scatter(Vector<ulong> mask, Vector<ulong> addresses, Vector<ulong> data) { throw new PlatformNotSupportedException(); }

// <summary>
// void svst1_scatter_[u64]offset[_u64](svbool_t pg, uint64_t *base, svuint64_t offsets, svuint64_t data)
// ST1D Zdata.D, Pg, [Xbase, Zoffsets.D]
// </summary>
public static unsafe void Scatter(Vector<ulong> mask, ulong* address, Vector<ulong> indicies, Vector<ulong> data) { throw new PlatformNotSupportedException(); }


/// Logical shift left

/// <summary>
Expand Down
Loading
Loading