diff --git a/examples/WiFiAdvancedCallback/WiFiAdvancedCallback.ino b/examples/WiFiAdvancedCallback/WiFiAdvancedCallback.ino
index 0259b64..fde35d6 100644
--- a/examples/WiFiAdvancedCallback/WiFiAdvancedCallback.ino
+++ b/examples/WiFiAdvancedCallback/WiFiAdvancedCallback.ino
@@ -69,6 +69,10 @@ void setup() {
   // You can provide a username and password for authentication
   // mqttClient.setUsernamePassword("username", "password");
 
+  // By default the library connects with the "clean session" flag set,
+  // you can disable this behaviour by using
+  // mqttClient.setCleanSession(false);
+
   // set a will message, used by the broker when the connection dies unexpectantly
   // you must know the size of the message before hand, and it must be set before connecting
   String willPayload = "oh no!";
diff --git a/keywords.txt b/keywords.txt
index 4adedb0..6c414b6 100644
--- a/keywords.txt
+++ b/keywords.txt
@@ -40,7 +40,7 @@ connected	KEYWORD2
 
 setId	KEYWORD2
 setUsernamePassword	KEYWORD2
-
+setCleanSession	KEYWORD2
 setKeepAliveInterval 	KEYWORD2
 setConnectionTimeout	KEYWORD2
 
diff --git a/src/MqttClient.cpp b/src/MqttClient.cpp
index af93a03..e6b4d22 100644
--- a/src/MqttClient.cpp
+++ b/src/MqttClient.cpp
@@ -64,6 +64,7 @@ enum {
 MqttClient::MqttClient(Client& client) :
   _client(&client),
   _onMessage(NULL),
+  _cleanSession(true),
   _keepAliveInterval(60 * 1000L),
   _connectionTimeout(30 * 1000L),
   _connectError(MQTT_SUCCESS),
@@ -772,6 +773,11 @@ void MqttClient::setUsernamePassword(const String& username, const String& passw
   _password = password;
 }
 
+void MqttClient::setCleanSession(bool cleanSession)
+{
+  _cleanSession = cleanSession;
+}
+
 void MqttClient::setKeepAliveInterval(unsigned long interval)
 {
   _keepAliveInterval = interval;
@@ -855,7 +861,10 @@ int MqttClient::connect(IPAddress ip, const char* host, uint16_t port)
   }
 
   flags |= _willFlags;
-  flags |= 0x02; // clean session
+
+  if (_cleanSession) {
+    flags |= 0x02; // clean session
+  }
 
   connectVariableHeader.protocolName.length = htons(sizeof(connectVariableHeader.protocolName.value));
   memcpy(connectVariableHeader.protocolName.value, "MQTT", sizeof(connectVariableHeader.protocolName.value));
diff --git a/src/MqttClient.h b/src/MqttClient.h
index 82d4bf0..d7c052f 100644
--- a/src/MqttClient.h
+++ b/src/MqttClient.h
@@ -84,6 +84,8 @@ class MqttClient : public Client {
   void setUsernamePassword(const char* username, const char* password);
   void setUsernamePassword(const String& username, const String& password);
 
+  void setCleanSession(bool cleanSession);
+
   void setKeepAliveInterval(unsigned long interval);
   void setConnectionTimeout(unsigned long timeout);
 
@@ -124,6 +126,7 @@ class MqttClient : public Client {
   String _id;
   String _username;
   String _password;
+  bool _cleanSession;
 
   unsigned long _keepAliveInterval;
   unsigned long _connectionTimeout;