Skip to content

Commit

Permalink
Add support for single-line regex anchors ^/$ in contains_re (#9482)
Browse files Browse the repository at this point in the history
Closes #9439 

The `^` (begin anchor) and `$` (end anchor) apply to beginning of line (BOL) and end of line (EOL) respectively. This means that they cannot be used to match on strings containing embedded new-line ('\n') characters when desiring the anchors only match just the beginning and end of the string as a whole.

Many regex engines support a flag for overriding the behavior of the BOL/EOL anchors: [Python](https://docs.python.org/3/library/re.html#re.MULTILINE), [Java](https://docs.oracle.com/javase/7/docs/api/java/util/regex/Pattern.html#MULTILINE), [C++](https://en.cppreference.com/w/cpp/regex/basic_regex/constants). This PR introduces a similar flag parameter to the `cudf::strings::contains_re`, `cudf::strings::matches_re` and `cudf::strings::count_re` APIs to tell the regex engine how to interpret the anchor characters in the given regex pattern. 

Additional information about these anchors can also be found here: https://www.regular-expressions.info/anchors.html

The current default behavior of the libcudf regex is to interpret BOL/EOL as similar to the `MULTILINE` flag. This behavior doesn't match the engines/languages listed above. So for consistency the default is reversed requiring this PR to be a breaking change.

Also, the new `flags` parameter added to the above APIs requires this to be a breaking change. An additional flag (DOTALL) is included in this PR since the internal regex code supports it but only needed a path for the caller to specify the behavior. The `DOTALL` flag is also a feature of the above languages. When specified, the dot '.' pattern includes embedded new-line characters in its matching character set.

Authors:
  - David Wendt (https://github.com/davidwendt)

Approvers:
  - AJ Schmidt (https://github.com/ajschmidt8)
  - Bradley Dice (https://github.com/bdice)
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #9482
  • Loading branch information
davidwendt authored Nov 1, 2021
1 parent 237b0ce commit d073ecb
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 60 deletions.
1 change: 1 addition & 0 deletions conda/recipes/libcudf/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ test:
- test -f $PREFIX/include/cudf/strings/find_multiple.hpp
- test -f $PREFIX/include/cudf/strings/json.hpp
- test -f $PREFIX/include/cudf/strings/padding.hpp
- test -f $PREFIX/include/cudf/strings/regex/flags.hpp
- test -f $PREFIX/include/cudf/strings/repeat_strings.hpp
- test -f $PREFIX/include/cudf/strings/replace.hpp
- test -f $PREFIX/include/cudf/strings/replace_re.hpp
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/cudf/strings/contains.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#pragma once

#include <cudf/column/column.hpp>
#include <cudf/strings/regex/flags.hpp>
#include <cudf/strings/strings_column_view.hpp>

namespace cudf {
Expand Down Expand Up @@ -44,12 +45,14 @@ namespace strings {
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match to each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New column of boolean results for each string.
*/
std::unique_ptr<column> contains_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -69,12 +72,14 @@ std::unique_ptr<column> contains_re(
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match to each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New column of boolean results for each string.
*/
std::unique_ptr<column> matches_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/**
Expand All @@ -94,12 +99,14 @@ std::unique_ptr<column> matches_re(
*
* @param strings Strings instance for this operation.
* @param pattern Regex pattern to match within each string.
* @param flags Regex flags for interpreting special characters in the pattern.
* @param mr Device memory resource used to allocate the returned column's device memory.
* @return New INT32 column with counts for each string.
*/
std::unique_ptr<column> count_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags = regex_flags::DEFAULT,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource());

/** @} */ // end of doxygen group
Expand Down
65 changes: 65 additions & 0 deletions cpp/include/cudf/strings/regex/flags.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <cstdint>

namespace cudf {
namespace strings {

/**
* @addtogroup strings_contains
* @{
*/

/**
* @brief Regex flags.
*
* These types can be or'd to combine them.
* The values are chosen to leave room for future flags
* and to match the Python flag values.
*/
enum regex_flags : uint32_t {
DEFAULT = 0, /// default
MULTILINE = 8, /// the '^' and '$' honor new-line characters
DOTALL = 16 /// the '.' matching includes new-line characters
};

/**
* @brief Returns true if the given flags contain MULTILINE.
*
* @param f Regex flags to check
* @return true if `f` includes MULTILINE
*/
constexpr bool is_multiline(regex_flags const f)
{
return (f & regex_flags::MULTILINE) == regex_flags::MULTILINE;
}

/**
* @brief Returns true if the given flags contain DOTALL.
*
* @param f Regex flags to check
* @return true if `f` includes DOTALL
*/
constexpr bool is_dotall(regex_flags const f)
{
return (f & regex_flags::DOTALL) == regex_flags::DOTALL;
}

/** @} */ // end of doxygen group
} // namespace strings
} // namespace cudf
23 changes: 16 additions & 7 deletions cpp/src/strings/contains.cu
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct contains_fn {
std::unique_ptr<column> contains_util(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
bool beginning_only = false,
rmm::cuda_stream_view stream = rmm::cuda_stream_default,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
Expand All @@ -75,7 +76,8 @@ std::unique_ptr<column> contains_util(
auto d_column = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
Expand Down Expand Up @@ -123,19 +125,21 @@ std::unique_ptr<column> contains_util(
std::unique_ptr<column> contains_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, false, stream, mr);
return contains_util(strings, pattern, flags, false, stream, mr);
}

std::unique_ptr<column> matches_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
return contains_util(strings, pattern, true, stream, mr);
return contains_util(strings, pattern, flags, true, stream, mr);
}

} // namespace detail
Expand All @@ -144,18 +148,20 @@ std::unique_ptr<column> matches_re(

std::unique_ptr<column> contains_re(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::contains_re(strings, pattern, rmm::cuda_stream_default, mr);
return detail::contains_re(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

std::unique_ptr<column> matches_re(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::matches_re(strings, pattern, rmm::cuda_stream_default, mr);
return detail::matches_re(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

namespace detail {
Expand Down Expand Up @@ -190,6 +196,7 @@ struct count_fn {
std::unique_ptr<column> count_re(
strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource())
{
Expand All @@ -198,7 +205,8 @@ std::unique_ptr<column> count_re(
auto d_column = *strings_column;

// compile regex into device object
auto prog = reprog_device::create(pattern, get_character_flags_table(), strings_count, stream);
auto prog =
reprog_device::create(pattern, flags, get_character_flags_table(), strings_count, stream);
auto d_prog = *prog;

// create the output column
Expand Down Expand Up @@ -247,10 +255,11 @@ std::unique_ptr<column> count_re(

std::unique_ptr<column> count_re(strings_column_view const& strings,
std::string const& pattern,
regex_flags const flags,
rmm::mr::device_memory_resource* mr)
{
CUDF_FUNC_RANGE();
return detail::count_re(strings, pattern, rmm::cuda_stream_default, mr);
return detail::count_re(strings, pattern, flags, rmm::cuda_stream_default, mr);
}

} // namespace strings
Expand Down
56 changes: 44 additions & 12 deletions cpp/src/strings/regex/regcomp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <strings/regex/regcomp.h>

#include <cudf/utilities/error.hpp>

#include <string.h>
Expand Down Expand Up @@ -523,6 +524,8 @@ class regex_compiler {
bool lastwasand;
int nbra;

regex_flags flags;

inline void pushand(int f, int l) { andstack.push_back({f, l}); }

inline Node popand(int op)
Expand Down Expand Up @@ -664,10 +667,13 @@ class regex_compiler {
{
if (lastwasand) Operator(CAT); /* catenate is implicit */
int inst_id = m_prog.add_inst(t);
if (t == CCLASS || t == NCCLASS)
if (t == CCLASS || t == NCCLASS) {
m_prog.inst_at(inst_id).u1.cls_id = yyclass_id;
else if (t == CHAR || t == BOL || t == EOL)
} else if (t == CHAR) {
m_prog.inst_at(inst_id).u1.c = yy;
} else if (t == BOL || t == EOL) {
m_prog.inst_at(inst_id).u1.c = is_multiline(flags) ? yy : '\n';
}
pushand(inst_id, inst_id);
lastwasand = true;
}
Expand Down Expand Up @@ -766,13 +772,20 @@ class regex_compiler {
}

public:
regex_compiler(const char32_t* pattern, int dot_type, reprog& prog)
: m_prog(prog), cursubid(0), pushsubid(0), lastwasand(false), nbra(0), yy(0), yyclass_id(0)
regex_compiler(const char32_t* pattern, regex_flags const flags, reprog& prog)
: m_prog(prog),
cursubid(0),
pushsubid(0),
lastwasand(false),
nbra(0),
flags(flags),
yy(0),
yyclass_id(0)
{
// Parse
std::vector<regex_parser::Item> items;
{
regex_parser parser(pattern, dot_type, m_prog);
regex_parser parser(pattern, is_dotall(flags) ? ANYNL : ANY, m_prog);

// Expand counted repetitions
if (parser.m_has_counted)
Expand Down Expand Up @@ -822,11 +835,11 @@ class regex_compiler {
};

// Convert pattern into program
reprog reprog::create_from(const char32_t* pattern)
reprog reprog::create_from(const char32_t* pattern, regex_flags const flags)
{
reprog rtn;
regex_compiler compiler(pattern, ANY, rtn); // future feature: ANYNL
// for debugging, it can be helpful to call rtn.print() here to dump
regex_compiler compiler(pattern, flags, rtn);
// for debugging, it can be helpful to call rtn.print(flags) here to dump
// out the instructions that have been created from the given pattern
return rtn;
}
Expand Down Expand Up @@ -913,9 +926,10 @@ void reprog::optimize2()
_startinst_ids.push_back(-1); // terminator mark
}

#ifndef NDEBUG
void reprog::print()
#ifndef NDBUG
void reprog::print(regex_flags const flags)
{
printf("Flags = 0x%08x\n", static_cast<uint32_t>(flags));
printf("Instructions:\n");
for (std::size_t i = 0; i < _insts.size(); i++) {
const reinst& inst = _insts[i];
Expand Down Expand Up @@ -943,8 +957,26 @@ void reprog::print()
case ANY: printf("ANY, nextid= %d", inst.u2.next_id); break;
case ANYNL: printf("ANYNL, nextid= %d", inst.u2.next_id); break;
case NOP: printf("NOP, nextid= %d", inst.u2.next_id); break;
case BOL: printf("BOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); break;
case EOL: printf("EOL, c = '%c', nextid= %d", inst.u1.c, inst.u2.next_id); break;
case BOL: {
printf("BOL, c = ");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
break;
}
case EOL: {
printf("EOL, c = ");
if (inst.u1.c == '\n') {
printf("'\\n'");
} else {
printf("'%c'", inst.u1.c);
}
printf(", nextid= %d", inst.u2.next_id);
break;
}
case CCLASS: printf("CCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id); break;
case NCCLASS:
printf("NCCLASS, cls_id=%d , nextid= %d", inst.u1.cls_id, inst.u2.next_id);
Expand Down
7 changes: 5 additions & 2 deletions cpp/src/strings/regex/regcomp.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
* limitations under the License.
*/
#pragma once

#include <cudf/strings/regex/flags.hpp>

#include <string>
#include <vector>

Expand Down Expand Up @@ -89,7 +92,7 @@ class reprog {
* @brief Parses the given regex pattern and compiles
* into a list of chained instructions.
*/
static reprog create_from(const char32_t* pattern);
static reprog create_from(const char32_t* pattern, regex_flags const flags);

int32_t add_inst(int32_t type);
int32_t add_inst(reinst inst);
Expand All @@ -113,7 +116,7 @@ class reprog {

void optimize1();
void optimize2();
void print(); // for debugging
void print(regex_flags const flags);

private:
std::vector<reinst> _insts;
Expand Down
Loading

0 comments on commit d073ecb

Please sign in to comment.