From 6b6abdfa54200fc4b73a6ef4214556299102ddd8 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 18 Feb 2024 18:22:25 -0500 Subject: [PATCH] fix --- WORKSPACE | 4 +- src/enzyme_ad/jax/BUILD | 7 +- src/enzyme_ad/jax/Implementations/Common.td | 4 +- .../jax/Implementations/HLODerivatives.td | 14 ++++ .../MHLOAutoDiffOpInterfaceImpl.cpp | 1 + .../jax/Implementations/MHLODerivatives.td | 4 +- .../StableHLOAutoDiffOpInterfaceImpl.cpp | 64 +++++++++++++-- .../Implementations/StableHLODerivatives.td | 4 +- .../jax/Implementations/XLADerivatives.h | 14 +--- src/enzyme_ad/jax/__init__.py | 1 + src/enzyme_ad/jax/compile_with_xla.cc | 81 ++++++++++--------- src/enzyme_ad/jax/compile_with_xla.h | 8 +- src/enzyme_ad/jax/enzyme_call.cc | 26 +++++- src/enzyme_ad/jax/primitives.py | 36 ++++++--- test/bench_vs_xla.py | 25 +++--- test/lit_tests/ir.pyt | 2 +- test/llama.py | 22 ++++- test/test.py | 7 +- 18 files changed, 229 insertions(+), 95 deletions(-) diff --git a/WORKSPACE b/WORKSPACE index 3dd2a612..7c08cafd 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen pip_install_dependencies() -ENZYME_COMMIT = "4ccab29dc691cb43d250a7c5ca612c3ff9cd23e3" -ENZYME_SHA256 = "" +ENZYME_COMMIT = "1e1c0eb1c9b4ae3fa6b0acc2394e305b3fc4e042" +ENZYME_SHA256 = "07eb58bb2b4d877f940b88b87bcb2a8e9f02a9320a23b31697c5a4105cb6f031" http_archive( name = "enzyme", diff --git a/src/enzyme_ad/jax/BUILD b/src/enzyme_ad/jax/BUILD index 7da18a96..66cac64b 100644 --- a/src/enzyme_ad/jax/BUILD +++ b/src/enzyme_ad/jax/BUILD @@ -186,11 +186,8 @@ pybind_library( "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:Parser", - - "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", - "@llvm-project//mlir:MLIRBindingsPythonCore", - - # EnzymeMLIR + + # EnzymeMLIR "@enzyme//:EnzymeMLIR", # Mosaic diff --git a/src/enzyme_ad/jax/Implementations/Common.td b/src/enzyme_ad/jax/Implementations/Common.td index 06551416..c6876436 100644 --- a/src/enzyme_ad/jax/Implementations/Common.td +++ b/src/enzyme_ad/jax/Implementations/Common.td @@ -70,11 +70,11 @@ class Inst : Operation : Operation { +class ConstantFP : Operation { string value = val; string dialect = dialect_; string opName = op_; + string type = type_; } def SelectIfActive : Operation { diff --git a/src/enzyme_ad/jax/Implementations/HLODerivatives.td b/src/enzyme_ad/jax/Implementations/HLODerivatives.td index 875af33b..0ed39661 100644 --- a/src/enzyme_ad/jax/Implementations/HLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/HLODerivatives.td @@ -14,6 +14,9 @@ def Exp : HLOInst<"ExpOp">; def Dot : HLOInst<"DotGeneralOp">; +def Compare : HLOInst<"CompareOp">; +def Select : HLOInst<"SelectOp">; + def CheckedMul : HLOInst<"MulOp">; def CheckedDiv : HLOInst<"DivOp">; @@ -80,6 +83,15 @@ def : HLODerivative<"SqrtOp", (Op $x), ] >; +def LT : GlobalExpr; +def : HLODerivative<"MaxOp", (Op $x, $y), + [ + (Select (Compare $x, $y, (LT)), (HLOConstantFP<"0"> $x), (DiffeRet)), + (Select (Compare $x, $y, (LT)), (DiffeRet), (HLOConstantFP<"0"> $x)) + ], + (Select (Compare $x, $y, (LT)), (SelectIfActive $y, (Shadow $y), (HLOConstantFP<"0"> $y)), (SelectIfActive $x, (Shadow $x), (HLOConstantFP<"0"> $x))) + >; + def : HLOReadOnlyIdentityOp<"ReshapeOp">; def : HLOReadOnlyIdentityOp<"SliceOp">; def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">; @@ -97,3 +109,5 @@ def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs), ], (Add (SelectIfActive $lhs, (Dot (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (Dot (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">))) >; + +def : HLORegionTerminatorOp<"ReturnOp">; diff --git a/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp index 8542248c..10fc5b8f 100644 --- a/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/MHLOAutoDiffOpInterfaceImpl.cpp @@ -27,6 +27,7 @@ using namespace mlir; using namespace mlir::enzyme; +using namespace mlir::mhlo; namespace { #include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc" diff --git a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td index 03ea7e0c..5ad475ca 100644 --- a/src/enzyme_ad/jax/Implementations/MHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/MHLODerivatives.td @@ -8,6 +8,8 @@ class HLOReadOnlyIdentityOp ptrargs_ = [0]> : ReadOnly class HLOControlFlowOp : ControlFlowOp<"mhlo", opName_, impl_>; -class HLOConstantFP : ConstantFP; +class HLOConstantFP : ConstantFP; + +class HLORegionTerminatorOp : RegionTerminatorOp<"mhlo", m>; include "HLODerivatives.td" diff --git a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp index 815bac0b..d6109170 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp +++ b/src/enzyme_ad/jax/Implementations/StableHLOAutoDiffOpInterfaceImpl.cpp @@ -89,6 +89,19 @@ static bool isEligibleForCompactPrint(ReduceOp op) { return llvm::equal(innerOp.getResults(), retOp.getOperands()); } +template +class AutoDiffReduceCF : public ControlFlowAutoDiffOpInterface::ExternalModel< + AutoDiffReduceCF, OpTy> { +public: + Operation *createWithShadows(Operation *op, OpBuilder &builder, + MGradientUtils *gutils, Operation *original, + ValueRange remappedOperands, + TypeRange rettys) const { + return builder.create(original->getLoc(), rettys, remappedOperands, + original->getAttrs()); + } +}; + template class AutoDiffReduceFwd : public AutoDiffOpInterface::ExternalModel, OpTy> { @@ -96,12 +109,38 @@ class AutoDiffReduceFwd LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder, MGradientUtils *gutils) const { auto red = cast(orig); - if (!isEligibleForCompactPrint(red)) + if (!isEligibleForCompactPrint(red)) { + orig->emitError() << "Unsupported operation in reduction autodiff(1): " + << *orig << "\n"; return failure(); + } Operation &innerOp = red.getBody().front().front(); - if (!isa(innerOp)) + + if (isa(innerOp) || isa(innerOp)) { + llvm::SmallDenseSet operandPositionsToShadow; + llvm::SmallDenseSet resultPositionsToShadow; + for (auto operand : orig->getOpResults()) { + if (!gutils->isConstantValue(operand)) { + operandPositionsToShadow.insert( + red.getInitValues().getBeginOperandIndex() + + operand.getResultNumber()); + operandPositionsToShadow.insert( + red.getInputs().getBeginOperandIndex() + + operand.getResultNumber()); + resultPositionsToShadow.insert(operand.getResultNumber()); + } + } + return mlir::enzyme::detail::controlFlowForwardHandler( + orig, builder, gutils, operandPositionsToShadow, + resultPositionsToShadow); + } + + if (!isa(innerOp)) { + orig->emitError() << "Unsupported operation in reduction autodiff(2): " + << *orig << "\n"; return failure(); + } Operation *primal = gutils->getNewFromOriginal(orig); @@ -122,19 +161,29 @@ class AutoDiffReduceFwd continue; } } - orig->emitWarning() << "Unsupported constant arg to reduce forward " - "handler(opidx=" - << operand.getOperandNumber() - << ", op=" << operand.get() << ")\n"; + orig->emitError() << "Unsupported constant arg to reduce forward " + "handler(opidx=" + << operand.getOperandNumber() + << ", op=" << operand.get() << ")\n"; return failure(); } Operation *shadow = builder.clone(*orig, map); Value shadowRes = shadow->getResult(0); + auto invAdd = gutils->invertedPointers.lookup(innerOp.getResult(0)); + gutils->invertedPointers.erase(innerOp.getResult(0)); + gutils->erase(invAdd.getDefiningOp()); + BitVector baToErase(cast(primal).getBody().front().getNumArguments()); + for (auto ba : red.getBody().front().getArguments()) { + auto invBA = cast(gutils->invertedPointers.lookup(ba)); + gutils->invertedPointers.erase(ba); + baToErase.set(invBA.getArgNumber()); + } + cast(primal).getBody().front().eraseArguments(baToErase); + gutils->setDiffe(orig->getResult(0), shadowRes, builder); gutils->eraseIfUnused(orig); - return success(); } }; @@ -147,5 +196,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface( +[](MLIRContext *context, stablehlo::StablehloDialect *) { registerInterfaces(context); ReduceOp::attachInterface>(*context); + ReduceOp::attachInterface>(*context); }); } diff --git a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td index 100e3179..40ee8ce0 100644 --- a/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td +++ b/src/enzyme_ad/jax/Implementations/StableHLODerivatives.td @@ -8,6 +8,8 @@ class HLOReadOnlyIdentityOp ptrargs_ = [0]> : ReadOnly class HLOControlFlowOp : ControlFlowOp<"stablehlo", opName_, impl_>; -class HLOConstantFP : ConstantFP; +class HLOConstantFP : ConstantFP; + +class HLORegionTerminatorOp : RegionTerminatorOp<"stablehlo", m>; include "HLODerivatives.td" diff --git a/src/enzyme_ad/jax/Implementations/XLADerivatives.h b/src/enzyme_ad/jax/Implementations/XLADerivatives.h index 5df9a769..122ea455 100644 --- a/src/enzyme_ad/jax/Implementations/XLADerivatives.h +++ b/src/enzyme_ad/jax/Implementations/XLADerivatives.h @@ -1,3 +1,4 @@ +#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/DialectRegistry.h" @@ -15,16 +16,3 @@ registerXLAAutoDiffInterfaces(mlir::DialectRegistry ®istry) { } } // namespace enzyme } // namespace mlir - -static inline mlir::DenseFPElementsAttr getTensorAttr(mlir::Type type, - llvm::StringRef value) { - using namespace mlir; - auto T = cast(type); - size_t num = 1; - for (auto sz : T.getShape()) - num *= sz; - APFloat apvalue(T.getElementType().cast().getFloatSemantics(), - value); - SmallVector supportedValues(num, apvalue); - return DenseFPElementsAttr::get(type.cast(), supportedValues); -} diff --git a/src/enzyme_ad/jax/__init__.py b/src/enzyme_ad/jax/__init__.py index 09df85e5..c52415aa 100644 --- a/src/enzyme_ad/jax/__init__.py +++ b/src/enzyme_ad/jax/__init__.py @@ -3,4 +3,5 @@ enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, + JaXPipeline, ) diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index f9e48d0f..d4e7cad9 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -33,12 +33,10 @@ #include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h" #include "Implementations/XLADerivatives.h" -#include "mlir-c/IR.h" -#include "mlir/CAPI/IR.h" -#include "mlir/lib/Bindings/Python/IRModule.h" - #include "mlir/Dialect/Func/Extensions/InlinerExtension.h" +#include "pybind11/stl.h" + void prepareRegistry(mlir::DialectRegistry ®istry) { mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry); mlir::enzyme::registerXLAAutoDiffInterfaces(registry); @@ -49,62 +47,72 @@ void prepareRegistry(mlir::DialectRegistry ®istry) { /// suffix in `lastUsedID`. static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName, unsigned &lastUsedID, - mlir::ModuleOp module) { + std::set &oldsym, + mlir::MLIRContext *ctx) { using namespace llvm; using namespace mlir; SmallString<64> newSymName(oldSymName); newSymName.push_back('_'); - MLIRContext *ctx = module->getContext(); - while (true) { - auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID)); - if (!SymbolTable::lookupSymbolIn(module, possible)) - return possible; + auto possible = newSymName + Twine(++lastUsedID); + if (!oldsym.count(possible.str())) { + oldsym.insert(possible.str()); + return StringAttr::get(ctx, possible); + } } } /// Checks if a symbol with the same name as `op` already exists in `source`. /// If so, renames `op` and updates all its references in `target`. -static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op, - mlir::ModuleOp target, - mlir::ModuleOp source, - unsigned &lastUsedID) { +static mlir::LogicalResult +updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp target, + std::set &oldsyms, unsigned &lastUsedID) { using namespace llvm; using namespace mlir; - if (!SymbolTable::lookupSymbolIn(source, op.getName())) + + auto opName = op.getName().str(); + + if (!oldsyms.count(opName)) { + oldsyms.insert(opName); return success(); + } - StringRef oldSymName = op.getName(); - StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target); + StringAttr newSymName = + renameSymbol(opName, lastUsedID, oldsyms, target.getContext()); if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target))) return op.emitError("unable to update all symbol uses for ") - << oldSymName << " to " << newSymName; + << opName << " to " << newSymName; SymbolTable::setSymbolName(op, newSymName); return success(); } -MlirOperation run_pass_pipeline(llvm::StringRef mlir, - const std::string &pass_pipeline) { +std::pair +run_pass_pipeline(const std::vector &oldsym_vec, + const std::string &mlir, const std::string &pass_pipeline) { using namespace llvm; using namespace mlir; - auto ins = mlir::python::PyThreadContextEntry::getTopOfStack() - ->getDefaultInsertionPoint(); - auto blk = unwrap(ins->getBlock().get()); + std::set oldsyms(oldsym_vec.begin(), oldsym_vec.end()); - auto oldMod = blk->getParent()->getParentOfType(); // Parse MLIR. mlir::DialectRegistry registry; prepareRegistry(registry); - oldMod->getContext()->appendDialectRegistry(registry); - mlir::ParserConfig parser_config(oldMod->getContext()); + MLIRContext context(registry); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + context.loadDialect(); + mlir::ParserConfig parser_config(&context); mlir::OwningOpRef parsed_module = mlir::parseSourceString(mlir, parser_config); + if (!parsed_module) { + throw pybind11::value_error("Failed to parse module"); + } - mlir::PassManager pm(oldMod->getContext()); + mlir::PassManager pm(&context); std::string error_message; llvm::raw_string_ostream error_stream(error_message); @@ -122,10 +130,6 @@ MlirOperation run_pass_pipeline(llvm::StringRef mlir, unsigned lastUsedID = 0; - OpBuilder combinedModuleBuilder(oldMod->getContext()); - combinedModuleBuilder.setInsertionPointToStart(oldMod.getBody()); - - Operation *resultOp = nullptr; for (auto &op : *parsed_module->getBody()) { auto symbolOp = dyn_cast(op); if (!symbolOp) @@ -133,7 +137,7 @@ MlirOperation run_pass_pipeline(llvm::StringRef mlir, StringRef oldSymName = symbolOp.getName(); - if (failed(updateSymbolAndAllUses(symbolOp, *parsed_module, oldMod, + if (failed(updateSymbolAndAllUses(symbolOp, *parsed_module, oldsyms, lastUsedID))) throw pybind11::value_error("failed to update all uses"); @@ -143,16 +147,16 @@ MlirOperation run_pass_pipeline(llvm::StringRef mlir, entryfn = newSymName; } } - Operation *const cloned = op.clone(); if (newSymName == entryfn) { - resultOp = cloned; + SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private); } - combinedModuleBuilder.insert(cloned); } - SymbolTable::setSymbolVisibility(resultOp, SymbolTable::Visibility::Private); + std::string output; + llvm::raw_string_ostream ss(output); + ss << *parsed_module; - return wrap(resultOp); + return std::make_pair(entryfn.str(), ss.str()); } // Compile an MHLO module given as a string to LLVM IR using XLA. @@ -171,6 +175,9 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, mlir::ParserConfig parser_config(&context); mlir::OwningOpRef parsed_module = mlir::parseSourceString(mhlo_text, parser_config); + if (!parsed_module) { + throw pybind11::value_error("Failed to parse module"); + } llvm::StringRef cur_pipeline = pass_pipeline; diff --git a/src/enzyme_ad/jax/compile_with_xla.h b/src/enzyme_ad/jax/compile_with_xla.h index 532c6c85..d7deabfe 100644 --- a/src/enzyme_ad/jax/compile_with_xla.h +++ b/src/enzyme_ad/jax/compile_with_xla.h @@ -2,11 +2,15 @@ #include "xla/client/local_client.h" #include "llvm/ADT/StringRef.h" #include -#include "mlir-c/IR.h" + +#include + // Compile an MHLO module given as a string to LLVM IR using XLA. std::unique_ptr compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output, bool xla_runtime, const std::string &pass_pipeline); -MlirOperation run_pass_pipeline(llvm::StringRef mlir, const std::string &pass_pipeline); +std::pair +run_pass_pipeline(const std::vector &oldsyms, + const std::string &mlir, const std::string &pass_pipeline); diff --git a/src/enzyme_ad/jax/enzyme_call.cc b/src/enzyme_ad/jax/enzyme_call.cc index 8f5c007e..9898ffbe 100644 --- a/src/enzyme_ad/jax/enzyme_call.cc +++ b/src/enzyme_ad/jax/enzyme_call.cc @@ -1146,7 +1146,31 @@ PYBIND11_MODULE(enzyme_call, m) { "xla._CUSTOM_CALL_TARGET"); }); - m.def("run_pass_pipeline", run_pass_pipeline); + m.def("run_pass_pipeline", + [](pybind11::object pyoldsyms, const std::string &mlir, + const std::string &pass_pipeline) { + auto pyargv = pyoldsyms.ptr(); + std::vector oldsyms; + assert(PySequence_Check(pyargv)); + auto sz = PySequence_Size(pyargv); + for (Py_ssize_t i = 0; i < sz; ++i) { + PyObject *item = PySequence_GetItem(pyargv, i); +#if PY_VERSION_HEX < 0x03000000 + auto argv = PyString_AsString(item); +#else + auto argv = PyUnicode_AsUTF8(item); +#endif + Py_DECREF(item); + assert(argv); + oldsyms.emplace_back(argv); +#if PY_VERSION_HEX < 0x03000000 + free(argv); +#else + // should not free py3+ +#endif + } + return run_pass_pipeline(oldsyms, mlir, pass_pipeline); + }); m.def("compile_mhlo_to_llvm_with_xla", [](const std::string &mhlo_text, bool xla_runtime, diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 8c14d9ed..3376702a 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -75,7 +75,6 @@ class NewXLAPipeline: def __init__(self, passes=None, mlirad=False): if passes is None: passes = """ - print, stablehlo-legalize-to-hlo, inline{default-pipeline=canonicalize max-iterations=4}, expand-hlo-tuples{entry-function=main}, @@ -487,7 +486,7 @@ def _enzyme_primal_lowering( pass_pipeline = pipeline_options.pass_pipeline() if lang == LANG_MHLO: (in_tree, in_idx_map, mfunc) = source - + orig_shapes = [] seen = {} for i, shape in enumerate(in_shapes): @@ -511,10 +510,21 @@ def _enzyme_primal_lowering( shape for (i, shape) in enumerate(in_shapes) if in_idx_map[i] in kept ] if pipeline_options.stablehlo_inject(): - fn = enzyme_call.run_pass_pipeline(source, pass_pipeline) - print(fn) - results = func.CallOp(fn.name, out_types, in_args) - print(results) + ins = ir.InsertionPoint.current + mod = ins.block.region.owner.parent + fns = [] + for f in mod.regions[0].blocks[0]: + fns.append(f.sym_name.value) + + name, nmod = enzyme_call.run_pass_pipeline(fns, source, pass_pipeline) + nmod = ir.Module.parse(nmod) + fn = None + for f in nmod.body: + mod.regions[0].blocks[0].append(f) + if f.sym_name.value == name: + fn = f + results = func.CallOp(fn, list(in_args)) + results = results.results return results argv = argv + ("-resource-dir", resource_dir()) + cflags() @@ -835,14 +845,18 @@ def make_zero(tan, prim): if pipeline_options.mlir_ad() and kwargs["lang"] == LANG_MHLO: act_tup = ",".join(["enzyme_dup" for a in arg_primals]) newpasses = ( - "inline{default-pipeline=canonicalize max-iterations=4}," + - "func.func(stablehlo-aggressive-simplification),cse,print,enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys=" + "inline{default-pipeline=canonicalize max-iterations=4}," + + "func.func(stablehlo-aggressive-simplification),cse,enzyme-wrap{infn=main outfn= retTy=enzyme_dup argTys=" + act_tup + " mode=ForwardMode}," - + "arith-raise{stablehlo=true}, func.func(stablehlo-aggressive-simplification), cse, canonicalize, print," - + pipeline_options.pass_pipeline() + + "arith-raise{stablehlo=true}, func.func(stablehlo-aggressive-simplification), cse, canonicalize" ) - pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) + if pipeline_options.pass_pipeline() != "": + newpasses = newpasses + "," + pipeline_options.pass_pipeline() + if pipeline_options.stablehlo_inject(): + pipeline_options = JaXPipeline(newpasses) + else: + pipeline_options = NewXLAPipeline(newpasses, pipeline_options.mlir_ad()) outshapes2 = [] for o in kwargs["out_shapes"]: outshapes2.append(o) diff --git a/test/bench_vs_xla.py b/test/bench_vs_xla.py index 865edad5..53adb168 100644 --- a/test/bench_vs_xla.py +++ b/test/bench_vs_xla.py @@ -1,6 +1,6 @@ import jax import jax.numpy as jnp -from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline +from enzyme_ad.jax import enzyme_jax_ir, NewXLAPipeline, OldXLAPipeline, JaXPipeline from absl.testing import absltest import timeit @@ -8,13 +8,14 @@ number = 1000 AllPipelines = [ + ("JaXPipeline", JaXPipeline()), ("NewXLAMLIR", NewXLAPipeline(mlirad=True)), ("NewXLA", NewXLAPipeline()), ("OldXLA", OldXLAPipeline()), ] -PrimalPipelines = AllPipelines[1:] +PrimalPipelines = AllPipelines FwdPipelines = AllPipelines -RevPipelines = AllPipelines[1:] +RevPipelines = AllPipelines[2:] # @jax.jit @@ -71,7 +72,8 @@ def harness(self, name, in_fn, ins, dins, douts): "fn": rfn_jax, } | primalins, - ).timeit(number) / number, + ).timeit(number) + / number, ) fwd_jax = jax.jit(splatjvp(rfn_jax)) @@ -97,7 +99,8 @@ def harness(self, name, in_fn, ins, dins, douts): "fwd": fwd_jax, } | fwdins, - ).timeit(number) / number, + ).timeit(number) + / number, ) assert len(douts) == 1 @@ -123,7 +126,8 @@ def harness(self, name, in_fn, ins, dins, douts): "rev": rev_jax, } | revins, - ).timeit(number) / number, + ).timeit(number) + / number, ) for name, pipeline in AllPipelines: @@ -145,7 +149,8 @@ def harness(self, name, in_fn, ins, dins, douts): "fn": rfn_enzyme, } | primalins, - ).timeit(number) / number, + ).timeit(number) + / number, ) if (name, pipeline) in FwdPipelines: @@ -171,7 +176,8 @@ def harness(self, name, in_fn, ins, dins, douts): "fwd": fwd_enzyme, } | fwdins, - ).timeit(number) / number, + ).timeit(number) + / number, ) if (name, pipeline) in RevPipelines: @@ -194,7 +200,8 @@ def harness(self, name, in_fn, ins, dins, douts): "rev": rev_enzyme, } | revins, - ).timeit(number) / number, + ).timeit(number) + / number, ) diff --git a/test/lit_tests/ir.pyt b/test/lit_tests/ir.pyt index 6971756c..c9849a84 100644 --- a/test/lit_tests/ir.pyt +++ b/test/lit_tests/ir.pyt @@ -46,7 +46,7 @@ print(fwdmode.lower(ones, twos, ones, twos).compiler_ir(dialect="stablehlo")) # CHECK: module @jit_fwdmode attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { # CHECK-NEXT: func.func public @main(%arg0: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg1: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg2: tensor<2x3xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}, %arg3: tensor<5x7xf32> {mhlo.layout_mode = "default", mhlo.sharding = "{replicated}"}) -> (tensor<6x9xf32> {jax.result_info = "[0][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[0][1]", mhlo.layout_mode = "default"}, tensor<6x9xf32> {jax.result_info = "[1][0]", mhlo.layout_mode = "default"}, tensor<4x6xf32> {jax.result_info = "[1][1]", mhlo.layout_mode = "default"}) { # CHECK-NEXT: %0 = stablehlo.constant dense<1> : tensor<1xi64> -# CHECK-NEXT %1:4 = stablehlo.custom_call @jaxzyme.fwd(%0, %arg0, %arg2, %arg1, %arg3) : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>) +# CHECK-NEXT: %1:4 = stablehlo.custom_call @jaxzyme.fwd(%0, %arg0, %arg2, %arg1, %arg3) : (tensor<1xi64>, tensor<2x3xf32>, tensor<2x3xf32>, tensor<5x7xf32>, tensor<5x7xf32>) -> (tensor<6x9xf32>, tensor<6x9xf32>, tensor<4x6xf32>, tensor<4x6xf32>) # CHECK-NEXT: return %1#0, %1#2, %1#1, %1#3 : tensor<6x9xf32>, tensor<4x6xf32>, tensor<6x9xf32>, tensor<4x6xf32> # CHECK-NEXT: } # CHECK-NEXT: } diff --git a/test/llama.py b/test/llama.py index a0ffb356..90eb1606 100644 --- a/test/llama.py +++ b/test/llama.py @@ -8,6 +8,7 @@ argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") + def rmsnorm(x, weight): ss = 1 / jnp.sqrt(x.dot(x) / x.shape[0] + 1e-5) return weight * x * ss @@ -31,6 +32,9 @@ def silu(x): asserts = True pipeline = enzyme_jax.NewXLAPipeline(mlirad=True) +pipeline = enzyme_jax.JaXPipeline() +pipeline = enzyme_jax.NewXLAPipeline(mlirad=False) + def forward(x, config, weights, key_cache, value_cache): pos = key_cache.shape[1] @@ -62,7 +66,23 @@ def forward(x, config, weights, key_cache, value_cache): wo = weights["wo"] if asserts: if wo.shape != (n_layers, dim, dim): - print(wo.shape, weights, (n_layers, dim, kv_dim, kv_mul, head_size, hidden_dim, n_kv_heads, vocab_size, n_heads, seq_len, n_layers)) + print( + wo.shape, + weights, + ( + n_layers, + dim, + kv_dim, + kv_mul, + head_size, + hidden_dim, + n_kv_heads, + vocab_size, + n_heads, + seq_len, + n_layers, + ), + ) assert wo.shape == (n_layers, dim, dim) rms_ffn_weight = weights["rms_ffn_weight"] if asserts: diff --git a/test/test.py b/test/test.py index e90198d9..0dcc4897 100644 --- a/test/test.py +++ b/test/test.py @@ -5,6 +5,7 @@ argv = ("-I/usr/include/c++/11", "-I/usr/include/x86_64-linux-gnu/c++/11") + class EnzymeJax(absltest.TestCase): def test_custom_cpp_kernel(self): @jax.jit @@ -30,7 +31,8 @@ def do_something(ones): } } """, - fn="myfn", argv=argv + fn="myfn", + argv=argv, ) c = cpp_call( a, @@ -40,7 +42,8 @@ def do_something(ones): void f(T1& out0, const T2& in1) { out0 = 56.0f; } - """, argv=argv + """, + argv=argv, ) return a, b, c