I'm trying to figure out a neat way of traversing a graph Scala-style, preferably with vals and immutable data types.
Given the following graph,
val graph = Map(0 -> Set(1),
1 -> Set(2),
2 -> Set(0, 3, 4),
3 -> Set(),
4 -> Set(3))
I'd like the output to be the depth first traversal starting in a given node. Starting in 1 for instance, should yield for instance 1 2 3 0 4
.
I can't seem to figure out a nice way of doing this without mutable collections or vars. Any help would be appreciated.
Tail Recursive solution:
def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
def childrenNotVisited(parent: Int, visited: List[Int]) =
graph(parent) filter (x => !visited.contains(x))
@annotation.tailrec
def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
if (stack isEmpty) visited
else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
stack.head :: visited)
}
loop(Set(start), Nil) reverse
}
This is one variant I guess:
graph.foldLeft((List[Int](), 1)){
(s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))
}._1.reverse
Updated: This is an expanded version. Here I fold left over the elements of the map starting out with a tuple of an empty list and number 1. For each element I check the size of the graph and create a new tuple accordingly. The resulting list come out in reverse order.
val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
(s, elem) =>
val (stack, count) = s
if (elem._2.size == 0)
(0 :: stack, count)
else
(count :: stack, count + 1)
}
result.reverse
Here is recursive solution (hope I understood your requirements correctly):
def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] =
List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))
traverse(graph, 1)
Also please note, that this function is NOT tail recursive.
Don't know if you are still looking for an answer after 6 years, but here it is :)
It also returns a topological ordering and cyclicality of the graph:-
case class Node(label: Int)
case class Graph(adj: Map[Node, Set[Node]]) {
case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
isCylic: Boolean = false)
def dfs: (List[Node], Boolean) = {
def dfsVisit(currState: DfsState, src: Node): DfsState = {
val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))
val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)
}
val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
(stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
}}
Seems that this question is more involving than I originally thought. I wrote another recursive solution. It's still not tail recursive. I also tried hard to make it one-liner, but in this case readability will suffer a lot, so I decided to declare several val
s this time:
def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
val newResult = result :+ node
val currentEdges = graph(node) -- newResult
val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges
(newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))
}
In my previous answer I tried to find all paths from the given node in directed graph. But it was wrong according to the requirements. This answer tries to follow directed edges, but if it can't, then it just takes some unvisited node and continues from there.
Tenshi,
I haven't fully understood your solution , but if I am not mistaken it's time complexity is at least O(|V|^2) since the following line complexity is O(|V|):
val newResult = result :+ node
Namely, appending an element to the right of a list.
Further more, the code is not tail recursive, which might be a problem if for example the recursion depth is limited by the environment you are using.
The following code solves a few DFS related graph problems on directed graphs. It is not the most elegant code, but if I am not mistaken it is:
- Tail recursive.
- Uses only immutable collections (and iterators on them).
- Has optimal time O(|V| + |E|) and space complexity (O(|V|).
The code:
import scala.annotation.tailrec
import scala.util.Try
/**
* Created with IntelliJ IDEA.
* User: mishaelr
* Date: 5/14/14
* Time: 5:18 PM
*/
object DirectedGraphTraversals {
type Graph[Vertex] = Map[Vertex, Set[Vertex]]
def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())
def topologicalSort[Vertex](graph: Graph[Vertex]) =
graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())
def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
val reversedGraph = reverse(graph)
exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
if(visitedAcc(vertex))
acc
else {
val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
visitedAcc, visitedAcc,List()).toSet
(visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)
}
}._2
}
def reverse[Vertex](graph: Graph[Vertex]) = {
val reverseList = for {
(vertex, neighbours) <- graph.toList
neighbour <- neighbours
} yield (neighbour, vertex)
reverseList.groupBy(_._1).mapValues(_.map(_._2).toSet)
}
private sealed trait NeighboursFunc {
def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])
}
private object DfsNeighbours extends NeighboursFunc {
def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
(vertex, graph.getOrElse(vertex, Set()).iterator)
}
private object TopologicalSortNeighbours extends NeighboursFunc {
def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
val neighbours = graph.getOrElse(vertex, Set())
if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
else
(vertex, neighbours.iterator)
}
}
@tailrec
private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
entered: Set[Vertex], exited: Set[Vertex],
exitStack: List[Vertex]): List[Vertex] = {
toVisit match {
case List() => exitStack
case (currentVertex, neighbours) :: tl =>
val filtered = neighbours.filterNot(entered)
if(filtered.hasNext) {
val nextNeighbour = filtered.next()
dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
entered + nextNeighbour, exited, exitStack)
} else
dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)
}
}
@tailrec
private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
if(notVisited.isEmpty)
order
else {
val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)
}
}
}
object DirectedGraphTraversalsExamples extends App {
import DirectedGraphTraversals._
val graph = Map(
"B" -> Set("D", "C"),
"A" -> Set("B", "D"),
"D" -> Set("E"),
"E" -> Set("C"))
println("dfs A " + dfs(graph, "A"))
println("dfs B " + dfs(graph, "B"))
println("topologicalSort " + topologicalSort(graph))
println("reverse " + reverse(graph))
println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))
val graph2 = graph + ("C" -> Set("D"))
println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
println("topologicalSort graph2 " + Try(topologicalSort(graph2)))
}
Marimuthu Madasamy's answer is indeed working.
Here is the generic version of it:
val graph = Map(0 -> Set(1),
1 -> Set(2),
2 -> Set(0, 3, 4),
3 -> Set[Int](),
4 -> Set(3))
def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
def childrenNotVisited(parent: T, visited: List[T]) =
graph(parent) filter (x => !visited.contains(x))
@annotation.tailrec
def loop(stack: Set[T], visited: List[T]): List[T] = {
if (stack.isEmpty) visited
else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
stack.head :: visited)
}
loop(Set(start), Nil).reverse
}
traverse(graph,0)
Note: You have to make sure the instances of T
are correctly implementing equals and hashcode. Using case classes with primitive values is an easy way to get there.
来源:https://stackoverflow.com/questions/5471234/how-to-implement-a-dfs-with-immutable-data-types