Skip to content

Commit 69ee353

Browse files
add threshold testing for softmax
1 parent 509e62d commit 69ee353

File tree

2 files changed

+85
-3
lines changed

2 files changed

+85
-3
lines changed

psyneulink/core/components/functions/nonstateful/transferfunctions.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3154,7 +3154,7 @@ def apply_softmax(self, input_value, gain, mask_threshold, output_type):
31543154

31553155
# Mask threshold
31563156
if mask_threshold is not None:
3157-
if np.any(input_value < 0):
3157+
if np.any(v < 0):
31583158
warnings.warn(f"SoftMax function: mask_threshold is set "
31593159
f"to {mask_threshold} but input_value contains negative values."
31603160
f"Masking will be applied to the magnitude of the input.")
@@ -3169,11 +3169,11 @@ def apply_softmax(self, input_value, gain, mask_threshold, output_type):
31693169
v = np.exp(v)
31703170

31713171
# Normalize (to sum to 1)
3172-
if not any(v):
3172+
if not np.any(v):
31733173
# If v is all zeros, avoid divide by zero in normalize and return all zeros for softmax
31743174
sm = v
31753175
else:
3176-
sm = v / np.sum(v, axis=0)
3176+
sm = v / np.sum(v)
31773177

31783178
# Generate one-hot encoding based on selected output_type
31793179
if output_type in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}:

tests/functions/test_transfer.py

+82
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
softmax_helper = np.exp(softmax_helper) / np.sum(np.exp(softmax_helper))
2424
softmax_helper2 = np.array((softmax_helper, softmax_helper)).reshape(2, -1)
2525

26+
# Here, we use RAND1 * .5 as threshold so that the expected inputs to be masked is 50%.
27+
softmax_threshold_helper = RAND1 * test_var
28+
softmax_threshold_helper = np.where(np.abs(softmax_threshold_helper) > RAND1 * .5, softmax_threshold_helper, -np.inf)
29+
if np.any(softmax_threshold_helper != -np.inf):
30+
softmax_threshold_helper = softmax_threshold_helper - np.max(softmax_threshold_helper)
31+
softmax_threshold_helper = np.exp(softmax_threshold_helper)
32+
if np.any(softmax_threshold_helper):
33+
softmax_threshold_helper = softmax_threshold_helper / np.sum(softmax_threshold_helper)
34+
softmax_threshold_helper2 = np.array((softmax_threshold_helper, softmax_threshold_helper)).reshape(2, -1)
35+
36+
2637
tanh_helper = (RAND1 * (test_var + RAND2 - RAND3) + RAND4)
2738
tanh_helper = np.tanh(tanh_helper)
2839

@@ -106,6 +117,77 @@ def binomial_distort_helper(seed):
106117
pytest.param(pnl.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM: True},
107118
np.where(softmax_helper2 == np.max(softmax_helper2), 1, 0), id="SOFT_MAX MAX_INDICATOR PER_ITEM"),
108119

120+
# SoftMax with mask_threshold 1D input
121+
pytest.param(pnl.SoftMax, test_var,
122+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:False},
123+
softmax_threshold_helper, id="SOFT_MAX MASK_THRESHOLD ALL",
124+
marks=pytest.mark.llvm_not_implemented),
125+
pytest.param(pnl.SoftMax, test_var,
126+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:False},
127+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0),
128+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX", marks=pytest.mark.llvm_not_implemented),
129+
pytest.param(pnl.SoftMax, test_var,
130+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:False},
131+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0),
132+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR", marks=pytest.mark.llvm_not_implemented),
133+
pytest.param(pnl.SoftMax, test_var,
134+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:False},
135+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0),
136+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL", marks=pytest.mark.llvm_not_implemented),
137+
pytest.param(pnl.SoftMax, test_var,
138+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:False},
139+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0),
140+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR", marks=pytest.mark.llvm_not_implemented),
141+
pytest.param(pnl.SoftMax, test_var,
142+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.PROB, kw.PER_ITEM:False},
143+
[0.0, 0.0, 0.0, test_var[3], test_var[4], 0.0, 0.0, 0.0, 0.0, 0.0],
144+
id="SOFT_MAX MASK_THRESHOLD PROB", marks=pytest.mark.llvm_not_implemented),
145+
#
146+
# # SoftMax 2D threshold testing per-item
147+
pytest.param(pnl.SoftMax, [test_var],
148+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:True}, [softmax_threshold_helper],
149+
id="SOFT_MAX MASK_THRESHOLD ALL 2D", marks=pytest.mark.llvm_not_implemented),
150+
pytest.param(pnl.SoftMax, [test_var],
151+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:True},
152+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0)],
153+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX 2D", marks=pytest.mark.llvm_not_implemented),
154+
pytest.param(pnl.SoftMax, [test_var],
155+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:True},
156+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0)],
157+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
158+
pytest.param(pnl.SoftMax, [test_var],
159+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
160+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0)],
161+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL 2D", marks=pytest.mark.llvm_not_implemented),
162+
pytest.param(pnl.SoftMax, [test_var],
163+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
164+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0)],
165+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
166+
pytest.param(pnl.SoftMax, [test_var],
167+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.PROB, kw.PER_ITEM:True},
168+
[[0.0, 0.0, 0.0, test_var[3], test_var[4], 0.0, 0.0, 0.0, 0.0, 0.0]],
169+
id="SOFT_MAX MASK_THRESHOLD PROB 2D", marks=pytest.mark.llvm_not_implemented),
170+
171+
# SoftMax threshold per-item with 2 elements in input
172+
pytest.param(pnl.SoftMax, [test_var, test_var],
173+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:True}, softmax_threshold_helper2,
174+
id="SOFT_MAX MASK_THRESHOLD ALL 2D", marks=pytest.mark.llvm_not_implemented),
175+
pytest.param(pnl.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:True},
176+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), softmax_threshold_helper2, 0),
177+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX 2D", marks=pytest.mark.llvm_not_implemented),
178+
pytest.param(pnl.SoftMax, [test_var, test_var],
179+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:True},
180+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), 1, 0),
181+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
182+
pytest.param(pnl.SoftMax, [test_var, test_var],
183+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
184+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper2), softmax_threshold_helper2, 0),
185+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL 2D", marks=pytest.mark.llvm_not_implemented),
186+
pytest.param(pnl.SoftMax, [test_var, test_var],
187+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
188+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), 1, 0),
189+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
190+
109191
# Linear Matrix
110192
pytest.param(pnl.MatrixTransform, test_var, {kw.MATRIX:test_matrix}, np.dot(test_var, test_matrix), id="LINEAR_MATRIX SQUARE"),
111193
pytest.param(pnl.MatrixTransform, test_var, {kw.MATRIX:test_matrix_l}, np.dot(test_var, test_matrix_l), id="LINEAR_MATRIX WIDE"),

0 commit comments

Comments
 (0)