From 90df76469ebd33018251283506f7878bb72d0d8f Mon Sep 17 00:00:00 2001 From: Gagandeep Singh Date: Fri, 7 Jul 2023 20:30:54 +0530 Subject: [PATCH] Accept dtype argument in numpy.array --- src/lpython/semantics/python_ast_to_asr.cpp | 43 +++++++++++++++++++-- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/src/lpython/semantics/python_ast_to_asr.cpp b/src/lpython/semantics/python_ast_to_asr.cpp index 2d41145d2c..40f032e93a 100644 --- a/src/lpython/semantics/python_ast_to_asr.cpp +++ b/src/lpython/semantics/python_ast_to_asr.cpp @@ -505,6 +505,10 @@ class CommonVisitor : public AST::BaseVisitor { // Stores the name of imported functions and the modules they are imported from std::map imported_functions; + std::map numpy2lpythontypes = { + {"int8", "i8"}, + }; + CommonVisitor(Allocator &al, LocationManager &lm, SymbolTable *symbol_table, diag::Diagnostics &diagnostics, bool main_module, std::string module_name, std::map &ast_overload, std::string parent_dir, @@ -7520,16 +7524,45 @@ class BodyVisitor : public CommonVisitor { tmp = ASR::make_UnsignedIntegerBitNot_t(al, x.base.base.loc, operand, operand_type, value); return; } else if( call_name == "array" ) { - parse_args(x, args); + ASR::ttype_t* type = nullptr; + if( x.n_keywords == 0 ) { + parse_args(x, args); + } else { + args.reserve(al, 1); + visit_expr_list(x.m_args, x.n_args, args); + if( x.n_keywords > 1 ) { + throw SemanticError("More than one keyword " + "arguments aren't recognised by array", + x.base.base.loc); + } + if( std::string(x.m_keywords[0].m_arg) != "dtype" ) { + throw SemanticError("Unrecognised keyword argument, " + + std::string(x.m_keywords[0].m_arg), x.base.base.loc); + } + std::string dtype_np = ""; + if( AST::is_a(*x.m_keywords[0].m_value) ) { + AST::Name_t* name_t = AST::down_cast(x.m_keywords[0].m_value); + dtype_np = name_t->m_id; + } else { + LCOMPILERS_ASSERT(false); + } + LCOMPILERS_ASSERT(numpy2lpythontypes.find(dtype_np) != numpy2lpythontypes.end()); + Vec dims; + dims.n = 0; + type = get_type_from_var_annotation( + numpy2lpythontypes[dtype_np], x.base.base.loc, dims); + } if( args.size() != 1 ) { throw SemanticError("array accepts only 1 argument for now, got " + std::to_string(args.size()) + " arguments instead.", x.base.base.loc); } ASR::expr_t *arg = args[0].m_value; - ASR::ttype_t *type = ASRUtils::expr_type(arg); + if( type == nullptr ) { + type = ASRUtils::expr_type(arg); + } if(ASR::is_a(*arg)) { - type = ASR::down_cast(type)->m_type; + type = ASRUtils::get_contained_type(type); ASR::ListConstant_t* list = ASR::down_cast(arg); ASR::expr_t **m_args = list->m_args; size_t n_args = list->n_args; @@ -7544,6 +7577,10 @@ class BodyVisitor : public CommonVisitor { dims.push_back(al, dim); type = ASRUtils::make_Array_t_util(al, x.base.base.loc, type, dims.p, dims.size(), ASR::abiType::Source, false, ASR::array_physical_typeType::PointerToDataArray, true); + for( size_t i = 0; i < n_args; i++ ) { + m_args[i] = CastingUtil::perform_casting(m_args[i], ASRUtils::expr_type(m_args[i]), + ASRUtils::type_get_past_array(type), al, x.base.base.loc); + } tmp = ASR::make_ArrayConstant_t(al, x.base.base.loc, m_args, n_args, type, ASR::arraystorageType::RowMajor); } else { throw SemanticError("array accepts only list for now, got " +