This commit is contained in:
Alan Hayward 2025-07-30 13:29:19 +01:00 committed by GitHub
commit f425492126
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 199 additions and 61 deletions

View File

@ -32364,7 +32364,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
// We shouldn't find AND_NOT nodes since it should only be produced in lowering
assert(oper != GT_AND_NOT);
#if defined(FEATURE_MASKED_HW_INTRINSICS) && defined(TARGET_XARCH)
#if defined(FEATURE_MASKED_HW_INTRINSICS)
#if defined(TARGET_XARCH)
if (GenTreeHWIntrinsic::OperIsBitwiseHWIntrinsic(oper))
{
// Comparisons that produce masks lead to more verbose trees than
@ -32482,7 +32483,65 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
}
}
}
#endif // FEATURE_MASKED_HW_INTRINSICS && TARGET_XARCH
#elif defined(TARGET_ARM64)
// Check if the tree can be folded into a mask variant
if (HWIntrinsicInfo::HasAllMaskVariant(tree->GetHWIntrinsicId()))
{
NamedIntrinsic maskVariant = HWIntrinsicInfo::GetMaskVariant(tree->GetHWIntrinsicId());
assert(opCount == (size_t)HWIntrinsicInfo::lookupNumArgs(maskVariant));
// Check all operands are valid
bool canFold = true;
for (size_t i = 1; i <= opCount && canFold; i++)
{
canFold &=
(varTypeIsMask(tree->Op(i)) || tree->Op(i)->OperIsConvertMaskToVector() || tree->Op(i)->IsVectorZero());
}
if (canFold)
{
// Convert all the operands to masks
for (size_t i = 1; i <= opCount; i++)
{
if (tree->Op(i)->OperIsConvertMaskToVector())
{
// Replace with op1.
tree->Op(i) = tree->Op(i)->AsHWIntrinsic()->Op(1);
}
else if (tree->Op(i)->IsVectorZero())
{
// Replace the vector of zeroes with a mask of zeroes.
tree->Op(i) = gtNewSimdFalseMaskByteNode();
tree->Op(i)->SetMorphed(this);
}
assert(varTypeIsMask(tree->Op(i)));
}
// Switch to the mask variant
switch (opCount)
{
case 1:
tree->ResetHWIntrinsicId(maskVariant, tree->Op(1));
break;
case 2:
tree->ResetHWIntrinsicId(maskVariant, tree->Op(1), tree->Op(2));
break;
case 3:
tree->ResetHWIntrinsicId(maskVariant, this, tree->Op(1), tree->Op(2), tree->Op(3));
break;
default:
unreached();
}
tree->gtType = TYP_MASK;
tree->SetMorphed(this);
tree = gtNewSimdCvtMaskToVectorNode(retType, tree, simdBaseJitType, simdSize)->AsHWIntrinsic();
tree->SetMorphed(this);
}
}
#endif // TARGET_ARM64
#endif // FEATURE_MASKED_HW_INTRINSICS
GenTree* cnsNode = nullptr;
GenTree* otherNode = nullptr;
@ -33869,7 +33928,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
// op2 = op2 & op1
op2->AsVecCon()->EvaluateBinaryInPlace(GT_AND, false, simdBaseType, op1->AsVecCon());
// op3 = op2 & ~op1
// op3 = op3 & ~op1
op3->AsVecCon()->EvaluateBinaryInPlace(GT_AND_NOT, false, simdBaseType, op1->AsVecCon());
// op2 = op2 | op3
@ -33882,8 +33941,8 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
#if defined(TARGET_ARM64)
case NI_Sve_ConditionalSelect:
case NI_Sve_ConditionalSelect_Predicates:
{
assert(!varTypeIsMask(retType));
assert(varTypeIsMask(op1));
if (cnsNode != op1)
@ -33912,10 +33971,11 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
if (op2->IsCnsVec() && op3->IsCnsVec())
{
assert(ni == NI_Sve_ConditionalSelect);
assert(op2->gtType == TYP_SIMD16);
assert(op3->gtType == TYP_SIMD16);
simd16_t op1SimdVal;
simd16_t op1SimdVal = {};
EvaluateSimdCvtMaskToVector<simd16_t>(simdBaseType, &op1SimdVal, op1->AsMskCon()->gtSimdMaskVal);
// op2 = op2 & op1
@ -33924,7 +33984,7 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
op1SimdVal);
op2->AsVecCon()->gtSimd16Val = result;
// op3 = op2 & ~op1
// op3 = op3 & ~op1
result = {};
EvaluateBinarySimd<simd16_t>(GT_AND_NOT, false, simdBaseType, &result, op3->AsVecCon()->gtSimd16Val,
op1SimdVal);
@ -33935,6 +33995,30 @@ GenTree* Compiler::gtFoldExprHWIntrinsic(GenTreeHWIntrinsic* tree)
resultNode = op2;
}
else if (op2->IsCnsMsk() && op3->IsCnsMsk())
{
assert(ni == NI_Sve_ConditionalSelect_Predicates);
// op2 = op2 & op1
simdmask_t result = {};
EvaluateBinaryMask<simd16_t>(GT_AND, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
op1->AsMskCon()->gtSimdMaskVal);
op2->AsMskCon()->gtSimdMaskVal = result;
// op3 = op3 & ~op1
result = {};
EvaluateBinaryMask<simd16_t>(GT_AND_NOT, false, simdBaseType, &result,
op3->AsMskCon()->gtSimdMaskVal, op1->AsMskCon()->gtSimdMaskVal);
op3->AsMskCon()->gtSimdMaskVal = result;
// op2 = op2 | op3
result = {};
EvaluateBinaryMask<simd16_t>(GT_OR, false, simdBaseType, &result, op2->AsMskCon()->gtSimdMaskVal,
op3->AsMskCon()->gtSimdMaskVal);
op2->AsMskCon()->gtSimdMaskVal = result;
resultNode = op2;
}
break;
}
#endif // TARGET_ARM64

View File

@ -17,110 +17,164 @@ public class PredicateInstructions
{
if (Sve.IsSupported)
{
ZipLow();
ZipHigh();
UnzipOdd();
UnzipEven();
TransposeOdd();
TransposeEven();
ReverseElement();
And();
BitwiseClear();
Xor();
Or();
ConditionalSelect();
Vector<sbyte> vecsb = Vector.Create<sbyte>(2);
Vector<short> vecs = Vector.Create<short>(2);
Vector<ushort> vecus = Vector.Create<ushort>(2);
Vector<int> veci = Vector.Create<int>(3);
Vector<uint> vecui = Vector.Create<uint>(5);
Vector<long> vecl = Vector.Create<long>(7);
ZipLowMask(vecs, vecs);
ZipHighMask(vecui, vecui);
UnzipOddMask(vecs, vecs);
UnzipEvenMask(vecsb, vecsb);
TransposeEvenMask(vecl, vecl);
TransposeOddMask(vecs, vecs);
ReverseElementMask(vecs, vecs);
AndMask(vecs, vecs);
BitwiseClearMask(vecs, vecs);
XorMask(veci, veci);
OrMask(vecs, vecs);
ConditionalSelectMask(veci, veci, veci);
UnzipEvenZipLowMask(vecs, vecs);
TransposeEvenAndMask(vecs, vecs, vecs);
}
}
// These should use the predicate variants.
// Sve intrinsics that return masks (Compare) or use mask arguments (CreateBreakAfterMask) are used
// to ensure masks are used.
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> ZipLow()
static Vector<short> ZipLowMask(Vector<short> a, Vector<short> b)
{
return Sve.ZipLow(Vector<short>.Zero, Sve.CreateTrueMaskInt16());
//ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<uint> ZipHigh()
static Vector<uint> ZipHighMask(Vector<uint> a, Vector<uint> b)
{
return Sve.ZipHigh(Sve.CreateTrueMaskUInt32(), Sve.CreateTrueMaskUInt32());
//ARM64-FULL-LINE: zip2 {{p[0-9]+}}.s, {{p[0-9]+}}.s, {{p[0-9]+}}.s
return Sve.CreateBreakAfterMask(Sve.ZipHigh(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskUInt32());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<sbyte> UnzipEven()
static Vector<sbyte> UnzipEvenMask(Vector<sbyte> a, Vector<sbyte> b)
{
return Sve.UnzipEven(Sve.CreateTrueMaskSByte(), Vector<sbyte>.Zero);
//ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.b, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.UnzipEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> UnzipOdd()
static Vector<short> UnzipOddMask(Vector<short> a, Vector<short> b)
{
return Sve.UnzipOdd(Sve.CreateTrueMaskInt16(), Sve.CreateFalseMaskInt16());
//ARM64-FULL-LINE: uzp2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.CreateBreakAfterMask(Sve.UnzipOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateTrueMaskInt16());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<long> TransposeEven()
static Vector<long> TransposeEvenMask(Vector<long> a, Vector<long> b)
{
return Sve.TransposeEven(Sve.CreateFalseMaskInt64(), Sve.CreateTrueMaskInt64());
//ARM64-FULL-LINE: trn1 {{p[0-9]+}}.d, {{p[0-9]+}}.d, {{p[0-9]+}}.d
return Sve.CreateBreakAfterMask(Sve.TransposeEven(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)), Sve.CreateFalseMaskInt64());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> TransposeOdd()
static Vector<short> TransposeOddMask(Vector<short> a, Vector<short> b)
{
return Sve.TransposeOdd(Vector<short>.Zero, Sve.CreateTrueMaskInt16());
//ARM64-FULL-LINE: trn2 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.TransposeOdd(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b));
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> ReverseElement()
static Vector<short> ReverseElementMask(Vector<short> a, Vector<short> b)
{
return Sve.ReverseElement(Sve.CreateTrueMaskInt16());
//ARM64-FULL-LINE: rev {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.CreateBreakAfterMask(Sve.ReverseElement(Sve.CompareGreaterThan(a, b)), Sve.CreateFalseMaskInt16());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> And()
static Vector<short> AndMask(Vector<short> a, Vector<short> b)
{
//ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.CreateBreakAfterMask(
Sve.ConditionalSelect(
Sve.CreateTrueMaskInt16(),
Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Vector<short>.Zero),
Sve.CreateFalseMaskInt16());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> BitwiseClearMask(Vector<short> a, Vector<short> b)
{
//ARM64-FULL-LINE: bic {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.ConditionalSelect(
Sve.CreateTrueMaskInt16(),
Sve.And(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
Vector<short>.Zero
);
Sve.CreateTrueMaskInt16(),
Sve.BitwiseClear(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Vector<short>.Zero);
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> BitwiseClear()
static Vector<int> XorMask(Vector<int> a, Vector<int> b)
{
return Sve.ConditionalSelect(
Sve.CreateFalseMaskInt16(),
Sve.BitwiseClear(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
Vector<short>.Zero
);
//ARM64-FULL-LINE: eor {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.CreateBreakAfterMask(
Sve.ConditionalSelect(
Sve.CreateTrueMaskInt32(),
Sve.Xor(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Vector<int>.Zero),
Sve.CreateFalseMaskInt32());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<int> Xor()
static Vector<short> OrMask(Vector<short> a, Vector<short> b)
{
//ARM64-FULL-LINE: orr {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.ConditionalSelect(
Sve.CreateTrueMaskInt32(),
Sve.Xor(Sve.CreateTrueMaskInt32(), Sve.CreateTrueMaskInt32()),
Vector<int>.Zero
);
Sve.CreateTrueMaskInt16(),
Sve.Or(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Vector<short>.Zero);
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> Or()
static Vector<int> ConditionalSelectMask(Vector<int> v, Vector<int> a, Vector<int> b)
{
return Sve.ConditionalSelect(
Sve.CreateTrueMaskInt16(),
Sve.Or(Sve.CreateTrueMaskInt16(), Sve.CreateTrueMaskInt16()),
Vector<short>.Zero
);
// Use a passed in vector for the mask to prevent optimising away the select
//ARM64-FULL-LINE: sel {{p[0-9]+}}.b, {{p[0-9]+}}, {{p[0-9]+}}.b, {{p[0-9]+}}.b
return Sve.CreateBreakAfterMask(
Sve.ConditionalSelect(v, Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Sve.CreateFalseMaskInt32());
}
// These have multiple uses of the predicate variants
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<short> UnzipEvenZipLowMask(Vector<short> a, Vector<short> b)
{
//ARM64-FULL-LINE: zip1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
//ARM64-FULL-LINE: uzp1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.CreateBreakAfterMask(
Sve.UnzipEven(
Sve.ZipLow(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Sve.CompareLessThan(a, b)),
Sve.CreateTrueMaskInt16());
}
[MethodImpl(MethodImplOptions.NoInlining)]
static Vector<int> ConditionalSelect()
static Vector<short> TransposeEvenAndMask(Vector<short> v, Vector<short> a, Vector<short> b)
{
return Sve.ConditionalSelect(
Vector<int>.Zero,
Sve.CreateFalseMaskInt32(),
Sve.CreateTrueMaskInt32()
);
//ARM64-FULL-LINE: and {{p[0-9]+}}.b, {{p[0-9]+}}/z, {{p[0-9]+}}.b, {{p[0-9]+}}.b
//ARM64-FULL-LINE: trn1 {{p[0-9]+}}.h, {{p[0-9]+}}.h, {{p[0-9]+}}.h
return Sve.TransposeEven(
Sve.CompareGreaterThan(a, b),
Sve.ConditionalSelect(
Sve.CreateTrueMaskInt16(),
Sve.And(Sve.CompareGreaterThan(a, b), Sve.CompareEqual(a, b)),
Sve.CompareLessThan(a, b)));
}
}
}