From 2dbf22911467c7e50afbab681b83d5a16c093e85 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 19 Oct 2023 16:05:17 +0900 Subject: [PATCH 1/4] reduce computational cost --- include/SPTK/math/statistics_accumulation.h | 5 ++++- src/main/vstat.cc | 6 ++++-- src/math/statistics_accumulation.cc | 6 ++++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/include/SPTK/math/statistics_accumulation.h b/include/SPTK/math/statistics_accumulation.h index df06c47e..8cc34caa 100644 --- a/include/SPTK/math/statistics_accumulation.h +++ b/include/SPTK/math/statistics_accumulation.h @@ -75,8 +75,10 @@ class StatisticsAccumulation { /** * @param[in] num_order Order of vector, @f$M@f$. * @param[in] num_statistics_order Order of statistics, @f$K@f$. + * @param[in] diagonal If true, only diagonal part is accumulated. */ - StatisticsAccumulation(int num_order, int num_statistics_order); + StatisticsAccumulation(int num_order, int num_statistics_order, + bool diagonal = false); virtual ~StatisticsAccumulation() { } @@ -186,6 +188,7 @@ class StatisticsAccumulation { private: const int num_order_; const int num_statistics_order_; + const bool diagonal_; bool is_valid_; diff --git a/src/main/vstat.cc b/src/main/vstat.cc index c51ae539..b0d41d1b 100644 --- a/src/main/vstat.cc +++ b/src/main/vstat.cc @@ -459,8 +459,10 @@ int main(int argc, char* argv[]) { } std::istream& input_stream(ifs.is_open() ? ifs : std::cin); - sptk::StatisticsAccumulation accumulation(vector_length - 1, - kMean == output_format ? 1 : 2); + sptk::StatisticsAccumulation accumulation( + vector_length - 1, kMean == output_format ? 1 : 2, + outputs_only_diagonal_elements || kStandardDeviation == output_format || + kMeanAndLowerAndUpperBounds == output_format); sptk::StatisticsAccumulation::Buffer buffer; if (!accumulation.IsValid()) { std::ostringstream error_message; diff --git a/src/math/statistics_accumulation.cc b/src/math/statistics_accumulation.cc index eb6e3db9..ef3a9099 100644 --- a/src/math/statistics_accumulation.cc +++ b/src/math/statistics_accumulation.cc @@ -24,9 +24,11 @@ namespace sptk { StatisticsAccumulation::StatisticsAccumulation(int num_order, - int num_statistics_order) + int num_statistics_order, + bool diagonal) : num_order_(num_order), num_statistics_order_(num_statistics_order), + diagonal_(diagonal), is_valid_(true) { if (num_order_ < 0 || num_statistics_order_ < 0 || 2 < num_statistics_order_) { @@ -242,7 +244,7 @@ bool StatisticsAccumulation::Run(const std::vector& data, // Accumulate 2nd order statistics. if (2 <= num_statistics_order_) { for (int i(0); i < length; ++i) { - for (int j(0); j <= i; ++j) { + for (int j(diagonal_ ? i : 0); j <= i; ++j) { buffer->second_order_statistics_[i][j] += data[i] * data[j]; } } From 51b03b0b57dc29c692d1b0f74f0b8b172ae22ddb Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 19 Oct 2023 16:58:13 +0900 Subject: [PATCH 2/4] strip -o 5 -d --- src/main/vstat.cc | 30 ++++++++++++++++-------------- test/test_vstat.bats | 4 ++-- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/main/vstat.cc b/src/main/vstat.cc index b0d41d1b..0cfaa4d0 100644 --- a/src/main/vstat.cc +++ b/src/main/vstat.cc @@ -73,6 +73,8 @@ void PrintUsage(std::ostream* stream) { *stream << " vectors (double)[stdin]" << std::endl; *stream << " stdout:" << std::endl; *stream << " statistics (double)" << std::endl; + *stream << " notice:" << std::endl; + *stream << " -d is valid only if o = 0 or o = 2" << std::endl; *stream << std::endl; *stream << " SPTK: version " << sptk::kVersion << std::endl; *stream << std::endl; @@ -155,20 +157,12 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, return false; } - if (outputs_only_diagonal_elements) { - for (int i(0); i < vector_length; ++i) { - if (!sptk::WriteStream(precision_matrix[i][i], &std::cout)) { + for (int i(0); i < vector_length; ++i) { + for (int j(0); j < vector_length; ++j) { + if (!sptk::WriteStream(precision_matrix[i][j], &std::cout)) { return false; } } - } else { - for (int i(0); i < vector_length; ++i) { - for (int j(0); j < vector_length; ++j) { - if (!sptk::WriteStream(precision_matrix[i][j], &std::cout)) { - return false; - } - } - } } } @@ -459,10 +453,18 @@ int main(int argc, char* argv[]) { } std::istream& input_stream(ifs.is_open() ? ifs : std::cin); + bool diagonal(false); + if (kMeanAndCovariance == output_format || kCovariance == output_format) { + if (outputs_only_diagonal_elements) { + diagonal = true; + } + } else if (kStandardDeviation == output_format || + kMeanAndLowerAndUpperBounds == output_format) { + diagonal = true; + } + sptk::StatisticsAccumulation accumulation( - vector_length - 1, kMean == output_format ? 1 : 2, - outputs_only_diagonal_elements || kStandardDeviation == output_format || - kMeanAndLowerAndUpperBounds == output_format); + vector_length - 1, kMean == output_format ? 1 : 2, diagonal); sptk::StatisticsAccumulation::Buffer buffer; if (!accumulation.IsValid()) { std::ostringstream error_message; diff --git a/test/test_vstat.bats b/test/test_vstat.bats index 9f0e3630..ebf316a2 100755 --- a/test/test_vstat.bats +++ b/test/test_vstat.bats @@ -30,8 +30,8 @@ teardown() { @test "vstat: compatibility" { $sptk3/nrand -l 100 > $tmp/0 - ary1=("-o 0" "-o 1" "-o 2" "-o 2 -d" "-o 2 -r" "-o 2 -i") - ary2=("-o 0" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5") + ary1=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 2 -r" "-o 2 -i") + ary2=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5 -d") for i in $(seq 0 $((${#ary1[@]} - 1))); do # shellcheck disable=SC2086 $sptk3/vstat -l 2 $tmp/0 ${ary1[$i]} > $tmp/1 From 37beca3e0b12cf8d208b4a3f4373b5b6241dafc3 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 19 Oct 2023 17:06:27 +0900 Subject: [PATCH 3/4] minor fix --- src/main/vstat.cc | 24 ++++++------------------ test/test_vstat.bats | 2 +- 2 files changed, 7 insertions(+), 19 deletions(-) diff --git a/src/main/vstat.cc b/src/main/vstat.cc index 0cfaa4d0..1e996a0e 100644 --- a/src/main/vstat.cc +++ b/src/main/vstat.cc @@ -111,12 +111,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, if (!accumulation.GetFullCovariance(buffer, &variance)) { return false; } - for (int i(0); i < vector_length; ++i) { - for (int j(0); j < vector_length; ++j) { - if (!sptk::WriteStream(variance[i][j], &std::cout)) { - return false; - } - } + if (!sptk::WriteStream(variance, &std::cout)) { + return false; } } } @@ -137,12 +133,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, if (!accumulation.GetCorrelation(buffer, &correlation)) { return false; } - for (int i(0); i < vector_length; ++i) { - for (int j(0); j < vector_length; ++j) { - if (!sptk::WriteStream(correlation[i][j], &std::cout)) { - return false; - } - } + if (!sptk::WriteStream(correlation, &std::cout)) { + return false; } } @@ -157,12 +149,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, return false; } - for (int i(0); i < vector_length; ++i) { - for (int j(0); j < vector_length; ++j) { - if (!sptk::WriteStream(precision_matrix[i][j], &std::cout)) { - return false; - } - } + if (!sptk::WriteStream(precision_matrix, &std::cout)) { + return false; } } diff --git a/test/test_vstat.bats b/test/test_vstat.bats index ebf316a2..6d311ca6 100755 --- a/test/test_vstat.bats +++ b/test/test_vstat.bats @@ -31,7 +31,7 @@ teardown() { $sptk3/nrand -l 100 > $tmp/0 ary1=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 2 -r" "-o 2 -i") - ary2=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5 -d") + ary2=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5") for i in $(seq 0 $((${#ary1[@]} - 1))); do # shellcheck disable=SC2086 $sptk3/vstat -l 2 $tmp/0 ${ary1[$i]} > $tmp/1 From 195b986e39b0fc981a1fce59b0797848383fce29 Mon Sep 17 00:00:00 2001 From: takenori-y Date: Thu, 19 Oct 2023 17:14:48 +0900 Subject: [PATCH 4/4] add safeguard --- src/math/statistics_accumulation.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/math/statistics_accumulation.cc b/src/math/statistics_accumulation.cc index ef3a9099..a2c4750a 100644 --- a/src/math/statistics_accumulation.cc +++ b/src/math/statistics_accumulation.cc @@ -133,7 +133,7 @@ bool StatisticsAccumulation::GetStandardDeviation( bool StatisticsAccumulation::GetFullCovariance( const StatisticsAccumulation::Buffer& buffer, SymmetricMatrix* full_covariance) const { - if (!is_valid_ || num_statistics_order_ < 2 || + if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ || buffer.zeroth_order_statistics_ <= 0 || NULL == full_covariance) { return false; } @@ -162,7 +162,7 @@ bool StatisticsAccumulation::GetFullCovariance( bool StatisticsAccumulation::GetUnbiasedCovariance( const StatisticsAccumulation::Buffer& buffer, SymmetricMatrix* unbiased_covariance) const { - if (!is_valid_ || num_statistics_order_ < 2 || + if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ || buffer.zeroth_order_statistics_ <= 1 || NULL == unbiased_covariance) { return false; } @@ -185,7 +185,8 @@ bool StatisticsAccumulation::GetUnbiasedCovariance( bool StatisticsAccumulation::GetCorrelation( const StatisticsAccumulation::Buffer& buffer, SymmetricMatrix* correlation) const { - if (!is_valid_ || num_statistics_order_ < 2 || NULL == correlation) { + if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ || + NULL == correlation) { return false; }