diff --git a/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala b/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala new file mode 100644 index 000000000000..9384a0b43e8b --- /dev/null +++ b/presentation-compiler/src/main/dotty/tools/pc/ApplyArgsExtractor.scala @@ -0,0 +1,269 @@ +package dotty.tools.pc + +import scala.util.Try + +import dotty.tools.dotc.ast.Trees.ValDef +import dotty.tools.dotc.ast.tpd.* +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Flags.Method +import dotty.tools.dotc.core.Names.Name +import dotty.tools.dotc.core.StdNames.* +import dotty.tools.dotc.core.SymDenotations.NoDenotation +import dotty.tools.dotc.core.Symbols.defn +import dotty.tools.dotc.core.Symbols.NoSymbol +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.* +import dotty.tools.pc.IndexedContext +import dotty.tools.pc.utils.InteractiveEnrichments.* +import scala.annotation.tailrec +import dotty.tools.dotc.core.Denotations.SingleDenotation +import dotty.tools.dotc.core.Denotations.MultiDenotation +import dotty.tools.dotc.util.Spans.Span + +object ApplyExtractor: + def unapply(path: List[Tree])(using Context): Option[Apply] = + path match + case ValDef(_, _, _) :: Block(_, app: Apply) :: _ + if !app.fun.isInfix => Some(app) + case rest => + def getApplyForContextFunctionParam(path: List[Tree]): Option[Apply] = + path match + // fun(arg@@) + case (app: Apply) :: _ => Some(app) + // fun(arg@@), where fun(argn: Context ?=> SomeType) + // recursively matched for multiple context arguments, e.g. Context1 ?=> Context2 ?=> SomeType + case (_: DefDef) :: Block(List(_), _: Closure) :: rest => + getApplyForContextFunctionParam(rest) + case _ => None + for + app <- getApplyForContextFunctionParam(rest) + if !app.fun.isInfix + yield app + end match + + +object ApplyArgsExtractor: + def getArgsAndParams( + optIndexedContext: Option[IndexedContext], + apply: Apply, + span: Span + )(using Context): List[(List[Tree], List[ParamSymbol])] = + def collectArgss(a: Apply): List[List[Tree]] = + def stripContextFuntionArgument(argument: Tree): List[Tree] = + argument match + case Block(List(d: DefDef), _: Closure) => + d.rhs match + case app: Apply => + app.args + case b @ Block(List(_: DefDef), _: Closure) => + stripContextFuntionArgument(b) + case _ => Nil + case v => List(v) + + val args = a.args.flatMap(stripContextFuntionArgument) + a.fun match + case app: Apply => collectArgss(app) :+ args + case _ => List(args) + end collectArgss + + val method = apply.fun + + val argss = collectArgss(apply) + + def fallbackFindApply(sym: Symbol) = + sym.info.member(nme.apply) match + case NoDenotation => Nil + case den => List(den.symbol) + + // fallback for when multiple overloaded methods match the supplied args + def fallbackFindMatchingMethods() = + def matchingMethodsSymbols( + indexedContext: IndexedContext, + method: Tree + ): List[Symbol] = + method match + case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil) + case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil) + case sel @ Select(from, name) => + val symbol = from.symbol + val ownerSymbol = + if symbol.is(Method) && symbol.owner.isClass then + Some(symbol.owner) + else Try(symbol.info.classSymbol).toOption + ownerSymbol.map(sym => sym.info.member(name)).collect{ + case single: SingleDenotation => List(single.symbol) + case multi: MultiDenotation => multi.allSymbols + }.getOrElse(Nil) + case Apply(fun, _) => matchingMethodsSymbols(indexedContext, fun) + case _ => Nil + val matchingMethods = + for + indexedContext <- optIndexedContext.toList + potentialMatch <- matchingMethodsSymbols(indexedContext, method) + if potentialMatch.is(Flags.Method) && + potentialMatch.vparamss.length >= argss.length && + Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption + .getOrElse(false) && + potentialMatch.vparamss + .zip(argss) + .reverse + .zipWithIndex + .forall { case (pair, index) => + FuzzyArgMatcher(potentialMatch.tparams) + .doMatch(allArgsProvided = index != 0, span) + .tupled(pair) + } + yield potentialMatch + matchingMethods + end fallbackFindMatchingMethods + + val matchingMethods: List[Symbol] = + if method.symbol.paramSymss.nonEmpty then + val allArgsAreSupplied = + val vparamss = method.symbol.vparamss + vparamss.length == argss.length && vparamss + .zip(argss) + .lastOption + .exists { case (baseParams, baseArgs) => + baseArgs.length == baseParams.length + } + // ``` + // m(arg : Int) + // m(arg : Int, anotherArg : Int) + // m(a@@) + // ``` + // complier will choose the first `m`, so we need to manually look for the other one + if allArgsAreSupplied then + val foundPotential = fallbackFindMatchingMethods() + if foundPotential.contains(method.symbol) then foundPotential + else method.symbol :: foundPotential + else List(method.symbol) + else if method.symbol.is(Method) || method.symbol == NoSymbol then + fallbackFindMatchingMethods() + else fallbackFindApply(method.symbol) + end if + end matchingMethods + + matchingMethods.map { methodSym => + val vparamss = methodSym.vparamss + + // get params and args we are interested in + // e.g. + // in the following case, the interesting args and params are + // - params: [apple, banana] + // - args: [apple, b] + // ``` + // def curry(x: Int)(apple: String, banana: String) = ??? + // curry(1)(apple = "test", b@@) + // ``` + val (baseParams0, baseArgs) = + vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) + + val baseParams: List[ParamSymbol] = + def defaultBaseParams = baseParams0.map(JustSymbol(_)) + @tailrec + def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] = + if level > 0 then + val resultTypeOpt = + refinedType match + case RefinedType(AppliedType(_, args), _, _) => args.lastOption + case AppliedType(_, args) => args.lastOption + case _ => None + resultTypeOpt match + case Some(resultType) => getRefinedParams(resultType, level - 1) + case _ => defaultBaseParams + else + refinedType match + case RefinedType(AppliedType(_, args), _, MethodType(ri)) => + baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) => + RefinedSymbol(sym, name, arg) + } + case _ => defaultBaseParams + // finds param refinements for lambda expressions + // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x) + @tailrec + def refineParams(method: Tree, level: Int): List[ParamSymbol] = + method match + case Select(Apply(f, _), _) => refineParams(f, level + 1) + case Select(h, name) => + // for Select(foo, name = apply) we want `foo.symbol` + if name == nme.apply then getRefinedParams(h.symbol.info, level) + else getRefinedParams(method.symbol.info, level) + case Apply(f, _) => + refineParams(f, level + 1) + case _ => getRefinedParams(method.symbol.info, level) + refineParams(method, 0) + end baseParams + (baseArgs, baseParams) + } + + extension (method: Symbol) + def vparamss(using Context) = method.filteredParamss(_.isTerm) + def tparams(using Context) = method.filteredParamss(_.isType).flatten + def filteredParamss(f: Symbol => Boolean)(using Context) = + method.paramSymss.filter(params => params.forall(f)) +sealed trait ParamSymbol: + def name: Name + def info: Type + def symbol: Symbol + def nameBackticked(using Context) = name.decoded.backticked + +case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol: + def name: Name = symbol.name + def info: Type = symbol.info + +case class RefinedSymbol(symbol: Symbol, name: Name, info: Type) + extends ParamSymbol + + +class FuzzyArgMatcher(tparams: List[Symbol])(using Context): + + /** + * A heuristic for checking if the passed arguments match the method's arguments' types. + * For non-polymorphic methods we use the subtype relation (`<:<`) + * and for polymorphic methods we use a heuristic. + * We check the args types not the result type. + */ + def doMatch( + allArgsProvided: Boolean, + span: Span + )(expectedArgs: List[Symbol], actualArgs: List[Tree]) = + (expectedArgs.length == actualArgs.length || + (!allArgsProvided && expectedArgs.length >= actualArgs.length)) && + actualArgs.zipWithIndex.forall { + case (arg: Ident, _) if arg.span.contains(span) => true + case (NamedArg(name, arg), _) => + expectedArgs.exists { expected => + expected.name == name && (!arg.hasType || arg.typeOpt.unfold + .fuzzyArg_<:<(expected.info)) + } + case (arg, i) => + !arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info) + } + + extension (arg: Type) + def fuzzyArg_<:<(expected: Type) = + if tparams.isEmpty then arg <:< expected + else arg <:< substituteTypeParams(expected) + def unfold = + arg match + case arg: TermRef => arg.underlying + case e => e + + private def substituteTypeParams(t: Type): Type = + t match + case e if tparams.exists(_ == e.typeSymbol) => + val matchingParam = tparams.find(_ == e.typeSymbol).get + matchingParam.info match + case b @ TypeBounds(_, _) => WildcardType(b) + case _ => WildcardType + case o @ OrType(e1, e2) => + OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft) + case AndType(e1, e2) => + AndType(substituteTypeParams(e1), substituteTypeParams(e2)) + case AppliedType(et, eparams) => + AppliedType(et, eparams.map(substituteTypeParams)) + case _ => t + +end FuzzyArgMatcher diff --git a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala index 075167f3f5c1..63ca7fdf8641 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/InferExpectedType.scala @@ -1,15 +1,11 @@ package dotty.tools.pc -import dotty.tools.dotc.ast.tpd import dotty.tools.dotc.ast.tpd.* -import dotty.tools.dotc.core.Constants.Constant import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Flags import dotty.tools.dotc.core.StdNames -import dotty.tools.dotc.core.Symbols import dotty.tools.dotc.core.Symbols.defn import dotty.tools.dotc.core.Types.* -import dotty.tools.dotc.core.Types.Type import dotty.tools.dotc.interactive.Interactive import dotty.tools.dotc.interactive.InteractiveDriver import dotty.tools.dotc.typer.Applications.UnapplyArgs @@ -76,7 +72,7 @@ object InterCompletionType: case Try(block, _, _) :: rest if block.span.contains(span) => inferType(rest, span) case CaseDef(_, _, body) :: Try(_, cases, _) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span) case If(cond, _, _) :: rest if !cond.span.contains(span) => inferType(rest, span) - case If(cond, _, _) :: rest if cond.span.contains(span) => Some(Symbols.defn.BooleanType) + case If(cond, _, _) :: rest if cond.span.contains(span) => Some(defn.BooleanType) case CaseDef(_, _, body) :: Match(_, cases) :: rest if body.span.contains(span) && cases.exists(_.span.contains(span)) => inferType(rest, span) case NamedArg(_, arg) :: rest if arg.span.contains(span) => inferType(rest, span) @@ -97,39 +93,38 @@ object InterCompletionType: if ind < 0 then None else Some(UnapplyArgs(fun.tpe.finalResultType, fun, pats, NoSourcePosition).argTypes(ind)) // f(@@) - case (app: Apply) :: rest => - val param = - for { - ind <- app.args.zipWithIndex.collectFirst { - case (arg, id) if arg.span.contains(span) => id - } - params <- app.symbol.paramSymss.find(!_.exists(_.isTypeParam)) - param <- params.get(ind) - } yield param.info - param match - // def f[T](a: T): T = ??? - // f[Int](@@) - // val _: Int = f(@@) - case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) => - for { - (typeParams, args) <- - app match - case Apply(TypeApply(fun, args), _) => - val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam)) - typeParams.map((_, args.map(_.tpe))) - // val f: (j: "a") => Int - // f(@@) - case Apply(Select(v, StdNames.nme.apply), _) => - v.symbol.info match - case AppliedType(des, args) => - Some((des.typeSymbol.typeParams, args)) - case _ => None - case _ => None - ind = typeParams.indexOf(t.symbol) - tpe <- args.get(ind) - if !tpe.isErroneous - } yield tpe - case Some(tpe) => Some(tpe) - case _ => None + case ApplyExtractor(app) => + val argsAndParams = ApplyArgsExtractor.getArgsAndParams(None, app, span).headOption + argsAndParams.flatMap: + case (args, params) => + val idx = args.indexWhere(_.span.contains(span)) + val param = + if idx >= 0 && params.length > idx then Some(params(idx).info) + else None + param match + // def f[T](a: T): T = ??? + // f[Int](@@) + // val _: Int = f(@@) + case Some(t : TypeRef) if t.symbol.is(Flags.TypeParam) => + for + (typeParams, args) <- + app match + case Apply(TypeApply(fun, args), _) => + val typeParams = fun.symbol.paramSymss.headOption.filter(_.forall(_.isTypeParam)) + typeParams.map((_, args.map(_.tpe))) + // val f: (j: "a") => Int + // f(@@) + case Apply(Select(v, StdNames.nme.apply), _) => + v.symbol.info match + case AppliedType(des, args) => + Some((des.typeSymbol.typeParams, args)) + case _ => None + case _ => None + ind = typeParams.indexOf(t.symbol) + tpe <- args.get(ind) + if !tpe.isErroneous + yield tpe + case Some(tpe) => Some(tpe) + case _ => None case _ => None diff --git a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala index 7b88d8edfbc8..faf6d715d8cf 100644 --- a/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala +++ b/presentation-compiler/src/main/dotty/tools/pc/completions/NamedArgCompletions.scala @@ -2,37 +2,23 @@ package dotty.tools.pc.completions import scala.util.Try -import dotty.tools.dotc.ast.Trees.ValDef import dotty.tools.dotc.ast.tpd.* import dotty.tools.dotc.ast.untpd import dotty.tools.dotc.core.Constants.Constant -import dotty.tools.dotc.core.ContextOps.localContext import dotty.tools.dotc.core.Contexts.Context import dotty.tools.dotc.core.Flags -import dotty.tools.dotc.core.Flags.Method import dotty.tools.dotc.core.NameKinds.DefaultGetterName import dotty.tools.dotc.core.Names.Name import dotty.tools.dotc.core.StdNames.* -import dotty.tools.dotc.core.SymDenotations.NoDenotation -import dotty.tools.dotc.core.Symbols import dotty.tools.dotc.core.Symbols.defn -import dotty.tools.dotc.core.Symbols.NoSymbol import dotty.tools.dotc.core.Symbols.Symbol -import dotty.tools.dotc.core.Types.AndType -import dotty.tools.dotc.core.Types.AppliedType -import dotty.tools.dotc.core.Types.MethodType -import dotty.tools.dotc.core.Types.OrType -import dotty.tools.dotc.core.Types.RefinedType -import dotty.tools.dotc.core.Types.TermRef -import dotty.tools.dotc.core.Types.Type -import dotty.tools.dotc.core.Types.TypeBounds -import dotty.tools.dotc.core.Types.WildcardType +import dotty.tools.dotc.core.Types.* import dotty.tools.pc.IndexedContext import dotty.tools.pc.utils.InteractiveEnrichments.* import scala.annotation.tailrec -import dotty.tools.dotc.core.Denotations.Denotation -import dotty.tools.dotc.core.Denotations.MultiDenotation -import dotty.tools.dotc.core.Denotations.SingleDenotation +import dotty.tools.pc.ApplyArgsExtractor +import dotty.tools.pc.ParamSymbol +import dotty.tools.pc.ApplyExtractor object NamedArgCompletions: @@ -43,36 +29,13 @@ object NamedArgCompletions: clientSupportsSnippets: Boolean, )(using ctx: Context): List[CompletionValue] = path match - case (ident: Ident) :: ValDef(_, _, _) :: Block(_, app: Apply) :: _ - if !app.fun.isInfix => + case (ident: Ident) :: ApplyExtractor(app) => contribute( - Some(ident), + ident, app, indexedContext, clientSupportsSnippets, ) - case (ident: Ident) :: rest => - def getApplyForContextFunctionParam(path: List[Tree]): Option[Apply] = - path match - // fun(arg@@) - case (app: Apply) :: _ => Some(app) - // fun(arg@@), where fun(argn: Context ?=> SomeType) - // recursively matched for multiple context arguments, e.g. Context1 ?=> Context2 ?=> SomeType - case (_: DefDef) :: Block(List(_), _: Closure) :: rest => - getApplyForContextFunctionParam(rest) - case _ => None - val contribution = - for - app <- getApplyForContextFunctionParam(rest) - if !app.fun.isInfix - yield - contribute( - Some(ident), - app, - indexedContext, - clientSupportsSnippets, - ) - contribution.getOrElse(Nil) case (app: Apply) :: _ => /** * def foo(aaa: Int, bbb: Int, ccc: Int) = ??? @@ -86,7 +49,7 @@ object NamedArgCompletions: untypedPath match case (ident: Ident) :: (app: Apply) :: _ => contribute( - Some(ident), + ident, app, indexedContext, clientSupportsSnippets, @@ -99,7 +62,7 @@ object NamedArgCompletions: end contribute private def contribute( - ident: Option[Ident], + ident: Ident, apply: Apply, indexedContext: IndexedContext, clientSupportsSnippets: Boolean, @@ -110,155 +73,14 @@ object NamedArgCompletions: case Literal(Constant(null)) => true // nullLiteral case _ => false - def collectArgss(a: Apply): List[List[Tree]] = - def stripContextFuntionArgument(argument: Tree): List[Tree] = - argument match - case Block(List(d: DefDef), _: Closure) => - d.rhs match - case app: Apply => - app.args - case b @ Block(List(_: DefDef), _: Closure) => - stripContextFuntionArgument(b) - case _ => Nil - case v => List(v) - - val args = a.args.flatMap(stripContextFuntionArgument) - a.fun match - case app: Apply => collectArgss(app) :+ args - case _ => List(args) - end collectArgss - - val method = apply.fun - - val argss = collectArgss(apply) - - def fallbackFindApply(sym: Symbol) = - sym.info.member(nme.apply) match - case NoDenotation => Nil - case den => List(den.symbol) - - // fallback for when multiple overloaded methods match the supplied args - def fallbackFindMatchingMethods() = - def matchingMethodsSymbols( - method: Tree - ): List[Symbol] = - method match - case Ident(name) => indexedContext.findSymbol(name).getOrElse(Nil) - case Select(This(_), name) => indexedContext.findSymbol(name).getOrElse(Nil) - case sel @ Select(from, name) => - val symbol = from.symbol - val ownerSymbol = - if symbol.is(Method) && symbol.owner.isClass then - Some(symbol.owner) - else Try(symbol.info.classSymbol).toOption - ownerSymbol.map(sym => sym.info.member(name)).collect{ - case single: SingleDenotation => List(single.symbol) - case multi: MultiDenotation => multi.allSymbols - }.getOrElse(Nil) - case Apply(fun, _) => matchingMethodsSymbols(fun) - case _ => Nil - val matchingMethods = - for - potentialMatch <- matchingMethodsSymbols(method) - if potentialMatch.is(Flags.Method) && - potentialMatch.vparamss.length >= argss.length && - Try(potentialMatch.isAccessibleFrom(apply.symbol.info)).toOption - .getOrElse(false) && - potentialMatch.vparamss - .zip(argss) - .reverse - .zipWithIndex - .forall { case (pair, index) => - FuzzyArgMatcher(potentialMatch.tparams) - .doMatch(allArgsProvided = index != 0, ident) - .tupled(pair) - } - yield potentialMatch - matchingMethods - end fallbackFindMatchingMethods - - val matchingMethods: List[Symbols.Symbol] = - if method.symbol.paramSymss.nonEmpty then - val allArgsAreSupplied = - val vparamss = method.symbol.vparamss - vparamss.length == argss.length && vparamss - .zip(argss) - .lastOption - .exists { case (baseParams, baseArgs) => - baseArgs.length == baseParams.length - } - // ``` - // m(arg : Int) - // m(arg : Int, anotherArg : Int) - // m(a@@) - // ``` - // complier will choose the first `m`, so we need to manually look for the other one - if allArgsAreSupplied then - val foundPotential = fallbackFindMatchingMethods() - if foundPotential.contains(method.symbol) then foundPotential - else method.symbol :: foundPotential - else List(method.symbol) - else if method.symbol.is(Method) || method.symbol == NoSymbol then - fallbackFindMatchingMethods() - else fallbackFindApply(method.symbol) - end if - end matchingMethods - - val allParams = matchingMethods.flatMap { methodSym => - val vparamss = methodSym.vparamss + val argsAndParams = ApplyArgsExtractor.getArgsAndParams( + Some(indexedContext), + apply, + ident.span + ) - // get params and args we are interested in - // e.g. - // in the following case, the interesting args and params are - // - params: [apple, banana] - // - args: [apple, b] - // ``` - // def curry(x: Int)(apple: String, banana: String) = ??? - // curry(1)(apple = "test", b@@) - // ``` - val (baseParams0, baseArgs) = - vparamss.zip(argss).lastOption.getOrElse((Nil, Nil)) - - val baseParams: List[ParamSymbol] = - def defaultBaseParams = baseParams0.map(JustSymbol(_)) - @tailrec - def getRefinedParams(refinedType: Type, level: Int): List[ParamSymbol] = - if level > 0 then - val resultTypeOpt = - refinedType match - case RefinedType(AppliedType(_, args), _, _) => args.lastOption - case AppliedType(_, args) => args.lastOption - case _ => None - resultTypeOpt match - case Some(resultType) => getRefinedParams(resultType, level - 1) - case _ => defaultBaseParams - else - refinedType match - case RefinedType(AppliedType(_, args), _, MethodType(ri)) => - baseParams0.zip(ri).zip(args).map { case ((sym, name), arg) => - RefinedSymbol(sym, name, arg) - } - case _ => defaultBaseParams - // finds param refinements for lambda expressions - // val hello: (x: Int, y: Int) => Unit = (x, _) => println(x) - @tailrec - def refineParams(method: Tree, level: Int): List[ParamSymbol] = - method match - case Select(Apply(f, _), _) => refineParams(f, level + 1) - case Select(h, name) => - // for Select(foo, name = apply) we want `foo.symbol` - if name == nme.apply then getRefinedParams(h.symbol.info, level) - else getRefinedParams(method.symbol.info, level) - case Apply(f, _) => - refineParams(f, level + 1) - case _ => getRefinedParams(method.symbol.info, level) - refineParams(method, 0) - end baseParams - - val args = ident - .map(i => baseArgs.filterNot(_ == i)) - .getOrElse(baseArgs) - .filterNot(isUselessLiteral) + val allParams = argsAndParams.flatMap { case (baseArgs, baseParams) => + val args = baseArgs.filterNot( a => a == ident || isUselessLiteral(a)) @tailrec def isDefaultArg(t: Tree): Boolean = t match @@ -293,9 +115,8 @@ object NamedArgCompletions: ) } - val prefix = ident - .map(_.name.toString) - .getOrElse("") + val prefix = + ident.name.toString .replace(Cursor.value, "") .nn @@ -330,7 +151,7 @@ object NamedArgCompletions: allParams.exists(param => param.name.startsWith(prefix)) def isExplicitlyCalled = suffix.startsWith(prefix) def hasParamsToFill = allParams.count(!_.symbol.is(Flags.HasDefault)) > 1 - if clientSupportsSnippets && matchingMethods.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill + if clientSupportsSnippets && argsAndParams.length == 1 && (shouldShow || isExplicitlyCalled) && hasParamsToFill then val editText = allParams.zipWithIndex .collect { @@ -375,73 +196,4 @@ object NamedArgCompletions: ) ::: findPossibleDefaults() ::: fillAllFields() end contribute - extension (method: Symbols.Symbol) - def vparamss(using Context) = method.filteredParamss(_.isTerm) - def tparams(using Context) = method.filteredParamss(_.isType).flatten - def filteredParamss(f: Symbols.Symbol => Boolean)(using Context) = - method.paramSymss.filter(params => params.forall(f)) end NamedArgCompletions - -class FuzzyArgMatcher(tparams: List[Symbols.Symbol])(using Context): - - /** - * A heuristic for checking if the passed arguments match the method's arguments' types. - * For non-polymorphic methods we use the subtype relation (`<:<`) - * and for polymorphic methods we use a heuristic. - * We check the args types not the result type. - */ - def doMatch( - allArgsProvided: Boolean, - ident: Option[Ident] - )(expectedArgs: List[Symbols.Symbol], actualArgs: List[Tree]) = - (expectedArgs.length == actualArgs.length || - (!allArgsProvided && expectedArgs.length >= actualArgs.length)) && - actualArgs.zipWithIndex.forall { - case (arg: Ident, _) if ident.contains(arg) => true - case (NamedArg(name, arg), _) => - expectedArgs.exists { expected => - expected.name == name && (!arg.hasType || arg.typeOpt.unfold - .fuzzyArg_<:<(expected.info)) - } - case (arg, i) => - !arg.hasType || arg.typeOpt.unfold.fuzzyArg_<:<(expectedArgs(i).info) - } - - extension (arg: Type) - def fuzzyArg_<:<(expected: Type) = - if tparams.isEmpty then arg <:< expected - else arg <:< substituteTypeParams(expected) - def unfold = - arg match - case arg: TermRef => arg.underlying - case e => e - - private def substituteTypeParams(t: Type): Type = - t match - case e if tparams.exists(_ == e.typeSymbol) => - val matchingParam = tparams.find(_ == e.typeSymbol).get - matchingParam.info match - case b @ TypeBounds(_, _) => WildcardType(b) - case _ => WildcardType - case o @ OrType(e1, e2) => - OrType(substituteTypeParams(e1), substituteTypeParams(e2), o.isSoft) - case AndType(e1, e2) => - AndType(substituteTypeParams(e1), substituteTypeParams(e2)) - case AppliedType(et, eparams) => - AppliedType(et, eparams.map(substituteTypeParams)) - case _ => t - -end FuzzyArgMatcher - -sealed trait ParamSymbol: - def name: Name - def info: Type - def symbol: Symbol - def nameBackticked(using Context) = name.decoded.backticked - -case class JustSymbol(symbol: Symbol)(using Context) extends ParamSymbol: - def name: Name = symbol.name - def info: Type = symbol.info - -case class RefinedSymbol(symbol: Symbol, name: Name, info: Type) - extends ParamSymbol diff --git a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala index ccdc68ef1cad..ba96488471b6 100644 --- a/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala +++ b/presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala @@ -290,3 +290,48 @@ class InferExpectedTypeSuite extends BasePCSuite: """|C |""".stripMargin // ideally A ) + + @Test def `multiple-args-lists` = + check( + """|def m(i: Int)(s: String) = ??? + |val x = m(@@) + |""".stripMargin, + """|Int + |""".stripMargin + ) + + @Test def `multiple-args-lists-2` = + check( + """|def m(i: Int)(s: String) = ??? + |val x = m(1)(@@) + |""".stripMargin, + """|String + |""".stripMargin + ) + + @Test def `extension-methods` = + check( + """|extension (i: Int) { + | def method(s: String): Unit = () + |} + | + |def testIt = + | 7.method(@@) + |""".stripMargin, + """|String + |""".stripMargin + ) + + @Test def `implicit-methods` = + check( + """|object I + |implicit class Xtension(i: I.type) { + | def method(s: String): Unit = () + |} + | + |def testIt = + | I.method(@@) + |""".stripMargin, + """|String + |""".stripMargin + )