diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index 04d1bf9860..622ec88915 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -3,7 +3,7 @@ use crate::attr::{AggregatedSpirvAttributes, IntrinsicType}; use crate::codegen_cx::CodegenCx; -use crate::spirv_type::SpirvType; +use crate::spirv_type::{SpirvType, StorageClassKind}; use itertools::Itertools; use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word}; use rustc_data_structures::fx::FxHashMap; @@ -339,6 +339,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> { PointeeDefState::Defining => { let id = SpirvType::Pointer { pointee: pointee_spv, + storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class? } .def(span, cx); entry.insert(PointeeDefState::Defined(id)); @@ -350,6 +351,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> { entry.insert(PointeeDefState::Defined(id)); SpirvType::Pointer { pointee: pointee_spv, + storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class? } .def_with_id(cx, span, id) } diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index a47b1ba1ad..0eab9927e5 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -337,7 +337,7 @@ impl CheckSpirvAttrVisitor<'_> { "attribute is only valid on a parameter of an entry-point function", ); } else { - // FIXME(eddyb) should we just remove all 5 of these storage class + // FIXME(eddyb) should we just remove all 6 of these storage class // attributes, instead of disallowing them here? if let SpirvAttribute::StorageClass(storage_class) = parsed_attr { let valid = match storage_class { @@ -347,7 +347,8 @@ impl CheckSpirvAttrVisitor<'_> { StorageClass::Private | StorageClass::Function - | StorageClass::Generic => { + | StorageClass::Generic + | StorageClass::PhysicalStorageBuffer => { Err("can not be used as part of an entry's interface") } diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 7e61c1b989..b0fcc913da 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -360,11 +360,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } fn zombie_convert_ptr_to_u(&self, def: Word) { - self.zombie(def, "cannot convert pointers to integers"); + if !self + .builder + .has_capability(Capability::PhysicalStorageBufferAddresses) + { + self.zombie(def, "cannot convert pointers to integers without OpCapability PhysicalStorageBufferAddresses"); + } } fn zombie_convert_u_to_ptr(&self, def: Word) { - self.zombie(def, "cannot convert integers to pointers"); + if !self + .builder + .has_capability(Capability::PhysicalStorageBufferAddresses) + { + self.zombie(def, "cannot convert integers to pointers without OpCapability PhysicalStorageBufferAddresses"); + } } fn zombie_ptr_equal(&self, def: Word, inst: &str) { @@ -407,14 +417,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { size: Size, ) -> Option<(SpirvValue, ::Type)> { let ptr = ptr.strip_ptrcasts(); - let mut leaf_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, + let pointee_ty = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee, .. } => pointee, other => self.fatal(format!("non-pointer type: {other:?}")), }; // FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset` // could instead be doing all the extra digging itself. let mut indices = SmallVec::<[_; 8]>::new(); + let mut leaf_ty = pointee_ty; while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset( leaf_ty, Size::ZERO, @@ -429,7 +440,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .then(|| self.type_ptr_to(leaf_ty))?; let leaf_ptr = if indices.is_empty() { - assert_ty_eq!(self, ptr.ty, leaf_ptr_ty); + // Compare pointee types instead of pointer types as storage class might be different. + assert_ty_eq!(self, pointee_ty, leaf_ty); ptr } else { let indices = indices @@ -586,7 +598,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let ptr = ptr.strip_ptrcasts(); let ptr_id = ptr.def(self); let original_pointee_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => self.fatal(format!("gep called on non-pointer type: {other:?}")), }; @@ -1461,11 +1473,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.fatal("dynamic alloca not supported yet") } - fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value { + fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value { let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty); let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| { self.emit() - .load(access_ty, None, ptr.def(self), None, empty()) + .load( + access_ty, + None, + ptr.def(self), + Some(MemoryAccess::ALIGNED), + std::iter::once(Operand::LiteralBit32(align.bytes() as _)), + ) .unwrap() .with_type(access_ty) }); @@ -1587,12 +1605,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // ignore } - fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value { + fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align) -> Self::Value { let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty); let val = self.bitcast(val, access_ty); self.emit() - .store(ptr.def(self), val.def(self), None, empty()) + .store( + ptr.def(self), + val.def(self), + Some(MemoryAccess::ALIGNED), + std::iter::once(Operand::LiteralBit32(align.bytes() as _)), + ) .unwrap(); // FIXME(eddyb) this is meant to be a handle the store instruction itself. val @@ -1750,20 +1773,23 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn inttoptr(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value { - match self.lookup_type(dest_ty) { - SpirvType::Pointer { .. } => (), + let result_ty = match self.lookup_type(dest_ty) { + SpirvType::Pointer { pointee, .. } => self.type_ptr_to_with_storage_class( + pointee, + StorageClassKind::Explicit(StorageClass::PhysicalStorageBuffer), + ), other => self.fatal(format!( "inttoptr called on non-pointer dest type: {other:?}" )), - } - if val.ty == dest_ty { + }; + if val.ty == result_ty { val } else { let result = self .emit() - .convert_u_to_ptr(dest_ty, None, val.def(self)) + .convert_u_to_ptr(result_ty, None, val.def(self)) .unwrap() - .with_type(dest_ty); + .with_type(result_ty); self.zombie_convert_u_to_ptr(result.def(self)); result } @@ -1926,6 +1952,25 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return ptr; } + // No cast is needed if only the storage class mismatches. + let ptr_pointee = match self.lookup_type(ptr.ty) { + SpirvType::Pointer { pointee, .. } => pointee, + other => self.fatal(format!( + "pointercast called on non-pointer source type: {other:?}" + )), + }; + let dest_pointee = match self.lookup_type(dest_ty) { + SpirvType::Pointer { pointee, .. } => pointee, + other => self.fatal(format!( + "pointercast called on non-pointer dest type: {other:?}" + )), + }; + + // FIXME(jwollen) Do we need to choose `dest_ty` if it has a fixed storage class and `ptr` has none? + if ptr_pointee == dest_pointee { + return ptr; + } + // Strip a previous `pointercast`, to reveal the original pointer type. let ptr = ptr.strip_ptrcasts(); @@ -1934,17 +1979,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } let ptr_pointee = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => self.fatal(format!( "pointercast called on non-pointer source type: {other:?}" )), }; - let dest_pointee = match self.lookup_type(dest_ty) { - SpirvType::Pointer { pointee } => pointee, - other => self.fatal(format!( - "pointercast called on non-pointer dest type: {other:?}" - )), - }; + + if ptr_pointee == dest_pointee { + return ptr; + } + let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self); if let Some((indices, _)) = self.recover_access_chain_from_offset( @@ -2229,9 +2273,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn memcpy( &mut self, dst: Self::Value, - _dst_align: Align, + dst_align: Align, src: Self::Value, - _src_align: Align, + src_align: Align, size: Self::Value, flags: MemFlags, ) { @@ -2269,12 +2313,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } }); + // Pass all operands as `additional_params` since rspirv doesn't allow specifying + // extra operands ofter the first `MemoryAccess` + let mut ops: SmallVec<[_; 4]> = Default::default(); + ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED)); + if src_align != dst_align { + if self.emit().version().unwrap() > (1, 3) { + ops.push(Operand::LiteralBit32(dst_align.bytes() as _)); + ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED)); + ops.push(Operand::LiteralBit32(src_align.bytes() as _)); + } else { + let align = dst_align.min(src_align); + ops.push(Operand::LiteralBit32(align.bytes() as _)); + } + } else { + ops.push(Operand::LiteralBit32(dst_align.bytes() as _)); + } + if let Some((dst, src)) = typed_copy_dst_src { if let Some(const_value) = src.const_fold_load(self) { self.store(const_value, dst, Align::from_bytes(0).unwrap()); } else { self.emit() - .copy_memory(dst.def(self), src.def(self), None, None, empty()) + .copy_memory(dst.def(self), src.def(self), None, None, ops) .unwrap(); } } else { @@ -2285,7 +2346,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { size.def(self), None, None, - empty(), + ops, ) .unwrap(); self.zombie(dst.def(self), "cannot memcpy dynamically sized data"); @@ -2324,7 +2385,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?))); let elem_ty = match self.lookup_type(ptr.ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, _ => self.fatal(format!( "memset called on non-pointer type: {}", self.debug_type(ptr.ty) @@ -2696,7 +2757,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { (callee.def(self), return_type, arguments) } - SpirvType::Pointer { pointee } => match self.lookup_type(pointee) { + SpirvType::Pointer { pointee, .. } => match self.lookup_type(pointee) { SpirvType::Function { return_type, arguments, diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index 972432267e..ff40be4512 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -15,8 +15,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use crate::abi::ConvSpirvType; use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt}; use crate::codegen_cx::CodegenCx; -use crate::spirv_type::SpirvType; -use rspirv::spirv::Word; +use crate::spirv_type::{SpirvType, StorageClassKind}; +use rspirv::spirv::{StorageClass, Word}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; use rustc_codegen_ssa::mir::place::PlaceRef; use rustc_codegen_ssa::traits::{ @@ -104,7 +104,23 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness. pub fn type_ptr_to(&self, ty: Word) -> Word { - SpirvType::Pointer { pointee: ty }.def(self.span(), self) + SpirvType::Pointer { + pointee: ty, + storage_class: StorageClassKind::Inferred, + } + .def(self.span(), self) + } + + pub fn type_ptr_to_with_storage_class( + &self, + ty: Word, + storage_class: StorageClassKind, + ) -> Word { + SpirvType::Pointer { + pointee: ty, + storage_class, + } + .def(self.span(), self) } // TODO: Definitely add tests to make sure this impl is right. diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 91da22473d..b50e139f56 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::Builder; use crate::builder_spirv::{BuilderCursor, SpirvValue}; use crate::codegen_cx::CodegenCx; -use crate::spirv_type::SpirvType; +use crate::spirv_type::{SpirvType, StorageClassKind}; use rspirv::dr; use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier, reflect}; use rspirv::spirv::{ @@ -307,19 +307,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } .def(self.span(), self), Op::TypePointer => { - let storage_class = inst.operands[0].unwrap_storage_class(); - if storage_class != StorageClass::Generic { - self.struct_err("TypePointer in asm! requires `Generic` storage class") - .with_note(format!( - "`{storage_class:?}` storage class was specified" - )) - .with_help(format!( - "the storage class will be inferred automatically (e.g. to `{storage_class:?}`)" - )) - .emit(); - } + // The storage class can be specified explicitly or inferred later by using StorageClass::Generic. + let storage_class = match inst.operands[0].unwrap_storage_class() { + StorageClass::Generic => StorageClassKind::Inferred, + storage_class => StorageClassKind::Explicit(storage_class), + }; SpirvType::Pointer { pointee: inst.operands[1].unwrap_id_ref(), + storage_class, } .def(self.span(), self) } @@ -678,6 +673,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { TyPat::Pointer(_, pat) => SpirvType::Pointer { pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?, + storage_class: StorageClassKind::Inferred, } .def(DUMMY_SP, cx), @@ -931,7 +927,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Some(match kind { TypeofKind::Plain => ty, TypeofKind::Dereference => match self.lookup_type(ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => { self.tcx.dcx().span_err( span, @@ -953,7 +949,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { self.check_reg(span, reg); if let Some(place) = place { match self.lookup_type(place.val.llval.ty) { - SpirvType::Pointer { pointee } => Some(pointee), + SpirvType::Pointer { pointee, .. } => Some(pointee), other => { self.tcx.dcx().span_err( span, diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 3c91b48872..b71941125d 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -101,7 +101,7 @@ impl SpirvValue { match entry.val { SpirvConst::PtrTo { pointee } => { let ty = match cx.lookup_type(self.ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, ty => bug!("load called on value that wasn't a pointer: {:?}", ty), }; // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. @@ -193,17 +193,20 @@ impl SpirvValue { original_ptr_ty, bitcast_result_id, } => { - cx.zombie_with_span( - bitcast_result_id, - span, - &format!( - "cannot cast between pointer types\ - \nfrom `{}`\ - \n to `{}`", - cx.debug_type(original_ptr_ty), - cx.debug_type(self.ty) - ), - ); + // If physical poitners are supported, defer the error until after storage class inferrence. + if !cx.builder.has_capability(Capability::PhysicalStorageBufferAddresses) { + cx.zombie_with_span( + bitcast_result_id, + span, + &format!( + "cannot cast between pointer types\ + \nfrom `{}`\ + \n to `{}`", + cx.debug_type(original_ptr_ty), + cx.debug_type(self.ty) + ), + ); + } bitcast_result_id } @@ -492,7 +495,14 @@ impl<'tcx> BuilderSpirv<'tcx> { // The linker will always be ran on this module add_cap(&mut builder, &mut enabled_capabilities, Capability::Linkage); - builder.memory_model(AddressingModel::Logical, memory_model); + let addressing_model = + if enabled_capabilities.contains(&Capability::PhysicalStorageBufferAddresses) { + AddressingModel::PhysicalStorageBuffer64 + } else { + AddressingModel::Logical + }; + + builder.memory_model(addressing_model, memory_model); Self { source_map: tcx.sess.source_map(), diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index 162eb9dc8c..e32bfad3b5 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -239,7 +239,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> { let (base_addr, _base_addr_space) = match self.tcx.global_alloc(alloc_id) { GlobalAlloc::Memory(alloc) => { let pointee = match self.lookup_type(ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => self.tcx.dcx().fatal(format!( "GlobalAlloc::Memory type not implemented: {}", other.debug(ty, self) @@ -259,7 +259,7 @@ impl<'tcx> ConstCodegenMethods<'tcx> for CodegenCx<'tcx> { .global_alloc(self.tcx.vtable_allocation((vty, dyn_ty.principal()))) .unwrap_memory(); let pointee = match self.lookup_type(ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => self.tcx.dcx().fatal(format!( "GlobalAlloc::VTable type not implemented: {}", other.debug(ty, self) @@ -328,7 +328,7 @@ impl<'tcx> CodegenCx<'tcx> { if let Some(SpirvConst::ConstDataFromAlloc(alloc)) = self.builder.lookup_const_by_id(pointee) { - if let SpirvType::Pointer { pointee } = self.lookup_type(ty) { + if let SpirvType::Pointer { pointee, .. } = self.lookup_type(ty) { let mut offset = Size::ZERO; let init = self.read_from_const_alloc(alloc, &mut offset, pointee); return self.static_addr_of(init, alloc.inner().align, None); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index eae83d8e93..592cdaa434 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -6,7 +6,7 @@ use crate::abi::ConvSpirvType; use crate::attr::AggregatedSpirvAttributes; use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::custom_decorations::{CustomDecoration, SrcLocDecoration}; -use crate::spirv_type::SpirvType; +use crate::spirv_type::{SpirvType, StorageClassKind}; use itertools::Itertools; use rspirv::spirv::{FunctionControl, LinkageType, StorageClass, Word}; use rustc_attr::InlineAttr; @@ -267,7 +267,12 @@ impl<'tcx> CodegenCx<'tcx> { } fn declare_global(&self, span: Span, ty: Word) -> SpirvValue { - let ptr_ty = SpirvType::Pointer { pointee: ty }.def(span, self); + // Could be explicitly StorageClass::Private but is inferred anyway. + let ptr_ty = SpirvType::Pointer { + pointee: ty, + storage_class: StorageClassKind::Inferred, + } + .def(span, self); // FIXME(eddyb) figure out what the correct storage class is. let result = self .emit_global() @@ -353,7 +358,7 @@ impl<'tcx> StaticCodegenMethods for CodegenCx<'tcx> { Err(_) => return, }; let value_ty = match self.lookup_type(g.ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, other => self.tcx.dcx().fatal(format!( "global had non-pointer type {}", other.debug(g.ty, self) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index a2ad70a010..297c91be30 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -967,7 +967,7 @@ impl<'tcx> CodegenCx<'tcx> { | SpirvType::Matrix { element, .. } | SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } - | SpirvType::Pointer { pointee: element } + | SpirvType::Pointer { pointee: element, .. } | SpirvType::InterfaceBlock { inner_type: element, } => recurse(cx, element, has_bool, must_be_flat), diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f3a3828018..6435a43e86 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -6,7 +6,7 @@ mod type_; use crate::builder::{ExtInst, InstructionTable}; use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvConst, SpirvValue, SpirvValueKind}; use crate::custom_decorations::{CustomDecoration, SrcLocDecoration, ZombieDecoration}; -use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache}; +use crate::spirv_type::{SpirvType, SpirvTypePrinter, StorageClassKind, TypeCache}; use crate::symbols::Symbols; use crate::target::SpirvTarget; @@ -234,11 +234,19 @@ impl<'tcx> CodegenCx<'tcx> { } pub fn type_ptr_to(&self, ty: Word) -> Word { - SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) + SpirvType::Pointer { + pointee: ty, + storage_class: StorageClassKind::Inferred, + } + .def(DUMMY_SP, self) } pub fn type_ptr_to_ext(&self, ty: Word, _address_space: AddressSpace) -> Word { - SpirvType::Pointer { pointee: ty }.def(DUMMY_SP, self) + SpirvType::Pointer { + pointee: ty, + storage_class: StorageClassKind::Inferred, + } + .def(DUMMY_SP, self) } /// Zombie system: @@ -866,6 +874,7 @@ impl<'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'tcx> { let ty = SpirvType::Pointer { pointee: function.ty, + storage_class: StorageClassKind::Inferred, } .def(span, self); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs index f0005a9e42..256f43b8af 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -219,7 +219,7 @@ impl<'tcx> BaseTypeCodegenMethods<'tcx> for CodegenCx<'tcx> { } fn element_type(&self, ty: Self::Type) -> Self::Type { match self.lookup_type(ty) { - SpirvType::Pointer { pointee } => pointee, + SpirvType::Pointer { pointee, .. } => pointee, SpirvType::Vector { element, .. } => element, spirv_type => self.tcx.dcx().fatal(format!( "element_type called on invalid type: {spirv_type:?}" diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index ba82f95fee..cea97cd588 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -11,11 +11,12 @@ use super::simple_passes::outgoing_edges; use super::{apply_rewrite_rules, id}; +use itertools::Itertools; use rspirv::dr::{Block, Function, Instruction, ModuleHeader, Operand}; use rspirv::spirv::{Op, Word}; use rustc_data_structures::fx::{FxHashMap, FxHashSet, FxIndexMap}; -use rustc_middle::bug; use std::collections::hash_map; +use std::iter; // HACK(eddyb) newtype instead of type alias to avoid mistakes. #[derive(Copy, Clone, PartialEq, Eq, Hash)] @@ -328,10 +329,15 @@ fn split_copy_memory( if inst.class.opcode == Op::CopyMemory { let target = inst.operands[0].id_ref_any().unwrap(); let source = inst.operands[1].id_ref_any().unwrap(); - if inst.operands.len() > 2 { - // TODO: Copy the memory operands to the load/store - bug!("mem2reg OpCopyMemory doesn't support memory operands yet"); - } + let mem_ops = &inst.operands[2..]; + let (store_mem_ops, load_mem_ops) = if let Some((index, _)) = mem_ops[1..] + .iter() + .find_position(|op| matches!(op, Operand::MemoryAccess(..))) + { + mem_ops.split_at(index) + } else { + (mem_ops, mem_ops) + }; let ty = match (var_map.get(&target), var_map.get(&source)) { (None, None) => { inst_index += 1; @@ -345,17 +351,22 @@ fn split_copy_memory( } }; let temp_id = id(header); + + let load_ops = iter::once(Operand::IdRef(source)) + .chain(load_mem_ops.iter().cloned()) + .collect(); + + let store_ops = [Operand::IdRef(target), Operand::IdRef(temp_id)] + .into_iter() + .chain(store_mem_ops.iter().cloned()) + .collect(); + block.instructions[inst_index] = - Instruction::new(Op::Load, Some(ty), Some(temp_id), vec![Operand::IdRef( - source, - )]); + Instruction::new(Op::Load, Some(ty), Some(temp_id), load_ops); inst_index += 1; block.instructions.insert( inst_index, - Instruction::new(Op::Store, None, None, vec![ - Operand::IdRef(target), - Operand::IdRef(temp_id), - ]), + Instruction::new(Op::Store, None, None, store_ops), ); } inst_index += 1; diff --git a/crates/rustc_codegen_spirv/src/linker/specializer.rs b/crates/rustc_codegen_spirv/src/linker/specializer.rs index 01c4c9f9b1..893102c048 100644 --- a/crates/rustc_codegen_spirv/src/linker/specializer.rs +++ b/crates/rustc_codegen_spirv/src/linker/specializer.rs @@ -1615,6 +1615,14 @@ impl<'a, S: Specialization> InferCx<'a, S> { #[allow(clippy::match_same_arms)] Ok(match (a.clone(), b.clone()) { + // Concrete result types explicitly created inside functions + // can be assigned to instances. + // FIXME(jwollen) do we need to infere instance generics? + (InferOperand::Instance(_), InferOperand::Concrete(new)) + | (InferOperand::Concrete(new), InferOperand::Instance(_)) => { + InferOperand::Concrete(new) + } + // Instances of "generic" globals/functions must be of the same ID, // and their `generic_args` inference variables must be unified. ( @@ -1999,13 +2007,13 @@ impl<'a, S: Specialization> InferCx<'a, S> { if let Some(type_of_result) = type_of_result { // Keep the (instantiated) *Result Type*, for future instructions to use - // (but only if it has any `InferVar`s at all). + // if it has any `InferVar`s at all or if it was a concrete type. match type_of_result { - InferOperand::Var(_) | InferOperand::Instance(_) => { + InferOperand::Var(_) | InferOperand::Instance(_) | InferOperand::Concrete(_) => { self.type_of_result .insert(inst.result_id.unwrap(), type_of_result); } - InferOperand::Unknown | InferOperand::Concrete(_) => {} + InferOperand::Unknown => {} } } } diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 0c8ce42ba0..dc584484ec 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -61,6 +61,7 @@ pub enum SpirvType<'tcx> { }, Pointer { pointee: Word, + storage_class: StorageClassKind, }, Function { return_type: Word, @@ -90,6 +91,17 @@ pub enum SpirvType<'tcx> { RayQueryKhr, } +#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] +pub enum StorageClassKind { + /// Inferred based on globals and other pointers with explicit storage classes. + /// This corresponds to `StorageClass::Generic` in inline `asm!` and intermediate SPIR-V. + Inferred, + + /// Explicitly set by an instruction that needs to create a storage class, + /// regardless of inputs. + Explicit(StorageClass), +} + impl SpirvType<'_> { /// Note: `Builder::type_*` should be called *nowhere else* but here, to ensure /// `CodegenCx::type_defs` stays up-to-date @@ -213,13 +225,18 @@ impl SpirvType<'_> { ); result } - Self::Pointer { pointee } => { + Self::Pointer { + pointee, + storage_class, + } => { // NOTE(eddyb) we emit `StorageClass::Generic` here, but later // the linker will specialize the entire SPIR-V module to use // storage classes inferred from `OpVariable`s. - let result = cx - .emit_global() - .type_pointer(id, StorageClass::Generic, pointee); + let storage_class = match storage_class { + StorageClassKind::Inferred => StorageClass::Generic, + StorageClassKind::Explicit(storage_class) => storage_class, + }; + let result = cx.emit_global().type_pointer(id, storage_class, pointee); // no pointers to functions if let SpirvType::Function { .. } = cx.lookup_type(pointee) { // FIXME(eddyb) use the `SPV_INTEL_function_pointers` extension. @@ -286,13 +303,20 @@ impl SpirvType<'_> { return cached; } let result = match self { - Self::Pointer { pointee } => { + Self::Pointer { + pointee, + storage_class, + } => { // NOTE(eddyb) we emit `StorageClass::Generic` here, but later // the linker will specialize the entire SPIR-V module to use // storage classes inferred from `OpVariable`s. - let result = - cx.emit_global() - .type_pointer(Some(id), StorageClass::Generic, pointee); + let storage_class = match storage_class { + StorageClassKind::Inferred => StorageClass::Generic, + StorageClassKind::Explicit(storage_class) => storage_class, + }; + let result = cx + .emit_global() + .type_pointer(Some(id), storage_class, pointee); // no pointers to functions if let SpirvType::Function { .. } = cx.lookup_type(pointee) { // FIXME(eddyb) use the `SPV_INTEL_function_pointers` extension. @@ -412,7 +436,13 @@ impl SpirvType<'_> { SpirvType::Matrix { element, count } => SpirvType::Matrix { element, count }, SpirvType::Array { element, count } => SpirvType::Array { element, count }, SpirvType::RuntimeArray { element } => SpirvType::RuntimeArray { element }, - SpirvType::Pointer { pointee } => SpirvType::Pointer { pointee }, + SpirvType::Pointer { + pointee, + storage_class, + } => SpirvType::Pointer { + pointee, + storage_class, + }, SpirvType::Image { sampled_type, dim, @@ -557,10 +587,14 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { .field("id", &self.id) .field("element", &self.cx.debug_type(element)) .finish(), - SpirvType::Pointer { pointee } => f + SpirvType::Pointer { + pointee, + storage_class, + } => f .debug_struct("Pointer") .field("id", &self.id) .field("pointee", &self.cx.debug_type(pointee)) + .field("storage_class", &storage_class) .finish(), SpirvType::Function { return_type, @@ -710,8 +744,14 @@ impl SpirvTypePrinter<'_, '_> { ty(self.cx, stack, f, element)?; f.write_str("]") } - SpirvType::Pointer { pointee } => { + SpirvType::Pointer { + pointee, + storage_class, + } => { f.write_str("*")?; + if let StorageClassKind::Explicit(storage_class) = storage_class { + write!(f, "{:?}", storage_class)?; + } ty(self.cx, stack, f, pointee) } SpirvType::Function { diff --git a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs index 148b44c61f..1321722903 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs @@ -427,7 +427,10 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::ConvertPtrToU | Op::SatConvertSToU | Op::SatConvertUToS | Op::ConvertUToPtr => {} Op::PtrCastToGeneric | Op::GenericCastToPtr => sig! { (Pointer(_, T)) -> Pointer(_, T) }, Op::GenericCastToPtrExplicit => sig! { {S} (Pointer(_, T)) -> Pointer(S, T) }, - Op::Bitcast => {} + Op::Bitcast => sig! { + (Pointer(S, _)) -> Pointer(S, _) | + (_) -> _ + }, // 3.37.12. Composite Instructions Op::VectorExtractDynamic => sig! { (Vector(T), _) -> T }, diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 61c84dfc47..145d38d6fb 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -98,6 +98,7 @@ pub mod indirect_command; pub mod integer; pub mod memory; pub mod number; +pub mod ptr; pub mod ray_tracing; mod runtime_array; mod sampler; diff --git a/crates/spirv-std/src/ptr.rs b/crates/spirv-std/src/ptr.rs new file mode 100644 index 0000000000..c135f05dfa --- /dev/null +++ b/crates/spirv-std/src/ptr.rs @@ -0,0 +1,118 @@ +//! Physical pointers + +#[cfg(target_arch = "spirv")] +use core::arch::asm; +use core::marker::PhantomData; + +/// A physical pointer in the `PhysicalStorageBuffer` storage class +/// with semantics similar to `*mut T`. +/// +/// This is similar to a raw pointer retrieved through `u64 as *mut T`, but +/// provides utilities for pointer manipulation that are currently not +/// supported on raw pointers due to the otherwise logical addressing model +/// and 32-bit pointer size. +pub struct PhysicalPtr { + // Use uvec2 instead of u64 to avoid demepndency on the Int64 dependency. + addr: glam::UVec2, + //addr: u64, + _marker: PhantomData<*mut T>, +} + +impl Copy for PhysicalPtr {} + +impl Clone for PhysicalPtr { + fn clone(&self) -> Self { + Self { + addr: self.addr, + _marker: PhantomData, + } + } +} + +impl PhysicalPtr { + /// Get a mutaple pointer to the physical address. + /// The same aliasing rules that apply to FFI, apply to the returned pointer. + #[crate::macros::gpu_only] + pub fn get(self) -> *mut T { + let result: *mut T; + unsafe { + // FIXME(jwollen) add a way to dereference the result type further + // or to pass type parameters + let dummy: T = core::mem::MaybeUninit::uninit().assume_init(); + asm!( + "%ptr_type = OpTypePointer PhysicalStorageBuffer typeof*{dummy}", + "{result} = OpBitcast %ptr_type {addr}", + addr = in(reg) &self.addr, + dummy = in(reg) &dummy, + result = out(reg) result, + ); + result + } + } + + /// Creates a null physical pointer. + pub fn null() -> Self { + Self { + addr: glam::UVec2::ZERO, + _marker: PhantomData, + } + } + + /// Returns `true` if the pointer is null. + pub fn is_null(self) -> bool { + self.addr == glam::UVec2::ZERO + } + + /// Casts to a pointer of another type. + pub fn cast(self) -> PhysicalPtr { + PhysicalPtr { addr: self.addr, _marker: PhantomData } + } + + /// Returns `None` if the pointer is null, or else returns a shared reference to the value wrapped in `Some`. + pub unsafe fn as_ref<'a>(self) -> Option<&'a T> { + self.is_null().then_some(unsafe { self.as_ref_unchecked() }) + } + + /// Returns `None` if the pointer is null, or else returns a mutable reference to the value wrapped in `Some`. + pub unsafe fn as_mut<'a>(self) -> Option<&'a mut T> { + self.is_null().then_some(unsafe { self.as_mut_unchecked() }) + } + + /// Returns a shared reference to the value behind the pointer. + pub unsafe fn as_ref_unchecked<'a>(self) -> &'a T { + unsafe { &*self.get() } + } + + /// Returns a mutable reference to the value behind the pointer. + pub unsafe fn as_mut_unchecked<'a>(self) -> &'a mut T { + unsafe { &mut *self.get() } + } + + /// Gets the address portion of the pointer. All physical pointers are considered to have global provenance. + pub fn addr(self) -> u64 { + unsafe { core::mem::transmute(self.addr) } + } + + /// Forms a physical pointer from an address. All physical pointers are considered to have global provenance. + pub fn from_addr(addr: u64) -> Self { + Self { + addr: unsafe { core::mem::transmute(addr) }, + _marker: PhantomData, + } + } + + /// Creates a new pointer by mapping `self`’s address to a new one. + pub fn map_addr(self, f: impl FnOnce(u64) -> u64) -> Self { + Self::from_addr(f(self.addr())) + } + + /// Adds a signed offset to a pointer. + pub unsafe fn offset(self, count: i64) -> Self { + unsafe { self.byte_offset(count * core::mem::size_of::() as i64) } + } + + /// Adds a signed offset in bytes to a pointer. + pub unsafe fn byte_offset(self, count: i64) -> Self { + self.map_addr(|addr| addr.overflowing_add_signed(count).0) + } +}