Skip to content

Commit b9d7877

Browse files
committed
Track const sizes in {create,read_from}_const_alloc, instead of mutating offsets.
1 parent 3df836e commit b9d7877

File tree

2 files changed

+142
-156
lines changed

2 files changed

+142
-156
lines changed

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 141 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ use super::CodegenCx;
55
use crate::abi::ConvSpirvType;
66
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind};
77
use crate::spirv_type::SpirvType;
8+
use itertools::Itertools as _;
89
use rspirv::spirv::Word;
910
use rustc_abi::{self as abi, AddressSpace, Float, HasDataLayout, Integer, Primitive, Size};
1011
use rustc_codegen_ssa::traits::{ConstCodegenMethods, MiscCodegenMethods, StaticCodegenMethods};
11-
use rustc_middle::bug;
1212
use rustc_middle::mir::interpret::{ConstAllocation, GlobalAlloc, Scalar, alloc_range};
1313
use rustc_middle::ty::layout::LayoutOf;
1414
use rustc_span::{DUMMY_SP, Span};
@@ -255,7 +255,11 @@ impl ConstCodegenMethods for CodegenCx<'_> {
255255
other.debug(ty, self)
256256
)),
257257
};
258-
let init = self.create_const_alloc(alloc, pointee);
258+
// FIXME(eddyb) always use `const_data_from_alloc`, and
259+
// defer the actual `try_read_from_const_alloc` step.
260+
let init = self
261+
.try_read_from_const_alloc(alloc, pointee)
262+
.unwrap_or_else(|| self.const_data_from_alloc(alloc));
259263
let value = self.static_addr_of(init, alloc.inner().align, None);
260264
(value, AddressSpace::DATA)
261265
}
@@ -280,7 +284,11 @@ impl ConstCodegenMethods for CodegenCx<'_> {
280284
other.debug(ty, self)
281285
)),
282286
};
283-
let init = self.create_const_alloc(alloc, pointee);
287+
// FIXME(eddyb) always use `const_data_from_alloc`, and
288+
// defer the actual `try_read_from_const_alloc` step.
289+
let init = self
290+
.try_read_from_const_alloc(alloc, pointee)
291+
.unwrap_or_else(|| self.const_data_from_alloc(alloc));
284292
let value = self.static_addr_of(init, alloc.inner().align, None);
285293
(value, AddressSpace::DATA)
286294
}
@@ -348,9 +356,8 @@ impl<'tcx> CodegenCx<'tcx> {
348356
&& let Some(SpirvConst::ConstDataFromAlloc(alloc)) =
349357
self.builder.lookup_const_by_id(pointee)
350358
&& let SpirvType::Pointer { pointee } = self.lookup_type(ty)
359+
&& let Some(init) = self.try_read_from_const_alloc(alloc, pointee)
351360
{
352-
let mut offset = Size::ZERO;
353-
let init = self.read_from_const_alloc(alloc, &mut offset, pointee);
354361
return self.static_addr_of(init, alloc.inner().align, None);
355362
}
356363

@@ -379,44 +386,38 @@ impl<'tcx> CodegenCx<'tcx> {
379386
}
380387
}
381388

382-
pub fn create_const_alloc(&self, alloc: ConstAllocation<'tcx>, ty: Word) -> SpirvValue {
383-
tracing::trace!(
384-
"Creating const alloc of type {} with {} bytes",
385-
self.debug_type(ty),
386-
alloc.inner().len()
387-
);
388-
let mut offset = Size::ZERO;
389-
let result = self.read_from_const_alloc(alloc, &mut offset, ty);
390-
assert_eq!(
391-
offset.bytes_usize(),
392-
alloc.inner().len(),
393-
"create_const_alloc must consume all bytes of an Allocation"
394-
);
395-
tracing::trace!("Done creating alloc of type {}", self.debug_type(ty));
396-
result
397-
}
398-
399-
fn read_from_const_alloc(
389+
/// Attempt to read a whole constant of type `ty` from `alloc`, but only
390+
/// returning that constant if its size covers the entirety of `alloc`.
391+
//
392+
// FIXME(eddyb) should this use something like `Result<_, PartialRead>`?
393+
pub fn try_read_from_const_alloc(
400394
&self,
401395
alloc: ConstAllocation<'tcx>,
402-
offset: &mut Size,
403396
ty: Word,
404-
) -> SpirvValue {
405-
let ty_concrete = self.lookup_type(ty);
406-
*offset = offset.align_to(ty_concrete.alignof(self));
407-
// these print statements are really useful for debugging, so leave them easily available
408-
// println!("const at {}: {}", offset.bytes(), self.debug_type(ty));
409-
match ty_concrete {
410-
SpirvType::Void => self
411-
.tcx
412-
.dcx()
413-
.fatal("cannot create const alloc of type void"),
397+
) -> Option<SpirvValue> {
398+
let (result, read_size) = self.read_from_const_alloc_at(alloc, ty, Size::ZERO);
399+
(read_size == alloc.inner().size()).then_some(result)
400+
}
401+
402+
// HACK(eddyb) the `Size` returned is the equivalent of `size_of_val` on
403+
// the returned constant, i.e. `ty.sizeof()` can be either `Some(read_size)`,
404+
// or `None` - i.e. unsized, in which case only the returned `Size` records
405+
// how much was read from `alloc` to build the returned constant value.
406+
#[tracing::instrument(level = "trace", skip(self), fields(ty = ?self.debug_type(ty), offset))]
407+
fn read_from_const_alloc_at(
408+
&self,
409+
alloc: ConstAllocation<'tcx>,
410+
ty: Word,
411+
offset: Size,
412+
) -> (SpirvValue, Size) {
413+
let ty_def = self.lookup_type(ty);
414+
match ty_def {
414415
SpirvType::Bool
415416
| SpirvType::Integer(..)
416417
| SpirvType::Float(_)
417418
| SpirvType::Pointer { .. } => {
418-
let size = ty_concrete.sizeof(self).unwrap();
419-
let primitive = match ty_concrete {
419+
let size = ty_def.sizeof(self).unwrap();
420+
let primitive = match ty_def {
420421
SpirvType::Bool => Primitive::Int(Integer::fit_unsigned(0), false),
421422
SpirvType::Integer(int_size, int_signedness) => Primitive::Int(
422423
match int_size {
@@ -445,147 +446,132 @@ impl<'tcx> CodegenCx<'tcx> {
445446
}
446447
}),
447448
SpirvType::Pointer { .. } => Primitive::Pointer(AddressSpace::DATA),
448-
unsupported_spirv_type => bug!(
449-
"invalid spirv type internal to create_alloc_const2: {:?}",
450-
unsupported_spirv_type
451-
),
449+
_ => unreachable!(),
452450
};
453-
// alloc_id is not needed by read_scalar, so we just use 0. If the context
454-
// refers to a pointer, read_scalar will find the actual alloc_id. It
455-
// only uses the input alloc_id in the case that the scalar is uninitialized
456-
// as part of the error output
457-
// tldr, the pointer here is only needed for the offset
458451
let value = match alloc.inner().read_scalar(
459452
self,
460-
alloc_range(*offset, size),
453+
alloc_range(offset, size),
461454
matches!(primitive, Primitive::Pointer(_)),
462455
) {
463456
Ok(scalar) => {
464457
self.scalar_to_backend(scalar, self.primitive_to_scalar(primitive), ty)
465458
}
459+
// FIXME(eddyb) this is really unsound, could be an error!
466460
_ => self.undef(ty),
467461
};
468-
*offset += size;
469-
value
462+
(value, size)
470463
}
471464
SpirvType::Adt {
472-
size,
473465
field_types,
474466
field_offsets,
475467
..
476468
} => {
477-
let base = *offset;
478-
let mut values = Vec::with_capacity(field_types.len());
479-
let mut occupied_spaces = Vec::with_capacity(field_types.len());
480-
for (&ty, &field_offset) in field_types.iter().zip(field_offsets.iter()) {
481-
let total_offset_start = base + field_offset;
482-
let mut total_offset_end = total_offset_start;
483-
values.push(
484-
self.read_from_const_alloc(alloc, &mut total_offset_end, ty)
485-
.def_cx(self),
486-
);
487-
occupied_spaces.push(total_offset_start..total_offset_end);
488-
}
489-
if let Some(size) = size {
490-
*offset += size;
491-
} else {
492-
assert_eq!(
493-
offset.bytes_usize(),
494-
alloc.inner().len(),
495-
"create_const_alloc must consume all bytes of an Allocation after an unsized struct"
469+
// HACK(eddyb) this accounts for unsized `struct`s, and allows
470+
// detecting gaps *only* at the end of the type, but is cheap.
471+
let mut tail_read_range = ..Size::ZERO;
472+
let result = self.constant_composite(
473+
ty,
474+
field_types
475+
.iter()
476+
.zip_eq(field_offsets.iter())
477+
.map(|(&f_ty, &f_offset)| {
478+
let (f, f_size) =
479+
self.read_from_const_alloc_at(alloc, f_ty, offset + f_offset);
480+
tail_read_range.end =
481+
tail_read_range.end.max(offset + f_offset + f_size);
482+
f.def_cx(self)
483+
}),
484+
);
485+
486+
let ty_size = ty_def.sizeof(self);
487+
488+
// HACK(eddyb) catch non-padding holes in e.g. `enum` values.
489+
if let Some(ty_size) = ty_size
490+
&& let Some(tail_gap) = (ty_size.bytes())
491+
.checked_sub(tail_read_range.end.align_to(ty_def.alignof(self)).bytes())
492+
&& tail_gap > 0
493+
{
494+
self.zombie_no_span(
495+
result.def_cx(self),
496+
&format!(
497+
"undersized `{}` constant (at least {tail_gap} bytes may be missing)",
498+
self.debug_type(ty)
499+
),
496500
);
497501
}
498-
self.constant_composite(ty, values.into_iter())
499-
}
500-
SpirvType::Array { element, count } => {
501-
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
502-
let values = (0..count).map(|_| {
503-
self.read_from_const_alloc(alloc, offset, element)
504-
.def_cx(self)
505-
});
506-
self.constant_composite(ty, values)
507-
}
508-
SpirvType::Vector { element, count } => {
509-
let total_size = ty_concrete
510-
.sizeof(self)
511-
.expect("create_const_alloc: Vectors must be sized");
512-
let final_offset = *offset + total_size;
513-
let values = (0..count).map(|_| {
514-
self.read_from_const_alloc(alloc, offset, element)
515-
.def_cx(self)
516-
});
517-
let result = self.constant_composite(ty, values);
518-
assert!(*offset <= final_offset);
519-
// Vectors sometimes have padding at the end (e.g. vec3), skip over it.
520-
*offset = final_offset;
521-
result
522-
}
523-
SpirvType::Matrix { element, count } => {
524-
let total_size = ty_concrete
525-
.sizeof(self)
526-
.expect("create_const_alloc: Matrices must be sized");
527-
let final_offset = *offset + total_size;
528-
let values = (0..count).map(|_| {
529-
self.read_from_const_alloc(alloc, offset, element)
530-
.def_cx(self)
531-
});
532-
let result = self.constant_composite(ty, values);
533-
assert!(*offset <= final_offset);
534-
// Matrices sometimes have padding at the end (e.g. Mat4x3), skip over it.
535-
*offset = final_offset;
536-
result
502+
503+
(result, ty_size.unwrap_or(tail_read_range.end))
537504
}
538-
SpirvType::RuntimeArray { element } => {
539-
let mut values = Vec::new();
540-
while offset.bytes_usize() != alloc.inner().len() {
541-
values.push(
542-
self.read_from_const_alloc(alloc, offset, element)
543-
.def_cx(self),
544-
);
505+
SpirvType::Vector { element, .. }
506+
| SpirvType::Matrix { element, .. }
507+
| SpirvType::Array { element, .. }
508+
| SpirvType::RuntimeArray { element } => {
509+
let stride = self.lookup_type(element).sizeof(self).unwrap();
510+
511+
let count = match ty_def {
512+
SpirvType::Vector { count, .. } | SpirvType::Matrix { count, .. } => {
513+
u64::from(count)
514+
}
515+
SpirvType::Array { count, .. } => {
516+
u64::try_from(self.builder.lookup_const_scalar(count).unwrap()).unwrap()
517+
}
518+
SpirvType::RuntimeArray { .. } => {
519+
(alloc.inner().size() - offset).bytes() / stride.bytes()
520+
}
521+
_ => unreachable!(),
522+
};
523+
524+
let result = self.constant_composite(
525+
ty,
526+
(0..count).map(|i| {
527+
let (e, e_size) =
528+
self.read_from_const_alloc_at(alloc, element, offset + i * stride);
529+
assert_eq!(e_size, stride);
530+
e.def_cx(self)
531+
}),
532+
);
533+
534+
// HACK(eddyb) `align_to` can only cause an increase for `Vector`,
535+
// because its `size`/`align` are rounded up to a power of two
536+
// (for now, at least, even if eventually that should go away).
537+
let read_size = (count * stride).align_to(ty_def.alignof(self));
538+
539+
if let Some(ty_size) = ty_def.sizeof(self) {
540+
assert_eq!(read_size, ty_size);
545541
}
546-
let result = self.constant_composite(ty, values.into_iter());
547-
// TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv:
548-
/*
549-
__constant struct A {
550-
float x;
551-
int y[];
552-
} a = {1, {2, 3, 4}};
553-
554-
__kernel void foo(__global int* data, __constant int* c) {
555-
__constant struct A* asdf = &a;
556-
*data = *c + asdf->y[*c];
542+
543+
if let SpirvType::RuntimeArray { .. } = ty_def {
544+
// FIXME(eddyb) values of this type should never be created,
545+
// the only reasonable encoding of e.g. `&str` consts should
546+
// be `&[u8; N]` consts, with the `static_addr_of` pointer
547+
// (*not* the value it points to) cast to `&str`, afterwards.
548+
self.zombie_no_span(
549+
result.def_cx(self),
550+
&format!("unsupported unsized `{}` constant", self.debug_type(ty)),
551+
);
557552
}
558-
*/
559-
// NOTE(eddyb) the above description is a bit outdated, it's now
560-
// clear `OpTypeRuntimeArray` does not belong in user code, and
561-
// is only for dynamically-sized SSBOs and descriptor indexing,
562-
// and a general solution looks similar to `union` handling, but
563-
// for the length of a fixed-length array.
564-
self.zombie_no_span(result.def_cx(self), "constant `OpTypeRuntimeArray` value");
565-
result
553+
554+
(result, read_size)
555+
}
556+
557+
SpirvType::Void
558+
| SpirvType::Function { .. }
559+
| SpirvType::Image { .. }
560+
| SpirvType::Sampler
561+
| SpirvType::SampledImage { .. }
562+
| SpirvType::InterfaceBlock { .. }
563+
| SpirvType::AccelerationStructureKhr
564+
| SpirvType::RayQueryKhr => {
565+
let result = self.undef(ty);
566+
self.zombie_no_span(
567+
result.def_cx(self),
568+
&format!(
569+
"cannot reinterpret Rust constant data as a `{}` value",
570+
self.debug_type(ty)
571+
),
572+
);
573+
(result, ty_def.sizeof(self).unwrap_or(Size::ZERO))
566574
}
567-
SpirvType::Function { .. } => self
568-
.tcx
569-
.dcx()
570-
.fatal("TODO: SpirvType::Function not supported yet in create_const_alloc"),
571-
SpirvType::Image { .. } => self.tcx.dcx().fatal("cannot create a constant image value"),
572-
SpirvType::Sampler => self
573-
.tcx
574-
.dcx()
575-
.fatal("cannot create a constant sampler value"),
576-
SpirvType::SampledImage { .. } => self
577-
.tcx
578-
.dcx()
579-
.fatal("cannot create a constant sampled image value"),
580-
SpirvType::InterfaceBlock { .. } => self
581-
.tcx
582-
.dcx()
583-
.fatal("cannot create a constant interface block value"),
584-
SpirvType::AccelerationStructureKhr => self
585-
.tcx
586-
.dcx()
587-
.fatal("cannot create a constant acceleration structure"),
588-
SpirvType::RayQueryKhr => self.tcx.dcx().fatal("cannot create a constant ray query"),
589575
}
590576
}
591577
}

crates/rustc_codegen_spirv/src/codegen_cx/declare.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> {
394394
other.debug(g.ty, self)
395395
)),
396396
};
397-
let v = self.create_const_alloc(alloc, value_ty);
397+
let v = self.try_read_from_const_alloc(alloc, value_ty).unwrap();
398398
assert_ty_eq!(self, value_ty, v.ty);
399399
self.builder
400400
.set_global_initializer(g.def_cx(self), v.def_cx(self));

0 commit comments

Comments
 (0)