diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 752e630..065663c 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -83,3 +83,16 @@ jobs: export CPATH=$CONDA_PREFIX/include:$CPATH ./integration_tests/run_tests.py -b gcc llvm wasm c ./integration_tests/run_tests.py -b gcc llvm wasm c -f + + - name: Test4 (Linux) + shell: bash -l -e {0} + if: contains(matrix.os, 'ubuntu') + run: | + export CPATH=$CONDA_PREFIX/include:$CPATH + conda install --yes pytorch::pytorch + cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/include/* $CONDA_PREFIX/include/ + cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/lib/* $CONDA_PREFIX/lib/ + cp -r $CONDA_PREFIX/lib/python3.12/site-packages/torch/share/* $CONDA_PREFIX/share/ + ./integration_tests/run_tests.py -b pytorch + export CPATH=$CPATH:$CONDA_PREFIX/include/torch/csrc/api/include + ./integration_tests/run_tests.py -b llvmPytorch diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 0ed90ee..c7b8ea8 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -84,7 +84,7 @@ macro(RUN_UTIL RUN_FAIL RUN_NAME RUN_FILE_NAME RUN_LABELS RUN_EXTRAFILES RUN_EXT if (ADD_TEST) if ((LC_BACKEND STREQUAL "llvm") OR (LC_BACKEND STREQUAL "cpp") OR (LC_BACKEND STREQUAL "x86") - OR (LC_BACKEND STREQUAL "c") OR (LC_BACKEND STREQUAL "fortran")) + OR (LC_BACKEND STREQUAL "c") OR (LC_BACKEND STREQUAL "fortran") OR (LC_BACKEND STREQUAL "llvmPytorch")) add_custom_command( OUTPUT ${name}.o COMMAND ${LC} -c ${CMAKE_CURRENT_SOURCE_DIR}/${file_name} -o ${name}.o ${extra_args} @@ -110,6 +110,16 @@ macro(RUN_UTIL RUN_FAIL RUN_NAME RUN_FILE_NAME RUN_LABELS RUN_EXTRAFILES RUN_EXT endif() set(WASM_EXEC_FLAGS ${WASM_EXEC_FLAGS} "--experimental-wasi-unstable-preview1") add_test(${name} ${WASM_EXEC_RUNTIME} ${WASM_EXEC_FLAGS} ${CURRENT_BINARY_DIR}/${name}.js) + elseif (LC_BACKEND STREQUAL "pytorch") + # PyTorch C++ API testing + find_package(Torch REQUIRED) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}") + + add_executable(${name} ${file_name} ${extra_files}) + target_link_libraries(${name} "${TORCH_LIBRARIES}") + target_compile_options(${name} PUBLIC ${gcc_args}) + set_property(TARGET ${name} PROPERTY CXX_STANDARD 17) + add_test(${name} ${CURRENT_BINARY_DIR}/${name}) else () add_executable(${name} ${file_name} ${extra_files}) target_compile_options(${name} PUBLIC ${gcc_args}) @@ -138,7 +148,7 @@ macro(RUN) "${multiValueArgs}" ${ARGN} ) foreach(b ${RUN_LABELS}) - if (NOT (b MATCHES "^(llvm|llvm2|llvm_rtlib|gcc|c|cpp|x86|wasm|gfortran|llvmImplicit|llvmStackArray|fortran|c_nopragma|llvm_nopragma)$")) + if (NOT (b MATCHES "^(llvm|llvm2|llvm_rtlib|gcc|c|cpp|x86|wasm|gfortran|llvmImplicit|llvmStackArray|fortran|c_nopragma|llvm_nopragma|pytorch|llvmPytorch)$")) message(FATAL_ERROR "Unsupported backend: ${b}") endif() endforeach() @@ -241,3 +251,5 @@ RUN(NAME vector_02.cpp LABELS gcc llvm) RUN(NAME loop_01.cpp LABELS gcc llvm NOFAST) RUN(NAME test_pkg_lnn_01.cpp LABELS gcc llvm NOFAST) + +RUN(NAME pytorch_01.cpp LABELS pytorch llvmPytorch) diff --git a/integration_tests/pytorch_01.cpp b/integration_tests/pytorch_01.cpp new file mode 100644 index 0000000..786e5cf --- /dev/null +++ b/integration_tests/pytorch_01.cpp @@ -0,0 +1,19 @@ +#include +#include + +void check(const torch::Tensor& tensor=torch::empty({1})) { + float array[5] = {4.0, 2.0, 2.0, 12.0, 2.0}; + std::cout << tensor << "\n"; + if( torch::any(torch::abs(tensor - torch::from_blob(array, {5})) > 1e-8).item() ) { + exit(2); + } +} + +int main() { + torch::Tensor tensor = torch::ones(5); + tensor[0] = 2.0; + tensor[3] = 6.0; + tensor = 2 * tensor; + check(tensor); + std::cout << tensor << "\n"; +} diff --git a/integration_tests/run_tests.py b/integration_tests/run_tests.py index 9fbeac8..0541bf8 100755 --- a/integration_tests/run_tests.py +++ b/integration_tests/run_tests.py @@ -8,7 +8,7 @@ NO_OF_THREADS = 8 # default no of threads is 8 SUPPORTED_BACKENDS = ['llvm', 'llvm2', 'llvm_rtlib', 'c', 'cpp', 'x86', 'wasm', 'gcc', 'llvmImplicit', 'llvmStackArray', 'fortran', - 'c_nopragma', 'llvm_nopragma'] + 'c_nopragma', 'llvm_nopragma', 'pytorch', 'llvmPytorch'] BASE_DIR = os.path.dirname(os.path.realpath(__file__)) LC_PATH = f"{BASE_DIR}/../src/bin:$PATH" diff --git a/src/lc/clang_ast_to_asr.cpp b/src/lc/clang_ast_to_asr.cpp index 6212cf8..f984155 100644 --- a/src/lc/clang_ast_to_asr.cpp +++ b/src/lc/clang_ast_to_asr.cpp @@ -44,6 +44,11 @@ enum SpecialFunc { Clear, Data, Reserve, + + TorchOnes, + TorchEmpty, + TorchFromBlob, + TorchTensorItem, }; std::map special_function_map = { @@ -77,6 +82,13 @@ std::map special_function_map = { {"clear", SpecialFunc::Clear}, {"data", SpecialFunc::Data}, {"reserve", SpecialFunc::Reserve}, + + {"torch::ones", SpecialFunc::TorchOnes}, + {"torch::empty", SpecialFunc::TorchEmpty}, + {"torch::from_blob", SpecialFunc::TorchFromBlob}, + {"torch::abs", SpecialFunc::Abs}, + {"torch::any", SpecialFunc::Any}, + {"item", SpecialFunc::TorchTensorItem}, }; class OneTimeUseString { @@ -155,6 +167,7 @@ class OneTimeUseASRNode { enum ThirdPartyCPPArrayTypes { XTensorArray, MDSpanArray, + PyTorchArray, }; class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor { @@ -587,19 +600,37 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorgetTypeClass() == clang::Type::TypeClass::Record ) { const clang::CXXRecordDecl* record_type = clang_type->getAsCXXRecordDecl(); std::string name = record_type->getNameAsString(); + std::string qualified_name = record_type->getQualifiedNameAsString(); if( name == "xtensor_container" || name == "vector" || name == "mdspan" ) { return nullptr; } ASR::symbol_t* type_t = current_scope->resolve_symbol(name); if( !type_t ) { - throw std::runtime_error(name + " not defined."); + if( qualified_name == "at::Tensor" ) { + if( array_type == nullptr || is_third_party_cpp_array == nullptr ) { + throw std::runtime_error("IEC: array_type and is_third_party_cpp_array couldn't be set."); + } + *is_third_party_cpp_array = true; + *array_type = ThirdPartyCPPArrayTypes::PyTorchArray; + type = ASRUtils::TYPE(ASR::make_Array_t(al, l, ASRUtils::TYPE(ASR::make_Real_t(al, l, 8)), + nullptr, 0, ASR::array_physical_typeType::DescriptorArray)); + return type; + } else if( qualified_name == "c10::Scalar" ) { + return nullptr; + } + throw std::runtime_error(qualified_name + " not defined."); } if( clang_type->isUnionType() ) { type = ASRUtils::TYPE(ASR::make_Union_t(al, l, type_t)); } else { type = ASRUtils::TYPE(ASR::make_Struct_t(al, l, type_t)); } - } else if( clang_type->getTypeClass() == clang::Type::TypeClass::SubstTemplateTypeParm ) { + } else if( clang_type->getTypeClass() == clang::Type::TypeClass::Using ) { + const clang::UsingType* using_type = clang_type->getAs(); + return ClangTypeToASRType(using_type->getUnderlyingType(), xshape_result, + array_type, is_third_party_cpp_array); + } else if( clang_type->getTypeClass() == clang::Type::TypeClass::SubstTemplateTypeParm || + clang_type->getTypeClass() == clang::Type::TypeClass::Typedef ) { return nullptr; } else { throw std::runtime_error("clang::QualType not yet supported " + @@ -771,7 +802,18 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorget_unique_name("param"); } - ASR::ttype_t* type = ClangTypeToASRType(x->getType()); + + bool is_third_party_array_type = false; + ThirdPartyCPPArrayTypes array_type; + Vec shape_result; shape_result.reserve(al, 1); + ASR::ttype_t* type = ClangTypeToASRType(x->getType(), &shape_result, + &array_type, &is_third_party_array_type); + if( is_third_party_array_type && + array_type == ThirdPartyCPPArrayTypes::PyTorchArray ) { + if( !x->getDefaultArg() ) { + throw std::runtime_error("torch::Tensor type arguments must have default arguments."); + } + } ASR::intentType intent_type = ASR::intentType::InOut; if( ASR::is_a(*type) ) { intent_type = ASR::intentType::In; @@ -784,20 +826,30 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor(tmp.get()); + current_scope->add_symbol(name, tmp_sym); + ASR::asr_t* tmp_ = ASR::make_Var_t(al, Lloc(x), tmp_sym); + clang::Expr *init = x->getDefaultArg(); ASR::expr_t* asr_init = nullptr; if (init) { + ASR::expr_t* assignment_target_copy = assignment_target; + assignment_target = ASRUtils::EXPR(tmp_); TraverseStmt(init); - asr_init = ASRUtils::EXPR(tmp.get()); + if( tmp != nullptr && !is_stmt_created ) { + asr_init = ASRUtils::EXPR(tmp.get()); + } } - tmp = ASR::make_Variable_t(al, Lloc(x), current_scope, s2c(al, name), - nullptr, 0, ASR::intentType::InOut, asr_init, nullptr, - ASR::storage_typeType::Default, type, nullptr, ASR::abiType::Source, - ASR::accessType::Public, ASR::presenceType::Required, false); - ASR::symbol_t* tmp_sym = ASR::down_cast(tmp.get()); - current_scope->add_symbol(name, tmp_sym); - tmp = ASR::make_Var_t(al, Lloc(x), tmp_sym); + // TODO: For PyTorch tensor create an intrinsic empty + // and then fill the initialiser value with a call + // to that intrinsic. + + tmp = tmp_; is_stmt_created = false; return true; } @@ -1042,7 +1094,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor(*ASRUtils::extract_type( - ASRUtils::expr_type(obj))) ) { + ASRUtils::expr_type(obj))) || + ASR::is_a(*obj) ) { TraverseStmt(args[1]); if( !is_stmt_created ) { ASR::expr_t* value = ASRUtils::EXPR(tmp.get()); @@ -1112,6 +1165,9 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorm_v)); } if (special_function_map.find(func_name) == special_function_map.end()) { + if( current_scope->resolve_symbol(func_name) == nullptr ) { + throw std::runtime_error("ICE: " + func_name + " is not handled yet in LC."); + } return false; } SpecialFunc sf = special_function_map[func_name]; @@ -1310,6 +1366,47 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor(callee); + is_stmt_created = false; + } else if (sf == SpecialFunc::TorchOnes) { + if( args.size() != 2 ) { // second one is TorchOptions, to be ignored + throw std::runtime_error("torch::ones should be called with only one argument."); + } + + ASR::expr_t* shape_arg = args.p[0]; + ASR::expr_t* one = ASRUtils::get_constant_one_with_given_type( + al, ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 8))); + Vec expr_dims; expr_dims.reserve(al, 1); + if( ASR::is_a(*shape_arg) ) { + ASR::dimension_t expr_dim; + expr_dim.loc = Lloc(x); + expr_dim.m_start = ASRUtils::get_constant_zero_with_given_type( + al, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))); + expr_dim.m_length = shape_arg; + expr_dims.push_back(al, expr_dim); + ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), + ASRUtils::TYPE(ASR::make_Real_t(al, Lloc(x), 8)), expr_dims.p, + expr_dims.size(), ASR::array_physical_typeType::FixedSizeArray)); + int num_ones = ASR::down_cast(shape_arg)->m_n; + Vec ones_vec; ones_vec.reserve(al, num_ones); + for( size_t onei = 0; onei < num_ones; onei++ ) { + ones_vec.push_back(al, one); + } + tmp = ASR::make_ArrayConstant_t(al, Lloc(x), ones_vec.p, ones_vec.size(), + type, ASR::arraystorageType::RowMajor); + is_stmt_created = false; + } else if( ASR::is_a(*shape_arg) ) { + throw std::runtime_error("{...} not yet supported in torch::ones"); + } + is_stmt_created = false; } else if( sf == SpecialFunc::All ) { // Handles xt::all() - no arguments // Handle with argument case later. @@ -1443,6 +1540,59 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor(*args.p[0]) ) { + ASR::ArrayConstant_t* array_constant = ASR::down_cast(args.p[0]); + + Vec empty_dims; empty_dims.reserve(al, array_constant->n_args); + for( size_t idim = 0; idim < array_constant->n_args; idim++ ) { + ASR::dimension_t empty_dim; + empty_dim.loc = Lloc(x); + empty_dim.m_start = ASRUtils::get_constant_zero_with_given_type( + al, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))); + empty_dim.m_length = nullptr; + empty_dims.push_back(al, empty_dim); + } + ASR::ttype_t* type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), + ASRUtils::extract_type(ASRUtils::expr_type(assignment_target)), + empty_dims.p, empty_dims.size(), ASR::array_physical_typeType::DescriptorArray)); + type = ASRUtils::TYPE(ASR::make_Allocatable_t(al, Lloc(x), type)); + ASR::down_cast( + ASR::down_cast(assignment_target)->m_v)->m_type = type; + tmp = nullptr; + is_stmt_created = false; + } else { + throw std::runtime_error("Only {...} is allowed for supplying shape to xt::empty."); + } + } else if (sf == SpecialFunc::TorchFromBlob) { + if( args.size() < 2 ) { // Ignore the last one + throw std::runtime_error("torch::from must be provided with C array and its shape."); + } + + ASR::dimension_t* m_dims = nullptr; + size_t n_dims = ASRUtils::extract_dimensions_from_ttype(ASRUtils::expr_type(args.p[0]), m_dims); + if( ASR::is_a(*args.p[1]) ) { + ASR::ArrayConstant_t* array_constant = ASR::down_cast(args.p[1]); + + Vec empty_dims; empty_dims.reserve(al, array_constant->n_args); + for( size_t idim = 0; idim < array_constant->n_args; idim++ ) { + if( !ASRUtils::is_value_equal(array_constant->m_args[idim], m_dims[idim].m_length) ) { + throw std::runtime_error("ICE: Could not decipher the equality of the shape " + "and shape of the array provided in torch::from_blob"); + } + } + tmp = reinterpret_cast(args.p[0]); + is_stmt_created = false; + } else { + throw std::runtime_error("Only {...} is allowed for supplying shape to torch::from_blob."); + } } else if (sf == SpecialFunc::Iota) { tmp = ASR::make_ComplexConstant_t(al, Lloc(x), 0.0, 1.0, ASRUtils::TYPE(ASR::make_Complex_t(al, Lloc(x), 8))); @@ -1874,7 +2024,16 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorgetType()); + Vec xshape_result; xshape_result.reserve(al, 0); + ThirdPartyCPPArrayTypes array_type; bool is_third_party_array_type = false; + ASR::ttype_t *asr_type = ClangTypeToASRType(x->getType(), &xshape_result, + &array_type, &is_third_party_array_type); + if( is_third_party_array_type && + array_type == ThirdPartyCPPArrayTypes::PyTorchArray ) { + if( !x->hasInit() ) { + throw std::runtime_error("torch::Tensor variables must have initialiser value."); + } + } ASR::symbol_t *v = ASR::down_cast(ASR::make_Variable_t(al, Lloc(x), current_scope, s2c(al, name), nullptr, 0, ASR::intentType::Local, nullptr, nullptr, ASR::storage_typeType::Default, asr_type, nullptr, ASR::abiType::Source, @@ -1905,6 +2064,38 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorm_value = ASRUtils::expr_value(init_val); variable_t->m_storage = ASR::storage_typeType::Parameter; } else { + if( is_third_party_array_type ) { + if( array_type == ThirdPartyCPPArrayTypes::PyTorchArray ) { + ASR::dimension_t* dims = nullptr; + size_t n_dims = ASRUtils::extract_dimensions_from_ttype( + ASRUtils::expr_type(init_val), dims); + + Vec empty_dims; empty_dims.reserve(al, n_dims); + for( size_t dimi = 0; dimi < n_dims; dimi++ ) { + ASR::dimension_t empty_dim; + empty_dim.loc = Lloc(x); + empty_dim.m_start = ASRUtils::get_constant_zero_with_given_type( + al, ASRUtils::TYPE(ASR::make_Integer_t(al, Lloc(x), 4))); + empty_dim.m_length = nullptr; + empty_dims.push_back(al, empty_dim); + } + ASR::Variable_t* variable_t = ASR::down_cast(v); + variable_t->m_type = ASRUtils::TYPE(ASR::make_Array_t(al, Lloc(x), + ASRUtils::extract_type(variable_t->m_type), empty_dims.p, empty_dims.size(), + ASR::array_physical_typeType::DescriptorArray)); + variable_t->m_type = ASRUtils::TYPE(ASR::make_Allocatable_t( + al, Lloc(x), variable_t->m_type)); + + Vec alloc_args; alloc_args.reserve(al, 1); + ASR::alloc_arg_t alloc_arg; alloc_arg.loc = Lloc(x); + alloc_arg.m_a = var; + alloc_arg.m_dims = dims; alloc_arg.n_dims = n_dims; + alloc_arg.m_len_expr = nullptr; alloc_arg.m_type = nullptr; + alloc_args.push_back(al, alloc_arg); + current_body->push_back(al, ASRUtils::STMT(ASR::make_Allocate_t( + al, Lloc(x), alloc_args.p, alloc_args.size(), nullptr, nullptr, nullptr))); + } + } add_reshape_if_needed(init_val, var); tmp = ASR::make_Assignment_t(al, Lloc(x), var, init_val, nullptr); is_stmt_created = true; @@ -2153,10 +2344,10 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitorgetNameInfo().getAsString(); + std::string namespace_name = ""; + if( x->getQualifier() ) { + namespace_name = x->getQualifier()->getAsNamespace()->getNameAsString(); + } ASR::symbol_t* sym = resolve_symbol(name); if( name == "operator<<" || name == "cout" || name == "endl" || name == "operator()" || name == "operator+" || name == "operator=" || @@ -2476,7 +2671,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor=" || name == "operator!=" || name == "operator\"\"i" || name == "sin" || - name == "cos" || name == "amin" || name == "operator[]" || name == "sqrt" ) { + name == "cos" || name == "amin" || name == "operator[]" || name == "sqrt" || + name == "ones" || name == "from_blob" ) { if( sym != nullptr && ASR::is_a( *ASRUtils::symbol_get_past_external(sym)) ) { throw std::runtime_error("Special function " + name + " cannot be overshadowed yet."); @@ -2484,6 +2680,8 @@ class ClangASTtoASRVisitor: public clang::RecursiveASTVisitor