Skip to content

Commit 25c303e

Browse files
authored
Added Condition(ABC) for better typehinting
1 parent 846cfe7 commit 25c303e

File tree

1 file changed

+23
-14
lines changed

1 file changed

+23
-14
lines changed
Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,36 @@
1-
21
from dataclasses import dataclass, field
2+
from abc import ABC, abstractmethod
33
from typing import Dict, Callable
44

5+
56
@dataclass
6-
class PositiveCondition:
7+
class Condition(ABC):
78
factor: float
89
barrier: float
910

11+
@abstractmethod
1012
def evaluate(self, value: float) -> bool:
11-
return value > self.factor * self.barrier
13+
pass
14+
15+
1216
@dataclass
13-
class NegativeCondition:
14-
factor: float
15-
barrier: float
17+
class PositiveCondition(Condition):
18+
def evaluate(self, value: float) -> bool:
19+
return value > self.factor * self.barrier
1620

21+
22+
@dataclass
23+
class NegativeCondition(Condition):
1724
def evaluate(self, value: float) -> bool:
1825
return value < -1 * self.factor * self.barrier
19-
26+
27+
2028
@dataclass
2129
class BarrierConditions:
2230
"""
2331
A class that generates and manages barrier conditions used for different labeling techniques.
2432
Those conditions can be used to generate labels like this. Example for n=1:
25-
y =
33+
y =
2634
-1 if r_{t,t+n} < -barrier,
2735
1 if r_{t,t+n} > -barrier,
2836
0 else
@@ -31,26 +39,27 @@ class BarrierConditions:
3139
Attributes:
3240
n (int): The number of barrier conditions to be generated for negative and positive barriers.
3341
barrier (float): The threshold value for the barrier.
34-
conditions (Dict[int, Callable[[float], bool]]): A dictionary holding condition functions for various barrier levels. Keys are sorted numerically.
42+
conditions (Dict[int, Condition): A dictionary holding condition functions for various barrier levels.
43+
Keys are sorted numerically.
3544
"""
3645
n: int
3746
barrier: float
38-
conditions: Dict[int, Callable[[float], bool]] = field(default_factory=dict)
47+
conditions: Dict[int, Condition] = field(default_factory=dict)
3948

4049
def __post_init__(self):
4150
"""
4251
Calculate the conditions after the instance has been initialized.
4352
"""
4453
self.generate_conditions()
4554
self.sort_conditions()
46-
55+
4756
def generate_conditions(self):
4857
"""
4958
Generates barrier conditions based on the specified number of conditions and threshold values.
5059
"""
51-
for i in range(1, self.n+1):
60+
for i in range(1, self.n + 1):
5261
self.conditions[-i] = NegativeCondition(factor=i, barrier=self.barrier)
53-
self.conditions[i] = PositiveCondition(factor=i, barrier=self.barrier)
62+
self.conditions[i] = PositiveCondition(factor=i, barrier=self.barrier)
5463

5564
def sort_conditions(self):
5665
"""
@@ -67,4 +76,4 @@ def __str__(self):
6776
condition_strings.append(f"\t{key}: \t{condition}")
6877

6978
conditions_str = "\n\t".join(condition_strings)
70-
return f"BarrierConditions(n={self.n}, barrier={self.barrier}):\n\tConditions={{\n\t{conditions_str}\n\t}}"
79+
return f"BarrierConditions(n={self.n}, barrier={self.barrier}):\n\tConditions={{\n\t{conditions_str}\n\t}}"

0 commit comments

Comments
 (0)