Skip to content

Commit

Permalink
Added implementation for constant strings and tests for str.count
Browse files Browse the repository at this point in the history
  • Loading branch information
advikkabra authored and certik committed Feb 17, 2024
1 parent f1e2e3a commit 12c88e7
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 2 deletions.
21 changes: 21 additions & 0 deletions integration_tests/test_str_attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,26 @@ def find():
assert s2.find("we") == -1
assert "".find("") == 0

def count():
s: str
sub: str
s = "ABC ABCDAB ABCDABCDABDE"
sub = "ABC"
assert s.count(sub) == 4
assert s.count("ABC") == 4

sub = "AB"
assert s.count(sub) == 6
assert s.count("AB") == 6

sub = "ABC"
assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 4
assert "ABC ABCDAB ABCDABCDABDE".count("ABC") == 4

sub = "AB"
assert "ABC ABCDAB ABCDABCDABDE".count(sub) == 6
assert "ABC ABCDAB ABCDABCDABDE".count("AB") == 6


def startswith():
s: str
Expand Down Expand Up @@ -307,6 +327,7 @@ def check():
strip()
swapcase()
find()
count()
startswith()
endswith()
partition()
Expand Down
38 changes: 38 additions & 0 deletions src/libasr/asr_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4242,6 +4242,44 @@ static inline int KMP_string_match(std::string &s_var, std::string &sub) {
return res;
}

static inline int KMP_string_match_count(std::string &s_var, std::string &sub) {
int str_len = s_var.size();
int sub_len = sub.size();
int count = 0;
std::vector<int> lps(sub_len, 0);
if (sub_len == 0) {
count = str_len + 1;
} else {
for(int i = 1, len = 0; i < sub_len;) {
if (sub[i] == sub[len]) {
lps[i++] = ++len;
} else {
if (len != 0) {
len = lps[len - 1];
} else {
lps[i++] = 0;
}
}
}
for (int i = 0, j = 0; (str_len - i) >= (sub_len - j);) {
if (sub[j] == s_var[i]) {
j++, i++;
}
if (j == sub_len) {
count++;
j = lps[j - 1];
} else if (i < str_len && sub[j] != s_var[i]) {
if (j != 0) {
j = lps[j - 1];
} else {
i = i + 1;
}
}
}
}
return count;
}

static inline void visit_expr_list(Allocator &al, Vec<ASR::call_arg_t>& exprs,
Vec<ASR::expr_t*>& exprs_vec) {
LCOMPILERS_ASSERT(exprs_vec.reserve_called);
Expand Down
39 changes: 37 additions & 2 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6872,13 +6872,13 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
}
} else if (attr_name == "find") {
if (args.size() != 1) {
throw SemanticError("str.find() takes one arguments",
throw SemanticError("str.find() takes one argument",
loc);
}
ASR::expr_t *arg = args[0].m_value;
ASR::ttype_t *type = ASRUtils::expr_type(arg);
if (!ASRUtils::is_character(*type)) {
throw SemanticError("str.find() takes one arguments of type: str",
throw SemanticError("str.find() takes one argument of type: str",
arg->base.loc);
}
if (ASRUtils::expr_value(arg) != nullptr) {
Expand All @@ -6905,6 +6905,41 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_find", loc);
}
return;
} else if (attr_name == "count") {
if (args.size() != 1) {
throw SemanticError("str.count() takes one argument",
loc);
}
ASR::expr_t *arg = args[0].m_value;
ASR::ttype_t *type = ASRUtils::expr_type(arg);
if (!ASRUtils::is_character(*type)) {
throw SemanticError("str.count() takes one argument of type: str",
arg->base.loc);
}
if (ASRUtils::expr_value(arg) != nullptr) {
ASR::StringConstant_t* sub_str_con = ASR::down_cast<ASR::StringConstant_t>(arg);
std::string sub = sub_str_con->m_s;
int res = ASRUtils::KMP_string_match_count(s_var, sub);
tmp = ASR::make_IntegerConstant_t(al, loc, res,
ASRUtils::TYPE(ASR::make_Integer_t(al,loc, 4)));
} else {
ASR::symbol_t *fn_div = resolve_intrinsic_function(loc, "_lpython_str_count");
Vec<ASR::call_arg_t> args;
args.reserve(al, 1);
ASR::call_arg_t str_arg;
str_arg.loc = loc;
ASR::ttype_t *str_type = ASRUtils::TYPE(ASR::make_Character_t(al, loc,
1, s_var.size(), nullptr));
str_arg.m_value = ASRUtils::EXPR(
ASR::make_StringConstant_t(al, loc, s2c(al, s_var), str_type));
ASR::call_arg_t sub_arg;
sub_arg.loc = loc;
sub_arg.m_value = arg;
args.push_back(al, str_arg);
args.push_back(al, sub_arg);
tmp = make_call_helper(al, fn_div, current_scope, args, "_lpython_str_count", loc);
}
return;
} else if (attr_name == "rstrip") {
if (args.size() != 0) {
throw SemanticError("str.rstrip() takes no arguments",
Expand Down

0 comments on commit 12c88e7

Please sign in to comment.