Based on the comments I found a working solution to collect plots generated in a loop without having to access the plotting function, and saving them to an animation.
The original loop I was using was the following:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
I'll first give the solution, and then a step-by-step explanation of the logic.
Final Solution
plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")
Step-by-step explanation
To save the plots and animate them for later, I had to do the following modifications:
- get a handle to the figures and axes generated within the figure:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    fig, ax = plt.gcf(), plt.gca()  # add this
- use the very first figure as a drawing canvas for all future axis:
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:  # fig is the one we'll use for our animation canvas.
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()  # we will ignore dummy_fig
        plt.close(dummy_fig)
- before closing the other figures, move their axis to our main canvas
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.remove()  # remove ax from dummy_fig
        ax.figure = fig  # now assign it to our canvas fig
        fig.add_axes(ax)  # also patch the fig axes to know about it
        plt.close(dummy_fig)
- set the axes to be animated (doesn't seem to be strictly necessary though)
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)  # from plt example, but doesn't seem needed
        # we could however add info to each plot here, e.g.
        # ax.set(xlabel=f"image {i}")  # this could be done in i ==0 cond. too.
        ax.remove()
        ax.figure = fig 
        fig.add_axes(ax)
        plt.close(dummy_fig)
- Now simply collect all of these axes on a list, and plot them.
plots = []
for i in range(100):
    arr = np.random.rand(10,10)
    heatmap(arr)
    if i==0:
        fig, ax = plt.gcf(), plt.gca()
    else:
        dummy_fig, ax = plt.gcf(), plt.gca()
        ax.set(animated=True)
        ax.remove()
        ax.figure = fig
        fig.add_axes(ax)
        plt.close(dummy_fig)
        
    plots.append([ax])
ani = animation.ArtistAnimation(fig, plots, interval=50, repeat_delay=200)
ani.save("video.mp4")