Skip to content

Commit 2bc4cdf

Browse files
authored
feat: add python solution to lc problem: No.3590 (#4518)
Provided a python3 solution for problem 3590. Kth Smallest Path XOR Sum. With Time: O(n log A) where A is the max value of path XOR (since we store numbers in tries, bit by bit). Space: O(n log A) for all tries.
1 parent 42061ce commit 2bc4cdf

File tree

3 files changed

+291
-2
lines changed

3 files changed

+291
-2
lines changed

solution/3500-3599/3590.Kth Smallest Path XOR Sum/README.md

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,103 @@ edit_url: https://github.com/doocs/leetcode/edit/main/solution/3500-3599/3590.Kt
128128
#### Python3
129129

130130
```python
131-
131+
class BinarySumTrie:
132+
def __init__(self):
133+
self.count = 0
134+
self.children = [None, None]
135+
136+
def add(self, num: int, delta: int, bit=17):
137+
self.count += delta
138+
if bit < 0:
139+
return
140+
b = (num >> bit) & 1
141+
if not self.children[b]:
142+
self.children[b] = BinarySumTrie()
143+
self.children[b].add(num, delta, bit - 1)
144+
145+
def collect(self, prefix=0, bit=17, output=None):
146+
if output is None:
147+
output = []
148+
if self.count == 0:
149+
return output
150+
if bit < 0:
151+
output.append(prefix)
152+
return output
153+
if self.children[0]:
154+
self.children[0].collect(prefix, bit - 1, output)
155+
if self.children[1]:
156+
self.children[1].collect(prefix | (1 << bit), bit - 1, output)
157+
return output
158+
159+
def exists(self, num: int, bit=17):
160+
if self.count == 0:
161+
return False
162+
if bit < 0:
163+
return True
164+
b = (num >> bit) & 1
165+
return self.children[b].exists(num, bit - 1) if self.children[b] else False
166+
167+
def find_kth(self, k: int, bit=17):
168+
if k > self.count:
169+
return -1
170+
if bit < 0:
171+
return 0
172+
left_count = self.children[0].count if self.children[0] else 0
173+
if k <= left_count:
174+
return self.children[0].find_kth(k, bit - 1)
175+
elif self.children[1]:
176+
return (1 << bit) + self.children[1].find_kth(k - left_count, bit - 1)
177+
else:
178+
return -1
179+
180+
181+
class Solution:
182+
def kthSmallest(
183+
self, par: List[int], vals: List[int], queries: List[List[int]]
184+
) -> List[int]:
185+
n = len(par)
186+
tree = [[] for _ in range(n)]
187+
for i in range(1, n):
188+
tree[par[i]].append(i)
189+
190+
path_xor = vals[:]
191+
narvetholi = path_xor
192+
193+
def compute_xor(node, acc):
194+
path_xor[node] ^= acc
195+
for child in tree[node]:
196+
compute_xor(child, path_xor[node])
197+
198+
compute_xor(0, 0)
199+
200+
node_queries = defaultdict(list)
201+
for idx, (u, k) in enumerate(queries):
202+
node_queries[u].append((k, idx))
203+
204+
trie_pool = {}
205+
result = [0] * len(queries)
206+
207+
def dfs(node):
208+
trie_pool[node] = BinarySumTrie()
209+
trie_pool[node].add(path_xor[node], 1)
210+
for child in tree[node]:
211+
dfs(child)
212+
if trie_pool[node].count < trie_pool[child].count:
213+
trie_pool[node], trie_pool[child] = (
214+
trie_pool[child],
215+
trie_pool[node],
216+
)
217+
for val in trie_pool[child].collect():
218+
if not trie_pool[node].exists(val):
219+
trie_pool[node].add(val, 1)
220+
for k, idx in node_queries[node]:
221+
if trie_pool[node].count < k:
222+
result[idx] = -1
223+
else:
224+
result[idx] = trie_pool[node].find_kth(k)
225+
226+
dfs(0)
227+
return result
132228
```
133229

134230
#### Java

solution/3500-3599/3590.Kth Smallest Path XOR Sum/README_EN.md

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,103 @@ edit_url: https://github.com/doocs/leetcode/edit/main/solution/3500-3599/3590.Kt
126126
#### Python3
127127

128128
```python
129-
129+
class BinarySumTrie:
130+
def __init__(self):
131+
self.count = 0
132+
self.children = [None, None]
133+
134+
def add(self, num: int, delta: int, bit=17):
135+
self.count += delta
136+
if bit < 0:
137+
return
138+
b = (num >> bit) & 1
139+
if not self.children[b]:
140+
self.children[b] = BinarySumTrie()
141+
self.children[b].add(num, delta, bit - 1)
142+
143+
def collect(self, prefix=0, bit=17, output=None):
144+
if output is None:
145+
output = []
146+
if self.count == 0:
147+
return output
148+
if bit < 0:
149+
output.append(prefix)
150+
return output
151+
if self.children[0]:
152+
self.children[0].collect(prefix, bit - 1, output)
153+
if self.children[1]:
154+
self.children[1].collect(prefix | (1 << bit), bit - 1, output)
155+
return output
156+
157+
def exists(self, num: int, bit=17):
158+
if self.count == 0:
159+
return False
160+
if bit < 0:
161+
return True
162+
b = (num >> bit) & 1
163+
return self.children[b].exists(num, bit - 1) if self.children[b] else False
164+
165+
def find_kth(self, k: int, bit=17):
166+
if k > self.count:
167+
return -1
168+
if bit < 0:
169+
return 0
170+
left_count = self.children[0].count if self.children[0] else 0
171+
if k <= left_count:
172+
return self.children[0].find_kth(k, bit - 1)
173+
elif self.children[1]:
174+
return (1 << bit) + self.children[1].find_kth(k - left_count, bit - 1)
175+
else:
176+
return -1
177+
178+
179+
class Solution:
180+
def kthSmallest(
181+
self, par: List[int], vals: List[int], queries: List[List[int]]
182+
) -> List[int]:
183+
n = len(par)
184+
tree = [[] for _ in range(n)]
185+
for i in range(1, n):
186+
tree[par[i]].append(i)
187+
188+
path_xor = vals[:]
189+
narvetholi = path_xor
190+
191+
def compute_xor(node, acc):
192+
path_xor[node] ^= acc
193+
for child in tree[node]:
194+
compute_xor(child, path_xor[node])
195+
196+
compute_xor(0, 0)
197+
198+
node_queries = defaultdict(list)
199+
for idx, (u, k) in enumerate(queries):
200+
node_queries[u].append((k, idx))
201+
202+
trie_pool = {}
203+
result = [0] * len(queries)
204+
205+
def dfs(node):
206+
trie_pool[node] = BinarySumTrie()
207+
trie_pool[node].add(path_xor[node], 1)
208+
for child in tree[node]:
209+
dfs(child)
210+
if trie_pool[node].count < trie_pool[child].count:
211+
trie_pool[node], trie_pool[child] = (
212+
trie_pool[child],
213+
trie_pool[node],
214+
)
215+
for val in trie_pool[child].collect():
216+
if not trie_pool[node].exists(val):
217+
trie_pool[node].add(val, 1)
218+
for k, idx in node_queries[node]:
219+
if trie_pool[node].count < k:
220+
result[idx] = -1
221+
else:
222+
result[idx] = trie_pool[node].find_kth(k)
223+
224+
dfs(0)
225+
return result
130226
```
131227

132228
#### Java
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
class BinarySumTrie:
2+
def __init__(self):
3+
self.count = 0
4+
self.children = [None, None]
5+
6+
def add(self, num: int, delta: int, bit=17):
7+
self.count += delta
8+
if bit < 0:
9+
return
10+
b = (num >> bit) & 1
11+
if not self.children[b]:
12+
self.children[b] = BinarySumTrie()
13+
self.children[b].add(num, delta, bit - 1)
14+
15+
def collect(self, prefix=0, bit=17, output=None):
16+
if output is None:
17+
output = []
18+
if self.count == 0:
19+
return output
20+
if bit < 0:
21+
output.append(prefix)
22+
return output
23+
if self.children[0]:
24+
self.children[0].collect(prefix, bit - 1, output)
25+
if self.children[1]:
26+
self.children[1].collect(prefix | (1 << bit), bit - 1, output)
27+
return output
28+
29+
def exists(self, num: int, bit=17):
30+
if self.count == 0:
31+
return False
32+
if bit < 0:
33+
return True
34+
b = (num >> bit) & 1
35+
return self.children[b].exists(num, bit - 1) if self.children[b] else False
36+
37+
def find_kth(self, k: int, bit=17):
38+
if k > self.count:
39+
return -1
40+
if bit < 0:
41+
return 0
42+
left_count = self.children[0].count if self.children[0] else 0
43+
if k <= left_count:
44+
return self.children[0].find_kth(k, bit - 1)
45+
elif self.children[1]:
46+
return (1 << bit) + self.children[1].find_kth(k - left_count, bit - 1)
47+
else:
48+
return -1
49+
50+
51+
class Solution:
52+
def kthSmallest(
53+
self, par: List[int], vals: List[int], queries: List[List[int]]
54+
) -> List[int]:
55+
n = len(par)
56+
tree = [[] for _ in range(n)]
57+
for i in range(1, n):
58+
tree[par[i]].append(i)
59+
60+
path_xor = vals[:]
61+
narvetholi = path_xor
62+
63+
def compute_xor(node, acc):
64+
path_xor[node] ^= acc
65+
for child in tree[node]:
66+
compute_xor(child, path_xor[node])
67+
68+
compute_xor(0, 0)
69+
70+
node_queries = defaultdict(list)
71+
for idx, (u, k) in enumerate(queries):
72+
node_queries[u].append((k, idx))
73+
74+
trie_pool = {}
75+
result = [0] * len(queries)
76+
77+
def dfs(node):
78+
trie_pool[node] = BinarySumTrie()
79+
trie_pool[node].add(path_xor[node], 1)
80+
for child in tree[node]:
81+
dfs(child)
82+
if trie_pool[node].count < trie_pool[child].count:
83+
trie_pool[node], trie_pool[child] = (
84+
trie_pool[child],
85+
trie_pool[node],
86+
)
87+
for val in trie_pool[child].collect():
88+
if not trie_pool[node].exists(val):
89+
trie_pool[node].add(val, 1)
90+
for k, idx in node_queries[node]:
91+
if trie_pool[node].count < k:
92+
result[idx] = -1
93+
else:
94+
result[idx] = trie_pool[node].find_kth(k)
95+
96+
dfs(0)
97+
return result

0 commit comments

Comments
 (0)