diff --git a/main/resolve/src/mill/resolve/ResolveCore.scala b/main/resolve/src/mill/resolve/ResolveCore.scala index 5c3a7396a74..86b5e4b1605 100644 --- a/main/resolve/src/mill/resolve/ResolveCore.scala +++ b/main/resolve/src/mill/resolve/ResolveCore.scala @@ -100,7 +100,8 @@ private object ResolveCore { m.cls, None, current.segments, - Nil + Nil, + Set.empty ) transitiveOrErr.map(transitive => self ++ transitive) @@ -122,7 +123,8 @@ private object ResolveCore { m.cls, None, current.segments, - typePattern + typePattern, + Set.empty ) transitiveOrErr.map(transitive => self ++ transitive) @@ -240,29 +242,39 @@ private object ResolveCore { } def resolveTransitiveChildren( - rootModule: BaseModule, - cls: Class[_], - nameOpt: Option[String], - segments: Segments, - typePattern: Seq[String] + rootModule: BaseModule, + cls: Class[_], + nameOpt: Option[String], + segments: Segments, + typePattern: Seq[String], + seenModules: Set[Class[_]], ): Either[String, Set[Resolved]] = { - val direct = - resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern) - direct.flatMap { direct => - for { - directTraverse <- - resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil) - indirect0 = directTraverse - .collect { case m: Resolved.Module => - resolveTransitiveChildren( - rootModule, - m.cls, - nameOpt, - m.segments, - typePattern - ) + val errOrDirect = resolveDirectChildren(rootModule, cls, nameOpt, segments, typePattern) + val directTraverse = resolveDirectChildren(rootModule, cls, nameOpt, segments, Nil) + + val errOrModules = directTraverse.map { modules => + modules.flatMap { + case m: Resolved.Module => Some(m) + case _ => None + } + } + + if (seenModules.contains(cls)) { + Left(s"Cyclic module reference detected: ${cls.getName}, it's required to wrap it in ModuleRef.") + } else { + val errOrIndirect0 = errOrModules match { + case Right(modules) => + modules.flatMap { m => + Some(resolveTransitiveChildren(rootModule, m.cls, nameOpt, m.segments, typePattern, seenModules + cls)) } - indirect <- EitherOps.sequence(indirect0).map(_.flatten) + case Left(err) => Seq(Left(err)) + } + + val errOrIndirect = EitherOps.sequence(errOrIndirect0).map(_.flatten) + + for { + direct <- errOrDirect + indirect <- errOrIndirect } yield direct ++ indirect } } @@ -318,21 +330,19 @@ private object ResolveCore { } } else Right(Nil) - crossesOrErr.flatMap { crosses => - val filteredCrosses = crosses.filter { c => + for { + crosses <- crossesOrErr + filteredCrosses = crosses.filter { c => classMatchesTypePred(typePattern)(c.cls) } - - resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern) - .map( - _.map { - case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls) - case (Resolved.NamedTask(s), _) => Resolved.NamedTask(segments ++ s) - case (Resolved.Command(s), _) => Resolved.Command(segments ++ s) - } - .toSet - .++(filteredCrosses) - ) + direct <- resolveDirectChildren0(rootModule, segments, cls, nameOpt, typePattern) + } yield { + direct.map { + case (Resolved.Module(s, cls), _) => Resolved.Module(segments ++ s, cls) + case (Resolved.NamedTask(s), _) => Resolved.NamedTask(segments ++ s) + case (Resolved.Command(s), _) => Resolved.Command(segments ++ s) + } + .toSet ++ filteredCrosses } } diff --git a/main/resolve/test/src/mill/main/ResolveTests.scala b/main/resolve/test/src/mill/main/ResolveTests.scala index 8c3c1718bac..7ffd361063c 100644 --- a/main/resolve/test/src/mill/main/ResolveTests.scala +++ b/main/resolve/test/src/mill/main/ResolveTests.scala @@ -1110,5 +1110,54 @@ object ResolveTests extends TestSuite { Right(Set(_.concrete.tests.inner.foo, _.concrete.tests.inner.innerer.bar)) ) } + test("cyclicModuleRefInitError") { + val check = new Checker(TestGraphs.CyclicModuleRefInitError) + test - check.checkSeq0( + Seq("__"), + isShortError(_, "Cyclic module reference detected") + ) + } + test("cyclicModuleRefInitError2") { + val check = new Checker(TestGraphs.CyclicModuleRefInitError2) + test - check.checkSeq0( + Seq("__"), + isShortError(_, "Cyclic module reference detected") + ) + } + test("cyclicModuleRefInitError3") { + val check = new Checker(TestGraphs.CyclicModuleRefInitError3) + test - check.checkSeq0( + Seq("__"), + isShortError(_, "Cyclic module reference detected") + ) + } + test("crossedCyclicModuleRefInitError") { + val check = new Checker(TestGraphs.CrossedCyclicModuleRefInitError) + test - check.checkSeq0( + Seq("__"), + isShortError(_, "Cyclic module reference detected") + ) + } + test("nonCyclicModules") { + val check = new Checker(TestGraphs.NonCyclicModules) + test - check( + "__", + Right(Set(_.foo)) + ) + } + test("moduleRefWithNonModuleRefChild") { + val check = new Checker(TestGraphs.ModuleRefWithNonModuleRefChild) + test - check( + "__", + Right(Set(_.foo)) + ) + } + test("moduleRefCycle") { + val check = new Checker(TestGraphs.ModuleRefCycle) + test - check( + "__", + Right(Set(_.foo)) + ) + } } } diff --git a/main/test/src/mill/util/TestGraphs.scala b/main/test/src/mill/util/TestGraphs.scala index 8ca8c563c87..be6ac674de9 100644 --- a/main/test/src/mill/util/TestGraphs.scala +++ b/main/test/src/mill/util/TestGraphs.scala @@ -667,4 +667,96 @@ object TestGraphs { } } + object CyclicModuleRefInitError extends TestBaseModule { + import mill.Agg + + // See issue: https://github.com/com-lihaoyi/mill/issues/3715 + trait CommonModule extends TestBaseModule { + def moduleDeps: Seq[CommonModule] = Seq.empty + def a = myA + def b = myB + } + + object myA extends A + trait A extends CommonModule + object myB extends B + trait B extends CommonModule { + override def moduleDeps = super.moduleDeps ++ Agg(a) + } + } + + object CyclicModuleRefInitError2 extends TestBaseModule { + // The cycle is in the child + def A = CyclicModuleRefInitError + } + + object CyclicModuleRefInitError3 extends TestBaseModule { + // The cycle is in directly here + object A extends Module { + def b = B + } + object B extends Module { + def a = A + } + } + + object CrossedCyclicModuleRefInitError extends TestBaseModule { + object cross extends mill.Cross[Cross]("210", "211", "212") + trait Cross extends Cross.Module[String] { + def suffix = Task { crossValue } + def c2 = cross2 + } + + object cross2 extends mill.Cross[Cross2]("210", "211", "212") + trait Cross2 extends Cross.Module[String] { + override def millSourcePath = super.millSourcePath / crossValue + def suffix = Task { crossValue } + def c1 = cross + } + } + + // The module names repeat, but it's not actually cyclic and is meant to confuse the cycle detection. + object NonCyclicModules extends TestBaseModule { + def foo = Task { "foo" } + + object A extends Module { + def b = B + } + object B extends Module { + object A extends Module { + def b = B + } + def a = A + + object B extends Module { + object B extends Module {} + object A extends Module { + def b = B + } + def a = A + } + } + } + + // This edge case shouldn't be an error + object ModuleRefWithNonModuleRefChild extends TestBaseModule { + def foo = Task { "foo" } + + def aRef = A + def a = ModuleRef(A) + + object A extends TestBaseModule {} + } + + object ModuleRefCycle extends TestBaseModule { + def foo = Task { "foo" } + + // The cycle is in directly here + object A extends Module { + def b = ModuleRef(B) + } + object B extends Module { + def a = ModuleRef(A) + } + } }