-
Notifications
You must be signed in to change notification settings - Fork 594
tensor.roll() #3281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
tensor.roll() #3281
Conversation
Probably needs some more tests; but I'd like a review of the |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3281 +/- ##
==========================================
+ Coverage 82.66% 82.73% +0.06%
==========================================
Files 995 997 +2
Lines 127626 127933 +307
==========================================
+ Hits 105498 105839 +341
+ Misses 22128 22094 -34 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some thoughts about the indexing stuff:
ReflectableIndex
seems very much in line with the IndexConversion
trait (previously added to extend the slicing/indexing support). It should probably be unified, and perhaps AsIndex
would be more idiomatic for a name.
Now, I am a bit conflicted about using the same trait for element indexing and dimension/axis selection. While technically we can consider them both forms of indexing into different structural levels of a tensor, it might be a little confusing. And for a public API, perhaps it is better to separate them. The same canonicalizing function could still be used under the hood to handle negative values based on the size of a dimension (for AsIndex
) vs the rank of a tensor (for AsAxis
/AsDim
). Or we could have a parent trait common to both. But for the public API, it might make more sense to differentiate them. Let me know what you think.
Finally, the number of places by which elements are shifted in roll
and roll_dim
are semantically different. shift
is just a type that support negative values to wrap around, so not sure that the interface should have shift: ReflectableIndex
. It's only being used to convert to isize and then wrap the negative values based on the size.
copied from Discord threadI would like to lift the trait and utility code into burn::indexing; as we'll probably have a lot of related support machinery, including ravel/unravel support with bounds checking, as that is always part of tensor binding. Particularly for things like binding external datatypes (such as Images) into/out-of TensorData. I do want to open a disagreement about having separate semantic types; and clarify a strategy here. I am on board, in general, with the "new type" pattern to do param / unit separation, particularly in situations where we don't have keyword params; but I don't think that applies here. We're not discussing explicit casting into the appropriate semantic type 'dimension, index, shape' at the call location; in fact most of the work is to permit usize and isize to be used interchangeably without additional call-site syntax. There is some protection if we make the utitily machinery require the newtypes, so you can't call something like wrap_index(size, index) and swap the params; but there's also a higher expectation of "know what I'm doing" when calling the indexing utilities explicitly; they will generally only be used by lower level implementors. And your comments aren't proposing that change anyway. There's some additional (potential) utility in labeling the panic errors better; but those could be addressed by passing in labels to the utility mechanisms. My preference would be to lift IndexConversion into burn::indexing, and not have seperate types for dim, index, size, and extent (what shift is here); just because, with the current auto-convert plans, the other approach puts a lot of weight on implementors without having any resulting api change for callers |
Per discussion with @laggui :
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good to me! I agree with the changes as discussed 👍
My only comments are mostly related to convention / practices.
Also, we should probably make the canonicalize_*
functions return a result instead of panicking internally (esp. for external usage). And in roll
we can unwrap.
pub fn roll_dim<I1, I2>(self, dim: I2, shift: I1) -> Self | ||
where | ||
I1: AsIndex, | ||
I2: AsIndex, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even if both generics are bound to AsIndex
, I would use better / more descriptive names than I1
and I2
. Since D
is already used for the const rank generic, maybe Dim
and Shift
.
Also, I thought you wanted to change the order to tensor.roll_dim(shift, dim)
(discord)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... my memory of the outcome of that discussion doesn't match the thread; I did extra work to make this change, because it used to be the other way ...
I'll change it back.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could also be my memory 😅 but I think dim is typically after the input arguments to specify on which dim/axis the operation is applied. And that's how most of these ops are currently defined (though there might be some inconsistencies as you pointed out last time).
fn test_roll_empty() { | ||
let device = Default::default(); | ||
let input = TestTensorInt::<2>::zeros([12, 0], &device); | ||
|
||
// Rolling an empty tensor should return the same empty tensor | ||
input | ||
.clone() | ||
.roll(&[0, 1], &[1, 2]) | ||
.to_data() | ||
.assert_eq(&input.to_data(), false); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tested this use case locally for CUDA, which doesn't have the same restrictions on 0 size resources as wgpu. We have a separate issue when returning the data for a buffer of size 0.
Probably need to handle this special case for cubecl into_data
method when bytes.is_empty()
. Can take care of this before merging.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm not sure why the macos CI hangs.. might actually be related to the empty test.
Will have to investigate before we merge. Otherwise, just some minor comments.
/// ## Arguments | ||
/// | ||
/// * `idx` - The dimension index to canonicalize. | ||
/// * `rank` - The size of the dimension. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank != size of the dimension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed.
size, | ||
); | ||
assert!( | ||
dim < size, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Didn't catch this in the previous round, is this check correct? size
here is the size of the dim
.. why are we checking that the dimension specified is smaller than its size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm; yeah.
- Move IndexConversion to burn::indexing - Reorder to (dim, shift) - improve error messages.
tensor.roll()
Checklist
cargo run-checks
command has been executed.Related Issues/PRs
See pytorch roll:
https://docs.pytorch.org/docs/stable/generated/torch.roll.html
Changes
burn::tensor::indexing
ReversableIndex
(for usize / isize dims)roll()
,roll_dim()
Testing
Gen tests.