diff --git a/tasty-query/shared/src/main/scala/tastyquery/Flags.scala b/tasty-query/shared/src/main/scala/tastyquery/Flags.scala index 6ead2781..ef857437 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Flags.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Flags.scala @@ -55,6 +55,7 @@ private[tastyquery] object Flags: val Extension: Flag = newFlag("Extension") val Final: Flag = newFlag("Final") val Given: Flag = newFlag("Given") + val HasDefault: Flag = newFlag("HasDefault") val Implicit: Flag = newFlag("Implicit") val Infix: Flag = newFlag("Infix") val Inline: Flag = newFlag("Inline") diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index 870b48b3..6e742b25 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -583,6 +583,17 @@ object Symbols { /** Is this symbol an exporter generated by an `export` statement? */ final def isExport: Boolean = flags.is(Exported) + /** Is this symbol a method with at least one parameter with a default value? */ + final def hasParamWithDefault: Boolean = + paramSymss.exists { + case Left(termParams) => termParams.exists(_.isParamWithDefault) + case Right(value) => false + } + end hasParamWithDefault + + /** Is this symbol a method parameter with a default value? */ + final def isParamWithDefault: Boolean = flags.isAllOf(HasDefault) + /** Get the module class of this module value definition, if it exists: * - for `object val C` => `object class C[$]` */ diff --git a/tasty-query/shared/src/main/scala/tastyquery/reader/pickles/PickleReader.scala b/tasty-query/shared/src/main/scala/tastyquery/reader/pickles/PickleReader.scala index ef2d988b..0b45ddaf 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/reader/pickles/PickleReader.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/reader/pickles/PickleReader.scala @@ -487,7 +487,7 @@ private[pickles] class PickleReader { if pickleFlags.isStable then flags |= StableRealizable if pickleFlags.isStatic then flags |= Static if pickleFlags.isCaseAccessor then flags |= CaseAccessor - // if pickleFlags.hasDefault then flags |= HasDefault + if pickleFlags.hasDefault then flags |= HasDefault if pickleFlags.isTrait then flags |= Trait // if pickleFlags.isBridge then flags |= Bridge if pickleFlags.isAccessor then flags |= Accessor diff --git a/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala b/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala index 3ab6c05c..cfead58b 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/reader/tasties/TreeUnpickler.scala @@ -312,7 +312,7 @@ private[tasties] class TreeUnpickler private ( case CASEaccessor => addFlag(CaseAccessor) case COVARIANT => addFlag(Covariant) case CONTRAVARIANT => addFlag(Contravariant) - case HASDEFAULT => ignoreFlag() + case HASDEFAULT => addFlag(HasDefault) case STABLE => addFlag(StableRealizable) case EXTENSION => addFlag(Extension) case GIVEN => addFlag(Given) diff --git a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala index db20727c..4be6a06e 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala @@ -1051,18 +1051,45 @@ class TypeSuite extends UnrestrictedUnpicklingSuite { throw AssertionError(s"unexpected type $tpe") } - testWithContext("scala-2-default-params") { + testWithContext("default-params") { + extension (sym: TermSymbol) def paramCount: Int = sym.declaredType.asInstanceOf[MethodType].paramNames.size + val DefaultParamsClass = ctx.findTopLevelClass("simple_trees.DefaultParams") assert(clue(DefaultParamsClass.getNonOverloadedDecl(DefaultGetterName(termName("foo"), 0))).isEmpty) DefaultParamsClass.findNonOverloadedDecl(DefaultGetterName(termName("foo"), 1)) DefaultParamsClass.findNonOverloadedDecl(DefaultGetterName(termName("foo"), 2)) assert(clue(DefaultParamsClass.getNonOverloadedDecl(DefaultGetterName(termName("foo"), 3))).isEmpty) + val fooOverloads = DefaultParamsClass.findAllOverloadedDecls(termName("foo")) + + val fooWithDefaults = fooOverloads.find(_.paramCount == 3).get + assert(clue(fooWithDefaults.hasParamWithDefault) && !clue(fooWithDefaults.isParamWithDefault)) + val List(Left(fooWithDefaultsParams)) = fooWithDefaults.paramSymss: @unchecked + assert(clue(fooWithDefaultsParams.map(_.isParamWithDefault)) == List(false, true, true)) + assert(clue(fooWithDefaultsParams.map(_.hasParamWithDefault)).forall(_ == false)) + + for fooOverload <- fooOverloads if fooOverload ne fooWithDefaults do + assert(!clue(fooOverload.hasParamWithDefault) && !clue(fooOverload.isParamWithDefault)) + val List(Left(params)) = fooOverload.paramSymss: @unchecked + assert(clue(params.map(_.isParamWithDefault)).forall(_ == false)) + assert(clue(params.map(_.hasParamWithDefault)).forall(_ == false)) + } + + testWithContext("default-params-scala-2") { val IteratorClass = ctx.findTopLevelClass("scala.collection.Iterator") assert(clue(IteratorClass.getNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 0))).isEmpty) IteratorClass.findNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 1)) assert(clue(IteratorClass.getNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 2))).isEmpty) + locally { + val indexWhere = IteratorClass.findNonOverloadedDecl(termName("indexWhere")) + assert(clue(indexWhere.hasParamWithDefault) && !clue(indexWhere.isParamWithDefault)) + + val List(Left(List(p, from))) = indexWhere.paramSymss: @unchecked + assert(!clue(p.hasParamWithDefault) && !clue(p.isParamWithDefault)) + assert(!clue(from.hasParamWithDefault) && clue(from.isParamWithDefault)) + } + val ArrayDequeModClass = ctx.findTopLevelModuleClass("scala.collection.mutable.ArrayDeque") ArrayDequeModClass.findNonOverloadedDecl(DefaultGetterName(nme.Constructor, 0)) assert(clue(ArrayDequeModClass.getNonOverloadedDecl(DefaultGetterName(nme.Constructor, 1))).isEmpty) diff --git a/test-sources/src/main/scala/simple_trees/DefaultParams.scala b/test-sources/src/main/scala/simple_trees/DefaultParams.scala index 75e205d6..853f83d9 100644 --- a/test-sources/src/main/scala/simple_trees/DefaultParams.scala +++ b/test-sources/src/main/scala/simple_trees/DefaultParams.scala @@ -2,4 +2,6 @@ package simple_trees class DefaultParams: def foo(x: Int, y: Int = 1, z: String = "hello"): String = s"$x $y $z" + + def foo(x: Int, y: String): String = s"$x $y" end DefaultParams