Perform factorization as a last resort of unsafe fadd/fsub simplification.
Rules include:
1)1 x*y +/- x*z => x*(y +/- z)
(the order of operands dosen't matter)
2) y/x +/- z/x => (y +/- z)/x
The transformation is disabled if the new add/sub expr "y +/- z" is a
denormal/naz/inifinity.
rdar://12911472
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@177088 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index c6d60d6..3c5781c 100644
--- a/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -150,7 +150,9 @@
typedef SmallVector<const FAddend*, 4> AddendVect;
Value *simplifyFAdd(AddendVect& V, unsigned InstrQuota);
-
+
+ Value *performFactorization(Instruction *I);
+
/// Convert given addend to a Value
Value *createAddendVal(const FAddend &A, bool& NeedNeg);
@@ -159,6 +161,7 @@
Value *createFSub(Value *Opnd0, Value *Opnd1);
Value *createFAdd(Value *Opnd0, Value *Opnd1);
Value *createFMul(Value *Opnd0, Value *Opnd1);
+ Value *createFDiv(Value *Opnd0, Value *Opnd1);
Value *createFNeg(Value *V);
Value *createNaryFAdd(const AddendVect& Opnds, unsigned InstrQuota);
void createInstPostProc(Instruction *NewInst);
@@ -388,6 +391,78 @@
return BreakNum;
}
+// Try to perform following optimization on the input instruction I. Return the
+// simplified expression if was successful; otherwise, return 0.
+//
+// Instruction "I" is Simplified into
+// -------------------------------------------------------
+// (x * y) +/- (x * z) x * (y +/- z)
+// (y / x) +/- (z / x) (y +/- z) / x
+//
+Value *FAddCombine::performFactorization(Instruction *I) {
+ assert((I->getOpcode() == Instruction::FAdd ||
+ I->getOpcode() == Instruction::FSub) && "Expect add/sub");
+
+ Instruction *I0 = dyn_cast<Instruction>(I->getOperand(0));
+ Instruction *I1 = dyn_cast<Instruction>(I->getOperand(1));
+
+ if (!I0 || !I1 || I0->getOpcode() != I1->getOpcode())
+ return 0;
+
+ bool isMpy = false;
+ if (I0->getOpcode() == Instruction::FMul)
+ isMpy = true;
+ else if (I0->getOpcode() != Instruction::FDiv)
+ return 0;
+
+ Value *Opnd0_0 = I0->getOperand(0);
+ Value *Opnd0_1 = I0->getOperand(1);
+ Value *Opnd1_0 = I1->getOperand(0);
+ Value *Opnd1_1 = I1->getOperand(1);
+
+ // Input Instr I Factor AddSub0 AddSub1
+ // ----------------------------------------------
+ // (x*y) +/- (x*z) x y z
+ // (y/x) +/- (z/x) x y z
+ //
+ Value *Factor = 0;
+ Value *AddSub0 = 0, *AddSub1 = 0;
+
+ if (isMpy) {
+ if (Opnd0_0 == Opnd1_0 || Opnd0_0 == Opnd1_1)
+ Factor = Opnd0_0;
+ else if (Opnd0_1 == Opnd1_0 || Opnd0_1 == Opnd1_1)
+ Factor = Opnd0_1;
+
+ if (Factor) {
+ AddSub0 = (Factor == Opnd0_0) ? Opnd0_1 : Opnd0_0;
+ AddSub1 = (Factor == Opnd1_0) ? Opnd1_1 : Opnd1_0;
+ }
+ } else if (Opnd0_1 == Opnd1_1) {
+ Factor = Opnd0_1;
+ AddSub0 = Opnd0_0;
+ AddSub1 = Opnd1_0;
+ }
+
+ if (!Factor)
+ return 0;
+
+ // Create expression "NewAddSub = AddSub0 +/- AddsSub1"
+ Value *NewAddSub = (I->getOpcode() == Instruction::FAdd) ?
+ createFAdd(AddSub0, AddSub1) :
+ createFSub(AddSub0, AddSub1);
+ if (ConstantFP *CFP = dyn_cast<ConstantFP>(NewAddSub)) {
+ const APFloat &F = CFP->getValueAPF();
+ if (!F.isNormal() || F.isDenormal())
+ return 0;
+ }
+
+ if (isMpy)
+ return createFMul(Factor, NewAddSub);
+
+ return createFDiv(NewAddSub, Factor);
+}
+
Value *FAddCombine::simplify(Instruction *I) {
assert(I->hasUnsafeAlgebra() && "Should be in unsafe mode");
@@ -471,7 +546,8 @@
return R;
}
- return 0;
+ // step 6: Try factorization as the last resort,
+ return performFactorization(I);
}
Value *FAddCombine::simplifyFAdd(AddendVect& Addends, unsigned InstrQuota) {
@@ -627,7 +703,8 @@
Value *FAddCombine::createFSub
(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFSub(Opnd0, Opnd1);
- createInstPostProc(cast<Instruction>(V));
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ createInstPostProc(I);
return V;
}
@@ -639,13 +716,22 @@
Value *FAddCombine::createFAdd
(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFAdd(Opnd0, Opnd1);
- createInstPostProc(cast<Instruction>(V));
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ createInstPostProc(I);
return V;
}
Value *FAddCombine::createFMul(Value *Opnd0, Value *Opnd1) {
Value *V = Builder->CreateFMul(Opnd0, Opnd1);
- createInstPostProc(cast<Instruction>(V));
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ createInstPostProc(I);
+ return V;
+}
+
+Value *FAddCombine::createFDiv(Value *Opnd0, Value *Opnd1) {
+ Value *V = Builder->CreateFDiv(Opnd0, Opnd1);
+ if (Instruction *I = dyn_cast<Instruction>(V))
+ createInstPostProc(I);
return V;
}
diff --git a/test/Transforms/InstCombine/fast-math.ll b/test/Transforms/InstCombine/fast-math.ll
index 3e32a2e..47f1ec4 100644
--- a/test/Transforms/InstCombine/fast-math.ll
+++ b/test/Transforms/InstCombine/fast-math.ll
@@ -350,3 +350,108 @@
; CHECK: @fdiv9
; CHECK: fmul fast float %x, 5.000000e+00
}
+
+; =========================================================================
+;
+; Testing-cases about factorization
+;
+; =========================================================================
+; x*z + y*z => (x+y) * z
+define float @fact_mul1(float %x, float %y, float %z) {
+ %t1 = fmul fast float %x, %z
+ %t2 = fmul fast float %y, %z
+ %t3 = fadd fast float %t1, %t2
+ ret float %t3
+; CHECK: @fact_mul1
+; CHECK: fmul fast float %1, %z
+}
+
+; z*x + y*z => (x+y) * z
+define float @fact_mul2(float %x, float %y, float %z) {
+ %t1 = fmul fast float %z, %x
+ %t2 = fmul fast float %y, %z
+ %t3 = fsub fast float %t1, %t2
+ ret float %t3
+; CHECK: @fact_mul2
+; CHECK: fmul fast float %1, %z
+}
+
+; z*x - z*y => (x-y) * z
+define float @fact_mul3(float %x, float %y, float %z) {
+ %t2 = fmul fast float %z, %y
+ %t1 = fmul fast float %z, %x
+ %t3 = fsub fast float %t1, %t2
+ ret float %t3
+; CHECK: @fact_mul3
+; CHECK: fmul fast float %1, %z
+}
+
+; x*z - z*y => (x-y) * z
+define float @fact_mul4(float %x, float %y, float %z) {
+ %t1 = fmul fast float %x, %z
+ %t2 = fmul fast float %z, %y
+ %t3 = fsub fast float %t1, %t2
+ ret float %t3
+; CHECK: @fact_mul4
+; CHECK: fmul fast float %1, %z
+}
+
+; x/y + x/z, no xform
+define float @fact_div1(float %x, float %y, float %z) {
+ %t1 = fdiv fast float %x, %y
+ %t2 = fdiv fast float %x, %z
+ %t3 = fadd fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div1
+; CHECK: fadd fast float %t1, %t2
+}
+
+; x/y + z/x; no xform
+define float @fact_div2(float %x, float %y, float %z) {
+ %t1 = fdiv fast float %x, %y
+ %t2 = fdiv fast float %z, %x
+ %t3 = fadd fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div2
+; CHECK: fadd fast float %t1, %t2
+}
+
+; y/x + z/x => (y+z)/x
+define float @fact_div3(float %x, float %y, float %z) {
+ %t1 = fdiv fast float %y, %x
+ %t2 = fdiv fast float %z, %x
+ %t3 = fadd fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div3
+; CHECK: fdiv fast float %1, %x
+}
+
+; y/x - z/x => (y-z)/x
+define float @fact_div4(float %x, float %y, float %z) {
+ %t1 = fdiv fast float %y, %x
+ %t2 = fdiv fast float %z, %x
+ %t3 = fsub fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div4
+; CHECK: fdiv fast float %1, %x
+}
+
+; y/x - z/x => (y-z)/x is disabled if y-z is denormal.
+define float @fact_div5(float %x) {
+ %t1 = fdiv fast float 0x3810000000000000, %x
+ %t2 = fdiv fast float 0x3800000000000000, %x
+ %t3 = fadd fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div5
+; CHECK: fdiv fast float 0x3818000000000000, %x
+}
+
+; y/x - z/x => (y-z)/x is disabled if y-z is denormal.
+define float @fact_div6(float %x) {
+ %t1 = fdiv fast float 0x3810000000000000, %x
+ %t2 = fdiv fast float 0x3800000000000000, %x
+ %t3 = fsub fast float %t1, %t2
+ ret float %t3
+; CHECK: fact_div6
+; CHECK: %t3 = fsub fast float %t1, %t2
+}