Skip to content

Commit

Permalink
PtLMBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
KexinFeng committed Jun 13, 2023
1 parent 30f6598 commit 8d91ef7
Showing 1 changed file with 6 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,22 +53,13 @@ public CausalLMOutput forward(NDList input, NDList pastKeyValues, NDManager mana
input = new NDList(input.get(0), input.get(1), attentionMask);
}

// forward call
IValue[] inputNative =
input.stream()
.map(object -> IValue.from((PtNDArray) object))
.toArray(IValue[]::new);
IValue resultIValue =
((PtSymbolBlock) blocks[0])
.forward(
inputNative[0],
inputNative[1],
inputNative[2],
IValueUtils.toTupleIValue(
pastKeyValues, new long[] {config.getNumLayers(), 2}));
String tupleName = "past_key_values(" + config.getNumLayers() + ',' + 2 + ')';
for (NDArray array : pastKeyValues) {
array.setName(tupleName);
}
input.addAll(pastKeyValues);

NDList output = resultIValue.toNDList((PtNDManager) manager);
Arrays.stream(inputNative).forEach(IValue::close);
NDList output = blocks[0].forward(null, input, false, null);

NDArray logitsOutput = output.get(0);
NDList pastKeyValuesOutput = output.subNDList(1, config.getNumLayers() * 2 + 1);
Expand Down

0 comments on commit 8d91ef7

Please sign in to comment.