-
-
Notifications
You must be signed in to change notification settings - Fork 46.6k
Added SimCLR (Deep Learning Framework) #11900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
""" | ||
Implementation of SimCLR. | ||
Self-Supervised Learning (SSL) with SimCLR. SimCLR is a framework for learning visual representations without labels by maximizing the agreement between different augmented views of the same image. | ||
""" | ||
|
||
import numpy as np | ||
import tensorflow as tf | ||
from tensorflow.keras import layers | ||
from tensorflow.keras.models import Model | ||
from tensorflow.keras.optimizers import Adam | ||
from tensorflow.keras.losses import SparseCategoricalCrossentropy | ||
from tensorflow.keras.applications import ResNet50 | ||
from sklearn.metrics import ConfusionMatrixDisplay | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.preprocessing import LabelEncoder | ||
import matplotlib.pyplot as plt | ||
|
||
|
||
def data_handling(data: dict) -> tuple: | ||
""" | ||
Handles the data by splitting features and targets. | ||
|
||
>>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])}) | ||
(array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1])) | ||
""" | ||
return (data["data"], data["target"]) | ||
|
||
|
||
def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model: | ||
""" | ||
Builds a SimCLR model based on ResNet50. | ||
|
||
>>> simclr_model().summary() # doctest: +ELLIPSIS | ||
Model: "model" | ||
_________________________________________________________________ | ||
... | ||
""" | ||
base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg") | ||
base_model.trainable = True | ||
|
||
inputs = layers.Input(shape=input_shape) | ||
x = base_model(inputs, training=True) | ||
x = layers.Dense(projection_dim, activation="relu")(x) | ||
outputs = layers.Dense(projection_dim)(x) | ||
return Model(inputs, outputs) | ||
|
||
|
||
def contrastive_loss(projection_1, projection_2, temperature=0.1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide return type hint for the function: Please provide type hint for the parameter: Please provide type hint for the parameter: Please provide type hint for the parameter: |
||
""" | ||
Contrastive loss function for self-supervised learning. | ||
|
||
>>> contrastive_loss(np.array([0.1]), np.array([0.2])) | ||
0.0 | ||
""" | ||
projections = tf.concat([projection_1, projection_2], axis=0) | ||
similarity_matrix = tf.matmul(projections, projections, transpose_b=True) | ||
labels = tf.range(tf.shape(projections)[0]) | ||
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, similarity_matrix) | ||
return tf.reduce_mean(loss) | ||
|
||
|
||
def main() -> None: | ||
""" | ||
>>> main() | ||
""" | ||
# Load a small dataset (using CIFAR-10 as an example) | ||
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() | ||
x_train, x_test = x_train / 255.0, x_test / 255.0 | ||
|
||
# Use label encoder to convert labels into numerical form | ||
le = LabelEncoder() | ||
y_train = le.fit_transform(y_train.flatten()) | ||
y_test = le.transform(y_test.flatten()) | ||
|
||
# Split data into train and validation sets | ||
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2) | ||
|
||
# Build the SimCLR model | ||
model = simclr_model() | ||
optimizer = Adam() | ||
loss_fn = SparseCategoricalCrossentropy(from_logits=True) | ||
|
||
# Training the SimCLR model | ||
for epoch in range(10): | ||
with tf.GradientTape() as tape: | ||
projections_1 = model(x_train) | ||
projections_2 = model(x_train) # Normally, this would use augmented views | ||
loss = contrastive_loss(projections_1, projections_2) | ||
gradients = tape.gradient(loss, model.trainable_variables) | ||
optimizer.apply_gradients(zip(gradients, model.trainable_variables)) | ||
print(f"Epoch {epoch+1}: Contrastive Loss = {loss.numpy()}") | ||
|
||
# Create a new model with a classification head for evaluation | ||
classifier = layers.Dense(10, activation="softmax")(model.output) | ||
classifier_model = Model(model.input, classifier) | ||
classifier_model.compile(optimizer=Adam(), loss=loss_fn, metrics=["accuracy"]) | ||
classifier_model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5) | ||
|
||
# Display the confusion matrix of the classifier | ||
ConfusionMatrixDisplay.from_estimator( | ||
classifier_model, | ||
x_test, | ||
y_test, | ||
display_labels=le.classes_, | ||
cmap="Blues", | ||
normalize="true", | ||
) | ||
plt.title("Normalized Confusion Matrix - CIFAR-10") | ||
plt.show() | ||
|
||
|
||
if __name__ == "__main__": | ||
import doctest | ||
|
||
doctest.testmod(verbose=True) | ||
main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please provide type hint for the parameter:
input_shape
Please provide type hint for the parameter:
projection_dim