Skip to content

Commit 2af5a0c

Browse files
committed
trainsetup tested
1 parent 2eba06b commit 2af5a0c

File tree

6 files changed

+28
-21
lines changed

6 files changed

+28
-21
lines changed

Manifest.toml

+17-3
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,12 @@ uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
55
version = "0.3.2"
66

77
[[AutoGrad]]
8-
deps = ["LinearAlgebra", "Pkg", "Printf", "Random", "SpecialFunctions", "Statistics", "Test"]
9-
git-tree-sha1 = "bd5bab7aa42b9ce1d5a054495b58c8f928227b8e"
8+
deps = ["LinearAlgebra", "Printf", "Random", "SpecialFunctions", "Statistics", "Test", "TimerOutputs"]
9+
git-tree-sha1 = "24a4f2faa56dcf1ebf73194be086ebca5782ee6d"
10+
repo-rev = "dy/1.1"
11+
repo-url = "https://github.com/denizyuret/AutoGrad.jl.git"
1012
uuid = "6710c13c-97f1-543f-91c5-74e8f7d95b35"
11-
version = "1.1.0"
13+
version = "1.1.0+"
1214

1315
[[AxisAlgorithms]]
1416
deps = ["Compat", "WoodburyMatrices"]
@@ -97,6 +99,12 @@ git-tree-sha1 = "47f05d0b7f4999609f92e657147df000818c1f24"
9799
uuid = "150eb455-5306-5404-9cee-2592286d6298"
98100
version = "0.5.0"
99101

102+
[[Crayons]]
103+
deps = ["Pkg", "Test"]
104+
git-tree-sha1 = "3017c662a988bcb8a3f43306a793617c6524d476"
105+
uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f"
106+
version = "1.0.0"
107+
100108
[[CustomUnitRanges]]
101109
deps = ["Pkg", "Test"]
102110
git-tree-sha1 = "0a106457a1831555857e18ac9617279c22fc393b"
@@ -452,6 +460,12 @@ git-tree-sha1 = "58f6f07d3b54a363ec283a8f5fc9fb4ecebde656"
452460
uuid = "06e1c1a7-607b-532d-9fad-de7d9aa2abac"
453461
version = "0.2.3"
454462

463+
[[TimerOutputs]]
464+
deps = ["Crayons", "Printf", "Test", "Unicode"]
465+
git-tree-sha1 = "89a9bd610d6bfd62a7c2b85112762b99b979fe5f"
466+
uuid = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
467+
version = "0.4.0"
468+
455469
[[Tqdm]]
456470
deps = ["Printf"]
457471
git-tree-sha1 = "59dcc6bf73d11e91b6d8ddbe77546a822e827098"

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[deps]
2+
AutoGrad = "6710c13c-97f1-543f-91c5-74e8f7d95b35"
23
Images = "916415d5-f1e6-5110-898d-aaa5f9f070e0"
3-
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
44
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
55
Knet = "1902f260-5fb4-5aff-8c31-6271790ab950"
66
Tqdm = "a73858fe-bd08-11e8-065c-3dfff236a6cd"

src/main.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
using Pkg; Pkg.activate("../")
12
#This implementation is very similar to original implementation in https://github.com/stanfordnlp/mac-network
2-
using JSON,Knet,Images
3+
using JSON,Knet,Images,HDF5
34
using Printf,Random,Tqdm
45
include("model.jl")
56
savemodel(filename,m,mrun,o) = Knet.save(filename,"m",m,"mrun",mrun,"o",o)
@@ -152,7 +153,7 @@ function modelrun(M,data,feats,o,Mrun=nothing;train=false)
152153
if train
153154
J = @diff M(questions,batchSizes,xS,xB,xP;answers=answers,p=o[:p],selfattn=o[:selfattn],gating=o[:gating])
154155
flush(Base.stdout)
155-
cnt += value(J)*B; total += B;
156+
cnt += value(J)*B; total += B;
156157
if acc===nothing
157158
acc = atype[];
158159
end

src/model.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ using Knet, Random
22
if !isdefined(Main,:atype)
33
global atype = gpu() < 0 ? Array{Float32} : KnetArray{Float32}
44
end
5-
include("loss2.jl")
5+
include("loss.jl")
66
abstract type Model;end;
77
struct ResNet <: Model; w; end
88
function (M::ResNet)(m,imgurl::String,avgimg;stage=3)
@@ -272,7 +272,7 @@ function (M::MACNetwork)(qs,batchSizes,xS,xB,xP;answers=nothing,p=12,selfattn=fa
272272
if selfattn; push!(cj,ci); push!(mj,mi); end
273273
tap!=nothing && (tap["cnt"]+=1)
274274
end
275-
275+
276276
y = M.output(q,mi;train=train)
277277

278278
if answers==nothing

train.jl

+2-7
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,13 @@
11
include(ARGS[1])
22
include(ARGS[2])
3-
# o=Dict(:h5=>false,:mfile=>nothing,:epochs=>16,
4-
# :lr=>0.0001,:p=>12,:ema=>0.999f0,:batchsize=>32,
5-
# :selfattn=>false,:gating=>false,
6-
# :shuffle=>true,:sorted=>false,:prefix=>string(now())[1:10],
7-
# :vocab_size=>90,:embed_size=>300, :dhome=>"data/", :loadresnet=>false,:d=>512)
83
println("Loading questions ...")
94
trnqstns = getQdata(o[:dhome],"train")
105
valqstns = getQdata(o[:dhome],"val")
116
println("Loading dictionaries ... ")
127
qvoc,avoc,i2w,i2a = getDicts(o[:dhome],"dic")
138
sets = []
14-
push!(sets,shuffle!(miniBatch(trnqstns;B=48,srtd=true)))
15-
push!(sets,shuffle!(miniBatch(valqstns;B=48,srtd=true)))
9+
push!(sets,shuffle!(miniBatch(trnqstns;B=32,srtd=true)))
10+
push!(sets,shuffle!(miniBatch(valqstns;B=32,srtd=true)))
1611
trnqstns=nothing;
1712
valqstns=nothing;
1813

trainsetup.jl

+3-6
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,22 @@
11
using Pkg; Pkg.activate("."); Pkg.instantiate(); #install required packages
22
server = "ai.ku.edu.tr/"
33
if length(ARGS)==0
4-
error("clevr home folder is not specified. Pre-processed data will be downloaded from the servers(70GB)? (yes or no)")
4+
println("clevr home folder is not specified. Pre-processed data will be downloaded from the servers(70GB) if they are not in the data folder? (yes or no)")
55
if readline() == "yes"
66
!isfile("data/train.bin") && download(server*"data/mac-network/train.bin","data/train.bin")
77
!isfile("data/val.bin") && download(server*"data/mac-network/val.bin","data/val.bin")
88
!isfile("data/train.json") && download(server*"data/mac-network/train.json","data/train.json")
99
!isfile("data/val.json") && download(server*"data/mac-network/val.json","data/val.json")
1010
!isfile("data/dic.json") && download(server*"data/mac-network/dic.json","data/dic.json")
11+
exit()
1112
else
1213
error("No data available!")
1314
end
1415
else
1516
CLEVR_HOME = ARGS[1]
17+
println(CLEVR_HOME)
1618
end
1719

18-
println(CLEVR_HOME)
19-
20-
println("Checking dependencies...")
21-
include("requirements.jl")
22-
2320
include("preprocess.jl")
2421
println("Starting to preprocess CLEVR questions @: $(CLEVR_HOME)/questions ")
2522
println("Training and validation questions will be processed...")

0 commit comments

Comments
 (0)