diff --git a/compiler/rustc_mir_transform/src/lib.rs b/compiler/rustc_mir_transform/src/lib.rs index bc2c6bd81aca9..5860101771a69 100644 --- a/compiler/rustc_mir_transform/src/lib.rs +++ b/compiler/rustc_mir_transform/src/lib.rs @@ -6,6 +6,7 @@ #![feature(file_buffered)] #![feature(if_let_guard)] #![feature(impl_trait_in_assoc_type)] +#![feature(iterator_try_collect)] #![feature(try_blocks)] #![feature(yeet_expr)] // tidy-alphabetical-end diff --git a/compiler/rustc_mir_transform/src/match_branches.rs b/compiler/rustc_mir_transform/src/match_branches.rs index 5e511f1a418b6..05d085fafe937 100644 --- a/compiler/rustc_mir_transform/src/match_branches.rs +++ b/compiler/rustc_mir_transform/src/match_branches.rs @@ -1,50 +1,34 @@ -use std::iter; - use rustc_abi::Integer; -use rustc_index::IndexSlice; +use rustc_const_eval::const_eval::mk_eval_cx_for_const_val; use rustc_middle::mir::*; use rustc_middle::ty::layout::{IntegerExt, TyAndLayout}; +use rustc_middle::ty::util::Discr; use rustc_middle::ty::{self, ScalarInt, Ty, TyCtxt}; -use tracing::instrument; use super::simplify::simplify_cfg; use crate::patch::MirPatch; +use crate::unreachable_prop::remove_successors_from_switch; +/// Unifies all targets into one basic block if each statement can have the same statement. pub(super) struct MatchBranchSimplification; impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { fn is_enabled(&self, sess: &rustc_session::Session) -> bool { - sess.mir_opt_level() >= 1 + // Enable only under -Zmir-opt-level=2 as this can make programs less debuggable. + sess.mir_opt_level() >= 2 } fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { let typing_env = body.typing_env(tcx); - let mut apply_patch = false; - let mut patch = MirPatch::new(body); - for (bb, bb_data) in body.basic_blocks.iter_enumerated() { - match &bb_data.terminator().kind { - TerminatorKind::SwitchInt { - discr: Operand::Copy(_) | Operand::Move(_), - targets, - .. - // We require that the possible target blocks don't contain this block. - } if !targets.all_targets().contains(&bb) => {} - // Only optimize switch int statements - _ => continue, - }; - - if SimplifyToIf.simplify(tcx, body, &mut patch, bb, typing_env).is_some() { - apply_patch = true; - continue; - } - if SimplifyToExp::default().simplify(tcx, body, &mut patch, bb, typing_env).is_some() { - apply_patch = true; + let mut changed = false; + for bb in body.basic_blocks.indices() { + if !candidate_match(body, bb) { continue; - } + }; + changed |= simplify_match(tcx, typing_env, body, bb) } - if apply_patch { - patch.apply(body); + if changed { simplify_cfg(tcx, body); } } @@ -54,222 +38,406 @@ impl<'tcx> crate::MirPass<'tcx> for MatchBranchSimplification { } } -trait SimplifyMatch<'tcx> { - /// Simplifies a match statement, returning `Some` if the simplification succeeds, `None` - /// otherwise. Generic code is written here, and we generally don't need a custom - /// implementation. - fn simplify( - &mut self, - tcx: TyCtxt<'tcx>, - body: &Body<'tcx>, - patch: &mut MirPatch<'tcx>, - switch_bb_idx: BasicBlock, - typing_env: ty::TypingEnv<'tcx>, - ) -> Option<()> { - let bbs = &body.basic_blocks; - let TerminatorKind::SwitchInt { discr, targets, .. } = - &bbs[switch_bb_idx].terminator().kind - else { - unreachable!(); - }; - - let discr_ty = discr.ty(body.local_decls(), tcx); - self.can_simplify(tcx, targets, typing_env, bbs, discr_ty)?; +struct SimplifyMatch<'tcx, 'a> { + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + patch: MirPatch<'tcx>, + body: &'a Body<'tcx>, + switch_bb: BasicBlock, + discr: &'a Operand<'tcx>, + discr_local: Option, + discr_ty: Ty<'tcx>, +} - // Take ownership of items now that we know we can optimize. - let discr = discr.clone(); +impl<'tcx, 'a> SimplifyMatch<'tcx, 'a> { + fn discr_local(&mut self) -> Local { + *self.discr_local.get_or_insert_with(|| { + // Introduce a temporary for the discriminant value. + let source_info = self.body.basic_blocks[self.switch_bb].terminator().source_info; + self.patch.new_temp(self.discr_ty, source_info.span) + }) + } - // Introduce a temporary for the discriminant value. - let source_info = bbs[switch_bb_idx].terminator().source_info; - let discr_local = patch.new_temp(discr_ty, source_info.span); + /// Unifies the assignments if all rvalues are constants and equal. + fn unify_if_equal_const( + &self, + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + otherwise: Option<&ConstOperand<'tcx>>, + ) -> Option> { + let (_, first_const, mut others) = split_first_case(consts, otherwise); + let first_scalar_int = first_const.const_.try_eval_scalar_int(self.tcx, self.typing_env)?; + if others.all(|const_| { + const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) == Some(first_scalar_int) + }) { + Some(StatementKind::Assign(Box::new(( + dest, + Rvalue::Use(Operand::Constant(Box::new(first_const.clone()))), + )))) + } else { + None + } + } - let (_, first) = targets.iter().next().unwrap(); - let statement_index = bbs[switch_bb_idx].statements.len(); - let parent_end = Location { block: switch_bb_idx, statement_index }; - patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); - patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); - self.new_stmts(tcx, targets, typing_env, patch, parent_end, bbs, discr_local, discr_ty); - patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); - patch.patch_terminator(switch_bb_idx, bbs[first].terminator().kind.clone()); - Some(()) + /// If a source block is found that switches between two blocks that are exactly + /// the same modulo const bool assignments (e.g., one assigns true another false + /// to the same place), unify a target block statements into the source block, + /// using Eq / Ne comparison with switch value where const bools value differ. + /// + /// For example: + /// + /// ```ignore (MIR) + /// bb0: { + /// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; + /// } + /// + /// bb1: { + /// _2 = const true; + /// goto -> bb3; + /// } + /// + /// bb2: { + /// _2 = const false; + /// goto -> bb3; + /// } + /// ``` + /// + /// into: + /// + /// ```ignore (MIR) + /// bb0: { + /// _2 = Eq(move _3, const 42_isize); + /// goto -> bb3; + /// } + /// ``` + fn unify_by_eq_op( + &mut self, + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + otherwise: Option<&ConstOperand<'tcx>>, + ) -> Option> { + // FIXME: extend to any case. + let (first_case, first_const, mut others) = split_first_case(consts, otherwise); + if !first_const.ty().is_bool() { + return None; + } + let first_bool = first_const.const_.try_eval_bool(self.tcx, self.typing_env)?; + if others.all(|const_| { + const_.const_.try_eval_bool(self.tcx, self.typing_env) == Some(!first_bool) + }) { + // Make value conditional on switch condition. + let size = + self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap().size; + let const_cmp = Operand::const_from_scalar( + self.tcx, + self.discr_ty, + rustc_const_eval::interpret::Scalar::from_uint(first_case, size), + rustc_span::DUMMY_SP, + ); + let op = if first_bool { BinOp::Eq } else { BinOp::Ne }; + let rval = Rvalue::BinaryOp( + op, + Box::new((Operand::Copy(Place::from(self.discr_local())), const_cmp)), + ); + Some(StatementKind::Assign(Box::new((dest, rval)))) + } else { + None + } } - /// Check that the BBs to be simplified satisfies all distinct and - /// that the terminator are the same. - /// There are also conditions for different ways of simplification. - fn can_simplify( + /// Unifies the assignments if all rvalues can be cast from the discriminant value by IntToInt. + /// + /// For example: + /// + /// ```ignore (MIR) + /// bb0: { + /// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; + /// } + /// + /// bb1: { + /// unreachable; + /// } + /// + /// bb2: { + /// _0 = const 1_i16; + /// goto -> bb5; + /// } + /// + /// bb3: { + /// _0 = const 2_i16; + /// goto -> bb5; + /// } + /// + /// bb4: { + /// _0 = const 3_i16; + /// goto -> bb5; + /// } + /// ``` + /// + /// into: + /// + /// ```ignore (MIR) + /// bb0: { + /// _0 = _1 as i16 (IntToInt); + /// goto -> bb5; + /// } + /// ``` + fn unify_by_int_to_int( &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - discr_ty: Ty<'tcx>, - ) -> Option<()>; + dest: Place<'tcx>, + consts: &[(u128, &ConstOperand<'tcx>)], + ) -> Option> { + let (_, first_const) = consts[0]; + if !first_const.ty().is_integral() { + return None; + } + let discr_layout = + self.tcx.layout_of(self.typing_env.as_query_input(self.discr_ty)).unwrap(); + if consts.iter().all(|&(case, const_)| { + let Some(scalar_int) = const_.const_.try_eval_scalar_int(self.tcx, self.typing_env) + else { + return false; + }; + can_cast(self.tcx, case, discr_layout, const_.ty(), scalar_int) + }) { + let operand = Operand::Copy(Place::from(self.discr_local())); + let rval = if first_const.ty() == self.discr_ty { + Rvalue::Use(operand) + } else { + Rvalue::Cast(CastKind::IntToInt, operand, first_const.ty()) + }; + Some(StatementKind::Assign(Box::new((dest, rval)))) + } else { + None + } + } - fn new_stmts( + /// This is primarily used to unify these copy statements that simplified the canonical enum clone method by GVN. + /// The GVN simplified + /// ```ignore (syntax-highlighting-only) + /// match a { + /// Foo::A(x) => Foo::A(*x), + /// Foo::B => Foo::B + /// } + /// ``` + /// to + /// ```ignore (syntax-highlighting-only) + /// match a { + /// Foo::A(_x) => a, // copy a + /// Foo::B => Foo::B + /// } + /// ``` + /// This will simplify into a copy statement. + fn unify_by_copy( &self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ); -} - -struct SimplifyToIf; - -/// If a source block is found that switches between two blocks that are exactly -/// the same modulo const bool assignments (e.g., one assigns true another false -/// to the same place), merge a target block statements into the source block, -/// using Eq / Ne comparison with switch value where const bools value differ. -/// -/// For example: -/// -/// ```ignore (MIR) -/// bb0: { -/// switchInt(move _3) -> [42_isize: bb1, otherwise: bb2]; -/// } -/// -/// bb1: { -/// _2 = const true; -/// goto -> bb3; -/// } -/// -/// bb2: { -/// _2 = const false; -/// goto -> bb3; -/// } -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// bb0: { -/// _2 = Eq(move _3, const 42_isize); -/// goto -> bb3; -/// } -/// ``` -impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf { - #[instrument(level = "debug", skip(self, tcx), ret)] - fn can_simplify( - &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - _discr_ty: Ty<'tcx>, - ) -> Option<()> { - let (first, second) = match targets.all_targets() { - &[first, otherwise] => (first, otherwise), - &[first, second, otherwise] if bbs[otherwise].is_empty_unreachable() => (first, second), - _ => { - return None; - } + dest: Place<'tcx>, + rvals: &[(u128, &Rvalue<'tcx>)], + ) -> Option> { + let bbs = &self.body.basic_blocks; + // Check if the copy source matches the following pattern. + // _2 = discriminant(*_1); // "*_1" is the expected the copy source. + // switchInt(move _2) -> [0: bb3, 1: bb2, otherwise: bb1]; + let &Statement { + kind: StatementKind::Assign(box (discr_place, Rvalue::Discriminant(copy_src_place))), + .. + } = bbs[self.switch_bb].statements.last()? + else { + return None; }; - - // We require that the possible target blocks all be distinct. - if first == second { + if self.discr.place() != Some(discr_place) { return None; } - // Check that destinations are identical, and if not, then don't optimize this block - if bbs[first].terminator().kind != bbs[second].terminator().kind { + let src_ty = copy_src_place.ty(self.body.local_decls(), self.tcx); + if !src_ty.ty.is_enum() || src_ty.variant_index.is_some() { return None; } - - // Check that blocks are assignments of consts to the same place or same statement, - // and match up 1-1, if not don't optimize this block. - let first_stmts = &bbs[first].statements; - let second_stmts = &bbs[second].statements; - if first_stmts.len() != second_stmts.len() { + let dest_ty = dest.ty(self.body.local_decls(), self.tcx); + if dest_ty.ty != src_ty.ty || dest_ty.variant_index.is_some() { return None; } - for (f, s) in iter::zip(first_stmts, second_stmts) { - match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => {} - - // If two statements are const bool assignments to the same place, we can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty().is_bool() - && s_c.const_.ty().is_bool() - && f_c.const_.try_eval_bool(tcx, typing_env).is_some() - && s_c.const_.try_eval_bool(tcx, typing_env).is_some() => {} + let ty::Adt(def, _) = dest_ty.ty.kind() else { + return None; + }; - // Otherwise we cannot optimize. Try another block. + for &(case, rvalue) in rvals.iter() { + match rvalue { + // Check if `_3 = const Foo::B` can be transformed to `_3 = copy *_1`. + Rvalue::Use(Operand::Constant(box constant)) + if let Const::Val(const_, ty) = constant.const_ => + { + let (ecx, op) = mk_eval_cx_for_const_val( + self.tcx.at(constant.span), + self.typing_env, + const_, + ty, + )?; + let variant = ecx.read_discriminant(&op).discard_err()?; + if !def.variants()[variant].fields.is_empty() { + return None; + } + let Discr { val, .. } = ty.discriminant_for_variant(self.tcx, variant)?; + if val != case { + return None; + } + } + Rvalue::Use(Operand::Copy(src_place)) if *src_place == copy_src_place => {} + // Check if `_3 = Foo::B` can be transformed to `_3 = copy *_1`. + Rvalue::Aggregate(box AggregateKind::Adt(_, variant_index, _, _, None), fields) + if fields.is_empty() + && let Some(Discr { val, .. }) = + src_ty.ty.discriminant_for_variant(self.tcx, *variant_index) + && val == case => {} _ => return None, } } - Some(()) + Some(StatementKind::Assign(Box::new((dest, Rvalue::Use(Operand::Copy(copy_src_place)))))) } - fn new_stmts( - &self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ) { - let ((val, first), second) = match (targets.all_targets(), targets.all_values()) { - (&[first, otherwise], &[val]) => ((val, first), otherwise), - (&[first, second, otherwise], &[val, _]) if bbs[otherwise].is_empty_unreachable() => { - ((val, first), second) + /// Returns a new statement if we can use the statement replace all statements. + fn try_unify_stmts( + &mut self, + index: usize, + stmts: &[(u128, &StatementKind<'tcx>)], + otherwise: Option<&StatementKind<'tcx>>, + ) -> Option> { + if let Some(new_stmt) = identical_stmts(stmts, otherwise) { + return Some(new_stmt); + } + + let (dest, rvals, otherwise) = candidate_assign(stmts, otherwise)?; + if let Some((consts, otherwise)) = candidate_const(&rvals, otherwise) { + if let Some(new_stmt) = self.unify_if_equal_const(dest, &consts, otherwise) { + return Some(new_stmt); } - _ => unreachable!(), - }; + if let Some(new_stmt) = self.unify_by_eq_op(dest, &consts, otherwise) { + return Some(new_stmt); + } + // Requires the otherwise is unreachable. + if otherwise.is_none() + && let Some(new_stmt) = self.unify_by_int_to_int(dest, &consts) + { + return Some(new_stmt); + } + } - // We already checked that first and second are different blocks, - // and bb_idx has a different terminator from both of them. - let first = &bbs[first]; - let second = &bbs[second]; - for (f, s) in iter::zip(&first.statements, &second.statements) { - match (&f.kind, &s.kind) { - (f_s, s_s) if f_s == s_s => { - patch.add_statement(parent_end, f.kind.clone()); - } + // We only know the first statement is safe to introduce new dereferences. + if index == 0 + // We cannot create overlapping assignments. + && dest.is_stable_offset() + // Requires the otherwise is unreachable. + && otherwise.is_none() + && let Some(new_stmt) = self.unify_by_copy(dest, &rvals) + { + return Some(new_stmt); + } + None + } +} - ( - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(s_c)))), - ) => { - // From earlier loop we know that we are dealing with bool constants only: - let f_b = f_c.const_.try_eval_bool(tcx, typing_env).unwrap(); - let s_b = s_c.const_.try_eval_bool(tcx, typing_env).unwrap(); - if f_b == s_b { - // Same value in both blocks. Use statement as is. - patch.add_statement(parent_end, f.kind.clone()); - } else { - // Different value between blocks. Make value conditional on switch - // condition. - let size = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap().size; - let const_cmp = Operand::const_from_scalar( - tcx, - discr_ty, - rustc_const_eval::interpret::Scalar::from_uint(val, size), - rustc_span::DUMMY_SP, - ); - let op = if f_b { BinOp::Eq } else { BinOp::Ne }; - let rhs = Rvalue::BinaryOp( - op, - Box::new((Operand::Copy(Place::from(discr_local)), const_cmp)), - ); - patch.add_assign(parent_end, *lhs, rhs); - } - } +/// Returns the first case target if all targets have an equal number of statements and identical destination. +fn candidate_match<'tcx>(body: &Body<'tcx>, switch_bb: BasicBlock) -> bool { + use itertools::Itertools; + let targets = match &body.basic_blocks[switch_bb].terminator().kind { + TerminatorKind::SwitchInt { + discr: Operand::Copy(_) | Operand::Move(_), targets, .. + } => targets, + // Only optimize switch int statements + _ => return false, + }; + // We require that the possible target blocks don't contain this block. + if targets.all_targets().contains(&switch_bb) { + return false; + } + // We require that the possible target blocks all be distinct. + if !targets.is_distinct() { + return false; + } + // Check that destinations are identical, and if not, then don't optimize this block + targets + .all_targets() + .iter() + .map(|&bb| &body.basic_blocks[bb]) + .filter(|bb| !bb.is_empty_unreachable()) + .map(|bb| (bb.statements.len(), &bb.terminator().kind)) + .all_equal() +} - _ => unreachable!(), - } +fn simplify_match<'tcx>( + tcx: TyCtxt<'tcx>, + typing_env: ty::TypingEnv<'tcx>, + body: &mut Body<'tcx>, + switch_bb: BasicBlock, +) -> bool { + let (discr, targets) = match &body.basic_blocks[switch_bb].terminator().kind { + TerminatorKind::SwitchInt { discr, targets, .. } => (discr, targets), + _ => unreachable!(), + }; + let mut simplify_match = SimplifyMatch { + tcx, + typing_env, + patch: MirPatch::new(body), + body, + switch_bb, + discr, + discr_local: None, + discr_ty: discr.ty(body.local_decls(), tcx), + }; + let reachable_cases: Vec<_> = + targets.iter().filter(|&(_, bb)| !body.basic_blocks[bb].is_empty_unreachable()).collect(); + let mut new_stmts = Vec::new(); + let otherwise = if body.basic_blocks[targets.otherwise()].is_empty_unreachable() { + None + } else { + Some(targets.otherwise()) + }; + // We can patch the terminator to goto because there is a single target. + match (reachable_cases.len(), otherwise.is_none()) { + (1, true) | (0, false) => { + let mut patch = simplify_match.patch; + remove_successors_from_switch(tcx, switch_bb, body, &mut patch, |bb| { + body.basic_blocks[bb].is_empty_unreachable() + }); + patch.apply(body); + return true; } + _ => {} } + let Some(&(_, first_case_bb)) = reachable_cases.first() else { + return false; + }; + let stmt_len = body.basic_blocks[first_case_bb].statements.len(); + let mut cases = Vec::with_capacity(stmt_len); + // Check at each position in the basic blocks whether these statements can be unified. + for index in 0..stmt_len { + cases.clear(); + let otherwise = otherwise.map(|bb| &body.basic_blocks[bb].statements[index].kind); + for &(case, bb) in &reachable_cases { + cases.push((case, &body.basic_blocks[bb].statements[index].kind)); + } + let Some(new_stmt) = simplify_match.try_unify_stmts(index, &cases, otherwise) else { + return false; + }; + new_stmts.push(new_stmt); + } + // Take ownership of items now that we know we can optimize. + let discr = discr.clone(); + + let statement_index = body.basic_blocks[switch_bb].statements.len(); + let parent_end = Location { block: switch_bb, statement_index }; + let mut patch = simplify_match.patch; + if let Some(discr_local) = simplify_match.discr_local { + patch.add_statement(parent_end, StatementKind::StorageLive(discr_local)); + patch.add_assign(parent_end, Place::from(discr_local), Rvalue::Use(discr)); + } + for new_stmt in new_stmts { + patch.add_statement(parent_end, new_stmt); + } + if let Some(discr_local) = simplify_match.discr_local { + patch.add_statement(parent_end, StatementKind::StorageDead(discr_local)); + } + patch.patch_terminator(switch_bb, body.basic_blocks[first_case_bb].terminator().kind.clone()); + patch.apply(body); + true } /// Check if the cast constant using `IntToInt` is equal to the target constant. @@ -298,234 +466,77 @@ fn can_cast( cast_scalar == target_scalar } -#[derive(Default)] -struct SimplifyToExp { - transform_kinds: Vec, -} - -#[derive(Clone, Copy, Debug)] -enum ExpectedTransformKind<'a, 'tcx> { - /// Identical statements. - Same(&'a StatementKind<'tcx>), - /// Assignment statements have the same value. - SameByEq { place: &'a Place<'tcx>, ty: Ty<'tcx>, scalar: ScalarInt }, - /// Enum variant comparison type. - Cast { place: &'a Place<'tcx>, ty: Ty<'tcx> }, -} - -enum TransformKind { - Same, - Cast, -} - -impl From> for TransformKind { - fn from(compare_type: ExpectedTransformKind<'_, '_>) -> Self { - match compare_type { - ExpectedTransformKind::Same(_) => TransformKind::Same, - ExpectedTransformKind::SameByEq { .. } => TransformKind::Same, - ExpectedTransformKind::Cast { .. } => TransformKind::Cast, - } - } -} - -/// If we find that the value of match is the same as the assignment, -/// merge a target block statements into the source block, -/// using cast to transform different integer types. -/// -/// For example: -/// -/// ```ignore (MIR) -/// bb0: { -/// switchInt(_1) -> [1: bb2, 2: bb3, 3: bb4, otherwise: bb1]; -/// } -/// -/// bb1: { -/// unreachable; -/// } -/// -/// bb2: { -/// _0 = const 1_i16; -/// goto -> bb5; -/// } -/// -/// bb3: { -/// _0 = const 2_i16; -/// goto -> bb5; -/// } -/// -/// bb4: { -/// _0 = const 3_i16; -/// goto -> bb5; -/// } -/// ``` -/// -/// into: -/// -/// ```ignore (MIR) -/// bb0: { -/// _0 = _3 as i16 (IntToInt); -/// goto -> bb5; -/// } -/// ``` -impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp { - #[instrument(level = "debug", skip(self, tcx), ret)] - fn can_simplify( - &mut self, - tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - typing_env: ty::TypingEnv<'tcx>, - bbs: &IndexSlice>, - discr_ty: Ty<'tcx>, - ) -> Option<()> { - if targets.iter().len() < 2 || targets.iter().len() > 64 { - return None; - } - // We require that the possible target blocks all be distinct. - if !targets.is_distinct() { - return None; - } - if !bbs[targets.otherwise()].is_empty_unreachable() { +fn candidate_assign<'tcx, 'a>( + stmts: &'a [(u128, &'a StatementKind<'tcx>)], + otherwise: Option<&'a StatementKind<'tcx>>, +) -> Option<(Place<'tcx>, Vec<(u128, &'a Rvalue<'tcx>)>, Option<&'a Rvalue<'tcx>>)> { + let (_, first_stmt) = stmts[0]; + let (dest, _) = first_stmt.as_assign()?; + let otherwise = if let Some(otherwise) = otherwise { + let Some((otherwise_dest, rval)) = otherwise.as_assign() else { return None; - } - let mut target_iter = targets.iter(); - let (first_case_val, first_target) = target_iter.next().unwrap(); - let first_terminator_kind = &bbs[first_target].terminator().kind; - // Check that destinations are identical, and if not, then don't optimize this block - if !targets - .iter() - .all(|(_, other_target)| first_terminator_kind == &bbs[other_target].terminator().kind) - { - return None; - } - - let discr_layout = tcx.layout_of(typing_env.as_query_input(discr_ty)).unwrap(); - let first_stmts = &bbs[first_target].statements; - let (second_case_val, second_target) = target_iter.next().unwrap(); - let second_stmts = &bbs[second_target].statements; - if first_stmts.len() != second_stmts.len() { + }; + if otherwise_dest != dest { return None; } - - // We first compare the two branches, and then the other branches need to fulfill the same - // conditions. - let mut expected_transform_kinds = Vec::new(); - for (f, s) in iter::zip(first_stmts, second_stmts) { - let compare_type = match (&f.kind, &s.kind) { - // If two statements are exactly the same, we can optimize. - (f_s, s_s) if f_s == s_s => ExpectedTransformKind::Same(f_s), - - // If two statements are assignments with the match values to the same place, we - // can optimize. - ( - StatementKind::Assign(box (lhs_f, Rvalue::Use(Operand::Constant(f_c)))), - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && f_c.const_.ty() == s_c.const_.ty() - && f_c.const_.ty().is_integral() => - { - match ( - f_c.const_.try_eval_scalar_int(tcx, typing_env), - s_c.const_.try_eval_scalar_int(tcx, typing_env), - ) { - (Some(f), Some(s)) if f == s => ExpectedTransformKind::SameByEq { - place: lhs_f, - ty: f_c.const_.ty(), - scalar: f, - }, - // Enum variants can also be simplified to an assignment statement, - // if we can use `IntToInt` cast to get an equal value. - (Some(f), Some(s)) - if (can_cast( - tcx, - first_case_val, - discr_layout, - f_c.const_.ty(), - f, - ) && can_cast( - tcx, - second_case_val, - discr_layout, - f_c.const_.ty(), - s, - )) => - { - ExpectedTransformKind::Cast { place: lhs_f, ty: f_c.const_.ty() } - } - _ => { - return None; - } - } - } - - // Otherwise we cannot optimize. Try another block. - _ => return None, - }; - expected_transform_kinds.push(compare_type); - } - - // All remaining BBs need to fulfill the same pattern as the two BBs from the previous step. - for (other_val, other_target) in target_iter { - let other_stmts = &bbs[other_target].statements; - if expected_transform_kinds.len() != other_stmts.len() { + Some(rval) + } else { + None + }; + let rvals = stmts + .into_iter() + .map(|&(case, stmt)| { + let (other_dest, rval) = stmt.as_assign()?; + if other_dest != dest { return None; } - for (f, s) in iter::zip(&expected_transform_kinds, other_stmts) { - match (*f, &s.kind) { - (ExpectedTransformKind::Same(f_s), s_s) if f_s == s_s => {} - ( - ExpectedTransformKind::SameByEq { place: lhs_f, ty: f_ty, scalar }, - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if lhs_f == lhs_s - && s_c.const_.ty() == f_ty - && s_c.const_.try_eval_scalar_int(tcx, typing_env) == Some(scalar) => {} - ( - ExpectedTransformKind::Cast { place: lhs_f, ty: f_ty }, - StatementKind::Assign(box (lhs_s, Rvalue::Use(Operand::Constant(s_c)))), - ) if let Some(f) = s_c.const_.try_eval_scalar_int(tcx, typing_env) - && lhs_f == lhs_s - && s_c.const_.ty() == f_ty - && can_cast(tcx, other_val, discr_layout, f_ty, f) => {} - _ => return None, - } - } - } - self.transform_kinds = expected_transform_kinds.into_iter().map(|c| c.into()).collect(); - Some(()) - } + Some((case, rval)) + }) + .try_collect()?; + Some((*dest, rvals, otherwise)) +} - fn new_stmts( - &self, - _tcx: TyCtxt<'tcx>, - targets: &SwitchTargets, - _typing_env: ty::TypingEnv<'tcx>, - patch: &mut MirPatch<'tcx>, - parent_end: Location, - bbs: &IndexSlice>, - discr_local: Local, - discr_ty: Ty<'tcx>, - ) { - let (_, first) = targets.iter().next().unwrap(); - let first = &bbs[first]; +// Returns all ConstOperands if all Rvalues are ConstOperands. +fn candidate_const<'tcx, 'a>( + rvals: &'a [(u128, &'a Rvalue<'tcx>)], + otherwise: Option<&'a Rvalue<'tcx>>, +) -> Option<(Vec<(u128, &'a ConstOperand<'tcx>)>, Option<&'a ConstOperand<'tcx>>)> { + let otherwise = if let Some(otherwise) = otherwise { + let Rvalue::Use(Operand::Constant(box const_)) = otherwise else { + return None; + }; + Some(const_) + } else { + None + }; + let consts = rvals + .into_iter() + .map(|&(case, rval)| { + let Rvalue::Use(Operand::Constant(box const_)) = rval else { return None }; + Some((case, const_)) + }) + .try_collect()?; + Some((consts, otherwise)) +} - for (t, s) in iter::zip(&self.transform_kinds, &first.statements) { - match (t, &s.kind) { - (TransformKind::Same, _) => { - patch.add_statement(parent_end, s.kind.clone()); - } - ( - TransformKind::Cast, - StatementKind::Assign(box (lhs, Rvalue::Use(Operand::Constant(f_c)))), - ) => { - let operand = Operand::Copy(Place::from(discr_local)); - let r_val = if f_c.const_.ty() == discr_ty { - Rvalue::Use(operand) - } else { - Rvalue::Cast(CastKind::IntToInt, operand, f_c.const_.ty()) - }; - patch.add_assign(parent_end, *lhs, r_val); - } - _ => unreachable!(), - } - } +// Returns the first case and others (including otherwise if present). +fn split_first_case<'a, T>( + stmts: &'a [(u128, &'a T)], + otherwise: Option<&'a T>, +) -> (u128, &'a T, impl Iterator) { + let (first_case, first) = stmts[0]; + (first_case, first, stmts[1..].into_iter().map(|&(_, val)| val).chain(otherwise)) +} + +// If all statements are identical, we can optimize. +fn identical_stmts<'tcx>( + stmts: &[(u128, &StatementKind<'tcx>)], + otherwise: Option<&StatementKind<'tcx>>, +) -> Option> { + use itertools::Itertools; + let (_, first_stmt, others) = split_first_case(stmts, otherwise); + if std::iter::once(first_stmt).chain(others).all_equal() { + return Some(first_stmt.clone()); } + None } diff --git a/compiler/rustc_mir_transform/src/unreachable_prop.rs b/compiler/rustc_mir_transform/src/unreachable_prop.rs index c417a9272f2a9..ddc33eafc9138 100644 --- a/compiler/rustc_mir_transform/src/unreachable_prop.rs +++ b/compiler/rustc_mir_transform/src/unreachable_prop.rs @@ -35,7 +35,9 @@ impl crate::MirPass<'_> for UnreachablePropagation { } // Try to remove unreachable targets from the switch. TerminatorKind::SwitchInt { .. } => { - remove_successors_from_switch(tcx, bb, &unreachable_blocks, body, &mut patch) + remove_successors_from_switch(tcx, bb, body, &mut patch, |bb| { + unreachable_blocks.contains(&bb) + }) } _ => false, }; @@ -60,20 +62,18 @@ impl crate::MirPass<'_> for UnreachablePropagation { } /// Return whether the current terminator is fully unreachable. -fn remove_successors_from_switch<'tcx>( +pub(crate) fn remove_successors_from_switch<'tcx>( tcx: TyCtxt<'tcx>, bb: BasicBlock, - unreachable_blocks: &FxHashSet, body: &Body<'tcx>, patch: &mut MirPatch<'tcx>, + is_unreachable_block: impl Fn(BasicBlock) -> bool, ) -> bool { let terminator = body.basic_blocks[bb].terminator(); let TerminatorKind::SwitchInt { discr, targets } = &terminator.kind else { bug!() }; let source_info = terminator.source_info; let location = body.terminator_loc(bb); - let is_unreachable = |bb| unreachable_blocks.contains(&bb); - // If there are multiple targets, we want to keep information about reachability for codegen. // For example (see tests/codegen-llvm/match-optimizes-away.rs) // @@ -116,10 +116,10 @@ fn remove_successors_from_switch<'tcx>( }; let otherwise = targets.otherwise(); - let otherwise_unreachable = is_unreachable(otherwise); + let otherwise_unreachable = is_unreachable_block(otherwise); let reachable_iter = targets.iter().filter(|&(value, bb)| { - let is_unreachable = is_unreachable(bb); + let is_unreachable = is_unreachable_block(bb); // We remove this target from the switch, so record the inequality using `Assume`. if is_unreachable && !otherwise_unreachable { add_assumption(BinOp::Ne, value); diff --git a/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs b/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs index b8b9ea7436f33..5834255f3d313 100644 --- a/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs +++ b/tests/codegen-llvm/issues/issue-107681-unwrap_unchecked.rs @@ -14,6 +14,7 @@ pub unsafe fn foo(x: &mut Copied>) -> u32 { // CHECK-NOT: br {{.*}} // CHECK-NOT: select // CHECK: [[RET:%.*]] = load i32, ptr + // CHECK-NEXT: assume // CHECK-NEXT: ret i32 [[RET]] x.next().unwrap_unchecked() } diff --git a/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs b/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs index a0b453fac8e93..5b100d2cdc381 100644 --- a/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs +++ b/tests/codegen-llvm/issues/issue-122600-ptr-discriminant-update.rs @@ -26,7 +26,7 @@ pub unsafe fn update(s: *mut State) { // CHECK-NOT: 75{{3|4}} // old: %[[TAG:.+]] = load i8, ptr %s, align 1 - // old-NEXT: trunc nuw i8 %[[TAG]] to i1 + // old-NEXT: and i8 %[[TAG]], 1 // CHECK-NOT: load // CHECK-NOT: store diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir index b7b892c177c3e..e0fcd5c92247c 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-abort.mir @@ -3,6 +3,7 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; + let mut _3: bool; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { let mut _2: isize; scope 2 { @@ -18,16 +19,10 @@ fn unwrap_unchecked(_1: Option) -> T { bb0: { StorageLive(_2); _2 = discriminant(_1); - switchInt(move _2) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { + _3 = Eq(copy _2, const 1_isize); + assume(move _3); _0 = copy ((_1 as Some).0: T); StorageDead(_2); return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir index b7b892c177c3e..e0fcd5c92247c 100644 --- a/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir +++ b/tests/mir-opt/inline/unwrap_unchecked.unwrap_unchecked.PreCodegen.after.panic-unwind.mir @@ -3,6 +3,7 @@ fn unwrap_unchecked(_1: Option) -> T { debug slf => _1; let mut _0: T; + let mut _3: bool; scope 1 (inlined #[track_caller] Option::::unwrap_unchecked) { let mut _2: isize; scope 2 { @@ -18,16 +19,10 @@ fn unwrap_unchecked(_1: Option) -> T { bb0: { StorageLive(_2); _2 = discriminant(_1); - switchInt(move _2) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { + _3 = Eq(copy _2, const 1_isize); + assume(move _3); _0 = copy ((_1 as Some).0: T); StorageDead(_2); return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..a896f66866e62 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff @@ -0,0 +1,49 @@ +- // MIR for `match_eq_bool` before MatchBranchSimplification ++ // MIR for `match_eq_bool` after MatchBranchSimplification + + fn match_eq_bool(_1: i32) -> bool { + debug i => _1; + let mut _0: bool; + let _2: bool; + let _3: (); ++ let mut _4: i32; + scope 1 { + debug a => _2; + } + + bb0: { + StorageLive(_2); + StorageLive(_3); +- switchInt(copy _1) -> [7: bb3, 8: bb2, otherwise: bb1]; +- } +- +- bb1: { +- _2 = const true; ++ StorageLive(_4); ++ _4 = copy _1; ++ _2 = Ne(copy _4, const 7_i32); + _3 = (); +- goto -> bb4; +- } +- +- bb2: { +- _2 = const true; +- _3 = (); +- goto -> bb4; +- } +- +- bb3: { +- _2 = const false; +- _3 = (); +- goto -> bb4; +- } +- +- bb4: { ++ StorageDead(_4); + StorageDead(_3); + _0 = copy _2; + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..3ed8e13304172 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff @@ -0,0 +1,44 @@ +- // MIR for `match_eq_bool_2` before MatchBranchSimplification ++ // MIR for `match_eq_bool_2` after MatchBranchSimplification + + fn match_eq_bool_2(_1: i32) -> bool { + debug i => _1; + let mut _0: bool; + let _2: bool; + let _3: (); + scope 1 { + debug a => _2; + } + + bb0: { + StorageLive(_2); + StorageLive(_3); + switchInt(copy _1) -> [7: bb3, 8: bb2, otherwise: bb1]; + } + + bb1: { + _2 = const true; + _3 = (); + goto -> bb4; + } + + bb2: { + _2 = const false; + _3 = (); + goto -> bb4; + } + + bb3: { + _2 = const false; + _3 = (); + goto -> bb4; + } + + bb4: { + StorageDead(_3); + _0 = copy _2; + StorageDead(_2); + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..76148eb9bd436 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_option.MatchBranchSimplification.diff @@ -0,0 +1,32 @@ +- // MIR for `match_option` before MatchBranchSimplification ++ // MIR for `match_option` after MatchBranchSimplification + + fn match_option(_1: &Option) -> Option { + debug i => _1; + let mut _0: std::option::Option; + let mut _2: isize; + + bb0: { + _2 = discriminant((*_1)); +- switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1]; +- } +- +- bb1: { +- unreachable; +- } +- +- bb2: { +- _0 = Option::::None; +- goto -> bb4; +- } +- +- bb3: { + _0 = copy (*_1); +- goto -> bb4; +- } +- +- bb4: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..e6f273d086417 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff @@ -0,0 +1,39 @@ +- // MIR for `match_option2_mut` before MatchBranchSimplification ++ // MIR for `match_option2_mut` after MatchBranchSimplification + + fn match_option2_mut(_1: &mut Option2) -> Option2 { + let mut _0: Option2; + let mut _2: isize; + + bb0: { + _2 = discriminant((*_1)); + switchInt(copy _2) -> [0: bb1, 1: bb2, 2: bb3, otherwise: bb4]; + } + + bb1: { + (*_1) = Option2::::None2; + _0 = Option2::::None1; + goto -> bb5; + } + + bb2: { + (*_1) = Option2::::None2; + _0 = Option2::::None2; + goto -> bb5; + } + + bb3: { + (*_1) = Option2::::None2; + _0 = copy (*_1); + goto -> bb5; + } + + bb4: { + unreachable; + } + + bb5: { + return; + } + } + diff --git a/tests/mir-opt/matches_reduce_branches.rs b/tests/mir-opt/matches_reduce_branches.rs index 89ef3bfb30857..1766e77cf664b 100644 --- a/tests/mir-opt/matches_reduce_branches.rs +++ b/tests/mir-opt/matches_reduce_branches.rs @@ -81,6 +81,54 @@ fn match_nested_if() -> bool { val } +// EMIT_MIR matches_reduce_branches.match_eq_bool.MatchBranchSimplification.diff +fn match_eq_bool(i: i32) -> bool { + // CHECK-LABEL: fn match_eq_bool( + // CHECK: = Ne( + // CHECK-NOT: switchInt + // CHECK: return + let a; + match i { + 7 => { + a = false; + () + } + 8 => { + a = true; + () + } + _ => { + a = true; + () + } + }; + a +} + +// EMIT_MIR matches_reduce_branches.match_eq_bool_2.MatchBranchSimplification.diff +fn match_eq_bool_2(i: i32) -> bool { + // CHECK-LABEL: fn match_eq_bool_2( + // CHECK-NOT: = Ne( + // CHECK: switchInt + // CHECK: return + let a; + match i { + 7 => { + a = false; + () + } + 8 => { + a = false; + () + } + _ => { + a = true; + () + } + }; + a +} + // # Fold switchInt into IntToInt. // To simplify writing and checking these test cases, I use the first character of // each case to distinguish the sign of the number: @@ -627,6 +675,87 @@ fn match_i128_u128(i: EnumAi128) -> u128 { } } +// EMIT_MIR matches_reduce_branches.match_option.MatchBranchSimplification.diff +fn match_option(i: &Option) -> Option { + // CHECK-LABEL: fn match_option( + // CHECK-NOT: switchInt + // CHECK: _0 = copy (*_1); + match i { + Some(_) => *i, + None => None, + } +} + +enum Option2 { + None1, + None2, + Some(T), +} + +// EMIT_MIR matches_reduce_branches.single_case.MatchBranchSimplification.diff +#[custom_mir(dialect = "runtime")] +fn single_case(i: Option) -> i32 { + // CHECK-LABEL: fn single_case( + // CHECK-NOT: switchInt + mir! { + { + let discr = Discriminant(i); + match discr { + 0 => none, + _ => unreachable_bb, + } + } + none = { + RET = 1; + Return() + } + unreachable_bb = { + Unreachable() + } + } +} + +// We cannot dereference `i` after the value has been changed. +// EMIT_MIR matches_reduce_branches.match_option2_mut.MatchBranchSimplification.diff +#[custom_mir(dialect = "runtime")] +fn match_option2_mut(i: &mut Option2) -> Option2 { + // CHECK-LABEL: fn match_option2_mut( + // CHECK: switchInt + // CHECK: return + mir! { + { + let discr = Discriminant(*i); + match discr { + 0 => none1_bb, + 1 => none2_bb, + 2 => some_bb, + _ => unreachable_bb, + } + } + none1_bb = { + *i = Option2::None2; + RET = Option2::None1; + Goto(ret) + } + none2_bb = { + *i = Option2::None2; + RET = Option2::None2; + Goto(ret) + } + some_bb = { + *i = Option2::None2; + RET = *i; + Goto(ret) + } + unreachable_bb = { + Unreachable() + } + ret = { + Return() + } + } +} + // EMIT_MIR matches_reduce_branches.match_non_int_failed.MatchBranchSimplification.diff #[custom_mir(dialect = "runtime")] fn match_non_int_failed(i: char) -> u8 { @@ -696,4 +825,6 @@ fn main() { let _ = my_is_some(None); let _ = match_non_int_failed('a'); + let _ = match_option(&None); + let _ = match_option2_mut(&mut Option2::None1); } diff --git a/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff new file mode 100644 index 0000000000000..ba99ab0229497 --- /dev/null +++ b/tests/mir-opt/matches_reduce_branches.single_case.MatchBranchSimplification.diff @@ -0,0 +1,25 @@ +- // MIR for `single_case` before MatchBranchSimplification ++ // MIR for `single_case` after MatchBranchSimplification + + fn single_case(_1: Option) -> i32 { + let mut _0: i32; + let mut _2: isize; ++ let mut _3: bool; + + bb0: { + _2 = discriminant(_1); +- switchInt(copy _2) -> [0: bb1, otherwise: bb2]; +- } +- +- bb1: { ++ _3 = Eq(copy _2, const 0_isize); ++ assume(move _3); + _0 = const 1_i32; + return; +- } +- +- bb2: { +- unreachable; + } + } + diff --git a/tests/mir-opt/pre-codegen/copy_and_clone.rs b/tests/mir-opt/pre-codegen/copy_and_clone.rs new file mode 100644 index 0000000000000..05da25afa2a39 --- /dev/null +++ b/tests/mir-opt/pre-codegen/copy_and_clone.rs @@ -0,0 +1,250 @@ +//@ [COPY] compile-flags: --cfg=copy +//@ revisions: COPY CLONE + +// Test case from https://github.com/rust-lang/rust/issues/128081. +// Ensure both Copy and Clone get optimized copy. + +#[unsafe(no_mangle)] +pub fn intra_clone(intra: &Av1BlockIntra) -> Av1BlockIntraInter { + // CHECK-LABEL: fn intra_clone( + // CHECK: [[C:_.*]] = copy (*_1); + // CHECK: _0 = Av1BlockIntraInter::Intra(move [[C]]); + Av1BlockIntraInter::Intra(intra.clone()) +} + +#[unsafe(no_mangle)] +pub fn inter_clone(inter: &Av1BlockInter) -> Av1BlockIntraInter { + // CHECK-LABEL: fn inter_clone( + // CHECK: [[C:_.*]] = copy (*_1); + // CHECK: _0 = Av1BlockIntraInter::Inter(move [[C]]); + Av1BlockIntraInter::Inter(inter.clone()) +} + +#[unsafe(no_mangle)] +pub fn dav1dsequenceheader_copy(v: &Dav1dSequenceHeader) -> Dav1dSequenceHeader { + // CHECK-LABEL: fn dav1dsequenceheader_copy( + // CHECK: _0 = copy (*_1); + v.clone() +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct mv { + pub y: i16, + pub x: i16, +} + +#[derive(Clone, Copy)] +#[repr(transparent)] +pub struct MaskedInterIntraPredMode(u8); + +#[derive(Clone)] +#[cfg_attr(copy, derive(Copy))] +#[repr(C)] +pub struct Av1BlockInter1d { + pub mv: [mv; 2], + pub wedge_idx: u8, + pub mask_sign: u8, + pub interintra_mode: MaskedInterIntraPredMode, + pub _padding: u8, +} + +#[derive(Clone)] +#[cfg_attr(copy, derive(Copy))] +#[repr(C)] +pub struct Av1BlockInterNd { + pub one_d: Av1BlockInter1d, +} + +#[derive(Clone, Copy)] +pub enum CompInterType { + WeightedAvg = 1, + Avg = 2, + Seg = 3, + Wedge = 4, +} + +#[derive(Clone, Copy)] +pub enum MotionMode { + Translation = 0, + Obmc = 1, + Warp = 2, +} + +#[derive(Clone, Copy)] +pub enum DrlProximity { + Nearest, + Nearer, + Near, + Nearish, +} + +#[derive(Clone, Copy)] +pub enum TxfmSize { + S4x4 = 0, + S8x8 = 1, + S16x16 = 2, + S32x32 = 3, + S64x64 = 4, + R4x8 = 5, + R8x4 = 6, + R8x16 = 7, + R16x8 = 8, + R16x32 = 9, + R32x16 = 10, + R32x64 = 11, + R64x32 = 12, + R4x16 = 13, + R16x4 = 14, + R8x32 = 15, + R32x8 = 16, + R16x64 = 17, + R64x16 = 18, +} + +#[derive(Clone, Copy)] +pub enum Filter2d { + Regular8Tap = 0, + RegularSmooth8Tap = 1, + RegularSharp8Tap = 2, + SharpRegular8Tap = 3, + SharpSmooth8Tap = 4, + Sharp8Tap = 5, + SmoothRegular8Tap = 6, + Smooth8Tap = 7, + SmoothSharp8Tap = 8, + Bilinear = 9, +} + +#[derive(Clone, Copy)] +pub enum InterIntraType { + Blend, + Wedge, +} + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Av1BlockInter { + pub nd: Av1BlockInterNd, + pub comp_type: Option, + pub inter_mode: u8, + pub motion_mode: MotionMode, + pub drl_idx: DrlProximity, + pub r#ref: [i8; 2], + pub max_ytx: TxfmSize, + pub filter2d: Filter2d, + pub interintra_type: Option, + pub tx_split0: u8, + pub tx_split1: u16, +} + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Av1BlockIntra { + pub y_mode: u8, + pub uv_mode: u8, + pub tx: TxfmSize, + pub pal_sz: [u8; 2], + pub y_angle: i8, + pub uv_angle: i8, + pub cfl_alpha: [i8; 2], +} + +#[repr(C)] +pub enum Av1BlockIntraInter { + Intra(Av1BlockIntra), + Inter(Av1BlockInter), +} + +use std::ffi::{c_int, c_uint}; + +pub type Dav1dPixelLayout = c_uint; +pub type Dav1dColorPrimaries = c_uint; +pub type Dav1dTransferCharacteristics = c_uint; +pub type Dav1dMatrixCoefficients = c_uint; +pub type Dav1dChromaSamplePosition = c_uint; +pub type Dav1dAdaptiveBoolean = c_uint; + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Dav1dSequenceHeaderOperatingPoint { + pub major_level: u8, + pub minor_level: u8, + pub initial_display_delay: u8, + pub idc: u16, + pub tier: u8, + pub decoder_model_param_present: u8, + pub display_model_param_present: u8, +} + +#[derive(Clone, Copy)] +#[repr(C)] +pub struct Dav1dSequenceHeaderOperatingParameterInfo { + pub decoder_buffer_delay: u32, + pub encoder_buffer_delay: u32, + pub low_delay_mode: u8, +} + +pub const DAV1D_MAX_OPERATING_POINTS: usize = 32; + +#[cfg_attr(copy, derive(Copy))] +#[derive(Clone)] +#[repr(C)] +pub struct Dav1dSequenceHeader { + pub profile: u8, + pub max_width: c_int, + pub max_height: c_int, + pub layout: Dav1dPixelLayout, + pub pri: Dav1dColorPrimaries, + pub trc: Dav1dTransferCharacteristics, + pub mtrx: Dav1dMatrixCoefficients, + pub chr: Dav1dChromaSamplePosition, + pub hbd: u8, + pub color_range: u8, + pub num_operating_points: u8, + pub operating_points: [Dav1dSequenceHeaderOperatingPoint; DAV1D_MAX_OPERATING_POINTS], + pub still_picture: u8, + pub reduced_still_picture_header: u8, + pub timing_info_present: u8, + pub num_units_in_tick: u32, + pub time_scale: u32, + pub equal_picture_interval: u8, + pub num_ticks_per_picture: u32, + pub decoder_model_info_present: u8, + pub encoder_decoder_buffer_delay_length: u8, + pub num_units_in_decoding_tick: u32, + pub buffer_removal_delay_length: u8, + pub frame_presentation_delay_length: u8, + pub display_model_info_present: u8, + pub width_n_bits: u8, + pub height_n_bits: u8, + pub frame_id_numbers_present: u8, + pub delta_frame_id_n_bits: u8, + pub frame_id_n_bits: u8, + pub sb128: u8, + pub filter_intra: u8, + pub intra_edge_filter: u8, + pub inter_intra: u8, + pub masked_compound: u8, + pub warped_motion: u8, + pub dual_filter: u8, + pub order_hint: u8, + pub jnt_comp: u8, + pub ref_frame_mvs: u8, + pub screen_content_tools: Dav1dAdaptiveBoolean, + pub force_integer_mv: Dav1dAdaptiveBoolean, + pub order_hint_n_bits: u8, + pub super_res: u8, + pub cdef: u8, + pub restoration: u8, + pub ss_hor: u8, + pub ss_ver: u8, + pub monochrome: u8, + pub color_description_present: u8, + pub separate_uv_delta_q: u8, + pub film_grain_present: u8, + pub operating_parameter_info: + [Dav1dSequenceHeaderOperatingParameterInfo; DAV1D_MAX_OPERATING_POINTS], +} diff --git a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir index 8a6732d5f745a..d08aa8456e7f9 100644 --- a/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/duplicate_switch_targets.ub_if_b.PreCodegen.after.mir @@ -4,6 +4,7 @@ fn ub_if_b(_1: Thing) -> Thing { debug t => _1; let mut _0: Thing; let mut _2: isize; + let mut _3: bool; scope 1 (inlined #[track_caller] unreachable_unchecked) { scope 2 (inlined core::ub_checks::check_language_ub) { scope 3 (inlined core::ub_checks::check_language_ub::runtime) { @@ -13,15 +14,9 @@ fn ub_if_b(_1: Thing) -> Thing { bb0: { _2 = discriminant(_1); - switchInt(move _2) -> [0: bb1, 1: bb2, otherwise: bb2]; - } - - bb1: { + _3 = Eq(copy _2, const 0_isize); + assume(move _3); _0 = move _1; return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir index b2b7f88d8534b..c0f3978663960 100644 --- a/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir +++ b/tests/mir-opt/pre-codegen/two_unwrap_unchecked.two_unwrap_unchecked.PreCodegen.after.mir @@ -4,11 +4,12 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { debug v => _1; let mut _0: i32; let mut _2: std::option::Option; - let _4: i32; + let mut _4: bool; + let _5: i32; scope 1 { - debug v1 => _4; + debug v1 => _5; scope 2 { - debug v2 => _4; + debug v2 => _5; } scope 8 (inlined #[track_caller] Option::::unwrap_unchecked) { scope 9 { @@ -36,16 +37,10 @@ fn two_unwrap_unchecked(_1: &Option) -> i32 { bb0: { _2 = copy (*_1); _3 = discriminant(_2); - switchInt(copy _3) -> [0: bb2, 1: bb1, otherwise: bb2]; - } - - bb1: { - _4 = copy ((_2 as Some).0: i32); - _0 = Add(copy _4, copy _4); + _4 = Eq(copy _3, const 1_isize); + assume(move _4); + _5 = copy ((_2 as Some).0: i32); + _0 = Add(copy _5, copy _5); return; } - - bb2: { - unreachable; - } } diff --git a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff index ff1bc58524bc2..dd21719adb656 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff +++ b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-abort.diff @@ -10,7 +10,6 @@ let mut _5: isize; - let mut _7: bool; - let mut _8: u8; -- let mut _9: bool; scope 1 { debug a => _6; let _6: u8; diff --git a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff index 2c289c664754a..6e50b615030f9 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff +++ b/tests/mir-opt/simplify_locals_fixedpoint.foo.SimplifyLocals-final.panic-unwind.diff @@ -10,7 +10,6 @@ let mut _5: isize; - let mut _7: bool; - let mut _8: u8; -- let mut _9: bool; scope 1 { debug a => _6; let _6: u8; diff --git a/tests/mir-opt/simplify_locals_fixedpoint.rs b/tests/mir-opt/simplify_locals_fixedpoint.rs index 0b6c95630c0a7..01aa7df0716f0 100644 --- a/tests/mir-opt/simplify_locals_fixedpoint.rs +++ b/tests/mir-opt/simplify_locals_fixedpoint.rs @@ -1,8 +1,12 @@ -// skip-filecheck // EMIT_MIR_FOR_EACH_PANIC_STRATEGY -//@ compile-flags: -Zmir-opt-level=1 +//@ compile-flags: -Zmir-opt-level=1 -Zmir-enable-passes=+MatchBranchSimplification +// EMIT_MIR simplify_locals_fixedpoint.foo.SimplifyLocals-final.diff fn foo() { + // CHECK-LABEL: fn foo( + // CHECK-NOT: let mut {{.*}}: bool; + // CHECK-NOT: let mut {{.*}}: u8; + // CHECK-NOT: let mut {{.*}}: bool; if let (Some(a), None) = (Option::::None, Option::::None) { if a > 42u8 {} } @@ -11,5 +15,3 @@ fn foo() { fn main() { foo::<()>(); } - -// EMIT_MIR simplify_locals_fixedpoint.foo.SimplifyLocals-final.diff