Skip to content

Commit

Permalink
fixed hierarchical clusterization with ward (#998)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniyal Aliev <daniial.aliev@abbyy.com>
Co-authored-by: Valeriy Fedyunin <valery.fedyunin@abbyy.com>
  • Loading branch information
daniyalaliev and Valeriy Fedyunin authored Nov 29, 2023
1 parent 73c9fb1 commit edc6049
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 11 deletions.
11 changes: 6 additions & 5 deletions NeoML/src/TraditionalML/NaiveHierarchicalClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,8 @@ void CNaiveHierarchicalClustering::mergeClusters( int first, int newClusterIndex
if( i == first || clusters[i] == nullptr ) {
continue;
}
const float distance = recalcDistance( *clusters[i], *clusters[first], firstSize, secondSize,
const float distance = recalcDistance( *clusters[i], *clusters[first], firstSize,
secondSize, clusters[i]->GetElementsCount(),
i < first ? distances[i][first] : prevDistances[i],
i < second ? distances[i][second] : distances[second][i],
mergeDistance );
Expand All @@ -275,7 +276,7 @@ void CNaiveHierarchicalClustering::mergeClusters( int first, int newClusterIndex

// Calculates distance between current cluster and merged cluster based on
float CNaiveHierarchicalClustering::recalcDistance( const CCommonCluster& currCluster, const CCommonCluster& mergedCluster,
int firstSize, int secondSize, float currToFirst, float currToSecond, float firstToSecond ) const
int firstSize, int secondSize, int currSize, float currToFirst, float currToSecond, float firstToSecond ) const
{
static_assert( CHierarchicalClustering::L_Count == 5, "L_Count != 5" );

Expand All @@ -298,9 +299,9 @@ float CNaiveHierarchicalClustering::recalcDistance( const CCommonCluster& currCl
return ::fmaxf( currToFirst, currToSecond );
case CHierarchicalClustering::L_Ward:
{
const int mergedSize = firstSize + secondSize;
return ( firstSize * currToFirst + secondSize * currToSecond
- ( firstSize * secondSize * firstToSecond ) / mergedSize ) / mergedSize;
const int sumSize = firstSize + secondSize + currSize;
return ( firstSize + currSize ) * currToFirst / sumSize + ( secondSize + currSize ) * currToSecond / sumSize -
currSize * firstToSecond / sumSize;
}
default:
NeoAssert( false );
Expand Down
2 changes: 1 addition & 1 deletion NeoML/src/TraditionalML/NaiveHierarchicalClustering.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class CNaiveHierarchicalClustering {
void findNearestClusters( int& first ) const;
void mergeClusters( int first, int newClusterIndex, CArray<CMergeInfo>* dendrogram );
float recalcDistance( const CCommonCluster& currCluster, const CCommonCluster& mergedCluster,
int firstSize, int secondSize, float currToFirst, float currToSecond, float firstToSecond ) const;
int firstSize, int secondSize, int currSize, float currToFirst, float currToSecond, float firstToSecond ) const;
void fillResult( const CFloatMatrixDesc& matrix, CClusteringResult& result, CArray<int>* dendrogramIndices ) const;
};

Expand Down
11 changes: 6 additions & 5 deletions NeoML/src/TraditionalML/NnChainHierarchicalClustering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ namespace NeoML {
// Recalculates distance between current and the result of the merge of the first and second clusters
// Doesn't support centroid linkage!
static float recalcDistance( CHierarchicalClustering::TLinkage linkage, TDistanceFunc distance, int firstSize, int secondSize,
float currToFirst, float currToSecond, float firstToSecond )
int currSize, float currToFirst, float currToSecond, float firstToSecond )
{
switch( linkage ) {
case CHierarchicalClustering::L_Single:
Expand All @@ -44,9 +44,9 @@ static float recalcDistance( CHierarchicalClustering::TLinkage linkage, TDistanc
return ::fmaxf( currToFirst, currToSecond );
case CHierarchicalClustering::L_Ward:
{
const int mergedSize = firstSize + secondSize;
return ( firstSize * currToFirst + secondSize * currToSecond
- ( firstSize * secondSize * firstToSecond ) / mergedSize ) / mergedSize;
const int sumSize = firstSize + secondSize + currSize;
return ( firstSize + currSize ) * currToFirst / sumSize + ( secondSize + currSize ) * currToSecond / sumSize -
currSize * firstToSecond / sumSize;
}
case CHierarchicalClustering::L_Centroid:
default:
Expand Down Expand Up @@ -211,7 +211,8 @@ void CNnChainHierarchicalClustering::mergeClusters( int first, int second )
continue;
}
// We can pass ref to any cluster here because linkage isn't centroid
const float distance = recalcDistance( params.Linkage, params.DistanceType, firstSize, secondSize,
const float distance = recalcDistance( params.Linkage, params.DistanceType, firstSize,
secondSize, clusterSizes[i],
i < first ? distances[i][first] : distances[first][i],
i < second ? distances[i][second] : distances[second][i],
mergeDistance );
Expand Down

0 comments on commit edc6049

Please sign in to comment.