Skip to content

Commit 8ff437e

Browse files
Michael TsangMichael Tsang
Michael Tsang
authored and
Michael Tsang
committed
add guards for corner cases
1 parent a8e11c8 commit 8ff437e

File tree

1 file changed

+20
-10
lines changed

1 file changed

+20
-10
lines changed

src/explainer.py

+20-10
Original file line numberDiff line numberDiff line change
@@ -247,16 +247,26 @@ def search_feature_sets(
247247
# interaction detection
248248
ell_i = np.abs(context[i].item() - insertion_target[i].item())
249249
ell_j = np.abs(context[j].item() - insertion_target[j].item())
250-
inter_scores[(i, j)] = (
251-
1
252-
/ (ell_i * ell_j)
253-
* (
254-
context_score
255-
- idv_scores[(i,)]
256-
- idv_scores[(j,)]
257-
+ pair_scores[(i, j)]
258-
)
259-
)
250+
f_a = context_score
251+
f_b = idv_scores[(i,)]
252+
f_c = idv_scores[(j,)]
253+
f_d = pair_scores[(i, j)]
254+
255+
numerator = f_a - f_b - f_c + f_d
256+
denominator = ell_i * ell_j
257+
258+
# The numerator should theorecially be zero when there aren't interactions
259+
# in the function f. However, it is possible that the numerator is not
260+
# exactly zero due to precision issues in the function call. Here, if all
261+
# f_x are much larger than and the existing numerator value in magitude,
262+
# then we set the numerator to zero
263+
if np.abs(numerator) / np.min(np.abs(np.array([f_a, f_b, f_c, f_d]))) < 1e-5:
264+
numerator = 0.0
265+
266+
if denominator == 0.0:
267+
inter_scores[(i, j)] = 0.0
268+
else:
269+
inter_scores[(i, j)] = numerator / denominator
260270

261271
if (
262272
get_pairwise_effects

0 commit comments

Comments
 (0)