Skip to content

Commit 85d2e16

Browse files
authored
Optimize map unions to avoid building long lists (#14215)
1 parent 2ab8a54 commit 85d2e16

File tree

2 files changed

+157
-38
lines changed

2 files changed

+157
-38
lines changed

lib/elixir/lib/module/types/descr.ex

Lines changed: 117 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,8 +1307,115 @@ defmodule Module.Types.Descr do
13071307

13081308
defp map_only?(descr), do: empty?(Map.delete(descr, :map))
13091309

1310-
# Union is list concatenation
1311-
defp map_union(dnf1, dnf2), do: dnf1 ++ (dnf2 -- dnf1)
1310+
defp map_union(dnf1, dnf2) do
1311+
# Union is just concatenation, but we rely on some optimization strategies to
1312+
# avoid the list to grow when possible
1313+
1314+
# first pass trying to identify patterns where two maps can be fused as one
1315+
with [map1] <- dnf1,
1316+
[map2] <- dnf2,
1317+
optimized when optimized != nil <- maybe_optimize_map_union(map1, map2) do
1318+
[optimized]
1319+
else
1320+
# otherwise we just concatenate and remove structural duplicates
1321+
_ -> dnf1 ++ (dnf2 -- dnf1)
1322+
end
1323+
end
1324+
1325+
defp maybe_optimize_map_union({tag1, pos1, []} = map1, {tag2, pos2, []} = map2) do
1326+
case map_union_optimization_strategy(tag1, pos1, tag2, pos2) do
1327+
:all_equal ->
1328+
map1
1329+
1330+
:any_map ->
1331+
{:open, %{}, []}
1332+
1333+
{:one_key_difference, key, v1, v2} ->
1334+
new_pos = Map.put(pos1, key, union(v1, v2))
1335+
{tag1, new_pos, []}
1336+
1337+
:left_subtype_of_right ->
1338+
map2
1339+
1340+
:right_subtype_of_left ->
1341+
map1
1342+
1343+
nil ->
1344+
nil
1345+
end
1346+
end
1347+
1348+
defp maybe_optimize_map_union(_, _), do: nil
1349+
1350+
defp map_union_optimization_strategy(tag1, pos1, tag2, pos2)
1351+
defp map_union_optimization_strategy(tag, pos, tag, pos), do: :all_equal
1352+
defp map_union_optimization_strategy(:open, empty, _, _) when empty == %{}, do: :any_map
1353+
defp map_union_optimization_strategy(_, _, :open, empty) when empty == %{}, do: :any_map
1354+
1355+
defp map_union_optimization_strategy(tag, pos1, tag, pos2)
1356+
when map_size(pos1) == map_size(pos2) do
1357+
:maps.iterator(pos1)
1358+
|> :maps.next()
1359+
|> do_map_union_optimization_strategy(pos2, :all_equal)
1360+
end
1361+
1362+
defp map_union_optimization_strategy(:open, pos1, _, pos2)
1363+
when map_size(pos1) <= map_size(pos2) do
1364+
:maps.iterator(pos1)
1365+
|> :maps.next()
1366+
|> do_map_union_optimization_strategy(pos2, :right_subtype_of_left)
1367+
end
1368+
1369+
defp map_union_optimization_strategy(_, pos1, :open, pos2)
1370+
when map_size(pos1) >= map_size(pos2) do
1371+
:maps.iterator(pos2)
1372+
|> :maps.next()
1373+
|> do_map_union_optimization_strategy(pos1, :right_subtype_of_left)
1374+
|> case do
1375+
:right_subtype_of_left -> :left_subtype_of_right
1376+
nil -> nil
1377+
end
1378+
end
1379+
1380+
defp map_union_optimization_strategy(_, _, _, _), do: nil
1381+
1382+
defp do_map_union_optimization_strategy(:none, _, status), do: status
1383+
1384+
defp do_map_union_optimization_strategy({key, v1, iterator}, pos2, status) do
1385+
with %{^key => v2} <- pos2,
1386+
next_status when next_status != nil <- map_union_next_strategy(key, v1, v2, status) do
1387+
do_map_union_optimization_strategy(:maps.next(iterator), pos2, next_status)
1388+
else
1389+
_ -> nil
1390+
end
1391+
end
1392+
1393+
defp map_union_next_strategy(key, v1, v2, status)
1394+
1395+
# structurally equal values do not impact the ongoing strategy
1396+
defp map_union_next_strategy(_key, same, same, status), do: status
1397+
1398+
defp map_union_next_strategy(key, v1, v2, :all_equal) do
1399+
if key != :__struct__, do: {:one_key_difference, key, v1, v2}
1400+
end
1401+
1402+
defp map_union_next_strategy(_key, v1, v2, {:one_key_difference, _, d1, d2}) do
1403+
# we have at least two key differences now, we switch strategy
1404+
# if both are subtypes in one direction, keep checking
1405+
cond do
1406+
subtype?(d1, d2) and subtype?(v1, v2) -> :left_subtype_of_right
1407+
subtype?(d2, d1) and subtype?(v2, v1) -> :right_subtype_of_left
1408+
true -> nil
1409+
end
1410+
end
1411+
1412+
defp map_union_next_strategy(_key, v1, v2, :left_subtype_of_right) do
1413+
if subtype?(v1, v2), do: :left_subtype_of_right
1414+
end
1415+
1416+
defp map_union_next_strategy(_key, v1, v2, :right_subtype_of_left) do
1417+
if subtype?(v2, v1), do: :right_subtype_of_left
1418+
end
13121419

13131420
# Given two unions of maps, intersects each pair of maps.
13141421
defp map_intersection(dnf1, dnf2) do
@@ -1790,49 +1897,21 @@ defmodule Module.Types.Descr do
17901897

17911898
defp map_non_negated_fuse(maps) do
17921899
Enum.reduce(maps, [], fn map, acc ->
1793-
case Enum.split_while(acc, &non_fusible_maps?(map, &1)) do
1794-
{_, []} ->
1795-
[map | acc]
1796-
1797-
{others, [match | rest]} ->
1798-
fused = map_non_negated_fuse_pair(map, match)
1799-
others ++ [fused | rest]
1800-
end
1900+
fuse_with_first_fusible(map, acc)
18011901
end)
18021902
end
18031903

1804-
# Two maps are fusible if they differ in at most one element.
1805-
# Given they are of the same size, the side you traverse is not important.
1806-
defp non_fusible_maps?({_, fields1, []}, {_, fields2, []}) do
1807-
not fusible_maps?(Map.to_list(fields1), fields2, 0)
1808-
end
1809-
1810-
defp fusible_maps?([{:__struct__, value} | rest], fields, count) do
1811-
case Map.fetch!(fields, :__struct__) do
1812-
^value -> fusible_maps?(rest, fields, count)
1813-
_ -> false
1814-
end
1815-
end
1904+
defp fuse_with_first_fusible(map, []), do: [map]
18161905

1817-
defp fusible_maps?([{key, value} | rest], fields, count) do
1818-
case Map.fetch!(fields, key) do
1819-
^value -> fusible_maps?(rest, fields, count)
1820-
_ when count == 1 -> false
1821-
_ when count == 0 -> fusible_maps?(rest, fields, count + 1)
1906+
defp fuse_with_first_fusible(map, [candidate | rest]) do
1907+
if fused = maybe_optimize_map_union(map, candidate) do
1908+
# we found a fusible candidate, we're done
1909+
[fused | rest]
1910+
else
1911+
[candidate | fuse_with_first_fusible(map, rest)]
18221912
end
18231913
end
18241914

1825-
defp fusible_maps?([], _fields, _count), do: true
1826-
1827-
defp map_non_negated_fuse_pair({tag, fields1, []}, {_, fields2, []}) do
1828-
fields =
1829-
symmetrical_merge(fields1, fields2, fn _k, v1, v2 ->
1830-
if v1 == v2, do: v1, else: union(v1, v2)
1831-
end)
1832-
1833-
{tag, fields, []}
1834-
end
1835-
18361915
# If all fields are the same except one, we can optimize map difference.
18371916
defp map_all_but_one?(tag1, fields1, tag2, fields2) do
18381917
keys1 = Map.keys(fields1)

lib/elixir/test/elixir/module/types/descr_test.exs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,46 @@ defmodule Module.Types.DescrTest do
105105
assert union(difference(list(term()), list(integer())), list(integer()))
106106
|> equal?(list(term()))
107107
end
108+
109+
test "optimizations" do
110+
# The tests are checking the actual implementation, not the semantics.
111+
# This is why we are using structural comparisons.
112+
# It's fine to remove these if the implementation changes, but breaking
113+
# these might have an important impact on compile times.
114+
115+
# Optimization one: same tags, all but one key are structurally equal
116+
assert union(
117+
open_map(a: float(), b: atom()),
118+
open_map(a: integer(), b: atom())
119+
) == open_map(a: union(float(), integer()), b: atom())
120+
121+
assert union(
122+
closed_map(a: float(), b: atom()),
123+
closed_map(a: integer(), b: atom())
124+
) == closed_map(a: union(float(), integer()), b: atom())
125+
126+
# Optimization two: we can tell that one map is a trivial subtype of the other:
127+
128+
assert union(
129+
closed_map(a: term(), b: term()),
130+
closed_map(a: float(), b: binary())
131+
) == closed_map(a: term(), b: term())
132+
133+
assert union(
134+
open_map(a: term()),
135+
closed_map(a: float(), b: binary())
136+
) == open_map(a: term())
137+
138+
assert union(
139+
closed_map(a: float(), b: binary()),
140+
open_map(a: term())
141+
) == open_map(a: term())
142+
143+
assert union(
144+
closed_map(a: term(), b: tuple([term(), term()])),
145+
closed_map(a: float(), b: tuple([atom(), binary()]))
146+
) == closed_map(a: term(), b: tuple([term(), term()]))
147+
end
108148
end
109149

110150
describe "intersection" do

0 commit comments

Comments
 (0)