Skip to content

Commit 06e10be

Browse files
authored
Merge pull request #197 from lucasimi/develop
Improved app performance with smart caching
2 parents 8cd58f3 + 8cff2bc commit 06e10be

File tree

1 file changed

+100
-25
lines changed

1 file changed

+100
-25
lines changed

app/streamlit_app.py

+100-25
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939

4040
SAMPLE_FRAC = 0.1
4141

42-
V_DATA_SUMMARY_FEAT = 'feature'
42+
V_DATA_SUMMARY_FEAT = 'Feat'
4343

44-
V_DATA_SUMMARY_HIST = 'histogram'
44+
V_DATA_SUMMARY_HIST = 'Hist'
4545

4646
V_DATA_SUMMARY_BINS = 15
4747

@@ -136,6 +136,27 @@ def fix_data(data):
136136
return df
137137

138138

139+
def _get_dim(fig):
140+
dim = 2
141+
for trace in fig.data:
142+
if '3d' in trace.type:
143+
dim = 3
144+
return dim
145+
146+
147+
def _get_graph_no_attribs(graph):
148+
graph_no_attribs = nx.Graph()
149+
graph_no_attribs.add_nodes_from(graph.nodes())
150+
graph_no_attribs.add_edges_from(graph.edges())
151+
return graph_no_attribs
152+
153+
154+
def _encode_graph(graph):
155+
nodes = tuple(sorted([int(v) for v in graph.nodes()]))
156+
edges = tuple(sorted(tuple(sorted(e)) for e in graph.edges()))
157+
return (nodes, edges)
158+
159+
139160
def _get_data_summary(df_X, df_y):
140161
df = pd.concat([get_sample(df_y), get_sample(df_X)], axis=1)
141162
df_hist = pd.DataFrame({
@@ -417,6 +438,15 @@ def mapper_clustering_input_section():
417438
return clustering
418439

419440

441+
@st.cache_data(
442+
hash_funcs={'tdamapper.learn.MapperAlgorithm': MapperAlgorithm.__repr__},
443+
show_spinner='Compuring Mapper',
444+
)
445+
def compute_mapper(mapper, X, y):
446+
mapper_graph = mapper.fit_transform(X, y)
447+
return mapper_graph
448+
449+
420450
def mapper_input_section(X):
421451
lens = mapper_lens_input_section(X)
422452
st.divider()
@@ -429,7 +459,7 @@ def mapper_input_section(X):
429459
verbose=True,
430460
n_jobs=1,
431461
)
432-
mapper_graph = mapper_algo.fit_transform(X, lens)
462+
mapper_graph = compute_mapper(mapper_algo, X, lens)
433463
return mapper_graph
434464

435465

@@ -528,14 +558,28 @@ def plot_dim_input_section():
528558
return dim
529559

530560

531-
def plot_input_section(df_X, df_y, mapper_graph):
532-
st.header('🎨 Plot')
561+
@st.cache_data(
562+
hash_funcs={'networkx.classes.graph.Graph': lambda g: _encode_graph(_get_graph_no_attribs(g))},
563+
show_spinner='Generating Mapper Embedding',
564+
)
565+
def compute_mapper_plot(mapper_graph, dim, seed):
566+
mapper_plot = MapperPlot(mapper_graph, dim, seed=seed)
567+
return mapper_plot
568+
569+
570+
def mapper_plot_section(mapper_graph):
571+
st.header('🗺️ Layout')
533572
dim = plot_dim_input_section()
534573
seed = plot_seed_input_section()
535-
agg, agg_name = plot_agg_input_section()
536-
cmap = plot_cmap_input_section()
537-
colors, colors_feat = plot_color_input_section(df_X, df_y)
538-
mapper_plot = MapperPlot(mapper_graph, dim, seed=seed)
574+
mapper_plot = compute_mapper_plot(mapper_graph, dim, seed)
575+
return mapper_plot
576+
577+
578+
@st.cache_data(
579+
hash_funcs={'tdamapper.plot.MapperPlot': lambda mp: mp.positions},
580+
show_spinner='Rendering Mapper',
581+
)
582+
def compute_mapper_fig(mapper_plot, colors, cmap, agg, agg_name, colors_feat):
539583
mapper_fig = mapper_plot.plot_plotly(
540584
colors,
541585
agg=agg,
@@ -544,8 +588,25 @@ def plot_input_section(df_X, df_y, mapper_graph):
544588
width=600,
545589
height=600,
546590
)
591+
return mapper_fig
592+
593+
594+
def mapper_figure_section(df_X, df_y, mapper_plot):
595+
st.header('🎨 Plot')
596+
agg, agg_name = plot_agg_input_section()
597+
cmap = plot_cmap_input_section()
598+
colors, colors_feat = plot_color_input_section(df_X, df_y)
599+
mapper_fig = compute_mapper_fig(
600+
mapper_plot,
601+
colors=colors,
602+
agg=agg,
603+
cmap=cmap,
604+
agg_name=agg_name,
605+
colors_feat=colors_feat,
606+
)
607+
dim = _get_dim(mapper_fig)
547608
mapper_fig.update_layout(
548-
dragmode='pan' if dim == 2 else 'orbit',
609+
dragmode='orbit' if dim == 3 else 'pan',
549610
uirevision='constant',
550611
margin=dict(b=0, l=0, r=0, t=0),
551612
)
@@ -557,7 +618,7 @@ def plot_input_section(df_X, df_y, mapper_graph):
557618
scaleanchor='x',
558619
scaleratio=1,
559620
)
560-
return mapper_plot, mapper_fig
621+
return mapper_fig
561622

562623

563624
def mapper_rendering_section(mapper_graph, mapper_fig):
@@ -570,26 +631,38 @@ def mapper_rendering_section(mapper_graph, mapper_fig):
570631

571632

572633
def data_summary_section(df_X, df_y, mapper_graph):
573-
df_data = pd.DataFrame({
574-
'samples': [len(df_X)],
575-
'input features': [len(df_X.columns)],
576-
'target features': [len(df_y.columns)],
577-
})
578-
df_graph = pd.DataFrame({
579-
'nodes': [mapper_graph.number_of_nodes()],
580-
'edges': [mapper_graph.number_of_edges()],
581-
'connected components': [nx.number_connected_components(mapper_graph)],
634+
df_stats = pd.DataFrame({
635+
'Stat': [
636+
'Samples',
637+
'Input Feats',
638+
'Target Feats',
639+
'Nodes',
640+
'Edges',
641+
'Conn. Comp.'
642+
],
643+
'Value': [
644+
len(df_X),
645+
len(df_X.columns),
646+
len(df_y.columns),
647+
mapper_graph.number_of_nodes(),
648+
mapper_graph.number_of_edges(),
649+
nx.number_connected_components(mapper_graph),
650+
]
582651
})
583-
st.dataframe(df_graph, hide_index=True, use_container_width=True)
584-
st.dataframe(df_data, hide_index=True, use_container_width=True)
652+
st.dataframe(
653+
df_stats,
654+
hide_index=True,
655+
use_container_width=True,
656+
height=250,
657+
)
585658
df_summary = _get_data_summary(df_X, df_y)
586659
st.dataframe(
587660
df_summary,
588661
hide_index=True,
589-
height=400,
662+
height=330,
590663
column_config={
591664
V_DATA_SUMMARY_HIST: st.column_config.AreaChartColumn(
592-
width='large',
665+
width='small',
593666
),
594667
V_DATA_SUMMARY_FEAT: st.column_config.TextColumn(
595668
width='small',
@@ -663,7 +736,9 @@ def main():
663736
st.divider()
664737
mapper_graph = mapper_input_section(df_X.to_numpy())
665738
st.divider()
666-
mapper_plot, mapper_fig = plot_input_section(df_X, df_y, mapper_graph)
739+
mapper_plot = mapper_plot_section(mapper_graph)
740+
st.divider()
741+
mapper_fig = mapper_figure_section(df_X, df_y, mapper_plot)
667742
col_0, col_1 = st.columns([1, 3])
668743
with col_0:
669744
data_summary_section(df_X, df_y, mapper_graph)

0 commit comments

Comments
 (0)