How to implement a DFS with immutable data types

爱⌒轻易说出口 提交于 2019-12-03 12:54:41

Tail Recursive solution:

  def traverse(graph: Map[Int, Set[Int]], start: Int): List[Int] = {
    def childrenNotVisited(parent: Int, visited: List[Int]) =
      graph(parent) filter (x => !visited.contains(x))

    @annotation.tailrec
    def loop(stack: Set[Int], visited: List[Int]): List[Int] = {
      if (stack isEmpty) visited
      else loop(childrenNotVisited(stack.head, visited) ++ stack.tail, 
        stack.head :: visited)
    }
    loop(Set(start), Nil) reverse
  }

This is one variant I guess:

graph.foldLeft((List[Int](), 1)){
  (s, e) => if (e._2.size == 0) (0 :: s._1, s._2) else (s._2 :: s._1, (s._2 + 1))
}._1.reverse

Updated: This is an expanded version. Here I fold left over the elements of the map starting out with a tuple of an empty list and number 1. For each element I check the size of the graph and create a new tuple accordingly. The resulting list come out in reverse order.

val init = (List[Int](), 1)
val (result, _) = graph.foldLeft(init) {
  (s, elem) => 
    val (stack, count) = s
    if (elem._2.size == 0) 
      (0 :: stack, count) 
    else 
      (count :: stack, count + 1)
}
result.reverse

Here is recursive solution (hope I understood your requirements correctly):

def traverse(graph: Map[Int, Set[Int]], node: Int, visited: Set[Int] = Set()): List[Int] = 
    List(node) ++ (graph(node) -- visited flatMap(traverse(graph, _, visited + node)))

traverse(graph, 1)

Also please note, that this function is NOT tail recursive.

Don't know if you are still looking for an answer after 6 years, but here it is :)

It also returns a topological ordering and cyclicality of the graph:-

case class Node(label: Int)
    case class Graph(adj: Map[Node, Set[Node]]) {
      case class DfsState(discovered: Set[Node] = Set(), activeNodes: Set[Node] = Set(), tsOrder: List[Node] = List(),
                          isCylic: Boolean = false)

      def dfs: (List[Node], Boolean) = {
        def dfsVisit(currState: DfsState, src: Node): DfsState = {
          val newState = currState.copy(discovered = currState.discovered + src, activeNodes = currState.activeNodes + src,
            isCylic = currState.isCylic || adj(src).exists(currState.activeNodes))

          val finalState = adj(src).filterNot(newState.discovered).foldLeft(newState)(dfsVisit(_, _))
          finalState.copy(tsOrder = src :: finalState.tsOrder, activeNodes = finalState.activeNodes - src)
        }

        val stateAfterSearch = adj.keys.foldLeft(DfsState()) {(state, n) => if (state.discovered(n)) state else dfsVisit(state, n)}
        (stateAfterSearch.tsOrder, stateAfterSearch.isCylic)
      }}

Seems that this question is more involving than I originally thought. I wrote another recursive solution. It's still not tail recursive. I also tried hard to make it one-liner, but in this case readability will suffer a lot, so I decided to declare several vals this time:

def traverse(graph: Map[Int, Set[Int]], node: Int, result: List[Int] = Nil): List[Int] = {
  val newResult = result :+ node
  val currentEdges = graph(node) -- newResult
  val realEdges = if (currentEdges isEmpty) graph.keySet -- newResult else currentEdges

  (newResult /: realEdges) ((r, n) => if (r contains n) r else traverse(graph, n, r))
}

In my previous answer I tried to find all paths from the given node in directed graph. But it was wrong according to the requirements. This answer tries to follow directed edges, but if it can't, then it just takes some unvisited node and continues from there.

Tenshi,

I haven't fully understood your solution , but if I am not mistaken it's time complexity is at least O(|V|^2) since the following line complexity is O(|V|):

val newResult = result :+ node

Namely, appending an element to the right of a list.

Further more, the code is not tail recursive, which might be a problem if for example the recursion depth is limited by the environment you are using.

The following code solves a few DFS related graph problems on directed graphs. It is not the most elegant code, but if I am not mistaken it is:

  1. Tail recursive.
  2. Uses only immutable collections (and iterators on them).
  3. Has optimal time O(|V| + |E|) and space complexity (O(|V|).

The code:

import scala.annotation.tailrec
import scala.util.Try

/**
 * Created with IntelliJ IDEA.
 * User: mishaelr
 * Date: 5/14/14
 * Time: 5:18 PM
 */
object DirectedGraphTraversals {

  type Graph[Vertex] = Map[Vertex, Set[Vertex]]

  def dfs[Vertex](graph: Graph[Vertex], initialVertex: Vertex) =
    dfsRec(DfsNeighbours)(graph, List(DfsNeighbours(graph, initialVertex, Set(), Set())), Set(), Set(), List())

  def topologicalSort[Vertex](graph: Graph[Vertex]) =
    graphDfsRec(TopologicalSortNeighbours)(graph, graph.keySet, Set(), Set(), List())

  def stronglyConnectedComponents[Vertex](graph: Graph[Vertex]) = {
    val exitOrder = graphDfsRec(DfsNeighbours)(graph, graph.keySet, Set(), Set(), List())
    val reversedGraph = reverse(graph)

    exitOrder.foldLeft((Set[Vertex](), List(Set[Vertex]()))){
      case (acc @(visitedAcc, connectedComponentsAcc), vertex) =>
        if(visitedAcc(vertex))
          acc
        else {
          val connectedComponent = dfsRec(DfsNeighbours)(reversedGraph, List(DfsNeighbours(reversedGraph, vertex, visitedAcc, visitedAcc)),
            visitedAcc, visitedAcc,List()).toSet
          (visitedAcc ++ connectedComponent, connectedComponent :: connectedComponentsAcc)
        }
    }._2
  }

  def reverse[Vertex](graph: Graph[Vertex]) = {
    val reverseList = for {
      (vertex, neighbours) <- graph.toList
      neighbour <- neighbours
    } yield (neighbour, vertex)

    reverseList.groupBy(_._1).mapValues(_.map(_._2).toSet)
  }

  private sealed trait NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]): (Vertex, Iterator[Vertex])
  }

  private object DfsNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) =
      (vertex, graph.getOrElse(vertex, Set()).iterator)
  }

  private object TopologicalSortNeighbours extends NeighboursFunc {
    def apply[Vertex](graph: Graph[Vertex], vertex: Vertex, entered: Set[Vertex], exited: Set[Vertex]) = {
      val neighbours = graph.getOrElse(vertex, Set())
      if(neighbours.exists(neighbour => entered(neighbour) && !exited(neighbour)))
        throw new IllegalArgumentException("The graph is not a DAG, it contains cycles: " + graph)
      else
        (vertex, neighbours.iterator)
    }
  }

  @tailrec
  private def dfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], toVisit: List[(Vertex, Iterator[Vertex])],
                                                             entered: Set[Vertex], exited: Set[Vertex],
                                                             exitStack: List[Vertex]): List[Vertex] = {
    toVisit match {
      case List() => exitStack
      case (currentVertex, neighbours) :: tl =>
        val filtered = neighbours.filterNot(entered)
        if(filtered.hasNext) {
          val nextNeighbour = filtered.next()
          dfsRec(neighboursFunc)(graph, neighboursFunc(graph, nextNeighbour, entered, exited) :: toVisit,
            entered + nextNeighbour, exited, exitStack)
        } else
          dfsRec(neighboursFunc)(graph, tl, entered, exited + currentVertex, currentVertex :: exitStack)
    }
  }

  @tailrec
  private def graphDfsRec[Vertex](neighboursFunc: NeighboursFunc)(graph: Graph[Vertex], notVisited: Set[Vertex],
                                                                  entered: Set[Vertex], exited: Set[Vertex], order: List[Vertex]): List[Vertex] = {
    if(notVisited.isEmpty)
      order
    else {
      val orderSuffix = dfsRec(neighboursFunc)(graph, List(neighboursFunc(graph, notVisited.head, entered, exited)), entered, exited, List())
      graphDfsRec(neighboursFunc)(graph, notVisited -- orderSuffix, entered ++ orderSuffix, exited ++ orderSuffix, orderSuffix ::: order)
    }
  }
}

object DirectedGraphTraversalsExamples extends App {
  import DirectedGraphTraversals._

  val graph = Map(
    "B" -> Set("D", "C"),
    "A" -> Set("B", "D"),
    "D" -> Set("E"),
    "E" -> Set("C"))

  println("dfs A " +  dfs(graph, "A"))
  println("dfs B " +  dfs(graph, "B"))

  println("topologicalSort " +  topologicalSort(graph))

  println("reverse " + reverse(graph))
  println("stronglyConnectedComponents graph " + stronglyConnectedComponents(graph))

  val graph2 = graph + ("C" -> Set("D"))
  println("stronglyConnectedComponents graph2 " + stronglyConnectedComponents(graph2))
  println("topologicalSort graph2 " + Try(topologicalSort(graph2)))
}

Marimuthu Madasamy's answer is indeed working.

Here is the generic version of it:

val graph = Map(0 -> Set(1),
  1 -> Set(2),
  2 -> Set(0, 3, 4),
  3 -> Set[Int](),
  4 -> Set(3))

def traverse[T](graph: Map[T, Set[T]], start: T): List[T] = {
  def childrenNotVisited(parent: T, visited: List[T]) =
    graph(parent) filter (x => !visited.contains(x))

  @annotation.tailrec
  def loop(stack: Set[T], visited: List[T]): List[T] = {
    if (stack.isEmpty) visited
    else loop(childrenNotVisited(stack.head, visited) ++ stack.tail,
      stack.head :: visited)
  }
  loop(Set(start), Nil).reverse
}

traverse(graph,0)

Note: You have to make sure the instances of T are correctly implementing equals and hashcode. Using case classes with primitive values is an easy way to get there.

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!