diff --git a/compiler/rustc_codegen_cranelift/src/base.rs b/compiler/rustc_codegen_cranelift/src/base.rs index 0b641ba64b7a4..1b68c6535da5f 100644 --- a/compiler/rustc_codegen_cranelift/src/base.rs +++ b/compiler/rustc_codegen_cranelift/src/base.rs @@ -407,6 +407,18 @@ fn codegen_fn_body(fx: &mut FunctionCx<'_, '_, '_>, start_block: Block) { source_info.span, ) } + AssertKind::InvalidEnumConstruction(source) => { + let source = codegen_operand(fx, source).load_scalar(fx); + let location = fx.get_caller_location(source_info).load_scalar(fx); + + codegen_panic_inner( + fx, + rustc_hir::LangItem::PanicInvalidEnumConstruction, + &[source, location], + *unwind, + source_info.span, + ) + } _ => { let location = fx.get_caller_location(source_info).load_scalar(fx); diff --git a/compiler/rustc_codegen_ssa/src/mir/block.rs b/compiler/rustc_codegen_ssa/src/mir/block.rs index 43b87171d510d..f064104042883 100644 --- a/compiler/rustc_codegen_ssa/src/mir/block.rs +++ b/compiler/rustc_codegen_ssa/src/mir/block.rs @@ -777,6 +777,12 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> { // `#[track_caller]` adds an implicit argument. (LangItem::PanicNullPointerDereference, vec![location]) } + AssertKind::InvalidEnumConstruction(source) => { + let source = self.codegen_operand(bx, source).immediate(); + // It's `fn panic_invalid_enum_construction(source: u128)`, + // `#[track_caller]` adds an implicit argument. + (LangItem::PanicInvalidEnumConstruction, vec![source, location]) + } _ => { // It's `pub fn panic_...()` and `#[track_caller]` adds an implicit argument. (msg.panic_function(), vec![location]) diff --git a/compiler/rustc_const_eval/src/const_eval/machine.rs b/compiler/rustc_const_eval/src/const_eval/machine.rs index a68dcf2998866..a4c5ca09a72a3 100644 --- a/compiler/rustc_const_eval/src/const_eval/machine.rs +++ b/compiler/rustc_const_eval/src/const_eval/machine.rs @@ -508,6 +508,7 @@ impl<'tcx> interpret::Machine<'tcx> for CompileTimeMachine<'tcx> { found: eval_to_int(found)?, }, NullPointerDereference => NullPointerDereference, + InvalidEnumConstruction(source) => InvalidEnumConstruction(eval_to_int(source)?), }; Err(ConstEvalErrKind::AssertFailure(err)).into() } diff --git a/compiler/rustc_hir/src/lang_items.rs b/compiler/rustc_hir/src/lang_items.rs index 21d36ed54cdf4..b537dceaddd8b 100644 --- a/compiler/rustc_hir/src/lang_items.rs +++ b/compiler/rustc_hir/src/lang_items.rs @@ -310,6 +310,7 @@ language_item_table! { PanicAsyncGenFnResumedPanic, sym::panic_const_async_gen_fn_resumed_panic, panic_const_async_gen_fn_resumed_panic, Target::Fn, GenericRequirement::None; PanicGenFnNonePanic, sym::panic_const_gen_fn_none_panic, panic_const_gen_fn_none_panic, Target::Fn, GenericRequirement::None; PanicNullPointerDereference, sym::panic_null_pointer_dereference, panic_null_pointer_dereference, Target::Fn, GenericRequirement::None; + PanicInvalidEnumConstruction, sym::panic_invalid_enum_construction, panic_invalid_enum_construction, Target::Fn, GenericRequirement::None; PanicCoroutineResumedDrop, sym::panic_const_coroutine_resumed_drop, panic_const_coroutine_resumed_drop, Target::Fn, GenericRequirement::None; PanicAsyncFnResumedDrop, sym::panic_const_async_fn_resumed_drop, panic_const_async_fn_resumed_drop, Target::Fn, GenericRequirement::None; PanicAsyncGenFnResumedDrop, sym::panic_const_async_gen_fn_resumed_drop, panic_const_async_gen_fn_resumed_drop, Target::Fn, GenericRequirement::None; diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 3d27e587b6cb4..7aa4743241202 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -17,6 +17,9 @@ middle_assert_gen_resume_after_drop = `gen` fn or block cannot be further iterat middle_assert_gen_resume_after_panic = `gen` fn or block cannot be further iterated on after it panicked +middle_assert_invalid_enum_construction = + trying to construct an enum from an invalid value `{$source}` + middle_assert_misaligned_ptr_deref = misaligned pointer dereference: address must be a multiple of {$required} but is {$found} diff --git a/compiler/rustc_middle/src/mir/syntax.rs b/compiler/rustc_middle/src/mir/syntax.rs index bb068f3821db8..802583f283dad 100644 --- a/compiler/rustc_middle/src/mir/syntax.rs +++ b/compiler/rustc_middle/src/mir/syntax.rs @@ -1075,6 +1075,7 @@ pub enum AssertKind { ResumedAfterDrop(CoroutineKind), MisalignedPointerDereference { required: O, found: O }, NullPointerDereference, + InvalidEnumConstruction(O), } #[derive(Clone, Debug, PartialEq, TyEncodable, TyDecodable, Hash, HashStable)] diff --git a/compiler/rustc_middle/src/mir/terminator.rs b/compiler/rustc_middle/src/mir/terminator.rs index 0834fa8844c00..4034a3a06e943 100644 --- a/compiler/rustc_middle/src/mir/terminator.rs +++ b/compiler/rustc_middle/src/mir/terminator.rs @@ -208,6 +208,7 @@ impl AssertKind { LangItem::PanicGenFnNonePanic } NullPointerDereference => LangItem::PanicNullPointerDereference, + InvalidEnumConstruction(_) => LangItem::PanicInvalidEnumConstruction, ResumedAfterDrop(CoroutineKind::Coroutine(_)) => LangItem::PanicCoroutineResumedDrop, ResumedAfterDrop(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { LangItem::PanicAsyncFnResumedDrop @@ -284,6 +285,9 @@ impl AssertKind { ) } NullPointerDereference => write!(f, "\"null pointer dereference occurred\""), + InvalidEnumConstruction(source) => { + write!(f, "\"trying to construct an enum from an invalid value {{}}\", {source:?}") + } ResumedAfterReturn(CoroutineKind::Coroutine(_)) => { write!(f, "\"coroutine resumed after completion\"") } @@ -367,6 +371,7 @@ impl AssertKind { middle_assert_coroutine_resume_after_panic } NullPointerDereference => middle_assert_null_ptr_deref, + InvalidEnumConstruction(_) => middle_assert_invalid_enum_construction, ResumedAfterDrop(CoroutineKind::Desugared(CoroutineDesugaring::Async, _)) => { middle_assert_async_resume_after_drop } @@ -420,6 +425,9 @@ impl AssertKind { add!("required", format!("{required:#?}")); add!("found", format!("{found:#?}")); } + InvalidEnumConstruction(source) => { + add!("source", format!("{source:#?}")); + } } } } diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs index 1777756174bf9..929ebe1aee181 100644 --- a/compiler/rustc_middle/src/mir/visit.rs +++ b/compiler/rustc_middle/src/mir/visit.rs @@ -642,7 +642,7 @@ macro_rules! make_mir_visitor { self.visit_operand(l, location); self.visit_operand(r, location); } - OverflowNeg(op) | DivisionByZero(op) | RemainderByZero(op) => { + OverflowNeg(op) | DivisionByZero(op) | RemainderByZero(op) | InvalidEnumConstruction(op) => { self.visit_operand(op, location); } ResumedAfterReturn(_) | ResumedAfterPanic(_) | NullPointerDereference | ResumedAfterDrop(_) => { diff --git a/compiler/rustc_mir_transform/src/check_enums.rs b/compiler/rustc_mir_transform/src/check_enums.rs new file mode 100644 index 0000000000000..e06e0c6122e8d --- /dev/null +++ b/compiler/rustc_mir_transform/src/check_enums.rs @@ -0,0 +1,501 @@ +use rustc_abi::{Scalar, Size, TagEncoding, Variants, WrappingRange}; +use rustc_hir::LangItem; +use rustc_index::IndexVec; +use rustc_middle::bug; +use rustc_middle::mir::visit::Visitor; +use rustc_middle::mir::*; +use rustc_middle::ty::layout::PrimitiveExt; +use rustc_middle::ty::{self, Ty, TyCtxt, TypingEnv}; +use rustc_session::Session; +use tracing::debug; + +/// This pass inserts checks for a valid enum discriminant where they are most +/// likely to find UB, because checking everywhere like Miri would generate too +/// much MIR. +pub(super) struct CheckEnums; + +impl<'tcx> crate::MirPass<'tcx> for CheckEnums { + fn is_enabled(&self, sess: &Session) -> bool { + sess.ub_checks() + } + + fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { + // This pass emits new panics. If for whatever reason we do not have a panic + // implementation, running this pass may cause otherwise-valid code to not compile. + if tcx.lang_items().get(LangItem::PanicImpl).is_none() { + return; + } + + let typing_env = body.typing_env(tcx); + let basic_blocks = body.basic_blocks.as_mut(); + let local_decls = &mut body.local_decls; + + // This operation inserts new blocks. Each insertion changes the Location for all + // statements/blocks after. Iterating or visiting the MIR in order would require updating + // our current location after every insertion. By iterating backwards, we dodge this issue: + // The only Locations that an insertion changes have already been handled. + for block in basic_blocks.indices().rev() { + for statement_index in (0..basic_blocks[block].statements.len()).rev() { + let location = Location { block, statement_index }; + let statement = &basic_blocks[block].statements[statement_index]; + let source_info = statement.source_info; + + let mut finder = EnumFinder::new(tcx, local_decls, typing_env); + finder.visit_statement(statement, location); + + for check in finder.into_found_enums() { + debug!("Inserting enum check"); + let new_block = split_block(basic_blocks, location); + + match check { + EnumCheckType::Direct { source_op, discr, op_size, valid_discrs } => { + insert_direct_enum_check( + tcx, + local_decls, + basic_blocks, + block, + source_op, + discr, + op_size, + valid_discrs, + source_info, + new_block, + ) + } + EnumCheckType::Uninhabited => insert_uninhabited_enum_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_info, + new_block, + ), + EnumCheckType::WithNiche { + source_op, + discr, + op_size, + offset, + valid_range, + } => insert_niche_check( + tcx, + local_decls, + &mut basic_blocks[block], + source_op, + valid_range, + discr, + op_size, + offset, + source_info, + new_block, + ), + } + } + } + } + } + + fn is_required(&self) -> bool { + true + } +} + +/// Represent the different kind of enum checks we can insert. +enum EnumCheckType<'tcx> { + /// We know we try to create an uninhabited enum from an inhabited variant. + Uninhabited, + /// We know the enum does no niche optimizations and can thus easily compute + /// the valid discriminants. + Direct { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + valid_discrs: Vec, + }, + /// We try to construct an enum that has a niche. + WithNiche { + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + valid_range: WrappingRange, + }, +} + +struct TyAndSize<'tcx> { + pub ty: Ty<'tcx>, + pub size: Size, +} + +/// A [Visitor] that finds the construction of enums and evaluates which checks +/// we should apply. +struct EnumFinder<'a, 'tcx> { + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + enums: Vec>, +} + +impl<'a, 'tcx> EnumFinder<'a, 'tcx> { + fn new( + tcx: TyCtxt<'tcx>, + local_decls: &'a mut LocalDecls<'tcx>, + typing_env: TypingEnv<'tcx>, + ) -> Self { + EnumFinder { tcx, local_decls, typing_env, enums: Vec::new() } + } + + /// Returns the found enum creations and which checks should be inserted. + fn into_found_enums(self) -> Vec> { + self.enums + } +} + +impl<'a, 'tcx> Visitor<'tcx> for EnumFinder<'a, 'tcx> { + fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { + if let Rvalue::Cast(CastKind::Transmute, op, ty) = rvalue { + let ty::Adt(adt_def, _) = ty.kind() else { + return; + }; + if !adt_def.is_enum() { + return; + } + + let Ok(enum_layout) = self.tcx.layout_of(self.typing_env.as_query_input(*ty)) else { + return; + }; + let Ok(op_layout) = self + .tcx + .layout_of(self.typing_env.as_query_input(op.ty(self.local_decls, self.tcx))) + else { + return; + }; + + match enum_layout.variants { + Variants::Empty if op_layout.is_uninhabited() => return, + // An empty enum that tries to be constructed from an inhabited value, this + // is never correct. + Variants::Empty => { + // The enum layout is uninhabited but we construct it from sth inhabited. + // This is always UB. + self.enums.push(EnumCheckType::Uninhabited); + } + // Construction of Single value enums is always fine. + Variants::Single { .. } => {} + // Construction of an enum with multiple variants but no niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Direct, + tag: Scalar::Initialized { value, .. }, + .. + } => { + let valid_discrs = + adt_def.discriminants(self.tcx).map(|(_, discr)| discr.val).collect(); + + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::Direct { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + valid_discrs, + }); + } + // Construction of an enum with multiple variants and niche optimizations. + Variants::Multiple { + tag_encoding: TagEncoding::Niche { .. }, + tag: Scalar::Initialized { value, valid_range, .. }, + tag_field, + .. + } => { + let discr = + TyAndSize { ty: value.to_int_ty(self.tcx), size: value.size(&self.tcx) }; + self.enums.push(EnumCheckType::WithNiche { + source_op: op.to_copy(), + discr, + op_size: op_layout.size, + offset: enum_layout.fields.offset(tag_field.as_usize()), + valid_range, + }); + } + _ => return, + } + + self.super_rvalue(rvalue, location); + } + } +} + +fn split_block( + basic_blocks: &mut IndexVec>, + location: Location, +) -> BasicBlock { + let block_data = &mut basic_blocks[location.block]; + + // Drain every statement after this one and move the current terminator to a new basic block. + let new_block = BasicBlockData { + statements: block_data.statements.split_off(location.statement_index), + terminator: block_data.terminator.take(), + is_cleanup: block_data.is_cleanup, + }; + + basic_blocks.push(new_block) +} + +/// Inserts the cast of an operand (any type) to a u128 value that holds the discriminant value. +fn insert_discr_cast_to_u128<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Option, + source_info: SourceInfo, +) -> Place<'tcx> { + let get_ty_for_size = |tcx: TyCtxt<'tcx>, size: Size| -> Ty<'tcx> { + match size.bytes() { + 1 => tcx.types.u8, + 2 => tcx.types.u16, + 4 => tcx.types.u32, + 8 => tcx.types.u64, + 16 => tcx.types.u128, + invalid => bug!("Found discriminant with invalid size, has {} bytes", invalid), + } + }; + + let (cast_kind, discr_ty_bits) = if discr.size.bytes() < op_size.bytes() { + // The discriminant is less wide than the operand, cast the operand into + // [MaybeUninit; N] and then index into it. + let mu = Ty::new_maybe_uninit(tcx, tcx.types.u8); + let array_len = op_size.bytes(); + let mu_array_ty = Ty::new_array(tcx, mu, array_len); + let mu_array = + local_decls.push(LocalDecl::with_source_info(mu_array_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, mu_array_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((mu_array, rvalue))), + }); + + // Index into the array of MaybeUninit to get something that is actually + // as wide as the discriminant. + let offset = offset.unwrap_or(Size::ZERO); + let smaller_mu_array = mu_array.project_deeper( + &[ProjectionElem::Subslice { + from: offset.bytes(), + to: offset.bytes() + discr.size.bytes(), + from_end: false, + }], + tcx, + ); + + (CastKind::Transmute, Operand::Copy(smaller_mu_array)) + } else { + let operand_int_ty = get_ty_for_size(tcx, op_size); + + let op_as_int = + local_decls.push(LocalDecl::with_source_info(operand_int_ty, source_info)).into(); + let rvalue = Rvalue::Cast(CastKind::Transmute, source_op, operand_int_ty); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((op_as_int, rvalue))), + }); + + (CastKind::IntToInt, Operand::Copy(op_as_int)) + }; + + // Cast the resulting value to the actual discriminant integer type. + let rvalue = Rvalue::Cast(cast_kind, discr_ty_bits, discr.ty); + let discr_in_discr_ty = + local_decls.push(LocalDecl::with_source_info(discr.ty, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new((discr_in_discr_ty, rvalue))), + }); + + // Cast the discriminant to a u128 (base for comparisions of enum discriminants). + let const_u128 = Ty::new_uint(tcx, ty::UintTy::U128); + let rvalue = Rvalue::Cast(CastKind::IntToInt, Operand::Copy(discr_in_discr_ty), const_u128); + let discr = local_decls.push(LocalDecl::with_source_info(const_u128, source_info)).into(); + block_data + .statements + .push(Statement { source_info, kind: StatementKind::Assign(Box::new((discr, rvalue))) }); + + discr +} + +fn insert_direct_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + basic_blocks: &mut IndexVec>, + current_block: BasicBlock, + source_op: Operand<'tcx>, + discr: TyAndSize<'tcx>, + op_size: Size, + discriminants: Vec, + source_info: SourceInfo, + new_block: BasicBlock, +) { + // Insert a new target block that is branched to in case of an invalid discriminant. + let invalid_discr_block_data = BasicBlockData::new(None, false); + let invalid_discr_block = basic_blocks.push(invalid_discr_block_data); + let block_data = &mut basic_blocks[current_block]; + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + None, + source_info, + ); + + // Branch based on the discriminant value. + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::SwitchInt { + discr: Operand::Copy(discr), + targets: SwitchTargets::new( + discriminants.into_iter().map(|discr| (discr, new_block)), + invalid_discr_block, + ), + }, + }); + + // Abort in case of an invalid enum discriminant. + basic_blocks[invalid_discr_block].terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + })), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_uninhabited_enum_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + block_data: &mut BasicBlockData<'tcx>, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::Use(Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_bool(false), tcx.types.bool), + }))), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Constant(Box::new( + ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(0), tcx.types.u128), + }, + )))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} + +fn insert_niche_check<'tcx>( + tcx: TyCtxt<'tcx>, + local_decls: &mut IndexVec>, + block_data: &mut BasicBlockData<'tcx>, + source_op: Operand<'tcx>, + valid_range: WrappingRange, + discr: TyAndSize<'tcx>, + op_size: Size, + offset: Size, + source_info: SourceInfo, + new_block: BasicBlock, +) { + let discr = insert_discr_cast_to_u128( + tcx, + local_decls, + block_data, + source_op, + discr, + op_size, + Some(offset), + source_info, + ); + + // Compare the discriminant agains the valid_range. + let start_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val(ConstValue::from_u128(valid_range.start), tcx.types.u128), + })); + let end_start_diff_const = Operand::Constant(Box::new(ConstOperand { + span: source_info.span, + user_ty: None, + const_: Const::Val( + ConstValue::from_u128(u128::wrapping_sub(valid_range.end, valid_range.start)), + tcx.types.u128, + ), + })); + + let discr_diff: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.u128, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + discr_diff, + Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(discr), start_const))), + ))), + }); + + let is_ok: Place<'_> = + local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); + block_data.statements.push(Statement { + source_info, + kind: StatementKind::Assign(Box::new(( + is_ok, + Rvalue::BinaryOp( + // This is a `WrappingRange`, so make sure to get the wrapping right. + BinOp::Le, + Box::new((Operand::Copy(discr_diff), end_start_diff_const)), + ), + ))), + }); + + block_data.terminator = Some(Terminator { + source_info, + kind: TerminatorKind::Assert { + cond: Operand::Copy(is_ok), + expected: true, + target: new_block, + msg: Box::new(AssertKind::InvalidEnumConstruction(Operand::Copy(discr))), + // This calls panic_invalid_enum_construction, which is #[rustc_nounwind]. + // We never want to insert an unwind into unsafe code, because unwinding could + // make a failing UB check turn into much worse UB when we start unwinding. + unwind: UnwindAction::Unreachable, + }, + }); +} diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index 572ad585c8c87..6b32254b0514c 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -117,6 +117,7 @@ declare_passes! { mod check_inline : CheckForceInline; mod check_call_recursion : CheckCallRecursion, CheckDropRecursion; mod check_alignment : CheckAlignment; + mod check_enums : CheckEnums; mod check_const_item_mutation : CheckConstItemMutation; mod check_null : CheckNull; mod check_packed_ref : CheckPackedRef; @@ -666,6 +667,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<' // Add some UB checks before any UB gets optimized away. &check_alignment::CheckAlignment, &check_null::CheckNull, + &check_enums::CheckEnums, // Before inlining: trim down MIR with passes to reduce inlining work. // Has to be done before inlining, otherwise actual call will be almost always inlined. diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 173030e0326e8..ce61d37947351 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -834,6 +834,9 @@ impl<'a, 'tcx> MirVisitor<'tcx> for MirUsedCollector<'a, 'tcx> { mir::AssertKind::NullPointerDereference => { push_mono_lang_item(self, LangItem::PanicNullPointerDereference); } + mir::AssertKind::InvalidEnumConstruction(_) => { + push_mono_lang_item(self, LangItem::PanicInvalidEnumConstruction); + } _ => { push_mono_lang_item(self, msg.panic_function()); } diff --git a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs index 42b3e59b73ab9..85e71ed2c2550 100644 --- a/compiler/rustc_smir/src/rustc_smir/convert/mir.rs +++ b/compiler/rustc_smir/src/rustc_smir/convert/mir.rs @@ -506,6 +506,9 @@ impl<'tcx> Stable<'tcx> for mir::AssertMessage<'tcx> { AssertKind::NullPointerDereference => { stable_mir::mir::AssertMessage::NullPointerDereference } + AssertKind::InvalidEnumConstruction(source) => { + stable_mir::mir::AssertMessage::InvalidEnumConstruction(source.stable(tables)) + } } } } diff --git a/compiler/rustc_smir/src/stable_mir/mir/body.rs b/compiler/rustc_smir/src/stable_mir/mir/body.rs index 660cd7db0800d..e4b7659ce7f91 100644 --- a/compiler/rustc_smir/src/stable_mir/mir/body.rs +++ b/compiler/rustc_smir/src/stable_mir/mir/body.rs @@ -270,6 +270,7 @@ pub enum AssertMessage { ResumedAfterDrop(CoroutineKind), MisalignedPointerDereference { required: Operand, found: Operand }, NullPointerDereference, + InvalidEnumConstruction(Operand), } impl AssertMessage { @@ -342,6 +343,9 @@ impl AssertMessage { Ok("misaligned pointer dereference") } AssertMessage::NullPointerDereference => Ok("null pointer dereference occurred"), + AssertMessage::InvalidEnumConstruction(_) => { + Ok("trying to construct an enum from an invalid value") + } } } } diff --git a/compiler/rustc_smir/src/stable_mir/mir/pretty.rs b/compiler/rustc_smir/src/stable_mir/mir/pretty.rs index ba20651f993d1..b068a9a1081f1 100644 --- a/compiler/rustc_smir/src/stable_mir/mir/pretty.rs +++ b/compiler/rustc_smir/src/stable_mir/mir/pretty.rs @@ -313,6 +313,10 @@ fn pretty_assert_message(writer: &mut W, msg: &AssertMessage) -> io::R AssertMessage::NullPointerDereference => { write!(writer, "\"null pointer dereference occurred\"") } + AssertMessage::InvalidEnumConstruction(op) => { + let pretty_op = pretty_operand(op); + write!(writer, "\"trying to construct an enum from an invalid value {{}}\",{pretty_op}") + } AssertMessage::ResumedAfterReturn(_) | AssertMessage::ResumedAfterPanic(_) | AssertMessage::ResumedAfterDrop(_) => { diff --git a/compiler/rustc_smir/src/stable_mir/mir/visit.rs b/compiler/rustc_smir/src/stable_mir/mir/visit.rs index e21dc11eea9ca..b7dd433eb0938 100644 --- a/compiler/rustc_smir/src/stable_mir/mir/visit.rs +++ b/compiler/rustc_smir/src/stable_mir/mir/visit.rs @@ -367,7 +367,8 @@ macro_rules! make_mir_visitor { } AssertMessage::OverflowNeg(op) | AssertMessage::DivisionByZero(op) - | AssertMessage::RemainderByZero(op) => { + | AssertMessage::RemainderByZero(op) + | AssertMessage::InvalidEnumConstruction(op) => { self.visit_operand(op, location); } AssertMessage::ResumedAfterReturn(_) diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index d66f98871b97d..95366d3bd9292 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1575,6 +1575,7 @@ symbols! { panic_implementation, panic_in_cleanup, panic_info, + panic_invalid_enum_construction, panic_location, panic_misaligned_pointer_dereference, panic_nounwind, diff --git a/library/core/src/panicking.rs b/library/core/src/panicking.rs index d87f4814f0218..812bc5e614572 100644 --- a/library/core/src/panicking.rs +++ b/library/core/src/panicking.rs @@ -314,6 +314,22 @@ fn panic_null_pointer_dereference() -> ! { ) } +#[cfg_attr(not(feature = "panic_immediate_abort"), inline(never), cold, optimize(size))] +#[cfg_attr(feature = "panic_immediate_abort", inline)] +#[track_caller] +#[lang = "panic_invalid_enum_construction"] // needed by codegen for panic on invalid enum construction. +#[rustc_nounwind] // `CheckEnums` MIR pass requires this function to never unwind +fn panic_invalid_enum_construction(source: u128) -> ! { + if cfg!(feature = "panic_immediate_abort") { + super::intrinsics::abort() + } + + panic_nounwind_fmt( + format_args!("trying to construct an enum from an invalid value {source:#x}"), + /* force_no_backtrace */ false, + ) +} + /// Panics because we cannot unwind out of a function. /// /// This is a separate function to avoid the codesize impact of each crate containing the string to diff --git a/src/tools/miri/src/lib.rs b/src/tools/miri/src/lib.rs index 51ec19af52a07..12a9a33175038 100644 --- a/src/tools/miri/src/lib.rs +++ b/src/tools/miri/src/lib.rs @@ -165,7 +165,7 @@ pub const MIRI_DEFAULT_ARGS: &[&str] = &[ "-Zmir-emit-retag", "-Zmir-preserve-ub", "-Zmir-opt-level=0", - "-Zmir-enable-passes=-CheckAlignment,-CheckNull", + "-Zmir-enable-passes=-CheckAlignment,-CheckNull,-CheckEnums", // Deduplicating diagnostics means we miss events when tracking what happens during an // execution. Let's not do that. "-Zdeduplicate-diagnostics=no", diff --git a/tests/ui/mir/enum/convert_non_enum_break.rs b/tests/ui/mir/enum/convert_non_enum_break.rs new file mode 100644 index 0000000000000..de062c39907a7 --- /dev/null +++ b/tests/ui/mir/enum/convert_non_enum_break.rs @@ -0,0 +1,20 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x10000 + +#[allow(dead_code)] +#[repr(u32)] +enum Foo { + A, + B, +} + +#[allow(dead_code)] +struct Bar { + a: u16, + b: u16, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::<_, Foo>(Bar { a: 0, b: 1 }) }; +} diff --git a/tests/ui/mir/enum/convert_non_enum_niche_break.rs b/tests/ui/mir/enum/convert_non_enum_niche_break.rs new file mode 100644 index 0000000000000..9ff4849c5b1f3 --- /dev/null +++ b/tests/ui/mir/enum/convert_non_enum_niche_break.rs @@ -0,0 +1,27 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x5 + +#[allow(dead_code)] +#[repr(u16)] +enum Mix { + A, + B(u16), +} + +#[allow(dead_code)] +enum Nested { + C(Mix), + D, + E, +} + +#[allow(dead_code)] +struct Bar { + a: u16, + b: u16, +} + +fn main() { + let _val: Nested = unsafe { std::mem::transmute::<_, Nested>(Bar { a: 5, b: 0 }) }; +} diff --git a/tests/ui/mir/enum/convert_non_enum_niche_ok.rs b/tests/ui/mir/enum/convert_non_enum_niche_ok.rs new file mode 100644 index 0000000000000..24027da54589a --- /dev/null +++ b/tests/ui/mir/enum/convert_non_enum_niche_ok.rs @@ -0,0 +1,29 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Mix { + A, + B(u16), +} + +#[allow(dead_code)] +enum Nested { + C(Mix), + D, + E, +} + +#[allow(dead_code)] +struct Bar { + a: u16, + b: u16, +} + +fn main() { + let _val: Nested = unsafe { std::mem::transmute::<_, Nested>(Bar { a: 0, b: 0 }) }; + let _val: Nested = unsafe { std::mem::transmute::<_, Nested>(Bar { a: 1, b: 0 }) }; + let _val: Nested = unsafe { std::mem::transmute::<_, Nested>(Bar { a: 2, b: 0 }) }; + let _val: Nested = unsafe { std::mem::transmute::<_, Nested>(Bar { a: 3, b: 0 }) }; +} diff --git a/tests/ui/mir/enum/convert_non_enum_ok.rs b/tests/ui/mir/enum/convert_non_enum_ok.rs new file mode 100644 index 0000000000000..37fc64342ca93 --- /dev/null +++ b/tests/ui/mir/enum/convert_non_enum_ok.rs @@ -0,0 +1,20 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u32)] +enum Foo { + A, + B, +} + +#[allow(dead_code)] +struct Bar { + a: u16, + b: u16, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::<_, Foo>(Bar { a: 0, b: 0 }) }; + let _val: Foo = unsafe { std::mem::transmute::<_, Foo>(Bar { a: 1, b: 0 }) }; +} diff --git a/tests/ui/mir/enum/niche_option_tuple_break.rs b/tests/ui/mir/enum/niche_option_tuple_break.rs new file mode 100644 index 0000000000000..43eef3a4cc5f0 --- /dev/null +++ b/tests/ui/mir/enum/niche_option_tuple_break.rs @@ -0,0 +1,20 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x3 + +#[allow(dead_code)] +enum Foo { + A, + B, +} + +#[allow(dead_code)] +struct Bar { + a: usize, + b: usize, +} + +fn main() { + let _val: Option<(usize, Foo)> = + unsafe { std::mem::transmute::<_, Option<(usize, Foo)>>(Bar { a: 3, b: 3 }) }; +} diff --git a/tests/ui/mir/enum/niche_option_tuple_ok.rs b/tests/ui/mir/enum/niche_option_tuple_ok.rs new file mode 100644 index 0000000000000..71c885c7edbf8 --- /dev/null +++ b/tests/ui/mir/enum/niche_option_tuple_ok.rs @@ -0,0 +1,21 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +enum Foo { + A, + B, +} + +#[allow(dead_code)] +struct Bar { + a: usize, + b: usize, +} + +fn main() { + let _val: Option<(usize, Foo)> = + unsafe { std::mem::transmute::<_, Option<(usize, Foo)>>(Bar { a: 0, b: 0 }) }; + let _val: Option<(usize, Foo)> = + unsafe { std::mem::transmute::<_, Option<(usize, Foo)>>(Bar { a: 1, b: 0 }) }; +} diff --git a/tests/ui/mir/enum/numbered_variants_break.rs b/tests/ui/mir/enum/numbered_variants_break.rs new file mode 100644 index 0000000000000..e3e71dc8aec4f --- /dev/null +++ b/tests/ui/mir/enum/numbered_variants_break.rs @@ -0,0 +1,13 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x3 + +#[allow(dead_code)] +enum Foo { + A, + B, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::(3) }; +} diff --git a/tests/ui/mir/enum/numbered_variants_ok.rs b/tests/ui/mir/enum/numbered_variants_ok.rs new file mode 100644 index 0000000000000..995a2f6511b13 --- /dev/null +++ b/tests/ui/mir/enum/numbered_variants_ok.rs @@ -0,0 +1,13 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +enum Foo { + A, + B, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::(0) }; + let _val: Foo = unsafe { std::mem::transmute::(1) }; +} diff --git a/tests/ui/mir/enum/option_with_bigger_niche_break.rs b/tests/ui/mir/enum/option_with_bigger_niche_break.rs new file mode 100644 index 0000000000000..c66614b845b5d --- /dev/null +++ b/tests/ui/mir/enum/option_with_bigger_niche_break.rs @@ -0,0 +1,14 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x0 + +#[repr(u32)] +#[allow(dead_code)] +enum Foo { + A = 2, + B, +} + +fn main() { + let _val: Option = unsafe { std::mem::transmute::>(0) }; +} diff --git a/tests/ui/mir/enum/option_with_bigger_niche_ok.rs b/tests/ui/mir/enum/option_with_bigger_niche_ok.rs new file mode 100644 index 0000000000000..1d44ffd28fcee --- /dev/null +++ b/tests/ui/mir/enum/option_with_bigger_niche_ok.rs @@ -0,0 +1,14 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[repr(u32)] +#[allow(dead_code)] +enum Foo { + A = 2, + B, +} + +fn main() { + let _val: Option = unsafe { std::mem::transmute::>(2) }; + let _val: Option = unsafe { std::mem::transmute::>(3) }; +} diff --git a/tests/ui/mir/enum/plain_no_data_break.rs b/tests/ui/mir/enum/plain_no_data_break.rs new file mode 100644 index 0000000000000..db68e752479dd --- /dev/null +++ b/tests/ui/mir/enum/plain_no_data_break.rs @@ -0,0 +1,14 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[repr(u32)] +#[allow(dead_code)] +enum Foo { + A = 2, + B, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::(1) }; +} diff --git a/tests/ui/mir/enum/plain_no_data_ok.rs b/tests/ui/mir/enum/plain_no_data_ok.rs new file mode 100644 index 0000000000000..bbdc18f96dc41 --- /dev/null +++ b/tests/ui/mir/enum/plain_no_data_ok.rs @@ -0,0 +1,14 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[repr(u32)] +#[allow(dead_code)] +enum Foo { + A = 2, + B, +} + +fn main() { + let _val: Foo = unsafe { std::mem::transmute::(2) }; + let _val: Foo = unsafe { std::mem::transmute::(3) }; +} diff --git a/tests/ui/mir/enum/single_ok.rs b/tests/ui/mir/enum/single_ok.rs new file mode 100644 index 0000000000000..06b5a237c6874 --- /dev/null +++ b/tests/ui/mir/enum/single_ok.rs @@ -0,0 +1,11 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +enum Single { + A +} + +fn main() { + let _val: Single = unsafe { std::mem::transmute::<(), Single>(()) }; +} diff --git a/tests/ui/mir/enum/single_with_repr_break.rs b/tests/ui/mir/enum/single_with_repr_break.rs new file mode 100644 index 0000000000000..5a4ec85a9b555 --- /dev/null +++ b/tests/ui/mir/enum/single_with_repr_break.rs @@ -0,0 +1,13 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x1 + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A +} + +fn main() { + let _val: Single = unsafe { std::mem::transmute::(1) }; +} diff --git a/tests/ui/mir/enum/single_with_repr_ok.rs b/tests/ui/mir/enum/single_with_repr_ok.rs new file mode 100644 index 0000000000000..b0ed2bad6608b --- /dev/null +++ b/tests/ui/mir/enum/single_with_repr_ok.rs @@ -0,0 +1,12 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Single { + A +} + +fn main() { + let _val: Single = unsafe { std::mem::transmute::(0) }; +} diff --git a/tests/ui/mir/enum/with_niche_int_break.rs b/tests/ui/mir/enum/with_niche_int_break.rs new file mode 100644 index 0000000000000..0ec60a33564af --- /dev/null +++ b/tests/ui/mir/enum/with_niche_int_break.rs @@ -0,0 +1,21 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x4 + +#[allow(dead_code)] +#[repr(u16)] +enum Mix { + A, + B(u16), +} + +#[allow(dead_code)] +enum Nested { + C(Mix), + D, + E, +} + +fn main() { + let _val: Nested = unsafe { std::mem::transmute::(4) }; +} diff --git a/tests/ui/mir/enum/with_niche_int_ok.rs b/tests/ui/mir/enum/with_niche_int_ok.rs new file mode 100644 index 0000000000000..9a3ff3a73beb9 --- /dev/null +++ b/tests/ui/mir/enum/with_niche_int_ok.rs @@ -0,0 +1,23 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +#[repr(u16)] +enum Mix { + A, + B(u16), +} + +#[allow(dead_code)] +enum Nested { + C(Mix), + D, + E, +} + +fn main() { + let _val: Nested = unsafe { std::mem::transmute::(0) }; + let _val: Nested = unsafe { std::mem::transmute::(1) }; + let _val: Nested = unsafe { std::mem::transmute::(2) }; + let _val: Nested = unsafe { std::mem::transmute::(3) }; +} diff --git a/tests/ui/mir/enum/with_niche_ptr_ok.rs b/tests/ui/mir/enum/with_niche_ptr_ok.rs new file mode 100644 index 0000000000000..969d955f7a4b9 --- /dev/null +++ b/tests/ui/mir/enum/with_niche_ptr_ok.rs @@ -0,0 +1,14 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +fn main() { + let _val = unsafe { + std::mem::transmute::<*const usize, Option>(std::ptr::null()) + }; + let _val = unsafe { + std::mem::transmute::<*const usize, Option>(usize::MAX as *const _) + }; + let _val = unsafe { std::mem::transmute::>(0) }; + let _val = unsafe { std::mem::transmute::>(1) }; + let _val = unsafe { std::mem::transmute::>(usize::MAX) }; +} diff --git a/tests/ui/mir/enum/wrap_break.rs b/tests/ui/mir/enum/wrap_break.rs new file mode 100644 index 0000000000000..4491394ca5a34 --- /dev/null +++ b/tests/ui/mir/enum/wrap_break.rs @@ -0,0 +1,14 @@ +//@ run-fail +//@ compile-flags: -C debug-assertions +//@ error-pattern: trying to construct an enum from an invalid value 0x0 +#![feature(never_type)] +#![allow(invalid_value)] + +#[allow(dead_code)] +enum Wrap { + A(!), +} + +fn main() { + let _val: Wrap = unsafe { std::mem::transmute::<(), Wrap>(()) }; +} diff --git a/tests/ui/mir/enum/wrap_ok.rs b/tests/ui/mir/enum/wrap_ok.rs new file mode 100644 index 0000000000000..2881675c9ce6f --- /dev/null +++ b/tests/ui/mir/enum/wrap_ok.rs @@ -0,0 +1,12 @@ +//@ run-pass +//@ compile-flags: -C debug-assertions + +#[allow(dead_code)] +enum Wrap { + A(u32), +} + +fn main() { + let _val: Wrap = unsafe { std::mem::transmute::(2) }; + let _val: Wrap = unsafe { std::mem::transmute::(u32::MAX) }; +}