Skip to content

Commit

Permalink
Merge pull request #50 from sp-nitech/update_vstat
Browse files Browse the repository at this point in the history
Reduce computational cost of vstat
  • Loading branch information
takenori-y committed Oct 19, 2023
2 parents 7f97c2d + 195b986 commit 9176709
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 36 deletions.
5 changes: 4 additions & 1 deletion include/SPTK/math/statistics_accumulation.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,10 @@ class StatisticsAccumulation {
/**
* @param[in] num_order Order of vector, @f$M@f$.
* @param[in] num_statistics_order Order of statistics, @f$K@f$.
* @param[in] diagonal If true, only diagonal part is accumulated.
*/
StatisticsAccumulation(int num_order, int num_statistics_order);
StatisticsAccumulation(int num_order, int num_statistics_order,
bool diagonal = false);

virtual ~StatisticsAccumulation() {
}
Expand Down Expand Up @@ -186,6 +188,7 @@ class StatisticsAccumulation {
private:
const int num_order_;
const int num_statistics_order_;
const bool diagonal_;

bool is_valid_;

Expand Down
48 changes: 20 additions & 28 deletions src/main/vstat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ void PrintUsage(std::ostream* stream) {
*stream << " vectors (double)[stdin]" << std::endl;
*stream << " stdout:" << std::endl;
*stream << " statistics (double)" << std::endl;
*stream << " notice:" << std::endl;
*stream << " -d is valid only if o = 0 or o = 2" << std::endl;
*stream << std::endl;
*stream << " SPTK: version " << sptk::kVersion << std::endl;
*stream << std::endl;
Expand Down Expand Up @@ -109,12 +111,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation,
if (!accumulation.GetFullCovariance(buffer, &variance)) {
return false;
}
for (int i(0); i < vector_length; ++i) {
for (int j(0); j < vector_length; ++j) {
if (!sptk::WriteStream(variance[i][j], &std::cout)) {
return false;
}
}
if (!sptk::WriteStream(variance, &std::cout)) {
return false;
}
}
}
Expand All @@ -135,12 +133,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation,
if (!accumulation.GetCorrelation(buffer, &correlation)) {
return false;
}
for (int i(0); i < vector_length; ++i) {
for (int j(0); j < vector_length; ++j) {
if (!sptk::WriteStream(correlation[i][j], &std::cout)) {
return false;
}
}
if (!sptk::WriteStream(correlation, &std::cout)) {
return false;
}
}

Expand All @@ -155,20 +149,8 @@ bool OutputStatistics(const sptk::StatisticsAccumulation& accumulation,
return false;
}

if (outputs_only_diagonal_elements) {
for (int i(0); i < vector_length; ++i) {
if (!sptk::WriteStream(precision_matrix[i][i], &std::cout)) {
return false;
}
}
} else {
for (int i(0); i < vector_length; ++i) {
for (int j(0); j < vector_length; ++j) {
if (!sptk::WriteStream(precision_matrix[i][j], &std::cout)) {
return false;
}
}
}
if (!sptk::WriteStream(precision_matrix, &std::cout)) {
return false;
}
}

Expand Down Expand Up @@ -459,8 +441,18 @@ int main(int argc, char* argv[]) {
}
std::istream& input_stream(ifs.is_open() ? ifs : std::cin);

sptk::StatisticsAccumulation accumulation(vector_length - 1,
kMean == output_format ? 1 : 2);
bool diagonal(false);
if (kMeanAndCovariance == output_format || kCovariance == output_format) {
if (outputs_only_diagonal_elements) {
diagonal = true;
}
} else if (kStandardDeviation == output_format ||
kMeanAndLowerAndUpperBounds == output_format) {
diagonal = true;
}

sptk::StatisticsAccumulation accumulation(
vector_length - 1, kMean == output_format ? 1 : 2, diagonal);
sptk::StatisticsAccumulation::Buffer buffer;
if (!accumulation.IsValid()) {
std::ostringstream error_message;
Expand Down
13 changes: 8 additions & 5 deletions src/math/statistics_accumulation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@
namespace sptk {

StatisticsAccumulation::StatisticsAccumulation(int num_order,
int num_statistics_order)
int num_statistics_order,
bool diagonal)
: num_order_(num_order),
num_statistics_order_(num_statistics_order),
diagonal_(diagonal),
is_valid_(true) {
if (num_order_ < 0 || num_statistics_order_ < 0 ||
2 < num_statistics_order_) {
Expand Down Expand Up @@ -131,7 +133,7 @@ bool StatisticsAccumulation::GetStandardDeviation(
bool StatisticsAccumulation::GetFullCovariance(
const StatisticsAccumulation::Buffer& buffer,
SymmetricMatrix* full_covariance) const {
if (!is_valid_ || num_statistics_order_ < 2 ||
if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ ||
buffer.zeroth_order_statistics_ <= 0 || NULL == full_covariance) {
return false;
}
Expand Down Expand Up @@ -160,7 +162,7 @@ bool StatisticsAccumulation::GetFullCovariance(
bool StatisticsAccumulation::GetUnbiasedCovariance(
const StatisticsAccumulation::Buffer& buffer,
SymmetricMatrix* unbiased_covariance) const {
if (!is_valid_ || num_statistics_order_ < 2 ||
if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ ||
buffer.zeroth_order_statistics_ <= 1 || NULL == unbiased_covariance) {
return false;
}
Expand All @@ -183,7 +185,8 @@ bool StatisticsAccumulation::GetUnbiasedCovariance(
bool StatisticsAccumulation::GetCorrelation(
const StatisticsAccumulation::Buffer& buffer,
SymmetricMatrix* correlation) const {
if (!is_valid_ || num_statistics_order_ < 2 || NULL == correlation) {
if (!is_valid_ || num_statistics_order_ < 2 || diagonal_ ||
NULL == correlation) {
return false;
}

Expand Down Expand Up @@ -242,7 +245,7 @@ bool StatisticsAccumulation::Run(const std::vector<double>& data,
// Accumulate 2nd order statistics.
if (2 <= num_statistics_order_) {
for (int i(0); i < length; ++i) {
for (int j(0); j <= i; ++j) {
for (int j(diagonal_ ? i : 0); j <= i; ++j) {
buffer->second_order_statistics_[i][j] += data[i] * data[j];
}
}
Expand Down
4 changes: 2 additions & 2 deletions test/test_vstat.bats
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ teardown() {
@test "vstat: compatibility" {
$sptk3/nrand -l 100 > $tmp/0

ary1=("-o 0" "-o 1" "-o 2" "-o 2 -d" "-o 2 -r" "-o 2 -i")
ary2=("-o 0" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5")
ary1=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 2 -r" "-o 2 -i")
ary2=("-o 0" "-o 0 -d" "-o 1" "-o 2" "-o 2 -d" "-o 4" "-o 5")
for i in $(seq 0 $((${#ary1[@]} - 1))); do
# shellcheck disable=SC2086
$sptk3/vstat -l 2 $tmp/0 ${ary1[$i]} > $tmp/1
Expand Down

0 comments on commit 9176709

Please sign in to comment.