6 #ifndef TAPKEE_VPTREE_H_ 7 #define TAPKEE_VPTREE_H_ 20 namespace tapkee_internal
23 template<
class Type,
class RandomAccessIterator,
class DistanceCallback>
26 template<
class RandomAccessIterator,
class DistanceCallback>
30 const RandomAccessIterator
item;
32 callback(c), item(i) {}
33 inline bool operator()(
const RandomAccessIterator& a,
const RandomAccessIterator& b)
42 template<
class RandomAccessIterator,
class DistanceCallback>
46 const RandomAccessIterator& a,
const RandomAccessIterator& b)
54 template<
class RandomAccessIterator,
class DistanceCallback>
58 const RandomAccessIterator& a,
const RandomAccessIterator& b)
64 template<
class RandomAccessIterator,
class DistanceCallback>
71 begin(b), items(),
callback(c), tau(0.0), root(0)
74 for (RandomAccessIterator i=b; i!=e; ++i)
76 root = buildFromPoints(0, items.size());
86 std::vector<IndexType>
search(
const RandomAccessIterator& target,
int k)
88 std::vector<IndexType> results;
90 std::priority_queue<HeapItem> heap;
93 tau = std::numeric_limits<double>::max();
96 search(root, target, k, heap);
100 while(!heap.empty()) {
101 results.push_back(items[heap.top().index]-begin);
113 std::vector<RandomAccessIterator>
items;
125 index(0), threshold(0.),
162 if (upper - lower > 1)
165 std::swap(items[lower], items[i]);
167 int median = (upper + lower) / 2;
168 std::nth_element(items.begin() + lower + 1, items.begin() + median, items.begin() + upper,
171 node->
threshold = callback.distance(items[lower], items[median]);
173 node->
left = buildFromPoints(lower + 1, median);
174 node->
right = buildFromPoints(median, upper);
180 void search(
Node*
node,
const RandomAccessIterator& target,
int k, std::priority_queue<HeapItem>& heap)
185 double distance = callback.distance(items[node->
index], target);
189 if (heap.size() ==
static_cast<size_t>(k))
194 if (heap.size() ==
static_cast<size_t>(k))
195 tau = heap.top().distance;
198 if (node->
left == NULL && node->
right == NULL)
203 if (distance < node->threshold)
206 search(node->
left, target, k, heap);
209 search(node->
right, target, k, heap);
214 search(node->
right, target, k, heap);
217 search(node->
left, target, k, heap);
bool operator<(const HeapItem &o) const
ScalarType distance(Callback &cb, const CoverTreePoint< RandomAccessIterator > &l, const CoverTreePoint< RandomAccessIterator > &r, ScalarType upper_bound)
HeapItem(int i, double d)
DistanceCallback callback
bool operator()(DistanceCallback &callback, const RandomAccessIterator &item, const RandomAccessIterator &a, const RandomAccessIterator &b)
bool operator()(DistanceCallback &callback, const RandomAccessIterator &item, const RandomAccessIterator &a, const RandomAccessIterator &b)
DistanceComparator(const DistanceCallback &c, const RandomAccessIterator &i)
DistanceCallback callback
RandomAccessIterator begin
ScalarType uniform_random()
void search(Node *node, const RandomAccessIterator &target, int k, std::priority_queue< HeapItem > &heap)
Node * buildFromPoints(int lower, int upper)
bool operator()(const RandomAccessIterator &a, const RandomAccessIterator &b)
const RandomAccessIterator item
std::vector< RandomAccessIterator > items
std::vector< IndexType > search(const RandomAccessIterator &target, int k)
VantagePointTree(RandomAccessIterator b, RandomAccessIterator e, DistanceCallback c)