Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

propagate location information #66

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/enzyme_ad/jax/compile_with_xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ run_pass_pipeline(const std::vector<std::string> &oldsym_vec,

std::string output;
llvm::raw_string_ostream ss(output);
ss << *parsed_module;
parsed_module->getOperation()->print(
ss, mlir::OpPrintingFlags().enableDebugInfo());

return std::make_pair(entryfn.str(), ss.str());
}
Expand Down
10 changes: 5 additions & 5 deletions src/enzyme_ad/jax/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def _enzyme_aug_abstract_eval(
avals_in = jax.tree_util.tree_unflatten(in_tree, args_flat)
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]

Expand Down Expand Up @@ -553,7 +553,7 @@ def _enzyme_primal_lowering(
avals_in = jax.tree_util.tree_unflatten(in_tree, avals)
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(
arg
Expand Down Expand Up @@ -701,7 +701,7 @@ def _enzyme_fwd_lowering(
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in[::2])
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i // 2 in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
Expand Down Expand Up @@ -763,7 +763,7 @@ def _enzyme_aug_lowering(
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_in)
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
Expand Down Expand Up @@ -829,7 +829,7 @@ def _enzyme_rev_lowering(
avals_in = jax.tree_util.tree_unflatten(in_tree, ctx.avals_out)
lowered_func = jax.jit(mfunc, **jit_options).lower(*avals_in)
mhlo = lowered_func.compiler_ir(dialect="stablehlo")
source = str(mhlo)
source = mhlo.operation.get_asm(enable_debug_info=True)
kept = lowered_func.compile()._executable._kept_var_idx
# in_args = tuple(arg for (i, arg) in enumerate(in_args) if i in kept)
in_shapes = [shape for (i, shape) in enumerate(in_shapes) if i in kept]
Expand Down
Loading