diff --git a/include/SPTK/math/statistics_accumulation.h b/include/SPTK/math/statistics_accumulation.h index de6dfad..d89a219 100644 --- a/include/SPTK/math/statistics_accumulation.h +++ b/include/SPTK/math/statistics_accumulation.h @@ -122,7 +122,7 @@ class StatisticsAccumulation { * @return True on success, false on failure. */ bool GetFirst(const StatisticsAccumulation::Buffer& buffer, - std::vector* first) const; + std::vector* first) const; /** * @param[in] buffer Buffer. @@ -205,6 +205,19 @@ class StatisticsAccumulation { bool Run(const std::vector& data, StatisticsAccumulation::Buffer* buffer) const; + /** + * Merge statistics. + * + * @param[in] num_data Number of data. + * @param[in] first First order statistics. + * @param[in] second Second order statistics. + * @param[in,out] buffer Buffer. + * @return True on success, false on failure. + */ + bool Merge(int num_data, const std::vector& first, + const SymmetricMatrix& second, + StatisticsAccumulation::Buffer* buffer) const; + private: const int num_order_; const int num_statistics_order_; diff --git a/include/SPTK/utils/misc_utils.h b/include/SPTK/utils/misc_utils.h index 79e487d..11a3c54 100644 --- a/include/SPTK/utils/misc_utils.h +++ b/include/SPTK/utils/misc_utils.h @@ -128,6 +128,23 @@ bool ComputeFirstOrderRegressionCoefficients(int n, bool ComputeSecondOrderRegressionCoefficients( int n, std::vector* coefficients); +/** + * Compute lower and upper bounds. + * + * @param[in] confidence_level Confidence level. + * @param[in] num_data Number of data. + * @param[in] mean Mean vector. + * @param[in] variance Variance vector. + * @param[out] lower_bound Lower bound. + * @param[out] upper_bound Upper bound. + * @return True on success, false on failure. + */ +bool ComputeLowerAndUpperBounds(double confidence_level, int num_data, + const std::vector mean, + const std::vector variance, + std::vector* lower_bound, + std::vector* upper_bound); + } // namespace sptk #endif // SPTK_UTILS_MISC_UTILS_H_ diff --git a/src/main/imsvq.cc b/src/main/imsvq.cc index aa423bc..131569a 100644 --- a/src/main/imsvq.cc +++ b/src/main/imsvq.cc @@ -81,7 +81,7 @@ void PrintUsage(std::ostream* stream) { */ int main(int argc, char* argv[]) { int num_order(kDefaultNumOrder); - std::vector codebook_vectors_file; + std::vector codebook_vectors_file; for (;;) { const int option_char(getopt_long(argc, argv, "l:m:s:h", NULL, NULL)); diff --git a/src/main/msvq.cc b/src/main/msvq.cc index eb5b4ee..a7b9251 100644 --- a/src/main/msvq.cc +++ b/src/main/msvq.cc @@ -80,7 +80,7 @@ void PrintUsage(std::ostream* stream) { */ int main(int argc, char* argv[]) { int num_order(kDefaultNumOrder); - std::vector codebook_vectors_file; + std::vector codebook_vectors_file; for (;;) { const int option_char(getopt_long(argc, argv, "l:m:s:h", NULL, NULL)); diff --git a/src/main/vstat.cc b/src/main/vstat.cc index a46a313..b63dd3a 100644 --- a/src/main/vstat.cc +++ b/src/main/vstat.cc @@ -37,7 +37,7 @@ enum OutputFormats { kCorrelation, kPrecision, kMeanAndLowerAndUpperBounds, - kForMerge, + kSufficientStatistics, kNumOutputFormats }; @@ -68,7 +68,8 @@ void PrintUsage(std::ostream* stream) { *stream << " 4 (correlation)" << std::endl; *stream << " 5 (precision)" << std::endl; *stream << " 6 (mean and lower/upper bounds)" << std::endl; - *stream << " 7 (statistics for merge)" << std::endl; + *stream << " 7 (sufficient statistics)" << std::endl; + *stream << " -s s : statistics file (string)[" << std::setw(5) << std::right << "N/A" << "]" << std::endl; // NOLINT *stream << " -d : output only diagonal ( bool)[" << std::setw(5) << std::right << sptk::ConvertBooleanToString(kDefaultOutputOnlyDiagonalElementsFlag) << "]" << std::endl; // NOLINT *stream << " elements" << std::endl; *stream << " -e : use a neumerically ( bool)[" << std::setw(5) << std::right << sptk::ConvertBooleanToString(kDefaultNeumericallyStableFlag) << "]" << std::endl; // NOLINT @@ -80,6 +81,7 @@ void PrintUsage(std::ostream* stream) { *stream << " statistics (double)" << std::endl; *stream << " notice:" << std::endl; *stream << " -d is valid only if o = 0 or o = 2" << std::endl; + *stream << " -s can be specified multiple times" << std::endl; *stream << std::endl; *stream << " SPTK: version " << sptk::kVersion << std::endl; *stream << std::endl; @@ -160,37 +162,27 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, } if (kMeanAndLowerAndUpperBounds == output_format) { - int num_vector; - if (!accumulation.GetNumData(buffer, &num_vector)) { - return false; - } - - const int degrees_of_freedom(num_vector - 1); - if (0 == degrees_of_freedom) { - return false; - } - double t; - if (!sptk::ComputePercentagePointOfTDistribution( - 0.5 * (1.0 - confidence_level / 100.0), degrees_of_freedom, &t)) { + int num_data; + if (!accumulation.GetNumData(buffer, &num_data)) { return false; } std::vector mean(vector_length); - std::vector variance(vector_length); if (!accumulation.GetMean(buffer, &mean)) { return false; } + std::vector variance(vector_length); if (!accumulation.GetDiagonalCovariance(buffer, &variance)) { return false; } - const double inverse_degrees_of_freedom(1.0 / degrees_of_freedom); std::vector lower_bound(vector_length); std::vector upper_bound(vector_length); - for (int i(0); i < vector_length; ++i) { - const double error(std::sqrt(variance[i] * inverse_degrees_of_freedom)); - lower_bound[i] = mean[i] - t * error; - upper_bound[i] = mean[i] + t * error; + if (!sptk::ComputeLowerAndUpperBounds(confidence_level, num_data, mean, + variance, &lower_bound, + &upper_bound)) { + return false; } + if (!sptk::WriteStream(0, vector_length, lower_bound, &std::cout, NULL)) { return false; } @@ -199,12 +191,12 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, } } - if (kForMerge == output_format) { - int zero; - if (!accumulation.GetNumData(buffer, &zero)) { + if (kSufficientStatistics == output_format) { + int num_data; + if (!accumulation.GetNumData(buffer, &num_data)) { return false; } - if (!sptk::WriteStream(static_cast(zero), &std::cout)) { + if (!sptk::WriteStream(static_cast(num_data), &std::cout)) { return false; } @@ -250,7 +242,9 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, * \arg @c 4 correlation * \arg @c 5 precision * \arg @c 6 mean and lower/upper bounds - * \arg @c 7 stats for merge + * \arg @c 7 sufficient statistics + * - @b -s @e str + * - statistics file * - @b -d * - output only diagonal elements * - @b -e @@ -356,6 +350,14 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation, * # 2, 7 * @endcode * + * @code{.sh} + * cat data1.d data2.d | vstat -o 7 > data12.stat + * cat data3.d data4.d | vstat -o 7 > data34.stat + * echo | vstat -s data12.stat -s data34.stat -o 1 > data.mean + * # equivalent to the following line + * cat data?.d | vstat -o 1 > data.mean + * @endcode + * * @param[in] argc Number of arguments. * @param[in] argv Argument vector. * @return 0 on success, 1 on failure. @@ -365,11 +367,13 @@ int main(int argc, char* argv[]) { int output_interval(kMagicNumberForEndOfFile); double confidence_level(kDefaultConfidenceLevel); OutputFormats output_format(kDefaultOutputFormat); + std::vector statistics_file; bool outputs_only_diagonal_elements(kDefaultOutputOnlyDiagonalElementsFlag); bool neumerically_stable(kDefaultNeumericallyStableFlag); for (;;) { - const int option_char(getopt_long(argc, argv, "l:m:t:c:o:deh", NULL, NULL)); + const int option_char( + getopt_long(argc, argv, "l:m:t:c:o:s:deh", NULL, NULL)); if (-1 == option_char) break; switch (option_char) { @@ -433,6 +437,10 @@ int main(int argc, char* argv[]) { output_format = static_cast(tmp); break; } + case 's': { + statistics_file.push_back(optarg); + break; + } case 'd': { outputs_only_diagonal_elements = true; break; @@ -452,6 +460,13 @@ int main(int argc, char* argv[]) { } } + if (kMagicNumberForEndOfFile != output_interval && !statistics_file.empty()) { + std::ostringstream error_message; + error_message << "Cannot specify -t option and -s option at the same time"; + sptk::PrintErrorMessage("vstat", error_message); + return 1; + } + const int num_input_files(argc - optind); if (1 < num_input_files) { std::ostringstream error_message; @@ -501,6 +516,51 @@ int main(int argc, char* argv[]) { return 1; } + for (const char* file : statistics_file) { + std::ifstream ifs; + ifs.open(file, std::ios::in | std::ios::binary); + if (ifs.fail()) { + std::ostringstream error_message; + error_message << "Cannot open file " << file; + sptk::PrintErrorMessage("vstat", error_message); + return 1; + } + std::istream& input_stream(ifs); + + double num_data; + if (!sptk::ReadStream(&num_data, &input_stream)) { + std::ostringstream error_message; + error_message << "Failed to read statistics (zeroth order)"; + sptk::PrintErrorMessage("vstat", error_message); + return 1; + } + + std::vector first(vector_length); + if (!sptk::ReadStream(false, 0, 0, vector_length, &first, &input_stream, + NULL)) { + std::ostringstream error_message; + error_message << "Failed to read statistics (first order)"; + sptk::PrintErrorMessage("vstat", error_message); + return 1; + } + + sptk::SymmetricMatrix second(vector_length); + if (!sptk::ReadStream(&second, &input_stream)) { + std::ostringstream error_message; + error_message << "Failed to read statistics (second order)"; + sptk::PrintErrorMessage("vstat", error_message); + return 1; + } + + if (!accumulation.Merge(static_cast(num_data), first, second, + &buffer)) { + std::ostringstream error_message; + error_message << "Failed to merge statistics"; + sptk::PrintErrorMessage("vstat_merge", error_message); + return 1; + } + } + std::vector data(vector_length); for (int vector_index(1); sptk::ReadStream(false, 0, 0, vector_length, &data, &input_stream, NULL); @@ -525,15 +585,15 @@ int main(int argc, char* argv[]) { } } - int num_actual_vector; - if (!accumulation.GetNumData(buffer, &num_actual_vector)) { + int num_data; + if (!accumulation.GetNumData(buffer, &num_data)) { std::ostringstream error_message; error_message << "Failed to accumulate statistics"; sptk::PrintErrorMessage("vstat", error_message); return 1; } - if (kMagicNumberForEndOfFile == output_interval && 0 < num_actual_vector) { + if (kMagicNumberForEndOfFile == output_interval && 0 < num_data) { if (!OutputStatistics(accumulation, buffer, vector_length, output_format, confidence_level, outputs_only_diagonal_elements)) { std::ostringstream error_message; diff --git a/src/math/statistics_accumulation.cc b/src/math/statistics_accumulation.cc index a68d39b..09bbf3b 100644 --- a/src/math/statistics_accumulation.cc +++ b/src/math/statistics_accumulation.cc @@ -329,4 +329,73 @@ bool StatisticsAccumulation::Run(const std::vector& data, return true; } +bool StatisticsAccumulation::Merge( + int num_data, const std::vector& first, + const SymmetricMatrix& second, + StatisticsAccumulation::Buffer* buffer) const { + if (!is_valid_ || NULL == buffer) { + return false; + } + + if (0 == buffer->zeroth_order_statistics_) { + buffer->zeroth_order_statistics_ = num_data; + buffer->first_order_statistics_ = first; + buffer->second_order_statistics_ = second; + return true; + } + + if (num_data <= 0 || first.size() != buffer->first_order_statistics_.size() || + second.GetNumDimension() != + buffer->second_order_statistics_.GetNumDimension()) { + return false; + } + + // Save the current statistics. + const int m(buffer->zeroth_order_statistics_); + std::vector prev_first_order_statistics( + buffer->first_order_statistics_); + + const int n(num_data); + const int mn(m + n); + buffer->zeroth_order_statistics_ = mn; + + if (1 <= num_statistics_order_) { + if (numerically_stable_) { + const double a(static_cast(n) / mn); + const double b(static_cast(m) / mn); + std::transform(first.begin(), first.end(), + buffer->first_order_statistics_.begin(), + buffer->first_order_statistics_.begin(), + [a, b](double x, double y) { return a * x + b * y; }); + } else { + std::transform( + first.begin(), first.end(), buffer->first_order_statistics_.begin(), + buffer->first_order_statistics_.begin(), std::plus()); + } + } + + if (2 <= num_statistics_order_) { + if (numerically_stable_) { + const double* mu1(&(prev_first_order_statistics[0])); + const double* mu2(&(first[0])); + const double c(static_cast(m * n) / mn); + for (int i(0); i <= num_order_; ++i) { + for (int j(diagonal_ ? i : 0); j <= i; ++j) { + buffer->second_order_statistics_[i][j] += + second[i][j] + c * ((mu1[i] * mu1[j] + mu2[i] * mu2[j]) - + (mu1[i] * mu2[j] + mu2[i] * mu1[j])); + } + } + } else { + for (int i(0); i <= num_order_; ++i) { + for (int j(diagonal_ ? i : 0); j <= i; ++j) { + buffer->second_order_statistics_[i][j] += second[i][j]; + } + } + } + } + + return true; +} + } // namespace sptk diff --git a/src/utils/misc_utils.cc b/src/utils/misc_utils.cc index c00e74b..9a7626c 100644 --- a/src/utils/misc_utils.cc +++ b/src/utils/misc_utils.cc @@ -310,4 +310,47 @@ bool ComputeSecondOrderRegressionCoefficients( return true; } +bool ComputeLowerAndUpperBounds(double confidence_level, int num_data, + const std::vector mean, + const std::vector variance, + std::vector* lower_bound, + std::vector* upper_bound) { + if (confidence_level <= 0.0 || + 100.0 <= confidence_level || + num_data <= 0 || mean.size() != variance.size() || + NULL == lower_bound || NULL == upper_bound) { + return false; + } + + if (lower_bound->size() != mean.size()) { + lower_bound->resize(mean.size()); + } + if (upper_bound->size() != mean.size()) { + upper_bound->resize(mean.size()); + } + + const int degrees_of_freedom(num_data - 1); + if (0 == degrees_of_freedom) { + return false; + } + + double t; + if (!sptk::ComputePercentagePointOfTDistribution( + 0.5 * (1.0 - confidence_level / 100.0), degrees_of_freedom, &t)) { + return false; + } + + const double inverse_degrees_of_freedom(1.0 / degrees_of_freedom); + const int vector_length(static_cast(mean.size())); + double* l(&((*lower_bound)[0])); + double* u(&((*upper_bound)[0])); + for (int i(0); i < vector_length; ++i) { + const double error(std::sqrt(variance[i] * inverse_degrees_of_freedom)); + l[i] = mean[i] - t * error; + u[i] = mean[i] + t * error; + } + + return true; +} + } // namespace sptk diff --git a/test/test_vstat.bats b/test/test_vstat.bats index c261232..a48df1c 100755 --- a/test/test_vstat.bats +++ b/test/test_vstat.bats @@ -63,11 +63,15 @@ teardown() { run $sptk4/aeq $tmp/1 $tmp/2 [ "$status" -eq 0 ] - # Neumerically stable algorithm: - $sptk3/vstat -l 2 $tmp/0 -o 0 > $tmp/1 - $sptk4/vstat -l 2 $tmp/0 -o 0 -e > $tmp/2 - run $sptk4/aeq $tmp/1 $tmp/2 - [ "$status" -eq 0 ] + # Merge: + for opt in "" "-e"; do + $sptk3/bcut +d -l 2 -e 29 $tmp/0 | $sptk4/vstat -l 2 -o 7 $opt > $tmp/1 + $sptk3/bcut +d -l 2 -s 30 $tmp/0 | $sptk4/vstat -l 2 -o 7 $opt > $tmp/2 + echo | $sptk4/vstat -l 2 -s $tmp/1 -s $tmp/2 -o 0 $opt > $tmp/3 + $sptk4/vstat -l 2 $tmp/0 -o 0 > $tmp/4 + run $sptk4/aeq $tmp/3 $tmp/4 + [ "$status" -eq 0 ] + done } @test "vstat: valgrind" {