Cross Entropy Loss Explained

A high level overview of PyTorch’s torch.nn.CrossEntropyLoss
ml
notes
Author

Seongbin Park

Published

August 6, 2022

Cross entropy loss is a loss function that can be used for multi-class classification using neural networks. Chapter 5 of the fast.ai textbook outlines the use of cross entropy loss for binary classification, so in this post, we will take a look at classification for 3 classes.

required libraries
from fastai.vision.all import *
torch.random.manual_seed(42);

Softmax

The softmax function ensures 2 things: - activations are all between 0 and 1 - activations sum to 1.

For multi-class classification, we need an activation per class (in the final layer). Each activation then indicates the relative confidence of each class being the true label. Therefore, we can get the predicted prababilities that each class is the true label by applying the softmax function to the final column of activations.

Given \(C\) total classes, for any class \(k,\) let’s say \(x_k\) represents the activation for \(c\). Then, the softmax activation for an arbitrary class \(c\) is equal to

\[\frac{e^{x_c}}{\sum^C_{k=1}e^{x_k}}.\]

In Python code, this would be

def softmax(x): return exp(x) / exp(x).sum(dim=1, keepdim=True)

Note that the code version returns a tensor/array of softmax activations.

For demonstration purposes, let’s first create a set of activations using torch.randn, assuming we have 6 objects to classify into 3 classes.

acts = torch.randn((6,3))*2
acts
tensor([[ 3.8538,  2.9746, -0.9948],
        [ 0.8792, -1.5163,  2.1566],
        [ 1.6016,  3.3612,  0.7117],
        [-1.3732,  1.2209,  2.6695],
        [-0.4632,  0.0835, -0.5032],
        [ 1.7197, -0.6195, -0.7914]])

Let’s also set our target labels:

targ = tensor([0,1,0,2,2,0])

To take the softmax of our initial (random) activations, we need to pass acts into torch.softmax:

sm_acts = torch.softmax(acts, dim=1)
sm_acts
tensor([[0.7028, 0.2917, 0.0055],
        [0.2137, 0.0195, 0.7668],
        [0.1385, 0.8046, 0.0569],
        [0.0140, 0.1876, 0.7984],
        [0.2711, 0.4684, 0.2605],
        [0.8492, 0.0819, 0.0689]])

Perfect! We can check that each row adds up to 1 as expected.

Log Likelihood

To calculate our loss, for each item of targ, we need to select the appropriate column of sm_acts using tensor indexing:

idx = range(6)
sm_acts[idx, targ]
tensor([0.7028, 0.0195, 0.1385, 0.7984, 0.2605, 0.8492])

F.nll_loss does the same thing, but flips the sign of each number in the tensor. PyTorch defaults to taking the mean of the losses; to prevent this, we can pass reduction='none' as a parameter.

result = -F.nll_loss(sm_acts, targ, reduction='none')
result
tensor([0.7028, 0.0195, 0.1385, 0.7984, 0.2605, 0.8492])

Taking the Logarithm

We take the (natural) logarithm of result for two reasons: - prevents under/overflow when performing mathematical operations - differences between small numbers is amplified

In our case, result relfects the predicted probability of the correct label, so when the prediction is “good” (closer to 1), we want our loss function to return a small value (and vice versa). We can achieve this by taking the negative of the log:

loss = -torch.log(result)
loss
tensor([0.3527, 3.9384, 1.9770, 0.2251, 1.3451, 0.1635])

And there we go! We just found the cross entropy loss for our example.

Using Modules

We can simplify the code above by using log_softmax followed nll_loss:

lsm_acts = F.log_softmax(acts, dim=1)
loss = F.nll_loss(lsm_acts, targ, reduction='none')
loss
tensor([0.3527, 3.9384, 1.9770, 0.2251, 1.3451, 0.1635])

In practice, this is exactly what nn.CrossEntropyLoss does:

nn.CrossEntropyLoss(reduction='none')(acts, targ)
tensor([0.3527, 3.9384, 1.9770, 0.2251, 1.3451, 0.1635])

The output loss tensors for all three approaches are equivalent as expected!