mlpack  2.0.1
mean_shift.hpp
Go to the documentation of this file.
1 
15 #ifndef __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
16 #define __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
17 
18 #include <mlpack/core.hpp>
22 #include <boost/utility.hpp>
23 
24 namespace mlpack {
25 namespace meanshift {
26 
48 template<bool UseKernel = false,
49  typename KernelType = kernel::GaussianKernel,
50  typename MatType = arma::mat>
51 class MeanShift
52 {
53  public:
65  MeanShift(const double radius = 0,
66  const size_t maxIterations = 1000,
67  const KernelType kernel = KernelType());
68 
75  double EstimateRadius(const MatType& data, const double ratio = 0.2);
76 
86  void Cluster(const MatType& data,
87  arma::Col<size_t>& assignments,
88  arma::mat& centroids,
89  bool useSeeds = true);
90 
92  size_t MaxIterations() const { return maxIterations; }
94  size_t& MaxIterations() { return maxIterations; }
95 
97  double Radius() const { return radius; }
99  void Radius(double radius);
100 
102  const KernelType& Kernel() const { return kernel; }
104  KernelType& Kernel() { return kernel; }
105 
106  private:
120  void GenSeeds(const MatType& data,
121  const double binSize,
122  const int minFreq,
123  MatType& seeds);
124 
133  template<bool ApplyKernel = UseKernel>
134  typename std::enable_if<ApplyKernel, bool>::type
135  CalculateCentroid(const MatType& data,
136  const std::vector<size_t>& neighbors,
137  const std::vector<double>& distances,
138  arma::colvec& centroid);
139 
148  template<bool ApplyKernel = UseKernel>
149  typename std::enable_if<!ApplyKernel, bool>::type
150  CalculateCentroid(const MatType& data,
151  const std::vector<size_t>& neighbors,
152  const std::vector<double>&, /*unused*/
153  arma::colvec& centroid);
154 
160  double radius;
161 
164 
166  KernelType kernel;
167 };
168 
169 } // namespace meanshift
170 } // namespace mlpack
171 
172 // Include implementation.
173 #include "mean_shift_impl.hpp"
174 
175 #endif // __MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
const KernelType & Kernel() const
Get the kernel.
Definition: mean_shift.hpp:102
double Radius() const
Get the radius.
Definition: mean_shift.hpp:97
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
Linear algebra utility functions, generally performed on matrices or vectors.
KernelType & Kernel()
Modify the kernel.
Definition: mean_shift.hpp:104
This class implements mean shift clustering.
Definition: mean_shift.hpp:51
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
void Cluster(const MatType &data, arma::Col< size_t > &assignments, arma::mat &centroids, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
std::enable_if< ApplyKernel, bool >::type CalculateCentroid(const MatType &data, const std::vector< size_t > &neighbors, const std::vector< double > &distances, arma::colvec &centroid)
Use kernel to calculate new centroid given dataset and valid neighbors.
size_t MaxIterations() const
Get the maximum number of iterations.
Definition: mean_shift.hpp:92
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
size_t & MaxIterations()
Set the maximum number of iterations.
Definition: mean_shift.hpp:94
double radius
If distance of two centroids is less than radius, one will be removed.
Definition: mean_shift.hpp:160
KernelType kernel
Instantiated kernel.
Definition: mean_shift.hpp:166
The standard Gaussian kernel.
size_t maxIterations
Maximum number of iterations before giving up.
Definition: mean_shift.hpp:163
void GenSeeds(const MatType &data, const double binSize, const int minFreq, MatType &seeds)
To speed up, we can generate some seeds from data set and use them as initial centroids rather than a...