@@ -2828,9 +2828,10 @@ class SoftMax(TransferFunction):
2828
2828
<SoftMax.gain>` parametrically based on the `variable <SoftMax.variable>`:
2829
2829
2830
2830
- *mask_threshold* -- setting the **mask_threshold** argument to a scalar value causes the `variable
2831
- <SoftMax.variable>` to be thresholded by that value before applying the SoftMax function; any elements of
2832
- `variable <SoftMax.variable>` with an absolute value below the threshold are set to 0; all others are scaled
2833
- by the specified `gain <SoftMax.gain>` and then passed through the SoftMax function. This only applies if the
2831
+ <SoftMax.variable>` to be thresholded by that value before applying the SoftMax function; Each element in
2832
+ variable <SoftMax.variable> is first scaled by gain <SoftMax.gain>. Then, any elements with an absolute
2833
+ value below *mask_threshold* are set to negative infinity (``-inf``), effectively masking them since
2834
+ ``exp(-inf) = 0``. The remaining values are then passed through the SoftMax function. This only applies if the
2834
2835
**gain** argument is specified as a scalar; if it is specified as *ADAPTIVE*, then the **mask_threshold**
2835
2836
argument is ignored.
2836
2837
@@ -2920,10 +2921,11 @@ class SoftMax(TransferFunction):
2920
2921
2921
2922
mask_threshold : scalar or None
2922
2923
determines whether the `variable <SoftMax.variable>` is thresholded before applying the SoftMax function;
2923
- if it is a scalar, only elements of `variable <SoftMax.variable>` with an absolute value greater than that
2924
- value are considered when applying the SoftMax function (which are then scaled by the `gain <SoftMax.gain>`
2925
- parameter; all other elements are assigned 0. This only applies if `gain <SoftMax.gain>` is specified as a
2926
- scalar; otherwise it is ignored (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2924
+ if it is a scalar, each elements of `variable <SoftMax.variable>` is first scaled by `<SoftMax.gain>`. Then,
2925
+ only elements with an absolute value greater than *mask_threshold* are considered when applying the SoftMax
2926
+ function, while all other elements are set to ``-inf`` effectively masking them since ``exp(-inf) = 0``.
2927
+ This only applies if `gain <SoftMax.gain>` is specified as a scalar; otherwise it is ignored
2928
+ (see `Thresholding and Adaptive Gain <SoftMax_AdaptGain>` for details).
2927
2929
2928
2930
adapt_scale : scalar
2929
2931
determines the *scale* parameter using by the `adapt_gain <SoftMax.adapt_gain>` method (see method for details).
@@ -3149,22 +3151,31 @@ def _validate_variable(self, variable, context=None):
3149
3151
return np .asarray (variable )
3150
3152
3151
3153
def apply_softmax (self , input_value , gain , mask_threshold , output_type ):
3152
-
3153
3154
# Modulate input_value by gain
3154
3155
v = gain * input_value
3155
- # Shift by max to avoid extreme values:
3156
- v = v - np .max (v )
3156
+
3157
+ # Mask threshold
3158
+ if mask_threshold is not None :
3159
+ if np .any (v < 0 ):
3160
+ warnings .warn (f"SoftMax function: mask_threshold is set "
3161
+ f"to { mask_threshold } but input_value contains negative values."
3162
+ f"Masking will be applied to the magnitude of the input." )
3163
+
3164
+ v = np .where (np .abs (v ) > mask_threshold , v , - np .inf )
3165
+
3166
+ # Make numerically stable by shifting by max value
3167
+ if np .any (v != - np .inf ):
3168
+ v = v - np .max (v )
3169
+
3157
3170
# Exponentiate
3158
3171
v = np .exp (v )
3159
- # Threshold if specified:
3160
- if mask_threshold :
3161
- v = v * np .where (input_value > mask_threshold , v , 0 )
3172
+
3162
3173
# Normalize (to sum to 1)
3163
- if not any (v ):
3174
+ if not np . any (v ):
3164
3175
# If v is all zeros, avoid divide by zero in normalize and return all zeros for softmax
3165
3176
sm = v
3166
3177
else :
3167
- sm = v / np .sum (v , axis = 0 )
3178
+ sm = v / np .sum (v )
3168
3179
3169
3180
# Generate one-hot encoding based on selected output_type
3170
3181
if output_type in {ARG_MAX , ARG_MAX_INDICATOR , MAX_VAL , MAX_INDICATOR }:
@@ -3472,15 +3483,34 @@ def _gen_pytorch_fct(self, device, context=None):
3472
3483
if isinstance (gain , str ) and gain == ADAPTIVE :
3473
3484
return lambda x : (torch .softmax (self ._gen_pytorch_adapt_gain_fct (device , context )(x ) * x , - 1 ))
3474
3485
3475
- elif mask_threshold :
3486
+ elif mask_threshold is not None :
3476
3487
def pytorch_thresholded_softmax (_input : torch .Tensor ) -> torch .Tensor :
3477
- # Mask elements of input below threshold
3478
- _mask = (torch .abs (_input ) > mask_threshold )
3479
- # Subtract off the max value in the input to eliminate extreme values, exponentiate, and apply mask
3480
- masked_exp = _mask * torch .exp (gain * (_input - torch .max (_input , - 1 , keepdim = True )[0 ]))
3481
- if (masked_exp == 0 ).all ():
3482
- return masked_exp
3483
- return masked_exp / torch .sum (masked_exp , - 1 , keepdim = True )
3488
+ v = gain * _input
3489
+
3490
+ # Apply threshold-based masking
3491
+ if mask_threshold is not None :
3492
+ if torch .any (_input < 0 ):
3493
+ warnings .warn (f"Softmax function: mask_threshold is set to { mask_threshold } , "
3494
+ f"but input contains negative values. "
3495
+ f"Masking will be applied to the magnitude of the input." )
3496
+
3497
+ # Create a mask where values below threshold are set to -inf
3498
+ mask = torch .abs (v ) > mask_threshold
3499
+ v = v .masked_fill (~ mask , float ('-inf' )) # More stable than torch.where()
3500
+
3501
+ # Handle case where all values are masked (return tensor with gradient support)
3502
+ if torch .all (~ mask ):
3503
+ return torch .full_like (v , 0.0 , requires_grad = True )
3504
+
3505
+ # Make numerically stable by shifting max value
3506
+ max_v = torch .max (v [mask ]) # Avoid computing max over -inf
3507
+ v = v - max_v
3508
+
3509
+ # Compute softmax (PyTorch handles -inf correctly)
3510
+ exp_v = torch .exp (v )
3511
+ sm = exp_v / torch .sum (exp_v , dim = - 1 , keepdim = True )
3512
+
3513
+ return sm
3484
3514
# Return the function
3485
3515
return pytorch_thresholded_softmax
3486
3516
0 commit comments