Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to use pykan to fit a piecewise function #200

Closed
zhongjingjogy opened this issue May 15, 2024 · 4 comments
Closed

How to use pykan to fit a piecewise function #200

zhongjingjogy opened this issue May 15, 2024 · 4 comments

Comments

@zhongjingjogy
Copy link

I've experimented with PyKan and discovered its impressive regression capabilities, particularly for modeling nonlinear functional relationships.

I tried to step a bit forward to do regression on a piecewise function. However, the fitting result can be improve significantly. Therefore, I present the implementation here and kindly ask for some suggestions to improve the goodness of fitting. Appreciate very much!

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt

model = KAN(width=[1,10,1], grid=5, k=3, seed=0)

def f(x):
    # piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
    return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

The outcome is
output

@KindXiaoming
Copy link
Owner

In this case, it might be more reasonable to try model = KAN(width=[1,10,1], grid=5, k=1, seed=0) (possibly increase grid as well), but this is just a workaround that may not be ideal and require more careful development.

@zhongjingjogy
Copy link
Author

By trying an continuous but not differentiable, it works nicely.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import torchpwl

# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)

def f(x):
    return torch.where(x > 0.5, x, 0.5)

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-1

@zhongjingjogy
Copy link
Author

Another try. It looks nice.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt
import numpy as np

# model = KAN(width=[1,10,1], grid=5, k=3, seed=0)
model = KAN(width=[1,10,1], grid=100, k=1, seed=0)

def f(x):
    return torch.where(x > 0.5, torch.sin(20.0*x), np.sin(20.0*0.5))

dataset = create_dataset(f, n_var=1, ranges=[0,1], train_num=2000, test_num=2000)

# train the model
model.train(dataset, opt="LBFGS", steps=20)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-2

@zhongjingjogy
Copy link
Author

Great, it works. Thank you so much.

from kan import KAN, create_dataset
import torch
import matplotlib.pyplot as plt

model = KAN(width=[1,4,1], grid=100, k=1, seed=0)

def f(x):
    # piecewise function, x < 0.5, f = 1, x >= 0.5, f = 0
    return torch.where(x[:,[0]]<0.5, torch.ones_like(x[:,[0]]), torch.zeros_like(x[:,[0]]))

dataset = create_dataset(f, n_var=1, ranges=[0,1])

# train the model
model.train(dataset, opt="LBFGS", steps=100)

inputs = dataset['train_input']
predictions = model(inputs) 

plt.plot(dataset['train_input'], dataset['train_label'], 'r', label='True', linestyle='', marker='o', markerfacecolor='white', markevery=10)
plt.plot(dataset['train_input'], predictions.detach().numpy(), 'b', label='Predictions', linestyle='none', marker='o', markerfacecolor='white', markevery=10)

plt.legend()
plt.show()

output-3

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants