Skip to content

Commit 1704ae1

Browse files
committed
[InstCombine] Canonicalize signed saturated additions with positive numbers only
https://alive2.llvm.org/ce/z/YGT5SN This is tricky because with positive numbers, we only go up, so we can in fact always hit the signed_max boundary. This is important because the intrinsic we use has the behavior of going the OTHER way, aka clamp to INT_MIN if it goes in that direction. And the range checking we do only works for positive numbers. Because of this issue, we can only do this for constants as well.
1 parent e75d3b0 commit 1704ae1

File tree

3 files changed

+100
-33
lines changed

3 files changed

+100
-33
lines changed

llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,10 +1012,9 @@ static Value *canonicalizeSaturatedSubtract(const ICmpInst *ICI,
10121012
return Result;
10131013
}
10141014

1015-
static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
1016-
InstCombiner::BuilderTy &Builder) {
1017-
if (!Cmp->hasOneUse())
1018-
return nullptr;
1015+
static Value *
1016+
canonicalizeSaturatedAddUnsigned(ICmpInst *Cmp, Value *TVal, Value *FVal,
1017+
InstCombiner::BuilderTy &Builder) {
10191018

10201019
// Match unsigned saturated add with constant.
10211020
Value *Cmp0 = Cmp->getOperand(0);
@@ -1037,8 +1036,7 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
10371036
// uge -1 is canonicalized to eq -1 and requires special handling
10381037
// (a == -1) ? -1 : a + 1 -> uadd.sat(a, 1)
10391038
if (Pred == ICmpInst::ICMP_EQ) {
1040-
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) &&
1041-
match(Cmp1, m_AllOnes())) {
1039+
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) {
10421040
return Builder.CreateBinaryIntrinsic(
10431041
Intrinsic::uadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
10441042
}
@@ -1115,6 +1113,94 @@ static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
11151113
return nullptr;
11161114
}
11171115

1116+
static Value *canonicalizeSaturatedAddSigned(ICmpInst *Cmp, Value *TVal,
1117+
Value *FVal,
1118+
InstCombiner::BuilderTy &Builder) {
1119+
// Match saturated add with constant.
1120+
Value *Cmp0 = Cmp->getOperand(0);
1121+
Value *Cmp1 = Cmp->getOperand(1);
1122+
ICmpInst::Predicate Pred = Cmp->getPredicate();
1123+
const APInt *C;
1124+
1125+
// Canonicalize INT_MAX to true value of the select.
1126+
if (match(FVal, m_MaxSignedValue())) {
1127+
std::swap(TVal, FVal);
1128+
Pred = CmpInst::getInversePredicate(Pred);
1129+
}
1130+
if (!match(TVal, m_MaxSignedValue()))
1131+
return nullptr;
1132+
1133+
// sge maximum signed value is canonicalized to eq minimum signed value and
1134+
// requires special handling (a == INT_MAX) ? INT_MAX : a + 1 -> sadd.sat(a,
1135+
// 1)
1136+
if (Pred == ICmpInst::ICMP_EQ) {
1137+
if (match(FVal, m_Add(m_Specific(Cmp0), m_One())) && Cmp1 == TVal) {
1138+
return Builder.CreateBinaryIntrinsic(
1139+
Intrinsic::sadd_sat, Cmp0, ConstantInt::get(Cmp0->getType(), 1));
1140+
}
1141+
return nullptr;
1142+
}
1143+
1144+
if ((Pred == ICmpInst::ICMP_SGE || Pred == ICmpInst::ICMP_SGT) &&
1145+
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
1146+
match(Cmp1, m_SpecificIntAllowPoison(
1147+
APInt::getSignedMaxValue(
1148+
Cmp1->getType()->getScalarSizeInBits()) -
1149+
*C)) &&
1150+
!C->isNegative()) {
1151+
// (X > INT_MAX - C) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1152+
// (X >= INT_MAX - C) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1153+
return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, Cmp0,
1154+
ConstantInt::get(Cmp0->getType(), *C));
1155+
}
1156+
1157+
if (Pred == ICmpInst::ICMP_SGT &&
1158+
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
1159+
match(Cmp1, m_SpecificIntAllowPoison(
1160+
APInt::getSignedMaxValue(
1161+
Cmp1->getType()->getScalarSizeInBits()) -
1162+
*C - 1)) &&
1163+
!C->isNegative()) {
1164+
// (X > INT_MAX - C - 1) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1165+
return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, Cmp0,
1166+
ConstantInt::get(Cmp0->getType(), *C));
1167+
}
1168+
1169+
if (Pred == ICmpInst::ICMP_SGE &&
1170+
match(FVal, m_Add(m_Specific(Cmp0), m_APIntAllowPoison(C))) &&
1171+
match(Cmp1, m_SpecificIntAllowPoison(
1172+
APInt::getSignedMinValue(
1173+
Cmp1->getType()->getScalarSizeInBits()) -
1174+
*C)) &&
1175+
!C->isNegative()) {
1176+
// (X >= INT_MAX - C + 1) ? INT_MAX : (X + C) --> sadd.sat(X, C)
1177+
return Builder.CreateBinaryIntrinsic(Intrinsic::sadd_sat, Cmp0,
1178+
ConstantInt::get(Cmp0->getType(), *C));
1179+
}
1180+
1181+
// TODO: Try to match variables. However, due to the fact that we can only
1182+
// fold if we know at least one is positive, we cannot fold for each and every
1183+
// time, unlike the unsigned case, where every number is positive.
1184+
1185+
// TODO: Match when known negatives go towards INT_MIN.
1186+
1187+
return nullptr;
1188+
}
1189+
1190+
static Value *canonicalizeSaturatedAdd(ICmpInst *Cmp, Value *TVal, Value *FVal,
1191+
InstCombiner::BuilderTy &Builder) {
1192+
if (!Cmp->hasOneUse())
1193+
return nullptr;
1194+
1195+
if (Value *V = canonicalizeSaturatedAddUnsigned(Cmp, TVal, FVal, Builder))
1196+
return V;
1197+
1198+
if (Value *V = canonicalizeSaturatedAddSigned(Cmp, TVal, FVal, Builder))
1199+
return V;
1200+
1201+
return nullptr;
1202+
}
1203+
11181204
/// Try to match patterns with select and subtract as absolute difference.
11191205
static Value *foldAbsDiff(ICmpInst *Cmp, Value *TVal, Value *FVal,
11201206
InstCombiner::BuilderTy &Builder) {

llvm/test/Transforms/InstCombine/canonicalize-const-to-bop.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,7 @@ define i8 @udiv_slt_exact(i8 %x) {
123123
define i8 @canonicalize_icmp_operands(i8 %x) {
124124
; CHECK-LABEL: define i8 @canonicalize_icmp_operands(
125125
; CHECK-SAME: i8 [[X:%.*]]) {
126-
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[X]], i8 119)
127-
; CHECK-NEXT: [[S:%.*]] = add nsw i8 [[TMP1]], 8
126+
; CHECK-NEXT: [[S:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X]], i8 8)
128127
; CHECK-NEXT: ret i8 [[S]]
129128
;
130129
%add = add nsw i8 %x, 8

llvm/test/Transforms/InstCombine/saturating-add-sub.ll

Lines changed: 7 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2378,9 +2378,7 @@ define i8 @sadd_sat_ugt_int_max(i8 %x, i8 %y) {
23782378

23792379
define i8 @sadd_sat_eq_int_max(i8 %x) {
23802380
; CHECK-LABEL: @sadd_sat_eq_int_max(
2381-
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[X:%.*]], 127
2382-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], 1
2383-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 127, i8 [[ADD]]
2381+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 1)
23842382
; CHECK-NEXT: ret i8 [[R]]
23852383
;
23862384
%cmp = icmp eq i8 %x, 127
@@ -2391,8 +2389,7 @@ define i8 @sadd_sat_eq_int_max(i8 %x) {
23912389

23922390
define i8 @sadd_sat_constant(i8 %x) {
23932391
; CHECK-LABEL: @sadd_sat_constant(
2394-
; CHECK-NEXT: [[TMP1:%.*]] = call i8 @llvm.smin.i8(i8 [[X:%.*]], i8 117)
2395-
; CHECK-NEXT: [[R:%.*]] = add nsw i8 [[TMP1]], 10
2392+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 10)
23962393
; CHECK-NEXT: ret i8 [[R]]
23972394
;
23982395
%cmp = icmp sge i8 %x, 118
@@ -2467,10 +2464,7 @@ define <2 x i8> @sadd_sat_vector_constant(<2 x i8> %x) {
24672464

24682465
define i8 @sadd_sat_int_max_minus_x(i8 %x, i8 %y) {
24692466
; CHECK-LABEL: @sadd_sat_int_max_minus_x(
2470-
; CHECK-NEXT: [[SUB:%.*]] = sub i8 127, [[X:%.*]]
2471-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[SUB]], [[Y:%.*]]
2472-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2473-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 127, i8 [[ADD]]
2467+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
24742468
; CHECK-NEXT: ret i8 [[R]]
24752469
;
24762470
%sub = sub i8 127, %x
@@ -2482,10 +2476,7 @@ define i8 @sadd_sat_int_max_minus_x(i8 %x, i8 %y) {
24822476

24832477
define i8 @sadd_sat_int_max_minus_x_commuted(i8 %x, i8 %y) {
24842478
; CHECK-LABEL: @sadd_sat_int_max_minus_x_commuted(
2485-
; CHECK-NEXT: [[SUB:%.*]] = sub i8 127, [[X:%.*]]
2486-
; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[Y:%.*]], [[SUB]]
2487-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2488-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP]], i8 127, i8 [[ADD]]
2479+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
24892480
; CHECK-NEXT: ret i8 [[R]]
24902481
;
24912482
%sub = sub i8 127, %x
@@ -2497,10 +2488,7 @@ define i8 @sadd_sat_int_max_minus_x_commuted(i8 %x, i8 %y) {
24972488

24982489
define i8 @sadd_sat_int_max_minus_x_nonstrict(i8 %x, i8 %y) {
24992490
; CHECK-LABEL: @sadd_sat_int_max_minus_x_nonstrict(
2500-
; CHECK-NEXT: [[SUB:%.*]] = sub i8 127, [[X:%.*]]
2501-
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp sgt i8 [[SUB]], [[Y:%.*]]
2502-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2503-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP_NOT]], i8 [[ADD]], i8 127
2491+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
25042492
; CHECK-NEXT: ret i8 [[R]]
25052493
;
25062494
%sub = sub i8 127, %x
@@ -2512,10 +2500,7 @@ define i8 @sadd_sat_int_max_minus_x_nonstrict(i8 %x, i8 %y) {
25122500

25132501
define i8 @sadd_sat_int_max_minus_x_commuted_nonstrict(i8 %x, i8 %y) {
25142502
; CHECK-LABEL: @sadd_sat_int_max_minus_x_commuted_nonstrict(
2515-
; CHECK-NEXT: [[SUB:%.*]] = sub i8 127, [[X:%.*]]
2516-
; CHECK-NEXT: [[CMP_NOT:%.*]] = icmp slt i8 [[Y:%.*]], [[SUB]]
2517-
; CHECK-NEXT: [[ADD:%.*]] = add i8 [[X]], [[Y]]
2518-
; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP_NOT]], i8 [[ADD]], i8 127
2503+
; CHECK-NEXT: [[R:%.*]] = call i8 @llvm.sadd.sat.i8(i8 [[X:%.*]], i8 [[Y:%.*]])
25192504
; CHECK-NEXT: ret i8 [[R]]
25202505
;
25212506
%sub = sub i8 127, %x
@@ -2557,10 +2542,7 @@ define i8 @sadd_sat_int_max_minus_x_wrong_predicate(i8 %x, i8 %y) {
25572542

25582543
define <2 x i8> @sadd_sat_int_max_minus_x_vector(<2 x i8> %x, <2 x i8> %y) {
25592544
; CHECK-LABEL: @sadd_sat_int_max_minus_x_vector(
2560-
; CHECK-NEXT: [[SUB:%.*]] = sub <2 x i8> splat (i8 127), [[X:%.*]]
2561-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt <2 x i8> [[SUB]], [[Y:%.*]]
2562-
; CHECK-NEXT: [[ADD:%.*]] = add <2 x i8> [[X]], [[Y]]
2563-
; CHECK-NEXT: [[R:%.*]] = select <2 x i1> [[CMP]], <2 x i8> splat (i8 127), <2 x i8> [[ADD]]
2545+
; CHECK-NEXT: [[R:%.*]] = call <2 x i8> @llvm.sadd.sat.v2i8(<2 x i8> [[X:%.*]], <2 x i8> [[Y:%.*]])
25642546
; CHECK-NEXT: ret <2 x i8> [[R]]
25652547
;
25662548
%sub = sub <2 x i8> <i8 127, i8 127>, %x

0 commit comments

Comments
 (0)