diff --git a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala index b7f15299..721a84cb 100644 --- a/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala +++ b/tasty-query/shared/src/main/scala/tastyquery/Symbols.scala @@ -510,7 +510,7 @@ object Symbols { * - for `object class C[$]` => `class C` */ final def companionClass(using Context): Option[ClassSymbol] = maybeOuter match - case scope: PackageSymbol => + case scope: DeclaringSymbol => scope.getDecl(this.name.companionName).collect { case sym: ClassSymbol => sym } diff --git a/tasty-query/shared/src/test/scala/tastyquery/Paths.scala b/tasty-query/shared/src/test/scala/tastyquery/Paths.scala index 8cc44b91..048b8a66 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/Paths.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/Paths.scala @@ -66,6 +66,7 @@ object Paths: } def show: String = path.mkString(".") def debug: String = toDebugString(path) + def asObj: DeclarationPath = path.convertAsObject extension [T <: DeclarationPath](path: T) private def convertAsObject: T = diff --git a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala index 6e5edef8..722728c9 100644 --- a/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala +++ b/tasty-query/shared/src/test/scala/tastyquery/TypeSuite.scala @@ -1256,4 +1256,25 @@ class TypeSuite extends UnrestrictedUnpicklingSuite { assert(clue(childToString.allOverriddenSymbols.toList) == List(superToString, objectToString)) assert(clue(childToString.nextOverriddenSymbol) == Some(superToString)) } + + def companionClassFullCycle(path: DeclarationPath)(using Context, munit.Location): Unit = { + val cls: ClassSymbol = resolve(path).asClass + val moduleClass: ClassSymbol = resolve(path.asObj).asClass + + assert(cls == moduleClass.companionClass.get) + assert(moduleClass.companionClass.get == cls) + } + + testWithContext("companion-tests-module-value") { + companionClassFullCycle(name"companions" / tname"CompanionObject") + } + + testWithContext("companion-tests-nested-module-value") { + companionClassFullCycle(name"companions" / tname"CompanionObject" / obj / tname"NestedObject") + } + + testWithContext("companion-tests-class-nested-module-value") { + companionClassFullCycle(name"companions" / tname"CompanionObject" / tname"ClassNestedObject") + } + } diff --git a/test-sources/src/main/scala/companions/CompanionObject.scala b/test-sources/src/main/scala/companions/CompanionObject.scala new file mode 100644 index 00000000..a2b4d2d6 --- /dev/null +++ b/test-sources/src/main/scala/companions/CompanionObject.scala @@ -0,0 +1,10 @@ +package companions + +class CompanionObject { + class ClassNestedObject + object ClassNestedObject +} +object CompanionObject { + class NestedObject + object NestedObject +}