Skip to content

Commit 92f1f5f

Browse files
committed
wipi remove synchronization in shader's for loop
it seems that all the invocation of a shader that lie in the same wave-front on opencl advance at the same time so we can use multiplication instead of if/else. not sure if webgpu uphold this property though
1 parent 9c22f11 commit 92f1f5f

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed
Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,35 @@
1-
struct PrimeIndices {
1+
struct Numbers {
22
data: [[stride(4)]] array<u32>;
3-
}; // this is used as both input and output for convenience
3+
};
44

55
[[group(0), binding(0)]]
6-
var<storage, read_write> v_indices: PrimeIndices;
6+
var<storage, read_write> numbers: Numbers;
77

8-
let blockSize = 16u;
9-
var<workgroup> sdata: array<u32, blockSize>;
8+
[[override]] let blockSize: u32;
9+
var<workgroup> workgroup_data: array<u32, blockSize>;
1010

1111
[[stage(compute), workgroup_size(16)]]
1212
fn main(
1313
[[builtin(global_invocation_id)]] global_id: vec3<u32>,
1414
[[builtin(local_invocation_id)]] local_id: vec3<u32>,
1515
[[builtin(workgroup_id)]] workgroup_id: vec3<u32>,
1616
) {
17-
var n: u32 = arrayLength(&v_indices.data);
17+
var n: u32 = arrayLength(&numbers.data);
1818

1919
if (global_id.x < n) {
20-
sdata[local_id.x] = v_indices.data[global_id.x];
20+
workgroup_data[local_id.x] = numbers.data[global_id.x];
2121
} else {
22-
sdata[local_id.x] = 0u;
22+
workgroup_data[local_id.x] = 0u;
2323
}
2424

2525
workgroupBarrier();
2626

2727
for (var stride: u32 = blockSize / 2u; stride > 0u; stride = stride >> 1u) {
28-
if (local_id.x < stride) {
29-
sdata[local_id.x] = sdata[local_id.x] + sdata[local_id.x + stride];
30-
}
31-
workgroupBarrier();
28+
var flag: u32 = u32(local_id.x < stride);
29+
workgroup_data[local_id.x] = workgroup_data[local_id.x] + flag * workgroup_data[local_id.x + flag * stride];
3230
}
3331

3432
if (local_id.x == 0u) {
35-
v_indices.data[workgroup_id.x] = sdata[0];
33+
numbers.data[workgroup_id.x] = workgroup_data[0];
3634
}
3735
}

0 commit comments

Comments
 (0)