Skip to content

fix: In "Wrap return type" assist, don't wrap exit points if they already have the right type #20061

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions crates/hir/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1727,10 +1727,10 @@ impl Adt {
pub fn ty_with_args<'db>(
self,
db: &'db dyn HirDatabase,
args: impl Iterator<Item = Type<'db>>,
args: impl IntoIterator<Item = Type<'db>>,
) -> Type<'db> {
let id = AdtId::from(self);
let mut it = args.map(|t| t.ty);
let mut it = args.into_iter().map(|t| t.ty);
let ty = TyBuilder::def_ty(db, id.into(), None)
.fill(|x| {
let r = it.next().unwrap_or_else(|| TyKind::Error.intern(Interner));
Expand Down
163 changes: 131 additions & 32 deletions crates/ide-assists/src/handlers/wrap_return_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,15 +56,16 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
};

let type_ref = &ret_type.ty()?;
let ty = ctx.sema.resolve_type(type_ref)?.as_adt();
let ty = ctx.sema.resolve_type(type_ref)?;
let ty_adt = ty.as_adt();
let famous_defs = FamousDefs(&ctx.sema, ctx.sema.scope(type_ref.syntax())?.krate());

for kind in WrapperKind::ALL {
let Some(core_wrapper) = kind.core_type(&famous_defs) else {
continue;
};

if matches!(ty, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
if matches!(ty_adt, Some(hir::Adt::Enum(ret_type)) if ret_type == core_wrapper) {
// The return type is already wrapped
cov_mark::hit!(wrap_return_type_simple_return_type_already_wrapped);
continue;
Expand All @@ -78,10 +79,23 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
|builder| {
let mut editor = builder.make_editor(&parent);
let make = SyntaxFactory::with_mappings();
let alias = wrapper_alias(ctx, &make, &core_wrapper, type_ref, kind.symbol());
let new_return_ty = alias.unwrap_or_else(|| match kind {
WrapperKind::Option => make.ty_option(type_ref.clone()),
WrapperKind::Result => make.ty_result(type_ref.clone(), make.ty_infer().into()),
let alias = wrapper_alias(ctx, &make, core_wrapper, type_ref, &ty, kind.symbol());
let (ast_new_return_ty, semantic_new_return_ty) = alias.unwrap_or_else(|| {
let (ast_ty, ty_constructor) = match kind {
WrapperKind::Option => {
(make.ty_option(type_ref.clone()), famous_defs.core_option_Option())
}
WrapperKind::Result => (
make.ty_result(type_ref.clone(), make.ty_infer().into()),
famous_defs.core_result_Result(),
),
};
let semantic_ty = ty_constructor
.map(|ty_constructor| {
hir::Adt::from(ty_constructor).ty_with_args(ctx.db(), [ty.clone()])
})
.unwrap_or_else(|| ty.clone());
(ast_ty, semantic_ty)
});

let mut exprs_to_wrap = Vec::new();
Expand All @@ -96,19 +110,30 @@ pub(crate) fn wrap_return_type(acc: &mut Assists, ctx: &AssistContext<'_>) -> Op
for_each_tail_expr(&body_expr, tail_cb);

for ret_expr_arg in exprs_to_wrap {
if let Some(ty) = ctx.sema.type_of_expr(&ret_expr_arg) {
if ty.adjusted().could_unify_with(ctx.db(), &semantic_new_return_ty) {
// The type is already correct, don't wrap it.
// We deliberately don't use `could_unify_with_deeply()`, because as long as the outer
// enum matches it's okay for us, as we don't trigger the assist if the return type
// is already `Option`/`Result`, so mismatched exact type is more likely a mistake
// than something intended.
continue;
}
}

let happy_wrapped = make.expr_call(
make.expr_path(make.ident_path(kind.happy_ident())),
make.arg_list(iter::once(ret_expr_arg.clone())),
);
editor.replace(ret_expr_arg.syntax(), happy_wrapped.syntax());
}

editor.replace(type_ref.syntax(), new_return_ty.syntax());
editor.replace(type_ref.syntax(), ast_new_return_ty.syntax());

if let WrapperKind::Result = kind {
// Add a placeholder snippet at the first generic argument that doesn't equal the return type.
// This is normally the error type, but that may not be the case when we inserted a type alias.
let args = new_return_ty
let args = ast_new_return_ty
.path()
.unwrap()
.segment()
Expand Down Expand Up @@ -188,35 +213,36 @@ impl WrapperKind {
}

// Try to find an wrapper type alias in the current scope (shadowing the default).
fn wrapper_alias(
ctx: &AssistContext<'_>,
fn wrapper_alias<'db>(
ctx: &AssistContext<'db>,
make: &SyntaxFactory,
core_wrapper: &hir::Enum,
ret_type: &ast::Type,
core_wrapper: hir::Enum,
ast_ret_type: &ast::Type,
semantic_ret_type: &hir::Type<'db>,
wrapper: hir::Symbol,
) -> Option<ast::PathType> {
) -> Option<(ast::PathType, hir::Type<'db>)> {
let wrapper_path = hir::ModPath::from_segments(
hir::PathKind::Plain,
iter::once(hir::Name::new_symbol_root(wrapper)),
);

ctx.sema.resolve_mod_path(ret_type.syntax(), &wrapper_path).and_then(|def| {
ctx.sema.resolve_mod_path(ast_ret_type.syntax(), &wrapper_path).and_then(|def| {
def.filter_map(|def| match def.into_module_def() {
hir::ModuleDef::TypeAlias(alias) => {
let enum_ty = alias.ty(ctx.db()).as_adt()?.as_enum()?;
(&enum_ty == core_wrapper).then_some(alias)
(enum_ty == core_wrapper).then_some((alias, enum_ty))
}
_ => None,
})
.find_map(|alias| {
.find_map(|(alias, enum_ty)| {
let mut inserted_ret_type = false;
let generic_args =
alias.source(ctx.db())?.value.generic_param_list()?.generic_params().map(|param| {
match param {
// Replace the very first type parameter with the function's return type.
ast::GenericParam::TypeParam(_) if !inserted_ret_type => {
inserted_ret_type = true;
make.type_arg(ret_type.clone()).into()
make.type_arg(ast_ret_type.clone()).into()
}
ast::GenericParam::LifetimeParam(_) => {
make.lifetime_arg(make.lifetime("'_")).into()
Expand All @@ -231,7 +257,10 @@ fn wrapper_alias(
make.path_segment_generics(make.name_ref(name.as_str()), generic_arg_list),
);

Some(make.ty_path(path))
let new_ty =
hir::Adt::from(enum_ty).ty_with_args(ctx.db(), [semantic_ret_type.clone()]);

Some((make.ty_path(path), new_ty))
})
})
}
Expand Down Expand Up @@ -605,29 +634,39 @@ fn foo() -> Option<i32> {
check_assist_by_label(
wrap_return_type,
r#"
//- minicore: option
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needed to change because previously the type was the error type, which unifies with anything.

//- minicore: option, future
struct F(i32);
impl core::future::Future for F {
type Output = i32;
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
}
async fn foo() -> i$032 {
if true {
if false {
1.await
F(1).await
} else {
2.await
F(2).await
}
} else {
24i32.await
F(24i32).await
}
}
"#,
r#"
struct F(i32);
impl core::future::Future for F {
type Output = i32;
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
}
async fn foo() -> Option<i32> {
if true {
if false {
Some(1.await)
Some(F(1).await)
} else {
Some(2.await)
Some(F(2).await)
}
} else {
Some(24i32.await)
Some(F(24i32).await)
}
}
"#,
Expand Down Expand Up @@ -1666,29 +1705,39 @@ fn foo() -> Result<i32, ${0:_}> {
check_assist_by_label(
wrap_return_type,
r#"
//- minicore: result
//- minicore: result, future
struct F(i32);
impl core::future::Future for F {
type Output = i32;
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
}
async fn foo() -> i$032 {
if true {
if false {
1.await
F(1).await
} else {
2.await
F(2).await
}
} else {
24i32.await
F(24i32).await
}
}
"#,
r#"
struct F(i32);
impl core::future::Future for F {
type Output = i32;
fn poll(self: core::pin::Pin<&mut Self>, cx: &mut core::task::Context<'_>) -> core::task::Poll<Self::Output> { 0 }
}
async fn foo() -> Result<i32, ${0:_}> {
if true {
if false {
Ok(1.await)
Ok(F(1).await)
} else {
Ok(2.await)
Ok(F(2).await)
}
} else {
Ok(24i32.await)
Ok(F(24i32).await)
}
}
"#,
Expand Down Expand Up @@ -2455,6 +2504,56 @@ type Result<T, const N: usize> = core::result::Result<Foo<T>, Bar<N>>;

fn foo() -> Result<i32, ${0:_}> {
Ok(0)
}
"#,
WrapperKind::Result.label(),
);
}

#[test]
fn already_wrapped() {
check_assist_by_label(
wrap_return_type,
r#"
//- minicore: option
fn foo() -> i32$0 {
if false {
0
} else {
Some(1)
}
}
"#,
r#"
fn foo() -> Option<i32> {
if false {
Some(0)
} else {
Some(1)
}
}
"#,
WrapperKind::Option.label(),
);
check_assist_by_label(
wrap_return_type,
r#"
//- minicore: result
fn foo() -> i32$0 {
if false {
0
} else {
Ok(1)
}
}
"#,
r#"
fn foo() -> Result<i32, ${0:_}> {
if false {
Ok(0)
} else {
Ok(1)
}
}
"#,
WrapperKind::Result.label(),
Expand Down