I am playing with this repo (https://github.com/SnailWalkerYC/LeNet-5_Speed_Up) and try to learn NN details. This repo implemented LeNet5 in C and CUDA. I am focusing on the CPU part now and its code in seq/. One particular place I am getting lost is this function in seq/lenet.c
static inline void softmax(double input[OUTPUT], double loss[OUTPUT], int label, int count){
double inner = 0;
for (int i = 0; i < count; ++i){
double res = 0;
for (int j = 0; j < count; ++j){
res += exp(input[j] - input[i]);
}
loss[i] = 1. / res;
inner -= loss[i] * loss[i];
}
inner += loss[label];
for (int i = 0; i < count; ++i){
loss[i] *= (i == label) - loss[i] - inner;
}
}
Because there are no comments, I spent some time to understand this function. Finally I figured it that it is computing the derivatives of a MSE loss function with respect to input to a softmax layer.
Then I am trying to use a cross-entropy loss function together with softmax so I came out with the following function to replace the above one.
static inline void softmax(double input[OUTPUT], double loss[OUTPUT], int label, int count)
{
double inner = 0;
double max_input = -INFINITY;
// Find the maximum input value to prevent numerical instability
for (int i = 0; i < count; ++i)
{
if (input[i] > max_input)
max_input = input[i];
}
// Compute softmax and cross-entropy loss
double sum_exp = 0;
for (int i = 0; i < count; ++i)
{
double exp_val = exp(input[i] - max_input);
sum_exp += exp_val;
loss[i] = exp_val;
}
double softmax_output[OUTPUT];
for (int i = 0; i < count; ++i)
{
loss[i] /= sum_exp;
softmax_output[i] = loss[i];
}
// Compute cross-entropy loss and derivatives
inner = -log(softmax_output[label]);
for (int i = 0; i < count; ++i)
{
loss[i] = softmax_output[i] - (i == label);
}
}
However, with my version of softmax() function, the MNIST recognition didn't work. The original version achieved a >96% accuracy. What's wrong with my code for cross-entropy loss?
Ok, I'll answer my own question.
I managed to make the cross-entropy loss work with softmax. There are two places that need to be tweaked:
instead of
This is because the way this CNN updates the weights and biases is counterintuitive:
So either use the opposite form of the loss derivative or change weight update to
0.5. Most places I've seen used0.1.With above 1) and 2), the model now predicts at >%96 level consistently.
Well, I have learned quite a lot during this process.