Skip to content

Commit e0c9f52

Browse files
authored
PDLL article (#43)
* update llvm commit hash to 2024-07-12 * update implementation for MLIR version bump * add dedicated test file * add first attempt at porting MulToAdd to PDLL * add LHS/RHS versions of power of two pattern * add PeelFromMul * add Cmake build --------- Co-authored-by: Jeremy Kun <[email protected]>
1 parent 8ac609a commit e0c9f52

File tree

13 files changed

+236
-11
lines changed

13 files changed

+236
-11
lines changed

bazel/import_llvm.bzl

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ load(
88
def import_llvm(name):
99
"""Imports LLVM."""
1010

11-
# 2023-11-13
12-
LLVM_COMMIT = "f778eafdd878e8b11ad76f9e0a312ce7791a7481"
11+
# 2024-07-12
12+
LLVM_COMMIT = "0913547d0e3939cc420e88ecd037240f33736820"
1313

1414
new_git_repository(
1515
name = name,

externals/llvm-project

lib/Analysis/ReduceNoiseAnalysis/ReduceNoiseAnalysis.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ ReduceNoiseAnalysis::ReduceNoiseAnalysis(Operation *op) {
4242
op->walk([&](Operation *op) {
4343
// FIXME: assumes all reduce_noise ops have already been removed and their
4444
// values forwarded.
45-
if (!llvm::isa<noisy::AddOp, noisy::SubOp, noisy::MulOp>(op)) {
45+
if (!isa<noisy::AddOp, noisy::SubOp, noisy::MulOp>(op)) {
4646
return;
4747
}
4848

@@ -81,8 +81,8 @@ ReduceNoiseAnalysis::ReduceNoiseAnalysis(Operation *op) {
8181
// In the tutorial, there is no control flow, so these are the function
8282
// arguments of the main function being analyzed. A real compiler would
8383
// need to handle this more generically.
84-
if (value.isa<BlockArgument>() ||
85-
llvm::isa<noisy::EncodeOp>(value.getDefiningOp())) {
84+
if (isa<BlockArgument>(value) ||
85+
isa<noisy::EncodeOp>(value.getDefiningOp())) {
8686
MPConstraint *const ct =
8787
solver->MakeRowConstraint(INITIAL_NOISE, INITIAL_NOISE, "");
8888
ct->SetCoefficient(var, 1);

lib/Dialect/Noisy/NoisyOps.td

+3-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1313
class Noisy_BinOp<string mnemonic> : Op<Noisy_Dialect, mnemonic, [
1414
Pure,
1515
SameOperandsAndResultType,
16-
DeclareOpInterfaceMethods<InferIntRangeInterface>
16+
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
1717
]> {
1818
let arguments = (ins Noisy_I32:$lhs, Noisy_I32:$rhs);
1919
let results = (outs Noisy_I32:$output);
@@ -33,7 +33,7 @@ def Noisy_MulOp : Noisy_BinOp<"mul"> {
3333
}
3434

3535
def Noisy_EncodeOp : Op<Noisy_Dialect, "encode", [
36-
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
36+
Pure, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
3737
let summary = "Encodes a noisy i32 from a small-width integer, injecting 12 bits of noise.";
3838
let arguments = (ins AnyIntOfWidths<[1, 2, 3, 4, 5]>:$input);
3939
let results = (outs Noisy_I32:$output);
@@ -48,7 +48,7 @@ def Noisy_DecodeOp : Op<Noisy_Dialect, "decode", [Pure]> {
4848
}
4949

5050
def Noisy_ReduceNoiseOp : Op<Noisy_Dialect, "reduce_noise", [
51-
Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<InferIntRangeInterface>]> {
51+
Pure, SameOperandsAndResultType, DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>]> {
5252
let summary = "Reduces the noise in a noisy integer to a fixed noise level. Expensive!";
5353
let arguments = (ins Noisy_I32:$input);
5454
let results = (outs Noisy_I32:$output);

lib/Dialect/Poly/PolyOps.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ OpFoldResult MulOp::fold(MulOp::FoldAdaptor adaptor) {
3131

3232
if (!lhs || !rhs) return nullptr;
3333

34-
auto degree = getResult().getType().cast<PolynomialType>().getDegreeBound();
34+
auto degree = llvm::cast<PolynomialType>(getResult().getType()).getDegreeBound();
3535
auto maxIndex = lhs.size() + rhs.size() - 1;
3636

3737
SmallVector<APInt, 8> result;
@@ -68,7 +68,7 @@ OpFoldResult FromTensorOp::fold(FromTensorOp::FoldAdaptor adaptor) {
6868
LogicalResult EvalOp::verify() {
6969
auto pointTy = getPoint().getType();
7070
bool isSignlessInteger = pointTy.isSignlessInteger(32);
71-
auto complexPt = llvm::dyn_cast<ComplexType>(pointTy);
71+
auto complexPt = dyn_cast<ComplexType>(pointTy);
7272
return isSignlessInteger || complexPt ? success()
7373
: emitOpError(
7474
"argument point must be a 32-bit "

lib/Transform/Arith/BUILD

+35
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ gentbl_cc_library(
2525
td_file = "Passes.td",
2626
deps = [
2727
"@llvm-project//mlir:OpBaseTdFiles",
28+
"@llvm-project//mlir:PDLDialectTdFiles",
29+
"@llvm-project//mlir:PDLInterpOpsTdFiles",
2830
"@llvm-project//mlir:PassBaseTdFiles",
2931
],
3032
)
@@ -47,6 +49,39 @@ cc_library(
4749
hdrs = ["Passes.h"],
4850
deps = [
4951
":MulToAdd",
52+
":MulToAddPdll",
5053
":pass_inc_gen",
5154
],
5255
)
56+
57+
gentbl_cc_library(
58+
name = "MulToAddPdllIncGen",
59+
tbl_outs = [
60+
(
61+
["-x=cpp"],
62+
"MulToAddPdll.h.inc",
63+
),
64+
],
65+
tblgen = "@llvm-project//mlir:mlir-pdll",
66+
td_file = "MulToAdd.pdll",
67+
deps = [
68+
"@llvm-project//mlir:ArithDialect",
69+
"@llvm-project//mlir:FuncDialect",
70+
"@llvm-project//mlir:ArithOpsTdFiles",
71+
],
72+
)
73+
74+
cc_library(
75+
name = "MulToAddPdll",
76+
srcs = ["MulToAddPdll.cpp"],
77+
hdrs = ["MulToAddPdll.h"],
78+
deps = [
79+
":pass_inc_gen",
80+
":MulToAddPdllIncGen",
81+
"@llvm-project//mlir:ArithDialect",
82+
"@llvm-project//mlir:FuncDialect",
83+
"@llvm-project//mlir:Pass",
84+
"@llvm-project//mlir:Transforms",
85+
],
86+
)
87+

lib/Transform/Arith/CMakeLists.txt

+7
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
1+
add_mlir_pdll_library(MulToAddPdllIncGen
2+
MulToAdd.pdll
3+
MulToAddPdll.h.inc
4+
)
5+
16
add_mlir_library(MulToAdd
27
MulToAdd.cpp
8+
MulToAddPdll.cpp
39

410
${PROJECT_SOURCE_DIR}/lib/Transform/Arith/
511
ADDITIONAL_HEADER_DIRS
612

713
DEPENDS
814
MLIRMulToAddPasses
15+
MulToAddPdllIncGen
916

1017
LINK_LIBS PUBLIC
1118
)

lib/Transform/Arith/MulToAdd.pdll

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#include "mlir/Dialect/Arith/IR/ArithOps.td"
2+
3+
Constraint IsPowerOfTwo(attr: Attr) [{
4+
int64_t value = cast<::mlir::IntegerAttr>(attr).getValue().getSExtValue();
5+
return success((value & (value - 1)) == 0);
6+
}];
7+
8+
// Currently, constraints that return values must be defined in C++
9+
Constraint Halve(atttr: Attr) -> Attr;
10+
Constraint MinusOne(attr: Attr) -> Attr;
11+
12+
// Replace y = C*x with y = C/2*x + C/2*x, when C is a power of 2, otherwise do
13+
// nothing.
14+
Pattern PowerOfTwoExpandRhs with benefit(2) {
15+
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);
16+
IsPowerOfTwo(const);
17+
let halved: Attr = Halve(const);
18+
19+
rewrite root with {
20+
let newConst = op<arith.constant> {value = halved};
21+
let newMul = op<arith.muli>(newConst, rhs);
22+
let newAdd = op<arith.addi>(newMul, newMul);
23+
replace root with newAdd;
24+
};
25+
}
26+
27+
Pattern PowerOfTwoExpandLhs with benefit(2) {
28+
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr});
29+
IsPowerOfTwo(const);
30+
let halved: Attr = Halve(const);
31+
32+
rewrite root with {
33+
let newConst = op<arith.constant> {value = halved};
34+
let newMul = op<arith.muli>(lhs, newConst);
35+
let newAdd = op<arith.addi>(newMul, newMul);
36+
replace root with newAdd;
37+
};
38+
}
39+
40+
// Replace y = 9*x with y = 8*x + x
41+
Pattern PeelFromMulRhs with benefit(1) {
42+
let root = op<arith.muli>(lhs: Value, op<arith.constant> {value = const: Attr});
43+
44+
// We are guaranteed `value` is not a power of two, because the greedy
45+
// rewrite engine ensures the PowerOfTwoExpand pattern is run first, since
46+
// it has higher benefit.
47+
let minusOne: Attr = MinusOne(const);
48+
49+
rewrite root with {
50+
let newConst = op<arith.constant> {value = minusOne};
51+
let newMul = op<arith.muli>(lhs, newConst);
52+
let newAdd = op<arith.addi>(newMul, lhs);
53+
replace root with newAdd;
54+
};
55+
}
56+
57+
Pattern PeelFromMulLhs with benefit(1) {
58+
let root = op<arith.muli>(op<arith.constant> {value = const: Attr}, rhs: Value);
59+
let minusOne: Attr = MinusOne(const);
60+
61+
rewrite root with {
62+
let newConst = op<arith.constant> {value = minusOne};
63+
let newMul = op<arith.muli>(newConst, rhs);
64+
let newAdd = op<arith.addi>(newMul, rhs);
65+
replace root with newAdd;
66+
};
67+
}

lib/Transform/Arith/MulToAddPdll.cpp

+49
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include "lib/Transform/Arith/MulToAddPdll.h"
2+
#include "mlir/Dialect/Arith/IR/Arith.h"
3+
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
4+
#include "mlir/IR/PatternMatch.h"
5+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
6+
#include "mlir/include/mlir/Pass/Pass.h"
7+
8+
namespace mlir {
9+
namespace tutorial {
10+
11+
#define GEN_PASS_DEF_MULTOADDPDLL
12+
#include "lib/Transform/Arith/Passes.h.inc"
13+
14+
LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
15+
ArrayRef<PDLValue> args) {
16+
Attribute attr = args[0].cast<Attribute>();
17+
IntegerAttr cAttr = cast<IntegerAttr>(attr);
18+
int64_t value = cAttr.getValue().getSExtValue();
19+
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2));
20+
return success();
21+
}
22+
23+
LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results,
24+
ArrayRef<PDLValue> args) {
25+
Attribute attr = args[0].cast<Attribute>();
26+
IntegerAttr cAttr = cast<IntegerAttr>(attr);
27+
int64_t value = cAttr.getValue().getSExtValue();
28+
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1));
29+
return success();
30+
}
31+
32+
void registerNativeConstraints(RewritePatternSet &patterns) {
33+
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl);
34+
patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl);
35+
}
36+
37+
struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {
38+
using MulToAddPdllBase::MulToAddPdllBase;
39+
40+
void runOnOperation() {
41+
mlir::RewritePatternSet patterns(&getContext());
42+
populateGeneratedPDLLPatterns(patterns);
43+
registerNativeConstraints(patterns);
44+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
45+
}
46+
};
47+
48+
} // namespace tutorial
49+
} // namespace mlir

lib/Transform/Arith/MulToAddPdll.h

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
2+
#define LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_
3+
4+
#include "mlir/Pass/Pass.h"
5+
#include "mlir/IR/PatternMatch.h"
6+
#include "mlir/Dialect/Arith/IR/Arith.h"
7+
#include "mlir/Parser/Parser.h"
8+
9+
namespace mlir {
10+
namespace tutorial {
11+
12+
#define GEN_PASS_DECL_MULTOADDPDLL
13+
#include "lib/Transform/Arith/Passes.h.inc"
14+
15+
#include "lib/Transform/Arith/MulToAddPdll.h.inc"
16+
17+
} // namespace tutorial
18+
} // namespace mlir
19+
20+
#endif // LIB_TRANSFORM_ARITH_MULTOADDPDLL_H_

lib/Transform/Arith/Passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#define LIB_TRANSFORM_ARITH_PASSES_H_
33

44
#include "lib/Transform/Arith/MulToAdd.h"
5+
#include "lib/Transform/Arith/MulToAddPdll.h"
56

67
namespace mlir {
78
namespace tutorial {

lib/Transform/Arith/Passes.td

+13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#ifndef LIB_TRANSFORM_ARITH_PASSES_TD_
22
#define LIB_TRANSFORM_ARITH_PASSES_TD_
33

4+
include "mlir/Dialect/PDL/IR/PDLDialect.td"
5+
include "mlir/Dialect/PDLInterp/IR/PDLInterpOps.td"
46
include "mlir/Pass/PassBase.td"
57

68
def MulToAdd : Pass<"mul-to-add"> {
@@ -10,4 +12,15 @@ def MulToAdd : Pass<"mul-to-add"> {
1012
}];
1113
}
1214

15+
def MulToAddPdll : Pass<"mul-to-add-pdll"> {
16+
let summary = "Convert multiplications to repeated additions using pdll";
17+
let description = [{
18+
Convert multiplications to repeated additions (using pdll).
19+
}];
20+
let dependentDialects = [
21+
"mlir::pdl::PDLDialect",
22+
"mlir::pdl_interp::PDLInterpDialect",
23+
];
24+
}
25+
1326
#endif // LIB_TRANSFORM_ARITH_PASSES_TD_

tests/mul_to_add_pdll.mlir

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: tutorial-opt %s --mul-to-add-pdll | FileCheck %s
2+
3+
func.func @just_power_of_two(%arg: i32) -> i32 {
4+
%0 = arith.constant 8 : i32
5+
%1 = arith.muli %arg, %0 : i32
6+
func.return %1 : i32
7+
}
8+
9+
// CHECK-LABEL: func.func @just_power_of_two(
10+
// CHECK-SAME: %[[ARG:.*]]: i32
11+
// CHECK-SAME: ) -> i32 {
12+
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]]
13+
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]]
14+
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]]
15+
// CHECK: return %[[SUM_2]] : i32
16+
// CHECK: }
17+
18+
19+
func.func @power_of_two_plus_one(%arg: i32) -> i32 {
20+
%0 = arith.constant 9 : i32
21+
%1 = arith.muli %arg, %0 : i32
22+
func.return %1 : i32
23+
}
24+
25+
// CHECK-LABEL: func.func @power_of_two_plus_one(
26+
// CHECK-SAME: %[[ARG:.*]]: i32
27+
// CHECK-SAME: ) -> i32 {
28+
// CHECK: %[[SUM_0:.*]] = arith.addi %[[ARG]], %[[ARG]]
29+
// CHECK: %[[SUM_1:.*]] = arith.addi %[[SUM_0]], %[[SUM_0]]
30+
// CHECK: %[[SUM_2:.*]] = arith.addi %[[SUM_1]], %[[SUM_1]]
31+
// CHECK: %[[SUM_3:.*]] = arith.addi %[[SUM_2]], %[[ARG]]
32+
// CHECK: return %[[SUM_3]] : i32
33+
// CHECK: }

0 commit comments

Comments
 (0)