Skip to content

Commit 871f620

Browse files
committed
minor cleanups
1 parent 961d9cb commit 871f620

File tree

5 files changed

+8
-7
lines changed

5 files changed

+8
-7
lines changed

JavaPostevanka/GPTMul/GPTMul.java

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public GPTMul(Module model, Tokenizer tokenizer) {
1515
}
1616

1717
public int mul(int a, int b) {
18-
String prompt1 = String.format("%d*%d=0", a, b);
18+
String prompt1 = String.format("%d*%d=", a, b);
1919
Matrix input1 = tokenizer.transform(prompt1);
2020
Matrix output1 = model.forward(new Matrix[] {input1})[0];
2121
int new_token1 = output1.rowArgMax()[3];
@@ -29,6 +29,5 @@ public int mul(int a, int b) {
2929
String out = tokenizer.transformInverse(pred);
3030
return Integer.parseInt(out);
3131
}
32-
33-
32+
3433
}

JavaPostevanka/Matrix/Matrix.java

+3-3
Original file line numberDiff line numberDiff line change
@@ -309,15 +309,15 @@ public Matrix mul(Matrix other) {
309309
}
310310

311311
public Matrix mul(float other) {
312-
return mul(new Matrix(new float[][] {{other}}));
312+
return applyUnary((x) -> x * other);
313313
}
314314

315315
public Matrix div(Matrix other) {
316-
return applyBinary((x, y) -> x/y, other);
316+
return applyBinary((x, y) -> x / y, other);
317317
}
318318

319319
public Matrix div(float other) {
320-
return div(new Matrix(new float[][] {{other}}));
320+
return applyUnary((x) -> x / other);
321321
}
322322

323323
public Matrix reciprocal(float other) {

JavaPostevanka/NN/Layer/SoftMax.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public class SoftMax extends Module {
1010
@Override
1111
public Matrix[] forward(Matrix[] inputs) {
1212
Matrix X = inputs[0];
13-
Matrix expX = (X.add(X.rowMax().mul(-1))).exp();
13+
Matrix expX = X.sub(X.rowMax()).exp();
1414
activation = expX.div(expX.rowSum());
1515
return new Matrix[] {activation};
1616
}

JavaPostevanka/NN/Optim/SGD.java

+1
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ public void step() {
1818
p.data = p.data.sub(p.grad.mul(lr));
1919
}
2020
}
21+
2122
}

JavaPostevanka/Postevanka.java

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import JavaPostevanka.Trainer.Trainer;
1111

1212
public class Postevanka {
13+
1314
public static void main(String[] args) {
1415
Random rng = new Random(1337);
1516

0 commit comments

Comments
 (0)