Skip to content

Commit 1126df0

Browse files
committed
bias-variance
1 parent 1d8205d commit 1126df0

9 files changed

+568
-540
lines changed

.cache

128 Bytes
Binary file not shown.

.gitignore

+3-1
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,6 @@ project/boot/
1010
project/plugins/project/
1111

1212
# Scala-IDE specific
13-
.scala_dependencies
13+
.scala_dependencies
14+
15+
*.csv

estimate.csv

+500-500
Large diffs are not rendered by default.

genTrain.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
import sys
5+
import numpy as np
6+
7+
if __name__ == "__main__":
8+
for k in range(0, 20):
9+
xlist = np.linspace(0, 1, 25)
10+
tlist = np.sin(2*np.pi*xlist) + np.random.normal(0, 0.2, xlist.size)
11+
f = open("train" + str(k) + ".csv", "w")
12+
for i in range(0, len(xlist)):
13+
f.write(str(xlist[i]) + "," + str(tlist[i]) + "\n")
14+
15+
f.close()

results/lambda0.0000000001.png

78.9 KB
Loading

results/lambda1.1.png

47.5 KB
Loading

src/main/scala/App.scala

+19-19
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,11 @@ import com.github.tototoshi.csv._
77

88
object App {
99
def main(args: Array[String]) {
10-
//val data = EmployeeData.read()
10+
/*
1111
val train = DataReader.trainRead("train.csv")
1212
13-
//val data = List[Pair[Double,Double]]((0.07218849, 1.47822651), (1.61854596, 11.57308266), (2.74441742, 23.53791146), (3.36337206, 36.71027063), (4.96067527, 40.60451491), (5.13515737, 58.41530339), ( 6.43154354, 64.22937428), (7.68320999, 79.51190066), (8.3692042, 91.04207353), ( 9.85696038, 97.7546571))
14-
//val data = List[Pair[Double,Double]]((1.0,1.0),(2.0,2.0),(3.0,3.0))
15-
16-
//val data = List[Pair[Double,Double]]((1.0, Math.sin(1.0)), (2.0, Math.sin(2.0)), (3.0, Math.sin(3.0)))
17-
val wlist = LeastSquares.estimate(train, 0.0)
18-
val test = DataReader.testRead("test.csv")
19-
20-
2113
val estimate = new PrintWriter("estimate.csv")
14+
2215
val sums = DenseMatrix.zeros[Double](BayesEstimation.M+1, BayesEstimation.M+1)
2316
for(i <- 0 to train.length-1) {
2417
sums += BayesEstimation.phi(train(i)._1) * BayesEstimation.phi(train(i)._1).t
@@ -36,18 +29,25 @@ object App {
3629
estimate.println(x + "," + m + "," + u + "," + l)
3730
}
3831
}
39-
40-
/*
41-
test.foreach {
42-
x => {
43-
estimate.println(x + "," + LinearEquation.y(x, wlist))
44-
}
45-
}
4632
*
4733
*/
48-
49-
estimate.flush()
50-
estimate.close()
34+
35+
for(i <- 0 to 19) {
36+
val tfilename = "train" + i + ".csv"
37+
val efilename = "estimate" + i + ".csv"
38+
val train = DataReader.trainRead(tfilename)
39+
val wlist = LeastSquares.estimate(train, 10.0)
40+
val test = DataReader.testRead("test.csv")
41+
val estimate = new PrintWriter(efilename)
42+
43+
test.foreach {
44+
x => {
45+
estimate.println(x + "," + LinearEquation.y(x, wlist))
46+
}
47+
}
48+
estimate.flush()
49+
estimate.close()
50+
}
5151

5252

5353
}

train.csv

-20
Original file line numberDiff line numberDiff line change
@@ -1,20 +0,0 @@
1-
0.0,-0.3064923
2-
0.05263158,0.57885623
3-
0.10526316,0.62210423
4-
0.15789474,0.68922288
5-
0.21052632,1.10374508
6-
0.26315789,0.98969203
7-
0.31578947,1.30612747
8-
0.36842105,1.00316323
9-
0.42105263,0.6029487
10-
0.47368421,0.04878376
11-
0.52631579,-0.49498364
12-
0.57894737,-0.87668114
13-
0.63157895,-0.65494625
14-
0.68421053,-1.10763935
15-
0.73684211,-1.13152455
16-
0.78947368,-0.72751465
17-
0.84210526,-0.98385029
18-
0.89473684,-0.755865
19-
0.94736842,-0.47920383
20-
1.0,-0.0789461

vizMulti.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import os
4+
import sys
5+
import csv
6+
from numpy import *
7+
from pylab import *
8+
9+
if __name__ == "__main__":
10+
for i in range(0, 19):
11+
estimate = open('estimate' + str(i) + '.csv', 'rb')
12+
13+
estreader = csv.reader(estimate)
14+
estxs = []
15+
estts = []
16+
for row in estreader:
17+
estxs.append(row[0])
18+
estts.append(row[1])
19+
estxs = array(estxs)
20+
estts = array(estts)
21+
22+
23+
# plot(trainxs, traints, 'bo')
24+
plot(estxs, estts, 'g-')
25+
xlim(0.0, 1.0)
26+
ylim(-1.5, 1.5)
27+
estimate.close()
28+
29+
show()
30+
31+

0 commit comments

Comments
 (0)