Skip to content

Commit

Permalink
[tfjs-core] fix gather gradient when batchDims is 1 (#7942)
Browse files Browse the repository at this point in the history
BUG
* feat: reproduce error

* feat: working code

* feat: remove logging

* feat: minor fix

* feat: refactor derX function

* Fix lint

---------

Co-authored-by: Matthew Soulanille <matthew@soulanille.net>
  • Loading branch information
paradite and mattsoulanille authored Sep 12, 2023
1 parent cfb217b commit f44e224
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 25 deletions.
66 changes: 41 additions & 25 deletions tfjs-core/src/gradients/GatherV2_grad.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {GatherV2, GatherV2Attrs} from '../kernel_names';
import {GradConfig, NamedAttrMap} from '../kernel_registry';
import {getUndoAxesPermutation} from '../ops/axis_util';
import {reshape} from '../ops/reshape';
import {stack} from '../ops/stack';
import {transpose} from '../ops/transpose';
import {unsortedSegmentSum} from '../ops/unsorted_segment_sum';
import {Tensor, Tensor1D} from '../tensor';
Expand All @@ -29,40 +30,55 @@ export const gatherGradConfig: GradConfig = {
inputsToSave: ['x', 'indices'],
gradFunc: (dy: Tensor, saved: Tensor[], attrs: NamedAttrMap) => {
const [x, indices] = saved;
const {axis} = attrs as unknown as GatherV2Attrs;
const {axis, batchDims} = attrs as unknown as GatherV2Attrs;

const parsedAxis = parseAxisParam(axis, x.shape)[0];

const derX = () => {
const paramsShape = x.shape;
const indicesSize = indices.size;
const derXBatch = (x: Tensor, indices: Tensor, dy: Tensor) => {
return (): Tensor => {
const paramsShape = x.shape;
const indicesSize = indices.size;

const outerShape = paramsShape.slice(0, parsedAxis);
const outerDims = outerShape.length;
const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
const innerDims = innerShape.length;
const outerShape = paramsShape.slice(0, parsedAxis);
const outerDims = outerShape.length;
const innerShape = paramsShape.slice(axis, paramsShape.length).slice(1);
const innerDims = innerShape.length;

const outerAxesIndices = arrayRange(0, outerDims);
const innerAxesIndices =
arrayRange(outerDims + 1, outerDims + 1 + innerDims);
const outerAxesIndices = arrayRange(0, outerDims);
const innerAxesIndices =
arrayRange(outerDims + 1, outerDims + 1 + innerDims);

const valuesShape = arrayConcat([outerShape, [indicesSize], innerShape]);
const valuesShape = arrayConcat([outerShape, [indicesSize],
innerShape]);

const values = reshape(dy, valuesShape);
const reshapedIndices = reshape(indices, [indicesSize]);
const values = reshape(dy, valuesShape);
const reshapedIndices = reshape(indices, [indicesSize]);

const transposeDims =
arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
const valuesTranspose = transpose(values, transposeDims);
let paramsGrad = unsortedSegmentSum(
valuesTranspose, reshapedIndices as Tensor1D, x.shape[parsedAxis]);

const invertTransposeDims = getUndoAxesPermutation(transposeDims);
paramsGrad = transpose(paramsGrad, invertTransposeDims);

return paramsGrad;
const transposeDims =
arrayConcat([[outerDims], outerAxesIndices, innerAxesIndices]);
const valuesTranspose = transpose(values, transposeDims);
let paramsGrad = unsortedSegmentSum(
valuesTranspose, reshapedIndices as Tensor1D, x.shape[parsedAxis]);
const invertTransposeDims = getUndoAxesPermutation(transposeDims);
paramsGrad = transpose(paramsGrad, invertTransposeDims);
return paramsGrad;
};
};
return {x: derX, indices: () => indices};

if (batchDims === 1) {
const batchSize = x.shape[0];
const xBatch = x.split(batchSize, 0);
const derXBatched = () => {
const stacked = stack(
xBatch.map((x, i) => {
return derXBatch(x, indices.slice(i,1), dy.slice(i,1))();
}));
return stacked.reshape(x.shape);
};
return {x: derXBatched, indices: () => indices};
} else {
return {x: derXBatch(x, indices, dy), indices: () => indices};
}
}
};

Expand Down
16 changes: 16 additions & 0 deletions tfjs-core/src/ops/gather_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,22 @@ describeWithFlags('gather', ALL_ENVS, (env) => {
expectArraysClose(await gradients.data(), [26, 36, 0, 0]);
});

it('gradient 2D (gather) axis=1 shape=[4, 2] 1D indices batchDims 1',
async () => {
const t = tf.variable(tf.tensor([[0, 1],
[1, 2],
[2, 3],
[3, 4]]));
const indices = tf.tensor([0, 1, 0, 1], [4, 1], 'int32');
const dy = tf.tensor([1, 1, 1, 1], [4, 1]);
const axis = 1;

const gradients = tf.grad(t => tf.gather(t, indices, axis, 1))(t, dy);

expect(gradients.shape).toEqual(t.shape);
expectArraysClose(await gradients.data(), [1, 0, 0, 1, 1, 0, 0, 1]);
});

it('gradient 2D (gather) axis=1 shape=[2, 2] 1D indices', async () => {
const t = tf.tensor2d([1, 11, 2, 22], [2, 2]);
const indices = tf.tensor1d([1, 0, 0, 1], 'int32');
Expand Down

0 comments on commit f44e224

Please sign in to comment.