diff --git a/lib/elixir/lib/enum.ex b/lib/elixir/lib/enum.ex index 66851044b6f..5b919bf45b5 100644 --- a/lib/elixir/lib/enum.ex +++ b/lib/elixir/lib/enum.ex @@ -3447,6 +3447,8 @@ defmodule Enum do Raises `ArithmeticError` if `enumerable` contains a non-numeric value. + If you need to apply a transformation first, consider using `Enum.sum_by/2` instead. + ## Examples iex> Enum.sum([1, 2, 3]) @@ -3473,6 +3475,40 @@ defmodule Enum do reduce(enumerable, 0, &+/2) end + @doc """ + Maps and sums the given enumerable in one pass. + + Raises `ArithmeticError` if `fun` returns a non-numeric value. + + ## Examples + + iex> Enum.sum_by([%{count: 1}, %{count: 2}, %{count: 3}], fn x -> x.count end) + 6 + + iex> Enum.sum_by(1..3, fn x -> x ** 2 end) + 14 + + iex> Enum.sum_by([], fn x -> x.count end) + 0 + + Filtering can be achieved by returning `0` to ignore elements: + + iex> Enum.sum_by([1, -2, 3], fn x -> if x > 0, do: x, else: 0 end) + 4 + + """ + @doc since: "1.18.0" + @spec sum_by(t, (element -> number)) :: number + def sum_by(enumerable, mapper) + + def sum_by(list, mapper) when is_list(list) and is_function(mapper, 1) do + sum_by_list(list, mapper, 0) + end + + def sum_by(enumerable, mapper) when is_function(mapper, 1) do + reduce(enumerable, 0, fn x, acc -> acc + mapper.(x) end) + end + @doc """ Returns the product of all elements. @@ -4770,6 +4806,11 @@ defmodule Enum do {:lists.reverse(acc), []} end + ## sum_by + + defp sum_by_list([], _, acc), do: acc + defp sum_by_list([h | t], mapper, acc), do: sum_by_list(t, mapper, acc + mapper.(h)) + ## take defp take_list(_list, 0), do: [] diff --git a/lib/elixir/pages/cheatsheets/enum-cheat.cheatmd b/lib/elixir/pages/cheatsheets/enum-cheat.cheatmd index 73255de6c11..d742abeb748 100644 --- a/lib/elixir/pages/cheatsheets/enum-cheat.cheatmd +++ b/lib/elixir/pages/cheatsheets/enum-cheat.cheatmd @@ -297,6 +297,15 @@ iex> cart |> Enum.map(& &1.count) |> Enum.sum() 10 ``` +Note: this should typically be done in one pass using `Enum.sum_by/2`. + +### [`sum_by(enum, mapper)`](`Enum.sum_by/2`) + +```elixir +iex> Enum.sum_by(cart, & &1.count) +10 +``` + ### [`product(enum)`](`Enum.product/1`) ```elixir diff --git a/lib/elixir/test/elixir/enum_test.exs b/lib/elixir/test/elixir/enum_test.exs index 42fb0f7d8ba..a2aba41ab92 100644 --- a/lib/elixir/test/elixir/enum_test.exs +++ b/lib/elixir/test/elixir/enum_test.exs @@ -1321,6 +1321,24 @@ defmodule EnumTest do end end + test "sum_by/2" do + assert Enum.sum_by([], &hd/1) == 0 + assert Enum.sum_by([[1]], &hd/1) == 1 + assert Enum.sum_by([[1], [2], [3]], &hd/1) == 6 + assert Enum.sum_by([[1.1], [2.2], [3.3]], &hd/1) == 6.6 + assert Enum.sum_by([[-3], [-2], [-1], [0], [1], [2], [3]], &hd/1) == 0 + + assert Enum.sum_by(1..3, &(&1 ** 2)) == 14 + + assert_raise ArithmeticError, fn -> + Enum.sum_by([[{}]], &hd/1) + end + + assert_raise ArithmeticError, fn -> + Enum.sum_by([[1], [{}]], &hd/1) + end + end + test "product/1" do assert Enum.product([]) == 1 assert Enum.product([1]) == 1