Skip to content

Commit

Permalink
Accept dtype argument in numpy.array
Browse files Browse the repository at this point in the history
  • Loading branch information
czgdp1807 committed Aug 2, 2023
1 parent c3314f7 commit 90df764
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions src/lpython/semantics/python_ast_to_asr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,6 +505,10 @@ class CommonVisitor : public AST::BaseVisitor<Struct> {
// Stores the name of imported functions and the modules they are imported from
std::map<std::string, std::string> imported_functions;

std::map<std::string, std::string> numpy2lpythontypes = {
{"int8", "i8"},
};

CommonVisitor(Allocator &al, LocationManager &lm, SymbolTable *symbol_table,
diag::Diagnostics &diagnostics, bool main_module, std::string module_name,
std::map<int, ASR::symbol_t*> &ast_overload, std::string parent_dir,
Expand Down Expand Up @@ -7520,16 +7524,45 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value);
return;
} else if( call_name == "array" ) {
parse_args(x, args);
ASR::ttype_t* type = nullptr;
if( x.n_keywords == 0 ) {
parse_args(x, args);
} else {
args.reserve(al, 1);
visit_expr_list(x.m_args, x.n_args, args);
if( x.n_keywords > 1 ) {
throw SemanticError("More than one keyword "
"arguments aren't recognised by array",
x.base.base.loc);
}
if( std::string(x.m_keywords[0].m_arg) != "dtype" ) {
throw SemanticError("Unrecognised keyword argument, " +
std::string(x.m_keywords[0].m_arg), x.base.base.loc);
}
std::string dtype_np = "";
if( AST::is_a<AST::Name_t>(*x.m_keywords[0].m_value) ) {
AST::Name_t* name_t = AST::down_cast<AST::Name_t>(x.m_keywords[0].m_value);
dtype_np = name_t->m_id;
} else {
LCOMPILERS_ASSERT(false);
}
LCOMPILERS_ASSERT(numpy2lpythontypes.find(dtype_np) != numpy2lpythontypes.end());
Vec<ASR::dimension_t> dims;
dims.n = 0;
type = get_type_from_var_annotation(
numpy2lpythontypes[dtype_np], x.base.base.loc, dims);
}
if( args.size() != 1 ) {
throw SemanticError("array accepts only 1 argument for now, got " +
std::to_string(args.size()) + " arguments instead.",
x.base.base.loc);
}
ASR::expr_t *arg = args[0].m_value;
ASR::ttype_t *type = ASRUtils::expr_type(arg);
if( type == nullptr ) {
type = ASRUtils::expr_type(arg);
}
if(ASR::is_a<ASR::ListConstant_t>(*arg)) {
type = ASR::down_cast<ASR::List_t>(type)->m_type;
type = ASRUtils::get_contained_type(type);
ASR::ListConstant_t* list = ASR::down_cast<ASR::ListConstant_t>(arg);
ASR::expr_t **m_args = list->m_args;
size_t n_args = list->n_args;
Expand All @@ -7544,6 +7577,10 @@ class BodyVisitor : public CommonVisitor<BodyVisitor> {
dims.push_back(al, dim);
type = ASRUtils::make_Array_t_util(al, x.base.base.loc, type, dims.p, dims.size(),
ASR::abiType::Source, false, ASR::array_physical_typeType::PointerToDataArray, true);
for( size_t i = 0; i < n_args; i++ ) {
m_args[i] = CastingUtil::perform_casting(m_args[i], ASRUtils::expr_type(m_args[i]),
ASRUtils::type_get_past_array(type), al, x.base.base.loc);
}
tmp = ASR::make_ArrayConstant_t(al, x.base.base.loc, m_args, n_args, type, ASR::arraystorageType::RowMajor);
} else {
throw SemanticError("array accepts only list for now, got " +
Expand Down

0 comments on commit 90df764

Please sign in to comment.