Skip to content

Commit ba129de

Browse files
authored
Fixes: 6216 | Support vector machines (#6240)
* initial commit * first implementation of hard margin * remove debugging print * many commits squashed because pre-commit was buggy * more kernels and improved kernel management * remove unnecessary code + fix names + formatting + doctests * rename to fit initial naming * better naming and documentation * better naming and documentation
1 parent b75a7c7 commit ba129de

File tree

1 file changed

+205
-0
lines changed

1 file changed

+205
-0
lines changed

Diff for: machine_learning/support_vector_machines.py

+205
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import numpy as np
2+
from numpy import ndarray
3+
from scipy.optimize import Bounds, LinearConstraint, minimize
4+
5+
6+
def norm_squared(vector: ndarray) -> float:
7+
"""
8+
Return the squared second norm of vector
9+
norm_squared(v) = sum(x * x for x in v)
10+
11+
Args:
12+
vector (ndarray): input vector
13+
14+
Returns:
15+
float: squared second norm of vector
16+
17+
>>> norm_squared([1, 2])
18+
5
19+
>>> norm_squared(np.asarray([1, 2]))
20+
5
21+
>>> norm_squared([0, 0])
22+
0
23+
"""
24+
return np.dot(vector, vector)
25+
26+
27+
class SVC:
28+
"""
29+
Support Vector Classifier
30+
31+
Args:
32+
kernel (str): kernel to use. Default: linear
33+
Possible choices:
34+
- linear
35+
regularization: constraint for soft margin (data not linearly separable)
36+
Default: unbound
37+
38+
>>> SVC(kernel="asdf")
39+
Traceback (most recent call last):
40+
...
41+
ValueError: Unknown kernel: asdf
42+
43+
>>> SVC(kernel="rbf")
44+
Traceback (most recent call last):
45+
...
46+
ValueError: rbf kernel requires gamma
47+
48+
>>> SVC(kernel="rbf", gamma=-1)
49+
Traceback (most recent call last):
50+
...
51+
ValueError: gamma must be > 0
52+
"""
53+
54+
def __init__(
55+
self,
56+
*,
57+
regularization: float = np.inf,
58+
kernel: str = "linear",
59+
gamma: float = 0,
60+
) -> None:
61+
self.regularization = regularization
62+
self.gamma = gamma
63+
if kernel == "linear":
64+
self.kernel = self.__linear
65+
elif kernel == "rbf":
66+
if self.gamma == 0:
67+
raise ValueError("rbf kernel requires gamma")
68+
if not (isinstance(self.gamma, float) or isinstance(self.gamma, int)):
69+
raise ValueError("gamma must be float or int")
70+
if not self.gamma > 0:
71+
raise ValueError("gamma must be > 0")
72+
self.kernel = self.__rbf
73+
# in the future, there could be a default value like in sklearn
74+
# sklear: def_gamma = 1/(n_features * X.var()) (wiki)
75+
# previously it was 1/(n_features)
76+
else:
77+
raise ValueError(f"Unknown kernel: {kernel}")
78+
79+
# kernels
80+
def __linear(self, vector1: ndarray, vector2: ndarray) -> float:
81+
"""Linear kernel (as if no kernel used at all)"""
82+
return np.dot(vector1, vector2)
83+
84+
def __rbf(self, vector1: ndarray, vector2: ndarray) -> float:
85+
"""
86+
RBF: Radial Basis Function Kernel
87+
88+
Note: for more information see:
89+
https://en.wikipedia.org/wiki/Radial_basis_function_kernel
90+
91+
Args:
92+
vector1 (ndarray): first vector
93+
vector2 (ndarray): second vector)
94+
95+
Returns:
96+
float: exp(-(gamma * norm_squared(vector1 - vector2)))
97+
"""
98+
return np.exp(-(self.gamma * norm_squared(vector1 - vector2)))
99+
100+
def fit(self, observations: list[ndarray], classes: ndarray) -> None:
101+
"""
102+
Fits the SVC with a set of observations.
103+
104+
Args:
105+
observations (list[ndarray]): list of observations
106+
classes (ndarray): classification of each observation (in {1, -1})
107+
"""
108+
109+
self.observations = observations
110+
self.classes = classes
111+
112+
# using Wolfe's Dual to calculate w.
113+
# Primal problem: minimize 1/2*norm_squared(w)
114+
# constraint: yn(w . xn + b) >= 1
115+
#
116+
# With l a vector
117+
# Dual problem: maximize sum_n(ln) -
118+
# 1/2 * sum_n(sum_m(ln*lm*yn*ym*xn . xm))
119+
# constraint: self.C >= ln >= 0
120+
# and sum_n(ln*yn) = 0
121+
# Then we get w using w = sum_n(ln*yn*xn)
122+
# At the end we can get b ~= mean(yn - w . xn)
123+
#
124+
# Since we use kernels, we only need l_star to calculate b
125+
# and to classify observations
126+
127+
(n,) = np.shape(classes)
128+
129+
def to_minimize(candidate: ndarray) -> float:
130+
"""
131+
Opposite of the function to maximize
132+
133+
Args:
134+
candidate (ndarray): candidate array to test
135+
136+
Return:
137+
float: Wolfe's Dual result to minimize
138+
"""
139+
s = 0
140+
(n,) = np.shape(candidate)
141+
for i in range(n):
142+
for j in range(n):
143+
s += (
144+
candidate[i]
145+
* candidate[j]
146+
* classes[i]
147+
* classes[j]
148+
* self.kernel(observations[i], observations[j])
149+
)
150+
return 1 / 2 * s - sum(candidate)
151+
152+
ly_contraint = LinearConstraint(classes, 0, 0)
153+
l_bounds = Bounds(0, self.regularization)
154+
155+
l_star = minimize(
156+
to_minimize, np.ones(n), bounds=l_bounds, constraints=[ly_contraint]
157+
).x
158+
self.optimum = l_star
159+
160+
# calculating mean offset of separation plane to points
161+
s = 0
162+
for i in range(n):
163+
for j in range(n):
164+
s += classes[i] - classes[i] * self.optimum[i] * self.kernel(
165+
observations[i], observations[j]
166+
)
167+
self.offset = s / n
168+
169+
def predict(self, observation: ndarray) -> int:
170+
"""
171+
Get the expected class of an observation
172+
173+
Args:
174+
observation (Vector): observation
175+
176+
Returns:
177+
int {1, -1}: expected class
178+
179+
>>> xs = [
180+
... np.asarray([0, 1]), np.asarray([0, 2]),
181+
... np.asarray([1, 1]), np.asarray([1, 2])
182+
... ]
183+
>>> y = np.asarray([1, 1, -1, -1])
184+
>>> s = SVC()
185+
>>> s.fit(xs, y)
186+
>>> s.predict(np.asarray([0, 1]))
187+
1
188+
>>> s.predict(np.asarray([1, 1]))
189+
-1
190+
>>> s.predict(np.asarray([2, 2]))
191+
-1
192+
"""
193+
s = sum(
194+
self.optimum[n]
195+
* self.classes[n]
196+
* self.kernel(self.observations[n], observation)
197+
for n in range(len(self.classes))
198+
)
199+
return 1 if s + self.offset >= 0 else -1
200+
201+
202+
if __name__ == "__main__":
203+
import doctest
204+
205+
doctest.testmod()

0 commit comments

Comments
 (0)