|
1 |
| -struct PrimeIndices { |
| 1 | +struct Numbers { |
2 | 2 | data: [[stride(4)]] array<u32>;
|
3 |
| -}; // this is used as both input and output for convenience |
| 3 | +}; |
4 | 4 |
|
5 | 5 | [[group(0), binding(0)]]
|
6 |
| -var<storage, read_write> v_indices: PrimeIndices; |
| 6 | +var<storage, read_write> numbers: Numbers; |
7 | 7 |
|
8 |
| -let blockSize = 16u; |
9 |
| -var<workgroup> sdata: array<u32, blockSize>; |
| 8 | +[[override]] let blockSize: u32; |
| 9 | +var<workgroup> workgroup_data: array<u32, blockSize>; |
10 | 10 |
|
11 | 11 | [[stage(compute), workgroup_size(16)]]
|
12 | 12 | fn main(
|
13 | 13 | [[builtin(global_invocation_id)]] global_id: vec3<u32>,
|
14 | 14 | [[builtin(local_invocation_id)]] local_id: vec3<u32>,
|
15 | 15 | [[builtin(workgroup_id)]] workgroup_id: vec3<u32>,
|
16 | 16 | ) {
|
17 |
| - var n: u32 = arrayLength(&v_indices.data); |
| 17 | + var n: u32 = arrayLength(&numbers.data); |
18 | 18 |
|
19 | 19 | if (global_id.x < n) {
|
20 |
| - sdata[local_id.x] = v_indices.data[global_id.x]; |
| 20 | + workgroup_data[local_id.x] = numbers.data[global_id.x]; |
21 | 21 | } else {
|
22 |
| - sdata[local_id.x] = 0u; |
| 22 | + workgroup_data[local_id.x] = 0u; |
23 | 23 | }
|
24 | 24 |
|
25 | 25 | workgroupBarrier();
|
26 | 26 |
|
27 | 27 | for (var stride: u32 = blockSize / 2u; stride > 0u; stride = stride >> 1u) {
|
28 |
| - if (local_id.x < stride) { |
29 |
| - sdata[local_id.x] = sdata[local_id.x] + sdata[local_id.x + stride]; |
30 |
| - } |
31 |
| - workgroupBarrier(); |
| 28 | + var flag: u32 = u32(local_id.x < stride); |
| 29 | + workgroup_data[local_id.x] = workgroup_data[local_id.x] + flag * workgroup_data[local_id.x + flag * stride]; |
32 | 30 | }
|
33 | 31 |
|
34 | 32 | if (local_id.x == 0u) {
|
35 |
| - v_indices.data[workgroup_id.x] = sdata[0]; |
| 33 | + numbers.data[workgroup_id.x] = workgroup_data[0]; |
36 | 34 | }
|
37 | 35 | }
|
0 commit comments