diff --git a/neuralcoref/train/dataset.py b/neuralcoref/train/dataset.py index 0831768..850c516 100644 --- a/neuralcoref/train/dataset.py +++ b/neuralcoref/train/dataset.py @@ -228,7 +228,7 @@ def __getitem__(self, mention_idx, debug=False): [self.mentions[idx.item()][0][np.newaxis, :] for idx in pairs_ant_index] ) ant_features = np.zeros((pairs_length, SIZE_FS - SIZE_GENRE)) - ant_features[:, ant_features_raw[:, 0]] = 1 + ant_features[np.arange(pairs_length), ant_features_raw[:, 0]] = 1 ant_features[:, 4:15] = encode_distance(ant_features_raw[:, 1]) ant_features[:, 15] = ant_features_raw[:, 2].astype(float) / ant_features_raw[ :, 3