Skip to content

Commit

Permalink
Add dot [fwd]
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 16, 2024
1 parent 8c74fa0 commit 5cd6ccc
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 6 deletions.
24 changes: 23 additions & 1 deletion src/enzyme_ad/jax/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@ class RegionTerminatorOp<string dialect_, string opName_> {
string opName = opName_;
}

class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps> {
class ForwardFromSummedReverseInternal<int unused_> {
int unused = unused_;
}
def ForwardFromSummedReverse : ForwardFromSummedReverseInternal<0>;


class MLIRDerivative<string dialect_, string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> {
string dialect = dialect_;
string opName = opName_;
dag PatternToMatch = patternToMatch;
list<dag> ArgDerivatives = resultOps;
dag ArgDuals = forwardOps;
}

class Operation<bit usesPrimal_, bit usesShadow_, bit usesCustom_=0> {
Expand All @@ -51,6 +58,13 @@ class DiffeRetIndex<list<int> indices_> {
}
def DiffeRet : DiffeRetIndex<[-1]>;

def Shadow : Operation</*primal*/0, /*shadow*/1> {
}

class GlobalExpr<bit uses_primal, bit uses_shadow, string val> : Operation<uses_primal, uses_shadow>{
string value = val;
}

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
Expand All @@ -63,5 +77,13 @@ class ConstantFP<string val, string dialect_, string op_> : Operation</*primal*/
string opName = op_;
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

def Op {
}

def ResultTypes : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op->getResultTypes()">;


17 changes: 14 additions & 3 deletions src/enzyme_ad/jax/Implementations/HLODerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ def Sin : HLOInst<"SineOp">;
def Sqrt : HLOInst<"SqrtOp">;
def Exp : HLOInst<"ExpOp">;

def Dot : HLOInst<"DotGeneralOp">;


def CheckedMul : HLOInst<"MulOp">;
def CheckedDiv : HLOInst<"DivOp">;
Expand Down Expand Up @@ -83,6 +85,15 @@ def : HLOReadOnlyIdentityOp<"SliceOp">;
def : HLOReadOnlyIdentityOp<"BroadcastInDimOp">;
def : HLOReadOnlyIdentityOp<"ConcatenateOp">;
// convert
// cos
// sin
// sqrt


def ResultDotDim : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op.getDotDimensionNumbersAttr()">;
def ResultDotPrec : GlobalExpr</*needsprimal*/0, /*needsshadow*/0, "op.getPrecisionConfigAttr()">;

def : HLODerivative<"DotGeneralOp", (Op $lhs, $rhs),
[
(Dot (ResultTypes), (DiffeRet), $rhs, (ResultDotDim), (ResultDotPrec)),
(Dot (ResultTypes), $lhs, (DiffeRet), (ResultDotDim), (ResultDotPrec))
],
(Add (SelectIfActive $lhs, (Dot (ResultTypes), (Shadow $lhs), $rhs, (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)), (SelectIfActive $rhs, (Dot (ResultTypes), $lhs, (Shadow $rhs), (ResultDotDim), (ResultDotPrec)), (HLOConstantFP<"0">)))
>;
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Implementations/MHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include "Common.td"

class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps>;
class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"mhlo", opName_, patternToMatch, resultOps, forwardOps>;

class HLOInst<string m> : Inst<m, "mhlo">;

Expand Down
2 changes: 1 addition & 1 deletion src/enzyme_ad/jax/Implementations/StableHLODerivatives.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
include "Common.td"

class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps>;
class HLODerivative<string opName_, dag patternToMatch, list<dag> resultOps, dag forwardOps=(ForwardFromSummedReverse)> : MLIRDerivative<"stablehlo", opName_, patternToMatch, resultOps, forwardOps>;

class HLOInst<string m> : Inst<m, "stablehlo">;

Expand Down

0 comments on commit 5cd6ccc

Please sign in to comment.