Skip to content

Commit

Permalink
Merge pull request #389 from sjrd/java-annotations
Browse files Browse the repository at this point in the history
Read Java annotations.
  • Loading branch information
bishabosha authored Nov 22, 2023
2 parents 976a049 + d7aede0 commit 2c65b9e
Show file tree
Hide file tree
Showing 20 changed files with 618 additions and 105 deletions.
1 change: 1 addition & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ lazy val testSources = crossProject(JSPlatform, JVMPlatform)
.settings(
publish / skip := true,
scalacOptions += "-Xfatal-warnings",
javacOptions += "-parameters",
)

lazy val tastyQuery =
Expand Down
28 changes: 24 additions & 4 deletions tasty-query/shared/src/main/scala/tastyquery/Annotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ object Annotations:

/** The symbol of the constructor used in the annotation.
*
* This operation is not supported for annotations read from Scala 2.
* This operation is not supported for annotations read from Java or Scala 2.
* It will throw an `UnsupportedOperationException`.
*/
def annotConstructor(using Context): TermSymbol =
Expand Down Expand Up @@ -111,6 +111,24 @@ object Annotations:

new Annotation(tree)
end apply

private[tastyquery] def fromAnnotTypeAndArgs(annotationType: Type, args: List[TermTree]): Annotation =
val pos = SourcePosition.NoPosition

/* Create a TermTree for the annotation that is "good enough" for the main
* methods of `Annotation` to work, notably `symbol` and `arguments`.
* We have to cheat for the constructor, as we do not have its Signature.
* Instead we use an unsigned `nme.Constructor`. This is invalid and will
* cause `Annotation.annotConstructor` to fail, but we do not really have
* a choice.
*/
val annotationTree: TermTree =
val newNode = New(TypeWrapper(annotationType)(pos))(pos)
val selectCtorNode = Select(newNode, nme.Constructor)(None)(pos) // cheating here
Apply(selectCtorNode, args)(pos)

Annotation(annotationTree)
end fromAnnotTypeAndArgs
end Annotation

private def computeAnnotSymbol(tree: TermTree)(using Context): ClassSymbol =
Expand Down Expand Up @@ -140,13 +158,15 @@ object Annotations:
def invalid(): Nothing =
throw InvalidProgramStructureException(s"Cannot find annotation constructor in $tree")

def unsupportedScala2(): Nothing =
throw UnsupportedOperationException(s"Cannot compute the annotation constructor of a Scala 2 annotation: $tree")
def unsupported(): Nothing =
throw UnsupportedOperationException(
s"Cannot compute the annotation constructor of a Java or Scala 2 annotation: $tree"
)

@tailrec
def loop(tree: TermTree): TermSymbol = tree match
case Apply(fun, _) => loop(fun)
case tree @ Select(New(tpt), name) => if name == nme.Constructor then unsupportedScala2() else tree.symbol.asTerm
case tree @ Select(New(tpt), name) => if name == nme.Constructor then unsupported() else tree.symbol.asTerm
case TypeApply(fun, _) => loop(fun)
case Block(_, expr) => loop(expr)
case _ => invalid()
Expand Down
7 changes: 7 additions & 0 deletions tasty-query/shared/src/main/scala/tastyquery/Constants.scala
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ object Constants {

def stringValue: String = if tag == NullTag then "null" else value.toString

/** The class type value of a `classOf` constant.
*
* This must be a "possibly parametrized class type" according to the
* specification of the language. If the class is polymorphic, it may be
* applied (making it a proper type) or not (in which case it is not a
* proper type).
*/
def typeValue: Type = value.asInstanceOf[Type]

override def hashCode: Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
.withDeclaredType(tpe)
.setAnnotations(Nil)
.autoFillParamSymss()
sym.paramSymss.foreach(_.merge.foreach(_.setAnnotations(Nil)))
sym.checkCompleted()
sym
end createSpecialMethod
Expand Down Expand Up @@ -347,6 +348,7 @@ final class Definitions private[tastyquery] (ctx: Context, rootPackage: PackageS
)
applyMethod.autoFillParamSymss()
applyMethod.setAnnotations(Nil)
applyMethod.paramSymss.foreach(_.merge.foreach(_.setAnnotations(Nil)))
applyMethod.checkCompleted()

cls.checkCompleted()
Expand Down
2 changes: 0 additions & 2 deletions tasty-query/shared/src/main/scala/tastyquery/Symbols.scala
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,6 @@ object Symbols {
.createNotDeclaration(name, this)
.withFlags(EmptyFlagSet, privateWithin = None)
.withDeclaredType(paramType)
.setAnnotations(Nil)
}
Left(paramSyms) :: autoComputeParamSymss(tpe.resultType)

Expand All @@ -498,7 +497,6 @@ object Symbols {
LocalTypeParamSymbol
.create(name, this)
.withFlags(EmptyFlagSet, privateWithin = None)
.setAnnotations(Nil)
}
val paramSymRefs = paramSyms.map(_.localRef)
def subst(t: TypeOrMethodic): t.ThisTypeMappableType =
Expand Down
Original file line number Diff line number Diff line change
@@ -1,23 +1,28 @@
package tastyquery.reader.classfiles

import scala.annotation.switch

import scala.collection.mutable

import tastyquery.Annotations.Annotation as TQAnnotation
import tastyquery.Annotations.*
import tastyquery.Classpaths.*
import tastyquery.Contexts.*
import tastyquery.Constants.*
import tastyquery.Exceptions.*
import tastyquery.Flags
import tastyquery.Flags.*
import tastyquery.Names.*
import tastyquery.SourceLanguage
import tastyquery.SourcePosition
import tastyquery.Symbols.*
import tastyquery.Trees.*
import tastyquery.Types.*

import tastyquery.reader.ReaderContext
import tastyquery.reader.ReaderContext.rctx
import tastyquery.reader.pickles.{Unpickler, PickleReader}

import ClassfileReader.*
import ClassfileReader.{Annotation as CFAnnotation, *}
import ClassfileReader.Access.AccessFlags
import Constants.*

Expand Down Expand Up @@ -63,6 +68,8 @@ private[reader] object ClassfileParser {
def resolve(binaryName: SimpleName)(using ReaderContext, InnerClasses): TypeRef =
lookup(binaryName, isStatic = false).asTypeRef

def resolveStatic(binaryName: SimpleName)(using ReaderContext, InnerClasses): TermRef =
lookup(binaryName, isStatic = true).asTermRef
end Resolver

/** The inner classes local to a class file */
Expand Down Expand Up @@ -127,10 +134,10 @@ private[reader] object ClassfileParser {

val sigBytes = scalaSigAnnotation.tpe match {
case annot.ScalaSignature =>
val bytesArg = scalaSigAnnotation.values.head.asInstanceOf[AnnotationValue.Const]
val bytesArg = scalaSigAnnotation.values.head._2.asInstanceOf[AnnotationValue.Const]
pool.sigbytes(bytesArg.valueIdx)
case annot.ScalaLongSignature =>
val bytesArrArg = scalaSigAnnotation.values.head.asInstanceOf[AnnotationValue.Arr]
val bytesArrArg = scalaSigAnnotation.values.head._2.asInstanceOf[AnnotationValue.Arr]
val idxs = bytesArrArg.values.map(_.asInstanceOf[AnnotationValue.Const].valueIdx)
pool.sigbytes(idxs)
}
Expand Down Expand Up @@ -200,15 +207,31 @@ private[reader] object ClassfileParser {
else false
end isSignaturePolymorphic

def createMember(name: SimpleName, isMethod: Boolean, javaFlags: AccessFlags, memberSig: MemberSig): TermSymbol =
def createMember(
name: SimpleName,
isMethod: Boolean,
javaFlags: AccessFlags,
descriptor: String,
attributes: Map[SimpleName, Forked[DataStream]]
): TermSymbol =
// Select the right owner and create the symbol
val owner = if javaFlags.isStatic then moduleClass else cls
val sym = TermSymbol.create(name, owner)
allRegisteredSymbols += sym

// Read parameter names
val methodParamNames =
if isMethod then readMethodParameters(attributes).map(_._1)
else Nil

// Find the signature, or fall back to the descriptor
val memberSig = attributes.get(attr.Signature) match
case Some(stream) => stream.use(ClassfileReader.readSignature)
case None => descriptor

// Parse the signature into a declared type for the symbol
val declaredType =
val parsedType = JavaSignatures.parseSignature(sym, isMethod, memberSig, allRegisteredSymbols)
val parsedType = JavaSignatures.parseSignature(sym, isMethod, methodParamNames, memberSig, allRegisteredSymbols)
val adaptedType =
if isMethod && sym.name == nme.Constructor then cls.makePolyConstructorType(parsedType)
else if isMethod && javaFlags.isVarargsIfMethod then patchForVarargs(sym, parsedType)
Expand All @@ -226,23 +249,46 @@ private[reader] object ClassfileParser {
end flags
sym.withFlags(flags, privateWithin(javaFlags))

// Auto fill the param symbols from the declared type
sym.autoFillParamSymss()

sym.setAnnotations(Nil) // TODO Read Java annotations on fields and methods
// Read and fill annotations
val annots = readAnnotations(sym, attributes)
sym.setAnnotations(annots)

// Handle parameters
if isMethod then
// Auto fill the param symbols from the declared type
sym.autoFillParamSymss()

val termParamAnnots = readTermParamAnnotations(attributes)
if termParamAnnots.isEmpty then
// fast path
sym.paramSymss.foreach(_.merge.foreach(_.setAnnotations(Nil)))
else
val termParamAnnotsIter = termParamAnnots.iterator

for paramSyms <- sym.paramSymss do
paramSyms match
case Left(termParamSyms) =>
for termParamSym <- termParamSyms do
val annots = if termParamAnnotsIter.hasNext then termParamAnnotsIter.next() else Nil
termParamSym.setAnnotations(annots)
case Right(typeParamSyms) =>
// TODO Maybe one day we also read type annotations
typeParamSyms.foreach(_.setAnnotations(Nil))
end if
end if

sym
end createMember

def loadMembers(): Unit =
structure.fields.use {
ClassfileReader.readFields { (name, sigOrDesc, javaFlags) =>
createMember(name, isMethod = false, javaFlags, sigOrDesc)
ClassfileReader.readMembers { (javaFlags, name, descriptor, attributes) =>
createMember(name, isMethod = false, javaFlags, descriptor, attributes)
}
}
structure.methods.use {
ClassfileReader.readMethods { (name, sigOrDesc, javaFlags) =>
createMember(name, isMethod = true, javaFlags, sigOrDesc)
ClassfileReader.readMembers { (javaFlags, name, descriptor, attributes) =>
createMember(name, isMethod = true, javaFlags, descriptor, attributes)
}
}
end loadMembers
Expand All @@ -254,7 +300,9 @@ private[reader] object ClassfileParser {
val parents = attributes.get(attr.Signature) match
case Some(stream) =>
val sig = stream.use(ClassfileReader.readSignature)
JavaSignatures.parseSignature(cls, isMethod = false, sig, allRegisteredSymbols).requireType match
val parsedSig =
JavaSignatures.parseSignature(cls, isMethod = false, methodParameterNames = Nil, sig, allRegisteredSymbols)
parsedSig.requireType match
case mix: AndType => mix.parts
case sup => sup :: Nil
case None =>
Expand All @@ -273,7 +321,6 @@ private[reader] object ClassfileParser {

cls.withGivenSelfType(None)
cls.withFlags(clsFlags, clsPrivateWithin)
cls.setAnnotations(Nil) // TODO Read Java annotations on classes
initParents()

// Intercept special classes to create their magic methods
Expand All @@ -284,6 +331,9 @@ private[reader] object ClassfileParser {

loadMembers()

val annotations = readAnnotations(cls, attributes)
cls.setAnnotations(annotations)

for sym <- allRegisteredSymbols do
sym.checkCompleted()
assert(sym.sourceLanguage == SourceLanguage.Java, s"$sym of ${sym.getClass()}")
Expand Down Expand Up @@ -324,6 +374,115 @@ private[reader] object ClassfileParser {
None
end ArrayTypeExtractor

private def readMethodParameters(attributes: AttributeMap)(
using ConstantPool
): List[(UnsignedTermName, AccessFlags)] =
attributes.get(attr.MethodParameters) match
case Some(stream) => stream.use(ClassfileReader.readMethodParameters())
case None => Nil
end readMethodParameters

private def readAnnotations(
sym: TermOrTypeSymbol,
attributes: AttributeMap
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): List[Annotation] =
readAnnotations(sym, attributes.get(attr.RuntimeVisibleAnnotations))
::: readAnnotations(sym, attributes.get(attr.RuntimeInvisibleAnnotations))
end readAnnotations

private def readAnnotations(
sym: TermOrTypeSymbol,
annotationsStream: Option[Forked[DataStream]]
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): List[Annotation] =
annotationsStream.fold(Nil)(readAnnotations(sym, _))
end readAnnotations

private def readAnnotations(
sym: TermOrTypeSymbol,
annotationsStream: Forked[DataStream]
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): List[Annotation] =
val classfileAnnots = annotationsStream.use(ClassfileReader.readAllAnnotations())
classfileAnnots.map(classfileAnnotToAnnot(_))

private def readTermParamAnnotations(
attributes: AttributeMap
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): List[List[Annotation]] =
val runtimeVisible = attributes.get(attr.RuntimeVisibleParameterAnnotations) match
case None => Nil
case Some(stream) => stream.use(ClassfileReader.readAllParameterAnnotations())

val runtimeInvisible = attributes.get(attr.RuntimeInvisibleParameterAnnotations) match
case None => Nil
case Some(stream) => stream.use(ClassfileReader.readAllParameterAnnotations())

if runtimeVisible.isEmpty && runtimeInvisible.isEmpty then
// fast path
Nil
else
for (rtVisible, rtInvisible) <- runtimeVisible.zipAll(runtimeInvisible, Nil, Nil)
yield (rtVisible ::: rtInvisible).map(classfileAnnotToAnnot(_))
end readTermParamAnnotations

private def classfileAnnotToAnnot(
classfileAnnot: CFAnnotation
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): Annotation =
val annotationType = JavaSignatures.parseFieldDescriptor(classfileAnnot.tpe.name)

val args: List[TermTree] =
for (name, value) <- classfileAnnot.values.toList yield
val valueTree = annotationValueToTree(value)
NamedArg(name, valueTree)(SourcePosition.NoPosition)

Annotation.fromAnnotTypeAndArgs(annotationType, args)
end classfileAnnotToAnnot

private def annotationValueToTree(
value: AnnotationValue
)(using ConstantPool, ReaderContext, InnerClasses, Resolver): TermTree =
import AnnotationValue.Tags

val pool = summon[ConstantPool]
val pos = SourcePosition.NoPosition

value match
case AnnotationValue.Const(tag, valueIdx) =>
val constant = (tag: @switch) match
case Tags.Byte => Constant(pool.integer(valueIdx).toByte)
case Tags.Char => Constant(pool.integer(valueIdx).toChar)
case Tags.Double => Constant(pool.double(valueIdx))
case Tags.Float => Constant(pool.float(valueIdx))
case Tags.Int => Constant(pool.integer(valueIdx))
case Tags.Long => Constant(pool.long(valueIdx))
case Tags.Short => Constant(pool.integer(valueIdx).toShort)
case Tags.Boolean => Constant(pool.integer(valueIdx) != 0)
case Tags.String => Constant(pool.utf8(valueIdx).name)
Literal(constant)(pos)

case AnnotationValue.EnumConst(descriptor, constName) =>
/* JVMS says that it can be any field descriptor,
* but I don't see what we would do with a base type or array type.
*/
val binaryName = descriptor.name match
case s"L$binaryName;" => binaryName
case other => throw ClassfileFormatException(s"unexpected non-class field descriptor: $other")
val enumClassStaticRef = resolver.resolveStatic(termName(binaryName))
val constRef = TermRef(enumClassStaticRef, constName)
Ident(constName)(constRef)(pos)

case AnnotationValue.ClassConst(descriptor) =>
val classType = JavaSignatures.parseReturnDescriptor(descriptor.name)
Literal(Constant(classType))(pos)

case AnnotationValue.NestedAnnotation(annotation) =>
val nestedAnnot = classfileAnnotToAnnot(annotation)
nestedAnnot.tree

case AnnotationValue.Arr(values) =>
val valueTrees = values.map(annotationValueToTree(_)).toList
val elemType = rctx.AnyType // TODO This will not be type-correct
SeqLiteral(valueTrees, TypeWrapper(elemType)(pos))(pos)
end annotationValueToTree

def detectClassKind(structure: Structure): ClassKind =
import structure.given

Expand Down
Loading

0 comments on commit 2c65b9e

Please sign in to comment.