diff --git a/src/enzyme_ad/jax/compile_with_xla.cc b/src/enzyme_ad/jax/compile_with_xla.cc index 4dfc6bf6..b14a9e48 100644 --- a/src/enzyme_ad/jax/compile_with_xla.cc +++ b/src/enzyme_ad/jax/compile_with_xla.cc @@ -168,7 +168,8 @@ run_pass_pipeline(const std::vector &oldsym_vec, std::string output; llvm::raw_string_ostream ss(output); - ss << *parsed_module; + parsed_module->getOperation()->print( + ss, mlir::OpPrintingFlags().enableDebugInfo()); return std::make_pair(entryfn.str(), ss.str()); } diff --git a/src/enzyme_ad/jax/primitives.py b/src/enzyme_ad/jax/primitives.py index 17bdc6cc..3ac3f6fd 100644 --- a/src/enzyme_ad/jax/primitives.py +++ b/src/enzyme_ad/jax/primitives.py @@ -413,7 +413,7 @@ def _enzyme_aug_abstract_eval( avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat) lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") - source = str(mhlo) + source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] @@ -553,7 +553,7 @@ def _enzyme_primal_lowering( avals_in = jax.tree_util.tree_unflatten(in_tree, avals) lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") - source = str(mhlo) + source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx in_args = tuple( arg @@ -701,7 +701,7 @@ def _enzyme_fwd_lowering( avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2]) lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") - source = str(mhlo) + source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx in_args = tuple(arg for (i, arg) in enumerate(in_args) if i // 2 in kept) in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] @@ -763,7 +763,7 @@ def _enzyme_aug_lowering( avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in) lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") - source = str(mhlo) + source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept] @@ -829,7 +829,7 @@ def _enzyme_rev_lowering( avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out) lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in) mhlo = lowered_func.compiler_ir(dialect="stablehlo") - source = str(mhlo) + source = mhlo.operation.get_asm(enable_debug_info=True) kept = lowered_func.compile()._executable._kept_var_idx # in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept) in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]