diff --git a/.github/workflows/github-action.yml b/.github/workflows/github-action.yml index 007cd1cc7..2a9e11206 100644 --- a/.github/workflows/github-action.yml +++ b/.github/workflows/github-action.yml @@ -48,6 +48,11 @@ jobs: git clone "https://github.com/SVF-tools/Test-Suite.git"; source ${{github.workspace}}/build.sh + - name: ctest objtype inference + working-directory: ${{github.workspace}}/Release-build + run: + ctest -R objtype -VV + - name: ctest wpa working-directory: ${{github.workspace}}/Release-build run: diff --git a/svf-llvm/include/SVF-LLVM/LLVMModule.h b/svf-llvm/include/SVF-LLVM/LLVMModule.h index c3b6f450a..571c62b17 100644 --- a/svf-llvm/include/SVF-LLVM/LLVMModule.h +++ b/svf-llvm/include/SVF-LLVM/LLVMModule.h @@ -39,6 +39,7 @@ namespace SVF { class SymbolTableInfo; +class ObjTypeInference; class LLVMModuleSet { @@ -88,6 +89,7 @@ class LLVMModuleSet SVFValue2LLVMValueMap SVFValue2LLVMValue; LLVMType2SVFTypeMap LLVMType2SVFType; Type2TypeInfoMap Type2TypeInfo; + ObjTypeInference* typeInference; /// Constructor LLVMModuleSet(); @@ -95,7 +97,7 @@ class LLVMModuleSet void build(); public: - ~LLVMModuleSet() = default; + ~LLVMModuleSet(); static inline LLVMModuleSet* getLLVMModuleSet() { @@ -343,6 +345,8 @@ class LLVMModuleSet /// Get LLVM Type const Type* getLLVMType(const SVFType* T) const; + ObjTypeInference* getTypeInference(); + private: /// Create SVFTypes SVFType* addSVFTypeInfo(const Type* t); diff --git a/svf-llvm/include/SVF-LLVM/LLVMUtil.h b/svf-llvm/include/SVF-LLVM/LLVMUtil.h index 429eed26a..8df0e2668 100644 --- a/svf-llvm/include/SVF-LLVM/LLVMUtil.h +++ b/svf-llvm/include/SVF-LLVM/LLVMUtil.h @@ -106,10 +106,9 @@ static inline Type* getPtrElementType(const PointerType* pty) #endif } -/// Get the reference type of heap/static object from an allocation site. -//@{ -const Type *inferTypeOfHeapObjOrStaticObj(const Instruction* inst); -//@} +/// Return size of this object based on LLVM value +u32_t getNumOfElements(const Type* ety); + /// Return true if this value refers to a object bool isObject(const Value* ref); @@ -362,6 +361,7 @@ std::string dumpValue(const Value* val); std::string dumpType(const Type* type); +std::string dumpValueAndDbgInfo(const Value* val); /** * See more: https://github.com/SVF-tools/SVF/pull/1191 diff --git a/svf-llvm/include/SVF-LLVM/ObjTypeInference.h b/svf-llvm/include/SVF-LLVM/ObjTypeInference.h new file mode 100644 index 000000000..c76674e34 --- /dev/null +++ b/svf-llvm/include/SVF-LLVM/ObjTypeInference.h @@ -0,0 +1,106 @@ +//===- ObjTypeInference.h -- Type inference----------------------------// +// +// SVF: Static Value-Flow Analysis +// +// Copyright (C) <2013-> +// + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +// +//===----------------------------------------------------------------------===// + +/* + * ObjTypeInference.h + * + * Created by Xiao Cheng on 10/01/24. + * + */ + +#ifndef SVF_OBJTYPEINFERENCE_H +#define SVF_OBJTYPEINFERENCE_H + +#include "Util/SVFUtil.h" +#include "SVF-LLVM/BasicTypes.h" +#include "SVFIR/SVFValue.h" +#include "Util/ThreadAPI.h" + +namespace SVF { +class ObjTypeInference { + +public: + typedef Set ValueSet; + typedef Map ValueToValueSet; + typedef ValueToValueSet ValueToInferSites; + typedef ValueToValueSet ValueToSources; + typedef Map ValueToType; + typedef std::pair ValueBoolPair; + + +private: + ValueToInferSites _valueToInferSites; // value inference site cache + ValueToType _valueToType; // value type cache + ValueToSources _valueToAllocs; // value allocations (stack, static, heap) cache + + +public: + + explicit ObjTypeInference() = default; + + ~ObjTypeInference() = default; + + + /// get or infer the type of a value + const Type *inferObjType(const Value *startValue); + + /// Validate type inference + void validateTypeCheck(const CallBase *cs); + + void typeSizeDiffTest(const PointerType *oPTy, const Type *iTy, const Value *val); + + /// Default type + const Type *defaultType(const Value *val); + + /// Pointer type + inline const Type *ptrType() { + return PointerType::getUnqual(getLLVMCtx()); + } + + /// Int8 type + inline const IntegerType *int8Type() { + return Type::getInt8Ty(getLLVMCtx()); + } + + LLVMContext &getLLVMCtx(); + +private: + + /// Forward collect all possible infer sites starting from a value + const Type *fwInferObjType(const Value *startValue); + + /// Backward collect all possible allocation sites (stack, static, heap) starting from a value + Set bwfindAllocations(const Value *startValue); + + bool isAllocation(const Value *val); + +public: + /// Select the largest (conservative) type from all types + const Type *selectLargestType(Set &objTys); + + u32_t objTyToNumFields(const Type *objTy); + + u32_t getArgPosInCall(const CallBase *callBase, const Value *arg); + +}; +} +#endif //SVF_OBJTYPEINFERENCE_H diff --git a/svf-llvm/include/SVF-LLVM/SVFIRBuilder.h b/svf-llvm/include/SVF-LLVM/SVFIRBuilder.h index e5eb82e26..653be87d2 100644 --- a/svf-llvm/include/SVF-LLVM/SVFIRBuilder.h +++ b/svf-llvm/include/SVF-LLVM/SVFIRBuilder.h @@ -271,7 +271,7 @@ class SVFIRBuilder: public llvm::InstVisitor inline NodeID addNullPtrNode() { LLVMContext& cxt = LLVMModuleSet::getLLVMModuleSet()->getContext(); - ConstantPointerNull* constNull = ConstantPointerNull::get(Type::getInt8PtrTy(cxt)); + ConstantPointerNull* constNull = ConstantPointerNull::get(PointerType::getUnqual(cxt)); NodeID nullPtr = pag->addValNode(LLVMModuleSet::getLLVMModuleSet()->getSVFValue(constNull),pag->getNullPtr()); setCurrentLocation(constNull, nullptr); addBlackHoleAddrEdge(pag->getBlkPtr()); diff --git a/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h b/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h index e013da5b6..6785ef75e 100644 --- a/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h +++ b/svf-llvm/include/SVF-LLVM/SymbolTableBuilder.h @@ -39,6 +39,8 @@ namespace SVF { +class ObjTypeInference; + class SymbolTableBuilder { friend class SVFIRBuilder; @@ -82,6 +84,18 @@ class SymbolTableBuilder void handleCE(const Value* val); // @} + + ObjTypeInference* getTypeInference(); + + /// Forward collect all possible infer sites starting from a value + const Type* inferObjType(const Value *startValue); + + /// Get the reference type of heap/static object from an allocation site. + //@{ + const Type *inferTypeOfHeapObjOrStaticObj(const Instruction* inst); + //@} + + /// Create an objectInfo based on LLVM value ObjTypeInfo* createObjTypeInfo(const Value* val); diff --git a/svf-llvm/lib/CHGBuilder.cpp b/svf-llvm/lib/CHGBuilder.cpp index c8417671d..08c15a05a 100644 --- a/svf-llvm/lib/CHGBuilder.cpp +++ b/svf-llvm/lib/CHGBuilder.cpp @@ -45,6 +45,7 @@ #include "SVFIR/SVFModule.h" #include "Util/PTAStat.h" #include "SVF-LLVM/LLVMModule.h" +#include "SVF-LLVM/ObjTypeInference.h" using namespace SVF; using namespace SVFUtil; @@ -671,6 +672,7 @@ void CHGBuilder::buildCSToCHAVtblsAndVfnsMap() } } + const CHGraph::CHNodeSetTy& CHGBuilder::getCSClasses(const CallBase* cs) { assert(cppUtil::isVirtualCallSite(cs) && "not virtual callsite!"); diff --git a/svf-llvm/lib/CppUtil.cpp b/svf-llvm/lib/CppUtil.cpp index ca998d084..ffb4cba5c 100644 --- a/svf-llvm/lib/CppUtil.cpp +++ b/svf-llvm/lib/CppUtil.cpp @@ -514,18 +514,20 @@ std::string cppUtil::getClassNameOfThisPtr(const CallBase* inst) } if (thisPtrClassName.size() == 0) { - const Value* thisPtr = cppUtil::getVCallThisPtr(inst); - if(const PointerType* ptrTy = SVFUtil::dyn_cast(thisPtr->getType())) + const Value* thisPtr = getVCallThisPtr(inst); + if (const PointerType *ptrTy = SVFUtil::dyn_cast(thisPtr->getType())) { // TODO: getPtrElementType need type inference - if(const StructType* st = SVFUtil::dyn_cast(LLVMUtil::getPtrElementType(ptrTy))) + if (const StructType *st = SVFUtil::dyn_cast(LLVMUtil::getPtrElementType(ptrTy))) { thisPtrClassName = getClassNameFromType(st); + } + } } size_t found = thisPtrClassName.find_last_not_of("0123456789"); if (found != std::string::npos) { if (found != thisPtrClassName.size() - 1 && - thisPtrClassName[found] == '.') + thisPtrClassName[found] == '.') { return thisPtrClassName.substr(0, found); } diff --git a/svf-llvm/lib/LLVMModule.cpp b/svf-llvm/lib/LLVMModule.cpp index 07b46b1a0..4367fa0f5 100644 --- a/svf-llvm/lib/LLVMModule.cpp +++ b/svf-llvm/lib/LLVMModule.cpp @@ -38,6 +38,7 @@ #include "SVF-LLVM/SymbolTableBuilder.h" #include "MSSA/SVFGBuilder.h" #include "llvm/Support/FileSystem.h" +#include "SVF-LLVM/ObjTypeInference.h" using namespace std; using namespace SVF; @@ -74,10 +75,19 @@ bool LLVMModuleSet::preProcessed = false; LLVMModuleSet::LLVMModuleSet() : symInfo(SymbolTableInfo::SymbolInfo()), - svfModule(SVFModule::getSVFModule()) + svfModule(SVFModule::getSVFModule()), typeInference(new ObjTypeInference()) { } +LLVMModuleSet::~LLVMModuleSet() { + delete typeInference; + typeInference = nullptr; +} + +ObjTypeInference* LLVMModuleSet::getTypeInference() { + return typeInference; +} + SVFModule* LLVMModuleSet::buildSVFModule(Module &mod) { LLVMModuleSet* mset = getLLVMModuleSet(); @@ -152,8 +162,8 @@ void LLVMModuleSet::build() void LLVMModuleSet::createSVFDataStructure() { - SVFType::i8Ty = getSVFType(IntegerType::getInt8Ty(getContext())); - SVFType::ptrTy = getSVFType(PointerType::getUnqual(getContext())); + SVFType::svfI8Ty = getSVFType(getTypeInference()->int8Type()); + SVFType::svfPtrTy = getSVFType(getTypeInference()->ptrType()); // Functions need to be retrieved in the order of insertion // candidateDefs is the vector for all used defined functions // candidateDecls is the vector for all used declared functions @@ -729,14 +739,14 @@ void LLVMModuleSet::addSVFMain() assert(mainMod && "Module with main function not found."); Module& M = *mainMod; // char ** - Type* i8ptr2 = PointerType::getInt8PtrTy(M.getContext())->getPointerTo(); + Type* ptr = PointerType::getUnqual(M.getContext()); Type* i32 = IntegerType::getInt32Ty(M.getContext()); // define void @svf.main(i32, i8**, i8**) #if (LLVM_VERSION_MAJOR >= 9) FunctionCallee svfmainFn = M.getOrInsertFunction( SVF_MAIN_FUNC_NAME, Type::getVoidTy(M.getContext()), - i32,i8ptr2,i8ptr2 + i32,ptr,ptr ); Function* svfmain = SVFUtil::dyn_cast(svfmainFn.getCallee()); #else @@ -1284,13 +1294,6 @@ SVFType* LLVMModuleSet::getSVFType(const Type* T) SVFType* svfType = addSVFTypeInfo(T); StInfo* stinfo = collectTypeInfo(T); svfType->setTypeInfo(stinfo); - /// TODO: set the void* to every element for now (imprecise) - /// For example, - /// [getPointerTo(): char ----> i8*] - /// [getPointerTo(): int ----> i8*] - /// [getPointerTo(): struct ----> i8*] - PointerType* ptrTy = PointerType::getInt8PtrTy(getContext()); - svfType->setPointerTo(SVFUtil::cast(getSVFType(ptrTy))); return svfType; } @@ -1381,15 +1384,6 @@ SVFType* LLVMModuleSet::addSVFTypeInfo(const Type* T) symInfo->addTypeInfo(svftype); LLVMType2SVFType[T] = svftype; - if (const PointerType* pt = SVFUtil::dyn_cast(T)) - { - //cast svftype to SVFPointerType - SVFPointerType* svfPtrType = SVFUtil::dyn_cast(svftype); - assert(svfPtrType && "this is not SVFPointerType"); - // TODO: getPtrElementType to be removed - if(!pt->isOpaque()) - svfPtrType->setPtrElementType(getSVFType(LLVMUtil::getPtrElementType(pt))); - } return svftype; } diff --git a/svf-llvm/lib/LLVMUtil.cpp b/svf-llvm/lib/LLVMUtil.cpp index 23256323e..b77642c1c 100644 --- a/svf-llvm/lib/LLVMUtil.cpp +++ b/svf-llvm/lib/LLVMUtil.cpp @@ -378,35 +378,20 @@ const Value* LLVMUtil::getFirstUseViaCastInst(const Value* val) } /*! - * Return the type of the object from a heap allocation + * Return size of this Object */ -const Type* LLVMUtil::inferTypeOfHeapObjOrStaticObj(const Instruction *inst) +u32_t LLVMUtil::getNumOfElements(const Type* ety) { - const PointerType* type = SVFUtil::dyn_cast(inst->getType()); - const SVFInstruction* svfinst = LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(inst); - if(SVFUtil::isHeapAllocExtCallViaRet(svfinst)) + assert(ety && "type is null?"); + u32_t numOfFields = 1; + if (SVFUtil::isa(ety)) { - if(const Value* v = getFirstUseViaCastInst(inst)) - { - if(const PointerType* newTy = SVFUtil::dyn_cast(v->getType())) - type = newTy; - } - } - else if(SVFUtil::isHeapAllocExtCallViaArg(svfinst)) - { - const CallBase* cs = LLVMUtil::getLLVMCallSite(inst); - int arg_pos = SVFUtil::getHeapAllocHoldingArgPosition(SVFUtil::getSVFCallSite(svfinst)); - const Value* arg = cs->getArgOperand(arg_pos); - type = SVFUtil::dyn_cast(arg->getType()); - } - else - { - assert( false && "not a heap allocation instruction?"); + if(Options::ModelArrays()) + return LLVMModuleSet::getLLVMModuleSet()->getSVFType(ety)->getTypeInfo()->getNumOfFlattenElements(); + else + return LLVMModuleSet::getLLVMModuleSet()->getSVFType(ety)->getTypeInfo()->getNumOfFlattenFields(); } - - assert(type && "not a pointer type?"); - // TODO: getPtrElementType need type inference - return getPtrElementType(type); + return numOfFields; } /*! @@ -946,6 +931,15 @@ std::string LLVMUtil::dumpType(const Type* type) return rawstr.str(); } +std::string LLVMUtil::dumpValueAndDbgInfo(const Value *val) { + std::string str; + llvm::raw_string_ostream rawstr(str); + if (val) + rawstr << dumpValue(val) << getSourceLoc(val); + else + rawstr << " llvm Value is null"; + return rawstr.str(); +} namespace SVF { diff --git a/svf-llvm/lib/ObjTypeInference.cpp b/svf-llvm/lib/ObjTypeInference.cpp new file mode 100644 index 000000000..e757da0ec --- /dev/null +++ b/svf-llvm/lib/ObjTypeInference.cpp @@ -0,0 +1,486 @@ +//===- ObjTypeInference.cpp -- Type inference----------------------------// +// +// SVF: Static Value-Flow Analysis +// +// Copyright (C) <2013-> +// + +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. + +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. + +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . +// +//===----------------------------------------------------------------------===// + +/* + * ObjTypeInference.cpp + * + * Created by Xiao Cheng on 10/01/24. + * + */ + +#include "SVF-LLVM/ObjTypeInference.h" +#include "SVF-LLVM/LLVMModule.h" +#include "SVF-LLVM/LLVMUtil.h" +#include "SVF-LLVM/CppUtil.h" + +#define TYPE_DEBUG 0 /* Turn this on if you're debugging type inference */ +#define ERR_MSG(msg) \ + do \ + { \ + SVFUtil::errs() << SVFUtil::errMsg("Error ") << __FILE__ << ':' \ + << __LINE__ << ": " << msg << '\n'; \ + } while (0) +#define ABORT_MSG(msg) \ + do \ + { \ + ERR_MSG(msg); \ + abort(); \ + } while (0) +#define ABORT_IFNOT(condition, msg) \ + do \ + { \ + if (!(condition)) \ + ABORT_MSG(msg); \ + } while (0) + +#if TYPE_DEBUG +#define WARN_MSG(msg) \ + do \ + { \ + SVFUtil::outs() << SVFUtil::wrnMsg("Warning ") << __FILE__ << ':' \ + << __LINE__ << ": " << msg << '\n'; \ + } while (0) +#define WARN_IFNOT(condition, msg) \ + do \ + { \ + if (!(condition)) \ + WARN_MSG(msg); \ + } while (0) +#else +#define WARN_MSG(msg) +#define WARN_IFNOT(condition, msg) +#endif + +using namespace SVF; +using namespace SVFUtil; +using namespace LLVMUtil; +using namespace cppUtil; + + +const std::string TYPEMALLOC = "TYPE_MALLOC"; + +/// Determine type based on infer site +/// https://llvm.org/docs/OpaquePointers.html#migration-instructions +const Type *infersiteToType(const Value *val) { + assert(val && "value cannot be empty"); + if (SVFUtil::isa(val)) { + return llvm::getLoadStoreType(const_cast(val)); + } else if (const GetElementPtrInst *gepInst = SVFUtil::dyn_cast(val)) { + return gepInst->getSourceElementType(); + } else if (const CallBase *call = SVFUtil::dyn_cast(val)) { + return call->getFunctionType(); + } else if (const AllocaInst *allocaInst = SVFUtil::dyn_cast(val)) { + return allocaInst->getAllocatedType(); + } else if (const GlobalValue *globalValue = SVFUtil::dyn_cast(val)) { + return globalValue->getValueType(); + } else { + ABORT_MSG("unknown value:" + dumpValueAndDbgInfo(val)); + } +} + +const Type *ObjTypeInference::defaultType(const Value *val) { + ABORT_IFNOT(val, "val cannot be null"); + // heap has a default type of 8-bit integer type + if (SVFUtil::isa(val) && SVFUtil::isHeapAllocExtCallViaRet( + LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(SVFUtil::cast(val)))) + return int8Type(); + // otherwise we return a pointer type in the default address space + return ptrType(); +} + +LLVMContext &ObjTypeInference::getLLVMCtx() { + return LLVMModuleSet::getLLVMModuleSet()->getContext(); +} + +/*! + * get or infer type of a value + * if the start value is a source (alloc/global, heap, static), call fwInferObjType + * if not, find sources and then forward get or infer types + * @param startValue + */ +const Type *ObjTypeInference::inferObjType(const Value *startValue) { + if (isAllocation(startValue)) return fwInferObjType(startValue); + Set sources = bwfindAllocations(startValue); + Set types; + for (const auto &source: sources) { + types.insert(fwInferObjType(source)); + } + const Type *largestTy = selectLargestType(types); + ABORT_IFNOT(largestTy, "return type cannot be null"); + return largestTy; +} + +/*! + * Forward collect all possible infer sites starting from a value + * @param startValue + */ +const Type *ObjTypeInference::fwInferObjType(const Value *startValue) { + // consult cache + auto tIt = _valueToType.find(startValue); + if (tIt != _valueToType.end()) { + return tIt->second ? tIt->second : defaultType(startValue); + } + + // simulate the call stack, the second element indicates whether we should update valueTypes for current value + FILOWorkList workList; + Set visited; + workList.push({startValue, false}); + + while (!workList.empty()) { + auto curPair = workList.pop(); + if (visited.count(curPair)) continue; + visited.insert(curPair); + const Value *curValue = curPair.first; + bool canUpdate = curPair.second; + Set infersites; + + auto insertInferSite = [&infersites, &canUpdate](const Value *infersite) { + if (canUpdate) infersites.insert(infersite); + }; + auto insertInferSitesOrPushWorklist = [this, &infersites, &workList, &canUpdate](const auto &pUser) { + auto vIt = _valueToInferSites.find(pUser); + if (canUpdate) { + if (vIt != _valueToInferSites.end() && !vIt->second.empty()) { + infersites.insert(vIt->second.begin(), vIt->second.end()); + } + } else { + if (vIt == _valueToInferSites.end()) workList.push({pUser, false}); + } + }; + if (!canUpdate && !_valueToInferSites.count(curValue)) { + workList.push({curValue, true}); + } + if (const GetElementPtrInst *gepInst = SVFUtil::dyn_cast(curValue)) + insertInferSite(gepInst); + for (const auto &it: curValue->uses()) { + if (LoadInst *loadInst = SVFUtil::dyn_cast(it.getUser())) { + /* + * infer based on load, e.g., + %call = call i8* malloc() + %1 = bitcast i8* %call to %struct.MyStruct* + %q = load %struct.MyStruct, %struct.MyStruct* %1 + */ + insertInferSite(loadInst); + } else if (StoreInst *storeInst = SVFUtil::dyn_cast(it.getUser())) { + if (storeInst->getPointerOperand() == curValue) { + /* + * infer based on store (pointer operand), e.g., + %call = call i8* malloc() + %1 = bitcast i8* %call to %struct.MyStruct* + store %struct.MyStruct .., %struct.MyStruct* %1 + */ + insertInferSite(storeInst); + } else { + for (const auto &nit: storeInst->getPointerOperand()->uses()) { + /* + * propagate across store (value operand) and load + %call = call i8* malloc() + store i8* %call, i8** %p + %q = load i8*, i8** %p + ..infer based on %q.. + */ + if (SVFUtil::isa(nit.getUser())) + insertInferSitesOrPushWorklist(nit.getUser()); + } + /* + * infer based on store (value operand) <- gep (result element) + %call1 = call i8* @TYPE_MALLOC(i32 noundef 16, i32 noundef 2), !dbg !39 + %2 = bitcast i8* %call1 to %struct.MyStruct*, !dbg !41 + %3 = load %struct.MyStruct*, %struct.MyStruct** %p, align 8, !dbg !42 + %next = getelementptr inbounds %struct.MyStruct, %struct.MyStruct* %3, i32 0, i32 1, !dbg !43 + store %struct.MyStruct* %2, %struct.MyStruct** %next, align 8, !dbg !44 + %5 = load %struct.MyStruct*, %struct.MyStruct** %p, align 8, !dbg !48 + %next3 = getelementptr inbounds %struct.MyStruct, %struct.MyStruct* %5, i32 0, i32 1, !dbg !49 + %6 = load %struct.MyStruct*, %struct.MyStruct** %next3, align 8, !dbg !49 + infer site -> %f1 = getelementptr inbounds %struct.MyStruct, %struct.MyStruct* %6, i32 0, i32 0, !dbg !50 + */ + if (GetElementPtrInst *gepInst = SVFUtil::dyn_cast( + storeInst->getPointerOperand())) { + const Value *gepBase = gepInst->getPointerOperand(); + if (!SVFUtil::isa(gepBase)) continue; + const LoadInst *load = SVFUtil::dyn_cast(gepBase); + for (const auto &loadUse: load->getPointerOperand()->uses()) { + if (loadUse.getUser() == load || !SVFUtil::isa(loadUse.getUser())) + continue; + for (const auto &gepUse: loadUse.getUser()->uses()) { + if (!SVFUtil::isa(gepUse.getUser())) continue; + for (const auto &loadUse2: gepUse.getUser()->uses()) { + if (SVFUtil::isa(loadUse2.getUser())) { + insertInferSitesOrPushWorklist(loadUse2.getUser()); + } + } + } + } + + } + } + + } else if (GetElementPtrInst *gepInst = SVFUtil::dyn_cast(it.getUser())) { + /* + * infer based on gep (pointer operand) + %call = call i8* malloc() + %1 = bitcast i8* %call to %struct.MyStruct* + %next = getelementptr inbounds %struct.MyStruct, %struct.MyStruct* %1, i32 0.. + */ + if (gepInst->getPointerOperand() == curValue) + insertInferSite(gepInst); + } else if (BitCastInst *bitcast = SVFUtil::dyn_cast(it.getUser())) { + // continue on bitcast + insertInferSitesOrPushWorklist(bitcast); + } else if (PHINode *phiNode = SVFUtil::dyn_cast(it.getUser())) { + // continue on bitcast + insertInferSitesOrPushWorklist(phiNode); + } else if (ReturnInst *retInst = SVFUtil::dyn_cast(it.getUser())) { + /* + * propagate from return to caller + Function Attrs: noinline nounwind optnone uwtable + define dso_local i8* @malloc_wrapper() #0 !dbg !22 { + entry: + %call = call i8* @malloc(i32 noundef 16), !dbg !25 + ret i8* %call, !dbg !26 + } + %call = call i8* @malloc_wrapper() + ..infer based on %call.. + */ + for (const auto &callsite: retInst->getFunction()->uses()) { + if (CallBase *callBase = SVFUtil::dyn_cast(callsite.getUser())) { + // skip function as parameter + // e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback) + if (callBase->getCalledFunction() != retInst->getFunction()) continue; + insertInferSitesOrPushWorklist(callBase); + } + } + } else if (CallBase *callBase = SVFUtil::dyn_cast(it.getUser())) { + /* + * propagate from callsite to callee + %call = call i8* @malloc(i32 noundef 16) + %0 = bitcast i8* %call to %struct.Node*, !dbg !43 + call void @foo(%struct.Node* noundef %0), !dbg !45 + + define dso_local void @foo(%struct.Node* noundef %param) #0 !dbg !22 {...} + ..infer based on the formal param %param.. + */ + // skip global function value -> callsite + // e.g., def @foo() -> call @foo() + // we don't skip function as parameter, e.g., def @foo() -> call @bar(..., @foo) + if (SVFUtil::isa(curValue) && curValue == callBase->getCalledFunction()) continue; + // skip indirect call + // e.g., %0 = ... -> call %0(...) + if (!callBase->hasArgument(curValue)) continue; + if (Function *calleeFunc = callBase->getCalledFunction()) { + u32_t pos = getArgPosInCall(callBase, curValue); + // for variable argument, conservatively collect all params + if (calleeFunc->isVarArg()) pos = 0; + if (!calleeFunc->isDeclaration()) { + insertInferSitesOrPushWorklist(calleeFunc->getArg(pos)); + } + } + } + } + if (canUpdate) { + Set types; + std::transform(infersites.begin(), infersites.end(), std::inserter(types, types.begin()), + infersiteToType); + _valueToInferSites[curValue] = infersites; + _valueToType[curValue] = selectLargestType(types); + } + } + const Type *type = _valueToType[startValue]; + if (type == nullptr) { + type = defaultType(startValue); + WARN_MSG("Using default type, trace ID is " + std::to_string(traceId) + ":" + dumpValueAndDbgInfo(startValue)); + } + ABORT_IFNOT(type, "type cannot be a null ptr"); + return type; +} + +/*! + * Backward collect all possible allocation sites (stack, static, heap) starting from a value + * @param startValue + * @return + */ +Set ObjTypeInference::bwfindAllocations(const Value *startValue) { + + // consult cache + auto tIt = _valueToAllocs.find(startValue); + if (tIt != _valueToAllocs.end()) { + WARN_IFNOT(!tIt->second.empty(), "empty type:" + dumpValueAndDbgInfo(startValue)); + return !tIt->second.empty() ? tIt->second : Set({startValue}); + } + + // simulate the call stack, the second element indicates whether we should update sources for current value + FILOWorkList workList; + Set visited; + workList.push({startValue, false}); + while (!workList.empty()) { + auto curPair = workList.pop(); + if (visited.count(curPair)) continue; + visited.insert(curPair); + const Value *curValue = curPair.first; + bool canUpdate = curPair.second; + + Set sources; + auto insertAllocs = [&sources, &canUpdate](const Value *source) { + if (canUpdate) sources.insert(source); + }; + auto insertAllocsOrPushWorklist = [this, &sources, &workList, &canUpdate](const auto &pUser) { + auto vIt = _valueToAllocs.find(pUser); + if (canUpdate) { + if (vIt != _valueToAllocs.end() && !vIt->second.empty()) { + sources.insert(vIt->second.begin(), vIt->second.end()); + } + } else { + if (vIt == _valueToAllocs.end()) workList.push({pUser, false}); + } + }; + + if (!canUpdate && !_valueToAllocs.count(curValue)) { + workList.push({curValue, true}); + } + + if (isAllocation(curValue)) { + insertAllocs(curValue); + } else if (const BitCastInst *bitCastInst = SVFUtil::dyn_cast(curValue)) { + Value *prevVal = bitCastInst->getOperand(0); + insertAllocsOrPushWorklist(prevVal); + } else if (const PHINode *phiNode = SVFUtil::dyn_cast(curValue)) { + for (u32_t i = 0; i < phiNode->getNumOperands(); ++i) { + insertAllocsOrPushWorklist(phiNode->getOperand(i)); + } + } else if (const LoadInst *loadInst = SVFUtil::dyn_cast(curValue)) { + for (const auto &use: loadInst->getPointerOperand()->uses()) { + if (const StoreInst *storeInst = SVFUtil::dyn_cast(use.getUser())) { + if (storeInst->getPointerOperand() == loadInst->getPointerOperand()) { + insertAllocsOrPushWorklist(storeInst->getValueOperand()); + } + } + } + } else if (const Argument *argument = SVFUtil::dyn_cast(curValue)) { + for (const auto &use: argument->getParent()->uses()) { + if (const CallBase *callBase = SVFUtil::dyn_cast(use.getUser())) { + // skip function as parameter + // e.g., call void @foo(%struct.ssl_ctx_st* %9, i32 (i8*, i32, i32, i8*)* @passwd_callback) + if (callBase->getCalledFunction() != argument->getParent()) continue; + u32_t pos = argument->getParent()->isVarArg() ? 0 : argument->getArgNo(); + insertAllocsOrPushWorklist(callBase->getArgOperand(pos)); + } + } + } else if (const CallBase *callBase = SVFUtil::dyn_cast(curValue)) { + ABORT_IFNOT(!callBase->doesNotReturn(), "callbase does not return:" + dumpValueAndDbgInfo(callBase)); + if (Function *callee = callBase->getCalledFunction()) { + if (!callee->isDeclaration()) { + const SVFFunction *svfFunc = LLVMModuleSet::getLLVMModuleSet()->getSVFFunction(callee); + const Value *pValue = LLVMModuleSet::getLLVMModuleSet()->getLLVMValue(svfFunc->getExitBB()->back()); + const ReturnInst *retInst = SVFUtil::dyn_cast(pValue); + ABORT_IFNOT(retInst && retInst->getReturnValue(), "not return inst?"); + insertAllocsOrPushWorklist(retInst->getReturnValue()); + } + } + } + if (canUpdate) { + _valueToAllocs[curValue] = sources; + } + } + Set srcs = _valueToAllocs[startValue]; + if (srcs.empty()) { + srcs = {startValue}; + WARN_MSG("Using default type, trace ID is " + std::to_string(traceId) + ":" + dumpValueAndDbgInfo(startValue)); + } + ABORT_IFNOT(!srcs.empty(), "sources cannot be empty"); + return srcs; +} + +bool ObjTypeInference::isAllocation(const SVF::Value *val) { + return LLVMUtil::isObject(val); +} + +/*! + * Validate type inference + * @param cs : stub malloc function with element number label + */ +void ObjTypeInference::validateTypeCheck(const CallBase *cs) { + if (const Function *func = cs->getCalledFunction()) { + if (func->getName().find(TYPEMALLOC) != std::string::npos) { + const Type *objType = fwInferObjType(cs); + ConstantInt *pInt = + SVFUtil::dyn_cast(cs->getOperand(1)); + assert(pInt && "the second argument is a integer"); + u32_t iTyNum = objTyToNumFields(objType); + if (iTyNum >= pInt->getZExtValue()) + SVFUtil::outs() << SVFUtil::sucMsg("\t SUCCESS :") << dumpValueAndDbgInfo(cs) + << SVFUtil::pasMsg(" TYPE: ") + << dumpType(objType) << "\n"; + else { + SVFUtil::errs() << SVFUtil::errMsg("\t FAILURE :") << ":" << dumpValueAndDbgInfo(cs) << " TYPE: " + << dumpType(objType) << "\n"; + abort(); + } + } + } +} + +void ObjTypeInference::typeSizeDiffTest(const PointerType *oPTy, const Type *iTy, const Value *val) { +#if TYPE_DEBUG + Type *oTy = getPtrElementType(oPTy); + u32_t iTyNum = objTyToNumFields(iTy); + if (getNumOfElements(oTy) > iTyNum) { + ERR_MSG("original type is:" + dumpType(oTy)); + ERR_MSG("infered type is:" + dumpType(iTy)); + ABORT_MSG("wrong type, trace ID is " + std::to_string(traceId) + ":" + dumpValueAndDbgInfo(val)); + } +#endif +} + +u32_t ObjTypeInference::getArgPosInCall(const CallBase *callBase, const Value *arg) { + assert(callBase->hasArgument(arg) && "callInst does not have argument arg?"); + auto it = std::find(callBase->arg_begin(), callBase->arg_end(), arg); + assert(it != callBase->arg_end() && "Didn't find argument?"); + return std::distance(callBase->arg_begin(), it); +} + + +const Type *ObjTypeInference::selectLargestType(Set &objTys) { + if (objTys.empty()) return nullptr; + // map type size to types from with key in descending order + OrderedMap, std::greater> typeSzToTypes; + for (const Type *ty: objTys) { + typeSzToTypes[objTyToNumFields(ty)].insert(ty); + } + assert(!typeSzToTypes.empty() && "typeSzToTypes cannot be empty"); + Set largestTypes; + std::tie(std::ignore, largestTypes) = *typeSzToTypes.begin(); + assert(!largestTypes.empty() && "largest element cannot be empty"); + return *largestTypes.begin(); +} + +u32_t ObjTypeInference::objTyToNumFields(const Type *objTy) { + u32_t num = Options::MaxFieldLimit(); + if (SVFUtil::isa(objTy)) + num = getNumOfElements(objTy); + else if (const StructType *st = SVFUtil::dyn_cast(objTy)) { + /// For an C++ class, it can have variant elements depending on the vtable size, + /// Hence we only handle non-cpp-class object, the type of the cpp class is treated as default PointerType + if (!classTyHasVTable(st)) + num = getNumOfElements(st); + } + return num; +} diff --git a/svf-llvm/lib/SVFIRBuilder.cpp b/svf-llvm/lib/SVFIRBuilder.cpp index ed6646f32..e5e32ec01 100644 --- a/svf-llvm/lib/SVFIRBuilder.cpp +++ b/svf-llvm/lib/SVFIRBuilder.cpp @@ -1083,42 +1083,6 @@ const Value* SVFIRBuilder::getBaseValueForExtArg(const Value* V) if(totalidx == 0 && !SVFUtil::isa(value->getType())) value = gep->getPointerOperand(); } - - // if the argument of memcpy is the result of an allocation (1) or a casted load instruction (2), - // further steps are necessary to find the correct base value - // - // (1) - // %call = malloc 80 - // %0 = bitcast i8* %call to %struct.A* - // %1 = bitcast %struct.B* %param to i8* - // call void memcpy(%call, %1, 80) - // - // (2) - // %0 = bitcast %struct.A* %param to i8* - // %2 = bitcast %struct.B** %arrayidx to i8** - // %3 = load i8*, i8** %2 - // call void @memcpy(%0, %3, 80) - LLVMContext &cxt = LLVMModuleSet::getLLVMModuleSet()->getContext(); - if (value->getType() == PointerType::getInt8PtrTy(cxt)) - { - // (1) - if (const CallBase* cb = SVFUtil::dyn_cast(value)) - { - const SVFInstruction* svfInst = LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(cb); - if (SVFUtil::isHeapAllocExtCallViaRet(svfInst)) - { - if (const Value* bitCast = getFirstUseViaCastInst(cb)) - return bitCast; - } - } - // (2) - else if (const LoadInst* load = SVFUtil::dyn_cast(value)) - { - if (const BitCastInst* bitCast = SVFUtil::dyn_cast(load->getPointerOperand())) - return bitCast->getOperand(0); - } - } - return value; } @@ -1219,7 +1183,9 @@ NodeID SVFIRBuilder::getGepValVar(const Value* val, const AccessPath& ap, const const SVFBasicBlock* cbb = getCurrentBB(); setCurrentLocation(curVal, nullptr); LLVMModuleSet* llvmmodule = LLVMModuleSet::getLLVMModuleSet(); - NodeID gepNode= pag->addGepValNode(curVal, llvmmodule->getSVFValue(val),ap, NodeIDAllocator::get()->allocateValueId(),elementType->getPointerTo()); + NodeID gepNode = pag->addGepValNode(curVal, llvmmodule->getSVFValue(val), ap, + NodeIDAllocator::get()->allocateValueId(), + llvmmodule->getSVFType(PointerType::getUnqual(llvmmodule->getContext()))); addGepEdge(base, gepNode, ap, true); setCurrentLocation(cval, cbb); return gepNode; diff --git a/svf-llvm/lib/SVFIRExtAPI.cpp b/svf-llvm/lib/SVFIRExtAPI.cpp index 085d82cb1..4752659df 100644 --- a/svf-llvm/lib/SVFIRExtAPI.cpp +++ b/svf-llvm/lib/SVFIRExtAPI.cpp @@ -30,6 +30,7 @@ #include "SVF-LLVM/SVFIRBuilder.h" #include "Util/SVFUtil.h" #include "SVF-LLVM/SymbolTableBuilder.h" +#include "SVF-LLVM/ObjTypeInference.h" using namespace std; using namespace SVF; @@ -43,12 +44,8 @@ const Type* SVFIRBuilder::getBaseTypeAndFlattenedFields(const Value* V, std::vec { assert(V); const Value* value = getBaseValueForExtArg(V); - const Type* T = value->getType(); - // TODO: getPtrElementType need type inference - while (const PointerType *ptype = SVFUtil::dyn_cast(T)) - T = getPtrElementType(ptype); - - u32_t numOfElems = pag->getSymbolInfo()->getNumOfFlattenElements(LLVMModuleSet::getLLVMModuleSet()->getSVFType(T)); + const Type *objType = LLVMModuleSet::getLLVMModuleSet()->getTypeInference()->inferObjType(value); + u32_t numOfElems = pag->getSymbolInfo()->getNumOfFlattenElements(LLVMModuleSet::getLLVMModuleSet()->getSVFType(objType)); /// use user-specified size for this copy operation if the size is a constaint int if(szValue && SVFUtil::isa(szValue)) { @@ -71,7 +68,7 @@ const Type* SVFIRBuilder::getBaseTypeAndFlattenedFields(const Value* V, std::vec ls.addOffsetVarAndGepTypePair(getPAG()->getGNode(getPAG()->getValueNode(svfOffset)), nullptr); fields.push_back(ls); } - return T; + return objType; } /*! diff --git a/svf-llvm/lib/SymbolTableBuilder.cpp b/svf-llvm/lib/SymbolTableBuilder.cpp index 45f917346..3cd1bdb67 100644 --- a/svf-llvm/lib/SymbolTableBuilder.cpp +++ b/svf-llvm/lib/SymbolTableBuilder.cpp @@ -38,6 +38,7 @@ #include "Util/NodeIDAllocator.h" #include "Util/Options.h" #include "Util/SVFUtil.h" +#include "SVF-LLVM/ObjTypeInference.h" using namespace SVF; using namespace SVFUtil; @@ -219,6 +220,9 @@ void SymbolTableBuilder::buildMemModel(SVFModule* svfModule) // TODO handle inlineAsm /// if (SVFUtil::isa(Callee)) + if (Options::EnableTypeCheck()) { + getTypeInference()->validateTypeCheck(cs); + } } //@} } @@ -566,6 +570,53 @@ void SymbolTableBuilder::handleGlobalInitializerCE(const Constant* C) } } +ObjTypeInference *SymbolTableBuilder::getTypeInference() { + return LLVMModuleSet::getLLVMModuleSet()->getTypeInference(); +} + + +const Type* SymbolTableBuilder::inferObjType(const Value *startValue) { + return getTypeInference()->inferObjType(startValue); +} + +/*! + * Return the type of the object from a heap allocation + */ +const Type* SymbolTableBuilder::inferTypeOfHeapObjOrStaticObj(const Instruction *inst) +{ + const Value* startValue = inst; + const PointerType *originalPType = SVFUtil::dyn_cast(inst->getType()); + const Type* inferedType = nullptr; + assert(originalPType && "empty type?"); + const SVFInstruction* svfinst = LLVMModuleSet::getLLVMModuleSet()->getSVFInstruction(inst); + if(SVFUtil::isHeapAllocExtCallViaRet(svfinst)) + { + if(const Value* v = getFirstUseViaCastInst(inst)) + { + if (const PointerType *newTy = SVFUtil::dyn_cast(v->getType())) { + originalPType = newTy; + } + } + inferedType = inferObjType(startValue); + } + else if(SVFUtil::isHeapAllocExtCallViaArg(svfinst)) + { + const CallBase* cs = LLVMUtil::getLLVMCallSite(inst); + int arg_pos = SVFUtil::getHeapAllocHoldingArgPosition(SVFUtil::getSVFCallSite(svfinst)); + const Value* arg = cs->getArgOperand(arg_pos); + originalPType = SVFUtil::dyn_cast(arg->getType()); + inferedType = inferObjType(startValue = arg); + } + else + { + assert( false && "not a heap allocation instruction?"); + } + + getTypeInference()->typeSizeDiffTest(originalPType, inferedType, startValue); + + return inferedType; +} + /* * Initial the memory object here */ @@ -599,7 +650,7 @@ ObjTypeInfo* SymbolTableBuilder::createObjTypeInfo(const Value* val) } else { - SVFUtil::errs() << dumpValue(val) << "\n"; + SVFUtil::errs() << dumpValueAndDbgInfo(val) << "\n"; assert(false && "not an allocation or global?"); } } @@ -751,26 +802,20 @@ u32_t SymbolTableBuilder::analyzeHeapAllocByteSize(const Value* val) */ u32_t SymbolTableBuilder::analyzeHeapObjType(ObjTypeInfo* typeinfo, const Value* val) { - if(const Value* castUse = getFirstUseViaCastInst(val)) - { - typeinfo->setFlag(ObjTypeInfo::HEAP_OBJ); - analyzeObjType(typeinfo,castUse); - const Type* objTy = LLVMModuleSet::getLLVMModuleSet()->getLLVMType(typeinfo->getType()); - if(SVFUtil::isa(objTy)) + typeinfo->setFlag(ObjTypeInfo::HEAP_OBJ); + analyzeObjType(typeinfo, val); + const Type* objTy = LLVMModuleSet::getLLVMModuleSet()->getLLVMType(typeinfo->getType()); + if(SVFUtil::isa(objTy)) + return getNumOfElements(objTy); + else if(const StructType* st = SVFUtil::dyn_cast(objTy)) + { + /// For an C++ class, it can have variant elements depending on the vtable size, + /// Hence we only handle non-cpp-class object, the type of the cpp class is treated as default PointerType + if(cppUtil::classTyHasVTable(st)) + typeinfo->resetTypeForHeapStaticObj(LLVMModuleSet::getLLVMModuleSet()->getSVFType( + LLVMModuleSet::getLLVMModuleSet()->getTypeInference()->ptrType())); + else return getNumOfElements(objTy); - else if(const StructType* st = SVFUtil::dyn_cast(objTy)) - { - /// For an C++ class, it can have variant elements depending on the vtable size, - /// Hence we only handle non-cpp-class object, the type of the cpp class is treated as PointerType at the cast site - if(cppUtil::classTyHasVTable(st)) - typeinfo->resetTypeForHeapStaticObj(LLVMModuleSet::getLLVMModuleSet()->getSVFType(castUse->getType())); - else - return getNumOfElements(objTy); - } - } - else - { - typeinfo->setFlag(ObjTypeInfo::HEAP_OBJ); } return typeinfo->getMaxFieldOffsetLimit(); } diff --git a/svf-llvm/tools/CFL/cfl.cpp b/svf-llvm/tools/CFL/cfl.cpp index 96dc84212..1300fc0f0 100644 --- a/svf-llvm/tools/CFL/cfl.cpp +++ b/svf-llvm/tools/CFL/cfl.cpp @@ -35,6 +35,7 @@ #include "CFL/CFLAlias.h" #include "CFL/CFLVF.h" + using namespace llvm; using namespace SVF; @@ -83,6 +84,7 @@ int main(int argc, char ** argv) SVFIR::releaseSVFIR(); SVF::LLVMModuleSet::releaseLLVMModuleSet(); + return 0; } diff --git a/svf-llvm/tools/Example/svf-ex.cpp b/svf-llvm/tools/Example/svf-ex.cpp index e2b71c3fa..b2de6f5b3 100644 --- a/svf-llvm/tools/Example/svf-ex.cpp +++ b/svf-llvm/tools/Example/svf-ex.cpp @@ -34,6 +34,7 @@ #include "Util/CommandLine.h" #include "Util/Options.h" + using namespace std; using namespace SVF; @@ -221,6 +222,7 @@ int main(int argc, char ** argv) SVFIRBuilder builder(svfModule); SVFIR* pag = builder.build(); + /// Create Andersen's pointer analysis Andersen* ander = AndersenWaveDiff::createAndersenWaveDiff(pag); @@ -263,7 +265,6 @@ int main(int argc, char ** argv) LLVMModuleSet::getLLVMModuleSet()->dumpModulesToFile(".svf.bc"); SVF::LLVMModuleSet::releaseLLVMModuleSet(); - llvm::llvm_shutdown(); return 0; } diff --git a/svf-llvm/tools/LLVM2SVF/llvm2svf.cpp b/svf-llvm/tools/LLVM2SVF/llvm2svf.cpp index 4156f6e6c..0b765de74 100644 --- a/svf-llvm/tools/LLVM2SVF/llvm2svf.cpp +++ b/svf-llvm/tools/LLVM2SVF/llvm2svf.cpp @@ -65,9 +65,9 @@ int main(int argc, char** argv) const std::string jsonPath = replaceExtension(moduleNameVec.front()); // PAG is borrowed from a unique_ptr, so we don't need to delete it. const SVFIR* pag = SVFIRBuilder(svfModule).build(); + SVFIRWriter::writeJsonToPath(pag, jsonPath); SVFUtil::outs() << "SVF IR is written to '" << jsonPath << "'\n"; - LLVMModuleSet::releaseLLVMModuleSet(); return 0; } diff --git a/svf-llvm/tools/MTA/mta.cpp b/svf-llvm/tools/MTA/mta.cpp index b5c2d16b6..74e726e49 100644 --- a/svf-llvm/tools/MTA/mta.cpp +++ b/svf-llvm/tools/MTA/mta.cpp @@ -27,6 +27,7 @@ #include "Util/Options.h" #include "MTAResultValidator.h" #include "LockResultValidator.h" + using namespace llvm; using namespace std; using namespace SVF; @@ -48,6 +49,7 @@ int main(int argc, char ** argv) SVFIRBuilder builder(svfModule); SVFIR* pag = builder.build(); + MTA mta; mta.runOnModule(pag); @@ -57,7 +59,8 @@ int main(int argc, char ** argv) // Initialize the validator and perform validation. LockResultValidator lockvalidator(mta.getLockAnalysis()); lockvalidator.analyze(); - LLVMModuleSet::releaseLLVMModuleSet(); + + return 0; } diff --git a/svf-llvm/tools/SABER/saber.cpp b/svf-llvm/tools/SABER/saber.cpp index 5118311a4..de5981ae3 100644 --- a/svf-llvm/tools/SABER/saber.cpp +++ b/svf-llvm/tools/SABER/saber.cpp @@ -74,6 +74,7 @@ int main(int argc, char ** argv) SVFIRBuilder builder(svfModule); SVFIR* pag = builder.build(); + std::unique_ptr saber; if(LEAKCHECKER()) @@ -86,8 +87,8 @@ int main(int argc, char ** argv) saber = std::make_unique(); // if no checker is specified, we use leak checker as the default one. saber->runOnModule(pag); - LLVMModuleSet::releaseLLVMModuleSet(); + return 0; diff --git a/svf-llvm/tools/WPA/wpa.cpp b/svf-llvm/tools/WPA/wpa.cpp index 7af37df1f..afa86d0f7 100644 --- a/svf-llvm/tools/WPA/wpa.cpp +++ b/svf-llvm/tools/WPA/wpa.cpp @@ -63,12 +63,12 @@ int main(int argc, char** argv) /// Build SVFIR SVFIRBuilder builder(svfModule); pag = builder.build(); + } WPAPass wpa; wpa.runOnModule(pag); LLVMModuleSet::releaseLLVMModuleSet(); - return 0; } diff --git a/svf/include/SVFIR/SVFType.h b/svf/include/SVFIR/SVFType.h index b14632533..d00f342ff 100644 --- a/svf/include/SVFIR/SVFType.h +++ b/svf/include/SVFIR/SVFType.h @@ -261,27 +261,25 @@ class SVFType public: - inline static SVFType* getPtrTy() + inline static SVFType* getSVFPtrType() { - assert(ptrTy && "ptr type not set?"); - return ptrTy; + assert(svfPtrTy && "ptr type not set?"); + return svfPtrTy; } - inline static SVFType* getI8Ty() + inline static SVFType* getSVFInt8Type() { - assert(i8Ty && "int8 type not set?"); - return i8Ty; + assert(svfI8Ty && "int8 type not set?"); + return svfI8Ty; } private: - static SVFType* ptrTy; ///< ptr type - static SVFType* i8Ty; ///< 8-bit int type + static SVFType* svfPtrTy; ///< ptr type + static SVFType* svfI8Ty; ///< 8-bit int type private: GNodeK kind; ///< used for classof - const SVFPointerType* - getPointerToTy; /// Return a pointer to the current type StInfo* typeinfo; ///< SVF's TypeInfo bool isSingleValTy; ///< The type represents a single value, not struct or u32_t byteSize; ///< LLVM Byte Size @@ -289,7 +287,7 @@ class SVFType protected: SVFType(bool svt, SVFTyKind k, u32_t Sz = 1) - : kind(k), getPointerToTy(nullptr), typeinfo(nullptr), + : kind(k), typeinfo(nullptr), isSingleValTy(svt), byteSize(Sz) { } @@ -309,17 +307,6 @@ class SVFType virtual void print(std::ostream& os) const = 0; - inline void setPointerTo(const SVFPointerType* ty) - { - getPointerToTy = ty; - } - - inline const SVFPointerType* getPointerTo() const - { - assert(getPointerToTy && "set the getPointerToTy first"); - return getPointerToTy; - } - inline void setTypeInfo(StInfo* ti) { @@ -373,12 +360,9 @@ class SVFPointerType : public SVFType friend class SVFIRWriter; friend class SVFIRReader; -private: - const SVFType* ptrElementType; - public: SVFPointerType(u32_t byteSize = 1) - : SVFType(true, SVFPointerTy, byteSize), ptrElementType(nullptr) + : SVFType(true, SVFPointerTy, byteSize) { } @@ -386,16 +370,6 @@ class SVFPointerType : public SVFType { return node->getKind() == SVFPointerTy; } - inline const SVFType* getPtrElementType() const - { - assert(ptrElementType && "ptrElementType is nullptr"); - return ptrElementType; - } - - inline void setPtrElementType(SVFType* _ptrElementType) - { - ptrElementType = _ptrElementType; - } void print(std::ostream& os) const override; }; diff --git a/svf/include/Util/Options.h b/svf/include/Util/Options.h index 87edcbb1a..f2f7e024d 100644 --- a/svf/include/Util/Options.h +++ b/svf/include/Util/Options.h @@ -128,6 +128,7 @@ class Options static const Option IndirectCallLimit; static const Option UsePreCompFieldSensitive; static const Option EnableAliasCheck; + static const Option EnableTypeCheck; static const Option EnableThreadCallGraph; static const Option ConnectVCallOnCHA; diff --git a/svf/lib/SVFIR/SVFFileSystem.cpp b/svf/lib/SVFIR/SVFFileSystem.cpp index 370f56b41..d93c33454 100644 --- a/svf/lib/SVFIR/SVFFileSystem.cpp +++ b/svf/lib/SVFIR/SVFFileSystem.cpp @@ -478,7 +478,6 @@ cJSON* SVFIRWriter::contentToJson(const SVFType* type) cJSON* root = jsonCreateObject(); JSON_WRITE_FIELD(root, type, kind); JSON_WRITE_FIELD(root, type, isSingleValTy); - JSON_WRITE_FIELD(root, type, getPointerToTy); JSON_WRITE_FIELD(root, type, typeinfo); return root; } @@ -486,7 +485,6 @@ cJSON* SVFIRWriter::contentToJson(const SVFType* type) cJSON* SVFIRWriter::contentToJson(const SVFPointerType* type) { cJSON* root = contentToJson(static_cast(type)); - JSON_WRITE_FIELD(root, type, ptrElementType); return root; } @@ -2496,14 +2494,12 @@ void SVFIRReader::virtFill(const cJSON*& fieldJson, SVFType* type) void SVFIRReader::fill(const cJSON*& fieldJson, SVFType* type) { // kind has already been read - JSON_READ_FIELD_FWD(fieldJson, type, getPointerToTy); JSON_READ_FIELD_FWD(fieldJson, type, typeinfo); } void SVFIRReader::fill(const cJSON*& fieldJson, SVFPointerType* type) { fill(fieldJson, static_cast(type)); - JSON_READ_FIELD_FWD(fieldJson, type, ptrElementType); } void SVFIRReader::fill(const cJSON*& fieldJson, SVFIntegerType* type) diff --git a/svf/lib/SVFIR/SVFType.cpp b/svf/lib/SVFIR/SVFType.cpp index 651f20dcf..f262b4764 100644 --- a/svf/lib/SVFIR/SVFType.cpp +++ b/svf/lib/SVFIR/SVFType.cpp @@ -4,8 +4,8 @@ namespace SVF { -SVFType* SVFType::i8Ty = nullptr; -SVFType* SVFType::ptrTy = nullptr; +SVFType* SVFType::svfI8Ty = nullptr; +SVFType* SVFType::svfPtrTy = nullptr; __attribute__((weak)) std::string SVFType::toString() const @@ -23,7 +23,7 @@ std::ostream& operator<<(std::ostream& os, const SVFType& type) void SVFPointerType::print(std::ostream& os) const { - os << *ptrElementType << '*'; + os << "ptr"; } void SVFIntegerType::print(std::ostream& os) const diff --git a/svf/lib/Util/Options.cpp b/svf/lib/Util/Options.cpp index 58e009932..9a29abdd8 100644 --- a/svf/lib/Util/Options.cpp +++ b/svf/lib/Util/Options.cpp @@ -301,6 +301,12 @@ const Option Options::EnableAliasCheck( true ); +const Option Options::EnableTypeCheck( + "type-check", + "Enable type check functions", + true +); + const Option Options::EnableThreadCallGraph( "enable-tcg", "Enable pointer analysis to use thread call graph",