kmeans: allow parameterised distance metrics
This commit is contained in:
parent
44d5b655ab
commit
337e53b725
20
kmeans.hpp
20
kmeans.hpp
@ -26,9 +26,9 @@ namespace util {
|
||||
// a simplistic implementation of Lloyd's algorithm
|
||||
//
|
||||
// returns index of the closest output for each input
|
||||
template <typename OutputT, typename InputT>
|
||||
template <typename OutputT, typename InputT, typename FunctionT>
|
||||
std::vector<size_t>
|
||||
kmeans (util::view<InputT> src, util::view<OutputT> dst)
|
||||
kmeans (util::view<InputT> src, util::view<OutputT> dst, FunctionT const &&metric)
|
||||
{
|
||||
CHECK_GE (src.size (), dst.size ());
|
||||
|
||||
@ -48,7 +48,7 @@ namespace util {
|
||||
size_t bucket = 0;
|
||||
|
||||
for (size_t k = 1; k < dst.size (); ++k) {
|
||||
if (norm2 (p - means[k]) < norm2 (p - means[bucket]))
|
||||
if (metric (p, means[k]) < metric (p, means[bucket]))
|
||||
bucket = k;
|
||||
}
|
||||
|
||||
@ -65,4 +65,18 @@ namespace util {
|
||||
|
||||
return closest;
|
||||
}
|
||||
|
||||
|
||||
template <typename OutputT, typename InputT>
|
||||
auto
|
||||
kmeans (InputT &&src, OutputT &&dst)
|
||||
{
|
||||
return kmeans (
|
||||
std::forward<InputT> (src),
|
||||
std::forward<OutputT> (dst),
|
||||
[] (auto a, auto b) {
|
||||
return distance (a, b);
|
||||
});
|
||||
}
|
||||
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user