diff --git a/polars/polars-ops/src/chunked_array/strings/namespace.rs b/polars/polars-ops/src/chunked_array/strings/namespace.rs index 360b37706958..1643540af7e5 100644 --- a/polars/polars-ops/src/chunked_array/strings/namespace.rs +++ b/polars/polars-ops/src/chunked_array/strings/namespace.rs @@ -8,7 +8,7 @@ use polars_arrow::export::arrow::compute::substring::substring; use polars_arrow::export::arrow::{self}; use polars_arrow::kernels::string::*; use polars_core::export::num::Num; -use polars_core::export::regex::{escape, Regex}; +use polars_core::export::regex::{escape, NoExpand, Regex}; use super::*; #[cfg(feature = "string_encoding")] @@ -207,8 +207,8 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { /// Replace the leftmost regex-matched (sub)string with another string; take /// fast-path for small (<= 32 chars) strings (otherwise regex faster). fn replace<'a>(&'a self, pat: &str, val: &str) -> PolarsResult { - let lit = pat.chars().all(|c| !c.is_ascii_punctuation()); - let ca = self.as_utf8(); + let lit = !(pat.chars().any(|c| c.is_ascii_punctuation()) + | val.chars().any(|c| c.is_ascii_punctuation())); let reg = Regex::new(pat)?; let f = |s: &'a str| { if lit && (s.len() <= 32) { @@ -217,25 +217,36 @@ pub trait Utf8NameSpaceImpl: AsUtf8 { reg.replace(s, val) } }; + let ca = self.as_utf8(); Ok(ca.apply(f)) } /// Replace the leftmost literal (sub)string with another string - fn replace_literal(&self, pat: &str, val: &str) -> PolarsResult { - self.replace(escape(pat).as_str(), val) + fn replace_literal<'a>(&'a self, pat: &str, val: &str) -> PolarsResult { + let reg = Regex::new(escape(pat).as_str())?; + let f = |s: &'a str| { + if s.len() <= 32 { + Cow::Owned(s.replacen(pat, val, 1)) + } else { + reg.replace(s, NoExpand(val)) + } + }; + let ca = self.as_utf8(); + Ok(ca.apply(f)) } /// Replace all regex-matched (sub)strings with another string fn replace_all(&self, pat: &str, val: &str) -> PolarsResult { let ca = self.as_utf8(); let reg = Regex::new(pat)?; - let f = |s| reg.replace_all(s, val); - Ok(ca.apply(f)) + Ok(ca.apply(|s| reg.replace_all(s, val))) } /// Replace all matching literal (sub)strings with another string fn replace_literal_all(&self, pat: &str, val: &str) -> PolarsResult { - self.replace_all(escape(pat).as_str(), val) + let ca = self.as_utf8(); + let reg = Regex::new(escape(pat).as_str())?; + Ok(ca.apply(|s| reg.replace_all(s, NoExpand(val)))) } /// Extract the nth capture group from pattern diff --git a/py-polars/tests/unit/namespaces/test_string.py b/py-polars/tests/unit/namespaces/test_string.py index 366b24ac12ae..312f8b043359 100644 --- a/py-polars/tests/unit/namespaces/test_string.py +++ b/py-polars/tests/unit/namespaces/test_string.py @@ -299,6 +299,13 @@ def test_replace() -> None: (r"^\(", "[", True, ["* * text", "(with) special\n * chars **etc...?$"]), (r"t$", "an", False, ["* * texan", "(with) special\n * chars **etc...?$"]), (r"t$", "an", True, ["* * text", "(with) special\n * chars **etc...?$"]), + (r"(with) special", "$1", True, ["* * text", "$1\n * chars **etc...?$"]), + ( + r"\((with)\) special", + ":$1:", + False, + ["* * text", ":with:\n * chars **etc...?$"], + ), ): # series assert ( @@ -315,23 +322,38 @@ def test_replace() -> None: )["text"].to_list() ) + assert pl.Series(["."]).str.replace(".", "$0", literal=True)[0] == "$0" + assert pl.Series(["(.)(?)"]).str.replace(".", "$1", literal=True)[0] == "($1)(?)" + def test_replace_all() -> None: df = pl.DataFrame( - data=[(1, "* * text"), (2, "(with) special * chars **etc...?$")], + data=[(1, "* * text"), (2, "(with) special\n * chars **etc...?$")], schema=["idx", "text"], orient="row", ) for pattern, replacement, as_literal, expected in ( - (r"\*", "-", False, ["- - text", "(with) special - chars --etc...?$"]), - (r"*", "-", True, ["- - text", "(with) special - chars --etc...?$"]), + (r"\*", "-", False, ["- - text", "(with) special\n - chars --etc...?$"]), + (r"*", "-", True, ["- - text", "(with) special\n - chars --etc...?$"]), (r"\W", "", False, ["text", "withspecialcharsetc"]), - (r".?$", "", True, ["* * text", "(with) special * chars **etc.."]), + (r".?$", "", True, ["* * text", "(with) special\n * chars **etc.."]), + ( + r"(with) special", + "$1", + True, + ["* * text", "$1\n * chars **etc...?$"], + ), + ( + r"\((with)\) special", + ":$1:", + False, + ["* * text", ":with:\n * chars **etc...?$"], + ), ( r"(\b)[\w\s]{2,}(\b)", "$1(blah)$3", False, - ["* * (blah)", "((blah)) (blah) * (blah) **(blah)...?$"], + ["* * (blah)", "((blah)) (blah)\n * (blah) **(blah)...?$"], ), ): # series @@ -352,6 +374,15 @@ def test_replace_all() -> None: with pytest.raises(pl.ComputeError): df["text"].str.replace_all("*", "") + assert ( + pl.Series([r"(.)(\?)(\?)"]).str.replace_all("\?", "$0", literal=True)[0] + == "(.)($0)($0)" + ) + assert ( + pl.Series([r"(.)(\?)(\?)"]).str.replace_all("\?", "$0", literal=False)[0] + == "(.)(\?)(\?)" + ) + def test_replace_expressions() -> None: df = pl.DataFrame({"foo": ["123 bla 45 asd", "xyz 678 910t"], "value": ["A", "B"]})