Skip to content

Add support for PhysicalStorageBuffer #237

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
4 changes: 3 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
Expand Down Expand Up @@ -339,6 +339,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def(span, cx);
entry.insert(PointeeDefState::Defined(id));
Expand All @@ -350,6 +351,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def_with_id(cx, span, id)
}
Expand Down
5 changes: 3 additions & 2 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ impl CheckSpirvAttrVisitor<'_> {
"attribute is only valid on a parameter of an entry-point function",
);
} else {
// FIXME(eddyb) should we just remove all 5 of these storage class
// FIXME(eddyb) should we just remove all 6 of these storage class
// attributes, instead of disallowing them here?
if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
let valid = match storage_class {
Expand All @@ -347,7 +347,8 @@ impl CheckSpirvAttrVisitor<'_> {

StorageClass::Private
| StorageClass::Function
| StorageClass::Generic => {
| StorageClass::Generic
| StorageClass::PhysicalStorageBuffer => {
Err("can not be used as part of an entry's interface")
}

Expand Down
119 changes: 90 additions & 29 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

fn zombie_convert_ptr_to_u(&self, def: Word) {
self.zombie(def, "cannot convert pointers to integers");
if !self
.builder
.has_capability(Capability::PhysicalStorageBufferAddresses)
{
self.zombie(def, "cannot convert pointers to integers without OpCapability PhysicalStorageBufferAddresses");
}
}

fn zombie_convert_u_to_ptr(&self, def: Word) {
self.zombie(def, "cannot convert integers to pointers");
if !self
.builder
.has_capability(Capability::PhysicalStorageBufferAddresses)
{
self.zombie(def, "cannot convert integers to pointers without OpCapability PhysicalStorageBufferAddresses");
}
}

fn zombie_ptr_equal(&self, def: Word, inst: &str) {
Expand Down Expand Up @@ -407,14 +417,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
size: Size,
) -> Option<(SpirvValue, <Self as BackendTypes>::Type)> {
let ptr = ptr.strip_ptrcasts();
let mut leaf_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
let pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("non-pointer type: {other:?}")),
};

// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
let mut leaf_ty = pointee_ty;
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Expand All @@ -429,7 +440,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.then(|| self.type_ptr_to(leaf_ty))?;

let leaf_ptr = if indices.is_empty() {
assert_ty_eq!(self, ptr.ty, leaf_ptr_ty);
// Compare pointee types instead of pointer types as storage class might be different.
assert_ty_eq!(self, pointee_ty, leaf_ty);
ptr
} else {
let indices = indices
Expand Down Expand Up @@ -586,7 +598,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let ptr = ptr.strip_ptrcasts();
let ptr_id = ptr.def(self);
let original_pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
};

Expand Down Expand Up @@ -1461,11 +1473,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.fatal("dynamic alloca not supported yet")
}

fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value {
fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| {
self.emit()
.load(access_ty, None, ptr.def(self), None, empty())
.load(
access_ty,
None,
ptr.def(self),
Some(MemoryAccess::ALIGNED),
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
)
.unwrap()
.with_type(access_ty)
});
Expand Down Expand Up @@ -1587,12 +1605,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// ignore
}

fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty);
let val = self.bitcast(val, access_ty);

self.emit()
.store(ptr.def(self), val.def(self), None, empty())
.store(
ptr.def(self),
val.def(self),
Some(MemoryAccess::ALIGNED),
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
)
.unwrap();
// FIXME(eddyb) this is meant to be a handle the store instruction itself.
val
Expand Down Expand Up @@ -1750,20 +1773,23 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn inttoptr(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
match self.lookup_type(dest_ty) {
SpirvType::Pointer { .. } => (),
let result_ty = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => self.type_ptr_to_with_storage_class(
pointee,
StorageClassKind::Explicit(StorageClass::PhysicalStorageBuffer),
),
other => self.fatal(format!(
"inttoptr called on non-pointer dest type: {other:?}"
)),
}
if val.ty == dest_ty {
};
if val.ty == result_ty {
val
} else {
let result = self
.emit()
.convert_u_to_ptr(dest_ty, None, val.def(self))
.convert_u_to_ptr(result_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty);
.with_type(result_ty);
self.zombie_convert_u_to_ptr(result.def(self));
result
}
Expand Down Expand Up @@ -1926,6 +1952,25 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return ptr;
}

// No cast is needed if only the storage class mismatches.
let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

// FIXME(jwollen) Do we need to choose `dest_ty` if it has a fixed storage class and `ptr` has none?
if ptr_pointee == dest_pointee {
return ptr;
}

// Strip a previous `pointercast`, to reveal the original pointer type.
let ptr = ptr.strip_ptrcasts();

Expand All @@ -1934,17 +1979,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

if ptr_pointee == dest_pointee {
return ptr;
}

let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

if let Some((indices, _)) = self.recover_access_chain_from_offset(
Expand Down Expand Up @@ -2229,9 +2273,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn memcpy(
&mut self,
dst: Self::Value,
_dst_align: Align,
dst_align: Align,
src: Self::Value,
_src_align: Align,
src_align: Align,
size: Self::Value,
flags: MemFlags,
) {
Expand Down Expand Up @@ -2269,12 +2313,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
});

// Pass all operands as `additional_params` since rspirv doesn't allow specifying
// extra operands ofter the first `MemoryAccess`
let mut ops: SmallVec<[_; 4]> = Default::default();
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
if src_align != dst_align {
if self.emit().version().unwrap() > (1, 3) {
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
ops.push(Operand::LiteralBit32(src_align.bytes() as _));
} else {
let align = dst_align.min(src_align);
ops.push(Operand::LiteralBit32(align.bytes() as _));
}
} else {
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
}

if let Some((dst, src)) = typed_copy_dst_src {
if let Some(const_value) = src.const_fold_load(self) {
self.store(const_value, dst, Align::from_bytes(0).unwrap());
} else {
self.emit()
.copy_memory(dst.def(self), src.def(self), None, None, empty())
.copy_memory(dst.def(self), src.def(self), None, None, ops)
.unwrap();
}
} else {
Expand All @@ -2285,7 +2346,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
size.def(self),
None,
None,
empty(),
ops,
)
.unwrap();
self.zombie(dst.def(self), "cannot memcpy dynamically sized data");
Expand Down Expand Up @@ -2324,7 +2385,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));

let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
_ => self.fatal(format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
Expand Down Expand Up @@ -2696,7 +2757,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
(callee.def(self), return_type, arguments)
}

SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Pointer { pointee, .. } => match self.lookup_type(pointee) {
SpirvType::Function {
return_type,
arguments,
Expand Down
22 changes: 19 additions & 3 deletions crates/rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::spirv::{StorageClass, Word};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
Expand Down Expand Up @@ -104,7 +104,23 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {

// HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness.
pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(self.span(), self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(self.span(), self)
}

pub fn type_ptr_to_with_storage_class(
&self,
ty: Word,
storage_class: StorageClassKind,
) -> Word {
SpirvType::Pointer {
pointee: ty,
storage_class,
}
.def(self.span(), self)
}

// TODO: Definitely add tests to make sure this impl is right.
Expand Down
24 changes: 10 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{BuilderCursor, SpirvValue};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::dr;
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier, reflect};
use rspirv::spirv::{
Expand Down Expand Up @@ -307,19 +307,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
}
.def(self.span(), self),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
self.struct_err("TypePointer in asm! requires `Generic` storage class")
.with_note(format!(
"`{storage_class:?}` storage class was specified"
))
.with_help(format!(
"the storage class will be inferred automatically (e.g. to `{storage_class:?}`)"
))
.emit();
}
// The storage class can be specified explicitly or inferred later by using StorageClass::Generic.
let storage_class = match inst.operands[0].unwrap_storage_class() {
StorageClass::Generic => StorageClassKind::Inferred,
storage_class => StorageClassKind::Explicit(storage_class),
};
SpirvType::Pointer {
pointee: inst.operands[1].unwrap_id_ref(),
storage_class,
}
.def(self.span(), self)
}
Expand Down Expand Up @@ -678,6 +673,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, cx),

Expand Down Expand Up @@ -931,7 +927,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Some(match kind {
TypeofKind::Plain => ty,
TypeofKind::Dereference => match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => {
self.tcx.dcx().span_err(
span,
Expand All @@ -953,7 +949,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.check_reg(span, reg);
if let Some(place) = place {
match self.lookup_type(place.val.llval.ty) {
SpirvType::Pointer { pointee } => Some(pointee),
SpirvType::Pointer { pointee, .. } => Some(pointee),
other => {
self.tcx.dcx().span_err(
span,
Expand Down
Loading
Loading