diff --git a/src/main/merge.cc b/src/main/merge.cc index 34c1335..0e0e37c 100644 --- a/src/main/merge.cc +++ b/src/main/merge.cc @@ -14,14 +14,15 @@ // limitations under the License. // // ------------------------------------------------------------------------ // -#include // int8_t, int16_t, int32_t, int64_t, etc. -#include // std::strncmp -#include // std::ifstream -#include // std::setw -#include // std::cerr, std::cin, std::cout, std::endl, etc. -#include // std::ostringstream -#include // std::string -#include // std::vector +#include // std::copy +#include // int8_t, int16_t, int32_t, int64_t, etc. +#include // std::strncmp +#include // std::ifstream +#include // std::setw +#include // std::cerr, std::cin, std::cout, std::endl, etc. +#include // std::ostringstream +#include // std::string +#include // std::vector #include "GETOPT/ya_getopt.h" #include "SPTK/utils/int24_t.h" @@ -30,9 +31,12 @@ namespace { +enum InputFormats { kNaive = 0, kRecursive, kNumInputFormats }; + const int kDefaultInsertPoint(0); const int kDefaultFrameLengthOfInputData(25); const int kDefaultFrameLengthOfInsertData(10); +const InputFormats kDefaultInputFormat(kNaive); const bool kDefaultOverwriteMode(false); const char* kDefaultDataType("d"); @@ -49,6 +53,13 @@ void PrintUsage(std::ostream* stream) { *stream << " -m m : order of input data ( int)[" << std::setw(5) << std::right << "l-1" << "][ 0 <= m <= ]" << std::endl; // NOLINT *stream << " -L L : frame length of insert data ( int)[" << std::setw(5) << std::right << kDefaultFrameLengthOfInsertData << "][ 1 <= L <= ]" << std::endl; // NOLINT *stream << " -M M : order of insert data ( int)[" << std::setw(5) << std::right << "L-1" << "][ 0 <= M <= ]" << std::endl; // NOLINT + *stream << " -q q : input format ( int)[" << std::setw(5) << std::right << kDefaultInputFormat << "][ 0 <= q <= 1 ]" << std::endl; // NOLINT + *stream << " 0 (naive)" << std::endl; + *stream << " infile: a11 a12 .. a1l a21 a22 .. a2l a31 a32 .. a3l a41 a42 .. a4l" << std::endl; // NOLINT + *stream << " file1 : b11 b12 .. b1L b21 b22 .. b2L b31 b32 .. b3L b41 b42 .. b4L" << std::endl; // NOLINT + *stream << " 1 (recursive)" << std::endl; + *stream << " infile: a11 a12 .. a1l a21 a22 .. a2l a31 a32 .. a3l a41 a42 .. a4l" << std::endl; // NOLINT + *stream << " file1 : b11 b12 .. b1L" << std::endl; *stream << " -w : overwrite mode ( bool)[" << std::setw(5) << std::right << sptk::ConvertBooleanToString(kDefaultOverwriteMode) << "]" << std::endl; // NOLINT *stream << " +type : data type [" << std::setw(5) << std::right << kDefaultDataType << "]" << std::endl; // NOLINT *stream << " "; sptk::PrintDataType("c", stream); sptk::PrintDataType("C", stream); *stream << std::endl; // NOLINT @@ -77,29 +88,41 @@ class VectorMergeInterface { } virtual bool Run(std::istream* input_stream, std::istream* insert_stream, - bool* all_merged) const = 0; + bool* eof_reached) const = 0; }; template class VectorMerge : public VectorMergeInterface { public: VectorMerge(int insert_point, int input_length, int insert_length, - bool overwrite_mode) + bool recursive, bool overwrite_mode, + std::istream* insert_stream = NULL) : insert_point_(insert_point), input_length_(input_length), insert_length_(insert_length), merged_length_(overwrite_mode ? input_length_ : input_length_ + insert_length_), input_rest_length_(merged_length_ - insert_point_ - insert_length_), - input_skip_length_(overwrite_mode ? insert_length_ : 0) { + input_skip_length_(overwrite_mode ? insert_length_ : 0), + recursive_(recursive), + has_vector_(false) { + if (recursive_ && sptk::ReadStream(false, 0, 0, insert_length_, + &insert_vector_, insert_stream, NULL)) { + has_vector_ = true; + } } ~VectorMerge() { } virtual bool Run(std::istream* input_stream, std::istream* insert_stream, - bool* all_merged) const { + bool* eof_reached) const { + if (recursive_ && !has_vector_) { + return true; + } + std::vector merged_vector(merged_length_); + std::vector garbage(input_skip_length_); for (;;) { if (0 < insert_point_) { if (!sptk::ReadStream(false, 0, 0, insert_point_, &merged_vector, @@ -107,8 +130,11 @@ class VectorMerge : public VectorMergeInterface { break; } } - if (!sptk::ReadStream(false, 0, insert_point_, insert_length_, - &merged_vector, insert_stream, NULL)) { + if (recursive_) { + std::copy(insert_vector_.begin(), insert_vector_.end(), + merged_vector.begin() + insert_point_); + } else if (!sptk::ReadStream(false, 0, insert_point_, insert_length_, + &merged_vector, insert_stream, NULL)) { break; } if (0 < input_rest_length_) { @@ -117,6 +143,11 @@ class VectorMerge : public VectorMergeInterface { input_rest_length_, &merged_vector, input_stream, NULL)) { break; } + } else if (0 < input_skip_length_) { + if (!sptk::ReadStream(false, 0, 0, input_skip_length_, &garbage, + input_stream, NULL)) { + break; + } } if (!sptk::WriteStream(0, merged_length_, merged_vector, &std::cout, NULL)) { @@ -124,9 +155,10 @@ class VectorMerge : public VectorMergeInterface { } } - if (all_merged) { - *all_merged = (input_stream->peek() == std::istream::traits_type::eof() && - insert_stream->peek() == std::istream::traits_type::eof()); + if (NULL != eof_reached) { + *eof_reached = + (input_stream->peek() == std::istream::traits_type::eof() && + insert_stream->peek() == std::istream::traits_type::eof()); } return true; } @@ -138,6 +170,10 @@ class VectorMerge : public VectorMergeInterface { const int merged_length_; const int input_rest_length_; const int input_skip_length_; + const bool recursive_; + + bool has_vector_; + std::vector insert_vector_; DISALLOW_COPY_AND_ASSIGN(VectorMerge); }; @@ -145,47 +181,60 @@ class VectorMerge : public VectorMergeInterface { class VectorMergeWrapper { public: VectorMergeWrapper(const std::string& data_type, int insert_point, - int input_length, int insert_length, bool overwrite_mode) + int input_length, int insert_length, bool recursive, + bool overwrite_mode, std::istream* insert_stream) : merge_(NULL) { if ("c" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("s" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("h" == data_type) { merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + insert_length, recursive, + overwrite_mode, insert_stream); } else if ("i" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("l" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("C" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("S" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("H" == data_type) { merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + insert_length, recursive, + overwrite_mode, insert_stream); } else if ("I" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("L" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("f" == data_type) { merge_ = new VectorMerge(insert_point, input_length, insert_length, - overwrite_mode); + recursive, overwrite_mode, insert_stream); } else if ("d" == data_type) { - merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + merge_ = + new VectorMerge(insert_point, input_length, insert_length, + recursive, overwrite_mode, insert_stream); } else if ("e" == data_type) { merge_ = new VectorMerge(insert_point, input_length, - insert_length, overwrite_mode); + insert_length, recursive, + overwrite_mode, insert_stream); } } @@ -198,8 +247,8 @@ class VectorMergeWrapper { } bool Run(std::istream* input_stream, std::istream* insert_stream, - bool* all_merged) const { - return IsValid() && merge_->Run(input_stream, insert_stream, all_merged); + bool* eof_reached) const { + return IsValid() && merge_->Run(input_stream, insert_stream, eof_reached); } private: @@ -223,6 +272,10 @@ class VectorMergeWrapper { * - frame length of output data @f$(1 \le L_2)@f$ * - @b -M @e int * - order of output data @f$(0 \le L_2 - 1)@f$ + * - @b -q @e int + * - input format + * \arg @c 0 naive + * \arg @c 1 recursive * - @b -w * - overwrite mode * - @b +type @e char @@ -270,6 +323,15 @@ class VectorMergeWrapper { * # 4, 1, 5, 2, 6, 3 * @endcode * + * Recursive mode example: + * + * @code{.sh} + * echo 1 1 2 2 3 3 | x2x +as > input.s + * echo 4 | x2x +as > insert.s + * merge -q 1 -s 0 -l 2 -L 1 +s insert.s < input.s | x2x +sa + * # 4, 1, 1, 4, 2, 2, 4, 3, 3 + * @endcode + * * @param[in] argc Number of arguments. * @param[in] argv Argument vector. * @return 0 on success, 1 on failure. @@ -278,11 +340,13 @@ int main(int argc, char* argv[]) { int insert_point(kDefaultInsertPoint); int input_length(kDefaultFrameLengthOfInputData); int insert_length(kDefaultFrameLengthOfInsertData); + InputFormats input_format(kDefaultInputFormat); bool overwrite_mode(kDefaultOverwriteMode); std::string data_type(kDefaultDataType); for (;;) { - const int option_char(getopt_long(argc, argv, "s:l:m:L:M:wh", NULL, NULL)); + const int option_char( + getopt_long(argc, argv, "s:l:m:L:M:q:wh", NULL, NULL)); if (-1 == option_char) break; switch (option_char) { @@ -343,6 +407,21 @@ int main(int argc, char* argv[]) { ++insert_length; break; } + case 'q': { + const int min(0); + const int max(static_cast(kNumInputFormats) - 1); + int tmp; + if (!sptk::ConvertStringToInteger(optarg, &tmp) || + !sptk::IsInRange(tmp, min, max)) { + std::ostringstream error_message; + error_message << "The argument for the -q option must be an integer " + << "in the range of " << min << " to " << max; + sptk::PrintErrorMessage("merge", error_message); + return 1; + } + input_format = static_cast(tmp); + break; + } case 'w': { overwrite_mode = true; break; @@ -429,7 +508,8 @@ int main(int argc, char* argv[]) { std::istream& input_stream(ifs2.is_open() ? ifs2 : std::cin); VectorMergeWrapper merge(data_type, insert_point, input_length, insert_length, - overwrite_mode); + kRecursive == input_format, overwrite_mode, + &insert_stream); if (!merge.IsValid()) { std::ostringstream error_message; diff --git a/test/test_merge.bats b/test/test_merge.bats index 03fec41..ecace4c 100755 --- a/test/test_merge.bats +++ b/test/test_merge.bats @@ -47,11 +47,29 @@ teardown() { $sptk3/x2x +"${ary1[$t]}"d > $tmp/4 run $sptk4/aeq -L $tmp/3 $tmp/4 [ "$status" -eq 0 ] + + $sptk3/merge +"${ary1[$t]}" $tmp/1 $tmp/2 -s 2 -l 6 -L 4 -o | + $sptk3/x2x +"${ary1[$t]}"d > $tmp/3 + $sptk4/merge +"${ary2[$t]}" $tmp/1 $tmp/2 -s 2 -l 6 -L 4 -w | + $sptk3/x2x +"${ary1[$t]}"d > $tmp/4 + run $sptk4/aeq -L $tmp/3 $tmp/4 + [ "$status" -eq 0 ] done } +@test "merge: recursive" { + echo 4 | $sptk3/x2x +ad > $tmp/1 + echo 1 1 2 2 3 3 | $sptk3/x2x +ad > $tmp/2 + $sptk4/merge +d $tmp/1 $tmp/2 -q 1 -s 0 -l 2 -L 1 > $tmp/3 + echo 4 1 1 4 2 2 4 3 3 | $sptk3/x2x +ad > $tmp/4 + run $sptk4/aeq $tmp/3 $tmp/4 + [ "$status" -eq 0 ] +} + @test "merge: valgrind" { $sptk3/nrand -l 20 > $tmp/1 - run valgrind $sptk4/merge +d $tmp/1 $tmp/1 -l 1 -L 1 - [ "$(echo "${lines[-1]}" | sed -r 's/.*SUMMARY: ([0-9]*) .*/\1/')" -eq 0 ] + for q in $(seq 0 1); do + run valgrind $sptk4/merge +d -q "$q" $tmp/1 $tmp/1 -l 1 -L 1 + [ "$(echo "${lines[-1]}" | sed -r 's/.*SUMMARY: ([0-9]*) .*/\1/')" -eq 0 ] + done } diff --git a/tools/Makefile b/tools/Makefile index 6417ae5..d63ee3c 100644 --- a/tools/Makefile +++ b/tools/Makefile @@ -16,7 +16,7 @@ PYTHON_VERSION := 3.8 SHELLCHECK_VERSION := 0.9.0 -SHFMT_VERSION := 3.6.0 +SHFMT_VERSION := 3.7.0 JOBS := 4