You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Thank you very much for your excellent work. I have encountered a problem that there is a very small difference between attn and attn1 when performing the following code, which should produce the same effect, so I would like to ask whether different precision calculations are used inside the na2d_qk function or other places that cause small differences.
Thank you for your interest.
I'll note three things that could contribute to the issue:
All BMM-style implementations of NATTEN scatter attention weights according to the window position, whereas in self attention everything is laid out in 1D. This means that the order of attention weights for tokens can be different.
Numerical differences between different implementations, and even different implementations with different configurations are normal, because floating point addition is not associative (due to the fact that precision is limited on hardware). This can sometimes affect precision as well, and is actually why most of our unit tests check outputs against reference with a tolerance for small errors.
Accumulation can be another issue. If you're using fp16, the torch BMM might use fp16 accumulation instead of fp32, while the NATTEN implementation might not, or vice versa. This can also contribute to small errors like that, but probably much more significantly than the previous item.
Thank you very much for your excellent work. I have encountered a problem that there is a very small difference between attn and attn1 when performing the following code, which should produce the same effect, so I would like to ask whether different precision calculations are used inside the na2d_qk function or other places that cause small differences.
attn1 = (q.view(B,self.num_heads,-1,1,self.head_dim) @ k.view(B,self.num_heads,-1,self.head_dim)
attn = na2d_qk(q,k,self.kernel_size,self.dilation)
The text was updated successfully, but these errors were encountered: