Skip to content

Commit

Permalink
fix typing and refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yaozengwei committed Sep 13, 2022
1 parent af2fcf2 commit e71b541
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions egs/librispeech/ASR/pruned_transducer_stateless2/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def forward(
x: Tensor,
batch_dim: int, # e.g., 1
threshold: float, # e.g., 10.0
*params: nn.Parameter, # module parameters
*params: Tensor, # module parameters
):
if x.requires_grad:
if batch_dim < 0:
Expand All @@ -131,7 +131,7 @@ def forward(
def backward(
ctx,
x_grad: Tensor,
*param_grads: nn.Parameter,
*param_grads: Tensor,
):
dim = ctx.batch_dim
norm_dims = [d for d in range(x_grad.ndim) if d != dim]
Expand All @@ -140,11 +140,12 @@ def backward(
mask = norm_of_batch <= median_norm * ctx.threshold
x_grad = x_grad * mask

if not torch.all(mask):
# If any of elements of x_grad is zeroed,
# the param_grads would be fully zeroed.
for g in param_grads:
g.zero_()
# 1 if no grad was zeroed, 0 if any was zeroed
all = torch.all(mask).to(x_grad.dtype)
# If any of elements of x_grad is zeroed,
# the param_grads would be fully zeroed.
param_grads = [all * g for g in param_grads]

return x_grad, None, None, *param_grads


Expand All @@ -167,7 +168,7 @@ def __init__(self, batch_dim: int = 1, threshold: float = 10.0):
self.batch_dim = batch_dim
self.threshold = threshold

def forward(self, x: Tensor, *params: nn.Parameter) -> Tensor:
def forward(self, x: Tensor, *params: Tensor) -> Tuple[Tensor]:
if torch.jit.is_scripting() or is_jit_tracing():
return x, *params
else:
Expand Down

0 comments on commit e71b541

Please sign in to comment.