Skip to content

Commit 7c1fa59

Browse files
authored
Use TensorIndexer for the view tests (#4237)
Enabled TensorIndexer for the reshape tests. I temporarily added a codegen diff result to this PR. This one is more concise as I disabled index hoisting. As far as I can see, there's no concerning change. I haven't verified everything, but I believe most of them are because TensorIndexer can detect more divisible splits, which helps generate simplified indices through more aggressive contig indexing. Once approved, I'll remove the html file. ### Context Part of #4175. I'm planning to enable the new indexer globally by default once we are sufficiently confident with it. I'm going to enable it for some of the C++ tests for now. Just manually checking the diff results seems to be the only way to gain some confidence. All the tests are passing in my local branch, but just having green test results don't necessarily mean everything is properly ported to the new indexer. I'll also check perf changes with the benchmarks, but they may not give clear signals as indexing is just one piece of performance bottlenecks.
1 parent d0ecedd commit 7c1fa59

File tree

1 file changed

+15
-2
lines changed

1 file changed

+15
-2
lines changed

tests/cpp/test_gpu_view.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,13 @@ namespace nvfuser {
5151

5252
using testing::UnorderedElementsAre;
5353

54-
using GpuViewTest = NVFuserTest;
54+
class GpuViewTest : public NVFuserTest {
55+
protected:
56+
void SetUp() override {
57+
NVFuserTest::SetUp();
58+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
59+
}
60+
};
5561

5662
TEST_F(GpuViewTest, FusionViewDtypeSameSizeOutput) {
5763
Fusion fusion;
@@ -386,7 +392,14 @@ std::vector<ReshapeReductionParam> generateReshapeReductionParams() {
386392
return params;
387393
}
388394

389-
using ReshapeReduction = NVFuserFixtureParamTest<ReshapeReductionParam>;
395+
class ReshapeReduction : public NVFuserFixtureParamTest<ReshapeReductionParam> {
396+
protected:
397+
void SetUp() override {
398+
NVFuserFixtureParamTest<ReshapeReductionParam>::SetUp();
399+
EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"});
400+
}
401+
};
402+
390403
TEST_P(ReshapeReduction, FusionReshapeReduction) {
391404
const auto& param = GetParam();
392405
const auto& [input_shape, output_shape] = param.reshape_example;

0 commit comments

Comments
 (0)