Skip to content

Commit 9e8bfdf

Browse files
author
liuyu
committed
Add some interface n solver
1 parent 94b828f commit 9e8bfdf

File tree

3 files changed

+54
-11
lines changed

3 files changed

+54
-11
lines changed

clip

Whitespace-only changes.

matlab/+caffe/Solver.m

+44
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,50 @@
4444
function iter = iter(self)
4545
iter = caffe_('solver_get_iter', self.hSolver_self);
4646
end
47+
% add new interfaces for solver
48+
function use_caffemodel(self, model_dir)
49+
for t = 1 : length(self.gpu_ids)
50+
self.nets{t}.copy_from(model_dir);
51+
end
52+
end
53+
function set_input_data(self, inputs)
54+
for t = 1 : length(self.gpu_ids)
55+
input_data = inputs{t};
56+
self.nets{t}.set_input_data(input_data);
57+
end
58+
end
59+
function output = get_output(self)
60+
output = cell(length(self.gpu_ids), 1);
61+
for t = 1 : length(self.gpu_ids)
62+
output{t} = self.nets{t}.get_output();
63+
end
64+
end
65+
function set_phase(self, phase)
66+
for t = 1 : length(self.gpu_ids)
67+
self.nets{t}.set_phase(phase);
68+
end
69+
end
70+
function reshape_as_input(self, inputs)
71+
for t = 1 : length(self.gpu_ids)
72+
input_data = inputs{t};
73+
self.nets{t}.reshape_as_input(input_data);
74+
end
75+
end
76+
function forward(self, inputs)
77+
for t = 1 : length(self.gpu_ids)
78+
input_data = inputs{t};
79+
self.nets{t}.set_input_data(input_data);
80+
end
81+
caffe_('solver_test');
82+
end
83+
function forward_prefilled(self)
84+
caffe_('solver_test');
85+
end
86+
function snapshot(self, path)
87+
self.nets{1}.save(path);
88+
end
89+
% add done
90+
4791
function max_iter = max_iter(self)
4892
max_iter = caffe_('solver_get_max_iter', self.hSolver_self);
4993
end

matlab/+caffe/private/caffe_.cpp

+10-11
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ static void solver_step(MEX_ARGS) {
442442
#pragma omp parallel num_threads(int(gpu_groups_[now_solver_].size()))
443443
{
444444
int ID = omp_get_thread_num();
445+
Caffe::SetDevice(gpu_groups_[ now_solver_ ][ ID ]);
445446
if ( ID == 0 ){
446447
#ifdef DEBUG
447448
LOG(INFO) << "Card " << ID << " at point 0\n";
@@ -825,17 +826,15 @@ static void write_mean(MEX_ARGS) {
825826

826827
// Usage: caffe_('solver_test')
827828
static void solver_test(MEX_ARGS) {
828-
mxCHECK(nrhs == 1 && mxIsChar(prhs[ 0 ]),
829-
"Usage: caffe_('read_mean', mean_proto_file)");
830-
char* mean_proto_file = mxArrayToString(prhs[ 0 ]);
831-
mxCHECK_FILE_EXIST(mean_proto_file);
832-
Blob<float> data_mean;
833-
BlobProto blob_proto;
834-
bool result = ReadProtoFromBinaryFile(mean_proto_file, &blob_proto);
835-
mxCHECK(result, "Could not read your mean file");
836-
data_mean.FromProto(blob_proto);
837-
plhs[ 0 ] = blob_to_mx_mat(&data_mean, DATA);
838-
mxFree(mean_proto_file);
829+
#pragma omp parallel num_threads(int(gpu_groups_[ now_solver_ ].size()))
830+
{
831+
int ID = omp_get_thread_num();
832+
Caffe::SetDevice(gpu_groups_[ now_solver_ ][ ID ]);
833+
if ( ID == 0 )
834+
syncSolvers_[ now_solver_ ]->solver()->net()->ForwardPrefilled();
835+
else
836+
syncSolvers_[ now_solver_ ]->workers()[ ID ]->solver()->net()->ForwardPrefilled();
837+
}
839838
}
840839

841840
// Usage: caffe_('version')

0 commit comments

Comments
 (0)