Skip to content

Commit 0e1da4b

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f95efce commit 0e1da4b

File tree

2 files changed

+14
-11
lines changed

2 files changed

+14
-11
lines changed

machine_learning/simCLR.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
def data_handling(data: dict) -> tuple:
1919
"""
2020
Handles the data by splitting features and targets.
21-
21+
2222
>>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])})
2323
(array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1]))
2424
"""
@@ -28,15 +28,15 @@ def data_handling(data: dict) -> tuple:
2828
def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:
2929
"""
3030
Builds a SimCLR model based on ResNet50.
31-
31+
3232
>>> simclr_model().summary() # doctest: +ELLIPSIS
3333
Model: "model"
3434
_________________________________________________________________
3535
...
3636
"""
3737
base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg")
3838
base_model.trainable = True
39-
39+
4040
inputs = layers.Input(shape=input_shape)
4141
x = base_model(inputs, training=True)
4242
x = layers.Dense(projection_dim, activation="relu")(x)
@@ -47,7 +47,7 @@ def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:
4747
def contrastive_loss(projection_1, projection_2, temperature=0.1):
4848
"""
4949
Contrastive loss function for self-supervised learning.
50-
50+
5151
>>> contrastive_loss(np.array([0.1]), np.array([0.2]))
5252
0.0
5353
"""
@@ -70,7 +70,7 @@ def main() -> None:
7070
le = LabelEncoder()
7171
y_train = le.fit_transform(y_train.flatten())
7272
y_test = le.transform(y_test.flatten())
73-
73+
7474
# Split data into train and validation sets
7575
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)
7676

@@ -110,5 +110,6 @@ def main() -> None:
110110

111111
if __name__ == "__main__":
112112
import doctest
113+
113114
doctest.testmod(verbose=True)
114115
main()

machine_learning/simclr.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
2-
Implementation of SimCLR.
2+
Implementation of SimCLR.
33
Self-Supervised Learning (SSL) with SimCLR. SimCLR is a framework for learning visual representations without labels by maximizing the agreement between different augmented views of the same image.
44
"""
5+
56
import numpy as np
67
import tensorflow as tf
78
from tensorflow.keras import layers
@@ -18,7 +19,7 @@
1819
def data_handling(data: dict) -> tuple:
1920
"""
2021
Handles the data by splitting features and targets.
21-
22+
2223
>>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])})
2324
(array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1]))
2425
"""
@@ -28,15 +29,15 @@ def data_handling(data: dict) -> tuple:
2829
def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:
2930
"""
3031
Builds a SimCLR model based on ResNet50.
31-
32+
3233
>>> simclr_model().summary() # doctest: +ELLIPSIS
3334
Model: "model"
3435
_________________________________________________________________
3536
...
3637
"""
3738
base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg")
3839
base_model.trainable = True
39-
40+
4041
inputs = layers.Input(shape=input_shape)
4142
x = base_model(inputs, training=True)
4243
x = layers.Dense(projection_dim, activation="relu")(x)
@@ -47,7 +48,7 @@ def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:
4748
def contrastive_loss(projection_1, projection_2, temperature=0.1):
4849
"""
4950
Contrastive loss function for self-supervised learning.
50-
51+
5152
>>> contrastive_loss(np.array([0.1]), np.array([0.2]))
5253
0.0
5354
"""
@@ -70,7 +71,7 @@ def main() -> None:
7071
le = LabelEncoder()
7172
y_train = le.fit_transform(y_train.flatten())
7273
y_test = le.transform(y_test.flatten())
73-
74+
7475
# Split data into train and validation sets
7576
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)
7677

@@ -110,5 +111,6 @@ def main() -> None:
110111

111112
if __name__ == "__main__":
112113
import doctest
114+
113115
doctest.testmod(verbose=True)
114116
main()

0 commit comments

Comments
 (0)