Caffe
solver.hpp
1 #ifndef CAFFE_SOLVER_HPP_
2 #define CAFFE_SOLVER_HPP_
3 #include <boost/function.hpp>
4 #include <string>
5 #include <vector>
6 
7 #include "caffe/net.hpp"
8 #include "caffe/solver_factory.hpp"
9 
10 namespace caffe {
11 
20  namespace SolverAction {
21  enum Enum {
22  NONE = 0, // Take no special action.
23  STOP = 1, // Stop training. snapshot_after_train controls whether a
24  // snapshot is created.
25  SNAPSHOT = 2 // Take a snapshot, and keep training.
26  };
27  }
28 
32 typedef boost::function<SolverAction::Enum()> ActionCallback;
33 
40 template <typename Dtype>
41 class Solver {
42  public:
43  explicit Solver(const SolverParameter& param,
44  const Solver* root_solver = NULL);
45  explicit Solver(const string& param_file, const Solver* root_solver = NULL);
46  void Init(const SolverParameter& param);
47  void InitTrainNet();
48  void InitTestNets();
49 
50  // Client of the Solver optionally may call this in order to set the function
51  // that the solver uses to see what action it should take (e.g. snapshot or
52  // exit training early).
53  void SetActionFunction(ActionCallback func);
54  SolverAction::Enum GetRequestedAction();
55  // The main entry of the solver function. In default, iter will be zero. Pass
56  // in a non-zero iter number to resume training for a pre-trained net.
57  virtual void Solve(const char* resume_file = NULL);
58  inline void Solve(const string resume_file) { Solve(resume_file.c_str()); }
59  void Step(int iters);
60  // The Restore method simply dispatches to one of the
61  // RestoreSolverStateFrom___ protected methods. You should implement these
62  // methods to restore the state from the appropriate snapshot type.
63  void Restore(const char* resume_file);
64  // The Solver::Snapshot function implements the basic snapshotting utility
65  // that stores the learned net. You should implement the SnapshotSolverState()
66  // function that produces a SolverState protocol buffer that needs to be
67  // written to disk together with the learned net.
68  void Snapshot();
69  virtual ~Solver() {}
70  inline const SolverParameter& param() const { return param_; }
71  inline shared_ptr<Net<Dtype> > net() { return net_; }
72  inline const vector<shared_ptr<Net<Dtype> > >& test_nets() {
73  return test_nets_;
74  }
75  int iter() { return iter_; }
76 
77  // Invoked at specific points during an iteration
78  class Callback {
79  protected:
80  virtual void on_start() = 0;
81  virtual void on_gradients_ready() = 0;
82 
83  template <typename T>
84  friend class Solver;
85  };
86  const vector<Callback*>& callbacks() const { return callbacks_; }
87  void add_callback(Callback* value) {
88  callbacks_.push_back(value);
89  }
90 
91  void CheckSnapshotWritePermissions();
95  virtual inline const char* type() const { return ""; }
96 
97  protected:
98  // Make and apply the update value for the current iteration.
99  virtual void ApplyUpdate() = 0;
100  string SnapshotFilename(const string extension);
101  string SnapshotToBinaryProto();
102  string SnapshotToHDF5();
103  // The test routine
104  void TestAll();
105  void Test(const int test_net_id = 0);
106  virtual void SnapshotSolverState(const string& model_filename) = 0;
107  virtual void RestoreSolverStateFromHDF5(const string& state_file) = 0;
108  virtual void RestoreSolverStateFromBinaryProto(const string& state_file) = 0;
109  void DisplayOutputBlobs(const int net_id);
110  void UpdateSmoothedLoss(Dtype loss, int start_iter, int average_loss);
111 
112  SolverParameter param_;
113  int iter_;
114  int current_step_;
115  shared_ptr<Net<Dtype> > net_;
116  vector<shared_ptr<Net<Dtype> > > test_nets_;
117  vector<Callback*> callbacks_;
118  vector<Dtype> losses_;
119  Dtype smoothed_loss_;
120 
121  // The root solver that holds root nets (actually containing shared layers)
122  // in data parallelism
123  const Solver* const root_solver_;
124 
125  // A function that can be set by a client of the Solver to provide indication
126  // that it wants a snapshot saved and/or to exit early.
127  ActionCallback action_request_function_;
128 
129  // True iff a request to stop early was received.
130  bool requested_early_exit_;
131 
132  DISABLE_COPY_AND_ASSIGN(Solver);
133 };
134 
139 template <typename Dtype>
140 class WorkerSolver : public Solver<Dtype> {
141  public:
142  explicit WorkerSolver(const SolverParameter& param,
143  const Solver<Dtype>* root_solver = NULL)
144  : Solver<Dtype>(param, root_solver) {}
145 
146  protected:
147  void ApplyUpdate() {}
148  void SnapshotSolverState(const string& model_filename) {
149  LOG(FATAL) << "Should not be called on worker solver.";
150  }
151  void RestoreSolverStateFromBinaryProto(const string& state_file) {
152  LOG(FATAL) << "Should not be called on worker solver.";
153  }
154  void RestoreSolverStateFromHDF5(const string& state_file) {
155  LOG(FATAL) << "Should not be called on worker solver.";
156  }
157 };
158 
159 } // namespace caffe
160 
161 #endif // CAFFE_SOLVER_HPP_
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
Solver that only computes gradients, used as worker for multi-GPU training.
Definition: solver.hpp:140
Definition: solver.hpp:78
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:41
virtual const char * type() const
Returns the solver type.
Definition: solver.hpp:95
boost::function< SolverAction::Enum()> ActionCallback
Type of a function that returns a Solver Action enumeration.
Definition: solver.hpp:32