mlpack  2.0.1
validation_RMSE_termination.hpp
Go to the documentation of this file.
1 
14 #ifndef _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
15 #define _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
16 
17 #include <mlpack/core.hpp>
18 
19 namespace mlpack
20 {
21 namespace amf
22 {
23 
38 template <class MatType>
40 {
41  public:
52  size_t num_test_points,
53  double tolerance = 1e-5,
54  size_t maxIterations = 10000,
55  size_t reverseStepTolerance = 3)
58  num_test_points(num_test_points),
60  {
61  size_t n = V.n_rows;
62  size_t m = V.n_cols;
63 
64  // initialize validation set matrix
65  test_points.zeros(num_test_points, 3);
66 
67  // fill validation set matrix with random chosen entries
68  for(size_t i = 0; i < num_test_points; i++)
69  {
70  double t_val;
71  size_t t_row;
72  size_t t_col;
73 
74  // pick a random non-zero entry
75  do
76  {
77  t_row = rand() % n;
78  t_col = rand() % m;
79  } while((t_val = V(t_row, t_col)) == 0);
80 
81  // add the entry to the validation set
82  test_points(i, 0) = t_row;
83  test_points(i, 1) = t_col;
84  test_points(i, 2) = t_val;
85 
86  // nullify the added entry from data matrix (training set)
87  V(t_row, t_col) = 0;
88  }
89  }
90 
96  void Initialize(const MatType& /* V */)
97  {
98  iteration = 1;
99 
100  rmse = DBL_MAX;
101  rmseOld = DBL_MAX;
102 
103  c_index = 0;
104  c_indexOld = 0;
105 
106  reverseStepCount = 0;
107  isCopy = false;
108  }
109 
116  bool IsConverged(arma::mat& W, arma::mat& H)
117  {
118  arma::mat WH;
119 
120  WH = W * H;
121 
122  // compute validation RMSE
123  if (iteration != 0)
124  {
125  rmseOld = rmse;
126  rmse = 0;
127  for(size_t i = 0; i < num_test_points; i++)
128  {
129  size_t t_row = test_points(i, 0);
130  size_t t_col = test_points(i, 1);
131  double t_val = test_points(i, 2);
132  double temp = (t_val - WH(t_row, t_col));
133  temp *= temp;
134  rmse += temp;
135  }
137  rmse = sqrt(rmse);
138  }
139 
140  // increment iteration count
141  iteration++;
142 
143  // if RMSE tolerance is not satisfied
144  if((rmseOld - rmse) / rmseOld < tolerance && iteration > 4)
145  {
146  // check if this is a first of successive drops
147  if(reverseStepCount == 0 && isCopy == false)
148  {
149  // store a copy of W and H matrix
150  isCopy = true;
151  this->W = W;
152  this->H = H;
153  // store residue values
155  c_index = rmse;
156  }
157  // increase successive drop count
159  }
160  // if tolerance is satisfied
161  else
162  {
163  // initialize successive drop count
164  reverseStepCount = 0;
165  // if residue is droped below minimum scrap stored values
166  if(rmse <= c_indexOld && isCopy == true)
167  {
168  isCopy = false;
169  }
170  }
171 
172  // check if termination criterion is met
174  {
175  // if stored values are present replace them with current value as they
176  // represent the minimum residue point
177  if(isCopy)
178  {
179  W = this->W;
180  H = this->H;
181  rmse = c_index;
182  }
183  return true;
184  }
185  else return false;
186  }
187 
189  const double& Index() const { return rmse; }
190 
192  const size_t& Iteration() const { return iteration; }
193 
195  const size_t& NumTestPoints() const { return num_test_points; }
196 
198  const size_t& MaxIterations() const { return maxIterations; }
199  size_t& MaxIterations() { return maxIterations; }
200 
202  const double& Tolerance() const { return tolerance; }
203  double& Tolerance() { return tolerance; }
204 
205  private:
207  double tolerance;
212 
214  size_t iteration;
215 
217  arma::mat test_points;
218 
220  double rmseOld;
221  double rmse;
222 
227 
230  bool isCopy;
231 
233  arma::mat W;
234  arma::mat H;
235  double c_indexOld;
236  double c_index;
237 }; // class ValidationRMSETermination
238 
239 } // namespace amf
240 } // namespace mlpack
241 
242 
243 #endif // _MLPACK_METHODS_AMF_VALIDATIONRMSETERMINATION_HPP_INCLUDED
size_t reverseStepTolerance
tolerance on successive residue drops
bool isCopy
indicates whether a copy of information is available which corresponds to minimum residue point ...
const double & Tolerance() const
Access tolerance value.
Linear algebra utility functions, generally performed on matrices or vectors.
size_t num_test_points
number of validation test points
const size_t & Iteration() const
Get current iteration count.
This class implements validation termination policy based on RMSE index.
const size_t & NumTestPoints() const
Get number of validation points.
const double & Index() const
Get current value of residue.
arma::mat W
variables to store information of minimum residue point
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
ValidationRMSETermination(MatType &V, size_t num_test_points, double tolerance=1e-5, size_t maxIterations=10000, size_t reverseStepTolerance=3)
Create a validation set according to given parameters and nullifies this set in data matrix(training ...
const size_t & MaxIterations() const
Access upper limit of iteration count.
void Initialize(const MatType &)
Initializes the termination policy before stating the factorization.
bool IsConverged(arma::mat &W, arma::mat &H)
Check if termination criterio is met.