TZWZ's personal page
Supercharging Ecto transactions for better integration with PubSub
06.12.2025
Ecto
Elixir
Phoenix
Phoenix LiveView

Ecto has tooling for transactions using the Multi module. The usual method of sending a message via PubSub after a transaction is more or less about doing a case statement and taking value out of the values stored in the transaction and broadcasting it after transactions succeed. For example:

Multi.new()
|> Multi.run(:field1, fn _, _ ->
  DifferentModule.do_update(some_attrs)
end)
|> Multi.insert(:field2, changeset(%__MODULE__{}, attrs))
|> Multi.update(:field3, SomeModule.update_changeset(%SomeModule{}, update_attrs))
|> Repo.transaction()
|> case do
  {:ok, values} ->
    DifferentModule.broadcast(values.field1)
    broadcast(values.field2)
    SomeModule.broadcast(values.field3)
    {:ok, values.field2}

  {:error, _, error, _} ->
    {:error, error}
end

This is fine as long as the transaction is pretty simple. Otherwise we need to remember, in every transaction, what data got edited and needs to be broadcasted, as you can see in the above example. This becomes problematic pretty soon when you have parts of data updated in different modules, which may not know if they are in a transaction or not. They shouldn't even care. Neither should the caller module care if a given change needs to be broadcasted or not if it's not its responsibility.

For example, when you write an exchange of a product (real or in some game) between users, you may have Exchange module that, when the exchange happens, needs to do a transaction that takes money from user A, gives it to user B, then takes the product from user B and gives it to user A. The taking/giving part is something that may and probably will be used in more places than just exchanges between users.

We need a better solution. One that doesn't require us to remember what to broadcast every time we do a transaction.

GenServer solution

The first way I have ever solved this was using GenServer that listens for messages telling it what to broadcast. The solution I ended up with looks like this:

defmodule MyApp.Repo.Transaction do
  require Logger
  alias MyApp.Repo
  use GenServer

  defstruct callbacks: [], return_key: nil

  @type transaction_calls() :: list({name :: atom(), transaction_cb_t()})
  @type return_key() :: atom() | nil
  @type transaction_result() :: {:ok, any()} | :ok | {:error, term()}
  @type transaction_cb_t() ::
          (map() ->
             {:ok, any(), (any() -> nil)} | {:ok, any()} | :ok | nil | {:error, term()})

  @spec new() :: %__MODULE__{}
  def new(), do: %__MODULE__{}

  @spec add(%__MODULE__{}, atom(), transaction_cb_t()) :: %__MODULE__{}
  def add(%__MODULE__{} = transaction, key, cb),
    do: put_in(transaction.callbacks, [{key, cb} | transaction.callbacks])

  @spec return_key(%__MODULE__{}, atom()) :: %__MODULE__{}
  def return_key(%__MODULE__{} = transaction, key), do: put_in(transaction.return_key, key)

  @doc """
  `:ok` is returned when `return_key` is `nil`,
  otherwise it's `{:ok, any()}` where `any()` contains whatever was under `return_key`
  """
  @spec run(%__MODULE__{}) :: transaction_result()
  def run(%__MODULE__{} = transaction) do
    if Repo.in_transaction?() do
      transaction.callbacks
      |> process_callbacks()
      |> handle_callbacks_result(transaction.return_key)
    else
      {:ok, pid} = GenServer.start_link(__MODULE__, nil)
      send(pid, {:start, transaction})
      # 24h timeout, just in case
      GenServer.call(pid, :get_result, 86_400_000)
    end
  end

  def init(nil) do
    {:ok, %{reply_to: nil, reply: nil, results: [], finished: false}}
  end

  def handle_info({:start, transaction}, state) do
    result_msg =
      transaction.callbacks
      |> run_transaction()
      |> handle_callbacks_result(transaction.return_key)

    GenServer.cast(self(), {:transaction_end, result_msg})
    {:noreply, state}
  end

  defp run_transaction(callbacks) do
    Repo.transaction(fn ->
      case process_callbacks(callbacks) do
        {:ok, values, results} -> {values, results}
        {:error, reason} -> Repo.rollback(reason)
      end
    end)
    |> case do
      {:ok, {values, results}} -> {:ok, values, results}
      err -> err
    end
  end

  def process_callbacks(callbacks) do
    callbacks
    |> Enum.reverse()
    |> Enum.map_reduce(%{}, fn
      _, {:error, reason} ->
        {nil, {:error, reason}}

      {name, cb}, values ->
        try do
          case cb.(values) do
            {:error, reason} ->
              {nil, {:error, reason}}

            val when val in [:ok, nil] ->
              values = update_values(values, name, nil)
              {nil, values}

            {:ok, data} ->
              values = update_values(values, name, data)
              {nil, values}

            {:ok, data, cb} ->
              values = update_values(values, name, data)
              {{data, cb}, values}
          end
        rescue
          err ->
            Logger.error("Transaction exception for field #{inspect(name)}")
            reraise(err, __STACKTRACE__)
        end
    end)
    |> case do
      {_, {:error, reason}} -> {:error, reason}
      {results, values} -> {:ok, values, Enum.reject(results, &is_nil/1)}
    end
  end

  defp update_values(values, name, result) do
    Map.put(values, name, result)
  end

  defp handle_callbacks_result(result, return_key) do
    case result do
      {:ok, values, results} ->
        send_results(results)
        return_value(values, return_key)

      {:exception, reason} ->
        Logger.error("Transaction raised because: #{inspect(reason, pretty: true)}")
        {:error, "Unexpected error"}

      {:error, reason} ->
        {:error, reason}
    end
  end

  defp send_results(results) do
    GenServer.cast(self(), {:results, results})
  end

  defp return_value(_values, nil), do: :ok
  defp return_value(values, key), do: {:ok, values[key]}

  def handle_cast({:results, results}, state) do
    # basically overensuring, because we have sort of guarantee that
    # we get start_link -> :start -> :get_result -> <results>
    #   -> :transaction_end -> {:stop, _, _}
    # only possible uncertainty is between :get_result and <results>
    # but in both cases :transaction_end will be later anyway
    # because we get all <results> while doing :start,
    # which only after finishing processing sends :transaction_end
    case state.finished do
      true ->
        call_results(results)
        {:stop, :normal, state}

      :failed ->
        {:stop, :normal, state}

      false ->
        {:noreply, put_in(state.results, [results | state.results])}
    end
  end

  def handle_cast({:transaction_end, reply}, state) do
    finish =
      if not is_tuple(reply) or elem(reply, 0) == :ok do
        call_results(state.results)
        true
      else
        :failed
      end

    state = put_in(state.finished, finish)

    case state.reply_to do
      nil ->
        {:noreply, put_in(state.reply, reply), 10_000}

      from ->
        GenServer.reply(from, reply)
        {:stop, :normal, state}
    end
  end

  defp call_results(results) do
    for result <- results |> Enum.flat_map(& &1) |> Enum.reverse(), {r_val, cb} = result do
      cb.(r_val)
    end
  end

  def handle_call(:get_result, from, state) when is_nil(state.reply) do
    {:noreply, put_in(state.reply_to, from)}
  end

  def handle_call(:get_result, from, state) do
    GenServer.reply(from, state.reply)
    {:stop, :normal, state}
  end
end

Ecto gives us Repo.in_transaction?(). Using it, we can know if we're doing a nested transaction and thus don't need to start a new GenServer to start a new transaction. Otherwise we have to start a GenServer and send it callbacks we added to be run in the transaction.

Way of using these transactions

The API to use this is basically to do:

Repo.Transaction.new()
|> Repo.Transaction.add(:field1, fn _ ->
  DifferentModule.do_update(some_attrs)
end)
|> Repo.Transaction.add(:field2, fn _ ->
  {:ok, :field2_value}
end)
|> Repo.Transaction.add(:field3, fn values ->
  {:ok, {:field3_value, values.field2}, &broadcast/1}
end)
|> Repo.Transaction.run()

This is mostly like the Multi.run function, but the return value can be a three-element tuple of :ok, return value, and callback to be called with return value. Or whatever else Multi.run accepts - {:ok, value} and {:error, err_value}. If we need a transaction to return some value from all added callbacks, we can do, e.g., Repo.Transaction.return_key(:field2).

How it works

When some function needs a transaction and it isn't yet in a transaction. it starts a new GenServer to which it sends all information added using the public API. After that it does GenServer.call to wait for the transaction to end and send the return value (if any). The GenServer itself runs over all callbacks in order (process_callbacks), puts returned values into storage, and broadcast callbacks are sent to itself to call them only after the transaction finishes successfully. Callbacks can return one of multiple values: :ok, nil, {:ok, value}, {:ok, value, callback}, {:error, reason}. The :ok and nil are for cases where we don't get a value but need side effects of an operation. It also allows us to write code like this:

Repo.Transaction.new()
|> Repo.Transaction.add(:field1, fn _ ->
  SomeModule.do_update(some_attrs)
end)
|> Repo.Transaction.add(:field2, fn _ ->
  OtherModule.do_update(other_attrs)
end)
|> Repo.Transaction.add(:check, fn values ->
  if values.field1.a_value + values.field2.some_value < 10 do
    {:error, "Values too small"}
  end
end)
|> Repo.Transaction.run()

We can check only for one case (error), then not worry when all is okay, as it will return nil which will be handled as success. Returning {:error, reason} from a callback stops processing further callbacks and returns that error from a transaction.

Each nested transaction (no matter how nested it is) does basically the same thing as already running GenServer. It runs callbacks in order and sends to self() callbacks to call after successfull transaction completion.

After the main transaction finishes, it calls all after-end callbacks (notifications) and returns a value. Considering Rrlang's messages are queued and processed in the order they arrive in, we should get all results before we ever get the :transaction_end message. But even if that's not the case, handle_cast for broadcast callbacks receiving has matches in the case statement if that happens.

All nested transactions have return values of each callback stored separately as they are run in different contexts; the only thing connecting any two transactions is being run in the same process.

Version without GenServer

The version above is pretty simple and fast, but we need to start a new process and send data between them. Things can fail along the way, and sending results fails if the process crashes, so we will wait for results until GenServer.call times out. And we also waste memory having to copy data to send it around.

To avoid these problems, I've made a second version, one that doesn't require a GenServer. With this, we improve performance and reduce risks in one move.

The solution I ended up with is like below. It's made mostly by taking parts of the original GenServer solution and adapting it to not need a new process.

defmodule MyApp.Repo.Transaction do
  @moduledoc """
  `notification` returned from action (`add`) can have `fn` and `reload` keys.
  `fn` gets called only after top level transaction finishes.
  This way there won't be eg. PubSub broadcast before entire real transaction
  gets commited into DB.

  `reload` function will get called immediately after your callback returns.
  This way all later `add` calls can assume this value to be reloaded.

  Both of these functions get returned value (`action_cb_value`) as only argument.
  """
  require Logger
  alias MyApp.Repo

  defstruct callbacks: []

  @type action_key() :: atom()
  @type action_cb_value() :: any()
  @type final_result() :: {:ok, any()} | {:ok, any(), list((any() -> any()))} | {:error, term()}
  @type collected_values() :: %{action_key() => any()}
  @type notification() ::
          [
            fn: (action_cb_value() -> any()),
            reload: (action_cb_value() -> any())
          ]
  @type action_cb() ::
          (collected_values() ->
             {:ok, action_cb_value(), notification()}
             | {:ok, action_cb_value()}
             | :ok
             | nil
             | {:error, term()})

  @type t() :: %__MODULE__{callbacks: list(action_cb())}

  @spec new() :: t()
  def new(), do: %__MODULE__{}

  @spec add(t(), action_key(), action_cb()) :: t()
  def add(%__MODULE__{} = transaction, key, cb),
    do: put_in(transaction.callbacks, [{key, cb} | transaction.callbacks])

  @doc """
  otherwise:
  - If run as nested transaction it can return `{:ok, value, notifications}`,
    where `value` may be `nil` if there's no return value.
  - If top transaction returns `{:ok, any()}` where `any()` will be `nil` if no return value is needed
  """
  @spec run(t(), return_key :: action_key() | nil) :: final_result()
  def run(%__MODULE__{} = transaction, return_key \\ nil) do
    if Repo.in_transaction?() do
      transaction.callbacks
      |> process_callbacks()
      |> transaction_result(return_key)
    else
      transaction.callbacks
      |> run_transaction()
      |> transaction_result(return_key)
      |> final_result()
    end
  end

  defp run_transaction(callbacks) do
    Repo.transaction(fn ->
      case process_callbacks(callbacks) do
        {:ok, values, results} -> {values, results}
        {:error, reason} -> Repo.rollback(reason)
      end
    end)
    |> case do
      {:ok, {values, results}} -> {:ok, values, results}
      err -> err
    end
  end

  def process_callbacks(callbacks) do
    callbacks
    |> Enum.reverse()
    |> Enum.map_reduce(%{}, fn
      _, {:error, reason} ->
        {nil, {:error, reason}}

      {key, cb}, values ->
        try do
          cb.(values)
        rescue
          err ->
            Logger.error("Transaction exception for field #{inspect(key)}")
            reraise(err, __STACKTRACE__)
        end
        |> process_action_result(key, values)
    end)
    |> case do
      {_, {:error, reason}} -> {:error, reason}
      {notifs, values} -> {:ok, values, notifs}
    end
  end

  defp process_action_result(result, key, values) do
    case result do
      {:error, reason} ->
        {nil, {:error, reason}}

      val when val in [:ok, nil] ->
        values = update_values(values, key, nil)
        {nil, values}

      {:ok, data} ->
        values = update_values(values, key, data)
        {nil, values}

      {:ok, data, notif} ->
        values = update_values(values, key, data)
        {{:result, data, notif}, values}
    end
  end

  defp update_values(values, name, result) do
    Map.put(values, name, result)
  end

  defp transaction_result(result, return_key) do
    case result do
      {:ok, values, results} -> {:ok, values[return_key], results}
      {:error, _} = err -> err
    end
  end

  defp final_result({:error, _} = err), do: err

  defp final_result({:ok, result, notifs}) do
    send_notifs(notifs)
    {:ok, result}
  end

  defp send_notifs(notifs, backlog \\ [])

  defp send_notifs([{:result, data, notif} | notifs], backlog) do
    call_notification(data, notif)
    send_notifs(notifs, backlog)
  end

  defp send_notifs([nil | notifs], backlog) do
    send_notifs(notifs, backlog)
  end

  defp send_notifs([notif | notifs], backlog) when is_list(notif) do
    send_notifs(notif, [notifs | backlog])
  end

  defp send_notifs([], [notifs | backlog]) do
    send_notifs(notifs, backlog)
  end

  defp send_notifs([], []) do
    :ok
  end

  defp call_notification(data, notify_fn) when is_function(notify_fn) do
    notify_fn.(data)
  end

  defp call_notification(_data, _notify_fn) do
    nil
  end
end

The outer API didn't change much. This time there's no return_key function. Instead, you can just pass a key you want returned when calling the run function.

This time we can't send notification to the same/our process, as there isn't GenServer stored. Every time we run a transaction, we either return {:ok, return_value, notifications} when we are in a nested transaction or {:ok, return_value} if we are in the outermost transaction. Both cases can return {:error, reason}, as an error means the transaction should be aborted anyway, so we don't have to worry about callbacks after finishing the transaction.

Another change about passing notifications around is that we store nested transactions's results inside the outer transaction list of notification, which leads us to storing [[[notifs_2], notifs_1], notifs] and Enum.flat_map remove only one level of nesting. Handling for this had to be written from scratch in the send_notifs functions.

In short, it works by calling a notification function when it's just a function with data. If it's a list, it passes the tail of a list to be processed later (backlog) and then proceeds to iterate over that new list of notifs. When the main list of notifications gets empty, it starts taking lists out of the backlog list. It does that until both lists are empty, so all notifications are sent and there's nothing to be done.

Conclusion

In this post we learned a way to make transactions more useful and featureful than they normally are. This way we can make our lives easier without much hassle, as the module is pretty short.

You could use something like the Ash framework. It has way more functionalities, but also a learning curve. Whether you want to use this solution or the framework is up to you and your needs.

Related projects