Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Efficient influence function constructed from influence_fn errors if called multiple times #483

Open
SamWitty opened this issue Jan 7, 2024 · 7 comments
Labels
bug Something isn't working module:robust

Comments

@SamWitty
Copy link
Collaborator

SamWitty commented Jan 7, 2024

chirho.robust.ops.influence_fn at https://github.com/BasisResearch/chirho/blob/staging-robust/chirho/robust/ops.py#L18 return a function for approximating the efficient influence function at a collection of data points. If the returned function is called multiple times it results in the following (somewhat opaque) error:

RuntimeError: !at::functionalization::impl::isFunctionalTensor(base_) INTERNAL ASSERT FAILED at "/Users/runner/work/pytorch/pytorch/pytorch/aten/src/ATen/FunctionalStorageImpl.cpp":101, please report a bug to PyTorch.

For example, this error can be seen in the current tests by adding multiple calls to the generated function as follows:
https://github.com/BasisResearch/chirho/blob/sw-multiple-eif/tests/robust/test_ops.py#L80

@SamWitty SamWitty added bug Something isn't working module:robust labels Jan 7, 2024
@SamWitty
Copy link
Collaborator Author

SamWitty commented Jan 9, 2024

@agrawalraj , this is now somewhat blocking #484 , as (our version of) TMLE involves evaluating the influence function on test_points, and then evaluating the influence function on samples drawn from the model and guide. I could come up with a workaround that only evaluated the influence function once by batching these two datasets into a single call, hence why I say "somewhat" blocking.

@eb8680
Copy link
Contributor

eb8680 commented Jan 11, 2024

I can't reproduce this on my machine with torch==2.1.2. What version of PyTorch were you using?

@SamWitty
Copy link
Collaborator Author

I'm using torch==2.0.1

@SamWitty
Copy link
Collaborator Author

I believe that is consistent with the (implied) requirements of ChiRho via Pyro.

https://github.com/pyro-ppl/pyro/blob/dev/setup.py#L105

@eb8680
Copy link
Contributor

eb8680 commented Jan 11, 2024

You're right, but I think the easiest short-term fix for this would just be to add a torch>=2.1.0 requirement specifically for chirho.robust.

@SamWitty
Copy link
Collaborator Author

When I upgrade torch to 2.1 I still get the same error. I made this "dummy" PR #494 to test this in a fresh environment with the bumped versions. Hopefully this tells us if/where the problem is.

@SamWitty
Copy link
Collaborator Author

It looks like this error doesn't appear in the CI builds even with torch==2.0.1. I'll see if I can reproduce the error another way...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working module:robust
Projects
None yet
Development

No branches or pull requests

2 participants