diff --git a/kan/KAN.py b/kan/KAN.py index 46216210..68f9e436 100644 --- a/kan/KAN.py +++ b/kan/KAN.py @@ -202,7 +202,7 @@ def initialize_from_another_model(self, another_model, x): # spb = spb_parent preacts = another_model.spline_preacts[l] postsplines = another_model.spline_postsplines[l] - self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, k=spb.k) + self.act_fun[l].coef.data = curve2coef(preacts.reshape(batch, spb.size).permute(1, 0), postsplines.reshape(batch, spb.size).permute(1, 0), spb.grid, k=spb.k, device=self.device) spb.scale_base.data = spb_parent.scale_base.data spb.scale_sp.data = spb_parent.scale_sp.data spb.mask.data = spb_parent.mask.data diff --git a/kan/KANLayer.py b/kan/KANLayer.py index 8f0eb218..fb35d1f1 100644 --- a/kan/KANLayer.py +++ b/kan/KANLayer.py @@ -123,7 +123,7 @@ def __init__(self, in_dim=3, out_dim=2, num=5, k=3, noise_scale=0.1, scale_base= if isinstance(scale_base, float): self.scale_base = torch.nn.Parameter(torch.ones(size, device=device) * scale_base).requires_grad_(sb_trainable) # make scale trainable else: - self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).cuda()).requires_grad_(sb_trainable) + self.scale_base = torch.nn.Parameter(torch.FloatTensor(scale_base).to(device)).requires_grad_(sb_trainable) self.scale_sp = torch.nn.Parameter(torch.ones(size, device=device) * scale_sp).requires_grad_(sp_trainable) # make scale trainable self.base_fun = base_fun @@ -249,8 +249,8 @@ def initialize_grid_from_parent(self, parent, x): # preacts: shape (batch, in_dim) => shape (size, batch) (size = out_dim * in_dim) x_eval = torch.einsum('ij,k->ikj', x, torch.ones(self.out_dim, ).to(self.device)).reshape(batch, self.size).permute(1, 0) x_pos = parent.grid - sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0.).to(self.device) - sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1) + sp2 = KANLayer(in_dim=1, out_dim=self.size, k=1, num=x_pos.shape[1] - 1, scale_base=0., device=self.device) + sp2.coef.data = curve2coef(sp2.grid, x_pos, sp2.grid, k=1, device=self.device) y_eval = coef2curve(x_eval, parent.grid, parent.coef, parent.k, device=self.device) percentile = torch.linspace(-1, 1, self.num + 1).to(self.device) self.grid.data = sp2(percentile.unsqueeze(dim=1))[0].permute(1, 0)