Skip to content

Commit f95efce

Browse files
authored
Create simclr.py
1 parent f3fb504 commit f95efce

File tree

1 file changed

+114
-0
lines changed

1 file changed

+114
-0
lines changed

machine_learning/simclr.py

+114
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Implementation of SimCLR.
3+
Self-Supervised Learning (SSL) with SimCLR. SimCLR is a framework for learning visual representations without labels by maximizing the agreement between different augmented views of the same image.
4+
"""
5+
import numpy as np
6+
import tensorflow as tf
7+
from tensorflow.keras import layers
8+
from tensorflow.keras.models import Model
9+
from tensorflow.keras.optimizers import Adam
10+
from tensorflow.keras.losses import SparseCategoricalCrossentropy
11+
from tensorflow.keras.applications import ResNet50
12+
from sklearn.metrics import ConfusionMatrixDisplay
13+
from sklearn.model_selection import train_test_split
14+
from sklearn.preprocessing import LabelEncoder
15+
import matplotlib.pyplot as plt
16+
17+
18+
def data_handling(data: dict) -> tuple:
19+
"""
20+
Handles the data by splitting features and targets.
21+
22+
>>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])})
23+
(array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1]))
24+
"""
25+
return (data["data"], data["target"])
26+
27+
28+
def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:
29+
"""
30+
Builds a SimCLR model based on ResNet50.
31+
32+
>>> simclr_model().summary() # doctest: +ELLIPSIS
33+
Model: "model"
34+
_________________________________________________________________
35+
...
36+
"""
37+
base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg")
38+
base_model.trainable = True
39+
40+
inputs = layers.Input(shape=input_shape)
41+
x = base_model(inputs, training=True)
42+
x = layers.Dense(projection_dim, activation="relu")(x)
43+
outputs = layers.Dense(projection_dim)(x)
44+
return Model(inputs, outputs)
45+
46+
47+
def contrastive_loss(projection_1, projection_2, temperature=0.1):
48+
"""
49+
Contrastive loss function for self-supervised learning.
50+
51+
>>> contrastive_loss(np.array([0.1]), np.array([0.2]))
52+
0.0
53+
"""
54+
projections = tf.concat([projection_1, projection_2], axis=0)
55+
similarity_matrix = tf.matmul(projections, projections, transpose_b=True)
56+
labels = tf.range(tf.shape(projections)[0])
57+
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, similarity_matrix)
58+
return tf.reduce_mean(loss)
59+
60+
61+
def main() -> None:
62+
"""
63+
>>> main()
64+
"""
65+
# Load a small dataset (using CIFAR-10 as an example)
66+
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
67+
x_train, x_test = x_train / 255.0, x_test / 255.0
68+
69+
# Use label encoder to convert labels into numerical form
70+
le = LabelEncoder()
71+
y_train = le.fit_transform(y_train.flatten())
72+
y_test = le.transform(y_test.flatten())
73+
74+
# Split data into train and validation sets
75+
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)
76+
77+
# Build the SimCLR model
78+
model = simclr_model()
79+
optimizer = Adam()
80+
loss_fn = SparseCategoricalCrossentropy(from_logits=True)
81+
82+
# Training the SimCLR model
83+
for epoch in range(10):
84+
with tf.GradientTape() as tape:
85+
projections_1 = model(x_train)
86+
projections_2 = model(x_train) # Normally, this would use augmented views
87+
loss = contrastive_loss(projections_1, projections_2)
88+
gradients = tape.gradient(loss, model.trainable_variables)
89+
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
90+
print(f"Epoch {epoch+1}: Contrastive Loss = {loss.numpy()}")
91+
92+
# Create a new model with a classification head for evaluation
93+
classifier = layers.Dense(10, activation="softmax")(model.output)
94+
classifier_model = Model(model.input, classifier)
95+
classifier_model.compile(optimizer=Adam(), loss=loss_fn, metrics=["accuracy"])
96+
classifier_model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5)
97+
98+
# Display the confusion matrix of the classifier
99+
ConfusionMatrixDisplay.from_estimator(
100+
classifier_model,
101+
x_test,
102+
y_test,
103+
display_labels=le.classes_,
104+
cmap="Blues",
105+
normalize="true",
106+
)
107+
plt.title("Normalized Confusion Matrix - CIFAR-10")
108+
plt.show()
109+
110+
111+
if __name__ == "__main__":
112+
import doctest
113+
doctest.testmod(verbose=True)
114+
main()

0 commit comments

Comments
 (0)