|
| 1 | +""" |
| 2 | +Implementation of SimCLR. |
| 3 | +Self-Supervised Learning (SSL) with SimCLR. SimCLR is a framework for learning visual representations without labels by maximizing the agreement between different augmented views of the same image. |
| 4 | +""" |
| 5 | +import numpy as np |
| 6 | +import tensorflow as tf |
| 7 | +from tensorflow.keras import layers |
| 8 | +from tensorflow.keras.models import Model |
| 9 | +from tensorflow.keras.optimizers import Adam |
| 10 | +from tensorflow.keras.losses import SparseCategoricalCrossentropy |
| 11 | +from tensorflow.keras.applications import ResNet50 |
| 12 | +from sklearn.metrics import ConfusionMatrixDisplay |
| 13 | +from sklearn.model_selection import train_test_split |
| 14 | +from sklearn.preprocessing import LabelEncoder |
| 15 | +import matplotlib.pyplot as plt |
| 16 | + |
| 17 | + |
| 18 | +def data_handling(data: dict) -> tuple: |
| 19 | + """ |
| 20 | + Handles the data by splitting features and targets. |
| 21 | + |
| 22 | + >>> data_handling({'data': np.array([[0.1, 0.2], [0.3, 0.4]]), 'target': np.array([0, 1])}) |
| 23 | + (array([[0.1, 0.2], [0.3, 0.4]]), array([0, 1])) |
| 24 | + """ |
| 25 | + return (data["data"], data["target"]) |
| 26 | + |
| 27 | + |
| 28 | +def simclr_model(input_shape=(32, 32, 3), projection_dim=64) -> Model: |
| 29 | + """ |
| 30 | + Builds a SimCLR model based on ResNet50. |
| 31 | + |
| 32 | + >>> simclr_model().summary() # doctest: +ELLIPSIS |
| 33 | + Model: "model" |
| 34 | + _________________________________________________________________ |
| 35 | + ... |
| 36 | + """ |
| 37 | + base_model = ResNet50(include_top=False, input_shape=input_shape, pooling="avg") |
| 38 | + base_model.trainable = True |
| 39 | + |
| 40 | + inputs = layers.Input(shape=input_shape) |
| 41 | + x = base_model(inputs, training=True) |
| 42 | + x = layers.Dense(projection_dim, activation="relu")(x) |
| 43 | + outputs = layers.Dense(projection_dim)(x) |
| 44 | + return Model(inputs, outputs) |
| 45 | + |
| 46 | + |
| 47 | +def contrastive_loss(projection_1, projection_2, temperature=0.1): |
| 48 | + """ |
| 49 | + Contrastive loss function for self-supervised learning. |
| 50 | + |
| 51 | + >>> contrastive_loss(np.array([0.1]), np.array([0.2])) |
| 52 | + 0.0 |
| 53 | + """ |
| 54 | + projections = tf.concat([projection_1, projection_2], axis=0) |
| 55 | + similarity_matrix = tf.matmul(projections, projections, transpose_b=True) |
| 56 | + labels = tf.range(tf.shape(projections)[0]) |
| 57 | + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, similarity_matrix) |
| 58 | + return tf.reduce_mean(loss) |
| 59 | + |
| 60 | + |
| 61 | +def main() -> None: |
| 62 | + """ |
| 63 | + >>> main() |
| 64 | + """ |
| 65 | + # Load a small dataset (using CIFAR-10 as an example) |
| 66 | + (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data() |
| 67 | + x_train, x_test = x_train / 255.0, x_test / 255.0 |
| 68 | + |
| 69 | + # Use label encoder to convert labels into numerical form |
| 70 | + le = LabelEncoder() |
| 71 | + y_train = le.fit_transform(y_train.flatten()) |
| 72 | + y_test = le.transform(y_test.flatten()) |
| 73 | + |
| 74 | + # Split data into train and validation sets |
| 75 | + x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2) |
| 76 | + |
| 77 | + # Build the SimCLR model |
| 78 | + model = simclr_model() |
| 79 | + optimizer = Adam() |
| 80 | + loss_fn = SparseCategoricalCrossentropy(from_logits=True) |
| 81 | + |
| 82 | + # Training the SimCLR model |
| 83 | + for epoch in range(10): |
| 84 | + with tf.GradientTape() as tape: |
| 85 | + projections_1 = model(x_train) |
| 86 | + projections_2 = model(x_train) # Normally, this would use augmented views |
| 87 | + loss = contrastive_loss(projections_1, projections_2) |
| 88 | + gradients = tape.gradient(loss, model.trainable_variables) |
| 89 | + optimizer.apply_gradients(zip(gradients, model.trainable_variables)) |
| 90 | + print(f"Epoch {epoch+1}: Contrastive Loss = {loss.numpy()}") |
| 91 | + |
| 92 | + # Create a new model with a classification head for evaluation |
| 93 | + classifier = layers.Dense(10, activation="softmax")(model.output) |
| 94 | + classifier_model = Model(model.input, classifier) |
| 95 | + classifier_model.compile(optimizer=Adam(), loss=loss_fn, metrics=["accuracy"]) |
| 96 | + classifier_model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=5) |
| 97 | + |
| 98 | + # Display the confusion matrix of the classifier |
| 99 | + ConfusionMatrixDisplay.from_estimator( |
| 100 | + classifier_model, |
| 101 | + x_test, |
| 102 | + y_test, |
| 103 | + display_labels=le.classes_, |
| 104 | + cmap="Blues", |
| 105 | + normalize="true", |
| 106 | + ) |
| 107 | + plt.title("Normalized Confusion Matrix - CIFAR-10") |
| 108 | + plt.show() |
| 109 | + |
| 110 | + |
| 111 | +if __name__ == "__main__": |
| 112 | + import doctest |
| 113 | + doctest.testmod(verbose=True) |
| 114 | + main() |
0 commit comments