Skip to content

Added SimCLR (Deep Learning Framework) #11900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions machine_learning/simclr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
"""
Implementation of SimCLR.
Self-Supervised Learning (SSL) with SimCLR. SimCLR is a framework for learning visual representations without labels by maximizing the agreement between different augmented views of the same image.

Check failure on line 3 in machine_learning/simclr.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/simclr.py:3:89: E501 Line too long (197 > 88)
"""

import numpy as np

Check failure on line 6 in machine_learning/simclr.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

machine_learning/simclr.py:6:17: F401 `numpy` imported but unused
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.losses import SparseCategoricalCrossentropy
from tensorflow.keras.applications import ResNet50
from sklearn.metrics import ConfusionMatrixDisplay
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt


def data_handling(data: dict) -> tuple:

Check failure on line 19 in machine_learning/simclr.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

machine_learning/simclr.py:6:1: I001 Import block is un-sorted or un-formatted
"""
Handles the data by splitting features and targets.

>>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])})

Check failure on line 23 in machine_learning/simclr.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (E501)

machine_learning/simclr.py:23:89: E501 Line too long (95 > 88)
(array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1]))
"""
return (data["data"], data["target"])


def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide type hint for the parameter: input_shape

Please provide type hint for the parameter: projection_dim

"""
Builds a SimCLR model based on ResNet50.

>>> simclr_model().summary() # doctest: +ELLIPSIS
Model: "model"
_________________________________________________________________
...
"""
base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg")
base_model.trainable = True

inputs = layers.Input(shape=input_shape)
x = base_model(inputs, training=True)
x = layers.Dense(projection_dim, activation="relu")(x)
outputs = layers.Dense(projection_dim)(x)
return Model(inputs, outputs)


def contrastive_loss(projection_1, projection_2, temperature=0.1):

Check failure on line 48 in machine_learning/simclr.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ARG001)

machine_learning/simclr.py:48:50: ARG001 Unused function argument: `temperature`

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: contrastive_loss. If the function does not return a value, please provide the type hint as: def function() -> None:

Please provide type hint for the parameter: projection_1

Please provide type hint for the parameter: projection_2

Please provide type hint for the parameter: temperature

"""
Contrastive loss function for self-supervised learning.

>>> contrastive_loss(np.array([0.1]), np.array([0.2]))
0.0
"""
projections = tf.concat([projection_1, projection_2], axis=0)
similarity_matrix = tf.matmul(projections, projections, transpose_b=True)
labels = tf.range(tf.shape(projections)[0])
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, similarity_matrix)
return tf.reduce_mean(loss)


def main() -> None:
"""
>>> main()
"""
# Load a small dataset (using CIFAR-10 as an example)
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Use label encoder to convert labels into numerical form
le = LabelEncoder()
y_train = le.fit_transform(y_train.flatten())
y_test = le.transform(y_test.flatten())

# Split data into train and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)

# Build the SimCLR model
model = simclr_model()
optimizer = Adam()
loss_fn = SparseCategoricalCrossentropy(from_logits=True)

# Training the SimCLR model
for epoch in range(10):
with tf.GradientTape() as tape:
projections_1 = model(x_train)
projections_2 = model(x_train) # Normally, this would use augmented views
loss = contrastive_loss(projections_1, projections_2)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
print(f"Epoch {epoch+1}: Contrastive Loss = {loss.numpy()}")

# Create a new model with a classification head for evaluation
classifier = layers.Dense(10, activation="softmax")(model.output)
classifier_model = Model(model.input, classifier)
classifier_model.compile(optimizer=Adam(), loss=loss_fn, metrics=["accuracy"])
classifier_model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5)

# Display the confusion matrix of the classifier
ConfusionMatrixDisplay.from_estimator(
classifier_model,
x_test,
y_test,
display_labels=le.classes_,
cmap="Blues",
normalize="true",
)
plt.title("Normalized Confusion Matrix - CIFAR-10")
plt.show()


if __name__ == "__main__":
import doctest

doctest.testmod(verbose=True)
main()
Loading