39
39
40
40
SAMPLE_FRAC = 0.1
41
41
42
- V_DATA_SUMMARY_FEAT = 'feature '
42
+ V_DATA_SUMMARY_FEAT = 'Feat '
43
43
44
- V_DATA_SUMMARY_HIST = 'histogram '
44
+ V_DATA_SUMMARY_HIST = 'Hist '
45
45
46
46
V_DATA_SUMMARY_BINS = 15
47
47
@@ -136,6 +136,27 @@ def fix_data(data):
136
136
return df
137
137
138
138
139
+ def _get_dim (fig ):
140
+ dim = 2
141
+ for trace in fig .data :
142
+ if '3d' in trace .type :
143
+ dim = 3
144
+ return dim
145
+
146
+
147
+ def _get_graph_no_attribs (graph ):
148
+ graph_no_attribs = nx .Graph ()
149
+ graph_no_attribs .add_nodes_from (graph .nodes ())
150
+ graph_no_attribs .add_edges_from (graph .edges ())
151
+ return graph_no_attribs
152
+
153
+
154
+ def _encode_graph (graph ):
155
+ nodes = tuple (sorted ([int (v ) for v in graph .nodes ()]))
156
+ edges = tuple (sorted (tuple (sorted (e )) for e in graph .edges ()))
157
+ return (nodes , edges )
158
+
159
+
139
160
def _get_data_summary (df_X , df_y ):
140
161
df = pd .concat ([get_sample (df_y ), get_sample (df_X )], axis = 1 )
141
162
df_hist = pd .DataFrame ({
@@ -417,6 +438,15 @@ def mapper_clustering_input_section():
417
438
return clustering
418
439
419
440
441
+ @st .cache_data (
442
+ hash_funcs = {'tdamapper.learn.MapperAlgorithm' : MapperAlgorithm .__repr__ },
443
+ show_spinner = 'Compuring Mapper' ,
444
+ )
445
+ def compute_mapper (mapper , X , y ):
446
+ mapper_graph = mapper .fit_transform (X , y )
447
+ return mapper_graph
448
+
449
+
420
450
def mapper_input_section (X ):
421
451
lens = mapper_lens_input_section (X )
422
452
st .divider ()
@@ -429,7 +459,7 @@ def mapper_input_section(X):
429
459
verbose = True ,
430
460
n_jobs = 1 ,
431
461
)
432
- mapper_graph = mapper_algo . fit_transform ( X , lens )
462
+ mapper_graph = compute_mapper ( mapper_algo , X , lens )
433
463
return mapper_graph
434
464
435
465
@@ -528,14 +558,28 @@ def plot_dim_input_section():
528
558
return dim
529
559
530
560
531
- def plot_input_section (df_X , df_y , mapper_graph ):
532
- st .header ('🎨 Plot' )
561
+ @st .cache_data (
562
+ hash_funcs = {'networkx.classes.graph.Graph' : lambda g : _encode_graph (_get_graph_no_attribs (g ))},
563
+ show_spinner = 'Generating Mapper Embedding' ,
564
+ )
565
+ def compute_mapper_plot (mapper_graph , dim , seed ):
566
+ mapper_plot = MapperPlot (mapper_graph , dim , seed = seed )
567
+ return mapper_plot
568
+
569
+
570
+ def mapper_plot_section (mapper_graph ):
571
+ st .header ('🗺️ Layout' )
533
572
dim = plot_dim_input_section ()
534
573
seed = plot_seed_input_section ()
535
- agg , agg_name = plot_agg_input_section ()
536
- cmap = plot_cmap_input_section ()
537
- colors , colors_feat = plot_color_input_section (df_X , df_y )
538
- mapper_plot = MapperPlot (mapper_graph , dim , seed = seed )
574
+ mapper_plot = compute_mapper_plot (mapper_graph , dim , seed )
575
+ return mapper_plot
576
+
577
+
578
+ @st .cache_data (
579
+ hash_funcs = {'tdamapper.plot.MapperPlot' : lambda mp : mp .positions },
580
+ show_spinner = 'Rendering Mapper' ,
581
+ )
582
+ def compute_mapper_fig (mapper_plot , colors , cmap , agg , agg_name , colors_feat ):
539
583
mapper_fig = mapper_plot .plot_plotly (
540
584
colors ,
541
585
agg = agg ,
@@ -544,8 +588,25 @@ def plot_input_section(df_X, df_y, mapper_graph):
544
588
width = 600 ,
545
589
height = 600 ,
546
590
)
591
+ return mapper_fig
592
+
593
+
594
+ def mapper_figure_section (df_X , df_y , mapper_plot ):
595
+ st .header ('🎨 Plot' )
596
+ agg , agg_name = plot_agg_input_section ()
597
+ cmap = plot_cmap_input_section ()
598
+ colors , colors_feat = plot_color_input_section (df_X , df_y )
599
+ mapper_fig = compute_mapper_fig (
600
+ mapper_plot ,
601
+ colors = colors ,
602
+ agg = agg ,
603
+ cmap = cmap ,
604
+ agg_name = agg_name ,
605
+ colors_feat = colors_feat ,
606
+ )
607
+ dim = _get_dim (mapper_fig )
547
608
mapper_fig .update_layout (
548
- dragmode = 'pan ' if dim == 2 else 'orbit ' ,
609
+ dragmode = 'orbit ' if dim == 3 else 'pan ' ,
549
610
uirevision = 'constant' ,
550
611
margin = dict (b = 0 , l = 0 , r = 0 , t = 0 ),
551
612
)
@@ -557,7 +618,7 @@ def plot_input_section(df_X, df_y, mapper_graph):
557
618
scaleanchor = 'x' ,
558
619
scaleratio = 1 ,
559
620
)
560
- return mapper_plot , mapper_fig
621
+ return mapper_fig
561
622
562
623
563
624
def mapper_rendering_section (mapper_graph , mapper_fig ):
@@ -570,26 +631,38 @@ def mapper_rendering_section(mapper_graph, mapper_fig):
570
631
571
632
572
633
def data_summary_section (df_X , df_y , mapper_graph ):
573
- df_data = pd .DataFrame ({
574
- 'samples' : [len (df_X )],
575
- 'input features' : [len (df_X .columns )],
576
- 'target features' : [len (df_y .columns )],
577
- })
578
- df_graph = pd .DataFrame ({
579
- 'nodes' : [mapper_graph .number_of_nodes ()],
580
- 'edges' : [mapper_graph .number_of_edges ()],
581
- 'connected components' : [nx .number_connected_components (mapper_graph )],
634
+ df_stats = pd .DataFrame ({
635
+ 'Stat' : [
636
+ 'Samples' ,
637
+ 'Input Feats' ,
638
+ 'Target Feats' ,
639
+ 'Nodes' ,
640
+ 'Edges' ,
641
+ 'Conn. Comp.'
642
+ ],
643
+ 'Value' : [
644
+ len (df_X ),
645
+ len (df_X .columns ),
646
+ len (df_y .columns ),
647
+ mapper_graph .number_of_nodes (),
648
+ mapper_graph .number_of_edges (),
649
+ nx .number_connected_components (mapper_graph ),
650
+ ]
582
651
})
583
- st .dataframe (df_graph , hide_index = True , use_container_width = True )
584
- st .dataframe (df_data , hide_index = True , use_container_width = True )
652
+ st .dataframe (
653
+ df_stats ,
654
+ hide_index = True ,
655
+ use_container_width = True ,
656
+ height = 250 ,
657
+ )
585
658
df_summary = _get_data_summary (df_X , df_y )
586
659
st .dataframe (
587
660
df_summary ,
588
661
hide_index = True ,
589
- height = 400 ,
662
+ height = 330 ,
590
663
column_config = {
591
664
V_DATA_SUMMARY_HIST : st .column_config .AreaChartColumn (
592
- width = 'large ' ,
665
+ width = 'small ' ,
593
666
),
594
667
V_DATA_SUMMARY_FEAT : st .column_config .TextColumn (
595
668
width = 'small' ,
@@ -663,7 +736,9 @@ def main():
663
736
st .divider ()
664
737
mapper_graph = mapper_input_section (df_X .to_numpy ())
665
738
st .divider ()
666
- mapper_plot , mapper_fig = plot_input_section (df_X , df_y , mapper_graph )
739
+ mapper_plot = mapper_plot_section (mapper_graph )
740
+ st .divider ()
741
+ mapper_fig = mapper_figure_section (df_X , df_y , mapper_plot )
667
742
col_0 , col_1 = st .columns ([1 , 3 ])
668
743
with col_0 :
669
744
data_summary_section (df_X , df_y , mapper_graph )
0 commit comments