18
18
import java .lang .reflect .Modifier ;
19
19
import java .util .Collection ;
20
20
import java .util .Collections ;
21
- import java .util .Comparator ;
22
21
import java .util .HashMap ;
23
22
import java .util .HashSet ;
24
23
import java .util .List ;
25
24
import java .util .Map ;
26
- import java .util .Optional ;
27
25
import java .util .Set ;
28
- import java .util .function .Function ;
29
- import java .util .stream .Collectors ;
26
+ import java .util .function .BiFunction ;
30
27
31
28
import org .springframework .data .mapping .context .AbstractMappingContext ;
32
29
import org .springframework .lang .Nullable ;
@@ -46,6 +43,24 @@ final class NodeDescriptionStore {
46
43
*/
47
44
private final Map <String , NodeDescription <?>> nodeDescriptionsByPrimaryLabel = new HashMap <>();
48
45
46
+ private final Map <NodeDescription <?>, Map <List <String >, NodeDescriptionAndLabels >> nodeDescriptionAndLabelsCache = new HashMap <>();
47
+
48
+ private final BiFunction <NodeDescription <?>, List <String >, NodeDescriptionAndLabels > nodeDescriptionAndLabels =
49
+ (nodeDescription , labels ) -> {
50
+ Map <List <String >, NodeDescriptionAndLabels > listNodeDescriptionAndLabelsMap = nodeDescriptionAndLabelsCache .get (nodeDescription );
51
+ if (listNodeDescriptionAndLabelsMap == null ) {
52
+ nodeDescriptionAndLabelsCache .put (nodeDescription , new HashMap <>());
53
+ listNodeDescriptionAndLabelsMap = nodeDescriptionAndLabelsCache .get (nodeDescription );
54
+ }
55
+
56
+ NodeDescriptionAndLabels cachedNodeDescriptionAndLabels = listNodeDescriptionAndLabelsMap .get (labels );
57
+ if (cachedNodeDescriptionAndLabels == null ) {
58
+ cachedNodeDescriptionAndLabels = computeConcreteNodeDescription (nodeDescription , labels );
59
+ listNodeDescriptionAndLabelsMap .put (labels , cachedNodeDescriptionAndLabels );
60
+ }
61
+ return cachedNodeDescriptionAndLabels ;
62
+ };
63
+
49
64
public boolean containsKey (String primaryLabel ) {
50
65
return nodeDescriptionsByPrimaryLabel .containsKey (primaryLabel );
51
66
}
@@ -81,7 +96,11 @@ public NodeDescription<?> getNodeDescription(Class<?> targetType) {
81
96
return null ;
82
97
}
83
98
84
- public NodeDescriptionAndLabels deriveConcreteNodeDescription (Neo4jPersistentEntity <?> entityDescription , List <String > labels ) {
99
+ public NodeDescriptionAndLabels deriveConcreteNodeDescription (NodeDescription <?> entityDescription , List <String > labels ) {
100
+ return nodeDescriptionAndLabels .apply (entityDescription , labels );
101
+ }
102
+
103
+ private NodeDescriptionAndLabels computeConcreteNodeDescription (NodeDescription <?> entityDescription , List <String > labels ) {
85
104
86
105
boolean isConcreteClassThatFulfillsEverything = !Modifier .isAbstract (entityDescription .getUnderlyingClass ().getModifiers ()) && entityDescription .getStaticLabels ().containsAll (labels );
87
106
@@ -97,25 +116,48 @@ public NodeDescriptionAndLabels deriveConcreteNodeDescription(Neo4jPersistentEnt
97
116
}
98
117
99
118
if (!haystack .isEmpty ()) {
100
- Function <NodeDescription <?>, Integer > count = (nodeDescription ) -> Math .toIntExact (nodeDescription .getStaticLabels ().stream ().filter (labels ::contains ).count ());
101
- Optional <Map .Entry <NodeDescription <?>, Integer >> mostMatchingNodeDescription = haystack .stream ()
102
- .filter (nd -> labels .containsAll (nd .getStaticLabels ())) // remove candidates having more mandatory labels
103
- .collect (Collectors .toMap (Function .identity (), nodeDescription -> count .apply (nodeDescription )))
104
- .entrySet ().stream ()
105
- .max (Comparator .comparingInt (Map .Entry ::getValue ));
106
-
107
- if (mostMatchingNodeDescription .isPresent ()) {
108
- NodeDescription <?> childNodeDescription = mostMatchingNodeDescription .get ().getKey ();
109
- List <String > staticLabels = childNodeDescription .getStaticLabels ();
110
- Set <String > surplusLabels = new HashSet <>(labels );
111
- surplusLabels .removeAll (staticLabels );
112
- return new NodeDescriptionAndLabels (childNodeDescription , surplusLabels );
119
+
120
+ NodeDescription <?> mostMatchingNodeDescription = null ;
121
+ Map <NodeDescription <?>, Integer > unmatchedLabelsCache = new HashMap <>();
122
+ List <String > mostMatchingStaticLabels = null ;
123
+
124
+ // Remove is faster than "stream, filter, count".
125
+ BiFunction <NodeDescription <?>, List <String >, Integer > unmatchedLabelsCount =
126
+ (nodeDescription , staticLabels ) -> {
127
+ Set <String > staticLabelsClone = new HashSet <>(staticLabels );
128
+ labels .forEach (staticLabelsClone ::remove );
129
+ return staticLabelsClone .size ();
130
+ };
131
+
132
+ for (NodeDescription <?> nd : haystack ) {
133
+ List <String > staticLabels = nd .getStaticLabels ();
134
+
135
+ if (staticLabels .containsAll (labels )) {
136
+ Set <String > surplusLabels = new HashSet <>(labels );
137
+ staticLabels .forEach (surplusLabels ::remove );
138
+ return new NodeDescriptionAndLabels (nd , surplusLabels );
139
+ }
140
+
141
+ unmatchedLabelsCache .put (nd , unmatchedLabelsCount .apply (nd , staticLabels ));
142
+ if (mostMatchingNodeDescription == null ) {
143
+ mostMatchingNodeDescription = nd ;
144
+ mostMatchingStaticLabels = staticLabels ;
145
+ continue ;
146
+ }
147
+
148
+ if (unmatchedLabelsCache .get (nd ) < unmatchedLabelsCache .get (mostMatchingNodeDescription )) {
149
+ mostMatchingNodeDescription = nd ;
150
+ }
113
151
}
152
+
153
+ Set <String > surplusLabels = new HashSet <>(labels );
154
+ mostMatchingStaticLabels .forEach (surplusLabels ::remove );
155
+ return new NodeDescriptionAndLabels (mostMatchingNodeDescription , surplusLabels );
114
156
}
115
157
116
158
Set <String > surplusLabels = new HashSet <>(labels );
117
159
surplusLabels .remove (entityDescription .getPrimaryLabel ());
118
- surplusLabels . removeAll ( entityDescription .getAdditionalLabels ());
160
+ entityDescription .getAdditionalLabels (). forEach ( surplusLabels :: remove );
119
161
return new NodeDescriptionAndLabels (entityDescription , surplusLabels );
120
162
}
121
163
}
0 commit comments