Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 18, 2024
1 parent 5cd6ccc commit 6b6abdf
Show file tree
Hide file tree
Showing 18 changed files with 229 additions and 95 deletions.
4 changes: 2 additions & 2 deletions WORKSPACE
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ load("@rules_python//python/pip_install:repositories.bzl", "pip_install_dependen

pip_install_dependencies()

ENZYME_COMMIT = "4ccab29dc691cb43d250a7c5ca612c3ff9cd23e3"
ENZYME_SHA256 = ""
ENZYME_COMMIT = "1e1c0eb1c9b4ae3fa6b0acc2394e305b3fc4e042"
ENZYME_SHA256 = "07eb58bb2b4d877f940b88b87bcb2a8e9f02a9320a23b31697c5a4105cb6f031"

http_archive(
name = "enzyme",
Expand Down
7 changes: 2 additions & 5 deletions src/enzyme_ad/jax/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -186,11 +186,8 @@ pybind_library(
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:Parser",

"@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps",
"@llvm-project//mlir:MLIRBindingsPythonCore",

# EnzymeMLIR

# EnzymeMLIR
"@enzyme//:EnzymeMLIR",

# Mosaic
Expand Down
4 changes: 2 additions & 2 deletions src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*
string dialect = dialect_;
}


class ConstantFP<string val, string dialect_, string op_> : Operation</*primal*/0, /*shadow*/0> {
class ConstantFP<string val, string dialect_, string op_, string type_=""> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
string dialect = dialect_;
string opName = op_;
string type = type_;
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {
Expand Down
14 changes: 14 additions & 0 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ def Exp : HLOInst<"ExpOp">;

def Dot : HLOInst<"DotGeneralOp">;

def Compare : HLOInst<"CompareOp">;
def Select : HLOInst<"SelectOp">;


def CheckedMul : HLOInst<"MulOp">;
def CheckedDiv : HLOInst<"DivOp">;
Expand Down Expand Up @@ -80,6 +83,15 @@ def : HLODerivative<"SqrtOp", (Op $x),
]
>;

def LT : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "ComparisonDirection::LT">;
def : HLODerivative<"MaxOp", (Op $x, $y),
[
(Select (Compare $x, $y, (LT)), (HLOConstantFP<"0"> $x), (DiffeRet)),
(Select (Compare $x, $y, (LT)), (DiffeRet), (HLOConstantFP<"0"> $x))
],
(Select (Compare $x, $y, (LT)), (SelectIfActive $y, (Shadow $y), (HLOConstantFP<"0"> $y)), (SelectIfActive $x, (Shadow $x), (HLOConstantFP<"0"> $x)))
>;

def : HLOReadOnlyIdentityOp<"ReshapeOp">;
def : HLOReadOnlyIdentityOp<"SliceOp">;
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;
Expand All @@ -97,3 +109,5 @@ def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs),
],
(Add (SelectIfActive $lhs, (Dot (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (Dot (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)))
>;

def : HLORegionTerminatorOp<"ReturnOp">;
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

using namespace mlir;
using namespace mlir::enzyme;
using namespace mlir::mhlo;

namespace {
#include "src/enzyme_ad/jax/Implementations/MHLODerivatives.inc"
Expand Down
4 changes: 3 additions & 1 deletion src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnly

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"mhlo", opName_, impl_>;

class HLOConstantFP<string m> : ConstantFP<m, "mhlo", "ConstantOp">;
class HLOConstantFP<string m> : ConstantFP<m, "mhlo", "ConstantOp", "mlir::ElementsAttr">;

class HLORegionTerminatorOp<string m> : RegionTerminatorOp<"mhlo", m>;

include "HLODerivatives.td"
Original file line number Diff line number Diff line change
Expand Up @@ -89,19 +89,58 @@ static bool isEligibleForCompactPrint(ReduceOp op) {
return llvm::equal(innerOp.getResults(), retOp.getOperands());
}

template <typename OpTy>
class AutoDiffReduceCF : public ControlFlowAutoDiffOpInterface::ExternalModel<
AutoDiffReduceCF<OpTy>, OpTy> {
public:
Operation *createWithShadows(Operation *op, OpBuilder &builder,
MGradientUtils *gutils, Operation *original,
ValueRange remappedOperands,
TypeRange rettys) const {
return builder.create<OpTy>(original->getLoc(), rettys, remappedOperands,
original->getAttrs());
}
};

template <typename OpTy>
class AutoDiffReduceFwd
: public AutoDiffOpInterface::ExternalModel<AutoDiffReduceFwd<OpTy>, OpTy> {
public:
LogicalResult createForwardModeTangent(Operation *orig, OpBuilder &builder,
MGradientUtils *gutils) const {
auto red = cast<OpTy>(orig);
if (!isEligibleForCompactPrint(red))
if (!isEligibleForCompactPrint(red)) {
orig->emitError() << "Unsupported operation in reduction autodiff(1): "
<< *orig << "\n";
return failure();
}

Operation &innerOp = red.getBody().front().front();
if (!isa<AddOp>(innerOp))

if (isa<MaxOp>(innerOp) || isa<MinOp>(innerOp)) {
llvm::SmallDenseSet<unsigned> operandPositionsToShadow;
llvm::SmallDenseSet<unsigned> resultPositionsToShadow;
for (auto operand : orig->getOpResults()) {
if (!gutils->isConstantValue(operand)) {
operandPositionsToShadow.insert(
red.getInitValues().getBeginOperandIndex() +
operand.getResultNumber());
operandPositionsToShadow.insert(
red.getInputs().getBeginOperandIndex() +
operand.getResultNumber());
resultPositionsToShadow.insert(operand.getResultNumber());
}
}
return mlir::enzyme::detail::controlFlowForwardHandler(
orig, builder, gutils, operandPositionsToShadow,
resultPositionsToShadow);
}

if (!isa<AddOp>(innerOp)) {
orig->emitError() << "Unsupported operation in reduction autodiff(2): "
<< *orig << "\n";
return failure();
}

Operation *primal = gutils->getNewFromOriginal(orig);

Expand All @@ -122,19 +161,29 @@ class AutoDiffReduceFwd
continue;
}
}
orig->emitWarning() << "Unsupported constant arg to reduce forward "
"handler(opidx="
<< operand.getOperandNumber()
<< ", op=" << operand.get() << ")\n";
orig->emitError() << "Unsupported constant arg to reduce forward "
"handler(opidx="
<< operand.getOperandNumber()
<< ", op=" << operand.get() << ")\n";
return failure();
}
Operation *shadow = builder.clone(*orig, map);

Value shadowRes = shadow->getResult(0);

auto invAdd = gutils->invertedPointers.lookup(innerOp.getResult(0));
gutils->invertedPointers.erase(innerOp.getResult(0));
gutils->erase(invAdd.getDefiningOp());
BitVector baToErase(cast<OpTy>(primal).getBody().front().getNumArguments());
for (auto ba : red.getBody().front().getArguments()) {
auto invBA = cast<BlockArgument>(gutils->invertedPointers.lookup(ba));
gutils->invertedPointers.erase(ba);
baToErase.set(invBA.getArgNumber());
}
cast<OpTy>(primal).getBody().front().eraseArguments(baToErase);

gutils->setDiffe(orig->getResult(0), shadowRes, builder);
gutils->eraseIfUnused(orig);

return success();
}
};
Expand All @@ -147,5 +196,6 @@ void mlir::enzyme::registerStableHLODialectAutoDiffInterface(
+[](MLIRContext *context, stablehlo::StablehloDialect *) {
registerInterfaces(context);
ReduceOp::attachInterface<AutoDiffReduceFwd<ReduceOp>>(*context);
ReduceOp::attachInterface<AutoDiffReduceCF<ReduceOp>>(*context);
});
}
4 changes: 3 additions & 1 deletion src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class HLOReadOnlyIdentityOp<string opName_, list<int> ptrargs_ = [0]> : ReadOnly

class HLOControlFlowOp<string opName_, string impl_> : ControlFlowOp<"stablehlo", opName_, impl_>;

class HLOConstantFP<string m> : ConstantFP<m, "stablehlo", "ConstantOp">;
class HLOConstantFP<string m> : ConstantFP<m, "stablehlo", "ConstantOp", "mlir::ElementsAttr">;

class HLORegionTerminatorOp<string m> : RegionTerminatorOp<"stablehlo", m>;

include "HLODerivatives.td"
14 changes: 1 addition & 13 deletions src/enzyme_ad/jax/Implementations/XLADerivatives.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/DialectRegistry.h"
Expand All @@ -15,16 +16,3 @@ registerXLAAutoDiffInterfaces(mlir::DialectRegistry &registry) {
}
} // namespace enzyme
} // namespace mlir

static inline mlir::DenseFPElementsAttr getTensorAttr(mlir::Type type,
llvm::StringRef value) {
using namespace mlir;
auto T = cast<TensorType>(type);
size_t num = 1;
for (auto sz : T.getShape())
num *= sz;
APFloat apvalue(T.getElementType().cast<FloatType>().getFloatSemantics(),
value);
SmallVector<APFloat> supportedValues(num, apvalue);
return DenseFPElementsAttr::get(type.cast<ShapedType>(), supportedValues);
}
1 change: 1 addition & 0 deletions src/enzyme_ad/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,5 @@
enzyme_jax_ir,
NewXLAPipeline,
OldXLAPipeline,
JaXPipeline,
)
81 changes: 44 additions & 37 deletions src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,10 @@
#include "Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Implementations/XLADerivatives.h"

#include "mlir-c/IR.h"
#include "mlir/CAPI/IR.h"
#include "mlir/lib/Bindings/Python/IRModule.h"

#include "mlir/Dialect/Func/Extensions/InlinerExtension.h"

#include "pybind11/stl.h"

void prepareRegistry(mlir::DialectRegistry &registry) {
mlir::enzyme::registerCoreDialectAutodiffInterfaces(registry);
mlir::enzyme::registerXLAAutoDiffInterfaces(registry);
Expand All @@ -49,62 +47,72 @@ void prepareRegistry(mlir::DialectRegistry &registry) {
/// suffix in `lastUsedID`.
static mlir::StringAttr renameSymbol(llvm::StringRef oldSymName,
unsigned &lastUsedID,
mlir::ModuleOp module) {
std::set<std::string> &oldsym,
mlir::MLIRContext *ctx) {
using namespace llvm;
using namespace mlir;
SmallString<64> newSymName(oldSymName);
newSymName.push_back('_');

MLIRContext *ctx = module->getContext();

while (true) {
auto possible = StringAttr::get(ctx, newSymName + Twine(++lastUsedID));
if (!SymbolTable::lookupSymbolIn(module, possible))
return possible;
auto possible = newSymName + Twine(++lastUsedID);
if (!oldsym.count(possible.str())) {
oldsym.insert(possible.str());
return StringAttr::get(ctx, possible);
}
}
}

/// Checks if a symbol with the same name as `op` already exists in `source`.
/// If so, renames `op` and updates all its references in `target`.
static mlir::LogicalResult updateSymbolAndAllUses(mlir::SymbolOpInterface op,
mlir::ModuleOp target,
mlir::ModuleOp source,
unsigned &lastUsedID) {
static mlir::LogicalResult
updateSymbolAndAllUses(mlir::SymbolOpInterface op, mlir::ModuleOp target,
std::set<std::string> &oldsyms, unsigned &lastUsedID) {
using namespace llvm;
using namespace mlir;
if (!SymbolTable::lookupSymbolIn(source, op.getName()))

auto opName = op.getName().str();

if (!oldsyms.count(opName)) {
oldsyms.insert(opName);
return success();
}

StringRef oldSymName = op.getName();
StringAttr newSymName = renameSymbol(oldSymName, lastUsedID, target);
StringAttr newSymName =
renameSymbol(opName, lastUsedID, oldsyms, target.getContext());

if (failed(SymbolTable::replaceAllSymbolUses(op, newSymName, target)))
return op.emitError("unable to update all symbol uses for ")
<< oldSymName << " to " << newSymName;
<< opName << " to " << newSymName;

SymbolTable::setSymbolName(op, newSymName);
return success();
}

MlirOperation run_pass_pipeline(llvm::StringRef mlir,
const std::string &pass_pipeline) {
std::pair<std::string, std::string>
run_pass_pipeline(const std::vector<std::string> &oldsym_vec,
const std::string &mlir, const std::string &pass_pipeline) {
using namespace llvm;
using namespace mlir;

auto ins = mlir::python::PyThreadContextEntry::getTopOfStack()
->getDefaultInsertionPoint();
auto blk = unwrap(ins->getBlock().get());
std::set<std::string> oldsyms(oldsym_vec.begin(), oldsym_vec.end());

auto oldMod = blk->getParent()->getParentOfType<mlir::ModuleOp>();
// Parse MLIR.
mlir::DialectRegistry registry;
prepareRegistry(registry);
oldMod->getContext()->appendDialectRegistry(registry);
mlir::ParserConfig parser_config(oldMod->getContext());
MLIRContext context(registry);
context.loadDialect<mlir::arith::ArithDialect>();
context.loadDialect<mlir::func::FuncDialect>();
context.loadDialect<mlir::mhlo::MhloDialect>();
context.loadDialect<mlir::stablehlo::StablehloDialect>();
mlir::ParserConfig parser_config(&context);
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mlir, parser_config);
if (!parsed_module) {
throw pybind11::value_error("Failed to parse module");
}

mlir::PassManager pm(oldMod->getContext());
mlir::PassManager pm(&context);

std::string error_message;
llvm::raw_string_ostream error_stream(error_message);
Expand All @@ -122,18 +130,14 @@ MlirOperation run_pass_pipeline(llvm::StringRef mlir,

unsigned lastUsedID = 0;

OpBuilder combinedModuleBuilder(oldMod->getContext());
combinedModuleBuilder.setInsertionPointToStart(oldMod.getBody());

Operation *resultOp = nullptr;
for (auto &op : *parsed_module->getBody()) {
auto symbolOp = dyn_cast<SymbolOpInterface>(op);
if (!symbolOp)
continue;

StringRef oldSymName = symbolOp.getName();

if (failed(updateSymbolAndAllUses(symbolOp, *parsed_module, oldMod,
if (failed(updateSymbolAndAllUses(symbolOp, *parsed_module, oldsyms,
lastUsedID)))
throw pybind11::value_error("failed to update all uses");

Expand All @@ -143,16 +147,16 @@ MlirOperation run_pass_pipeline(llvm::StringRef mlir,
entryfn = newSymName;
}
}
Operation *const cloned = op.clone();
if (newSymName == entryfn) {
resultOp = cloned;
SymbolTable::setSymbolVisibility(&op, SymbolTable::Visibility::Private);
}
combinedModuleBuilder.insert(cloned);
}

SymbolTable::setSymbolVisibility(resultOp, SymbolTable::Visibility::Private);
std::string output;
llvm::raw_string_ostream ss(output);
ss << *parsed_module;

return wrap(resultOp);
return std::make_pair(entryfn.str(), ss.str());
}

// Compile an MHLO module given as a string to LLVM IR using XLA.
Expand All @@ -171,6 +175,9 @@ compile_mhlo_to_llvm_with_xla(llvm::StringRef mhlo_text, std::string &output,
mlir::ParserConfig parser_config(&context);
mlir::OwningOpRef<mlir::ModuleOp> parsed_module =
mlir::parseSourceString<mlir::ModuleOp>(mhlo_text, parser_config);
if (!parsed_module) {
throw pybind11::value_error("Failed to parse module");
}

llvm::StringRef cur_pipeline = pass_pipeline;

Expand Down
Loading

0 comments on commit 6b6abdf

Please sign in to comment.