Skip to content

Rust: Type inference for impl trait types with type parameters #20119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jul 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/ql/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion rust/ql/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
// generated by codegen, remove this comment if you wish to edit this file
/**
* This module provides a hand-modifiable wrapper around the generated class `ImplTraitTypeRepr`.
*
* INTERNAL: Do not use.
*/

private import codeql.rust.elements.internal.generated.ImplTraitTypeRepr
private import rust

/**
* INTERNAL: This module contains the customizable definition of `ImplTraitTypeRepr` and should not
* be referenced directly.
*/
module Impl {
// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* An `impl Trait` type.
*
Expand All @@ -21,5 +22,15 @@ module Impl {
* // ^^^^^^^^^^^^^^^^^^^^^^^^^^
* ```
*/
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr { }
class ImplTraitTypeRepr extends Generated::ImplTraitTypeRepr {
/** Gets the function for which this impl trait type occurs, if any. */
Function getFunction() {
this.getParentNode*() = [result.getRetType().getTypeRepr(), result.getAParam().getTypeRepr()]
}

/** Holds if this impl trait type occurs in the return type of a function. */
predicate isInReturnPos() {
this.getParentNode*() = this.getFunction().getRetType().getTypeRepr()
}
}
}
39 changes: 36 additions & 3 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,21 @@ newtype TType =
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
TDynTraitTypeParameter(AstNode n) { dynTraitTypeParameter(_, n) } or
TImplTraitTypeParameter(ImplTraitTypeRepr implTrait, TypeParam tp) {
implTraitTypeParam(implTrait, _, tp)
} or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()

predicate implTraitTypeParam(ImplTraitTypeRepr implTrait, int i, TypeParam tp) {
implTrait.isInReturnPos() and
tp = implTrait.getFunction().getGenericParamList().getTypeParam(i) and
// Only include type parameters of the function that occur inside the impl
// trait type.
exists(Path path | path.getParentNode*() = implTrait and resolvePath(path) = tp)
}

/**
* A type without type arguments.
*
Expand Down Expand Up @@ -263,7 +274,12 @@ class ImplTraitType extends Type, TImplTraitType {

override TupleField getTupleField(int i) { none() }

override TypeParameter getTypeParameter(int i) { none() }
override TypeParameter getTypeParameter(int i) {
exists(TypeParam tp |
implTraitTypeParam(impl, i, tp) and
result = TImplTraitTypeParameter(impl, tp)
)
}

override string toString() { result = impl.toString() }

Expand Down Expand Up @@ -302,7 +318,7 @@ class DynTraitType extends Type, TDynTraitType {
class ImplTraitReturnType extends ImplTraitType {
private Function function;

ImplTraitReturnType() { impl = function.getRetType().getTypeRepr() }
ImplTraitReturnType() { impl.isInReturnPos() and function = impl.getFunction() }

override Function getFunction() { result = function }
}
Expand Down Expand Up @@ -456,6 +472,21 @@ class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
override Location getLocation() { result = n.getLocation() }
}

class ImplTraitTypeParameter extends TypeParameter, TImplTraitTypeParameter {
private TypeParam typeParam;
private ImplTraitTypeRepr implTrait;

ImplTraitTypeParameter() { this = TImplTraitTypeParameter(implTrait, typeParam) }

TypeParam getTypeParam() { result = typeParam }

ImplTraitTypeRepr getImplTraitTypeRepr() { result = implTrait }

override string toString() { result = "impl(" + typeParam.toString() + ")" }

override Location getLocation() { result = typeParam.getLocation() }
}

/** An implicit reference type parameter. */
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
override string toString() { result = "&T" }
Expand Down Expand Up @@ -569,5 +600,7 @@ final class SelfTypeBoundTypeAbstraction extends TypeAbstraction, Name {
}

final class ImplTraitTypeReprAbstraction extends TypeAbstraction, ImplTraitTypeRepr {
override TypeParameter getATypeParameter() { none() }
override TypeParameter getATypeParameter() {
implTraitTypeParam(this, _, result.(TypeParamTypeParameter).getTypeParam())
}
}
32 changes: 19 additions & 13 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -83,42 +83,48 @@ private module Input1 implements InputSig1<Location> {

int getTypeParameterId(TypeParameter tp) {
tp =
rank[result](TypeParameter tp0, int kind, int id |
rank[result](TypeParameter tp0, int kind, int id1, int id2 |
tp0 instanceof ArrayTypeParameter and
kind = 0 and
id = 0
id1 = 0 and
id2 = 0
or
tp0 instanceof RefTypeParameter and
kind = 0 and
id = 1
id1 = 0 and
id2 = 1
or
tp0 instanceof SliceTypeParameter and
kind = 0 and
id = 2
id1 = 0 and
id2 = 2
or
kind = 1 and
id =
id1 = 0 and
id2 =
idOfTypeParameterAstNode([
tp0.(DynTraitTypeParameter).getTypeParam().(AstNode),
tp0.(DynTraitTypeParameter).getTypeAlias()
])
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
id1 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getImplTraitTypeRepr()) and
id2 = idOfTypeParameterAstNode(tp0.(ImplTraitTypeParameter).getTypeParam())
or
kind = 3 and
id1 = 0 and
exists(AstNode node | id2 = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
node = tp0.(SelfTypeParameter).getTrait() or
node = tp0.(ImplTraitTypeTypeParameter).getImplTraitTypeRepr()
)
or
exists(TupleTypeParameter ttp, int maxArity |
maxArity = max(int i | i = any(TupleType tt).getArity()) and
tp0 = ttp and
kind = 3 and
id = ttp.getTupleType().getArity() * maxArity + ttp.getIndex()
)
kind = 4 and
id1 = tp0.(TupleTypeParameter).getTupleType().getArity() and
id2 = tp0.(TupleTypeParameter).getIndex()
|
tp0 order by kind, id
tp0 order by kind, id1, id2
)
}
}
Expand Down
6 changes: 6 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ class ImplTraitTypeReprMention extends TypeMention instanceof ImplTraitTypeRepr
override Type resolveTypeAt(TypePath typePath) {
typePath.isEmpty() and
result.(ImplTraitType).getImplTraitTypeRepr() = this
or
exists(ImplTraitTypeParameter tp |
this = tp.getImplTraitTypeRepr() and
typePath = TypePath::singleton(tp) and
result = TTypeParamTypeParameter(tp.getTypeParam())
)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
multipleCallTargets
| dereference.rs:61:15:61:24 | e1.deref() |
| main.rs:2253:13:2253:31 | ...::from(...) |
| main.rs:2254:13:2254:31 | ...::from(...) |
| main.rs:2255:13:2255:31 | ...::from(...) |
| main.rs:2261:13:2261:31 | ...::from(...) |
| main.rs:2262:13:2262:31 | ...::from(...) |
| main.rs:2263:13:2263:31 | ...::from(...) |
| main.rs:2278:13:2278:31 | ...::from(...) |
| main.rs:2279:13:2279:31 | ...::from(...) |
| main.rs:2280:13:2280:31 | ...::from(...) |
| main.rs:2286:13:2286:31 | ...::from(...) |
| main.rs:2287:13:2287:31 | ...::from(...) |
| main.rs:2288:13:2288:31 | ...::from(...) |
27 changes: 26 additions & 1 deletion rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1913,8 +1913,10 @@ mod async_ {
}

mod impl_trait {
#[derive(Copy, Clone)]
struct S1;
struct S2;
struct S3<T3>(T3);

trait Trait1 {
fn f1(&self) {} // Trait1f1
Expand Down Expand Up @@ -1946,6 +1948,13 @@ mod impl_trait {
}
}

impl<T: Clone> MyTrait<T> for S3<T> {
fn get_a(&self) -> T {
let S3(t) = self;
t.clone()
}
}

fn get_a_my_trait() -> impl MyTrait<S2> {
S1
}
Expand All @@ -1954,6 +1963,18 @@ mod impl_trait {
t.get_a() // $ target=MyTrait::get_a
}

fn get_a_my_trait2<T: Clone>(x: T) -> impl MyTrait<T> {
S3(x)
}

fn get_a_my_trait3<T: Clone>(x: T) -> Option<impl MyTrait<T>> {
Some(S3(x))
}

fn get_a_my_trait4<T: Clone>(x: T) -> (impl MyTrait<T>, impl MyTrait<T>) {
(S3(x.clone()), S3(x)) // $ target=clone
}

fn uses_my_trait2<A>(t: impl MyTrait<A>) -> A {
t.get_a() // $ target=MyTrait::get_a
}
Expand All @@ -1967,6 +1988,10 @@ mod impl_trait {
let a = get_a_my_trait(); // $ target=get_a_my_trait
let c = uses_my_trait2(a); // $ type=c:S2 target=uses_my_trait2
let d = uses_my_trait2(S1); // $ type=d:S2 target=uses_my_trait2
let e = get_a_my_trait2(S1).get_a(); // $ target=get_a_my_trait2 target=MyTrait::get_a type=e:S1
// For this function the `impl` type does not appear in the root of the return type
let f = get_a_my_trait3(S1).unwrap().get_a(); // $ target=get_a_my_trait3 target=unwrap target=MyTrait::get_a type=f:S1
let g = get_a_my_trait4(S1).0.get_a(); // $ target=get_a_my_trait4 target=MyTrait::get_a type=g:S1
}
}

Expand Down Expand Up @@ -2425,7 +2450,7 @@ mod tuples {

let pair = [1, 1].into(); // $ type=pair:(T_2) type=pair:0(2).i32 type=pair:1(2).i32 MISSING: target=into
match pair {
(0,0) => print!("unexpected"),
(0, 0) => print!("unexpected"),
_ => print!("expected"),
}
let x = pair.0; // $ type=x:i32
Expand Down
Loading