Skip to content

Commit 53dba3f

Browse files
committed
wip sum reduction done!
1 parent a2da0e6 commit 53dba3f

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

src/main.rs

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ fn bytes_as_slice<T>(v: &[u8]) -> &[T] {
1111
unsafe { std::slice::from_raw_parts(v.as_ptr() as *const T, v.len() / mem::size_of::<T>()) }
1212
}
1313

14-
async fn steps_many(numbers: &[u32]) -> Vec<u32> {
14+
async fn steps_many(numbers: &[u32]) -> u32 {
1515
let instance = wgpu::Instance::new(wgpu::Backends::all());
1616
let adapter = instance
1717
.request_adapter(&wgpu::RequestAdapterOptions {
@@ -62,8 +62,8 @@ async fn steps_many(numbers: &[u32]) -> Vec<u32> {
6262
});
6363

6464
const BLOCK_SIZE: u32 = 16;
65-
let mut n = numbers.len() as u32 / BLOCK_SIZE;
66-
while n > 0 {
65+
let mut n = numbers.len() as u32;
66+
while n >= BLOCK_SIZE {
6767
println!("Iter: {n}");
6868

6969
let mut command_encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
@@ -88,10 +88,10 @@ async fn steps_many(numbers: &[u32]) -> Vec<u32> {
8888
let mut command_encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
8989
label: Some("my command encoder"),
9090
});
91-
command_encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, number_buf_size);
91+
command_encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffer, 0, n as u64 * 4);
9292
queue.submit(Some(command_encoder.finish()));
9393

94-
let buffer_slice = staging_buffer.slice(..);
94+
let buffer_slice = staging_buffer.slice(..n as u64 * 4);
9595
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
9696

9797
device.poll(wgpu::Maintain::Wait);
@@ -101,22 +101,16 @@ async fn steps_many(numbers: &[u32]) -> Vec<u32> {
101101
}
102102

103103
let data = buffer_slice.get_mapped_range();
104-
let result = bytes_as_slice(&data).to_vec();
104+
let result: u32 = bytes_as_slice::<u32>(&data).iter().sum();
105105
mem::drop(data);
106106
staging_buffer.unmap();
107107

108108
result
109109
}
110110

111111
fn main() {
112-
let numbers: Vec<u32> = (1..=256).collect();
113-
let steps = futures_lite::future::block_on(steps_many(&numbers));
114-
115-
for step in steps.into_iter().take(32) {
116-
if step == u32::MAX {
117-
println!("overflow");
118-
} else {
119-
println!("{step}");
120-
}
121-
}
112+
let numbers: Vec<u32> = (1..=128).collect();
113+
let sum = futures_lite::future::block_on(steps_many(&numbers));
114+
115+
println!("{sum}");
122116
}

0 commit comments

Comments
 (0)