Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API Proposal]: AVX-512 Masking helper functions #96986

Closed
MineCake147E opened this issue Jan 15, 2024 · 8 comments
Closed

[API Proposal]: AVX-512 Masking helper functions #96986

MineCake147E opened this issue Jan 15, 2024 · 8 comments
Labels
api-suggestion Early API idea and discussion, it is NOT ready for implementation area-System.Runtime.Intrinsics

Comments

@MineCake147E
Copy link
Contributor

MineCake147E commented Jan 15, 2024

Background and motivation

#87097 aims to implement some AVX-512 mask-related intrinsics.
It was supposed to be a solution to deal with some mask-related things while also retaining the code compatibility with older hardware that supports neither AVX-512 nor SVE2.

It makes sense to use existing Vector*<T> types to define a register (or variable) that behaves like a mask, as we usually do in AVX2, which doesn't even have a physical mask register at all. AVX2 provides a way to treat a vector register as a mask by looking at the most significant bit of each element.
But ideas in #87097 extensively rely on RyuJIT pattern-matching, which I don't really think it to be reliable, predictable, and optimal for now.

Even worse, Vector*<T>.ConditionalSelect<T>(condition, left, right) always emits something equivalent to vpternlogd condition, left, right, 0xca, no matter how mask-like the condition is, as the document says:

Conditionally selects a value from two vectors on a bitwise basis.

For the case when the condition is in a mask register, it emits vpmovm2* just to copy the value to the vector register, which takes 3 precious clock cycles to complete on Intel CPUs.
In this way, the whole operation is slower than to use the better instruction for that matter, like vpblendm* for AVX-512 environments, if the condition is inside a mask register.
On the other hand, when the condition is somehow spilled into a vector register instead, like when it emits multiple mask-writing instructions like vpcmpeqb, bringing a spilled vector register back to a mask register makes it emit vpmov*2m which also takes 3 clock cycles to complete on Intel CPUs.
In this case, vpternlogd condition, left, right, 0xca is actually better than to reload the mask register.
Vector*<T>.ConditionalSelect<T>(condition, left, right) always chooses the former approach, as the document says, so the dedicated masked blending function should be needed anyway.

Also, as I noted in #92261, pattern-matching that actually changes the behavior of functions drastically hurts the readability of the code, no matter how much effort is being invested to overcome.

Hardware Intrinsics used to be 'specific' in the .NET Core 3.1 days. Every single function specifies the instructions to be executed at least implicitly.
.NET 5 didn't change anything about it, or at most made it more readable and portable.
.NET 6 just made it more readable and portable.
.NET 7 just made it more readable and portable as well.
.NET 8 almost broke the consistency by RC 1, as I wrote in #92261, but it reverted its decision in RC 2.

Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques.
So I need these new APIs to be added in .NET 9 for now.

API Proposal

Cross-platform APIs

BlendVariable is designed in a way that the order of parameters match with Avx512BW.BlendVariable.

namespace System.Runtime.Intrinsics
{
    // This lets analyzer to ensure the parameter is already in a mask register.
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskExpectedAttribute : Attribute
    {
    }

    // To be added to `condition` parameter of ConditionalSelect methods
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskNotExpectedAttribute : Attribute
    {
    }

    // This lets analyzer to assume the result is stored into a mask register.
    [AttributeUsage(AttributeTargets.ReturnValue, Inherited = false, AllowMultiple = false)]
    public sealed class MaskAttribute : Attribute
    {
    }

    public static class Vector512
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector512<T> MergeWith<T>(this Vector512<T> newValue, Vector512<T> destinationOperand, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector512<T> ZeroIfNot<T>(this Vector512<T> newValue, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector512<T> BlendVariable<T>(Vector512<T> left, Vector512<T> right, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector512<T> MaskifyMostSignificantBits<T>(this Vector512<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector512<T> VectorifyMask<T>([MaskExpected] this Vector512<T> mask);

        [Mask]
        public static Vector512<float> CreateMaskSingle(short value);
        [Mask]
        public static Vector512<float> CreateMaskSingle(ushort value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(int value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(uint value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(long value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(ulong value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(long value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(ulong value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(int value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(uint value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(int value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(uint value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(short value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(ushort value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(short value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(ushort value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftLeft<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftRightLogical<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);
    }

    public static class Vector256
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector256<T> MergeWith<T>(this Vector256<T> newValue, Vector256<T> destinationOperand, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector256<T> ZeroIfNot<T>(this Vector256<T> newValue, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector256<T> BlendVariable<T>(Vector256<T> left, Vector256<T> right, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector256<T> MaskifyMostSignificantBits<T>(this Vector256<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector256<T> VectorifyMask<T>([MaskExpected] this Vector256<T> mask);

        [Mask]
        public static Vector256<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector256<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(short value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(ushort value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(int value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(uint value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(int value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(uint value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(short value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(ushort value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(short value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(ushort value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftLeft<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftRightLogical<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);
    }

    public static class Vector128
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector128<T> MergeWith<T>(this Vector128<T> newValue, Vector128<T> destinationOperand, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector128<T> ZeroIfNot<T>(this Vector128<T> newValue, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector128<T> BlendVariable<T>(Vector128<T> left, Vector128<T> right, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector128<T> MaskifyMostSignificantBits<T>(this Vector128<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector128<T> VectorifyMask<T>([MaskExpected] this Vector128<T> mask);

        [Mask]
        public static Vector128<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector128<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(sbyte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(byte value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(short value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(ushort value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(short value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(ushort value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(sbyte value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(byte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(sbyte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(byte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftLeft<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftRightLogical<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);
    }
}

Additional AVX-512 Intrinsics

Some bit operations are omitted in this list.

namespace System.Runtime.Intrinsics.X86
{
    public abstract class Avx512F : Avx2
    {
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector128<byte> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<sbyte> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);

        [Mask]
        public static Vector512<float> MaskXnor([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskXnor([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskXnor([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskXnor([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskXnor([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskXnor([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskXnor([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskXnor([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
    }

    public abstract class Avx512DQ : Avx512F
    {
        [Mask]
        public static Vector512<float> MaskAdd([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskAdd([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskAdd([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskAdd([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskAdd([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskAdd([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskAdd([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskAdd([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector256<float> MaskAdd([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskAdd([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskAdd([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskAdd([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskAdd([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskAdd([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskAdd([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskAdd([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskAdd([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

        [Mask]
        public static Vector256<float> MaskXnor([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskXnor([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskXnor([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskXnor([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskXnor([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskXnor([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskXnor([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskXnor([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskXnor([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

    }
    
    public abstract class Avx512BW : Avx512F
    {
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskAdd([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskAdd([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskAdd([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskAdd([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskAdd([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskAdd([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskAdd([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskXnor([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskXnor([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskXnor([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskXnor([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskXnor([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskXnor([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskXnor([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);
    }
}

API Usage

Merge-masking and zero-masking can be written like:

zmm0 = Avx512BW.Subtract(zmm1, zmm2).MergeWith(zmm0, k1);   // vpsubb zmm0 {k1}, zmm1, zmm2
zmm0 = Avx512BW.Subtract(zmm3, zmm4).ZeroIfNot(k2);         // vpsubb zmm0 {k2}{z}, zmm3, zmm4

Alternative Designs

There may be better names for each identifier.

Risks

None I can come up with.

@MineCake147E MineCake147E added the api-suggestion Early API idea and discussion, it is NOT ready for implementation label Jan 15, 2024
@ghost ghost added the untriaged New issue has not been triaged by the area owner label Jan 15, 2024
@ghost
Copy link

ghost commented Jan 15, 2024

Tagging subscribers to this area: @dotnet/area-system-runtime-intrinsics
See info in area-owners.md if you want to be subscribed.

Issue Details

Background and motivation

#87097 aims to implement some AVX-512 mask-related intrinsics.
It was supposed to be a solution to deal with some mask-related things while also retaining the code compatibility with older hardware that supports neither AVX-512 nor SVE2.

It makes sense to use existing Vector*<T> types to define a register (or variable) that behaves like a mask, as we usually do in AVX2, which doesn't even have a physical mask register at all. AVX2 provides a way to treat a vector register as a mask by looking at the most significant bit of each element.
But ideas in #87097 extensively rely on RyuJIT pattern-matching, which I don't really think it to be reliable, predictable, and optimal for now.

Even worse, Vector*<T>.ConditionalSelect<T>(condition, left, right) always emits something equivalent to vpternlogd condition, left, right, 0xca, no matter how mask-like the condition is, as the document says:

Conditionally selects a value from two vectors on a bitwise basis.

For the case when the condition is in a mask register, it emits vpmovm2* just to copy the value to the vector register, which takes 3 precious clock cycles to complete on Intel CPUs.
In this way, the whole operation is slower than to use the better instruction for that matter, like vpblendm* for AVX-512 environments, if the condition is inside a mask register.
On the other hand, when the condition is somehow spilled into a vector register instead, like when it emits multiple mask-writing instructions like vpcmpeqb, bringing a spilled vector register back to a mask register makes it emit vpmov*2m which also takes 3 clock cycles to complete on Intel CPUs.
In this case, vpternlogd condition, left, right, 0xca is actually better than to reload the mask register.
Vector*<T>.ConditionalSelect<T>(condition, left, right) always chooses the former approach, as the document says, so the dedicated masked blending function should be needed anyway.

Also, as I noted in #92261, pattern-matching that actually changes the behavior of functions drastically hurts the readability of the code, no matter how much effort is being invested to overcome.

Hardware Intrinsics used to be 'specific' in the .NET Core 3.1 days. Every single function specifies the instructions to be executed at least implicitly.
.NET 5 didn't change anything about it, or at least made it more readable and portable.
.NET 6 just made it more readable and portable.
.NET 7 just made it more readable and portable as well.
.NET 8 almost broke the consistency by RC 1, as I wrote in #92261, but it reverted its decision in RC 2.

Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques.
So I need these new APIs to be added in .NET 9 for now.

API Proposal

Cross-platform APIs

BlendVariable is designed in a way that the order of parameters match with Avx512BW.BlendVariable.

namespace System.Runtime.Intrinsics
{
    // This lets analyzer to ensure the parameter is already in a mask register.
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskExpectedAttribute : Attribute
    {
    }

    // To be added to `condition` parameter of ConditionalSelect methods
    [AttributeUsage(AttributeTargets.Parameter, Inherited = false, AllowMultiple = false)]
    public sealed class MaskNotExpectedAttribute : Attribute
    {
    }

    // This lets analyzer to assume the result is stored into a mask register.
    [AttributeUsage(AttributeTargets.ReturnValue, Inherited = false, AllowMultiple = false)]
    public sealed class MaskAttribute : Attribute
    {
    }

    public static class Vector512
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector512<T> MergeWith<T>(this Vector512<T> newValue, Vector512<T> destinationOperand, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector512<T> ZeroIfNot<T>(this Vector512<T> newValue, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector512<T> BlendVariable<T>(Vector512<T> left, Vector512<T> right, [MaskExpected] Vector512<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector512<T> MaskifyMostSignificantBits<T>(this Vector512<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector512<T> VectorifyMask<T>([MaskExpected] this Vector512<T> mask);

        [Mask]
        public static Vector512<float> CreateMaskSingle(short value);
        [Mask]
        public static Vector512<float> CreateMaskSingle(ushort value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector512<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(int value);
        [Mask]
        public static Vector512<Half> CreateMaskHalf(uint value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(long value);
        [Mask]
        public static Vector512<byte> CreateMaskByte(ulong value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(long value);
        [Mask]
        public static Vector512<sbyte> CreateMaskSByte(ulong value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(int value);
        [Mask]
        public static Vector512<ushort> CreateMaskUInt16(uint value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(int value);
        [Mask]
        public static Vector512<short> CreateMaskInt16(uint value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(short value);
        [Mask]
        public static Vector512<uint> CreateMaskUInt32(ushort value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(short value);
        [Mask]
        public static Vector512<int> CreateMaskInt32(ushort value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector512<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector512<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftLeft<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector512<T> MaskShiftRightLogical<T>([MaskExpected] Vector512<T> left, [ConstantExpected] byte right);
    }

    public static class Vector256
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector256<T> MergeWith<T>(this Vector256<T> newValue, Vector256<T> destinationOperand, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector256<T> ZeroIfNot<T>(this Vector256<T> newValue, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector256<T> BlendVariable<T>(Vector256<T> left, Vector256<T> right, [MaskExpected] Vector256<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector256<T> MaskifyMostSignificantBits<T>(this Vector256<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector256<T> VectorifyMask<T>([MaskExpected] this Vector256<T> mask);

        [Mask]
        public static Vector256<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector256<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector256<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(short value);
        [Mask]
        public static Vector256<Half> CreateMaskHalf(ushort value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(int value);
        [Mask]
        public static Vector256<byte> CreateMaskByte(uint value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(int value);
        [Mask]
        public static Vector256<sbyte> CreateMaskSByte(uint value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(short value);
        [Mask]
        public static Vector256<ushort> CreateMaskUInt16(ushort value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(short value);
        [Mask]
        public static Vector256<short> CreateMaskInt16(ushort value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector256<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector256<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector256<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector256<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftLeft<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector256<T> MaskShiftRightLogical<T>([MaskExpected] Vector256<T> left, [ConstantExpected] byte right);
    }

    public static class Vector128
    {
        public static bool IsSupported { get; }

        /// <summary>
        /// Performs masked merging operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, otherwise, one from <paramref name="destinationOperand"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="destinationOperand">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : <paramref name="destinationOperand"/></returns>
        public static Vector128<T> MergeWith<T>(this Vector128<T> newValue, Vector128<T> destinationOperand, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Performs masked zeroing operation.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="newValue"/>, or 0 otherwise.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="newValue">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="newValue"/> : 0</returns>
        public static Vector128<T> ZeroIfNot<T>(this Vector128<T> newValue, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Conditionally selects a value from two vectors on a element-wise basis.
        /// If corresponding element of <paramref name="mask"/> is 1, then the element would be one from <paramref name="right"/>, otherwise, one from <paramref name="left"/>.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the vector.</typeparam>
        /// <param name="left">The values to be selected when the corresponding element in <paramref name="mask"/> is false.</param>
        /// <param name="right">The values to be selected when the corresponding element in <paramref name="mask"/> is true.</param>
        /// <param name="mask">The mask to control which value to take.</param>
        /// <returns>For each elements, <paramref name="mask"> ? <paramref name="right"/> : <paramref name="left"/></returns>
        public static Vector128<T> BlendVariable<T>(Vector128<T> left, Vector128<T> right, [MaskExpected] Vector128<T> mask);

        /// <summary>
        /// Creates a mask consisting of the most significant bit of each element in a vector.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a mask register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in the <paramref name="vector"/>.</typeparam>
        /// <param name="vector">The vector whose elements should have their most significant bit extracted and stored in a mask register.</param>
        /// <returns>The packed most significant bits extracted from the elements in <paramref name="vector"/>.</returns>
        [Mask]
        public static Vector128<T> MaskifyMostSignificantBits<T>(this Vector128<T> vector);

        /// <summary>
        /// Creates a vector which represents a mask.
        /// This function acts like a hint for compilers so that the compiler can assume the returned value is stored in a vector register.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="mask">The mask whose elements should have their most significant bit extracted and stored in a vector register.</param>
        /// <returns>The expanded <paramref name="mask"/> as a vector.</returns>
        public static Vector128<T> VectorifyMask<T>([MaskExpected] this Vector128<T> mask);

        [Mask]
        public static Vector128<float> CreateMaskSingle(sbyte value);
        [Mask]
        public static Vector128<float> CreateMaskSingle(byte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(sbyte value);
        [Mask]
        public static Vector128<double> CreateMaskDouble(byte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(sbyte value);
        [Mask]
        public static Vector128<Half> CreateMaskHalf(byte value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(short value);
        [Mask]
        public static Vector128<byte> CreateMaskByte(ushort value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(short value);
        [Mask]
        public static Vector128<sbyte> CreateMaskSByte(ushort value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(sbyte value);
        [Mask]
        public static Vector128<ushort> CreateMaskUInt16(byte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(sbyte value);
        [Mask]
        public static Vector128<short> CreateMaskInt16(byte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(sbyte value);
        [Mask]
        public static Vector128<uint> CreateMaskUInt32(byte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(sbyte value);
        [Mask]
        public static Vector128<int> CreateMaskInt32(byte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(sbyte value);
        [Mask]
        public static Vector128<ulong> CreateMaskUInt64(byte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(sbyte value);
        [Mask]
        public static Vector128<long> CreateMaskInt64(byte value);

        /// <summary>
        /// Shifts the value of a mask left by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftLeft<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);

        /// <summary>
        /// Shifts the value of a mask right by the specified amount.
        /// </summary>
        /// <typeparam name="T">The type of the elements in a vector to mask with.</typeparam>
        /// <param name="left">The mask that represents the bits to be shifted.</param>
        /// <param name="right">The number of bits to shift.</param>
        /// <returns>A mask that represents the shifted value.</returns>
        [Mask]
        public static Vector128<T> MaskShiftRightLogical<T>([MaskExpected] Vector128<T> left, [ConstantExpected] byte right);
    }
}

Additional AVX-512 Intrinsics

Some bit operations are omitted in this list.

namespace System.Runtime.Intrinsics.X86
{
    public abstract class Avx512F : Avx2
    {
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<float> MaskUnpack16([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);
        [Mask]
        public static Vector512<uint> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector512<int> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector256<Half> MaskUnpack16([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector256<ushort> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector256<short> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector128<byte> MaskUnpack16([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<sbyte> MaskUnpack16([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);

        [Mask]
        public static Vector512<float> MaskXnor([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskXnor([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskXnor([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskXnor([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskXnor([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskXnor([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskXnor([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskXnor([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
    }

    public abstract class Avx512DQ : Avx512F
    {
        [Mask]
        public static Vector512<float> MaskAdd([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector256<Half> MaskAdd([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector128<byte> MaskAdd([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector128<sbyte> MaskAdd([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector256<ushort> MaskAdd([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<short> MaskAdd([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector512<uint> MaskAdd([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<int> MaskAdd([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector256<float> MaskAdd([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskAdd([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskAdd([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskAdd([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskAdd([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskAdd([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskAdd([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskAdd([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskAdd([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

        [Mask]
        public static Vector256<float> MaskXnor([MaskExpected] Vector256<float> left, [MaskExpected] Vector256<float> right);
        [Mask]
        public static Vector512<double> MaskXnor([MaskExpected] Vector512<double> left, [MaskExpected] Vector512<double> right);
        [Mask]
        public static Vector128<Half> MaskXnor([MaskExpected] Vector128<Half> left, [MaskExpected] Vector128<Half> right);
        [Mask]
        public static Vector128<ushort> MaskXnor([MaskExpected] Vector128<ushort> left, [MaskExpected] Vector128<ushort> right);
        [Mask]
        public static Vector128<short> MaskXnor([MaskExpected] Vector128<short> left, [MaskExpected] Vector128<short> right);
        [Mask]
        public static Vector256<uint> MaskXnor([MaskExpected] Vector256<uint> left, [MaskExpected] Vector256<uint> right);
        [Mask]
        public static Vector256<int> MaskXnor([MaskExpected] Vector256<int> left, [MaskExpected] Vector256<int> right);
        [Mask]
        public static Vector512<ulong> MaskXnor([MaskExpected] Vector512<ulong> left, [MaskExpected] Vector512<ulong> right);
        [Mask]
        public static Vector512<long> MaskXnor([MaskExpected] Vector512<long> left, [MaskExpected] Vector512<long> right);

    }
    
    public abstract class Avx512BW : Avx512F
    {
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector512<float> left, [MaskExpected] Vector512<float> right);
        [Mask]
        public static Vector512<Half> MaskUnpack32([MaskExpected] Vector256<Half> left, [MaskExpected] Vector256<Half> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector512<uint> left, [MaskExpected] Vector512<uint> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector512<int> left, [MaskExpected] Vector512<int> right);
        [Mask]
        public static Vector512<ushort> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector512<short> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector256<ushort> left, [MaskExpected] Vector256<ushort> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector256<short> left, [MaskExpected] Vector256<short> right);
        [Mask]
        public static Vector256<byte> MaskUnpack32([MaskExpected] Vector128<byte> left, [MaskExpected] Vector128<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskUnpack32([MaskExpected] Vector128<sbyte> left, [MaskExpected] Vector128<sbyte> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskUnpack64([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskUnpack64([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskAdd([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskAdd([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskAdd([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskAdd([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskAdd([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskAdd([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskAdd([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);

        [Mask]
        public static Vector512<Half> MaskXnor([MaskExpected] Vector512<Half> left, [MaskExpected] Vector512<Half> right);
        [Mask]
        public static Vector256<byte> MaskXnor([MaskExpected] Vector256<byte> left, [MaskExpected] Vector256<byte> right);
        [Mask]
        public static Vector256<sbyte> MaskXnor([MaskExpected] Vector256<sbyte> left, [MaskExpected] Vector256<sbyte> right);
        [Mask]
        public static Vector512<ushort> MaskXnor([MaskExpected] Vector512<ushort> left, [MaskExpected] Vector512<ushort> right);
        [Mask]
        public static Vector512<short> MaskXnor([MaskExpected] Vector512<short> left, [MaskExpected] Vector512<short> right);
        [Mask]
        public static Vector512<byte> MaskXnor([MaskExpected] Vector512<byte> left, [MaskExpected] Vector512<byte> right);
        [Mask]
        public static Vector512<sbyte> MaskXnor([MaskExpected] Vector512<sbyte> left, [MaskExpected] Vector512<sbyte> right);
    }
}

API Usage

Merge-masking and zero-masking can be written like:

zmm0 = Avx512BW.Subtract(zmm1, zmm2).MergeWith(zmm0, k1);   // vpsubb zmm0 {k1}, zmm1, zmm2
zmm0 = Avx512BW.Subtract(zmm3, zmm4).ZeroIfNot(k2);         // vpsubb zmm0 {k2}{z}, zmm3, zmm4

Alternative Designs

There may be better names for each identifier.

Risks

None I can come up with.

Author: MineCake147E
Assignees: -
Labels:

api-suggestion, area-System.Runtime.Intrinsics

Milestone: -

@tannergooding
Copy link
Member

Which I don't really think it to be reliable, predictable, and optimal for now.

This is a point in time problem and was a known limitation around the release of AVX-512 in .NET 8 due to the massive scope and size of the feature. There were multiple areas related to masking that didn't make the cut in the first implementation and which are being worked on for .NET 9. That includes implementing the necessary pattern recognition and enabling optimizations or alternative instruction emission for various cases, particularly when it comes to the xplat APIs.

Hardware Intrinsics should be consistent forever, while it should adopt many new instructions and much new optimization techniques.

The intrinsics space needs to evolve and fit the needs of a growing number of platforms while continuing to avoid massive amounts of bloat, complexity, and other problematic considerations that come about from newer concepts and how they impact the overall JIT, whether they only add cost if used vs overhead even if not used, etc.

The intrinsics space has always used and relied on pattern recognition in a number of scenarios, especially when it comes to "embedded operations" as are common on x86/x64 (such as embedded loads/stores). Embedded masking is no different.


This proposal, particularly as it applies to the xplat APIs proposed, isn't going to help with codegen. It's only going to complicate the required recognition and make it harder to use the APIs on platforms where masking doesn't exist.

For the platform-specific APIs, there are a couple concepts that might be worth exposing such as adding or unpacking masks given that the behavior is different from doing the same with two vectors that are known to be allbitsset or zero per-element. However, there are also ones proposed like MaskXnor which do not make sense to expose because there is no difference in behavior from a regular Xnor operation.

I'd recommend opening issues to track individual cases where the pattern recognition around masking doesn't work as expected (such as no embedded masking support today or the case where ConditionalSelect could emit vpblendm instead of vpternlog).

I'd then recommend opening a standalone proposal for the mask like concepts which can't be trivially handled, such as addition and unpacking.

@ghost ghost removed the untriaged New issue has not been triaged by the area owner label Jan 15, 2024
@tannergooding tannergooding closed this as not planned Won't fix, can't repro, duplicate, stale Jan 15, 2024
@MineCake147E
Copy link
Contributor Author

MineCake147E commented Jan 17, 2024

I'd then recommend opening a standalone proposal for the mask like concepts which can't be trivially handled, such as addition and unpacking.

I'd recommend opening issues to track individual cases where the pattern recognition around masking doesn't work as expected (such as no embedded masking support today or the case where ConditionalSelect could emit vpblendm instead of vpternlog).

Sure. I will.

This proposal, particularly as it applies to the xplat APIs proposed, isn't going to help with codegen. It's only going to complicate the required recognition and make it harder to use the APIs on platforms where masking doesn't exist.

I'm afraid I forgot to mention that MergeWith and ZeroIfNot were ideas mainly from readability issues.
CreateMask* were for kmov* instructions.
MaskifyMostSignificantBits was for vpmov*2m.
VectorifyMask was for vpmovm2*.
MaskShift* could be necessary as some of alternative approach, like ones with one or more AlignRight on different element types, or one uses Permute*, can't be recognized very easily.
The same applies for CreateMask*.

I personally think that they don't hurt anything on platforms where masking doesn't exist, as most of their fallback code could easily be implemented using non-mask instructions.

@MineCake147E
Copy link
Contributor Author

Also, covering all known/potential workaround with pattern matching doesn't seem to be a good idea anyway. It bloats the RyuJIT up pretty quickly, or even end up missing the optimization opportunities of large methods more because of extended execution time of RyuJIT optimizing a method.
CreateMask*, MaskifyMostSignificantBits, VectorifyMask, and MaskShift* are the ones which can't be implemented trivially without corresponding instructions.

@MineCake147E
Copy link
Contributor Author

MineCake147E commented Jan 17, 2024

I forgot to mention, though, the optimal code for CPUs without masking support could be suboptimal in CPUs with masking support.

Consider subtracting short values in ymm1 from ones in ymm0, if the corresponding mask bit represents true.

A code optimized for AVX2 looks like if ymm2 is the mask here:

vpand ymm1, ymm1, ymm2
vpsubw ymm0, ymm0, ymm1

A code optimized for AVX-512 looks like if k1 is the mask here:

vpsubw ymm0 {k1}, ymm0, ymm1

Can future RyuJIT be able to recognize ymm0 -= k1 & ymm1 and emit the code above?
Or if ymm2 were given as a mask instead, here's what an optimized code looks like instead:

vpand ymm1, ymm1, ymm2
vpsubw ymm0, ymm0, ymm1

Can RyuJIT be able to recognize ymm0 = Avx512BW.BlendVariable(ymm0, ymm0 - ymm1, ymm2) and emit the code above as well?

Consider multiplying short values in ymm0 by 16, if the corresponding mask bit represents true.

A code optimized for AVX2 looks like if ymm1 is the mask here:

vpsllw ymm2, ymm0, 4
vpblendvb ymm0, ymm0, ymm2, ymm1

A code optimized for AVX-512 looks like if k1 is the mask here:

vpsllw ymm0 {k1}, ymm0, 4

Can RyuJIT recognize ymm0 <<= k1 & Vector512.Create((short)4) and emit the code above in the future?
Or if ymm1 were given as a mask instead, here's what an optimized code looks like instead:

vpandd ymm1, ymm1, dword ptr [rip + .DISPLACEMENT]{1to8} ; The memory address stores 0x0004_0004
vpsllvw ymm0, ymm0, ymm1

Can RyuJIT recognize ymm0 = Avx512BW.BlendVariable(ymm0, Avx512BW.ShiftLeftLogical(ymm0, 4), ymm1) and emit the code above in the future?

I think that the user should be aware of masking availability anyway.
No matter how far RyuJIT evolves, the optimal C# code should vary from platform to platform in some cases, as I showed above.
Even LLVM sometimes fails to generate optimal code today.
It is nearly impossible to create a compiler that optimizes the code perfectly at all times.

@tannergooding
Copy link
Member

I'm afraid I forgot to mention that MergeWith and ZeroIfNot were ideas mainly from readability issues.

I don't think they help that much with readability and there are definitely some ambiguities in how the names can be interpreted.

On the other hand, ConditionalSelect is very clear on what it does and fits the pattern that people have already been following for years. Users that "really" want to have some kind of helper like MergeWith can trivially define it over ConditionalSelect as an extension method.

CreateMask* were for kmov* instructions.
...

As indicated, any APIs which can't be trivially handled via pattern recognition, such as because the naive operation isn't clear or 1-to-1, can and likely should have platform specific APIs exposed. So exposing an API like Avx512F.AddMask is fine and an appropriate proposal should be opened. Having an xplat API like Vector128.AddMask is likely not a good idea on the other hand and will lead to pessimizations in xplat code.

It bloats the RyuJIT up pretty quickly

It does not and is significantly less expensive than the alternative, which is one of the reasons why we went with it.

Also, covering all known/potential workaround with pattern matching doesn't seem to be a good idea anyway.

It is not a goal to cover every potential pattern, no compiler does this. Even LLVM doesn't cover "everything".

It is instead a goal of the compiler to cover common/typical patterns. Users can request that a new pattern be recognized as well, but it all comes down to cost, complexity, and benefit.

There is some responsibility on the developer, especially in perf critical code, to write their code in a way that fits the well-established and recognized patterns to ensure the underlying compiler can do its job.

could be suboptimal in CPUs with masking support.

It is not the goal of the compiler to generate "ideal" codegen in every possible scenario. This is effectively impossible and even LLVM when you're targeting a specific micro-architecture gets it wrong in many cases.

At the end of the day, losing 1-3 cycles isn't going to matter for many code patterns and the smallest code isn't necessarily the fastest code. Many of the examples you've provided aren't necessarily standard/common patterns that would be encountered for typical SIMD code, and where they are encountered, the difference between the given and suggested codegen isn't significant. It's within the realm of noise, especially when compared to the general performance deltas caused by inherent latency (cache, ram, disk, etc), variable processor speeds, resource contention between cores/hyper-threads, etc.

You might be able to measure the difference in a micro-benchmark, but the difference is unlikely to surface in most real world applications, unless they are highly specialized and those lines happen to be the bottleneck over gigabytes of data.

I think that the user should be aware of masking availability anyway.
No matter how far RyuJIT evolves, the optimal C# code should vary from platform to platform in some cases, as I showed above.

Most code doesn't need to be "optimal" and as the number of platforms and scenarios needing to be supported expands, so does the need to write portable and reusable code.

The BCL is explicitly utilizing the xplat intrinsics, with selective usage of the platform specific intrinsics, in most code paths because losing a nanosecond here or there is acceptable for the massively reduced complexity, the increased confidence that the code is working as expected, and the ability to rapidly bring online 2-16x perf gains for new platforms. -- Even if a 20x perf gain is possible with hand tuned code, its a diminishing return, especially when viewed in the context of typical applications.

Even LLVM sometimes fails to generate optimal code today.

Yes, and there are places where RyuJIT provides better controls/guarantees than LLVM provides. There is also the inverse case where LLVM does it better.

At the end of the day, we are confident we will be able to get the most common patterns and support in such that you will be able to get nearly ideal code for the vast majority of scenarios. There will be some places over time that users will surface as needing improvements and we will look at those as they come in, determining whether the pattern can/should be supported or if the user should modify their code slightly to fit in with the already recognized patterns.

@MineCake147E
Copy link
Contributor Author

Oh I get it now.

As indicated, any APIs which can't be trivially handled via pattern recognition, such as because the naive operation isn't clear or 1-to-1, can and likely should have platform specific APIs exposed. So exposing an API like Avx512F.AddMask is fine and an appropriate proposal should be opened.

So am I allowed to open one for MaskShift*, VectorifyMask, MaskifyMostSignificantBits, and CreateMask* to be included in Avx512* instead then?

I opened one for MaskAdd and MaskUnpack* yesterday by the way.

@tannergooding
Copy link
Member

tannergooding commented Jan 18, 2024

So am I allowed to open one for MaskShift*, VectorifyMask, MaskifyMostSignificantBits, and CreateMask* to be included in Avx512* instead then?

Most of these APIs should be *Mask rather than Mask*. We already have Mask* APIs (such as Sse.MaskStore) and using it as a postfix also better matches other APIs where the type appears at the end (LoadVector128, ConvertToDouble, AsInt64, etc).

As for the ones listed:

  • ShiftMask sounds fine.
  • It's not clear what VectorifyMask does
  • MaskifyMostSignificantBits sounds like it is identical to ExtractMostSignificantBits, an existing API
  • CreateMask sounds like something that would be better as part of the xplat API surface, similar to the other Create APIs

@github-actions github-actions bot locked and limited conversation to collaborators Feb 18, 2024
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
api-suggestion Early API idea and discussion, it is NOT ready for implementation area-System.Runtime.Intrinsics
Projects
None yet
Development

No branches or pull requests

2 participants