10
10
"""
11
11
12
12
from typing import Optional
13
+ from dataclasses import dataclass , field
13
14
14
15
16
+ @dataclass
15
17
class TreeNode :
16
18
"""
17
19
Initialize a TreeNode.
@@ -30,14 +32,19 @@ class TreeNode:
30
32
2
31
33
"""
32
34
33
- def __init__ (
34
- self , name_value : str , num_occur : int , parent_node : Optional ["TreeNode" ] = None
35
- ) -> None :
36
- self .name = name_value
37
- self .count = num_occur
38
- self .node_link = None # Initialize node_link to None
39
- self .parent = parent_node
40
- self .children : dict [str , TreeNode ] = {}
35
+ # def __init__(
36
+ # self, name_value: str, num_occur: int, parent_node: Optional["TreeNode"] = None
37
+ # ) -> None:
38
+ # self.name = name_value
39
+ # self.count = num_occur
40
+ # self.node_link = TreeNode | None # Initialize node_link to None
41
+ # self.parent = parent_node
42
+ # self.children: dict[str, TreeNode] = {}
43
+ name : str
44
+ count : int
45
+ node_link : Optional ['TreeNode' ] = None # Initialize node_link to None
46
+ parent : Optional ["TreeNode" ] = None
47
+ children : dict [str , "TreeNode" ] = field (default_factory = dict )
41
48
42
49
def inc (self , num_occur : int ) -> None :
43
50
self .count += num_occur
@@ -50,7 +57,7 @@ def disp(self, ind: int = 1) -> None:
50
57
51
58
def create_tree (data_set : list , min_sup : int = 1 ) -> tuple [TreeNode , dict ]:
52
59
"""
53
- Create FP tree
60
+ Create Frequent Pattern tree
54
61
55
62
Args:
56
63
data_set (list): A list of transactions, where each transaction
@@ -193,10 +200,7 @@ def update_header(node_to_test: TreeNode, target_node: TreeNode) -> TreeNode:
193
200
while node_to_test .node_link is not None :
194
201
node_to_test = node_to_test .node_link
195
202
if node_to_test .node_link is None :
196
- node_to_test .node_link = TreeNode (
197
- target_node .name , target_node .count , node_to_test
198
- )
199
- # Return the updated node
203
+ node_to_test .node_link = target_node
200
204
return node_to_test
201
205
202
206
@@ -298,6 +302,7 @@ def mine_tree(
298
302
>>> all(expected in frequent_itemsets for expected in expe_itm)
299
303
True
300
304
"""
305
+ new_head : Optional ['TreeNode' ] = None
301
306
sorted_items = sorted (header_table .items (), key = lambda item_info : item_info [1 ][0 ])
302
307
big_l = [item [0 ] for item in sorted_items ]
303
308
for base_pat in big_l :
@@ -311,6 +316,7 @@ def mine_tree(
311
316
header_table [base_pat ][1 ] = update_header (
312
317
header_table [base_pat ][1 ], my_cond_tree
313
318
)
319
+ my_head = new_head
314
320
mine_tree (my_cond_tree , my_head , min_sup , new_freq_set , freq_item_list )
315
321
316
322
0 commit comments