Skip to content

Commit 677e0c6

Browse files
authored
Update TypeMerging for exact types (#7580)
TypeMerging was already careful not to merge a subtype and its supertype where there is a cast to the subtype that could differentiate it from its supertype. Now with exact types, it's possible to have an exact cast to the supertype that can differentiate it from the subtype as well. Update the pass to collect types used as the target of exact casts and ensure they are not merged with their subtypes. It would be natural to use a single map from HeapType to bool to track the cast target types and whether each one is used in an exact cast, but that would regress performance because this pass previously used a SmallSet to track cast types. I spent a few hours trying to create a SmallMap that shared most of its code with SmallSet, but it would require more time to make that work. For now, just use a second SmallSet to track exact exact casts.
1 parent 15df7fe commit 677e0c6

File tree

2 files changed

+159
-27
lines changed

2 files changed

+159
-27
lines changed

src/passes/TypeMerging.cpp

Lines changed: 44 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ using CastTypes = SmallUnorderedSet<HeapType, 5>;
7171
struct CastFinder : public PostWalker<CastFinder> {
7272
CastTypes castTypes;
7373

74+
// For each cast target, record whether there is an exact cast. Exact casts
75+
// will additionally prevent subtypes from being merged into the cast target.
76+
// TODO: Use a SmallMap to combine this with `castTypes`.
77+
CastTypes exactCastTypes;
78+
7479
// If traps never happen, then ref.cast and call_indirect can never
7580
// differentiate between types since they always succeed. Take advantage of
7681
// that by not having those instructions inhibit merges in TNH mode.
@@ -83,6 +88,9 @@ struct CastFinder : public PostWalker<CastFinder> {
8388
template<typename T> void visitCast(T* curr) {
8489
if (auto type = curr->getCastType(); type != Type::unreachable) {
8590
castTypes.insert(type.getHeapType());
91+
if (type.isExact()) {
92+
exactCastTypes.insert(type.getHeapType());
93+
}
8694
}
8795
}
8896

@@ -126,8 +134,9 @@ struct TypeMerging : public Pass {
126134
// All private original types.
127135
std::unordered_set<HeapType> privateTypes;
128136

129-
// Types that are distinguished by cast instructions.
137+
// Types that are distinguished by casts and exact casts.
130138
CastTypes castTypes;
139+
CastTypes exactCastTypes;
131140

132141
// The list of remaining types that have not been merged into other types.
133142
// Candidates for further merging.
@@ -169,7 +178,8 @@ struct TypeMerging : public Pass {
169178
std::vector<std::vector<HeapType>>
170179
splitSupertypePartition(const std::vector<HeapType>&);
171180

172-
CastTypes findCastTypes();
181+
// Return the cast types and the exact cast types.
182+
std::pair<CastTypes, CastTypes> findCastTypes();
173183
std::vector<HeapType> getPublicChildren(HeapType type);
174184
DFA::State<HeapType> makeDFAState(HeapType type);
175185
void applyMerges();
@@ -220,7 +230,9 @@ void TypeMerging::run(Module* module_) {
220230
mergeable = ModuleUtils::getPrivateHeapTypes(*module);
221231
privateTypes =
222232
std::unordered_set<HeapType>(mergeable.begin(), mergeable.end());
223-
castTypes = findCastTypes();
233+
auto casts = findCastTypes();
234+
castTypes = std::move(casts.first);
235+
exactCastTypes = std::move(casts.second);
224236

225237
// Merging supertypes or siblings can unlock more sibling merging
226238
// opportunities, but merging siblings can never unlock more supertype merging
@@ -329,17 +341,18 @@ bool TypeMerging::merge(MergeKind kind) {
329341
switch (kind) {
330342
case Supertypes: {
331343
auto super = type.getDeclaredSuperType();
332-
if (super && shapeEq(type, *super)) {
333-
// The current type and its supertype have the same top-level
334-
// structure and are not distinguished, so add the current type to its
335-
// supertype's partition.
336-
auto it = ensurePartition(*super);
337-
it->push_back(makeDFAState(type));
338-
typePartitions[type] = it;
339-
} else {
340-
// Otherwise, create a new partition for this type.
344+
bool superHasExactCast = super && exactCastTypes.count(*super);
345+
if (!super || !shapeEq(type, *super) || superHasExactCast) {
346+
// Create a new partition for this type and bail.
341347
ensurePartition(type);
348+
break;
342349
}
350+
// The current type and its supertype have the same top-level
351+
// structure and are not distinguished, so add the current type to its
352+
// supertype's partition.
353+
auto it = ensurePartition(*super);
354+
it->push_back(makeDFAState(type));
355+
typePartitions[type] = it;
343356
break;
344357
}
345358
case Siblings: {
@@ -476,17 +489,19 @@ TypeMerging::splitSupertypePartition(const std::vector<HeapType>& types) {
476489
return partitions;
477490
}
478491

479-
CastTypes TypeMerging::findCastTypes() {
480-
ModuleUtils::ParallelFunctionAnalysis<CastTypes> analysis(
481-
*module, [&](Function* func, CastTypes& castTypes) {
482-
if (func->imported()) {
483-
return;
484-
}
492+
std::pair<CastTypes, CastTypes> TypeMerging::findCastTypes() {
493+
ModuleUtils::ParallelFunctionAnalysis<std::pair<CastTypes, CastTypes>>
494+
analysis(*module,
495+
[&](Function* func, std::pair<CastTypes, CastTypes>& castTypes) {
496+
if (func->imported()) {
497+
return;
498+
}
485499

486-
CastFinder finder(getPassOptions());
487-
finder.walk(func->body);
488-
castTypes = std::move(finder.castTypes);
489-
});
500+
CastFinder finder(getPassOptions());
501+
finder.walk(func->body);
502+
castTypes = {std::move(finder.castTypes),
503+
std::move(finder.exactCastTypes)};
504+
});
490505

491506
// Also find cast types in the module scope (not possible in the current
492507
// spec, but do it to be future-proof).
@@ -495,12 +510,17 @@ CastTypes TypeMerging::findCastTypes() {
495510

496511
// Accumulate all the castTypes.
497512
auto& allCastTypes = moduleFinder.castTypes;
498-
for (auto& [k, castTypes] : analysis.map) {
513+
auto& allExactCastTypes = moduleFinder.exactCastTypes;
514+
for (auto& [k, types] : analysis.map) {
515+
auto& [castTypes, exactCastTypes] = types;
499516
for (auto type : castTypes) {
500517
allCastTypes.insert(type);
501518
}
519+
for (auto type : exactCastTypes) {
520+
allExactCastTypes.insert(type);
521+
}
502522
}
503-
return allCastTypes;
523+
return {std::move(allCastTypes), std::move(allExactCastTypes)};
504524
}
505525

506526
std::vector<HeapType> TypeMerging::getPublicChildren(HeapType type) {

test/lit/passes/type-merging-exact.wast

Lines changed: 115 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
;; NOTE: Assertions have been generated by update_lit_checks.py and should not be edited.
22

3-
;; Check that types that differ only in exactness are not merged.
4-
5-
;; RUN: wasm-opt %s -all --closed-world --preserve-type-order \
3+
;; RUN: foreach %s %t wasm-opt -all --closed-world --preserve-type-order \
64
;; RUN: --type-merging --remove-unused-types -S -o - | filecheck %s
75

6+
;; Check that types that differ only in exactness are not merged.
87
(module
98
;; CHECK: (rec
109
;; CHECK-NEXT: (type $foo (struct))
@@ -19,3 +18,116 @@
1918
;; CHECK: (global $b (ref null $B) (ref.null none))
2019
(global $b (ref null $B) (ref.null none))
2120
)
21+
22+
;; Check that exact casts to a supertype prevent subtypes from being merged into
23+
;; it.
24+
(module
25+
;; CHECK: (rec
26+
;; CHECK-NEXT: (type $super (sub (struct)))
27+
(type $super (sub (struct)))
28+
;; CHECK: (type $sub (sub $super (struct)))
29+
(type $sub (sub $super (struct)))
30+
31+
;; CHECK: (func $ref-cast (type $2) (param $any anyref)
32+
;; CHECK-NEXT: (local $sub (ref null $sub))
33+
;; CHECK-NEXT: (drop
34+
;; CHECK-NEXT: (ref.cast (ref (exact $super))
35+
;; CHECK-NEXT: (local.get $any)
36+
;; CHECK-NEXT: )
37+
;; CHECK-NEXT: )
38+
;; CHECK-NEXT: )
39+
(func $ref-cast (param $any anyref)
40+
(local $sub (ref null $sub))
41+
(drop
42+
(ref.cast (ref (exact $super))
43+
(local.get $any)
44+
)
45+
)
46+
)
47+
)
48+
49+
;; Same as above but with ref.test.
50+
(module
51+
;; CHECK: (rec
52+
;; CHECK-NEXT: (type $super (sub (struct)))
53+
(type $super (sub (struct)))
54+
;; CHECK: (type $sub (sub $super (struct)))
55+
(type $sub (sub $super (struct)))
56+
57+
;; CHECK: (func $ref-test (type $2) (param $any anyref)
58+
;; CHECK-NEXT: (local $sub (ref null $sub))
59+
;; CHECK-NEXT: (drop
60+
;; CHECK-NEXT: (ref.test (ref (exact $super))
61+
;; CHECK-NEXT: (local.get $any)
62+
;; CHECK-NEXT: )
63+
;; CHECK-NEXT: )
64+
;; CHECK-NEXT: )
65+
(func $ref-test (param $any anyref)
66+
(local $sub (ref null $sub))
67+
(drop
68+
(ref.test (ref (exact $super))
69+
(local.get $any)
70+
)
71+
)
72+
)
73+
)
74+
75+
;; Same as above but with br_on_cast.
76+
(module
77+
;; CHECK: (rec
78+
;; CHECK-NEXT: (type $super (sub (struct)))
79+
(type $super (sub (struct)))
80+
;; CHECK: (type $sub (sub $super (struct)))
81+
(type $sub (sub $super (struct)))
82+
83+
;; CHECK: (func $br-on-cast (type $2) (param $any anyref)
84+
;; CHECK-NEXT: (local $sub (ref null $sub))
85+
;; CHECK-NEXT: (drop
86+
;; CHECK-NEXT: (block $l (result anyref)
87+
;; CHECK-NEXT: (br_on_cast $l anyref (ref (exact $super))
88+
;; CHECK-NEXT: (local.get $any)
89+
;; CHECK-NEXT: )
90+
;; CHECK-NEXT: )
91+
;; CHECK-NEXT: )
92+
;; CHECK-NEXT: )
93+
(func $br-on-cast (param $any anyref)
94+
(local $sub (ref null $sub))
95+
(drop
96+
(block $l (result anyref)
97+
(br_on_cast $l anyref (ref (exact $super))
98+
(local.get $any)
99+
)
100+
)
101+
)
102+
)
103+
)
104+
105+
;; Same as above but with br_on_cast_fail
106+
(module
107+
;; CHECK: (rec
108+
;; CHECK-NEXT: (type $super (sub (struct)))
109+
(type $super (sub (struct)))
110+
;; CHECK: (type $sub (sub $super (struct)))
111+
(type $sub (sub $super (struct)))
112+
113+
;; CHECK: (func $br-on-cast-fail (type $2) (param $any anyref)
114+
;; CHECK-NEXT: (local $sub (ref null $sub))
115+
;; CHECK-NEXT: (drop
116+
;; CHECK-NEXT: (block $l (result anyref)
117+
;; CHECK-NEXT: (br_on_cast_fail $l anyref (ref (exact $super))
118+
;; CHECK-NEXT: (local.get $any)
119+
;; CHECK-NEXT: )
120+
;; CHECK-NEXT: )
121+
;; CHECK-NEXT: )
122+
;; CHECK-NEXT: )
123+
(func $br-on-cast-fail (param $any anyref)
124+
(local $sub (ref null $sub))
125+
(drop
126+
(block $l (result anyref)
127+
(br_on_cast_fail $l anyref (ref (exact $super))
128+
(local.get $any)
129+
)
130+
)
131+
)
132+
)
133+
)

0 commit comments

Comments
 (0)