Skip to content

Commit

Permalink
Merge pull request #44 from sp-nitech/update_merge
Browse files Browse the repository at this point in the history
Support recursively merging
  • Loading branch information
takenori-y committed Aug 2, 2023
2 parents ecf5821 + a1b3493 commit 26f849c
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 47 deletions.
168 changes: 124 additions & 44 deletions src/main/merge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,15 @@
// limitations under the License. //
// ------------------------------------------------------------------------ //

#include <cstdint> // int8_t, int16_t, int32_t, int64_t, etc.
#include <cstring> // std::strncmp
#include <fstream> // std::ifstream
#include <iomanip> // std::setw
#include <iostream> // std::cerr, std::cin, std::cout, std::endl, etc.
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <vector> // std::vector
#include <algorithm> // std::copy
#include <cstdint> // int8_t, int16_t, int32_t, int64_t, etc.
#include <cstring> // std::strncmp
#include <fstream> // std::ifstream
#include <iomanip> // std::setw
#include <iostream> // std::cerr, std::cin, std::cout, std::endl, etc.
#include <sstream> // std::ostringstream
#include <string> // std::string
#include <vector> // std::vector

#include "GETOPT/ya_getopt.h"
#include "SPTK/utils/int24_t.h"
Expand All @@ -30,9 +31,12 @@

namespace {

enum InputFormats { kNaive = 0, kRecursive, kNumInputFormats };

const int kDefaultInsertPoint(0);
const int kDefaultFrameLengthOfInputData(25);
const int kDefaultFrameLengthOfInsertData(10);
const InputFormats kDefaultInputFormat(kNaive);
const bool kDefaultOverwriteMode(false);
const char* kDefaultDataType("d");

Expand All @@ -49,6 +53,13 @@ void PrintUsage(std::ostream* stream) {
*stream << " -m m : order of input data ( int)[" << std::setw(5) << std::right << "l-1" << "][ 0 <= m <= ]" << std::endl; // NOLINT
*stream << " -L L : frame length of insert data ( int)[" << std::setw(5) << std::right << kDefaultFrameLengthOfInsertData << "][ 1 <= L <= ]" << std::endl; // NOLINT
*stream << " -M M : order of insert data ( int)[" << std::setw(5) << std::right << "L-1" << "][ 0 <= M <= ]" << std::endl; // NOLINT
*stream << " -q q : input format ( int)[" << std::setw(5) << std::right << kDefaultInputFormat << "][ 0 <= q <= 1 ]" << std::endl; // NOLINT
*stream << " 0 (naive)" << std::endl;
*stream << " infile: a11 a12 .. a1l a21 a22 .. a2l a31 a32 .. a3l a41 a42 .. a4l" << std::endl; // NOLINT
*stream << " file1 : b11 b12 .. b1L b21 b22 .. b2L b31 b32 .. b3L b41 b42 .. b4L" << std::endl; // NOLINT
*stream << " 1 (recursive)" << std::endl;
*stream << " infile: a11 a12 .. a1l a21 a22 .. a2l a31 a32 .. a3l a41 a42 .. a4l" << std::endl; // NOLINT
*stream << " file1 : b11 b12 .. b1L" << std::endl;
*stream << " -w : overwrite mode ( bool)[" << std::setw(5) << std::right << sptk::ConvertBooleanToString(kDefaultOverwriteMode) << "]" << std::endl; // NOLINT
*stream << " +type : data type [" << std::setw(5) << std::right << kDefaultDataType << "]" << std::endl; // NOLINT
*stream << " "; sptk::PrintDataType("c", stream); sptk::PrintDataType("C", stream); *stream << std::endl; // NOLINT
Expand Down Expand Up @@ -77,38 +88,53 @@ class VectorMergeInterface {
}

virtual bool Run(std::istream* input_stream, std::istream* insert_stream,
bool* all_merged) const = 0;
bool* eof_reached) const = 0;
};

template <typename T>
class VectorMerge : public VectorMergeInterface {
public:
VectorMerge(int insert_point, int input_length, int insert_length,
bool overwrite_mode)
bool recursive, bool overwrite_mode,
std::istream* insert_stream = NULL)
: insert_point_(insert_point),
input_length_(input_length),
insert_length_(insert_length),
merged_length_(overwrite_mode ? input_length_
: input_length_ + insert_length_),
input_rest_length_(merged_length_ - insert_point_ - insert_length_),
input_skip_length_(overwrite_mode ? insert_length_ : 0) {
input_skip_length_(overwrite_mode ? insert_length_ : 0),
recursive_(recursive),
has_vector_(false) {
if (recursive_ && sptk::ReadStream(false, 0, 0, insert_length_,
&insert_vector_, insert_stream, NULL)) {
has_vector_ = true;
}
}

~VectorMerge() {
}

virtual bool Run(std::istream* input_stream, std::istream* insert_stream,
bool* all_merged) const {
bool* eof_reached) const {
if (recursive_ && !has_vector_) {
return true;
}

std::vector<T> merged_vector(merged_length_);
std::vector<T> garbage(input_skip_length_);
for (;;) {
if (0 < insert_point_) {
if (!sptk::ReadStream(false, 0, 0, insert_point_, &merged_vector,
input_stream, NULL)) {
break;
}
}
if (!sptk::ReadStream(false, 0, insert_point_, insert_length_,
&merged_vector, insert_stream, NULL)) {
if (recursive_) {
std::copy(insert_vector_.begin(), insert_vector_.end(),
merged_vector.begin() + insert_point_);
} else if (!sptk::ReadStream(false, 0, insert_point_, insert_length_,
&merged_vector, insert_stream, NULL)) {
break;
}
if (0 < input_rest_length_) {
Expand All @@ -117,16 +143,22 @@ class VectorMerge : public VectorMergeInterface {
input_rest_length_, &merged_vector, input_stream, NULL)) {
break;
}
} else if (0 < input_skip_length_) {
if (!sptk::ReadStream(false, 0, 0, input_skip_length_, &garbage,
input_stream, NULL)) {
break;
}
}
if (!sptk::WriteStream(0, merged_length_, merged_vector, &std::cout,
NULL)) {
return false;
}
}

if (all_merged) {
*all_merged = (input_stream->peek() == std::istream::traits_type::eof() &&
insert_stream->peek() == std::istream::traits_type::eof());
if (NULL != eof_reached) {
*eof_reached =
(input_stream->peek() == std::istream::traits_type::eof() &&
insert_stream->peek() == std::istream::traits_type::eof());
}
return true;
}
Expand All @@ -138,54 +170,71 @@ class VectorMerge : public VectorMergeInterface {
const int merged_length_;
const int input_rest_length_;
const int input_skip_length_;
const bool recursive_;

bool has_vector_;
std::vector<T> insert_vector_;

DISALLOW_COPY_AND_ASSIGN(VectorMerge<T>);
};

class VectorMergeWrapper {
public:
VectorMergeWrapper(const std::string& data_type, int insert_point,
int input_length, int insert_length, bool overwrite_mode)
int input_length, int insert_length, bool recursive,
bool overwrite_mode, std::istream* insert_stream)
: merge_(NULL) {
if ("c" == data_type) {
merge_ = new VectorMerge<int8_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<int8_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("s" == data_type) {
merge_ = new VectorMerge<int16_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<int16_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("h" == data_type) {
merge_ = new VectorMerge<sptk::int24_t>(insert_point, input_length,
insert_length, overwrite_mode);
insert_length, recursive,
overwrite_mode, insert_stream);
} else if ("i" == data_type) {
merge_ = new VectorMerge<int32_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<int32_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("l" == data_type) {
merge_ = new VectorMerge<int64_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<int64_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("C" == data_type) {
merge_ = new VectorMerge<uint8_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<uint8_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("S" == data_type) {
merge_ = new VectorMerge<uint16_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<uint16_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("H" == data_type) {
merge_ = new VectorMerge<sptk::uint24_t>(insert_point, input_length,
insert_length, overwrite_mode);
insert_length, recursive,
overwrite_mode, insert_stream);
} else if ("I" == data_type) {
merge_ = new VectorMerge<uint32_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<uint32_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("L" == data_type) {
merge_ = new VectorMerge<uint64_t>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<uint64_t>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("f" == data_type) {
merge_ = new VectorMerge<float>(insert_point, input_length, insert_length,
overwrite_mode);
recursive, overwrite_mode, insert_stream);
} else if ("d" == data_type) {
merge_ = new VectorMerge<double>(insert_point, input_length,
insert_length, overwrite_mode);
merge_ =
new VectorMerge<double>(insert_point, input_length, insert_length,
recursive, overwrite_mode, insert_stream);
} else if ("e" == data_type) {
merge_ = new VectorMerge<long double>(insert_point, input_length,
insert_length, overwrite_mode);
insert_length, recursive,
overwrite_mode, insert_stream);
}
}

Expand All @@ -198,8 +247,8 @@ class VectorMergeWrapper {
}

bool Run(std::istream* input_stream, std::istream* insert_stream,
bool* all_merged) const {
return IsValid() && merge_->Run(input_stream, insert_stream, all_merged);
bool* eof_reached) const {
return IsValid() && merge_->Run(input_stream, insert_stream, eof_reached);
}

private:
Expand All @@ -223,6 +272,10 @@ class VectorMergeWrapper {
* - frame length of output data @f$(1 \le L_2)@f$
* - @b -M @e int
* - order of output data @f$(0 \le L_2 - 1)@f$
* - @b -q @e int
* - input format
* \arg @c 0 naive
* \arg @c 1 recursive
* - @b -w
* - overwrite mode
* - @b +type @e char
Expand Down Expand Up @@ -270,6 +323,15 @@ class VectorMergeWrapper {
* # 4, 1, 5, 2, 6, 3
* @endcode
*
* Recursive mode example:
*
* @code{.sh}
* echo 1 1 2 2 3 3 | x2x +as > input.s
* echo 4 | x2x +as > insert.s
* merge -q 1 -s 0 -l 2 -L 1 +s insert.s < input.s | x2x +sa
* # 4, 1, 1, 4, 2, 2, 4, 3, 3
* @endcode
*
* @param[in] argc Number of arguments.
* @param[in] argv Argument vector.
* @return 0 on success, 1 on failure.
Expand All @@ -278,11 +340,13 @@ int main(int argc, char* argv[]) {
int insert_point(kDefaultInsertPoint);
int input_length(kDefaultFrameLengthOfInputData);
int insert_length(kDefaultFrameLengthOfInsertData);
InputFormats input_format(kDefaultInputFormat);
bool overwrite_mode(kDefaultOverwriteMode);
std::string data_type(kDefaultDataType);

for (;;) {
const int option_char(getopt_long(argc, argv, "s:l:m:L:M:wh", NULL, NULL));
const int option_char(
getopt_long(argc, argv, "s:l:m:L:M:q:wh", NULL, NULL));
if (-1 == option_char) break;

switch (option_char) {
Expand Down Expand Up @@ -343,6 +407,21 @@ int main(int argc, char* argv[]) {
++insert_length;
break;
}
case 'q': {
const int min(0);
const int max(static_cast<int>(kNumInputFormats) - 1);
int tmp;
if (!sptk::ConvertStringToInteger(optarg, &tmp) ||
!sptk::IsInRange(tmp, min, max)) {
std::ostringstream error_message;
error_message << "The argument for the -q option must be an integer "
<< "in the range of " << min << " to " << max;
sptk::PrintErrorMessage("merge", error_message);
return 1;
}
input_format = static_cast<InputFormats>(tmp);
break;
}
case 'w': {
overwrite_mode = true;
break;
Expand Down Expand Up @@ -429,7 +508,8 @@ int main(int argc, char* argv[]) {
std::istream& input_stream(ifs2.is_open() ? ifs2 : std::cin);

VectorMergeWrapper merge(data_type, insert_point, input_length, insert_length,
overwrite_mode);
kRecursive == input_format, overwrite_mode,
&insert_stream);

if (!merge.IsValid()) {
std::ostringstream error_message;
Expand Down
22 changes: 20 additions & 2 deletions test/test_merge.bats
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,29 @@ teardown() {
$sptk3/x2x +"${ary1[$t]}"d > $tmp/4
run $sptk4/aeq -L $tmp/3 $tmp/4
[ "$status" -eq 0 ]

$sptk3/merge +"${ary1[$t]}" $tmp/1 $tmp/2 -s 2 -l 6 -L 4 -o |
$sptk3/x2x +"${ary1[$t]}"d > $tmp/3
$sptk4/merge +"${ary2[$t]}" $tmp/1 $tmp/2 -s 2 -l 6 -L 4 -w |
$sptk3/x2x +"${ary1[$t]}"d > $tmp/4
run $sptk4/aeq -L $tmp/3 $tmp/4
[ "$status" -eq 0 ]
done
}

@test "merge: recursive" {
echo 4 | $sptk3/x2x +ad > $tmp/1
echo 1 1 2 2 3 3 | $sptk3/x2x +ad > $tmp/2
$sptk4/merge +d $tmp/1 $tmp/2 -q 1 -s 0 -l 2 -L 1 > $tmp/3
echo 4 1 1 4 2 2 4 3 3 | $sptk3/x2x +ad > $tmp/4
run $sptk4/aeq $tmp/3 $tmp/4
[ "$status" -eq 0 ]
}

@test "merge: valgrind" {
$sptk3/nrand -l 20 > $tmp/1
run valgrind $sptk4/merge +d $tmp/1 $tmp/1 -l 1 -L 1
[ "$(echo "${lines[-1]}" | sed -r 's/.*SUMMARY: ([0-9]*) .*/\1/')" -eq 0 ]
for q in $(seq 0 1); do
run valgrind $sptk4/merge +d -q "$q" $tmp/1 $tmp/1 -l 1 -L 1
[ "$(echo "${lines[-1]}" | sed -r 's/.*SUMMARY: ([0-9]*) .*/\1/')" -eq 0 ]
done
}
2 changes: 1 addition & 1 deletion tools/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

PYTHON_VERSION := 3.8
SHELLCHECK_VERSION := 0.9.0
SHFMT_VERSION := 3.6.0
SHFMT_VERSION := 3.7.0
JOBS := 4


Expand Down

0 comments on commit 26f849c

Please sign in to comment.