Wednesday, April 27, 2022

Cross-Entropy, Negative Log-Likelihood, and All That Jazz

 Численный эксперимент

Чтобы понять разницу между CrossEntropyLoss и NLLLoss (и BCELoss и т. д.), я разработал небольшой численный эксперимент следующим образом.

В двоичной настройке я сначала генерирую случайный вектор (z) размера пять из нормального распределения и вручную создаю вектор метки (y) той же формы с элементами либо нулями, либо единицами. Затем я вычисляю предсказанные вероятности (y_hat) на основе z, используя softmax (строка 8). В строке 13 я применяю формулу отрицательного логарифмического правдоподобия, полученную в предыдущем разделе, для вычисления ожидаемого отрицательного логарифмического значения правдоподобия в этом случае. Используя BCELoss с y_hat в качестве входных данных и BCEWithLogitLoss с z в качестве входных данных, я получаю те же результаты, что и вычисленные выше.

В настройке мультикласса я генерирую z2, y2 и вычисляю yhat2 с помощью функции softmax. На этот раз NLLLoss с логарифмическими вероятностями (log of yhat2) в качестве входных данных и CrossEntropyLoss с необработанными значениями прогноза (z) в качестве входных данных дают те же результаты, вычисленные с использованием формулы, полученной ранее.

(.env) [boris@fedora35server LOSS]$ cat lossNegative1.py

import torch

torch.manual_seed(0)

# Binary setting 

#########################

print(f"{'Setting up binary case':-^80}")

z = torch.randn(5)

yhat = torch.sigmoid(z)

y = torch.Tensor([0, 1, 1, 0, 1])

print(f"{z=}\n{yhat=}\n{y=}\n{'':-^80}")


# First compute the negative log likelihoods using the derived formula

l = -(y * yhat.log() + (1 - y) * (1 - yhat).log())

print(f"{l}")

# Observe that BCELoss and BCEWithLogitsLoss can produce the same results

l_BCELoss_nored = torch.nn.BCELoss(reduction="none")(yhat, y)

l_BCEWithLogitsLoss_nored = torch.nn.BCEWithLogitsLoss(reduction="none")(z, y)

print(f"{l_BCELoss_nored}\n{l_BCEWithLogitsLoss_nored}\n{'':=^80}")

# Multiclass setting 

#################

print(f"{'Setting up multiclass case':-^80}")

z2 = torch.randn(5, 3)

yhat2 = torch.softmax(z2, dim=-1)

y2 = torch.Tensor([0, 2, 1, 1, 0]).long()

print(f"{z2=}\n{yhat2=}\n{y2=}\n{'':-^80}")


# First compute the negative log likelihoods using the derived formulat

l2 = -yhat2.log()[torch.arange(5), y2]  # masking the correct entries

print(f"{l2}")

print(-torch.log_softmax(z2, dim=-1)[torch.arange(5), y2])


l2_NLLLoss_nored = torch.nn.NLLLoss(reduction="none")(yhat2.log(), y2)

l2_CrossEntropyLoss_nored = torch.nn.CrossEntropyLoss(reduction="none")(z2, y2)

print(f"{l2_NLLLoss_nored}\n{l2_CrossEntropyLoss_nored}\n{'':=^80}")

**********

Runtime

***********

(.env) [boris@fedora35server LOSS]$ python lossNegative1.py

-----------------------------Setting up binary case-----------------------------

z=tensor([ 1.5410, -0.2934, -2.1788,  0.5684, -1.0845])

yhat=tensor([0.8236, 0.4272, 0.1017, 0.6384, 0.2527])

y=tensor([0., 1., 1., 0., 1.])

--------------------------------------------------------------------------------

tensor([1.7351, 0.8506, 2.2860, 1.0172, 1.3757])

tensor([1.7351, 0.8506, 2.2860, 1.0172, 1.3757])

tensor([1.7351, 0.8506, 2.2860, 1.0172, 1.3757])

===============================================

---------------------------Setting up multiclass case---------------------------

z2=tensor([[-1.3986,  0.4033,  0.8380],

        [-0.7193, -0.4033, -0.5966],

        [ 0.1820, -0.8567,  1.1006],

        [-1.0712,  0.1227, -0.5663],

        [ 0.3731, -0.8920, -1.5091]])

yhat2=tensor([[0.0609, 0.3691, 0.5700],

        [0.2856, 0.3916, 0.3228],

        [0.2591, 0.0917, 0.6492],

        [0.1679, 0.5540, 0.2781],

        [0.6971, 0.1967, 0.1061]])

y2=tensor([0, 2, 1, 1, 0])

--------------------------------------------------------------------------------

tensor([2.7987, 1.1307, 2.3893, 0.5906, 0.3608])

tensor([2.7987, 1.1307, 2.3893, 0.5906, 0.3608])

tensor([2.7987, 1.1307, 2.3893, 0.5906, 0.3608])

tensor([2.7987, 1.1307, 2.3893, 0.5906, 0.3608])

================================================

References

No comments:

Post a Comment