Skip to content

Commit

Permalink
Merge pull request #390 from sjrd/default-params
Browse files Browse the repository at this point in the history
Add methods to detect parameters with default values.
  • Loading branch information
sjrd authored Nov 21, 2023
2 parents 01f46c9 + fe1398c commit e2f9476
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 3 deletions.
1 change: 1 addition & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ private[tastyquery] object Flags:
val Extension: Flag = newFlag("Extension")
val Final: Flag = newFlag("Final")
val Given: Flag = newFlag("Given")
val HasDefault: Flag = newFlag("HasDefault")
val Implicit: Flag = newFlag("Implicit")
val Infix: Flag = newFlag("Infix")
val Inline: Flag = newFlag("Inline")
Expand Down
11 changes: 11 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -583,6 +583,17 @@ object Symbols {
/** Is this symbol an exporter generated by an `export` statement? */
final def isExport: Boolean = flags.is(Exported)

/** Is this symbol a method with at least one parameter with a default value? */
final def hasParamWithDefault: Boolean =
paramSymss.exists {
case Left(termParams) => termParams.exists(_.isParamWithDefault)
case Right(value) => false
}
end hasParamWithDefault

/** Is this symbol a method parameter with a default value? */
final def isParamWithDefault: Boolean = flags.isAllOf(HasDefault)

/** Get the module class of this module value definition, if it exists:
* - for `object val C` => `object class C[$]`
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ private[pickles] class PickleReader {
if pickleFlags.isStable then flags |= StableRealizable
if pickleFlags.isStatic then flags |= Static
if pickleFlags.isCaseAccessor then flags |= CaseAccessor
// if pickleFlags.hasDefault then flags |= HasDefault
if pickleFlags.hasDefault then flags |= HasDefault
if pickleFlags.isTrait then flags |= Trait
// if pickleFlags.isBridge then flags |= Bridge
if pickleFlags.isAccessor then flags |= Accessor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ private[tasties] class TreeUnpickler private (
case CASEaccessor => addFlag(CaseAccessor)
case COVARIANT => addFlag(Covariant)
case CONTRAVARIANT => addFlag(Contravariant)
case HASDEFAULT => ignoreFlag()
case HASDEFAULT => addFlag(HasDefault)
case STABLE => addFlag(StableRealizable)
case EXTENSION => addFlag(Extension)
case GIVEN => addFlag(Given)
Expand Down
29 changes: 28 additions & 1 deletion tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1051,18 +1051,45 @@ class TypeSuite extends UnrestrictedUnpicklingSuite {
throw AssertionError(s"unexpected type $tpe")
}

testWithContext("scala-2-default-params") {
testWithContext("default-params") {
extension (sym: TermSymbol) def paramCount: Int = sym.declaredType.asInstanceOf[MethodType].paramNames.size

val DefaultParamsClass = ctx.findTopLevelClass("simple_trees.DefaultParams")
assert(clue(DefaultParamsClass.getNonOverloadedDecl(DefaultGetterName(termName("foo"), 0))).isEmpty)
DefaultParamsClass.findNonOverloadedDecl(DefaultGetterName(termName("foo"), 1))
DefaultParamsClass.findNonOverloadedDecl(DefaultGetterName(termName("foo"), 2))
assert(clue(DefaultParamsClass.getNonOverloadedDecl(DefaultGetterName(termName("foo"), 3))).isEmpty)

val fooOverloads = DefaultParamsClass.findAllOverloadedDecls(termName("foo"))

val fooWithDefaults = fooOverloads.find(_.paramCount == 3).get
assert(clue(fooWithDefaults.hasParamWithDefault) && !clue(fooWithDefaults.isParamWithDefault))
val List(Left(fooWithDefaultsParams)) = fooWithDefaults.paramSymss: @unchecked
assert(clue(fooWithDefaultsParams.map(_.isParamWithDefault)) == List(false, true, true))
assert(clue(fooWithDefaultsParams.map(_.hasParamWithDefault)).forall(_ == false))

for fooOverload <- fooOverloads if fooOverload ne fooWithDefaults do
assert(!clue(fooOverload.hasParamWithDefault) && !clue(fooOverload.isParamWithDefault))
val List(Left(params)) = fooOverload.paramSymss: @unchecked
assert(clue(params.map(_.isParamWithDefault)).forall(_ == false))
assert(clue(params.map(_.hasParamWithDefault)).forall(_ == false))
}

testWithContext("default-params-scala-2") {
val IteratorClass = ctx.findTopLevelClass("scala.collection.Iterator")
assert(clue(IteratorClass.getNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 0))).isEmpty)
IteratorClass.findNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 1))
assert(clue(IteratorClass.getNonOverloadedDecl(DefaultGetterName(termName("indexWhere"), 2))).isEmpty)

locally {
val indexWhere = IteratorClass.findNonOverloadedDecl(termName("indexWhere"))
assert(clue(indexWhere.hasParamWithDefault) && !clue(indexWhere.isParamWithDefault))

val List(Left(List(p, from))) = indexWhere.paramSymss: @unchecked
assert(!clue(p.hasParamWithDefault) && !clue(p.isParamWithDefault))
assert(!clue(from.hasParamWithDefault) && clue(from.isParamWithDefault))
}

val ArrayDequeModClass = ctx.findTopLevelModuleClass("scala.collection.mutable.ArrayDeque")
ArrayDequeModClass.findNonOverloadedDecl(DefaultGetterName(nme.Constructor, 0))
assert(clue(ArrayDequeModClass.getNonOverloadedDecl(DefaultGetterName(nme.Constructor, 1))).isEmpty)
Expand Down
2 changes: 2 additions & 0 deletions test-sources/src/main/scala/simple_trees/DefaultParams.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,6 @@ package simple_trees

class DefaultParams:
def foo(x: Int, y: Int = 1, z: String = "hello"): String = s"$x $y $z"

def foo(x: Int, y: String): String = s"$x $y"
end DefaultParams

0 comments on commit e2f9476

Please sign in to comment.