scikeras.wrappers.KerasClassifier returning ValueError: Could not interpret metric identifier: loss

263 views Asked by At

I was looking into KerasClassifier, as I would like to plug it in a scikit-learn pipeline, but I'm getting the aforementioned ValueError.

The following code should be able to reproduce the error I'm getting:

from sklearn.model_selection import KFold, cross_val_score
from sklearn.preprocessing import StandardScaler
from scikeras.wrappers import KerasClassifier
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.datasets import load_iris
import numpy as np

data = load_iris()
X = data.data
y = data.target

def create_model():
    model = Sequential()
    model.add(Dense(8, input_dim=4, activation='relu'))
    model.add(Dense(3, activation='softmax'))
    model.compile(loss='sparse_categorical_crossentropy',
                  optimizer='adam',
                  metrics=['accuracy'])
    return model

clf = KerasClassifier(build_fn=create_model,
                      epochs=100,
                      batch_size=10,
                      verbose=1)

pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('clf', clf)
])

kf = KFold(n_splits=5, shuffle=True, random_state=42)
results = cross_val_score(pipeline, X, y, cv=kf)
print("Cross-Validation Accuracy:", np.mean(results))

It seems that my model is being compiled as the epochs are run. However, afterwards, I get the error:

ValueError: Could not interpret metric identifier: loss

The versions for the tensorflow and scikeras libraries are:

scikeras==0.12.0
tensorflow==2.15.0

EDIT: Eventually I experimented with different library versions and the following allowed me to run the code successfully, it seems the issue was caused by scikit-learn's version:

scikeras==0.12.0
tensorflow==2.15.0
scikit-learn==1.4.1
2

There are 2 answers

1
shadow On

This is just a problem with the tensorflow version. It can be solved with tensorflow==2.15.0. It has nothing to do with scikit-learn, scikeras, and python versions.

0
Datagniel On

Downgrading tensorflow to version 2.15 did the trick.

tensorflow==2.15
scikit-learn==1.14.post1
scikeras==0.12