diff --git a/src/main/java/org/springframework/session/data/mongo/AbstractMongoSessionConverter.java b/src/main/java/org/springframework/session/data/mongo/AbstractMongoSessionConverter.java index 651e558e..d6e156e1 100644 --- a/src/main/java/org/springframework/session/data/mongo/AbstractMongoSessionConverter.java +++ b/src/main/java/org/springframework/session/data/mongo/AbstractMongoSessionConverter.java @@ -47,7 +47,7 @@ */ public abstract class AbstractMongoSessionConverter implements GenericConverter { - static final String EXPIRE_AT_FIELD_NAME = "expireAt"; + private String expireAtFieldName = "expireAt"; private static final Log LOG = LogFactory.getLog(AbstractMongoSessionConverter.class); private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; @@ -73,16 +73,16 @@ public abstract class AbstractMongoSessionConverter implements GenericConverter protected void ensureIndexes(IndexOperations sessionCollectionIndexes) { for (IndexInfo info : sessionCollectionIndexes.getIndexInfo()) { - if (EXPIRE_AT_FIELD_NAME.equals(info.getName())) { - LOG.debug("TTL index on field " + EXPIRE_AT_FIELD_NAME + " already exists"); + if (expireAtFieldName.equals(info.getName())) { + LOG.debug("TTL index on field " + expireAtFieldName + " already exists"); return; } } - LOG.info("Creating TTL index on field " + EXPIRE_AT_FIELD_NAME); + LOG.info("Creating TTL index on field " + expireAtFieldName); sessionCollectionIndexes - .ensureIndex(new Index(EXPIRE_AT_FIELD_NAME, Sort.Direction.ASC).named(EXPIRE_AT_FIELD_NAME).expire(0)); + .ensureIndex(new Index(expireAtFieldName, Sort.Direction.ASC).named(expireAtFieldName).expire(0)); } protected String extractPrincipal(MongoSession expiringSession) { @@ -91,11 +91,13 @@ protected String extractPrincipal(MongoSession expiringSession) { .get(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); } + @Override public Set getConvertibleTypes() { return Collections.singleton(new ConvertiblePair(DBObject.class, MongoSession.class)); } + @Override @SuppressWarnings("unchecked") @Nullable public Object convert(Object source, TypeDescriptor sourceType, TypeDescriptor targetType) { @@ -120,4 +122,12 @@ public Object convert(Object source, TypeDescriptor sourceType, TypeDescriptor t public void setIndexResolver(IndexResolver indexResolver) { this.indexResolver = Assert.requireNonNull(indexResolver, "indexResolver must not be null!"); } + + public void setExpireAtFieldName(String expireAtFieldName) { + this.expireAtFieldName = expireAtFieldName; + } + + String getExpireAtFieldName() { + return expireAtFieldName; + } } diff --git a/src/main/java/org/springframework/session/data/mongo/JacksonMongoSessionConverter.java b/src/main/java/org/springframework/session/data/mongo/JacksonMongoSessionConverter.java index f4f4adfa..95b0f83a 100644 --- a/src/main/java/org/springframework/session/data/mongo/JacksonMongoSessionConverter.java +++ b/src/main/java/org/springframework/session/data/mongo/JacksonMongoSessionConverter.java @@ -56,9 +56,8 @@ public class JacksonMongoSessionConverter extends AbstractMongoSessionConverter private static final Log LOG = LogFactory.getLog(JacksonMongoSessionConverter.class); - private static final String ATTRS_FIELD_NAME = "attrs."; - private static final String PRINCIPAL_FIELD_NAME = "principal"; - private static final String EXPIRE_AT_FIELD_NAME = "expireAt"; + private String attrsFieldName = "attrs."; + private String pricipalFieldName = "principal"; private final ObjectMapper objectMapper; @@ -78,13 +77,14 @@ public JacksonMongoSessionConverter(ObjectMapper objectMapper) { this.objectMapper = objectMapper; } + @Override @Nullable protected Query getQueryForIndex(String indexName, Object indexValue) { if (FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME.equals(indexName)) { - return Query.query(Criteria.where(PRINCIPAL_FIELD_NAME).is(indexValue)); + return Query.query(Criteria.where(pricipalFieldName).is(indexValue)); } else { - return Query.query(Criteria.where(ATTRS_FIELD_NAME + MongoSession.coverDot(indexName)).is(indexValue)); + return Query.query(Criteria.where(attrsFieldName + MongoSession.coverDot(indexName)).is(indexValue)); } } @@ -115,8 +115,8 @@ protected DBObject convert(MongoSession source) { DBObject dbSession = BasicDBObject.parse(this.objectMapper.writeValueAsString(source)); // Override default serialization with proper values. - dbSession.put(PRINCIPAL_FIELD_NAME, extractPrincipal(source)); - dbSession.put(EXPIRE_AT_FIELD_NAME, source.getExpireAt()); + dbSession.put(pricipalFieldName, extractPrincipal(source)); + dbSession.put(getExpireAtFieldName(), source.getExpireAt()); return dbSession; } catch (JsonProcessingException e) { throw new IllegalStateException("Cannot convert MongoExpiringSession", e); @@ -127,7 +127,7 @@ protected DBObject convert(MongoSession source) { @Nullable protected MongoSession convert(Document source) { - Date expireAt = (Date) source.remove(EXPIRE_AT_FIELD_NAME); + Date expireAt = (Date) source.remove(getExpireAtFieldName()); source.remove("originalSessionId"); String json = source.toJson(JsonWriterSettings.builder().outputMode(JsonMode.RELAXED).build()); diff --git a/src/main/java/org/springframework/session/data/mongo/JdkMongoSessionConverter.java b/src/main/java/org/springframework/session/data/mongo/JdkMongoSessionConverter.java index 056dd381..151d18cc 100644 --- a/src/main/java/org/springframework/session/data/mongo/JdkMongoSessionConverter.java +++ b/src/main/java/org/springframework/session/data/mongo/JdkMongoSessionConverter.java @@ -47,12 +47,12 @@ */ public class JdkMongoSessionConverter extends AbstractMongoSessionConverter { - private static final String ID = "_id"; - private static final String CREATION_TIME = "created"; - private static final String LAST_ACCESSED_TIME = "accessed"; - private static final String MAX_INTERVAL = "interval"; - private static final String ATTRIBUTES = "attr"; - private static final String PRINCIPAL_FIELD_NAME = "principal"; + private String idFieldName = "_id"; + private String creationTimeFieldName = "created"; + private String lastAccessedTimeFieldName = "accessed"; + private String maxIntervalFieldName = "interval"; + private String attributesFieldName = "attr"; + private String principalFieldName = "principal"; private final Converter serializer; private final Converter deserializer; @@ -80,7 +80,7 @@ public JdkMongoSessionConverter(Converter serializer, Converter< public Query getQueryForIndex(String indexName, Object indexValue) { if (FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME.equals(indexName)) { - return Query.query(Criteria.where(PRINCIPAL_FIELD_NAME).is(indexValue)); + return Query.query(Criteria.where(principalFieldName).is(indexValue)); } else { return null; } @@ -91,13 +91,13 @@ protected DBObject convert(MongoSession session) { BasicDBObject basicDBObject = new BasicDBObject(); - basicDBObject.put(ID, session.getId()); - basicDBObject.put(CREATION_TIME, session.getCreationTime()); - basicDBObject.put(LAST_ACCESSED_TIME, session.getLastAccessedTime()); - basicDBObject.put(MAX_INTERVAL, session.getMaxInactiveInterval()); - basicDBObject.put(PRINCIPAL_FIELD_NAME, extractPrincipal(session)); - basicDBObject.put(EXPIRE_AT_FIELD_NAME, session.getExpireAt()); - basicDBObject.put(ATTRIBUTES, serializeAttributes(session)); + basicDBObject.put(idFieldName, session.getId()); + basicDBObject.put(creationTimeFieldName, session.getCreationTime()); + basicDBObject.put(lastAccessedTimeFieldName, session.getLastAccessedTime()); + basicDBObject.put(maxIntervalFieldName, session.getMaxInactiveInterval()); + basicDBObject.put(principalFieldName, extractPrincipal(session)); + basicDBObject.put(getExpireAtFieldName(), session.getExpireAt()); + basicDBObject.put(attributesFieldName, serializeattributesFieldName(session)); return basicDBObject; } @@ -105,60 +105,112 @@ protected DBObject convert(MongoSession session) { @Override protected MongoSession convert(Document sessionWrapper) { - Object maxInterval = sessionWrapper.getOrDefault(MAX_INTERVAL, this.maxInactiveInterval); + Object maxInterval = sessionWrapper.getOrDefault(maxIntervalFieldName, this.maxInactiveInterval); Duration maxIntervalDuration = (maxInterval instanceof Duration) ? (Duration) maxInterval : Duration.parse(maxInterval.toString()); - MongoSession session = new MongoSession(sessionWrapper.getString(ID), maxIntervalDuration.getSeconds()); + MongoSession session = new MongoSession(sessionWrapper.getString(idFieldName), + maxIntervalDuration.getSeconds()); - Object creationTime = sessionWrapper.get(CREATION_TIME); + Object creationTime = sessionWrapper.get(creationTimeFieldName); if (creationTime instanceof Instant) { session.setCreationTime(((Instant) creationTime).toEpochMilli()); } else if (creationTime instanceof Date) { session.setCreationTime(((Date) creationTime).getTime()); } - Object lastAccessedTime = sessionWrapper.get(LAST_ACCESSED_TIME); + Object lastAccessedTime = sessionWrapper.get(lastAccessedTimeFieldName); if (lastAccessedTime instanceof Instant) { session.setLastAccessedTime((Instant) lastAccessedTime); } else if (lastAccessedTime instanceof Date) { session.setLastAccessedTime(Instant.ofEpochMilli(((Date) lastAccessedTime).getTime())); } - session.setExpireAt((Date) sessionWrapper.get(EXPIRE_AT_FIELD_NAME)); + session.setExpireAt((Date) sessionWrapper.get(getExpireAtFieldName())); - deserializeAttributes(sessionWrapper, session); + deserializeattributesFieldName(sessionWrapper, session); return session; } @Nullable - private byte[] serializeAttributes(Session session) { + private byte[] serializeattributesFieldName(Session session) { - Map attributes = new HashMap<>(); + Map attributesFieldName = new HashMap<>(); for (String attrName : session.getAttributeNames()) { - attributes.put(attrName, session.getAttribute(attrName)); + attributesFieldName.put(attrName, session.getAttribute(attrName)); } - return this.serializer.convert(attributes); + return this.serializer.convert(attributesFieldName); } @SuppressWarnings("unchecked") - private void deserializeAttributes(Document sessionWrapper, Session session) { + private void deserializeattributesFieldName(Document sessionWrapper, Session session) { - Object sessionAttributes = sessionWrapper.get(ATTRIBUTES); + Object sessionattributesFieldName = sessionWrapper.get(attributesFieldName); - byte[] attributesBytes = (sessionAttributes instanceof Binary ? ((Binary) sessionAttributes).getData() - : (byte[]) sessionAttributes); + byte[] attributesFieldNameBytes = (sessionattributesFieldName instanceof Binary + ? ((Binary) sessionattributesFieldName).getData() + : (byte[]) sessionattributesFieldName); - Map attributes = (Map) this.deserializer.convert(attributesBytes); + Map attributesFieldName = (Map) this.deserializer + .convert(attributesFieldNameBytes); - if (attributes != null) { - for (Map.Entry entry : attributes.entrySet()) { + if (attributesFieldName != null) { + for (Map.Entry entry : attributesFieldName.entrySet()) { session.setAttribute(entry.getKey(), entry.getValue()); } } } + + public String getIdFieldName() { + return idFieldName; + } + + public void setIdFieldName(String idFieldName) { + this.idFieldName = idFieldName; + } + + public String getCreationTimeFieldName() { + return creationTimeFieldName; + } + + public void setCreationTimeFieldName(String creationTimeFieldName) { + this.creationTimeFieldName = creationTimeFieldName; + } + + public String getLastAccessedTimeFieldName() { + return lastAccessedTimeFieldName; + } + + public void setLastAccessedTimeFieldName(String lastAccessedTimeFieldName) { + this.lastAccessedTimeFieldName = lastAccessedTimeFieldName; + } + + public String getMaxIntervalFieldName() { + return maxIntervalFieldName; + } + + public void setMaxIntervalFieldName(String maxIntervalFieldName) { + this.maxIntervalFieldName = maxIntervalFieldName; + } + + public String getPrincipalFieldName() { + return principalFieldName; + } + + public void setPrincipalFieldName(String principalFieldName) { + this.principalFieldName = principalFieldName; + } + + public String getAttributesFieldName() { + return attributesFieldName; + } + + public void setAttributesFieldName(String attributesFieldName) { + this.attributesFieldName = attributesFieldName; + } + }