Skip to content

Commit

Permalink
Merge pull request #185 from sjrd/better-find-toplevel-api
Browse files Browse the repository at this point in the history
Better APIs to find known top-level and static things.
  • Loading branch information
bishabosha authored Nov 3, 2022
2 parents b308b25 + a4fdd33 commit aeac543
Show file tree
Hide file tree
Showing 4 changed files with 155 additions and 68 deletions.
92 changes: 92 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Contexts.scala
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,98 @@ object Contexts {
rec(RootPackage, path)
end findSymbolFromRoot

def findPackage(fullyQualifiedName: String): PackageSymbol =
findPackageFromRoot(FullyQualifiedName(fullyQualifiedName.split('.').toList.map(termName(_))))

def findTopLevelClass(fullyQualifiedName: String): ClassSymbol =
val (pkg, nameStr) = splitPackageAndName(fullyQualifiedName)
val name = typeName(nameStr)
pkg.getDecl(name) match
case Some(cls: ClassSymbol) =>
cls
case _ =>
throw MemberNotFoundException(pkg, name, s"cannot find class $nameStr in $pkg")
end findTopLevelClass

def findTopLevelModuleClass(fullyQualifiedName: String): ClassSymbol =
val (pkg, nameStr) = splitPackageAndName(fullyQualifiedName)
val name = termName(nameStr).withObjectSuffix.toTypeName
pkg.getDecl(name) match
case Some(cls: ClassSymbol) =>
cls
case _ =>
throw MemberNotFoundException(pkg, name, s"cannot find module class $nameStr in $pkg")
end findTopLevelModuleClass

def findStaticClass(fullyQualifiedName: String): ClassSymbol =
findStaticType(fullyQualifiedName) match
case cls: ClassSymbol =>
cls
case sym =>
throw InvalidProgramStructureException(s"expected class symbol but got $sym")
end findStaticClass

def findStaticModuleClass(fullyQualifiedName: String): ClassSymbol =
findStaticTerm(fullyQualifiedName) match
case sym if sym.is(Module) =>
sym.moduleClass.get
case sym =>
throw InvalidProgramStructureException(s"expected module symbol but got $sym")
end findStaticModuleClass

def findStaticType(fullyQualifiedName: String): TypeSymbol =
val (owner, nameStr) = findStaticOwnerAndName(fullyQualifiedName)
val name = typeName(nameStr)
owner
.getDecl(name)
.getOrElse {
throw MemberNotFoundException(owner, name)
}
.asType
end findStaticType

def findStaticTerm(fullyQualifiedName: String): TermSymbol =
val (owner, nameStr) = findStaticOwnerAndName(fullyQualifiedName)
val name = termName(nameStr)
owner
.getDecl(name)
.getOrElse {
throw MemberNotFoundException(owner, name)
}
.asTerm
end findStaticTerm

private def findStaticOwnerAndName(fullyQualifiedName: String): (DeclaringSymbol, String) =
val path = fullyQualifiedName.split('.').toList
(findStaticOwner(path.init), path.last)

private def findStaticOwner(path: List[String]): DeclaringSymbol =
def loop(owner: DeclaringSymbol, path: List[String]): DeclaringSymbol =
path match
case Nil =>
owner
case nameStr :: rest =>
val name = termName(nameStr)
owner.getDecl(name) match
case Some(pkg: PackageSymbol) =>
loop(pkg, rest)
case Some(moduleSymbol: TermSymbol) if moduleSymbol.is(Module) =>
loop(moduleSymbol.moduleClass.get, rest)
case Some(sym) =>
throw InvalidProgramStructureException(s"$sym is not a static owner")
case None =>
throw MemberNotFoundException(owner, name)
end loop

if path.isEmpty then EmptyPackage
else loop(RootPackage, path)
end findStaticOwner

private def splitPackageAndName(fullyQualifiedName: String): (PackageSymbol, String) =
fullyQualifiedName.split('.').toList match
case name :: Nil => (EmptyPackage, name)
case path => (findPackageFromRoot(FullyQualifiedName(path.init.map(termName(_)))), path.last)

private[tastyquery] def findPackageFromRootOrCreate(fullyQualifiedName: FullyQualifiedName): PackageSymbol =
fullyQualifiedName.path.foldLeft(RootPackage) { (owner, name) =>
owner.getPackageDeclOrCreate(name.asSimpleName)
Expand Down
7 changes: 7 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,13 @@ object Symbols {
if local != null then local
else throw new IllegalStateException(s"$this was not assigned a declared type")

/** Get the module class of this module value definition, if it exists:
* - for `object val C` => `object class C[$]`
*/
final def moduleClass(using Context): Option[ClassSymbol] =
if is(Module) then declaredType.classSymbol
else None

private[tastyquery] final def declaredTypeAsSeenFrom(prefix: Type)(using Context): Type =
declaredType.asSeenFrom(prefix, owner)

Expand Down
54 changes: 27 additions & 27 deletions tasty-query/shared/src/test/scala/tastyquery/SignatureSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package tastyquery

import munit.Location

import tastyquery.Contexts.Context
import tastyquery.Contexts.*
import tastyquery.Names.*
import tastyquery.Signatures.*
import tastyquery.Symbols.*
Expand All @@ -26,7 +26,7 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
assert(!clue(actual).isInstanceOf[SignedName])

testWithContext("java.lang.String") {
val StringClass = resolve(name"java" / name"lang" / tname"String").asClass
val StringClass = defn.StringClass

val charAt = StringClass.getDecl(name"charAt").get.asTerm
assertIsSignedName(charAt.signedName, "charAt", "(scala.Int):scala.Char")
Expand All @@ -39,7 +39,7 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
}

testWithContext("GenericClass") {
val GenericClass = resolve(name"simple_trees" / tname"GenericClass").asClass
val GenericClass = ctx.findTopLevelClass("simple_trees.GenericClass")

val field = GenericClass.getDecl(name"field").get.asTerm
assertNotSignedName(field.signedName)
Expand All @@ -52,35 +52,35 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
}

testWithContext("GenericMethod") {
val GenericMethod = resolve(name"simple_trees" / tname"GenericMethod").asClass
val GenericMethod = ctx.findTopLevelClass("simple_trees.GenericMethod")

val identity = GenericMethod.getDecl(name"identity").get.asTerm
assertIsSignedName(identity.signedName, "identity", "(1,java.lang.Object):java.lang.Object")
}

testWithContext("RichInt") {
val RichInt = resolve(name"scala" / name"runtime" / tname"RichInt").asClass
val RichInt = ctx.findTopLevelClass("scala.runtime.RichInt")

val toHexString = RichInt.getDecl(name"toHexString").get.asTerm
assertIsSignedName(toHexString.signedName, "toHexString", "():java.lang.String")
}

testWithContext("Product") {
val Product = resolve(name"scala" / tname"Product").asClass
val Product = ctx.findTopLevelClass("scala.Product")

val productIterator = Product.getDecl(name"productIterator").get.asTerm
assertIsSignedName(productIterator.signedName, "productIterator", "():scala.collection.Iterator")
}

testWithContext("with type") {
val RefinedTypeTree = resolve(name"simple_trees" / tname"RefinedTypeTree").asClass
val RefinedTypeTree = ctx.findTopLevelClass("simple_trees.RefinedTypeTree")

val andType = RefinedTypeTree.getDecl(name"andType").get.asTerm
intercept[UnsupportedOperationException](andType.signedName)
}

testWithContext("array types") {
val TypeRefIn = resolve(name"simple_trees" / tname"TypeRefIn").asClass
val TypeRefIn = ctx.findTopLevelClass("simple_trees.TypeRefIn")

// TODO The erasure is not actually correct here, but at least we don't crash

Expand Down Expand Up @@ -122,7 +122,7 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
}

testWithContext("type-member") {
val TypeMember = resolve(name"simple_trees" / tname"TypeMember").asClass
val TypeMember = ctx.findTopLevelClass("simple_trees.TypeMember")

val mTypeAlias = TypeMember.getDecl(name"mTypeAlias").get.asTerm
assertIsSignedName(mTypeAlias.signedName, "mTypeAlias", "(scala.Int):scala.Int")
Expand All @@ -141,109 +141,109 @@ class SignatureSuite extends UnrestrictedUnpicklingSuite:
}

testWithContext("scala2-case-class-varargs") {
val StringContext = resolve(name"scala" / tname"StringContext").asClass
val StringContext = ctx.findTopLevelClass("scala.StringContext")

val parts = StringContext.getDecl(name"parts").get.asTerm
assertIsSignedName(parts.signedName, "parts", "():scala.collection.immutable.Seq")
}

testWithContext("scala2-method-byname") {
val StringContext = resolve(name"scala" / tname"Option").asClass
val StringContext = ctx.findTopLevelClass("scala.Option")

val getOrElse = StringContext.getDecl(name"getOrElse").get.asTerm
assertIsSignedName(getOrElse.signedName, "getOrElse", "(1,scala.Function0):java.lang.Object")
}

testWithContext("scala2-existential-type") {
val ClassTag = resolve(name"scala" / name"reflect" / tname"ClassTag" / obj).asClass
val ClassTag = ctx.findTopLevelModuleClass("scala.reflect.ClassTag")

val apply = ClassTag.getDecl(name"apply").get.asTerm
assertIsSignedName(apply.signedName, "apply", "(1,java.lang.Class):scala.reflect.ClassTag")
}

testWithContext("iarray") {
val IArraySig = resolve(name"simple_trees" / tname"IArraySig").asClass
val IArraySig = ctx.findTopLevelClass("simple_trees.IArraySig")

val from = IArraySig.getDecl(name"from").get.asTerm
assertIsSignedName(from.signedName, "from", "():java.lang.String[]")
}

testWithContext("value-class-arrayOps-generic") {
val MyArrayOps = resolve(name"inheritance" / tname"MyArrayOps" / obj).asClass
val MyArrayOps = ctx.findTopLevelModuleClass("inheritance.MyArrayOps")
val genericArrayOps = MyArrayOps.getDecl(name"genericArrayOps").get.asTerm
assertIsSignedName(genericArrayOps.signedName, "genericArrayOps", "(1,java.lang.Object):java.lang.Object")
}

testWithContext("value-class-arrayOps-int") {
val MyArrayOps = resolve(name"inheritance" / tname"MyArrayOps" / obj).asClass
val MyArrayOps = ctx.findTopLevelModuleClass("inheritance.MyArrayOps")
val intArrayOps = MyArrayOps.getDecl(name"intArrayOps").get.asTerm
assertIsSignedName(intArrayOps.signedName, "intArrayOps", "(scala.Int[]):java.lang.Object")
}

testWithContext("scala2-value-class-arrayOps-generic") {
val Predef = resolve(name"scala" / tname"Predef" / obj).asClass
val Predef = ctx.findTopLevelModuleClass("scala.Predef")
val genericArrayOps = Predef.getDecl(name"genericArrayOps").get.asTerm
assertIsSignedName(genericArrayOps.signedName, "genericArrayOps", "(1,java.lang.Object):java.lang.Object")
}

testWithContext("scala2-value-class-arrayOps-int") {
val Predef = resolve(name"scala" / tname"Predef" / obj).asClass
val Predef = ctx.findTopLevelModuleClass("scala.Predef")
val intArrayOps = Predef.getDecl(name"intArrayOps").get.asTerm
assertIsSignedName(intArrayOps.signedName, "intArrayOps", "(scala.Int[]):java.lang.Object")
}

testWithContext("value-class-monomorphic") {
val MyFlags = resolve(name"inheritance" / tname"MyFlags").asClass
val MyFlags = ctx.findTopLevelClass("inheritance.MyFlags")
val merge = MyFlags.getDecl(name"merge").get.asTerm
assertIsSignedName(merge.signedName, "merge", "(scala.Long):scala.Long")
}

testWithContext("value-class-monomorphic-arrayOf") {
val MyFlags = resolve(name"inheritance" / tname"MyFlags" / obj).asClass
val MyFlags = ctx.findTopLevelModuleClass("inheritance.MyFlags")
val mergeAll = MyFlags.getDecl(name"mergeAll").get.asTerm
assertIsSignedName(mergeAll.signedName, "mergeAll", "(inheritance.MyFlags[]):scala.Long")
}

testWithContext("value-class-polymorphic-arrayOf") {
val MyArrayOps = resolve(name"inheritance" / tname"MyArrayOps" / obj).asClass
val MyArrayOps = ctx.findTopLevelModuleClass("inheritance.MyArrayOps")
val arrayOfIntArrayOps = MyArrayOps.getDecl(name"arrayOfIntArrayOps").get.asTerm
assertIsSignedName(arrayOfIntArrayOps.signedName, "arrayOfIntArrayOps", "(scala.Int[][]):inheritance.MyArrayOps[]")
}

testWithContext("package-ref-from-tasty") {
val LazyVals = resolve(name"javacompat" / tname"LazyVals" / obj).asClass
val LazyVals = ctx.findTopLevelModuleClass("javacompat.LazyVals")
val getOffsetStatic = LazyVals.getDecl(name"getOffsetStatic").get.asTerm
assertIsSignedName(getOffsetStatic.signedName, "getOffsetStatic", "(java.lang.reflect.Field):scala.Long")
}

testWithContext("Scala 3 special function types") {
val SpecialFunctionTypes = resolve(name"simple_trees" / tname"SpecialFunctionTypes").asClass
val SpecialFunctionTypes = ctx.findTopLevelClass("simple_trees.SpecialFunctionTypes")
val contextFunction = SpecialFunctionTypes.getDecl(name"contextFunction").get.asTerm
assertIsSignedName(contextFunction.signedName, "contextFunction", "(scala.Function1):scala.Unit")
}

testWithContext("inherited type member, same tasty") {
val SubClass = resolve(name"inheritance" / tname"SameTasty" / obj / tname"Sub").asClass
val SubClass = ctx.findStaticClass("inheritance.SameTasty.Sub")
val foo3 = SubClass.getDecl(name"foo3").get.asTerm
assertIsSignedName(foo3.signedName, "foo3", "():scala.Int")

val SubWithMixinClass = resolve(name"inheritance" / tname"SameTasty" / obj / tname"SubWithMixin").asClass
val SubWithMixinClass = ctx.findStaticClass("inheritance.SameTasty.SubWithMixin")
val bar3 = SubWithMixinClass.getDecl(name"bar3").get.asTerm
assertIsSignedName(bar3.signedName, "bar3", "():scala.Int")
}

testWithContext("inherited type member, cross tasty") {
val SubClass = resolve(name"inheritance" / name"crosstasty" / tname"Sub").asClass
val SubClass = ctx.findTopLevelClass("inheritance.crosstasty.Sub")
val foo3 = SubClass.getDecl(name"foo3").get.asTerm
assertIsSignedName(foo3.signedName, "foo3", "():scala.Int")

val SubWithMixinClass = resolve(name"inheritance" / name"crosstasty" / tname"SubWithMixin").asClass
val SubWithMixinClass = ctx.findTopLevelClass("inheritance.crosstasty.SubWithMixin")
val bar3 = SubWithMixinClass.getDecl(name"bar3").get.asTerm
assertIsSignedName(bar3.signedName, "bar3", "():scala.Int")
}

testWithContext("case class copy method") {
val CaseClass = resolve(name"synthetics" / tname"CaseClass").asClass
val CaseClass = ctx.findTopLevelClass("synthetics.CaseClass")
val copy = CaseClass.getDecl(name"copy").get.asTerm
assertIsSignedName(copy.signedName, "copy", "(java.lang.String):synthetics.CaseClass")
}
Expand Down
Loading

0 comments on commit aeac543

Please sign in to comment.