kmeans: allow parameterised distance metrics
This commit is contained in:
parent
44d5b655ab
commit
337e53b725
20
kmeans.hpp
20
kmeans.hpp
@ -26,9 +26,9 @@ namespace util {
|
|||||||
// a simplistic implementation of Lloyd's algorithm
|
// a simplistic implementation of Lloyd's algorithm
|
||||||
//
|
//
|
||||||
// returns index of the closest output for each input
|
// returns index of the closest output for each input
|
||||||
template <typename OutputT, typename InputT>
|
template <typename OutputT, typename InputT, typename FunctionT>
|
||||||
std::vector<size_t>
|
std::vector<size_t>
|
||||||
kmeans (util::view<InputT> src, util::view<OutputT> dst)
|
kmeans (util::view<InputT> src, util::view<OutputT> dst, FunctionT const &&metric)
|
||||||
{
|
{
|
||||||
CHECK_GE (src.size (), dst.size ());
|
CHECK_GE (src.size (), dst.size ());
|
||||||
|
|
||||||
@ -48,7 +48,7 @@ namespace util {
|
|||||||
size_t bucket = 0;
|
size_t bucket = 0;
|
||||||
|
|
||||||
for (size_t k = 1; k < dst.size (); ++k) {
|
for (size_t k = 1; k < dst.size (); ++k) {
|
||||||
if (norm2 (p - means[k]) < norm2 (p - means[bucket]))
|
if (metric (p, means[k]) < metric (p, means[bucket]))
|
||||||
bucket = k;
|
bucket = k;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -65,4 +65,18 @@ namespace util {
|
|||||||
|
|
||||||
return closest;
|
return closest;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
template <typename OutputT, typename InputT>
|
||||||
|
auto
|
||||||
|
kmeans (InputT &&src, OutputT &&dst)
|
||||||
|
{
|
||||||
|
return kmeans (
|
||||||
|
std::forward<InputT> (src),
|
||||||
|
std::forward<OutputT> (dst),
|
||||||
|
[] (auto a, auto b) {
|
||||||
|
return distance (a, b);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user