Skip to content

Commit b642eb4

Browse files
authored
Merge pull request #199 from lucasimi/develop
Develop
2 parents 58cb533 + f8bc68d commit b642eb4

5 files changed

+169
-420
lines changed

app/streamlit_app.py

+78-23
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
1+
import os
12
import json
23
import time
34
import io
45
import gzip
6+
import logging
57
from datetime import datetime
68

79
import streamlit as st
810
import pandas as pd
911
import numpy as np
10-
import plotly.express as px
1112

1213
import networkx as nx
1314
from networkx.readwrite.json_graph import adjacency_data
@@ -24,13 +25,21 @@
2425

2526
from umap import UMAP
2627

27-
from tdamapper.core import ATTR_SIZE
2828
from tdamapper.learn import MapperAlgorithm
2929
from tdamapper.cover import CubicalCover, BallCover
3030
from tdamapper.plot import MapperPlot
31-
from tdamapper.utils.metrics import minkowski
3231

3332

33+
LIMITS_ENABLED = bool(os.environ.get('LIMITS_ENABLED', False))
34+
35+
LIMITS_NUM_SAMPLES = int(os.environ.get('LIMITS_NUM_SAMPLES', 10000))
36+
37+
LIMITS_NUM_FEATURES = int(os.environ.get('LIMITS_NUM_FEATURES', 1000))
38+
39+
LIMITS_NUM_NODES = int(os.environ.get('LIMITS_NUM_NODES', 2000))
40+
41+
LIMITS_NUM_EDGES = int(os.environ.get('LIMITS_NUM_EDGES', 3000))
42+
3443
OPENML_URL = 'https://www.openml.org/search?type=data&sort=runs&status=active'
3544

3645
S_RESULTS = 'stored_results'
@@ -113,24 +122,43 @@
113122
)
114123

115124

116-
def mode(arr):
117-
unique, counts = np.unique(arr, return_counts=True)
118-
max_count_index = np.argmax(counts)
119-
return unique[max_count_index]
120-
121-
122-
def quantile(q):
123-
return lambda agg: np.nanquantile(agg, q=q)
125+
def _check_limits_mapper_graph(mapper_graph):
126+
if LIMITS_ENABLED:
127+
num_nodes = mapper_graph.number_of_nodes()
128+
if num_nodes > LIMITS_NUM_NODES:
129+
logging.warn('Too many nodes.')
130+
raise ValueError(
131+
'Too many nodes: select different parameters or run the app '
132+
'locally on your machine.'
133+
)
134+
num_edges = mapper_graph.number_of_edges()
135+
if num_edges > LIMITS_NUM_EDGES:
136+
logging.warn('Too many edges.')
137+
raise ValueError(
138+
'Too many edges: select different parameters or run the app '
139+
'locally on your machine.'
140+
)
124141

125142

126-
@st.cache_data
127-
def get_sample(df: pd.DataFrame, frac=SAMPLE_FRAC, max_n=MAX_SAMPLES, rand=42):
128-
if frac * len(df) > max_n:
129-
return df.sample(n=max_n, random_state=rand)
130-
return df.sample(frac=frac, random_state=rand)
143+
def _check_limits_dataset(df_X, df_y):
144+
if LIMITS_ENABLED:
145+
num_samples = len(df_X)
146+
if num_samples > LIMITS_NUM_SAMPLES:
147+
logging.warn('Dataset too big.')
148+
raise ValueError(
149+
'Dataset too big: select a different dataset or run the app '
150+
'locally on your machine.'
151+
)
152+
num_features = len(df_X.columns) + len(df_y.columns)
153+
if num_features > LIMITS_NUM_FEATURES:
154+
logging.warn('Too many features.')
155+
raise ValueError(
156+
'Too many features: select a different dataset or run the app '
157+
'locally on your machine.'
158+
)
131159

132160

133-
def fix_data(data):
161+
def _fix_data(data):
134162
df = pd.DataFrame(data)
135163
df = df.select_dtypes(include='number')
136164
df.dropna(axis=1, how='all', inplace=True)
@@ -171,6 +199,23 @@ def _get_data_summary(df_X, df_y):
171199
return df_summary
172200

173201

202+
def mode(arr):
203+
unique, counts = np.unique(arr, return_counts=True)
204+
max_count_index = np.argmax(counts)
205+
return unique[max_count_index]
206+
207+
208+
def quantile(q):
209+
return lambda agg: np.nanquantile(agg, q=q)
210+
211+
212+
@st.cache_data
213+
def get_sample(df: pd.DataFrame, frac=SAMPLE_FRAC, max_n=MAX_SAMPLES, rand=42):
214+
if frac * len(df) > max_n:
215+
return df.sample(n=max_n, random_state=rand)
216+
return df.sample(frac=frac, random_state=rand)
217+
218+
174219
def initialize():
175220
st.set_page_config(
176221
layout='wide',
@@ -196,6 +241,7 @@ def load_data(source=None, name=None, csv=None):
196241
elif name == 'Iris':
197242
X, y = load_iris(return_X_y=True, as_frame=True)
198243
elif source == 'OpenML':
244+
logging.info(f'Fetching dataset {name} from OpenML')
199245
X, y = fetch_openml(
200246
name,
201247
return_X_y=True,
@@ -207,7 +253,8 @@ def load_data(source=None, name=None, csv=None):
207253
raise ValueError('No csv file uploaded')
208254
else:
209255
X, y = pd.read_csv(csv), pd.DataFrame()
210-
df_X, df_y = fix_data(X), fix_data(y)
256+
df_X, df_y = _fix_data(X), _fix_data(y)
257+
_check_limits_dataset(df_X, df_y)
211258
return df_X, df_y
212259

213260

@@ -304,11 +351,18 @@ def mapper_cover_input_section():
304351
'Intervals',
305352
value=10,
306353
min_value=0)
307-
cubical_p = st.number_input(
308-
'Overlap',
309-
value=0.25,
310-
min_value=0.0,
311-
max_value=1.0)
354+
cubical_overlap = st.checkbox(
355+
'Set overlap',
356+
value=False,
357+
help='Uses a dimension-dependant default overlap when unchecked',
358+
)
359+
cubical_p = None
360+
if cubical_overlap:
361+
cubical_p = st.number_input(
362+
'Overlap',
363+
value=0.25,
364+
min_value=0.0,
365+
max_value=1.0)
312366
cover = CubicalCover(n_intervals=cubical_n, overlap_frac=cubical_p)
313367
return cover
314368

@@ -547,6 +601,7 @@ def plot_color_input_section(df_X, df_y):
547601
show_spinner='Generating Mapper Layout',
548602
)
549603
def compute_mapper_plot(mapper_graph, dim, seed, iterations):
604+
_check_limits_mapper_graph(mapper_graph)
550605
mapper_plot = MapperPlot(
551606
mapper_graph,
552607
dim,
6.62 KB
Loading
138 Bytes
Loading
-29 KB
Binary file not shown.

0 commit comments

Comments
 (0)