Skip to content

Commit

Permalink
:q
Browse files Browse the repository at this point in the history
  • Loading branch information
jannisborn committed Jan 12, 2024
1 parent 6887c80 commit 7653dae
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ where `-checkpoint` specifies which `.pt` file to pick for the evaluation (based

## Attention visualization
The model uses a self-attention mechanism that can highlight chemical motifs used for the predictions.
In [notebooks/toxicity_attention.ipynb](notebooks/toxicity_attention.ipynb) we share a tutorial on how to create such plots:
In [notebooks/toxicity_attention_plot.ipynb](notebooks/toxicity_attention_plot.ipynb) we share a tutorial on how to create such plots:
![Attention](assets/attention.gif "toxicophore attention")


Expand Down
2 changes: 1 addition & 1 deletion toxsmi/models/mca.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
smiles_projection,
)
from paccmann_predictor.utils.utils import get_device

from toxsmi.utils.hyperparams import ACTIVATION_FN_FACTORY, LOSS_FN_FACTORY
from toxsmi.utils.layers import EnsembleLayer

Expand Down Expand Up @@ -231,6 +230,7 @@ def __init__(self, params: dict, *args, **kwargs):
self.loss_name = params.get(
"loss_fn", "binary_cross_entropy_ignore_nan_and_sum"
)

final_activation = (
ACTIVATION_FN_FACTORY["sigmoid"]
if "cross" in self.loss_name
Expand Down
18 changes: 9 additions & 9 deletions toxsmi/utils/performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def process_data(self, labels: np.array, preds: np.array):
preds = preds[~np.isnan(labels)]
labels = labels[~np.isnan(labels)]

return labels, preds
return labels.astype(float), preds.astype(float)

def performance_report_binary_classification(
self, labels: np.array, preds: np.array, loss: float, model: Callable
Expand All @@ -125,10 +125,10 @@ def performance_report_binary_classification(

bin_preds, youden = binarize_predictions(preds, labels, return_youden=True)
report = classification_report(labels, bin_preds, output_dict=True)
negative_precision = report["0.0"]["precision"]
negative_recall = report["0.0"]["recall"]
positive_precision = report["1.0"]["precision"]
positive_recall = report["1.0"]["recall"]
negative_precision = report.get("0.0", {}).get("precision", -1)
negative_recall = report.get("0.0", {}).get("recall", -1)
positive_precision = report.get("1.0", {}).get("precision", -1)
positive_recall = report.get("1.0", {}).get("recall", -1)
f1 = fbeta_score(
labels, bin_preds, beta=self.beta, pos_label=1, average="binary"
)
Expand Down Expand Up @@ -237,10 +237,10 @@ def inference_report_binary_classification(
precision, recall, _ = precision_recall_curve(labels, preds)
precision_recall = average_precision_score(labels, preds)
report = classification_report(labels, bin_preds, output_dict=True)
negative_precision = report["0.0"]["precision"]
negative_recall = report["0.0"]["recall"]
positive_precision = report["1.0"]["precision"]
positive_recall = report["1.0"]["recall"]
negative_precision = report.get("0.0", {}).get("precision", -1)
negative_recall = report.get("0.0", {}).get("recall", -1)
positive_precision = report.get("1.0", {}).get("precision", -1)
positive_recall = report.get("1.0", {}).get("recall", -1)
accuracy = accuracy_score(labels, bin_preds)
bal_accuracy = balanced_accuracy_score(labels, bin_preds)
f1 = fbeta_score(labels, bin_preds, beta=0.5, pos_label=1, average="binary")
Expand Down

0 comments on commit 7653dae

Please sign in to comment.