Skip to content

Commit

Permalink
add onehot
Browse files Browse the repository at this point in the history
  • Loading branch information
takenori-y committed May 23, 2024
1 parent 3dc6937 commit 57df254
Show file tree
Hide file tree
Showing 3 changed files with 226 additions and 0 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,7 @@ set(MAIN_SOURCES
${SOURCE_DIR}/main/ndps2c.cc
${SOURCE_DIR}/main/norm0.cc
${SOURCE_DIR}/main/nrand.cc
${SOURCE_DIR}/main/onehot.cc
${SOURCE_DIR}/main/par2lar.cc
${SOURCE_DIR}/main/par2lpc.cc
${SOURCE_DIR}/main/pca.cc
Expand Down
6 changes: 6 additions & 0 deletions doc/main/onehot.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
.. _onehot:

onehot
======

.. doxygenfile:: onehot.cc
219 changes: 219 additions & 0 deletions src/main/onehot.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// ------------------------------------------------------------------------ //
// Copyright 2021 SPTK Working Group //
// //
// Licensed under the Apache License, Version 2.0 (the "License"); //
// you may not use this file except in compliance with the License. //
// You may obtain a copy of the License at //
// //
// http://www.apache.org/licenses/LICENSE-2.0 //
// //
// Unless required by applicable law or agreed to in writing, software //
// distributed under the License is distributed on an "AS IS" BASIS, //
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. //
// See the License for the specific language governing permissions and //
// limitations under the License. //
// ------------------------------------------------------------------------ //

#include <fstream> // std::ifstream
#include <iomanip> // std::setw
#include <iostream> // std::cerr, std::cin, std::cout, std::endl, etc.
#include <sstream> // std::ostringstream
#include <vector> // std::vector

#include "GETOPT/ya_getopt.h"
#include "SPTK/utils/sptk_utils.h"

namespace {

enum WarningType { kIgnore = 0, kWarn, kExit, kNumWarningTypes };

const int kDefaultVectorLength(10);
const WarningType kDefaultWarningType(kIgnore);

void PrintUsage(std::ostream* stream) {
// clang-format off
*stream << std::endl;
*stream << " onehot - generate one-hot vector sequence" << std::endl;
*stream << std::endl;
*stream << " usage:" << std::endl;
*stream << " onehot [ options ] [ infile ] > stdout" << std::endl;
*stream << " options:" << std::endl;
*stream << " -l l : length of vector (double)[" << std::setw(5) << std::right << kDefaultVectorLength << "][ 1 <= l <= ]" << std::endl; // NOLINT
*stream << " -m m : order of vector (double)[" << std::setw(5) << std::right << "l-1" << "][ 0 <= m <= ]" << std::endl; // NOLINT
*stream << " -e e : out-of-range error ( int)[" << std::setw(5) << std::right << kDefaultWarningType << "][ 0 <= e <= 2 ]" << std::endl; // NOLINT
*stream << " 0 (no warning)" << std::endl;
*stream << " 1 (output the index to stderr)" << std::endl;
*stream << " 2 (output the index to stderr and" << std::endl;
*stream << " exit immediately)" << std::endl;
*stream << " -h : print this message" << std::endl;
*stream << " infile:" << std::endl;
*stream << " 0-based index ( int)[stdin]" << std::endl;
*stream << " stdout:" << std::endl;
*stream << " one-hot vector (double)" << std::endl;
*stream << std::endl;
*stream << " SPTK: version " << sptk::kVersion << std::endl;
*stream << std::endl;
// clang-format on
}

} // namespace

/**
* @a onehot [ @e option ] [ @e infile ]
*
* - @b -l @e int
* - length of vector @f$(1 \le L)@f$
* - @b -m @e int
* - order of vector @f$(0 \le L - 1)@f$
* - @b -e @e int
* - warning type
* \arg @c 0 no warning
* \arg @c 1 output index
* \arg @c 2 output index and exit immediately
* - @b infile @e str
* - int-type 0-based index
* - @b stdout
* - double-type one-hot vector
*
* @code{.sh}
* step -l 3 | x2x +di | onehot -l 3 | x2x +da
* # 1, 0, 0, 0, 1, 0, 0, 0, 1
* @endcode
*
* @param[in] argc Number of arguments.
* @param[in] argv Argument vector.
* @return 0 on success, 1 on failure.
*/
int main(int argc, char* argv[]) {
int vector_length(kDefaultVectorLength);
WarningType warning_type(kDefaultWarningType);

for (;;) {
const int option_char(getopt_long(argc, argv, "l:m:e:h", NULL, NULL));
if (-1 == option_char) break;

switch (option_char) {
case 'l': {
if (!sptk::ConvertStringToInteger(optarg, &vector_length) ||
vector_length <= 0) {
std::ostringstream error_message;
error_message
<< "The argument for the -l option must be a positive integer";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
break;
}
case 'm': {
if (!sptk::ConvertStringToInteger(optarg, &vector_length) ||
vector_length < 0) {
std::ostringstream error_message;
error_message << "The argument for the -m option must be a "
<< "non-negative integer";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
++vector_length;
break;
}
case 'e': {
const int min(0);
const int max(static_cast<int>(kNumWarningTypes) - 1);
int tmp;
if (!sptk::ConvertStringToInteger(optarg, &tmp) ||
!sptk::IsInRange(tmp, min, max)) {
std::ostringstream error_message;
error_message << "The argument for the -e option must be an integer "
<< "in the range of " << min << " to " << max;
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
warning_type = static_cast<WarningType>(tmp);
break;
}
case 'h': {
PrintUsage(&std::cout);
return 0;
}
default: {
PrintUsage(&std::cerr);
return 1;
}
}
}

const int num_input_files(argc - optind);
if (1 < num_input_files) {
std::ostringstream error_message;
error_message << "Too many input files";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
const char* input_file(0 == num_input_files ? NULL : argv[optind]);

if (!sptk::SetBinaryMode()) {
std::ostringstream error_message;
error_message << "Cannot set translation mode";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}

std::ifstream ifs;
if (NULL != input_file) {
ifs.open(input_file, std::ios::in | std::ios::binary);
if (ifs.fail()) {
std::ostringstream error_message;
error_message << "Cannot open file " << input_file;
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
}
std::istream& input_stream(ifs.is_open() ? ifs : std::cin);

int index;
std::vector<double> onehot_vector(vector_length, 0.0);

for (int sample_index(0); sptk::ReadStream(&index, &input_stream);
++sample_index) {
const bool is_valid(0 <= index && index < vector_length);
if (is_valid) {
onehot_vector[index] = 1.0;
} else {
switch (warning_type) {
case kIgnore: {
// nothing to do
break;
}
case kWarn: {
std::ostringstream error_message;
error_message << sample_index << "th sample is out of range";
sptk::PrintErrorMessage("onehot", error_message);
break;
}
case kExit: {
std::ostringstream error_message;
error_message << sample_index << "th sample is out of range";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
default: {
std::ostringstream error_message;
error_message << "Unknown warning type";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
}
}
if (!sptk::WriteStream(0, vector_length, onehot_vector, &std::cout, NULL)) {
std::ostringstream error_message;
error_message << "Failed to write one-hot vector";
sptk::PrintErrorMessage("onehot", error_message);
return 1;
}
if (is_valid) {
onehot_vector[index] = 0.0;
}
}

return 0;
}

0 comments on commit 57df254

Please sign in to comment.