1 #ifndef CAFFE_SGD_SOLVERS_HPP_ 2 #define CAFFE_SGD_SOLVERS_HPP_ 7 #include "caffe/solver.hpp" 15 template <
typename Dtype>
18 explicit SGDSolver(
const SolverParameter& param)
20 explicit SGDSolver(
const string& param_file)
22 virtual inline const char*
type()
const {
return "SGD"; }
24 const vector<shared_ptr<Blob<Dtype> > >& history() {
return history_; }
26 virtual void ApplyUpdate();
27 Dtype GetLearningRate();
31 virtual void Normalize(
int param_id);
32 virtual void Regularize(
int param_id);
33 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
34 virtual void ClipGradients();
35 virtual void SnapshotSolverState(
const string& model_filename);
36 virtual void SnapshotSolverStateToBinaryProto(
const string& model_filename);
37 virtual void SnapshotSolverStateToHDF5(
const string& model_filename);
38 virtual void RestoreSolverStateFromHDF5(
const string& state_file);
39 virtual void RestoreSolverStateFromBinaryProto(
const string& state_file);
44 vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;
49 template <
typename Dtype>
56 virtual inline const char*
type()
const {
return "Nesterov"; }
59 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
64 template <
typename Dtype>
71 virtual inline const char*
type()
const {
return "AdaGrad"; }
74 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
75 void constructor_sanity_check() {
76 CHECK_EQ(0, this->param_.momentum())
77 <<
"Momentum cannot be used with AdaGrad.";
84 template <
typename Dtype>
91 virtual inline const char*
type()
const {
return "RMSProp"; }
94 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
95 void constructor_sanity_check() {
96 CHECK_EQ(0, this->param_.momentum())
97 <<
"Momentum cannot be used with RMSProp.";
98 CHECK_GE(this->param_.rms_decay(), 0)
99 <<
"rms_decay should lie between 0 and 1.";
100 CHECK_LT(this->param_.rms_decay(), 1)
101 <<
"rms_decay should lie between 0 and 1.";
107 template <
typename Dtype>
114 virtual inline const char*
type()
const {
return "AdaDelta"; }
117 void AdaDeltaPreSolve();
118 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
131 template <
typename Dtype>
134 explicit AdamSolver(
const SolverParameter& param)
138 virtual inline const char*
type()
const {
return "Adam"; }
142 virtual void ComputeUpdateValue(
int param_id, Dtype rate);
149 #endif // CAFFE_SGD_SOLVERS_HPP_ virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:56
A layer factory that allows one to register layers. During runtime, registered layers can be called b...
Definition: blob.hpp:14
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:22
Optimizes the parameters of a Net using stochastic gradient descent (SGD) with momentum.
Definition: sgd_solvers.hpp:16
An interface for classes that perform optimization on Nets.
Definition: solver.hpp:42
AdamSolver, an algorithm for first-order gradient-based optimization of stochastic objective function...
Definition: sgd_solvers.hpp:132
Definition: sgd_solvers.hpp:108
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:114
Definition: sgd_solvers.hpp:50
Definition: sgd_solvers.hpp:85
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:91
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:71
Definition: sgd_solvers.hpp:65
virtual const char * type() const
Returns the solver type.
Definition: sgd_solvers.hpp:138