Skip to content

Commit 3fb6fbb

Browse files
SVE2 API DotProductComplex and DotProductComplexBySelectedIndex (#117706)
* SVE2 DotProductComplex* Implementation for DotProductComplex and DotProductComplexBySelectedIndex * Renaming to DotProductRotateComplex * Adding constant range to rotation parameter * Update src/coreclr/jit/hwintrinsiccodegenarm64.cpp --------- Co-authored-by: Aman Khalid <[email protected]>
1 parent 2fe0457 commit 3fb6fbb

File tree

13 files changed

+476
-111
lines changed

13 files changed

+476
-111
lines changed

src/coreclr/jit/emitarm64sve.cpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4560,14 +4560,13 @@ void emitter::emitInsSve_R_R_R_I(instruction ins,
45604560
case INS_sve_cdot:
45614561
assert(insScalableOptsNone(sopt));
45624562
assert(insOptsScalableWords(opt));
4563-
assert(isVectorRegister(reg1)); // ddddd
4564-
assert(isVectorRegister(reg2)); // nnnnn
4565-
assert(isVectorRegister(reg3)); // mmmmm
4566-
assert(isValidRot(imm)); // rr
4567-
assert(isValidVectorElemsize(optGetSveElemsize(opt))); // xx
4563+
assert(isVectorRegister(reg1)); // ddddd
4564+
assert(isVectorRegister(reg2)); // nnnnn
4565+
assert(isVectorRegister(reg3)); // mmmmm
4566+
assert(isValidRot(emitDecodeRotationImm0_to_270(imm))); // rr
4567+
assert(isValidVectorElemsize(optGetSveElemsize(opt))); // xx
45684568

45694569
// Convert rot to bitwise representation
4570-
imm = emitEncodeRotationImm0_to_270(imm);
45714570
fmt = IF_SVE_EJ_3A;
45724571
break;
45734572

@@ -5764,12 +5763,12 @@ void emitter::emitInsSve_R_R_R_I_I(instruction ins,
57645763
switch (ins)
57655764
{
57665765
case INS_sve_cdot:
5767-
assert(isVectorRegister(reg1)); // ddddd
5768-
assert(isVectorRegister(reg2)); // nnnnn
5769-
assert(isLowVectorRegister(reg3)); // mmmm
5770-
assert(isValidRot(imm2)); // rr
5771-
// Convert imm2 from rotation value (0-270) to bitwise representation (0-3)
5772-
imm = (imm1 << 2) | emitEncodeRotationImm0_to_270(imm2);
5766+
assert(isVectorRegister(reg1)); // ddddd
5767+
assert(isVectorRegister(reg2)); // nnnnn
5768+
assert(isLowVectorRegister(reg3)); // mmmm
5769+
assert(isValidRot(emitDecodeRotationImm0_to_270(imm2))); // rr
5770+
5771+
imm = (imm1 << 2) | imm2;
57735772

57745773
if (opt == INS_OPTS_SCALABLE_B)
57755774
{

src/coreclr/jit/hwintrinsic.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1275,6 +1275,7 @@ struct HWIntrinsicInfo
12751275
}
12761276

12771277
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
1278+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
12781279
{
12791280
assert(sig->numArgs == 5);
12801281
*imm1Pos = 0;

src/coreclr/jit/hwintrinsicarm64.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ void HWIntrinsicInfo::lookupImmBounds(
493493
break;
494494

495495
case NI_Sve_MultiplyAddRotateComplex:
496+
case NI_Sve2_DotProductRotateComplex:
496497
immLowerBound = 0;
497498
immUpperBound = 3;
498499
break;
@@ -518,6 +519,23 @@ void HWIntrinsicInfo::lookupImmBounds(
518519
}
519520
break;
520521

522+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
523+
if (immNumber == 1)
524+
{
525+
// Bounds for rotation
526+
immLowerBound = 0;
527+
immUpperBound = 3;
528+
}
529+
else
530+
{
531+
// Bounds for index
532+
assert(immNumber == 2);
533+
assert(baseType == TYP_BYTE || baseType == TYP_SHORT);
534+
immLowerBound = 0;
535+
immUpperBound = (baseType == TYP_BYTE) ? 3 : 1;
536+
}
537+
break;
538+
521539
case NI_Sve_TrigonometricMultiplyAddCoefficient:
522540
immLowerBound = 0;
523541
immUpperBound = 7;
@@ -3205,6 +3223,7 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
32053223
}
32063224

32073225
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
3226+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
32083227
{
32093228
assert(sig->numArgs == 5);
32103229
assert(!isScalar);

src/coreclr/jit/hwintrinsiccodegenarm64.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2737,6 +2737,116 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)
27372737
GetEmitter()->emitInsSve_R_R_R(ins, emitSize, targetReg, op3Reg, op1Reg, INS_OPTS_SCALABLE_D);
27382738
break;
27392739

2740+
case NI_Sve2_DotProductRotateComplex:
2741+
{
2742+
assert(isRMW);
2743+
assert(hasImmediateOperand);
2744+
2745+
HWIntrinsicImmOpHelper helper(this, intrin.op4, node, (targetReg != op1Reg) ? 2 : 1);
2746+
2747+
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
2748+
{
2749+
if (targetReg != op1Reg)
2750+
{
2751+
assert(targetReg != op2Reg);
2752+
2753+
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
2754+
}
2755+
2756+
GetEmitter()->emitInsSve_R_R_R_I(ins, emitSize, targetReg, op2Reg, op3Reg, helper.ImmValue(), opt);
2757+
}
2758+
break;
2759+
}
2760+
2761+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
2762+
{
2763+
assert(isRMW);
2764+
assert(hasImmediateOperand);
2765+
2766+
// If both immediates are constant, we don't need a jump table
2767+
if (intrin.op4->IsCnsIntOrI() && intrin.op5->IsCnsIntOrI())
2768+
{
2769+
if (targetReg != op1Reg)
2770+
{
2771+
assert(targetReg != op2Reg);
2772+
assert(targetReg != op3Reg);
2773+
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
2774+
}
2775+
2776+
assert(intrin.op4->isContainedIntOrIImmed() && intrin.op5->isContainedIntOrIImmed());
2777+
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg,
2778+
intrin.op4->AsIntCon()->gtIconVal,
2779+
intrin.op5->AsIntCon()->gtIconVal, opt);
2780+
}
2781+
else
2782+
{
2783+
// Use the helper to generate a table. The table can only use a single lookup value, therefore
2784+
// the two immediates index and rotation must be combined to a single value
2785+
assert(!intrin.op4->isContainedIntOrIImmed() && !intrin.op5->isContainedIntOrIImmed());
2786+
emitAttr scalarSize = emitActualTypeSize(node->GetSimdBaseType());
2787+
2788+
var_types baseType = node->GetSimdBaseType();
2789+
2790+
if (baseType == TYP_BYTE)
2791+
{
2792+
GetEmitter()->emitIns_R_R_I(INS_lsl, scalarSize, op5Reg, op5Reg, 2);
2793+
GetEmitter()->emitIns_R_R_R(INS_orr, scalarSize, op4Reg, op4Reg, op5Reg);
2794+
2795+
// index and rotation both take values 0 to 3 so must be
2796+
// combined to a single value (0 to 15)
2797+
HWIntrinsicImmOpHelper helper(this, op4Reg, 0, 15, node, (targetReg != op1Reg) ? 2 : 1);
2798+
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
2799+
{
2800+
if (targetReg != op1Reg)
2801+
{
2802+
assert(targetReg != op2Reg);
2803+
assert(targetReg != op3Reg);
2804+
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
2805+
}
2806+
2807+
const int value = helper.ImmValue();
2808+
const ssize_t index = value & 3;
2809+
const ssize_t rotation = (value >> 2) & 3;
2810+
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg, index,
2811+
rotation, opt);
2812+
}
2813+
2814+
GetEmitter()->emitIns_R_R_I(INS_and, scalarSize, op4Reg, op4Reg, 3);
2815+
GetEmitter()->emitIns_R_R_I(INS_lsr, scalarSize, op5Reg, op5Reg, 2);
2816+
}
2817+
else
2818+
{
2819+
assert(baseType == TYP_SHORT);
2820+
GetEmitter()->emitIns_R_R_I(INS_lsl, scalarSize, op5Reg, op5Reg, 1);
2821+
GetEmitter()->emitIns_R_R_R(INS_orr, scalarSize, op4Reg, op4Reg, op5Reg);
2822+
2823+
// index (0 to 1, in op4Reg) and rotation (0 to 3, in op5Reg) must be
2824+
// combined to a single value (0 to 7)
2825+
HWIntrinsicImmOpHelper helper(this, op4Reg, 0, 7, node, (targetReg != op1Reg) ? 2 : 1);
2826+
for (helper.EmitBegin(); !helper.Done(); helper.EmitCaseEnd())
2827+
{
2828+
if (targetReg != op1Reg)
2829+
{
2830+
assert(targetReg != op2Reg);
2831+
assert(targetReg != op3Reg);
2832+
GetEmitter()->emitInsSve_R_R(INS_sve_movprfx, EA_SCALABLE, targetReg, op1Reg);
2833+
}
2834+
2835+
const int value = helper.ImmValue();
2836+
const ssize_t index = value & 1;
2837+
const ssize_t rotation = value >> 1;
2838+
GetEmitter()->emitInsSve_R_R_R_I_I(ins, emitSize, targetReg, op2Reg, op3Reg, index,
2839+
rotation, opt);
2840+
}
2841+
2842+
GetEmitter()->emitIns_R_R_I(INS_and, scalarSize, op4Reg, op4Reg, 1);
2843+
GetEmitter()->emitIns_R_R_I(INS_lsr, scalarSize, op5Reg, op5Reg, 1);
2844+
}
2845+
}
2846+
2847+
break;
2848+
}
2849+
27402850
case NI_Sve2_SubtractWideningEven:
27412851
{
27422852
var_types returnType = node->AsHWIntrinsic()->GetSimdBaseType();

src/coreclr/jit/hwintrinsiclistarm64sve.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,8 @@ HARDWARE_INTRINSIC(Sve2, BitwiseClearXor,
338338
HARDWARE_INTRINSIC(Sve2, BitwiseSelect, -1, 3, {INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_sve_bsl, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
339339
HARDWARE_INTRINSIC(Sve2, BitwiseSelectLeftInverted, -1, 3, {INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_sve_bsl1n, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
340340
HARDWARE_INTRINSIC(Sve2, BitwiseSelectRightInverted, -1, 3, {INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_sve_bsl2n, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_HasRMWSemantics)
341+
HARDWARE_INTRINSIC(Sve2, DotProductRotateComplex, -1, 4, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_cdot, INS_invalid, INS_sve_cdot, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics|HW_Flag_SpecialCodeGen|HW_Flag_HasImmediateOperand)
342+
HARDWARE_INTRINSIC(Sve2, DotProductRotateComplexBySelectedIndex, -1, 5, {INS_sve_cdot, INS_invalid, INS_sve_cdot, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_HasRMWSemantics|HW_Flag_SpecialCodeGen|HW_Flag_HasImmediateOperand|HW_Flag_LowVectorOperation|HW_Flag_SpecialImport|HW_Flag_BaseTypeFromSecondArg)
341343
HARDWARE_INTRINSIC(Sve2, FusedAddHalving, -1, -1, {INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_sve_shadd, INS_sve_uhadd, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
342344
HARDWARE_INTRINSIC(Sve2, FusedAddRoundedHalving, -1, -1, {INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_sve_srhadd, INS_sve_urhadd, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
343345
HARDWARE_INTRINSIC(Sve2, FusedSubtractHalving, -1, -1, {INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_sve_shsub, INS_sve_uhsub, INS_invalid, INS_invalid}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)

src/coreclr/jit/lowerarmarch.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4091,6 +4091,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
40914091
case NI_Sve_FusedMultiplyAddBySelectedScalar:
40924092
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
40934093
case NI_Sve_MultiplyAddRotateComplex:
4094+
case NI_Sve2_DotProductRotateComplex:
40944095
assert(hasImmediateOperand);
40954096
assert(varTypeIsIntegral(intrin.op4));
40964097
if (intrin.op4->IsCnsIntOrI())
@@ -4148,6 +4149,7 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
41484149
break;
41494150

41504151
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
4152+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
41514153
assert(hasImmediateOperand);
41524154
assert(varTypeIsIntegral(intrin.op4));
41534155
assert(varTypeIsIntegral(intrin.op5));

src/coreclr/jit/lsraarm64.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,6 +1710,7 @@ void LinearScan::BuildHWIntrinsicImmediate(GenTreeHWIntrinsic* intrinsicTree, co
17101710
break;
17111711

17121712
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
1713+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
17131714
// This API has two immediates, one of which is used to index pairs of floats in a vector.
17141715
// For a vector width of 128 bits, this means the index's range is [0, 1],
17151716
// which means we will skip the above jump table register check,
@@ -1734,6 +1735,7 @@ void LinearScan::BuildHWIntrinsicImmediate(GenTreeHWIntrinsic* intrinsicTree, co
17341735
break;
17351736

17361737
case NI_Sve_MultiplyAddRotateComplex:
1738+
case NI_Sve2_DotProductRotateComplex:
17371739
needBranchTargetReg = !intrin.op4->isContainedIntOrIImmed();
17381740
break;
17391741

@@ -2164,6 +2166,7 @@ SingleTypeRegSet LinearScan::getOperandCandidates(GenTreeHWIntrinsic* intrinsicT
21642166
case NI_Sve_FusedMultiplyAddBySelectedScalar:
21652167
case NI_Sve_FusedMultiplySubtractBySelectedScalar:
21662168
case NI_Sve_MultiplyAddRotateComplexBySelectedScalar:
2169+
case NI_Sve2_DotProductRotateComplexBySelectedIndex:
21672170
case NI_Sve2_MultiplyAddBySelectedScalar:
21682171
case NI_Sve2_MultiplyBySelectedScalarWideningEvenAndAdd:
21692172
case NI_Sve2_MultiplyBySelectedScalarWideningOddAndAdd:
@@ -2189,6 +2192,10 @@ SingleTypeRegSet LinearScan::getOperandCandidates(GenTreeHWIntrinsic* intrinsicT
21892192
if (isLowVectorOpNum)
21902193
{
21912194
unsigned baseElementSize = genTypeSize(intrin.baseType);
2195+
if (intrin.id == NI_Sve2_DotProductRotateComplexBySelectedIndex)
2196+
{
2197+
baseElementSize = intrin.baseType == TYP_BYTE ? 4 : 8;
2198+
}
21922199

21932200
if (baseElementSize == 8)
21942201
{

src/libraries/System.Private.CoreLib/src/System/Runtime/Intrinsics/Arm/Sve2.PlatformNotSupported.cs

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,6 +1121,33 @@ internal Arm64() { }
11211121
/// </summary>
11221122
public static Vector<ulong> BitwiseSelectRightInverted(Vector<ulong> select, Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }
11231123

1124+
// Complex dot product
1125+
1126+
/// <summary>
1127+
/// svint32_t svcdot[_s32](svint32_t op1, svint8_t op2, svint8_t op3, uint64_t imm_rotation)
1128+
/// CDOT Ztied1.S, Zop2.B, Zop3.B, #imm_rotation
1129+
/// </summary>
1130+
public static Vector<int> DotProductRotateComplex(Vector<int> op1, Vector<sbyte> op2, Vector<sbyte> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }
1131+
1132+
/// <summary>
1133+
/// svint64_t svcdot[_s64](svint64_t op1, svint16_t op2, svint16_t op3, uint64_t imm_rotation)
1134+
/// CDOT Ztied1.D, Zop2.H, Zop3.H, #imm_rotation
1135+
/// </summary>
1136+
public static Vector<long> DotProductRotateComplex(Vector<long> op1, Vector<short> op2, Vector<short> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }
1137+
1138+
/// <summary>
1139+
/// svint32_t svcdot_lane[_s32](svint32_t op1, svint8_t op2, svint8_t op3, uint64_t imm_index, uint64_t imm_rotation)
1140+
/// CDOT Ztied1.S, Zop2.B, Zop3.B[imm_index], #imm_rotation
1141+
/// </summary>
1142+
public static Vector<int> DotProductRotateComplexBySelectedIndex(Vector<int> op1, Vector<sbyte> op2, Vector<sbyte> op3, [ConstantExpected(Min = 0, Max = (byte)(3))] byte imm_index, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }
1143+
1144+
/// <summary>
1145+
/// svint64_t svcdot_lane[_s64](svint64_t op1, svint16_t op2, svint16_t op3, uint64_t imm_index, uint64_t imm_rotation)
1146+
/// CDOT Ztied1.D, Zop2.H, Zop3.H[imm_index], #imm_rotation
1147+
/// </summary>
1148+
public static Vector<long> DotProductRotateComplexBySelectedIndex(Vector<long> op1, Vector<short> op2, Vector<short> op3, [ConstantExpected(Min = 0, Max = (byte)(1))] byte imm_index, [ConstantExpected(Min = 0, Max = (byte)(3))] byte rotation) { throw new PlatformNotSupportedException(); }
1149+
1150+
11241151
// Halving add
11251152

11261153
/// <summary>

0 commit comments

Comments
 (0)