kmeans: allow parameterised distance metrics

This commit is contained in:
Danny Robson 2018-04-23 23:19:14 +10:00
parent 44d5b655ab
commit 337e53b725

View File

@ -26,9 +26,9 @@ namespace util {
// a simplistic implementation of Lloyd's algorithm // a simplistic implementation of Lloyd's algorithm
// //
// returns index of the closest output for each input // returns index of the closest output for each input
template <typename OutputT, typename InputT> template <typename OutputT, typename InputT, typename FunctionT>
std::vector<size_t> std::vector<size_t>
kmeans (util::view<InputT> src, util::view<OutputT> dst) kmeans (util::view<InputT> src, util::view<OutputT> dst, FunctionT const &&metric)
{ {
CHECK_GE (src.size (), dst.size ()); CHECK_GE (src.size (), dst.size ());
@ -48,7 +48,7 @@ namespace util {
size_t bucket = 0; size_t bucket = 0;
for (size_t k = 1; k < dst.size (); ++k) { for (size_t k = 1; k < dst.size (); ++k) {
if (norm2 (p - means[k]) < norm2 (p - means[bucket])) if (metric (p, means[k]) < metric (p, means[bucket]))
bucket = k; bucket = k;
} }
@ -65,4 +65,18 @@ namespace util {
return closest; return closest;
} }
template <typename OutputT, typename InputT>
auto
kmeans (InputT &&src, OutputT &&dst)
{
return kmeans (
std::forward<InputT> (src),
std::forward<OutputT> (dst),
[] (auto a, auto b) {
return distance (a, b);
});
}
} }