Skip to content

Commit 6d470cd

Browse files
committed
fix: handle multiple params lists in for infer type
1 parent 782c539 commit 6d470cd

File tree

4 files changed

+85
-43
lines changed

4 files changed

+85
-43
lines changed

presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ object ApplyExtractor:
4545

4646
object ApplyArgsExtractor:
4747
def getArgsAndParams(
48-
indexedContext: IndexedContext,
48+
optIndexedContext: Option[IndexedContext],
4949
apply: Apply,
5050
span: Span
5151
)(using Context): List[(List[Tree], List[ParamSymbol])] =
@@ -79,6 +79,7 @@ object ApplyArgsExtractor:
7979
// fallback for when multiple overloaded methods match the supplied args
8080
def fallbackFindMatchingMethods() =
8181
def matchingMethodsSymbols(
82+
indexedContext: IndexedContext,
8283
method: Tree
8384
): List[Symbol] =
8485
method match
@@ -94,11 +95,12 @@ object ApplyArgsExtractor:
9495
case single: SingleDenotation => List(single.symbol)
9596
case multi: MultiDenotation => multi.allSymbols
9697
}.getOrElse(Nil)
97-
case Apply(fun, _) => matchingMethodsSymbols(fun)
98+
case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun)
9899
case _ => Nil
99100
val matchingMethods =
100101
for
101-
potentialMatch <- matchingMethodsSymbols(method)
102+
indexedContext <- optIndexedContext.toList
103+
potentialMatch <- matchingMethodsSymbols(indexedContext, method)
102104
if potentialMatch.is(Flags.Method) &&
103105
potentialMatch.vparamss.length >= argss.length &&
104106
Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption

presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala

Lines changed: 34 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
package dotty.tools.pc
22

3-
import dotty.tools.dotc.ast.tpd
43
import dotty.tools.dotc.ast.tpd.*
5-
import dotty.tools.dotc.core.Constants.Constant
64
import dotty.tools.dotc.core.Contexts.Context
75
import dotty.tools.dotc.core.Flags
86
import dotty.tools.dotc.core.StdNames
9-
import dotty.tools.dotc.core.Symbols
107
import dotty.tools.dotc.core.Symbols.defn
118
import dotty.tools.dotc.core.Types.*
12-
import dotty.tools.dotc.core.Types.Type
139
import dotty.tools.dotc.interactive.Interactive
1410
import dotty.tools.dotc.interactive.InteractiveDriver
1511
import dotty.tools.dotc.typer.Applications.UnapplyArgs
@@ -76,7 +72,7 @@ object InterCompletionType:
7672
case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span)
7773
case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span)
7874
case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span)
79-
case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType)
75+
case If(cond, _, _) :: rest if cond.span.contains(span) => Some(defn.BooleanType)
8076
case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) =>
8177
inferType(rest, span)
8278
case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span)
@@ -97,39 +93,38 @@ object InterCompletionType:
9793
if ind < 0 then None
9894
else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind))
9995
// f(@@)
100-
case (app: Apply) :: rest =>
101-
val param =
102-
for {
103-
ind <- app.args.zipWithIndex.collectFirst {
104-
case (arg, id) if arg.span.contains(span) => id
105-
}
106-
params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam))
107-
param <- params.get(ind)
108-
} yield param.info
109-
param match
110-
// def f[T](a: T): T = ???
111-
// f[Int](@@)
112-
// val _: Int = f(@@)
113-
case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) =>
114-
for {
115-
(typeParams, args) <-
116-
app match
117-
case Apply(TypeApply(fun, args), _) =>
118-
val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
119-
typeParams.map((_, args.map(_.tpe)))
120-
// val f: (j: "a") => Int
121-
// f(@@)
122-
case Apply(Select(v, StdNames.nme.apply), _) =>
123-
v.symbol.info match
124-
case AppliedType(des, args) =>
125-
Some((des.typeSymbol.typeParams, args))
126-
case _ => None
127-
case _ => None
128-
ind = typeParams.indexOf(t.symbol)
129-
tpe <- args.get(ind)
130-
if !tpe.isErroneous
131-
} yield tpe
132-
case Some(tpe) => Some(tpe)
133-
case _ => None
96+
case ApplyExtractor(app) =>
97+
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption
98+
argsAndParams.flatMap:
99+
case (args, params) =>
100+
val idx = args.indexWhere(_.span.contains(span))
101+
val param =
102+
if idx >= 0 && params.length > idx then Some(params(idx).info)
103+
else None
104+
param match
105+
// def f[T](a: T): T = ???
106+
// f[Int](@@)
107+
// val _: Int = f(@@)
108+
case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) =>
109+
for
110+
(typeParams, args) <-
111+
app match
112+
case Apply(TypeApply(fun, args), _) =>
113+
val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam))
114+
typeParams.map((_, args.map(_.tpe)))
115+
// val f: (j: "a") => Int
116+
// f(@@)
117+
case Apply(Select(v, StdNames.nme.apply), _) =>
118+
v.symbol.info match
119+
case AppliedType(des, args) =>
120+
Some((des.typeSymbol.typeParams, args))
121+
case _ => None
122+
case _ => None
123+
ind = typeParams.indexOf(t.symbol)
124+
tpe <- args.get(ind)
125+
if !tpe.isErroneous
126+
yield tpe
127+
case Some(tpe) => Some(tpe)
128+
case _ => None
134129
case _ => None
135130

presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ object NamedArgCompletions:
7474
case _ => false
7575

7676
val argsAndParams = ApplyArgsExtractor.getArgsAndParams(
77-
indexedContext,
77+
Some(indexedContext),
7878
apply,
7979
ident.span
8080
)

presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,3 +290,48 @@ class InferExpectedTypeSuite extends BasePCSuite:
290290
"""|C
291291
|""".stripMargin // ideally A
292292
)
293+
294+
@Test def `multiple-args-lists` =
295+
check(
296+
"""|def m(i: Int)(s: String) = ???
297+
|val x = m(@@)
298+
|""".stripMargin,
299+
"""|Int
300+
|""".stripMargin
301+
)
302+
303+
@Test def `multiple-args-lists-2` =
304+
check(
305+
"""|def m(i: Int)(s: String) = ???
306+
|val x = m(1)(@@)
307+
|""".stripMargin,
308+
"""|String
309+
|""".stripMargin
310+
)
311+
312+
@Test def `extension-methods` =
313+
check(
314+
"""|extension (i: Int) {
315+
| def method(s: String): Unit = ()
316+
|}
317+
|
318+
|def testIt =
319+
| 7.method(@@)
320+
|""".stripMargin,
321+
"""|String
322+
|""".stripMargin
323+
)
324+
325+
@Test def `implicit-methods` =
326+
check(
327+
"""|object I
328+
|implicit class Xtension(i: I.type) {
329+
| def method(s: String): Unit = ()
330+
|}
331+
|
332+
|def testIt =
333+
| I.method(@@)
334+
|""".stripMargin,
335+
"""|String
336+
|""".stripMargin
337+
)

0 commit comments

Comments
 (0)