@@ -50,6 +50,7 @@ static vector<shared_ptr<Net<float> > > nets_;
50
50
static vector< vector<int > > gpu_groups_;
51
51
static int now_solver_ = -1 ;
52
52
static bool is_log_inited = false ;
53
+ static bool is_init = 0 ;
53
54
// init_key is generated at the beginning and everytime you call reset
54
55
#ifndef _MSC_VER // We are not using MSVC.
55
56
static double init_key = static_cast <double >(caffe_rng_rand());
@@ -244,6 +245,8 @@ static mxArray* ptr_vec_to_handle_vec(const vector<shared_ptr<T> >& ptr_vec) {
244
245
static void get_solver (MEX_ARGS) {
245
246
mxCHECK (nrhs >= 1 && mxIsChar (prhs[0 ]),
246
247
" Usage: caffe_('get_solver', solver_file, [gpu_id])" );
248
+ mxCHECK (!is_init, " Solver has already init, for now only support single multi-gpu solver." );
249
+ is_init = true ;
247
250
// if ( nrhs == 2 )
248
251
// {
249
252
// mxCHECK(mxIsNumeric(prhs[ 1 ]), "Device_ids only supports double vector~");
@@ -397,19 +400,6 @@ static void solver_solve(MEX_ARGS) {
397
400
mexPrintf (" Solver %d finished.\n " , now_solver_);
398
401
}
399
402
400
- static void do_set_phase (int ID, enum Phase phase_t ){
401
- vector<boost::shared_ptr<Layer<float > > >& layers = syncSolvers_[ID]->solver ()->net ()->getlayers ();
402
- for ( int i = 0 ; i < layers.size (); ++i ){
403
- layers[ i ]->set_phase (phase_t );
404
- }
405
-
406
- for ( int i = 1 ; i < int (gpu_groups_[ ID ].size ()); ++i ){
407
- vector<boost::shared_ptr<Layer<float > > >& layers = syncSolvers_[ ID ]->workers ()[ i ]->solver ()->net ()->getlayers ();
408
- for ( int j = 0 ; j < layers.size (); ++j ){
409
- layers[ j ]->set_phase (phase_t );
410
- }
411
- }
412
- }
413
403
414
404
static void do_set_phase (enum Phase phase_t ){
415
405
int ID = now_solver_;
@@ -508,6 +498,12 @@ static void net_set_phase(MEX_ARGS) {
508
498
mxERROR (" Unknown phase" );
509
499
}
510
500
net->SetPhase (phase);
501
+ vector<boost::shared_ptr<Layer<float > > >& layers = net->getlayers ();
502
+ for ( int i = 0 ; i < layers.size (); ++i ){
503
+ layers[ i ]->set_phase (phase);
504
+ }
505
+
506
+
511
507
mxFree (phase_name);
512
508
}
513
509
@@ -731,11 +727,12 @@ static void reset(MEX_ARGS) {
731
727
gpu_groups_.clear ();
732
728
// Generate new init_key, so that handles created before becomes invalid
733
729
init_key = static_cast <double >(caffe_rng_rand ());
734
- if ( is_log_inited )
730
+ /* if ( is_log_inited )
735
731
{
736
732
is_log_inited = false;
737
733
::google::ShutdownGoogleLogging();
738
- }
734
+ }*/
735
+ is_init = false ;
739
736
}
740
737
741
738
// Usage: caffe_('set_random_seed', random_seed)
@@ -773,8 +770,11 @@ static void init_log(MEX_ARGS) {
773
770
774
771
mxCHECK (nrhs == 1 && mxIsChar (prhs[0 ]),
775
772
" Usage: caffe_('init_log', log_dir)" );
776
- if (is_log_inited)
777
- ::google::ShutdownGoogleLogging ();
773
+ if ( is_log_inited ){
774
+ // ::google::ShutdownGoogleLogging();
775
+ return ;
776
+
777
+ }
778
778
char * log_base_filename = mxArrayToString (prhs[0 ]);
779
779
::google::SetLogDestination (0 , log_base_filename);
780
780
mxFree (log_base_filename);
@@ -823,6 +823,21 @@ static void write_mean(MEX_ARGS) {
823
823
mxFree (mean_proto_file);
824
824
}
825
825
826
+ // Usage: caffe_('solver_test')
827
+ static void solver_test (MEX_ARGS) {
828
+ mxCHECK (nrhs == 1 && mxIsChar (prhs[ 0 ]),
829
+ " Usage: caffe_('read_mean', mean_proto_file)" );
830
+ char * mean_proto_file = mxArrayToString (prhs[ 0 ]);
831
+ mxCHECK_FILE_EXIST (mean_proto_file);
832
+ Blob<float > data_mean;
833
+ BlobProto blob_proto;
834
+ bool result = ReadProtoFromBinaryFile (mean_proto_file, &blob_proto);
835
+ mxCHECK (result, " Could not read your mean file" );
836
+ data_mean.FromProto (blob_proto);
837
+ plhs[ 0 ] = blob_to_mx_mat (&data_mean, DATA);
838
+ mxFree (mean_proto_file);
839
+ }
840
+
826
841
// Usage: caffe_('version')
827
842
static void version (MEX_ARGS) {
828
843
mxCHECK (nrhs == 0 , " Usage: caffe_('version')" );
@@ -847,7 +862,8 @@ static handler_registry handlers[] = {
847
862
{ " solver_get_max_iter" , solver_get_max_iter },
848
863
{ " solver_restore" , solver_restore },
849
864
{ " solver_solve" , solver_solve },
850
- { " solver_step" , solver_step },
865
+ { " solver_step" , solver_step },
866
+ { " solver_test" , solver_test },
851
867
{ " get_net" , get_net },
852
868
{ " net_get_attr" , net_get_attr },
853
869
{ " net_set_phase" , net_set_phase },
0 commit comments