Detecting all circles in a graph

后端 未结 2 626
太阳男子
太阳男子 2021-02-10 07:49

I have a directed graph stored in a Map data structure, where the key is the node\'s ID and the [value] is the array of the nodeIds of the nodes which are pointed by the key no

2条回答
  •  误落风尘
    2021-02-10 08:29

    package neo4j
    
    import java.net.URI
    import org.apache.spark.graphx.{Edge, EdgeRDD, Graph, Pregel, VertexId, VertexRDD}
    import org.apache.spark.rdd.RDD
    import org.opencypher.spark.api.CAPSSession
    import org.apache.spark.sql.types._
    import org.apache.spark.sql.functions._
    import org.apache.spark.sql.{DataFrame, Row}
    import org.opencypher.okapi.api.graph.GraphName
    import org.opencypher.okapi.api.io.conversion.{NodeMapping, RelationshipMapping}
    import org.opencypher.spark.api.io.neo4j.{Neo4jConfig, Neo4jPropertyGraphDataSource}
    import org.opencypher.spark.api.io.{CAPSNodeTable, CAPSRelationshipTable}
    import org.opencypher.spark.impl.io.neo4j.external.Neo4j
    import scala.collection.mutable
    import scala.collection.mutable.ListBuffer
    import scala.util.Random
    import scala.util.control.Breaks
    
    object TestV  extends App {
      //1) Create CAPS session and retrieve Spark session
      implicit val session: CAPSSession = CAPSSession.local()
      val spark = session.sparkSession
      //11)  Connect to Neo4j
      val boltWriteURI: URI = new URI("bolt://localhost:7687")
      val neo4jWriteConfig: Neo4jConfig = new Neo4jConfig(boltWriteURI, "neo4j", Some("123abc"), true)
      val neo4jResult: Neo4jPropertyGraphDataSource = new Neo4jPropertyGraphDataSource(neo4jWriteConfig)(session)
      val neo4jConnection = Neo4j(neo4jWriteConfig, session.sparkSession)
    
      val neo4jResultName: GraphName = new GraphName("neo4jgraph")
      neo4jResult.delete(neo4jResultName)
      //node : no, name;  relation : source, target, amount
      val len = 9
      var node_seq = collection.mutable.ListBuffer[(Long, String)]()
      var relation_seq = collection.mutable.ListBuffer[(Long, Long, Double)]()
      for (i <- 0 until len) {
        node_seq.+=((i.toLong, s"name$i"))
        i match {
          case a if a == len - 1 => relation_seq.+=((0.toLong, a.toLong, Random.nextDouble() * len))
          case _ => relation_seq = relation_seq.+=((i + 1.toLong, i.toLong, Random.nextDouble() * len))
        }
        if (i % 10 == 0) {
          Thread.sleep(100)
          //println(node_seq.length + " " + relation_seq.length)
        }
      }
      //println(node_seq)
      //println(relation_seq)
      node_seq.+=((9.toLong, "name9"))
      node_seq.+=((10.toLong, "name10"))
      node_seq.+=((11.toLong, "name11"))
      node_seq.+=((12.toLong, "name12"))
      relation_seq.+=((0L, 0L, Random.nextDouble() * len))
      relation_seq.+=((3L, 3L, Random.nextDouble() * len))
      relation_seq.+=((0L, 3L, Random.nextDouble() * len))
      relation_seq.+=((0L, 9L, Random.nextDouble() * len))
      relation_seq.+=((7L, 10L, Random.nextDouble() * len))
      relation_seq.+=((10L, 7L, Random.nextDouble() * len))
      relation_seq.+=((7L, 10L, Random.nextDouble() * len))
      relation_seq.+=((10L, 7L, Random.nextDouble() * len))
      relation_seq.+=((10L, 11L, Random.nextDouble() * len))
      relation_seq.+=((11L, 12L, Random.nextDouble() * len))
      relation_seq.+=((12L, 10L, Random.nextDouble() * len))
    
      relation_seq.+=((0L, 0L, Random.nextDouble() * len))
      relation_seq.+=((3L, 3L, Random.nextDouble() * len))
      relation_seq.+=((3L, 0L, Random.nextDouble() * len))
      relation_seq.+=((9L, 0L, Random.nextDouble() * len))
      relation_seq.+=((10L, 7L, Random.nextDouble() * len))
      relation_seq.+=((7L, 10L, Random.nextDouble() * len))
      relation_seq.+=((10L, 7L, Random.nextDouble() * len))
      relation_seq.+=((7L, 10L, Random.nextDouble() * len))
      relation_seq.+=((11L, 10L, Random.nextDouble() * len))
      relation_seq.+=((12L, 11L, Random.nextDouble() * len))
      relation_seq.+=((10L, 12L, Random.nextDouble() * len))
    
      //3) cache the dataframe
      //println("#####" +node_seq.length + " " + relation_seq.length)
      val nodesDF: DataFrame = spark.createDataFrame(node_seq).toDF("no", "name").
        withColumn("id1", monotonically_increasing_id()).select("id1", "name", "no").cache()
      nodesDF.count()
      //nodesDF.show
      val relsDF: DataFrame = spark.createDataFrame(relation_seq).toDF("source", "target", "amount").
        withColumn("id2", monotonically_increasing_id()).select("id2", "source", "target", "amount").cache()
      relsDF.count()
      //relsDF.show
    
      import spark.implicits._
    
      //8) mapping the columns
      val node_mapping = NodeMapping.withSourceIdKey("id1").withImpliedLabel("Person").withPropertyKeys("no", "name")
      val rel_mapping = RelationshipMapping.withSourceIdKey("id2").withSourceStartNodeKey("source")
        .withSourceEndNodeKey("target").withRelType("KNOWS").withPropertyKeys("amount")
    
      //9)  create tables
      val node_table = CAPSNodeTable(node_mapping, nodesDF)
      val rel_table = CAPSRelationshipTable(rel_mapping, relsDF)
    
      //10) Create graph
      val graph = session.readFrom(node_table, rel_table)
    
      //12) Store graph in neo4j
      neo4jResult.store(neo4jResultName, graph)
      val node_result = neo4jConnection.cypher("MATCH (n:Person) RETURN n.no as no, n.name as name").loadNodeRdds
      val node_fields: Array[StructField] = Array(new StructField("no", LongType, true), StructField("name", StringType, true))
      val nodeSchema = new StructType().add(node_fields(0)).add(node_fields(1))
      //session.sparkSession.createDataFrame(node_result, nodeSchema).show(10)
      val rel_result = neo4jConnection.rels("MATCH (m:Person)-[r:KNOWS]->(n:Person) RETURN m.no as source, n.no as target, r.amount as amount").loadRelRdd
      val rel_fields: Array[StructField] = Array(new StructField("source", LongType, true), StructField("target", LongType, true),
        StructField("amount", DoubleType, true))
      val relSchema = new StructType().add(rel_fields(0)).add(rel_fields(1)).add(rel_fields(2))
      //session.sparkSession.createDataFrame(rel_result, relSchema).show(10)
      val edges: EdgeRDD[Double] = EdgeRDD.fromEdges(rel_result.map { case Row(a: Long, b: Long, c: Double) => Edge(a, b, c) })
      val vertices: VertexRDD[(Long, String)] = VertexRDD(node_result.map { case Row(a: Long, b: String) => (a, (a, b)) })
      val graph_spark = Graph[(Long, String), Double](vertices, edges)
      val final_graph_fit_tmp1 = graph_spark.removeSelfEdges().mapVertices((id, attr) =>
        (new collection.mutable.HashMap[Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]]().+=
        ((id, new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]().+=
        (collection.mutable.LinkedHashSet[Long](id))))
          , new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]](), 0))
    
      val find_circle = Pregel(final_graph_fit_tmp1, new collection.mutable.HashMap
        [Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]]())(
        (id, attr, msg) => {
          Thread.sleep(10)
          if (id == 4) {
            println("================================")
            println("                " + attr._3 + "                ")
            println("================================")
          }
          println("id: " + id + " msg: " + msg + " attr: " + attr)
          val ss: collection.mutable.HashMap[Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]] =
            new collection.mutable.HashMap[Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]]()
          for (ss_a <- attr._1) {
            val cpy = new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]()
            for (e <- ss_a._2) {
              cpy.+=(e)
            }
            ss.+=(ss_a._1 -> cpy)
            if (msg.contains(ss_a._1)) {
              for (s_m <- msg(ss_a._1)) {
                for (s_a <- ss_a._2) {
                  if (attr._3 == 3 && id == 11)
                    println("order operate: " + s_a + " " + s_m)
                  if (s_a.-(s_a.head) == s_m.-(s_m.last))
                    ss(ss_a._1).+=(s_a.++(s_m))
                  if (attr._3 > 1) {
                    val set = new collection.mutable.HashSet[Int]()
                    for (i <- ss(ss_a._1).indices) {
                      if (ss(ss_a._1)(i) == (s_a)) {
                        set.add(i)
                      }
                    }
                    val set2 = ss(ss_a._1).indices.toSet.diff(set)
                    val listBuf = new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]()
                    for (i <- set2) {
                      listBuf.append(ss(ss_a._1)(i))
                    }
                    ss(ss_a._1).clear()
                    ss(ss_a._1).append(listBuf:_*)
                  }
                }
              }
            }
          }
          for (key <- msg.keys.toSet.diff(ss.keys.toSet)) {
            val cpy = new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]()
            for (e <- msg(key)) {
              cpy.+=(e)
            }
            ss.+=((key, cpy))
          }
          val set_remove: collection.mutable.ListBuffer[(Long, collection.mutable.LinkedHashSet[Long])] =
            new collection.mutable.ListBuffer[(Long, collection.mutable.LinkedHashSet[Long])]()
          for (s <- ss) {
            for (e <- s._2) {
              if (attr._3 > 0) {
                if (e.isEmpty) {
                  set_remove.+=((s._1, e))
                } else if (attr._3 == 1) {
                  ss.remove(id)
                } else if (e.contains(id) || e.size != attr._3) {
                  set_remove.+=((s._1, e))
                  if (e.size == attr._3) {
                    attr._2.+=(e)
                  }
                }
              }
            }
          }
          for (s <- set_remove) {
            val set = new collection.mutable.HashSet[Int]()
            for (i <- ss(s._1).indices) {
              if (ss(s._1)(i) == s._2) {
                set.add(i)
              }
            }
            val set2 = ss(s._1).indices.toSet.diff(set)
            val listBuf = new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]()
            for (i <- set2) {
              listBuf.append(ss(s._1)(i))
            }
            ss(s._1).clear()
            ss(s._1).append(listBuf:_*)
          }
          (ss.filter(s => s._2.nonEmpty), attr._2.filter(s => s.nonEmpty), attr._3 + 1)
        },
        trp => if (trp.srcAttr._1.keys.nonEmpty) {
          val ss = new collection.mutable.HashMap[Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]]()
          for (s <- trp.srcAttr._1) {
            if (ss.contains(trp.srcId)) {
              ss(trp.srcId).++=(s._2)
            } else {
              val cpy = new collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]()
              for (e <- s._2 if e.size - trp.srcAttr._3 < 2 && e.size - trp.srcAttr._3 > -2 && e.size - trp.dstAttr._3 < 2 && e.size - trp.dstAttr._3 > -2) {
                cpy.+=(e)
              }
              if (cpy.nonEmpty)
                ss.+=(trp.srcId -> cpy)
            }
          }
          if (ss.nonEmpty)
            Iterator({
              (trp.dstId, ss)
            })
          else
            Iterator.empty
        } else Iterator.empty,
        (a, b) => {
          val ss = new collection.mutable.HashMap[Long, collection.mutable.ListBuffer[collection.mutable.LinkedHashSet[Long]]]()
          ss.++=(a)
          ss.++=(b)
          ss
        }).cache()
      println("================================")
      println("                vertices                ")
      println("================================")
      p**rintln(find_circle.vertices.mapValues(v => v._2.toSet).take(15).mkString("##;\r\n"))
      System.exit(0)
    }
    

    output:

    ================================
                    0                
    ================================
    id: 4 msg: Map() attr: (Map(4 -> ListBuffer(Set(4))),ListBuffer(),0)
    id: 11 msg: Map() attr: (Map(11 -> ListBuffer(Set(11))),ListBuffer(),0)
    id: 0 msg: Map() attr: (Map(0 -> ListBuffer(Set(0))),ListBuffer(),0)
    id: 1 msg: Map() attr: (Map(1 -> ListBuffer(Set(1))),ListBuffer(),0)
    id: 6 msg: Map() attr: (Map(6 -> ListBuffer(Set(6))),ListBuffer(),0)
    id: 3 msg: Map() attr: (Map(3 -> ListBuffer(Set(3))),ListBuffer(),0)
    id: 12 msg: Map() attr: (Map(12 -> ListBuffer(Set(12))),ListBuffer(),0)
    id: 9 msg: Map() attr: (Map(9 -> ListBuffer(Set(9))),ListBuffer(),0)
    id: 7 msg: Map() attr: (Map(7 -> ListBuffer(Set(7))),ListBuffer(),0)
    id: 8 msg: Map() attr: (Map(8 -> ListBuffer(Set(8))),ListBuffer(),0)
    id: 10 msg: Map() attr: (Map(10 -> ListBuffer(Set(10))),ListBuffer(),0)
    id: 5 msg: Map() attr: (Map(5 -> ListBuffer(Set(5))),ListBuffer(),0)
    id: 2 msg: Map() attr: (Map(2 -> ListBuffer(Set(2))),ListBuffer(),0)
    ================================
                    1                
    ================================
    id: 4 msg: Map(5 -> ListBuffer(Set(5))) attr: (Map(4 -> ListBuffer(Set(4))),ListBuffer(),1)
    id: 11 msg: Map(10 -> ListBuffer(Set(10)), 12 -> ListBuffer(Set(12))) attr: (Map(11 -> ListBuffer(Set(11))),ListBuffer(),1)
    id: 0 msg: Map(1 -> ListBuffer(Set(1)), 3 -> ListBuffer(Set(3)), 9 -> ListBuffer(Set(9))) attr: (Map(0 -> ListBuffer(Set(0))),ListBuffer(),1)
    id: 1 msg: Map(2 -> ListBuffer(Set(2))) attr: (Map(1 -> ListBuffer(Set(1))),ListBuffer(),1)
    id: 6 msg: Map(7 -> ListBuffer(Set(7))) attr: (Map(6 -> ListBuffer(Set(6))),ListBuffer(),1)
    id: 3 msg: Map(4 -> ListBuffer(Set(4)), 0 -> ListBuffer(Set(0))) attr: (Map(3 -> ListBuffer(Set(3))),ListBuffer(),1)
    id: 12 msg: Map(11 -> ListBuffer(Set(11)), 10 -> ListBuffer(Set(10))) attr: (Map(12 -> ListBuffer(Set(12))),ListBuffer(),1)
    id: 9 msg: Map(0 -> ListBuffer(Set(0))) attr: (Map(9 -> ListBuffer(Set(9))),ListBuffer(),1)
    id: 7 msg: Map(8 -> ListBuffer(Set(8)), 10 -> ListBuffer(Set(10))) attr: (Map(7 -> ListBuffer(Set(7))),ListBuffer(),1)
    id: 8 msg: Map(0 -> ListBuffer(Set(0))) attr: (Map(8 -> ListBuffer(Set(8))),ListBuffer(),1)
    id: 10 msg: Map(11 -> ListBuffer(Set(11)), 7 -> ListBuffer(Set(7)), 12 -> ListBuffer(Set(12))) attr: (Map(10 -> ListBuffer(Set(10))),ListBuffer(),1)
    id: 5 msg: Map(6 -> ListBuffer(Set(6))) attr: (Map(5 -> ListBuffer(Set(5))),ListBuffer(),1)
    id: 2 msg: Map(3 -> ListBuffer(Set(3))) attr: (Map(2 -> ListBuffer(Set(2))),ListBuffer(),1)
    ....
    ....
    ================================
                        9                
    ================================
    ....
    ....
    ================================
                    vertices                
    ================================
    (4,Set(Set(5, 6, 7, 8, 0, 3, 4), Set(5, 6, 7, 8, 0, 1, 2, 3, 4)))##;
    (11,Set(Set(10, 11), Set(12, 11), Set(10, 12, 11)))##;
    (0,Set(Set(1, 2, 3, 4, 5, 6, 7, 8, 0), Set(3, 4, 5, 6, 7, 8, 0), Set(9, 0), Set(1, 2, 3, 0), Set(3, 0)))##;
    (1,Set(Set(2, 3, 0, 1), Set(2, 3, 4, 5, 6, 7, 8, 0, 1)))##;
    (6,Set(Set(7, 8, 0, 3, 4, 5, 6), Set(7, 8, 0, 1, 2, 3, 4, 5, 6)))##;
    (3,Set(Set(0, 3), Set(0, 1, 2, 3), Set(4, 5, 6, 7, 8, 0, 3), Set(4, 5, 6, 7, 8, 0, 1, 2, 3)))##;
    (12,Set(Set(11, 12), Set(10, 12), Set(11, 10, 12)))##;
    (9,Set(Set(0, 9)))##;
    (7,Set(Set(10, 7), Set(8, 0, 3, 4, 5, 6, 7), Set(8, 0, 1, 2, 3, 4, 5, 6, 7)))##;
    (8,Set(Set(0, 3, 4, 5, 6, 7, 8), Set(0, 1, 2, 3, 4, 5, 6, 7, 8)))##;
    (10,Set(Set(11, 10), Set(7, 10), Set(12, 10), Set(11, 12, 10)))##;
    (5,Set(Set(6, 7, 8, 0, 3, 4, 5), Set(6, 7, 8, 0, 1, 2, 3, 4, 5)))##;
    (2,Set(Set(3, 0, 1, 2), Set(3, 4, 5, 6, 7, 8, 0, 1, 2)))
    
    Process finished with exit code 0
    

提交回复
热议问题