I am having trouble plotting a grid of images in dash. In particular, I want to create a grid of n_rows x n_cols, where n_cols can be anywhere between 5 and 10 and n_rows need to be user specified. Every image is the same size.
So far, I have used px.imshow() with the facet_col option. However, for a large number of images the images get very small.
I have used plotly.subplots.make_subplots() directly, however, when I change n_rows within the app the images do not stay in their respective positions and get moved around. In particular, when I first choose to plot 5 images beneath each other, the plot will look fine, plotting five images closely together (chosen first). However, when the user subsequently chooses to plot 2 points, the images will be plotted in the same space the five images were. Similarly, when the user first chooses to plot 2 images the image looks fine. Subsequently, when the user plots 5 images, they all get messed up.
See the following code to reproduce:
from dash import Dash, dcc, html
from dash.dependencies import Input, Output
import plotly.graph_objects as go
import dash.exceptions
import numpy as np
from plotly.subplots import make_subplots
def create_fig(n_rows):
    fig = make_subplots(n_rows, 1)
    for i in range(n_rows):
        fig.add_trace(go.Image(z=np.random.randint(0, 255, size=(60, 60, 3))), i+1, j)
    fig.update_layout(autosize=True,
                      width=100,
                      height=n_rows*100,
                      margin=dict(l=0, r=0, b=0, t=0)
                      )
    return fig
def create_fig_using_pximshow(n_rows):
    images = np.random.randint(0, 255, size=(n_rows, 60, 60, 3))
    fig = px.imshow(images, facet_col=0, facet_col_wrap=1, facet_row_spacing=0.3/n_rows)
    fig.update_layout(autosize=True,
                      width=100,
                      height=n_rows*100,
                      margin=dict(l=0, r=0, b=0, t=0)
                      )
    return fig
app = Dash(__name__)
app.layout = html.Div(
    [dcc.Graph(id='graph', style={'overflow':'scroll'}),
     dcc.Input(value=0, id='input', type='number')])
@app.callback(Output('graph', 'figure'), Input('input', 'value'))
def create_graph(n_rows):
    if n_rows == 0:
        raise dash.exceptions.PreventUpdate("Prevent update")
    else:
        if n_rows is not None:
            return create_fig(int(n_rows))
            # return create_fig_using_pximshow(int(n_rows))
if __name__ == '__main__':
    app.run_server(debug=True)
Note that changing go.Image to go.Scatter makes the code work and the dashboard graph adapt to the amount of rows present. Is there a better/easier way to plot a grid of images?
