Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom Log Loss does not have the same loss curves #5373

Closed
sushmit-goyal opened this issue Jul 14, 2022 · 9 comments
Closed

Custom Log Loss does not have the same loss curves #5373

sushmit-goyal opened this issue Jul 14, 2022 · 9 comments
Labels

Comments

@sushmit-goyal
Copy link

sushmit-goyal commented Jul 14, 2022

If I use a custom implementation of binary logloss for evaluation in place of default one I get very different loss curves and values. Why is that?

Plots:
Screenshot 2022-07-15 at 12 53 55 AM
Screenshot 2022-07-15 at 12 54 40 AM

Code:

def custom_loss(logits, train_data):
    labels = train_data.get_label()
    logsigmoid = np.where(logits >= 0, 
                          -np.log(1 + np.exp(-logits)),
                          logits - np.log(1 + np.exp(logits)))
    loss = (-logsigmoid + logits * (1 - labels)).mean()
    return "custom_loss", loss, False

lgb_trainData = lightgbm.Dataset(trainData, label = trainLabel)
lgb_validData = lightgbm.Dataset(validData, validLabel, reference = lgb_trainData)

params = {
    'seed': 0,
    'num_leaves': 100,
    'learning_rate': 0.001,
    'is_unbalance': 'true',
    'boosting_type' : 'gbdt',
    'application' : 'binary',
    'feature_fraction' : 1,
    'bagging_fraction' : 1,
    'n_estimators': 10,
    'min_split_gain' : 1.0,
    'min_data_in_leaf': 40,
    'max_depth' : 7,
    'verbose' : -1
    }

eval_result = {}
model = lightgbm.train(params, lgb_trainData, valid_sets = [lgb_validData], valid_names = ['validation_loss'], feval = custom_loss, callbacks=[lightgbm.record_evaluation(eval_result)])

lightgbm.plot_metric(eval_result,metric = 'binary_logloss', title='Default Loss')
lightgbm.plot_metric(eval_result,metric = 'custom_loss', title='Custom Loss')
@jmoralez
Copy link
Collaborator

Hi @sushmit-goyal, thank you for your interest in LightGBM. What the custom objective function receives are the probabilities, not the logits. Changing your custom_loss to:

from sklearn.metrics import log_loss

def custom_loss(probs, train_data):
    loss = log_loss(train_data.get_label(), probs)
    return 'custom_loss', loss, False

I get the same results as the built-in metric.

Please let us know if you have further doubts.

@sushmit-goyal
Copy link
Author

sushmit-goyal commented Jul 18, 2022

Hi @sushmit-goyal, thank you for your interest in LightGBM. What the custom objective function receives are the probabilities, not the logits. Changing your custom_loss to:

from sklearn.metrics import log_loss

def custom_loss(probs, train_data):
    loss = log_loss(train_data.get_label(), probs)
    return 'custom_loss', loss, False

I get the same results as the built-in metric.

Please let us know if you have further doubts.

Hey @jmoralez, I'm using LightGBM 3.3.2 and the documentation in basic.py reads:

fobj : callable or None, optional (default=None)
            Customized objective function.
            Should accept two parameters: preds, train_data,
            and return (grad, hess).

                preds : numpy 1-D array or numpy 2-D array (for multi-class task)
                    The predicted values.
                    Predicted values are returned before any transformation,
                    e.g. they are raw margin instead of probability of positive class for binary task.

@jmoralez
Copy link
Collaborator

jmoralez commented Jul 18, 2022

That's for fobj, but you're using feval, for which the preds argument says:

The predicted values. If fobj is specified, predicted values are returned before any transformation, e.g. they are raw margin instead of probability of positive class for binary task in this case. (docs)

Since you're not using fobj, the preds that feval gets are the probabilities. Please let us know if you have further doubts.

@sushmit-goyal
Copy link
Author

sushmit-goyal commented Jul 18, 2022

Thanks @jmoralez! I was able to reproduce the evaluation loss but I'm getting the same problem while using custom fobj to train the model. I'd like to get the same validation loss using both default and custom objective function but I'm observing different values.

Code:

def logloss_objective(preds, train_data):
    y = train_data.get_label()
    p = special.expit(preds)
    grad = p - y
    hess = p * (1 - p)
    return grad, hess

lgb_trainData = lightgbm.Dataset(trainData, label = trainLabel)
lgb_validData = lightgbm.Dataset(validData, validLabel, reference = lgb_trainData)

params = {
    'seed': 0,
    'num_leaves': 110,
    'learning_rate': 0.01,
    'is_unbalance': 'true',
    'boosting_type' : 'gbdt',
    'application' : 'binary',
    'metric': 'binary_logloss',
    'feature_fraction' : 0.6,
    'bagging_fraction' : 1,
    'n_estimators': 10,
    'min_split_gain' : 1.5,
    'min_data_in_leaf': 50,
    'max_depth' : 7,
    'verbose' : -1    
    }

eval_result = {}
model = lightgbm.train(params, lgb_trainData, valid_sets = [lgb_validData], valid_names = ['validation_loss'], callbacks=[lightgbm.record_evaluation(eval_result)])

lightgbm.plot_metric(eval_result,title='Loss with default objective')

params['objective'] = logloss_objective

eval_result = {}
model = lightgbm.train(params, lgb_trainData, valid_sets = [lgb_validData], valid_names = ['validation_loss'], callbacks=[lightgbm.record_evaluation(eval_result)])

lightgbm.plot_metric(eval_result,title='Loss with custom objective')

Screenshot 2022-07-18 at 10 50 43 PM

Screenshot 2022-07-18 at 10 51 13 PM

The loss curves and values are totally different. Any help would be greatly appreciated!

Update : I was trying the solution in #5114 but it did not work. The loss starts from approximately same value as loss with default objective and explodes.
Code:

lgb_trainData = lightgbm.Dataset(trainData, label = trainLabel , init_score=np.full_like(trainLabel, trainLabel.mean()))
lgb_validData = lightgbm.Dataset(validData, validLabel,init_score=np.full_like(validLabel, validLabel.mean()), reference = lgb_trainData)  # initialise datasets with the init_scores

params['objective'] = logloss_objective

eval_result = {}
model = lightgbm.train(params, lgb_trainData, valid_sets = [lgb_validData], valid_names = ['validation_loss'], callbacks=[lightgbm.record_evaluation(eval_result)])

lightgbm.plot_metric(eval_result,title='Loss with custom objective')

New Loss Curve
Screenshot 2022-07-19 at 1 06 02 AM

@jmoralez
Copy link
Collaborator

Here's an example where I get pretty much the same curves:

import lightgbm as lgb
import matplotlib.pyplot as plt
import numpy as np
from scipy.special import expit, logit
from sklearn.datasets import make_classification
from sklearn.metrics import log_loss
from sklearn.model_selection import train_test_split

params = {
    'seed': 0,
    'num_leaves': 110,
    'learning_rate': 0.01,
    'is_unbalance': True,
    'boosting_type' : 'gbdt',
    'application' : 'binary',
    'metric': 'binary_logloss',
    'feature_fraction' : 0.6,
    'bagging_fraction' : 1,
    'n_estimators': 10,
    'min_split_gain' : 1.5,
    'min_data_in_leaf': 50,
    'max_depth' : 7,
    'verbose' : -1
}

X, y = make_classification(n_samples=10_000, n_features=10, weights=[0.9, 0.1], random_state=0)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, random_state=0)

# built-in
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = train_ds.create_valid(X_valid, y_valid)
eval_result = {}
bst = lgb.train(params, train_ds, valid_sets=[valid_ds], callbacks=[lgb.record_evaluation(eval_result)])
eval_result2 = {}

# custom
cnt_positive = y_train.sum()
cnt_negative = y_train.size - cnt_positive
if cnt_positive > cnt_negative:
    pos_weight = 1.0
    neg_weight = cnt_positive / cnt_negative
else:
    pos_weight = cnt_negative / cnt_positive
    neg_weight = 1.0

def get_weights(labels, pos_weight, neg_weight):
    positives = labels == 1
    weights = np.empty(labels.size)
    weights[positives] = pos_weight
    weights[~positives] = neg_weight
    return weights

def logloss_objective(preds, train_data):
    y = train_data.get_label()
    p = expit(preds)
    weights = train_data.get_weight()
    grad = p - y
    hess = p * (1 - p)
    grad *= weights
    hess *= weights
    return grad, hess

def binary_metric(raw_score, ds):
    y_true = ds.get_label()
    probs = 1.0 / (1.0 + np.exp(-raw_score))
    val = log_loss(y_true, probs)
    return 'custom_loss', val, False

train_weights = get_weights(y_train, pos_weight, neg_weight)
valid_weights = get_weights(y_valid, pos_weight, neg_weight)
init_score = logit(y_train.mean())
train_init = np.full(y_train.size, init_score)
valid_init = np.full(y_valid.size, init_score)
train_ds2 = lgb.Dataset(X_train, y_train, weight=train_weights, init_score=train_init)
valid_ds2 = train_ds.create_valid(X_valid, y_valid, weight=valid_weights, init_score=valid_init)
params2 = {k: v for k, v in params.items() if k != 'metric'}
params2['objective'] = logloss_objective
params2['verbose'] = 1
eval_result2 = {}
bst2 = lgb.train(params2, train_ds2, valid_sets=[valid_ds2], feval=binary_metric, callbacks=[lgb.record_evaluation(eval_result2)])

# plot results
fig, ax = plt.subplots()

ax.plot(eval_result['valid_0']['binary_logloss'], label='built-in')
ax.plot(eval_result2['valid_0']['custom_loss'], label='custom')
ax.legend();

image

When you set is_unbalance=True it performs that scaling

if (is_unbalance_ && cnt_positive > 0 && cnt_negative > 0) {
if (cnt_positive > cnt_negative) {
label_weights_[1] = 1.0f;
label_weights_[0] = static_cast<double>(cnt_positive) / cnt_negative;
} else {
label_weights_[1] = static_cast<double>(cnt_negative) / cnt_positive;
label_weights_[0] = 1.0f;
}
}

so the weights and the init_scores as logits were the parts that were missing I believe.

@sushmit-goyal
Copy link
Author

sushmit-goyal commented Jul 19, 2022

@jmoralez The loss values are tuning out to be even higher now incorporating weights and init_scores.
Code:

params = {
    'seed': 0,
    'num_leaves': 110,
    'learning_rate': 0.01,
    'is_unbalance': True,
    'boosting_type' : 'gbdt',
    'application' : 'binary',
    'metric': 'binary_logloss',
    'feature_fraction' : 0.6,
    'bagging_fraction' : 1,
    'n_estimators': 10,
    'min_split_gain' : 1.5,
    'min_data_in_leaf': 50,
    'max_depth' : 7,
    'verbose' : -1
}

def logloss_objective(preds, train_data):
    y = train_data.get_label()
    p = expit(preds)
    weights = train_data.get_weight()
    grad = p - y
    hess = p * (1 - p)
    grad *= weights
    hess *= weights
    return grad, hess

cnt_positive = trainLabel.sum()
cnt_negative = trainLabel.size - cnt_positive

if cnt_positive > cnt_negative:
    pos_weight = 1.0
    neg_weight = cnt_positive / cnt_negative
else:
    pos_weight = cnt_negative / cnt_positive
    neg_weight = 1.0

def get_weights(labels, pos_weight, neg_weight):
    positives = labels == 1
    weights = np.empty(labels.size)
    weights[positives] = pos_weight
    weights[~positives] = neg_weight
    return weights

train_weights = get_weights(trainLabel, pos_weight, neg_weight)
valid_weights = get_weights(validLabel, pos_weight, neg_weight)

init_score = logit(trainLabel.mean())
train_init = np.full(trainLabel.size, init_score)
valid_init = np.full(validLabel.size, init_score)

lgb_trainData = lightgbm.Dataset(trainData, label = trainLabel , weight=train_weights, init_score=train_init)
lgb_validData = lightgbm.Dataset(validData, label = validLabel, weight=valid_weights, init_score=valid_init, reference = lgb_trainData)

params['objective'] = logloss_objective

eval_result = {}
model = lightgbm.train(params, lgb_trainData, valid_sets = [lgb_validData], valid_names = ['validation_loss'], callbacks=[lightgbm.record_evaluation(eval_result)])

lightgbm.plot_metric(eval_result,title='Loss with custom objective')

Screenshot 2022-07-20 at 2 12 17 AM

What am I doing wrong?

@jmoralez
Copy link
Collaborator

Can you try with feval? I think the built-in metric may not be adding the init score

@github-actions
Copy link

This issue has been automatically closed because it has been awaiting a response for too long. When you have time to to work with the maintainers to resolve this issue, please post a new comment and it will be re-opened. If the issue has been locked for editing by the time you return to it, please open a new issue and reference this one. Thank you for taking the time to improve LightGBM!

@github-actions
Copy link

This issue has been automatically locked since there has not been any recent activity since it was closed. To start a new related discussion, open a new issue at https://github.com/microsoft/LightGBM/issues including a reference to this.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Aug 19, 2023
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests

3 participants