Skip to content

Commit f1a3b67

Browse files
committed
Least Square Discriminant
1 parent f8df544 commit f1a3b67

File tree

5 files changed

+63
-32
lines changed

5 files changed

+63
-32
lines changed

.cache

945 Bytes
Binary file not shown.

src/main/scala/App.scala

+20-1
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,29 @@ object App {
1010
var classData = ClassReader.trainRead("./data/classes.csv")
1111
var testData = ClassReader.testRead("./data/test_cat.csv")
1212

13+
/*
14+
val ld = new LeastSquareDiscriminant(2, 2)
15+
val W = ld.train(classData)
16+
println(W)
17+
val a = - (W(0,1) - W(1,1)) / (W(0,2) - W(1,2))
18+
val b = - (W(0,0) - W(1,0)) / (W(0,2) - W(1,2))
19+
var estimate_cat = new PrintWriter("./data/estimate_cat.csv")
20+
testData.foreach {
21+
x => {
22+
val y = a * x + b
23+
estimate_cat.println(x + "," + y)
24+
}
25+
}
26+
estimate_cat.flush()
27+
estimate_cat.close()
28+
*/
29+
30+
31+
32+
//* For Perceptron discriminant
1333
val ld = new LinearDiscriminant(2)
1434
val w = ld.train(classData)
1535
println(w)
16-
1736
val a = -w(1)/w(2)
1837
val b = -w(0)/w(2)
1938
var estimate_cat = new PrintWriter("./data/estimate_cat.csv")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package com.PhysicsEngine.cpistats
2+
import breeze.linalg._
3+
4+
class LeastSquareDiscriminant(D: Int, K: Int) {
5+
private var _w = DenseVector.ones[Double](D+1)
6+
7+
def train(plist: List[DenseVector[Double]]): DenseMatrix[Double] = {
8+
val n = plist.length
9+
val T = DenseMatrix.zeros[Double](n, K)
10+
val X = DenseMatrix.zeros[Double](n, D+1)
11+
12+
for (i <- 0 until n) {
13+
val x = plist(i)
14+
X(i, 0) = 1.0
15+
for (j <- 1 until D+1) {
16+
X(i, j) = x(j-1)
17+
}
18+
}
19+
20+
for (i <- 0 until n) {
21+
val t = plist(i)
22+
if (t(2) > 0.0) {
23+
T(i, 0) = 1
24+
T(i, 1) = 0
25+
} else {
26+
T(i, 0) = 0
27+
T(i, 1) = 1
28+
}
29+
}
30+
31+
//println(inv(X.t * X) * X.t * T)
32+
val W = inv(X.t * X) * X.t * T
33+
W.t
34+
}
35+
}

src/main/scala/LinearDiscriminant.scala

+1-27
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import java.io._
44

55
class LinearDiscriminant(val D: Int) {
66
private var _w = DenseVector.ones[Double](D+1)
7-
_w = _w * -5.0
87
private var _eta = 0.5
98
def train(plist: List[DenseVector[Double]]): DenseVector[Double] = {
109
var correct = 0
@@ -13,17 +12,6 @@ class LinearDiscriminant(val D: Int) {
1312
step += 1
1413
correct = 0
1514

16-
/*
17-
val index = Math.floor(Math.random()*plist.length).toInt
18-
val point = plist(index)
19-
val phi = DenseVector(1.0, point(0), point(1))
20-
if ((_w.t * phi).apply(0) * point(2) < 0) {
21-
_w += phi * point(2) * _eta
22-
}
23-
*
24-
*/
25-
26-
2715
plist.foreach {
2816
point => {
2917
val phi = DenseVector(1.0, point(0), point(1))
@@ -41,22 +29,8 @@ class LinearDiscriminant(val D: Int) {
4129
}
4230
}
4331
}
44-
/*
45-
plist.foreach {
46-
point => {
47-
val phi = DenseVector(1, point(0), point(1))
48-
if ((_w.t * phi).apply(0) * point(2) < 0) {
49-
_w += phi * point(2) * _eta
50-
}
51-
else {
52-
correct += 1
53-
}
54-
}
55-
}
56-
*
57-
*/
5832
}
59-
println(step)
33+
println("Step: " + step)
6034
_w
6135
}
6236
}

tools/genTrainForCat.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,21 @@
66
from pylab import *
77

88
K = 2
9-
N = 200
9+
N = 50
1010

1111
if __name__ == "__main__":
1212
cls1 = []
1313
cls2 = []
1414

15-
mean1 = [-0.5, 0.7]
16-
mean2 = [0.8, -0.8]
17-
cov = [[0.2,0.1], [0.1,0.2]]
15+
16+
mean1 = [-1.2, 1.4]
17+
mean2 = [-0.5, 0.5]
18+
mean3 = [1.6, -1.4]
19+
cov = [[0.2,0.1], [0.1,0.1]]
1820

1921
cls1.extend(np.random.multivariate_normal(mean1, cov, N/2))
2022
cls2.extend(np.random.multivariate_normal(mean2, cov, N/2))
23+
cls2.extend(np.random.multivariate_normal(mean3, cov, N/2))
2124

2225

2326
f = open("/Users/sasakiumi/MyWorks/cpi-stats/data/classes.csv", "w")

0 commit comments

Comments
 (0)