-
Notifications
You must be signed in to change notification settings - Fork 573
ONNX Import: switch to rank inferencing, rename shape to static_shape, decouple tensor shape info #3037
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
- Update struct TensorData field access patterns - Add support for non-optional data fields - Fix issues with tensor data handling
The PR resolves several compilation errors caused by changes to the TensorType structure: 1. Modified code to adapt to the removal of the `shape` field from TensorType 2. Fixed pattern matching issues in rank_inference.rs to properly match against TensorData.data 3. Updated from_onnx.rs's remap_unsqueeze_to_reshape function to work with the new API 4. Fixed unused imports across multiple files 5. Fixed function calls that were using Option.len() incorrectly
Enhance tensor type system to support both static shapes and dynamic ranks across multiple ONNX operations including Expand, RandomNormal, Constant, and related nodes. Ensure proper shape validation and improve type safety throughout the conversion process.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR modernizes the ONNX IR processing by switching from shape inferencing to rank inferencing, renaming the "shape" field to "static_shape", and decoupling tensor shape information by encapsulating it in a dedicated TensorData structure.
- Replaces shape inferencing with rank inferencing
- Renames the field "shape" to "static_shape" to clarify its purpose
- Refactors tensor data conversion to include explicit shape details via TensorData
Reviewed Changes
Copilot reviewed 25 out of 25 changed files in this pull request and generated no comments.
Show a summary per file
File | Description |
---|---|
crates/onnx-ir/src/from_onnx.rs | Refactors unsqueeze node remapping to use TensorData and replaces direct shape extraction with a decoupled approach. |
crates/onnx-ir/src/coalesce.rs | Adjusts tensor weight conversion logic to extract shape information from TensorData rather than a raw option. |
crates/burn-import/* | Updates multiple modules to remove direct use of shape in TensorType constructors and to rely on TensorData, reflecting the new rank inferencing model. |
crates/burn-ndarray/* | Minor changes including added allow attributes for unused imports. |
crates/burn-import/onnx-tests/* | Temporarily disables unsqueeze tests due to dynamic rank support changes in Burn. |
Comments suppressed due to low confidence (3)
crates/onnx-ir/src/onnx/op_configuration.rs:1891
- Using 'static_shape' with expect() assumes it is always available; verify that all tensors processed in this context have their static_shape set to avoid potential runtime panics.
let dim_size = tensor.static_shape.as_ref().expect("Split: Static shape must be known to calculate split size")[axis as usize];
crates/onnx-ir/src/from_onnx.rs:404
- The new field 'static_shape' is assigned None here; please confirm this aligns with the intended decoupling of tensor shape information in favor of runtime rank inferencing.
static_shape: None,
crates/burn-import/src/onnx/to_burn.rs:438
- The TensorType constructor has been updated to omit the explicit shape parameter; ensure that downstream logic correctly derives shape information from the TensorData structure.
ConstantValue::Tensor(TensorType::new(name, rank, kind), tensor_data)
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3037 +/- ##
==========================================
+ Coverage 81.16% 81.17% +0.01%
==========================================
Files 815 815
Lines 117201 117294 +93
==========================================
+ Hits 95121 95213 +92
- Misses 22080 22081 +1 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for tackling the ONNX import issues 🙏
See my comments below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A couple minor comments, otherwise should be good to go after! 🙂
Pull Request Template
Checklist
run-checks all
script has been executed.Related Issues/PRs
Fixes #2478
Changes
shape
field name tostatic_shape
to signal that the shape information is captured from ONNX file during build. Burn relies on dynamic shape information. This static_shape field was kept (instead of removing it completely) because there are some rare edge cases when we can use dimension information from static_shape.Testing