Skip to content

Rust: Implement type inference for trait objects/dyn types #20084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion rust/ql/.generated.list

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion rust/ql/.gitattributes

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
// generated by codegen, remove this comment if you wish to edit this file
/**
* This module provides a hand-modifiable wrapper around the generated class `DynTraitTypeRepr`.
*
Expand All @@ -12,6 +11,10 @@ private import codeql.rust.elements.internal.generated.DynTraitTypeRepr
* be referenced directly.
*/
module Impl {
private import rust
private import codeql.rust.internal.PathResolution as PathResolution

// the following QLdoc is generated: if you need to edit it, do it in the schema file
/**
* A dynamic trait object type.
*
Expand All @@ -21,5 +24,16 @@ module Impl {
* // ^^^^^^^^^
* ```
*/
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr { }
class DynTraitTypeRepr extends Generated::DynTraitTypeRepr {
/** Gets the trait that this trait object refers to. */
pragma[nomagic]
Trait getTrait() {
result =
PathResolution::resolvePath(this.getTypeBoundList()
.getBound(0)
.getTypeRepr()
.(PathTypeRepr)
.getPath())
}
}
}
43 changes: 43 additions & 0 deletions rust/ql/lib/codeql/rust/internal/Type.qll
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,14 @@ newtype TType =
TArrayType() or // todo: add size?
TRefType() or // todo: add mut?
TImplTraitType(ImplTraitTypeRepr impl) or
TDynTraitType(Trait t) { t = any(DynTraitTypeRepr dt).getTrait() } or
TSliceType() or
TTypeParamTypeParameter(TypeParam t) or
TAssociatedTypeTypeParameter(TypeAlias t) { any(TraitItemNode trait).getAnAssocItem() = t } or
TArrayTypeParameter() or
TDynTraitTypeParameter(TypeParam tp) {
tp = any(DynTraitTypeRepr dt).getTrait().getGenericParamList().getAGenericParam()
} or
TRefTypeParameter() or
TSelfTypeParameter(Trait t) or
TSliceTypeParameter()
Expand Down Expand Up @@ -226,6 +230,26 @@ class ImplTraitType extends Type, TImplTraitType {
override Location getLocation() { result = impl.getLocation() }
}

class DynTraitType extends Type, TDynTraitType {
Trait trait;

DynTraitType() { this = TDynTraitType(trait) }

override StructField getStructField(string name) { none() }

override TupleField getTupleField(int i) { none() }

override DynTraitTypeParameter getTypeParameter(int i) {
result = TDynTraitTypeParameter(trait.getGenericParamList().getTypeParam(i))
}

Trait getTrait() { result = trait }

override string toString() { result = "dyn " + trait.getName().toString() }

override Location getLocation() { result = trait.getLocation() }
}

/**
* An [impl Trait in return position][1] type, for example:
*
Expand Down Expand Up @@ -336,6 +360,18 @@ class ArrayTypeParameter extends TypeParameter, TArrayTypeParameter {
override Location getLocation() { result instanceof EmptyLocation }
}

class DynTraitTypeParameter extends TypeParameter, TDynTraitTypeParameter {
private TypeParam typeParam;

DynTraitTypeParameter() { this = TDynTraitTypeParameter(typeParam) }

TypeParam getTypeParam() { result = typeParam }

override string toString() { result = "dyn(" + typeParam.toString() + ")" }

override Location getLocation() { result = typeParam.getLocation() }
}

/** An implicit reference type parameter. */
class RefTypeParameter extends TypeParameter, TRefTypeParameter {
override string toString() { result = "&T" }
Expand Down Expand Up @@ -420,6 +456,13 @@ final class ImplTypeAbstraction extends TypeAbstraction, Impl {
}
}

final class DynTypeAbstraction extends TypeAbstraction, DynTraitTypeRepr {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() =
this.getTrait().getGenericParamList().getATypeParam()
}
}

final class TraitTypeAbstraction extends TypeAbstraction, Trait {
override TypeParameter getATypeParameter() {
result.(TypeParamTypeParameter).getTypeParam() = this.getGenericParamList().getATypeParam()
Expand Down
31 changes: 29 additions & 2 deletions rust/ql/lib/codeql/rust/internal/TypeInference.qll
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ private module Input1 implements InputSig1<Location> {
id = 2
or
kind = 1 and
id = idOfTypeParameterAstNode(tp0.(DynTraitTypeParameter).getTypeParam())
or
kind = 2 and
exists(AstNode node | id = idOfTypeParameterAstNode(node) |
node = tp0.(TypeParamTypeParameter).getTypeParam() or
node = tp0.(AssociatedTypeTypeParameter).getTypeAlias() or
Expand Down Expand Up @@ -182,6 +185,14 @@ private module Input2 implements InputSig2 {
condition = impl and
constraint = impl.getTypeBoundList().getABound().getTypeRepr()
)
or
// a `dyn Trait` type implements `Trait`. See the comment on
// `DynTypeBoundListMention` for further details.
exists(DynTraitTypeRepr object |
abs = object and
condition = object.getTypeBoundList() and
constraint = object.getTrait()
)
}
}

Expand Down Expand Up @@ -1655,10 +1666,16 @@ private Function getMethodFromImpl(MethodCall mc) {

bindingset[trait, name]
pragma[inline_late]
private Function getTraitMethod(ImplTraitReturnType trait, string name) {
private Function getImplTraitMethod(ImplTraitReturnType trait, string name) {
result = getMethodSuccessor(trait.getImplTraitTypeRepr(), name)
}

bindingset[traitObject, name]
pragma[inline_late]
private Function getDynTraitMethod(DynTraitType traitObject, string name) {
result = getMethodSuccessor(traitObject.getTrait(), name)
}

pragma[nomagic]
private Function resolveMethodCallTarget(MethodCall mc) {
// The method comes from an `impl` block targeting the type of the receiver.
Expand All @@ -1669,7 +1686,10 @@ private Function resolveMethodCallTarget(MethodCall mc) {
result = getTypeParameterMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is an `impl Trait` type.
result = getTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
result = getImplTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
or
// The type of the receiver is a trait object `dyn Trait` type.
result = getDynTraitMethod(mc.getTypeAt(TypePath::nil()), mc.getMethodName())
}

pragma[nomagic]
Expand Down Expand Up @@ -2006,6 +2026,13 @@ private module Debug {
result = resolveCallTarget(c)
}

predicate debugConditionSatisfiesConstraint(
TypeAbstraction abs, TypeMention condition, TypeMention constraint
) {
abs = getRelevantLocatable() and
Input2::conditionSatisfiesConstraint(abs, condition, constraint)
}

predicate debugInferImplicitSelfType(SelfParam self, TypePath path, Type t) {
self = getRelevantLocatable() and
t = inferImplicitSelfType(self, path)
Expand Down
61 changes: 61 additions & 0 deletions rust/ql/lib/codeql/rust/internal/TypeMention.qll
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,64 @@ class SelfTypeParameterMention extends TypeMention instanceof Name {
result = TSelfTypeParameter(trait)
}
}

class DynTraitTypeReprMention extends TypeMention instanceof DynTraitTypeRepr {
private DynTraitType dynType;

DynTraitTypeReprMention() {
// This excludes `DynTraitTypeRepr` elements where `getTrait` is not
// defined, i.e., where path resolution can't find a trait.
dynType.getTrait() = super.getTrait()
}

override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result = dynType
or
exists(DynTraitTypeParameter tp, TypePath path0, TypePath suffix |
tp = dynType.getTypeParameter(_) and
path = TypePath::cons(tp, suffix) and
result = super.getTypeBoundList().getBound(0).getTypeRepr().(TypeMention).resolveTypeAt(path0) and
path0.isCons(TTypeParamTypeParameter(tp.getTypeParam()), suffix)
)
}
}

// We want a type of the form `dyn Trait` to implement `Trait`. If `Trait` has
// type parameters then `dyn Trait` has equivalent type parameters and the
// implementation should be abstracted over them.
//
// Intuitively we want something to the effect of:
// ```
// impl<A, B, ..> Trait<A, B, ..> for (dyn Trait)<A, B, ..>
// ```
// To achieve this:
// - `DynTypeAbstraction` is an abstraction over type parameters of the trait.
// - `DynTypeBoundListMention` (this class) is a type mention which has `dyn
// Trait` at the root and which for every type parameter of `dyn Trait` has the
// corresponding type parameter of the trait.
// - `TraitMention` (which is used for other things as well) is a type mention
// for the trait applied to its own type parameters.
//
// We arbitrarily use the `TypeBoundList` inside `DynTraitTypeRepr` to encode
// this type mention, since it doesn't syntactically appear in the AST. This
// works because there is a one-to-one correspondence between a trait object and
// its list of type bounds.
class DynTypeBoundListMention extends TypeMention instanceof TypeBoundList {
private Trait trait;

DynTypeBoundListMention() {
exists(DynTraitTypeRepr dyn | this = dyn.getTypeBoundList() and trait = dyn.getTrait())
}

override Type resolveTypeAt(TypePath path) {
path.isEmpty() and
result.(DynTraitType).getTrait() = trait
or
exists(TypeParam param |
param = trait.getGenericParamList().getATypeParam() and
path = TypePath::singleton(TDynTraitTypeParameter(param)) and
result = TTypeParamTypeParameter(param)
)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
category: minorAnalysis
---
* Type inference now supports trait objects, i.e., `dyn Trait` types.
67 changes: 67 additions & 0 deletions rust/ql/test/library-tests/type-inference/dyn_type.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Test cases for type inference and method resolution with `dyn` types

use std::fmt::Debug;

trait MyTrait1 {
// MyTrait1::m
fn m(&self) -> String;
}

trait GenericGet<A> {
// GenericGet::get
fn get(&self) -> A;
}

#[derive(Clone, Debug)]
struct MyStruct {
value: i32,
}

impl MyTrait1 for MyStruct {
// MyStruct1::m
fn m(&self) -> String {
format!("MyTrait1: {}", self.value) // $ fieldof=MyStruct
}
}

#[derive(Clone, Debug)]
struct GenStruct<A: Clone + Debug> {
value: A,
}

impl<A: Clone + Debug> GenericGet<A> for GenStruct<A> {
// GenStruct<A>::get
fn get(&self) -> A {
self.value.clone() // $ fieldof=GenStruct target=clone
}
}

fn get_a<A, G: GenericGet<A> + ?Sized>(a: &G) -> A {
a.get() // $ target=GenericGet::get
}

fn get_box_trait<A: Clone + Debug + 'static>(a: A) -> Box<dyn GenericGet<A>> {
Box::new(GenStruct { value: a }) // $ target=new
}

fn test_basic_dyn_trait(obj: &dyn MyTrait1) {
let _result = (*obj).m(); // $ target=deref target=MyTrait1::m type=_result:String
}

fn test_generic_dyn_trait(obj: &dyn GenericGet<String>) {
let _result1 = (*obj).get(); // $ target=deref target=GenericGet::get type=_result1:String
let _result2 = get_a(obj); // $ target=get_a type=_result2:String
}

fn test_poly_dyn_trait() {
let obj = get_box_trait(true); // $ target=get_box_trait
let _result = (*obj).get(); // $ target=deref target=GenericGet::get type=_result:bool
}

pub fn test() {
test_basic_dyn_trait(&MyStruct { value: 42 }); // $ target=test_basic_dyn_trait
test_generic_dyn_trait(&GenStruct {
value: "".to_string(),
}); // $ target=test_generic_dyn_trait
test_poly_dyn_trait(); // $ target=test_poly_dyn_trait
}
8 changes: 5 additions & 3 deletions rust/ql/test/library-tests/type-inference/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2265,8 +2265,6 @@ mod loops {
}
}

mod dereference;

mod explicit_type_args {
struct S1<T>(T);

Expand Down Expand Up @@ -2418,6 +2416,9 @@ mod closures {
}
}

mod dereference;
mod dyn_type;

fn main() {
field_access::f(); // $ target=f
method_impl::f(); // $ target=f
Expand Down Expand Up @@ -2448,5 +2449,6 @@ fn main() {
dereference::test(); // $ target=test
pattern_matching::test_all_patterns(); // $ target=test_all_patterns
pattern_matching_experimental::box_patterns(); // $ target=box_patterns
closures::f() // $ target=f
closures::f(); // $ target=f
dyn_type::test(); // $ target=test
}
Loading