Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime Error in hellokan.ipynb #173

Merged
merged 1 commit into from
May 12, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions kan/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def extend_grid(grid, k_extend=0):
value = (x >= grid[:, :-1]) * (x < grid[:, 1:])
else:
B_km1 = B_batch(x[:, 0], grid=grid[:, :, 0], k=k - 1, extend=False, device=device)
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
value = (x - grid[:, :-(k + 1)]) / (grid[:, k:-1] - grid[:, :-(k + 1)]) * B_km1[:, :-1] + (
grid[:, k + 1:] - x) / (grid[:, k + 1:] - grid[:, 1:(-k)]) * B_km1[:, 1:]
return value


Expand Down Expand Up @@ -129,10 +130,11 @@ def curve2coef(x_eval, y_eval, grid, k, device="cpu"):
>>> x_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> y_eval = torch.normal(0,1,size=(num_spline, num_sample))
>>> grids = torch.einsum('i,j->ij', torch.ones(num_spline,), torch.linspace(-1,1,steps=num_grid_interval+1))
>>> curve2coef(x_eval, y_eval, grids, k=k).shape
torch.Size([5, 13])
'''
# x_eval: (size, batch); y_eval: (size, batch); grid: (size, grid); k: scalar
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1).to(y_eval.dtype)
coef = torch.linalg.lstsq(mat.to('cpu'), y_eval.unsqueeze(dim=2).to('cpu')).solution[:, :, 0] # sometimes 'cuda' version may diverge
mat = B_batch(x_eval, grid, k, device=device).permute(0, 2, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please confirm if it is compatible with float64? Please check the context of the changes made in PR #148.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It works with float64 inputs. I agree with Ziming's reply to #146. The singular problem matters. In fact, torch.linalg.lstsq may raise errors due to a singular problem using gels as the driver in cuda version because gels can only work with a full-rank mat.

My other concern about this modification is that I just noticed the author commented the original code # sometimes 'cuda' version may diverge . I don't know if this modification would raise the 'diverge' situation again while running on 'cuda'. For now, this modification at least fixed the Runtime Error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually tested it thoroughly through all singularity problems I could think of. Of course I don't expect this to work in all situations, but at least now it's working on most :)
We'll have to understand better when a non full-rank mat actually appears and how to mitigate the issue - from a theoretic point of view, I can only see this kind of problem in a feature being totally excluded by the graph, which indeed would be a valuable information by itself.

# coef = torch.linalg.lstsq(mat, y_eval.unsqueeze(dim=2)).solution[:, :, 0]
coef = torch.linalg.lstsq(mat.to(device), y_eval.unsqueeze(dim=2).to(device),
driver='gelsy' if device == 'cpu' else 'gels').solution[:, :, 0]
return coef.to(device)