So here's what I want to do: I have a list that contains several equivalence relations:
l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
And I want to union the sets that share one element. Here is a sample implementation:
def union(lis):
lis = [set(e) for e in lis]
res = []
while True:
for i in range(len(lis)):
a = lis[i]
if res == []:
res.append(a)
else:
pointer = 0
while pointer < len(res):
if a & res[pointer] != set([]) :
res[pointer] = res[pointer].union(a)
break
pointer +=1
if pointer == len(res):
res.append(a)
if res == lis:
break
lis,res = res,[]
return res
And it prints
[set([1, 2, 3, 6, 7]), set([4, 5])]
This does the right thing but is way too slow when the equivalence relations is too large. I looked up the descriptions on union-find algorithm: http://en.wikipedia.org/wiki/Disjoint-set_data_structure but I still having problem coding a Python implementation.
Solution that runs in O(n)
time
def indices_dict(lis):
d = defaultdict(list)
for i,(a,b) in enumerate(lis):
d[a].append(i)
d[b].append(i)
return d
def disjoint_indices(lis):
d = indices_dict(lis)
sets = []
while len(d):
que = set(d.popitem()[1])
ind = set()
while len(que):
ind |= que
que = set([y for i in que
for x in lis[i]
for y in d.pop(x, [])]) - ind
sets += [ind]
return sets
def disjoint_sets(lis):
return [set([x for i in s for x in lis[i]]) for s in disjoint_indices(lis)]
How it works:
>>> lis = [(1,2),(2,3),(4,5),(6,7),(1,7)]
>>> indices_dict(lis)
>>> {1: [0, 4], 2: [0, 1], 3: [1], 4: [2], 5: [2], 6: [3], 7: [3, 4]})
indices_dict
gives a map from an equivalence # to an index in lis
. E.g. 1
is mapped to index 0
and 4
in lis
.
>>> disjoint_indices(lis)
>>> [set([0,1,3,4], set([2])]
disjoint_indices
gives a list of disjoint sets of indices. Each set corresponds to indices in an equivalence. E.g. lis[0]
and lis[3]
are in the same equivalence but not lis[2]
.
>>> disjoint_set(lis)
>>> [set([1, 2, 3, 6, 7]), set([4, 5])]
disjoint_set
converts disjoint indices into into their proper equivalences.
Time complexity
The O(n)
time complexity is difficult to see but I'll try to explain. Here I will use n = len(lis)
.
indices_dict
certainly runs inO(n)
time because only 1 for-loopdisjoint_indices
is the hardest to see. It certainly runs inO(len(d))
time since the outer loop stops whend
is empty and the inner loop removes an element ofd
each iteration. now, thelen(d) <= 2n
sinced
is a map from equivalence #'s to indices and there are at most2n
different equivalence #'s inlis
. Therefore, the function runs inO(n)
.disjoint_sets
is difficult to see because of the 3 combined for-loops. However, you'll notice that at mosti
can run over alln
indices inlis
andx
runs over the 2-tuple, so the total complexity is2n = O(n)
I think this is an elegant solution, using the built in set functions:
#!/usr/bin/python3
def union_find(lis):
lis = map(set, lis)
unions = []
for item in lis:
temp = []
for s in unions:
if not s.isdisjoint(item):
item = s.union(item)
else:
temp.append(s)
temp.append(item)
unions = temp
return unions
if __name__ == '__main__':
l = [[1, 2], [2, 3], [4, 5], [6, 7], [1, 7]]
print(union_find(l))
It returns a list of sets.
Perhaps something like this?
#!/usr/local/cpython-3.3/bin/python
import copy
import pprint
import collections
def union(list_):
dict_ = collections.defaultdict(set)
for sublist in list_:
dict_[sublist[0]].add(sublist[1])
dict_[sublist[1]].add(sublist[0])
change_made = True
while change_made:
change_made = False
for key, values in dict_.items():
for value in copy.copy(values):
for element in dict_[value]:
if element not in dict_[key]:
dict_[key].add(element)
change_made = True
return dict_
list_ = [ [1, 2], [2, 3], [4, 5], [6, 7], [1, 7] ]
pprint.pprint(union(list_))
This works by completely exhausting one equivalence at a time. When an element finds it's equivalence it is removed from the original set and no longer searched.
def equiv_sets(lis):
s = set(lis)
sets = []
#loop while there are still items in original set
while len(s):
s1 = set(s.pop())
length = 0
#loop while there are still equivalences to s1
while( len(s1) != length):
length = len(s1)
for v in list(s):
if v[0] in s1 or v[1] in s1:
s1 |= set(v)
s -= set([v])
sets += [s1]
return sets
print equiv_sets([(1,2),(2,3),(4,5),(6,7),(1,7)])
OUTPUT: [set([1, 2, 3, 6, 7]), set([4, 5])]
来源:https://stackoverflow.com/questions/20154368/union-find-implementation-using-python