-
Notifications
You must be signed in to change notification settings - Fork 0
/
day_12.scala
142 lines (123 loc) · 4.12 KB
/
day_12.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import scala.annotation.tailrec
import scala.io.Source
import scala.io.StdIn
import scala.collection.immutable.ArraySeq
object Day12 {
case class Point(row: Int, col: Int) {
def +(other: Point): Point = Point(row + other.row, col + other.col)
}
case class Grid(locations: IndexedSeq[IndexedSeq[Char]]) {
val numRows = locations.size
val numCols = locations(0).size
val start: Point = locationOf('S')
val end: Point = locationOf('E')
def locationOf(char: Char): Point = {
val (row, rowNum) =
locations.zipWithIndex.find((row, _) => row.contains(char)).get
val colNum = row.indexOf(char)
Point(rowNum, colNum)
}
val allPoints: Set[Point] =
Range(0, numRows)
.flatMap(r => Range(0, numCols).map(c => Point(r, c)))
.toSet
def inBounds(point: Point): Boolean =
Range(0, numRows).contains(point.row) &&
Range(0, numCols).contains(point.col)
val NEIGHBOR_OFFSETS = List(
Point(0, 1),
Point(0, -1),
Point(1, 0),
Point(-1, 0)
)
def neighborsTraverseable(from: Point, to: Point): Boolean = {
val fromHeight = heightAt(from).get
heightAt(to).exists { toHeight => toHeight - fromHeight <= 1 }
}
def neighbors(point: Point): List[Point] =
val height = heightAt(point).get
NEIGHBOR_OFFSETS.map(_ + point).filter { neighbor =>
heightAt(neighbor).isDefined
}
def heightAt(point: Point): Option[Char] =
locations.lift(point.row).flatMap(row => row.lift(point.col)).map {
case 'S' => 'a'
case 'E' => 'z'
case char => char
}
def shortestDistance(origin: Point, dest: Point): Int = {
recursivelyFindShortestPaths(
node = origin,
distances = Map(start -> 0),
unvisited = allPoints - origin,
isTraverseable = (from, to) => {
heightAt(to).get - heightAt(from).get <= 1
}
)(dest)
}
def shortestStartingDistance(dest: Point): Int = {
val distances =
recursivelyFindShortestPaths(
node = dest,
distances = Map(dest -> 0),
unvisited = allPoints - dest,
isTraverseable = (from, to) => {
heightAt(from).get - heightAt(to).get <= 1
}
)
distances.filterKeys(p => heightAt(p).get == 'a').values.min
}
@tailrec private def recursivelyFindShortestPaths(
node: Point,
// Shortest known distance to Point from origin.
// A missing entry represents infinite distance.
distances: Map[Point, Int],
unvisited: Set[Point],
isTraverseable: (from: Point, to: Point) => Boolean
): Map[Point, Int] = {
val updatedDistances = updateDistances(node, distances, isTraverseable)
updatedDistances
.filterKeys(unvisited.contains(_))
.minByOption(_._2) match {
case None =>
// Nothing left to traverse
return updatedDistances
case Some((nextNode, _)) =>
val updatedUnvisited = unvisited - nextNode
recursivelyFindShortestPaths(
node = nextNode,
distances = updatedDistances,
unvisited = updatedUnvisited,
isTraverseable = isTraverseable
)
}
}
private def updateDistances(
node: Point,
distances: Map[Point, Int],
isTraverseable: (Point, Point) => Boolean
): Map[Point, Int] = {
val updatedDistancesToNeighbors = this
.neighbors(node)
.filter(neighbor => isTraverseable(node, neighbor))
.map { neighbor =>
val distanceThroughNode = distances(node) + 1
val knownDistance = distances.getOrElse(neighbor, Int.MaxValue)
(neighbor, Math.min(distanceThroughNode, knownDistance))
}
.toMap
distances ++ updatedDistancesToNeighbors
}
}
def main = {
val grid = Grid(
Source
.fromFile("day_12.input")
.getLines()
.toIndexedSeq
.map(_.toIndexedSeq)
)
println(s"Part 1: ${grid.shortestDistance(grid.start, grid.end)}")
println(s"Part 2: ${grid.shortestStartingDistance(grid.end)}")
}
}