mlpack  2.0.1
svd_complete_incremental_learning.hpp
Go to the documentation of this file.
1 
14 #ifndef __MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
15 #define __MLPACK_METHODS_AMF_SVD_COMPLETE_INCREMENTAL_LEARNING_HPP
16 
17 #include <mlpack/core.hpp>
18 
19 namespace mlpack
20 {
21 namespace amf
22 {
23 
46 template <class MatType>
48 {
49  public:
59  double kw = 0,
60  double kh = 0)
61  : u(u), kw(kw), kh(kh)
62  {
63  // Nothing to do.
64  }
65 
74  void Initialize(const MatType& /* dataset */, const size_t /* rank */)
75  {
76  // Initialize the current score counters.
77  currentUserIndex = 0;
78  currentItemIndex = 0;
79  }
80 
89  inline void WUpdate(const MatType& V,
90  arma::mat& W,
91  const arma::mat& H)
92  {
93  arma::mat deltaW;
94  deltaW.zeros(1, W.n_cols);
95 
96  // Loop until a non-zero entry is found.
97  while(true)
98  {
99  const double val = V(currentItemIndex, currentUserIndex);
100  // Update feature vector if current entry is non-zero and break the loop.
101  if (val != 0)
102  {
103  deltaW += (val - arma::dot(W.row(currentItemIndex),
104  H.col(currentUserIndex))) * H.col(currentUserIndex).t();
105 
106  // Add regularization.
107  if (kw != 0)
108  deltaW -= kw * W.row(currentItemIndex);
109  break;
110  }
111  }
112 
113  W.row(currentItemIndex) += u * deltaW;
114  }
115 
125  inline void HUpdate(const MatType& V,
126  const arma::mat& W,
127  arma::mat& H)
128  {
129  arma::mat deltaH;
130  deltaH.zeros(H.n_rows, 1);
131 
132  const double val = V(currentItemIndex, currentUserIndex);
133 
134  // Update H matrix based on the non-zero entry found in WUpdate function.
135  deltaH += (val - arma::dot(W.row(currentItemIndex),
136  H.col(currentUserIndex))) * W.row(currentItemIndex).t();
137  // Add regularization.
138  if (kh != 0)
139  deltaH -= kh * H.col(currentUserIndex);
140 
141  // Move on to the next entry.
143  if (currentUserIndex == V.n_rows)
144  {
145  currentUserIndex = 0;
146  currentItemIndex = (currentItemIndex + 1) % V.n_cols;
147  }
148 
149  H.col(currentUserIndex++) += u * deltaH;
150  }
151 
152  private:
154  double u;
156  double kw;
158  double kh;
159 
164 };
165 
168 
170 template<>
172 {
173  public:
175  double kw = 0,
176  double kh = 0)
177  : u(u), kw(kw), kh(kh), it(NULL)
178  {}
179 
181  {
182  delete it;
183  }
184 
185  void Initialize(const arma::sp_mat& dataset, const size_t rank)
186  {
187  (void)rank;
188  n = dataset.n_rows;
189  m = dataset.n_cols;
190 
191  it = new arma::sp_mat::const_iterator(dataset.begin());
192  isStart = true;
193  }
194 
204  inline void WUpdate(const arma::sp_mat& V,
205  arma::mat& W,
206  const arma::mat& H)
207  {
208  if(!isStart) (*it)++;
209  else isStart = false;
210 
211  if(*it == V.end())
212  {
213  delete it;
214  it = new arma::sp_mat::const_iterator(V.begin());
215  }
216 
217  size_t currentUserIndex = it->col();
218  size_t currentItemIndex = it->row();
219 
220  arma::mat deltaW(1, W.n_cols);
221  deltaW.zeros();
222 
223  deltaW += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
224  * arma::trans(H.col(currentUserIndex));
225  if(kw != 0) deltaW -= kw * W.row(currentItemIndex);
226 
227  W.row(currentItemIndex) += u*deltaW;
228  }
229 
239  inline void HUpdate(const arma::sp_mat& V,
240  const arma::mat& W,
241  arma::mat& H)
242  {
243  (void)V;
244 
245  arma::mat deltaH(H.n_rows, 1);
246  deltaH.zeros();
247 
248  size_t currentUserIndex = it->col();
249  size_t currentItemIndex = it->row();
250 
251  deltaH += (**it - arma::dot(W.row(currentItemIndex), H.col(currentUserIndex)))
252  * arma::trans(W.row(currentItemIndex));
253  if(kh != 0) deltaH -= kh * H.col(currentUserIndex);
254 
255  H.col(currentUserIndex) += u * deltaH;
256  }
257 
258  private:
259  double u;
260  double kw;
261  double kh;
262 
263  size_t n;
264  size_t m;
265 
266  arma::sp_mat dummy;
267  arma::sp_mat::const_iterator* it;
268 
269  bool isStart;
270 }; // class SVDCompleteIncrementalLearning
271 
272 } // namespace amf
273 } // namespace mlpack
274 
275 #endif
276 
void HUpdate(const arma::sp_mat &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
This class computes SVD using complete incremental batch learning, as described in the following pape...
void WUpdate(const MatType &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Linear algebra utility functions, generally performed on matrices or vectors.
void HUpdate(const MatType &V, const arma::mat &W, arma::mat &H)
The update rule for the encoding matrix H.
double kw
Regularization parameter for matrix W.
void WUpdate(const arma::sp_mat &V, arma::mat &W, const arma::mat &H)
The update rule for the basis matrix W.
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
SVDCompleteIncrementalLearning(double u=0.0001, double kw=0, double kh=0)
Initialize the SVDCompleteIncrementalLearning class with the given parameters.
double kh
Regularization parameter for matrix H.
void Initialize(const MatType &, const size_t)
Initialize parameters before factorization.
void Initialize(const arma::sp_mat &dataset, const size_t rank)