Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cmap option to specify the colormap of plot (Sourcery refactored) #1798

Closed
wants to merge 5 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 66 additions & 49 deletions tardis/visualization/widgets/custom_abundance.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@

BASE_DIR = tardis.__path__[0]
YAML_DELIMITER = "---"
COLORMAP = "jet"


class CustomAbundanceWidgetData:
Expand Down Expand Up @@ -313,8 +312,7 @@ def __init__(
self.name = name
self.model_density_time_0 = d_time_0
self.model_isotope_time_0 = i_time_0
self.datatype = {}
self.datatype["fields"] = []
self.datatype = {'fields': []}
self.v_inner_boundary = v_inner_boundary
self.v_outer_boundary = v_outer_boundary

Expand Down Expand Up @@ -362,6 +360,10 @@ class CustomAbundanceWidget:
checked_list : list of bool
A list of bool to record whether the checkbox is checked.
The index of the bool corresponds to the index of checkbox.
fig : plotly.graph_objs._figurewidget.FigureWidget
The figure object of abundance density plot.
plot_cmap : str, default: "jet", optional
String defines the colormap used in abundance density plot.
_trigger : bool
If False, disable the callback when abundance input is changed.
"""
Expand All @@ -377,6 +379,7 @@ def __init__(self, widget_data):
widget_data : CustomAbundanceWidgetData
"""
self.data = widget_data
self.fig = go.FigureWidget()
self._trigger = True

self.create_widgets()
Expand Down Expand Up @@ -405,11 +408,7 @@ def no_of_elements(self):

@property
def checked_list(self): # A boolean list to store the value of checkboxes.
_checked_list = []
for check in self.checks:
_checked_list.append(check.value)

return _checked_list
return [check.value for check in self.checks]

def create_widgets(self):
"""Create widget components in GUI and register callbacks for widgets."""
Expand Down Expand Up @@ -651,6 +650,12 @@ def update_front_end(self):
width[0] = x_outer - x_inner
self.fig.data[0].width = width

def update_line_color(self):
"""Update line color in the plot according to colormap."""
colorscale = transition_colors(self.no_of_elements, self.plot_cmap)
for i in range(self.no_of_elements):
self.fig.data[2 + i].line.color = colorscale[i]

def overwrite_existing_shells(self, v_0, v_1):
"""Judge whether the existing shell(s) will be overwritten when
inserting a new shell within the entered velocity range.
Expand Down Expand Up @@ -682,13 +687,10 @@ def overwrite_existing_shells(self, v_0, v_1):
else position_1
)

if (index_1 - index_0 > 1) or (
return bool((index_1 - index_0 > 1) or (
(index_1 < len(v_vals) and np.isclose(v_vals[index_1], v_1))
or (index_1 - index_0 == 1 and not np.isclose(v_vals[index_0], v_0))
):
return True
else:
return False
))

def on_btn_add_shell(self, obj):
"""Add new shell with given boundary velocities. Triggered if
Expand Down Expand Up @@ -745,22 +747,21 @@ def on_btn_add_shell(self, obj):
1,
inplace=True,
)
elif start_index == 0:
self.data.abundance.insert(end_index, "new", 0)
self.data.abundance.insert(
end_index, "gap", 0
) # Add a shell to fill the gap.
else:
if start_index == 0:
self.data.abundance.insert(end_index, "new", 0)
self.data.abundance.insert(
end_index, "gap", 0
) # Add a shell to fill the gap.
self.data.abundance.insert(end_index - 1, "new", 0)
if start_index == self.no_of_shells:
self.data.abundance.insert(end_index - 1, "gap", 0)
else:
self.data.abundance.insert(end_index - 1, "new", 0)
if start_index == self.no_of_shells:
self.data.abundance.insert(end_index - 1, "gap", 0)
else:
self.data.abundance.insert(
end_index - 1,
"gap",
self.data.abundance.iloc[:, end_index],
) # Add a shell to fill the gap with original abundances
self.data.abundance.insert(
end_index - 1,
"gap",
self.data.abundance.iloc[:, end_index],
) # Add a shell to fill the gap with original abundances

self.data.abundance.columns = range(self.no_of_shells)

Expand Down Expand Up @@ -848,9 +849,9 @@ def check_eventhandler(self, obj):
obj : traitlets.utils.bunch.Bunch
A dictionary holding the information about the change.
"""
item_index = obj.owner.index

if obj.new == True:
item_index = obj.owner.index

self.bound_locked_sum_to_1(item_index)

def dpd_shell_no_eventhandler(self, obj):
Expand All @@ -863,16 +864,8 @@ def dpd_shell_no_eventhandler(self, obj):
A dictionary holding the information about the change.
"""
# Disable "previous" and "next" buttons when shell no comes to boundaries.
if obj.new == 1:
self.btn_prev.disabled = True
else:
self.btn_prev.disabled = False

if obj.new == self.no_of_shells:
self.btn_next.disabled = True
else:
self.btn_next.disabled = False

self.btn_prev.disabled = obj.new == 1
self.btn_next.disabled = obj.new == self.no_of_shells
self.update_front_end()

def on_btn_prev(self, obj):
Expand Down Expand Up @@ -1024,9 +1017,7 @@ def on_btn_add_element(self, obj):
)
self.fig.data = fig_data_lst[:-1]

colorscale = transition_colors(self.no_of_elements, COLORMAP)
for i in range(self.no_of_elements):
self.fig.data[2 + i].line.color = colorscale[i]
self.update_line_color()

self.read_abundance()

Expand Down Expand Up @@ -1156,7 +1147,6 @@ def irs_shell_range_eventhandler(self, obj):

def generate_abundance_density_plot(self):
"""Generate abundance and density plot in different shells."""
self.fig = go.FigureWidget()
title = "Abundance/Density vs Velocity"
abundance = self.data.abundance
velocity = self.data.velocity
Expand Down Expand Up @@ -1188,14 +1178,13 @@ def generate_abundance_density_plot(self):
),
)

colorscale = transition_colors(self.no_of_elements, COLORMAP)
for i in range(self.no_of_elements):
self.fig.add_trace(
go.Scatter(
x=velocity,
y=np.append(abundance.iloc[i], abundance.iloc[i, -1]),
mode="lines+markers",
line=dict(shape="hv", color=colorscale[i]),
line=dict(shape="hv"),
name=self.data.elements[i],
),
)
Expand Down Expand Up @@ -1224,14 +1213,20 @@ def generate_abundance_density_plot(self):
),
)

def display(self):
def display(self, cmap="jet"):
"""Display the GUI.

Parameters
----------
cmap : str, default: "jet", optional
String defines the colormap used in abundance density plot.

Returns
-------
ipywidgets.widgets.widget_box.VBox
A box that contains all the widgets in the GUI.
"""
# --------------Combine widget components--------------
self.box_editor = ipw.HBox(
[
ipw.VBox(self.input_items),
Expand Down Expand Up @@ -1303,6 +1298,10 @@ def display(self):
),
]
)

# Initialize the widget and plot colormap
self.plot_cmap = cmap
self.update_line_color()
self.read_abundance()
self.density_editor.read_density()

Expand Down Expand Up @@ -1335,9 +1334,8 @@ def to_csvy(self, path, overwrite):
raise FileExistsError(
"The file already exists. Click the 'overwrite' checkbox to overwrite it."
)
else:
self.write_yaml_portion(posix_path)
self.write_csv_portion(posix_path)
self.write_yaml_portion(posix_path)
self.write_csv_portion(posix_path)

@error_view.capture(clear_output=True)
def write_yaml_portion(self, path):
Expand Down Expand Up @@ -1522,6 +1520,14 @@ def create_widgets(self):
)
self.input_d_time_0.observe(self.input_d_time_0_eventhandler, "value")

self.input_d_time_0 = ipw.FloatText(
value=self.data.density_t_0.value,
description="Density time_0 (day): ",
style={"description_width": "initial"},
layout=ipw.Layout(margin="0 0 20px 0"),
)
self.input_d_time_0.observe(self.input_d_time_0_eventhandler, "value")

self.dpd_dtype = ipw.Dropdown(
options=["-", "uniform", "exponential", "power_law"],
description="Density type: ",
Expand Down Expand Up @@ -1625,6 +1631,17 @@ def input_d_time_0_eventhandler(self, obj):
new_value = obj.new
self.data.density_t_0 = new_value * self.data.density_t_0.unit

def input_d_time_0_eventhandler(self, obj):
"""Update density time 0 data when the widget gets new input.

Parameters
----------
obj : traitlets.utils.bunch.Bunch
A dictionary holding the information about the change.
"""
new_value = obj.new
self.data.density_t_0 = new_value * self.data.density_t_0.unit

dtype_out = ipw.Output()

@dtype_out.capture(clear_output=True)
Expand Down