Skip to content

Commit bbb3fb0

Browse files
Merge pull request #3196 from PrincetonUniversity/fix-softmax
Fix softmax
2 parents c777ff0 + eba49a9 commit bbb3fb0

File tree

3 files changed

+157
-37
lines changed

3 files changed

+157
-37
lines changed

psyneulink/core/components/functions/nonstateful/transferfunctions.py

+53-23
Original file line numberDiff line numberDiff line change
@@ -2828,9 +2828,10 @@ class SoftMax(TransferFunction):
28282828
<SoftMax.gain>` parametrically based on the `variable <SoftMax.variable>`:
28292829
28302830
- *mask_threshold* -- setting the **mask_threshold** argument to a scalar value causes the `variable
2831-
<SoftMax.variable>` to be thresholded by that value before applying the SoftMax function; any elements of
2832-
`variable <SoftMax.variable>` with an absolute value below the threshold are set to 0; all others are scaled
2833-
by the specified `gain <SoftMax.gain>` and then passed through the SoftMax function. This only applies if the
2831+
<SoftMax.variable>` to be thresholded by that value before applying the SoftMax function; Each element in
2832+
variable <SoftMax.variable> is first scaled by gain <SoftMax.gain>. Then, any elements with an absolute
2833+
value below *mask_threshold* are set to negative infinity (``-inf``), effectively masking them since
2834+
``exp(-inf) = 0``. The remaining values are then passed through the SoftMax function. This only applies if the
28342835
**gain** argument is specified as a scalar; if it is specified as *ADAPTIVE*, then the **mask_threshold**
28352836
argument is ignored.
28362837
@@ -2920,10 +2921,11 @@ class SoftMax(TransferFunction):
29202921
29212922
mask_threshold : scalar or None
29222923
determines whether the `variable <SoftMax.variable>` is thresholded before applying the SoftMax function;
2923-
if it is a scalar, only elements of `variable <SoftMax.variable>` with an absolute value greater than that
2924-
value are considered when applying the SoftMax function (which are then scaled by the `gain <SoftMax.gain>`
2925-
parameter; all other elements are assigned 0. This only applies if `gain <SoftMax.gain>` is specified as a
2926-
scalar; otherwise it is ignored (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2924+
if it is a scalar, each elements of `variable <SoftMax.variable>` is first scaled by `<SoftMax.gain>`. Then,
2925+
only elements with an absolute value greater than *mask_threshold* are considered when applying the SoftMax
2926+
function, while all other elements are set to ``-inf`` effectively masking them since ``exp(-inf) = 0``.
2927+
This only applies if `gain <SoftMax.gain>` is specified as a scalar; otherwise it is ignored
2928+
(see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
29272929
29282930
adapt_scale : scalar
29292931
determines the *scale* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
@@ -3149,22 +3151,31 @@ def _validate_variable(self, variable, context=None):
31493151
return np.asarray(variable)
31503152

31513153
def apply_softmax(self, input_value, gain, mask_threshold, output_type):
3152-
31533154
# Modulate input_value by gain
31543155
v = gain * input_value
3155-
# Shift by max to avoid extreme values:
3156-
v = v - np.max(v)
3156+
3157+
# Mask threshold
3158+
if mask_threshold is not None:
3159+
if np.any(v < 0):
3160+
warnings.warn(f"SoftMax function: mask_threshold is set "
3161+
f"to {mask_threshold} but input_value contains negative values."
3162+
f"Masking will be applied to the magnitude of the input.")
3163+
3164+
v = np.where(np.abs(v) > mask_threshold, v, -np.inf)
3165+
3166+
# Make numerically stable by shifting by max value
3167+
if np.any(v != -np.inf):
3168+
v = v - np.max(v)
3169+
31573170
# Exponentiate
31583171
v = np.exp(v)
3159-
# Threshold if specified:
3160-
if mask_threshold:
3161-
v = v * np.where(input_value > mask_threshold, v, 0)
3172+
31623173
# Normalize (to sum to 1)
3163-
if not any(v):
3174+
if not np.any(v):
31643175
# If v is all zeros, avoid divide by zero in normalize and return all zeros for softmax
31653176
sm = v
31663177
else:
3167-
sm = v / np.sum(v, axis=0)
3178+
sm = v / np.sum(v)
31683179

31693180
# Generate one-hot encoding based on selected output_type
31703181
if output_type in {ARG_MAX, ARG_MAX_INDICATOR, MAX_VAL, MAX_INDICATOR}:
@@ -3472,15 +3483,34 @@ def _gen_pytorch_fct(self, device, context=None):
34723483
if isinstance(gain, str) and gain == ADAPTIVE:
34733484
return lambda x: (torch.softmax(self._gen_pytorch_adapt_gain_fct(device, context)(x) * x, -1))
34743485

3475-
elif mask_threshold:
3486+
elif mask_threshold is not None:
34763487
def pytorch_thresholded_softmax(_input: torch.Tensor) -> torch.Tensor:
3477-
# Mask elements of input below threshold
3478-
_mask = (torch.abs(_input) > mask_threshold)
3479-
# Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask
3480-
masked_exp = _mask * torch.exp(gain * (_input - torch.max(_input, -1, keepdim=True)[0]))
3481-
if (masked_exp == 0).all():
3482-
return masked_exp
3483-
return masked_exp / torch.sum(masked_exp, -1, keepdim=True)
3488+
v = gain * _input
3489+
3490+
# Apply threshold-based masking
3491+
if mask_threshold is not None:
3492+
if torch.any(_input < 0):
3493+
warnings.warn(f"Softmax function: mask_threshold is set to {mask_threshold}, "
3494+
f"but input contains negative values. "
3495+
f"Masking will be applied to the magnitude of the input.")
3496+
3497+
# Create a mask where values below threshold are set to -inf
3498+
mask = torch.abs(v) > mask_threshold
3499+
v = v.masked_fill(~mask, float('-inf')) # More stable than torch.where()
3500+
3501+
# Handle case where all values are masked (return tensor with gradient support)
3502+
if torch.all(~mask):
3503+
return torch.full_like(v, 0.0, requires_grad=True)
3504+
3505+
# Make numerically stable by shifting max value
3506+
max_v = torch.max(v[mask]) # Avoid computing max over -inf
3507+
v = v - max_v
3508+
3509+
# Compute softmax (PyTorch handles -inf correctly)
3510+
exp_v = torch.exp(v)
3511+
sm = exp_v / torch.sum(exp_v, dim=-1, keepdim=True)
3512+
3513+
return sm
34843514
# Return the function
34853515
return pytorch_thresholded_softmax
34863516

tests/composition/test_emcomposition.py

+22-14
Original file line numberDiff line numberDiff line change
@@ -225,14 +225,17 @@ def test_memory_fill(start, memory_fill):
225225
test_memory_fill(start=repeat, memory_fill=memory_fill)
226226

227227
@pytest.mark.parametrize("softmax_choice, expected",
228-
[(pnl.WEIGHTED_AVG, [[0.93016008, 0.1, 0.16983992]]),
228+
[(pnl.WEIGHTED_AVG, [[0.8479525858370621, 0.1, 0.25204741416293786]]),
229229
(pnl.ARG_MAX, [[1, .1, .1]]),
230230
(pnl.PROBABILISTIC, [[1, .1, .1]]), # NOTE: actual stochasticity not tested here
231231
])
232232
def test_softmax_choice(self, softmax_choice, expected):
233233
em = EMComposition(memory_template=[[[1,.1,.1]], [[1,.1,.1]], [[.1,.1,1]]],
234234
softmax_choice=softmax_choice,
235-
enable_learning=False)
235+
enable_learning=False,
236+
softmax_threshold=None,
237+
memory_decay_rate=0,
238+
normalize_memories=False)
236239
result = em.run(inputs={em.query_input_nodes[0]:[[1,0,0]]})
237240

238241
np.testing.assert_allclose(result, expected)
@@ -739,11 +742,12 @@ def test_assign_field_weights_and_0_vs_None(self,
739742

740743
em = pnl.EMComposition(memory_template=memory_template,
741744
memory_capacity=4,
742-
memory_decay_rate= 0,
743-
learn_field_weights = False,
745+
memory_decay_rate=0,
746+
learn_field_weights=False,
744747
softmax_choice=softmax_choice,
745748
field_weights=field_weights,
746-
field_names=['A','B','C'])
749+
field_names=['A', 'B', 'C'],
750+
)
747751
# Confirm initial weight assignments (that favor A)
748752
assert em.nodes['A [WEIGHT]'].input_port.defaults.variable == [.75]
749753
assert em.nodes['B [WEIGHT]'].input_port.defaults.variable == [.25]
@@ -774,9 +778,9 @@ def test_assign_field_weights_and_0_vs_None(self,
774778
# Note: field_weights favors A
775779
if softmax_choice == pnl.MAX_VAL:
776780
if operation == pnl.L0:
777-
expected = [[1.70381182], [0.], [3.40762364]]
781+
expected = [[1.467373], [0.], [2.934746]]
778782
else:
779-
expected = [[1.56081243, 0.0], [0.0, 1.56081243], [3.12162487, 3.12162487]]
783+
expected = [[1.419423, 0.0], [0.0, 1.419423], [2.838846, 2.838846]]
780784
else:
781785
expected = memory_template[0]
782786
np.testing.assert_allclose(result, expected)
@@ -899,7 +903,7 @@ def test_backpropagation_of_error_in_learning(self):
899903
memory_capacity=50,
900904
memory_decay_rate=0,
901905
softmax_gain=10,
902-
softmax_threshold=.001,
906+
softmax_threshold=0.001,
903907
fields = {'STATE': {pnl.FIELD_WEIGHT: None,
904908
pnl.LEARN_FIELD_WEIGHT: False,
905909
pnl.TARGET_FIELD: True},
@@ -1026,12 +1030,16 @@ def test_backpropagation_of_error_in_learning(self):
10261030
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]]
10271031

10281032
result = EGO.learn(inputs={'STATE':INPUTS}, learning_rate=.5, execution_mode=pnl.ExecutionMode.PyTorch)
1029-
expected = [[ 0.00000000e+00, 1.35476414e-03, 1.13669378e-03, 2.20434260e-03, 6.61008388e-04, 9.88672202e-01,
1030-
6.52088276e-04, 1.74149507e-03, 1.09769133e-03, 2.47971436e-03, 0.00000000e+00],
1031-
[ 0.00000000e+00, -6.75284069e-02, -1.28930436e-03, -2.10726610e-01, -1.41050716e-03, -5.92286989e-01,
1032-
-2.75196416e-03, -2.21010605e-03, -7.14369243e-03, -2.05167374e-02, 0.00000000e+00],
1033-
[ 0.00000000e+00, 1.18578255e-03, 1.29393181e-03, 1.35476414e-03, 1.13669378e-03, 2.20434260e-03,
1034-
6.61008388e-04, 9.88672202e-01, 6.52088276e-04, 2.83918640e-03, 0.00000000e+00]]
1033+
expected = [
1034+
[0.00000000e+00, 1.35933540e-03, 1.13114366e-03, 2.20590015e-03,
1035+
1.09314885e-03, 9.87722281e-01, 1.10371450e-03, 1.72925210e-03,
1036+
1.17352360e-03, 2.48170027e-03, 0.00000000e+00],
1037+
[0.00000000e+00, -6.54396065e-02, 1.41905061e-03, -2.08500295e-01,
1038+
-5.03985394e-05, -5.90196484e-01, -5.33017075e-03, -2.33024404e-03,
1039+
-2.02730870e-02, -1.58091223e-02, 0.00000000e+00],
1040+
[0.00000000e+00, 1.19576382e-03, 1.28593645e-03, 1.35933540e-03,
1041+
1.13114366e-03, 2.20590015e-03, 1.09314885e-03, 9.87722281e-01,
1042+
1.10371450e-03, 2.90277570e-03, 0.00000000e+00]]
10351043
np.testing.assert_allclose(result, expected)
10361044

10371045
# Plot (for during debugging):

tests/functions/test_transfer.py

+82
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
softmax_helper = np.exp(softmax_helper) / np.sum(np.exp(softmax_helper))
2424
softmax_helper2 = np.array((softmax_helper, softmax_helper)).reshape(2, -1)
2525

26+
# Here, we use RAND1 * .5 as threshold so that the expected inputs to be masked is 50%.
27+
softmax_threshold_helper = RAND1 * test_var
28+
softmax_threshold_helper = np.where(np.abs(softmax_threshold_helper) > RAND1 * .5, softmax_threshold_helper, -np.inf)
29+
if np.any(softmax_threshold_helper != -np.inf):
30+
softmax_threshold_helper = softmax_threshold_helper - np.max(softmax_threshold_helper)
31+
softmax_threshold_helper = np.exp(softmax_threshold_helper)
32+
if np.any(softmax_threshold_helper):
33+
softmax_threshold_helper = softmax_threshold_helper / np.sum(softmax_threshold_helper)
34+
softmax_threshold_helper2 = np.array((softmax_threshold_helper, softmax_threshold_helper)).reshape(2, -1)
35+
36+
2637
tanh_helper = (RAND1 * (test_var + RAND2 - RAND3) + RAND4)
2738
tanh_helper = np.tanh(tanh_helper)
2839

@@ -106,6 +117,77 @@ def binomial_distort_helper(seed):
106117
pytest.param(pnl.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM: True},
107118
np.where(softmax_helper2 == np.max(softmax_helper2), 1, 0), id="SOFT_MAX MAX_INDICATOR PER_ITEM"),
108119

120+
# SoftMax with mask_threshold 1D input
121+
pytest.param(pnl.SoftMax, test_var,
122+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:False},
123+
softmax_threshold_helper, id="SOFT_MAX MASK_THRESHOLD ALL",
124+
marks=pytest.mark.llvm_not_implemented),
125+
pytest.param(pnl.SoftMax, test_var,
126+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:False},
127+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0),
128+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX", marks=pytest.mark.llvm_not_implemented),
129+
pytest.param(pnl.SoftMax, test_var,
130+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:False},
131+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0),
132+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR", marks=pytest.mark.llvm_not_implemented),
133+
pytest.param(pnl.SoftMax, test_var,
134+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:False},
135+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0),
136+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL", marks=pytest.mark.llvm_not_implemented),
137+
pytest.param(pnl.SoftMax, test_var,
138+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:False},
139+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0),
140+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR", marks=pytest.mark.llvm_not_implemented),
141+
pytest.param(pnl.SoftMax, test_var,
142+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.PROB, kw.PER_ITEM:False},
143+
[0.0, 0.0, 0.0, test_var[3], test_var[4], 0.0, 0.0, 0.0, 0.0, 0.0],
144+
id="SOFT_MAX MASK_THRESHOLD PROB", marks=pytest.mark.llvm_not_implemented),
145+
#
146+
# # SoftMax 2D threshold testing per-item
147+
pytest.param(pnl.SoftMax, [test_var],
148+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:True}, [softmax_threshold_helper],
149+
id="SOFT_MAX MASK_THRESHOLD ALL 2D", marks=pytest.mark.llvm_not_implemented),
150+
pytest.param(pnl.SoftMax, [test_var],
151+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:True},
152+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0)],
153+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX 2D", marks=pytest.mark.llvm_not_implemented),
154+
pytest.param(pnl.SoftMax, [test_var],
155+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:True},
156+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0)],
157+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
158+
pytest.param(pnl.SoftMax, [test_var],
159+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
160+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), softmax_threshold_helper, 0)],
161+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL 2D", marks=pytest.mark.llvm_not_implemented),
162+
pytest.param(pnl.SoftMax, [test_var],
163+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
164+
[np.where(softmax_threshold_helper == np.max(softmax_threshold_helper), 1, 0)],
165+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
166+
pytest.param(pnl.SoftMax, [test_var],
167+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.PROB, kw.PER_ITEM:True},
168+
[[0.0, 0.0, 0.0, test_var[3], test_var[4], 0.0, 0.0, 0.0, 0.0, 0.0]],
169+
id="SOFT_MAX MASK_THRESHOLD PROB 2D", marks=pytest.mark.llvm_not_implemented),
170+
171+
# SoftMax threshold per-item with 2 elements in input
172+
pytest.param(pnl.SoftMax, [test_var, test_var],
173+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.PER_ITEM:True}, softmax_threshold_helper2,
174+
id="SOFT_MAX MASK_THRESHOLD ALL 2D", marks=pytest.mark.llvm_not_implemented),
175+
pytest.param(pnl.SoftMax, [test_var, test_var], {kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX, kw.PER_ITEM:True},
176+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), softmax_threshold_helper2, 0),
177+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX 2D", marks=pytest.mark.llvm_not_implemented),
178+
pytest.param(pnl.SoftMax, [test_var, test_var],
179+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:pnl.ARG_MAX_INDICATOR, kw.PER_ITEM:True},
180+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), 1, 0),
181+
id="SOFT_MAX MASK_THRESHOLD ARG_MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
182+
pytest.param(pnl.SoftMax, [test_var, test_var],
183+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_VAL, kw.PER_ITEM:True},
184+
np.where(softmax_threshold_helper == np.max(softmax_threshold_helper2), softmax_threshold_helper2, 0),
185+
id="SOFT_MAX MASK_THRESHOLD MAX_VAL 2D", marks=pytest.mark.llvm_not_implemented),
186+
pytest.param(pnl.SoftMax, [test_var, test_var],
187+
{kw.GAIN:RAND1, 'mask_threshold': RAND1 * .5, kw.OUTPUT_TYPE:kw.MAX_INDICATOR, kw.PER_ITEM:True},
188+
np.where(softmax_threshold_helper2 == np.max(softmax_threshold_helper2), 1, 0),
189+
id="SOFT_MAX MASK_THRESHOLD MAX_INDICATOR 2D", marks=pytest.mark.llvm_not_implemented),
190+
109191
# Linear Matrix
110192
pytest.param(pnl.MatrixTransform, test_var, {kw.MATRIX:test_matrix}, np.dot(test_var, test_matrix), id="LINEAR_MATRIX SQUARE"),
111193
pytest.param(pnl.MatrixTransform, test_var, {kw.MATRIX:test_matrix_l}, np.dot(test_var, test_matrix_l), id="LINEAR_MATRIX WIDE"),

0 commit comments

Comments
 (0)