Skip to content

tensor.roll() #3281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open

tensor.roll() #3281

wants to merge 16 commits into from

Conversation

crutcher
Copy link
Contributor

@crutcher crutcher commented Jun 10, 2025

tensor.roll()

Checklist

  • Confirmed that cargo run-checks command has been executed.
  • Made sure the book is up to date with changes in this PR.

Related Issues/PRs

See pytorch roll:
https://docs.pytorch.org/docs/stable/generated/torch.roll.html

Changes

  • Added burn::tensor::indexing
  • Added ReversableIndex (for usize / isize dims)
  • Added roll(), roll_dim()

Testing

Gen tests.

@crutcher crutcher changed the title [WIP] tensor.roll() tensor.roll() Jun 11, 2025
@crutcher
Copy link
Contributor Author

Probably needs some more tests; but I'd like a review of the ReflectableIndex code and if changes needed to be made there

Copy link

codecov bot commented Jun 11, 2025

Codecov Report

Attention: Patch coverage is 89.55696% with 33 lines in your changes missing coverage. Please review.

Project coverage is 82.73%. Comparing base (81985bd) to head (6c1583a).

Files with missing lines Patch % Lines
crates/burn-tensor/src/tensor/indexing/mod.rs 83.21% 23 Missing ⚠️
crates/burn-tensor/src/tensor/api/base.rs 90.47% 10 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3281      +/-   ##
==========================================
+ Coverage   82.66%   82.73%   +0.06%     
==========================================
  Files         995      997       +2     
  Lines      127626   127933     +307     
==========================================
+ Hits       105498   105839     +341     
+ Misses      22128    22094      -34     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some thoughts about the indexing stuff:

ReflectableIndex seems very much in line with the IndexConversion trait (previously added to extend the slicing/indexing support). It should probably be unified, and perhaps AsIndex would be more idiomatic for a name.

Now, I am a bit conflicted about using the same trait for element indexing and dimension/axis selection. While technically we can consider them both forms of indexing into different structural levels of a tensor, it might be a little confusing. And for a public API, perhaps it is better to separate them. The same canonicalizing function could still be used under the hood to handle negative values based on the size of a dimension (for AsIndex) vs the rank of a tensor (for AsAxis/AsDim). Or we could have a parent trait common to both. But for the public API, it might make more sense to differentiate them. Let me know what you think.

Finally, the number of places by which elements are shifted in roll and roll_dim are semantically different. shift is just a type that support negative values to wrap around, so not sure that the interface should have shift: ReflectableIndex. It's only being used to convert to isize and then wrap the negative values based on the size.

@crutcher
Copy link
Contributor Author

copied from Discord thread

I would like to lift the trait and utility code into burn::indexing; as we'll probably have a lot of related support machinery, including ravel/unravel support with bounds checking, as that is always part of tensor binding. Particularly for things like binding external datatypes (such as Images) into/out-of TensorData.

I do want to open a disagreement about having separate semantic types; and clarify a strategy here.

I am on board, in general, with the "new type" pattern to do param / unit separation, particularly in situations where we don't have keyword params; but I don't think that applies here. We're not discussing explicit casting into the appropriate semantic type 'dimension, index, shape' at the call location; in fact most of the work is to permit usize and isize to be used interchangeably without additional call-site syntax.

There is some protection if we make the utitily machinery require the newtypes, so you can't call something like wrap_index(size, index) and swap the params; but there's also a higher expectation of "know what I'm doing" when calling the indexing utilities explicitly; they will generally only be used by lower level implementors. And your comments aren't proposing that change anyway.

There's some additional (potential) utility in labeling the panic errors better; but those could be addressed by passing in labels to the utility mechanisms.

My preference would be to lift IndexConversion into burn::indexing, and not have seperate types for dim, index, size, and extent (what shift is here); just because, with the current auto-convert plans, the other approach puts a lot of weight on implementors without having any resulting api change for callers

@crutcher
Copy link
Contributor Author

Per discussion with @laggui :

  • Renamed IndexConversion to AsIndex
  • Added more trait impls for the basic integer types
  • moved to burn::indexing
  • also, reordered roll fns to : (dim, shift) param order.
  • added tests

Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good to me! I agree with the changes as discussed 👍

My only comments are mostly related to convention / practices.

Also, we should probably make the canonicalize_* functions return a result instead of panicking internally (esp. for external usage). And in roll we can unwrap.

Comment on lines 856 to 859
pub fn roll_dim<I1, I2>(self, dim: I2, shift: I1) -> Self
where
I1: AsIndex,
I2: AsIndex,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if both generics are bound to AsIndex, I would use better / more descriptive names than I1 and I2. Since D is already used for the const rank generic, maybe Dim and Shift.

Also, I thought you wanted to change the order to tensor.roll_dim(shift, dim) (discord)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... my memory of the outcome of that discussion doesn't match the thread; I did extra work to make this change, because it used to be the other way ...

I'll change it back.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could also be my memory 😅 but I think dim is typically after the input arguments to specify on which dim/axis the operation is applied. And that's how most of these ops are currently defined (though there might be some inconsistencies as you pointed out last time).

Comment on lines 9 to 19
fn test_roll_empty() {
let device = Default::default();
let input = TestTensorInt::<2>::zeros([12, 0], &device);

// Rolling an empty tensor should return the same empty tensor
input
.clone()
.roll(&[0, 1], &[1, 2])
.to_data()
.assert_eq(&input.to_data(), false);
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this use case locally for CUDA, which doesn't have the same restrictions on 0 size resources as wgpu. We have a separate issue when returning the data for a buffer of size 0.

Probably need to handle this special case for cubecl into_data method when bytes.is_empty(). Can take care of this before merging.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@crutcher crutcher requested a review from laggui June 16, 2025 19:27
Copy link
Member

@laggui laggui left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm not sure why the macos CI hangs.. might actually be related to the empty test.

Will have to investigate before we merge. Otherwise, just some minor comments.

/// ## Arguments
///
/// * `idx` - The dimension index to canonicalize.
/// * `rank` - The size of the dimension.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rank != size of the dimension

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed.

size,
);
assert!(
dim < size,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't catch this in the previous round, is this check correct? size here is the size of the dim.. why are we checking that the dimension specified is smaller than its size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm; yeah.

@crutcher crutcher requested a review from laggui June 17, 2025 19:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants