Skip to content

ONNX Import: switch to rank inferencing, rename shape to static_shape, decouple tensor shape info #3037

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
Apr 25, 2025
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
f56c4c9
Update ONNX-IR documentation with more comprehensive description
antimora Mar 4, 2025
47339fa
Fix build issues with data structure changes
antimora Mar 4, 2025
e1557e9
Fix build issues with TensorType structure changes
antimora Mar 7, 2025
1946252
Add static shape handling and rank inference for tensor operations
antimora Apr 16, 2025
02d7b29
Fix clippy warnings
antimora Apr 16, 2025
56de3f1
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 16, 2025
731c6d9
Fix merge issues
antimora Apr 17, 2025
b23f3c6
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 17, 2025
d554562
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 17, 2025
f4e815c
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 18, 2025
d5acc51
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 23, 2025
8728372
Enable unsqueeze with runtime axes values
antimora Apr 24, 2025
55a677a
Fix clippy error
antimora Apr 24, 2025
43af757
Remove default fall back
antimora Apr 24, 2025
03cdbe5
Removed dead code.
antimora Apr 24, 2025
c9b32f2
Removed rank from TensroData
antimora Apr 24, 2025
5ae4685
Removed elem_type from TensorData
antimora Apr 24, 2025
ee7f329
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 24, 2025
a29aba2
Merge remote-tracking branch 'upstream/main' into onnx-shape
antimora Apr 25, 2025
1aeb4a1
Simplify elem_type match expressions with pattern grouping
antimora Apr 25, 2025
285e361
Add static_shape back
antimora Apr 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion crates/burn-import/onnx-tests/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ fn main() {
.input("tests/trilu/trilu_upper.onnx")
.input("tests/trilu/trilu_lower.onnx")
.input("tests/transpose/transpose.onnx")
.input("tests/unsqueeze/unsqueeze.onnx")
// .input("tests/unsqueeze/unsqueeze.onnx") disabled for now because dynamic ranks are not supported in Burn
.input("tests/unsqueeze/unsqueeze_opset11.onnx")
.input("tests/unsqueeze/unsqueeze_opset16.onnx")
.input("tests/split/split.onnx")
Expand Down
23 changes: 12 additions & 11 deletions crates/burn-import/onnx-tests/tests/test_onnx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ include_models!(
trilu_upper,
trilu_lower,
transpose,
unsqueeze,
// unsqueeze, Disabled for now because dynamic ranks are not supported in Burn
unsqueeze_opset11,
unsqueeze_opset16,
split
Expand Down Expand Up @@ -2081,16 +2081,17 @@ mod tests {
output.assert_eq(&expected, true);
}

#[test]
fn unsqueeze() {
let device = Default::default();
let model: unsqueeze::Model<Backend> = unsqueeze::Model::new(&device);
let input_shape = Shape::from([3, 4, 5]);
let expected_shape = Shape::from([1, 1, 3, 4, 5, 1]);
let input = Tensor::ones(input_shape, &device);
let output = model.forward(input);
assert_eq!(output.shape(), expected_shape);
}
// NOTE: unsqueeze, Disabled for now because dynamic ranks are not supported in Burn
// #[test]
// fn unsqueeze() {
// let device = Default::default();
// let model: unsqueeze::Model<Backend> = unsqueeze::Model::new(&device);
// let input_shape = Shape::from([3, 4, 5]);
// let expected_shape = Shape::from([1, 1, 3, 4, 5, 1]);
// let input = Tensor::ones(input_shape, &device);
// let output = model.forward(input);
// assert_eq!(output.shape(), expected_shape);
// }

#[test]
fn unsqueeze_opset16() {
Expand Down
58 changes: 19 additions & 39 deletions crates/burn-import/src/burn/node/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,33 +117,40 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {

fn field_init(&self) -> Option<TokenStream> {
match &self.value {
ConstantValue::Tensor(tensor_type, _) => {
ConstantValue::Tensor(tensor_type, data) => {
let ty = tensor_type.ty();
let name = Ident::new(self.name.as_ref(), Span::call_site());
let shape = tensor_type.clone().shape.unwrap().to_tokens();
let dim = tensor_type.rank.to_tokens();

assert_eq!(
data.shape.len(),
tensor_type.rank,
"Tensor data shape does not match tensor type rank"
);

let shape = data.shape.to_tokens();
let rank = tensor_type.rank.to_tokens();

match tensor_type.kind {
crate::burn::TensorKind::Int => Some(quote! {
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
burn::module::ParamId::new(),
move |device, _require_grad| Tensor::<B, #dim, Int>::zeros(#shape, &device),
move |device, _require_grad| Tensor::<B, #rank, Int>::zeros(#shape, &device),
device.clone(),
false
);
}),
crate::burn::TensorKind::Float => Some(quote! {
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
burn::module::ParamId::new(),
move |device, _require_grad| Tensor::<B, #dim>::zeros(#shape, &device),
move |device, _require_grad| Tensor::<B, #rank>::zeros(#shape, &device),
device.clone(),
false,
);
}),
crate::burn::TensorKind::Bool => Some(quote! {
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
burn::module::ParamId::new(),
move |device, _require_grad| Tensor::<B, #dim, Bool>::empty(#shape, &device),
move |device, _require_grad| Tensor::<B, #rank, Bool>::empty(#shape, &device),
device.clone(),
false,
);
Expand Down Expand Up @@ -288,23 +295,14 @@ mod tests {

let const_tensor = Ident::new("const_tensor", Span::call_site());
let dimensions = 1;
let shape = vec![4];
let data = TensorData::from([2f32, 2f32, 2f32, 2f32]);
let tensor_type = TensorType::new_float_with_shape(
const_tensor.to_string(),
dimensions,
Some(shape.clone()),
);
let tensor_type = TensorType::new_float(const_tensor.to_string(), dimensions);
let constant = ConstantValue::Tensor(tensor_type.clone(), data);

graph.register(ConstantNode::new(
const_tensor.to_string(),
constant.clone(),
Type::Tensor(TensorType::new_float_with_shape(
"output",
dimensions,
Some(shape.clone()),
)),
Type::Tensor(TensorType::new_float("output", dimensions)),
));

graph.register_input_output(vec![], vec!["output".to_string()]);
Expand Down Expand Up @@ -356,23 +354,14 @@ mod tests {

let const_tensor = Ident::new("const_tensor_int", Span::call_site());
let dimensions = 1;
let shape = vec![3];
let data = TensorData::from([1i32, 2i32, 3i32]);
let tensor_type = TensorType::new_int_with_shape(
const_tensor.to_string(),
dimensions,
Some(shape.clone()),
);
let tensor_type = TensorType::new_int(const_tensor.to_string(), dimensions);
let constant = ConstantValue::Tensor(tensor_type.clone(), data);

graph.register(ConstantNode::new(
const_tensor.to_string(),
constant.clone(),
Type::Tensor(TensorType::new_int_with_shape(
"output",
dimensions,
Some(shape.clone()),
)),
Type::Tensor(TensorType::new_int("output", dimensions)),
));

graph.register_input_output(vec![], vec!["output".to_string()]);
Expand Down Expand Up @@ -425,23 +414,14 @@ mod tests {

let const_tensor = Ident::new("const_tensor_3d", Span::call_site());
let dimensions = 3;
let shape = vec![1, 3, 2];
let data = TensorData::from([[[true, false], [true, false], [true, false]]]);
let tensor_type = TensorType::new_bool_with_shape(
const_tensor.to_string(),
dimensions,
Some(shape.clone()),
);
let tensor_type = TensorType::new_bool(const_tensor.to_string(), dimensions);
let constant = ConstantValue::Tensor(tensor_type.clone(), data);

graph.register(ConstantNode::new(
const_tensor.to_string(),
constant.clone(),
Type::Tensor(TensorType::new_bool_with_shape(
"output",
dimensions,
Some(shape.clone()),
)),
Type::Tensor(TensorType::new_bool("output", dimensions)),
));

graph.register_input_output(vec![], vec!["output".to_string()]);
Expand Down
17 changes: 8 additions & 9 deletions crates/burn-import/src/burn/node/expand.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,22 +35,22 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ExpandNode {
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
let input = scope.tensor_use_owned(&self.input, node_position);
let output = &self.output.name;
let output_rank = &self.output.rank;

let shape = match &self.shape {
ExpandShape::Static(static_shape) => static_shape.to_tokens(),
ExpandShape::Runtime(Type::Tensor(shape_tensor)) => {
// since we don't take ownership of the shape_tensor, we don't need `tensor_use_owned` here:
// Since we don't take ownership of the shape_tensor, `tensor_use_owned` is not needed here.
let tensor_name = &shape_tensor.name;
let dim = shape_tensor.shape.as_ref().unwrap()[0];
// the shape of the tensor is already validated statically to be rank one when parsing the input
// we'll need to download the Tensor from device to cpu for expand operation.
// Also, we'll need to convert it to an array for conversion into BroadcastArgs
// The shape of the tensor is statically validated to be rank one during input parsing.
// The tensor must be downloaded from device to CPU for the expand operation.
// Additionally, it needs to be converted to an array for use in BroadcastArgs.
quote! {
TryInto::<[B::IntElem; #dim]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
TryInto::<[B::IntElem; #output_rank]>::try_into(#tensor_name.to_data().as_slice::<B::IntElem>().unwrap()).unwrap()
}
}
ExpandShape::Runtime(Type::Shape(shape)) => {
// Shape implements BroadcastArgs, so it can be passed to expand directly
// Shape implements BroadcastArgs, allowing it to be passed directly to the expand method.
let shape_name = &shape.name;
quote! { #shape_name }
}
Expand Down Expand Up @@ -177,8 +177,7 @@ mod tests {
fn test_codegen_expand_tensor() {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

let mut shape_tensor_type = TensorType::new_int("tensor3", 4);
shape_tensor_type.shape = Some(vec![4]);
let shape_tensor_type = TensorType::new_int("tensor3", 4);

graph.register(ExpandNode::new(
TensorType::new_float("tensor1", 4),
Expand Down
14 changes: 6 additions & 8 deletions crates/burn-import/src/burn/node/random_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,21 @@ pub struct RandomNormalNode {
pub mean: f64,
pub scale: f64,
pub output_ty: TensorType,
pub shape: Vec<usize>,
}

impl RandomNormalNode {
pub fn new(output_ty: TensorType, mean: f64, scale: f64) -> Self {
pub fn new(output_ty: TensorType, mean: f64, scale: f64, shape: Vec<usize>) -> Self {
Self {
mean,
scale,
output_ty,
shape,
}
}

fn get_output_shape(&self) -> TokenStream {
let shape_it = self
.output_ty
.shape
.as_ref()
.expect("RandomNormal output has no shape!")
.iter();
let shape_it = self.shape.iter();
quote! { Shape::new([#(#shape_it),*]) }
}

Expand Down Expand Up @@ -81,9 +78,10 @@ mod tests {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(RandomNormalNode::new(
TensorType::new("tensor1", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("tensor1", 2, TensorKind::Float),
0.0f64,
1.0f64,
vec![2, 3],
));

graph.register_input_output(vec![], vec!["tensor1".to_string()]);
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-import/src/burn/node/random_normal_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomNormalLikeNode {

#[cfg(test)]
mod tests {

use super::*;
use crate::burn::{TensorKind, TensorType, graph::BurnGraph, node::test::assert_tokens};
use burn::record::FullPrecisionSettings;
Expand All @@ -61,8 +62,8 @@ mod tests {
graph.register(RandomNormalLikeNode::new(
0.0f64,
1.0f64,
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("input", 2, TensorKind::Float),
TensorType::new("output", 2, TensorKind::Float),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
14 changes: 6 additions & 8 deletions crates/burn-import/src/burn/node/random_uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,21 @@ pub struct RandomUniformNode {
pub low: f64,
pub high: f64,
pub output_ty: TensorType,
pub shape: Vec<usize>,
}

impl RandomUniformNode {
pub fn new(output_ty: TensorType, low: f64, high: f64) -> Self {
pub fn new(output_ty: TensorType, low: f64, high: f64, shape: Vec<usize>) -> Self {
Self {
low,
high,
output_ty,
shape,
}
}

fn get_output_shape(&self) -> TokenStream {
let shape_it = self
.output_ty
.shape
.as_ref()
.expect("RandomUniform output has no shape!")
.iter();
let shape_it = self.shape.iter();
quote! { Shape::new([#(#shape_it),*]) }
}

Expand Down Expand Up @@ -81,9 +78,10 @@ mod tests {
let mut graph = BurnGraph::<FullPrecisionSettings>::default();

graph.register(RandomUniformNode::new(
TensorType::new("tensor1", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("tensor1", 2, TensorKind::Float),
0.0f64,
1.0f64,
vec![2, 3],
));

graph.register_input_output(vec![], vec!["tensor1".to_string()]);
Expand Down
5 changes: 3 additions & 2 deletions crates/burn-import/src/burn/node/random_uniform_like.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for RandomUniformLikeNode {

#[cfg(test)]
mod tests {

use super::*;
use crate::burn::{TensorKind, TensorType, graph::BurnGraph, node::test::assert_tokens};
use burn::record::FullPrecisionSettings;
Expand All @@ -61,8 +62,8 @@ mod tests {
graph.register(RandomUniformLikeNode::new(
0.0f64,
1.0f64,
TensorType::new("input", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("output", 2, TensorKind::Float, Some(vec![2, 3])),
TensorType::new("input", 2, TensorKind::Float),
TensorType::new("output", 2, TensorKind::Float),
));

graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-import/src/burn/node/split.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,14 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for SplitNode {
if let Some(split_sizes) = &self.config.split_sizes {
let split_sizes_tokens = split_sizes.to_tokens();
quote! {
let mut split_tensors = #input.split_with_sizes(#split_sizes_tokens, #axis);
let split_tensors = #input.split_with_sizes(#split_sizes_tokens.to_vec(), #axis);
#unpack_outputs
}
} else {
let split_size = &self.config.split_size.unwrap();
let split_size_tokens = split_size.to_tokens();
quote! {
let mut split_tensors = #input.split(#split_size_tokens, #axis);
let split_tensors = #input.split(#split_size_tokens, #axis);
#unpack_outputs
}
}
Expand Down Expand Up @@ -125,7 +125,7 @@ mod tests {
&self,
tensor1: Tensor<B, 2>,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
let mut split_tensors = tensor1.split(2, 0);
let split_tensors = tensor1.split(2, 0);

let [tensor2, tensor3] = split_tensors.try_into().unwrap();
(tensor2, tensor3)
Expand Down
Loading