diff --git a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp index 609f9deb..e35db2fd 100644 --- a/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp +++ b/src/enzyme_ad/jax/Passes/EnzymeHLOOpt.cpp @@ -2627,13 +2627,15 @@ struct MulSimplify : public OpRewritePattern { } // 1 * x -> x - if (matchPattern(op.getLhs(), m_One())) { + if (matchPattern(op.getLhs(), m_One()) || + matchPattern(op.getLhs(), m_OneFloat())) { rewriter.replaceOp(op, op.getRhs()); return success(); } // x * 1 -> x - if (matchPattern(op.getRhs(), m_One())) { + if (matchPattern(op.getRhs(), m_One()) || + matchPattern(op.getRhs(), m_OneFloat())) { rewriter.replaceOp(op, op.getLhs()); return success(); } diff --git a/workspace.bzl b/workspace.bzl index 92318c7f..c575e0af 100644 --- a/workspace.bzl +++ b/workspace.bzl @@ -1,8 +1,8 @@ JAX_COMMIT = "493698e6e053641aa8c51bca657cbd763a3ced19" JAX_SHA256 = "f8bbcc40cdee9d8d83a7f6e197ce111f1c01ee00341eab83ddd9367e48519665" -ENZYME_COMMIT = "0246d1a57ceb1e1d24215e3d660d54e6ac4f1a0d" -ENZYME_SHA256 = "4109afb1784d52bed965c133578796e7a98f2b99cf83e641eb4647734222fe8e" +ENZYME_COMMIT = "d5eac0fc9b2f0a4054f7bcd815cc5698661ae112" +ENZYME_SHA256 = "7ac6047d15358434ec77833ebfd96704f5784dea0c79aa2408d2b4bc08183777" PYRULES_COMMIT = "fe33a4582c37499f3caeb49a07a78fc7948a8949" PYRULES_SHA256 = "cfa6957832ae0e0c7ee2ccf455a888a291e8419ed8faf45f4420dd7414d5dd96"