@@ -8,7 +8,7 @@ def __init__(self, max_bit_len=31):
8
8
def add (self , a ):
9
9
u = 0
10
10
self .cc [u ] += 1
11
- for i in range (self .mb - 1 , - 1 , - 1 ):
11
+ for i in range (self .mb - 1 , - 1 , - 1 ):
12
12
d = a >> i & 1
13
13
if self .to [d ][u ] == - 1 :
14
14
self .to [d ][u ] = len (self .cc )
@@ -19,29 +19,34 @@ def add(self, a):
19
19
self .cc [u ] += 1
20
20
21
21
def remove (self , a ):
22
- if self .cc [0 ] == 0 : return False
22
+ if self .cc [0 ] == 0 :
23
+ return False
23
24
uu = [0 ]
24
25
u = 0
25
- for i in range (self .mb - 1 , - 1 , - 1 ):
26
+ for i in range (self .mb - 1 , - 1 , - 1 ):
26
27
d = a >> i & 1
27
28
u = self .to [d ][u ]
28
- if u == - 1 or self .cc [u ] == 0 : return False
29
+ if u == - 1 or self .cc [u ] == 0 :
30
+ return False
29
31
uu .append (u )
30
- for u in uu : self .cc [u ] -= 1
32
+ for u in uu :
33
+ self .cc [u ] -= 1
31
34
return True
32
35
33
36
def cnt (self , a ):
34
37
u = 0
35
- for i in range (self .mb - 1 , - 1 , - 1 ):
38
+ for i in range (self .mb - 1 , - 1 , - 1 ):
36
39
d = a >> i & 1
37
40
u = self .to [d ][u ]
38
- if u == - 1 or self .cc [u ] == 0 : return 0
41
+ if u == - 1 or self .cc [u ] == 0 :
42
+ return 0
39
43
return self .cc [u ]
40
44
41
45
def min_xor (self , a ):
42
- if self .cc [0 ] == 0 : return self .inf
46
+ if self .cc [0 ] == 0 :
47
+ return self .inf
43
48
u , res = 0 , 0
44
- for i in range (self .mb - 1 , - 1 , - 1 ):
49
+ for i in range (self .mb - 1 , - 1 , - 1 ):
45
50
d = a >> i & 1
46
51
v = self .to [d ][u ]
47
52
if v == - 1 or self .cc [v ] == 0 :
@@ -52,9 +57,10 @@ def min_xor(self, a):
52
57
return res
53
58
54
59
def max_xor (self , a ):
55
- if self .cc [0 ] == 0 : return - self .inf
60
+ if self .cc [0 ] == 0 :
61
+ return - self .inf
56
62
u , res = 0 , 0
57
- for i in range (self .mb - 1 , - 1 , - 1 ):
63
+ for i in range (self .mb - 1 , - 1 , - 1 ):
58
64
d = a >> i & 1
59
65
v = self .to [d ^ 1 ][u ]
60
66
if v == - 1 or self .cc [v ] == 0 :
0 commit comments