Dtype error when trying to calculate SHAP values from a pytorch DL model

95 views Asked by At

I am getting the error "mat1 and mat2 must have the same dtype" when I try to calculate the SHAP values or try to run LIME with a Pytorch model. It works for other machine learning algorithms from the sklearn package.

My idea is to extract information from the model to explain the prediction made by the deep learning algorithm so I would also appreciate any suggestion about that.

I guess that the problem is the dtype being used for the weights by Pytorch. I am using float32 for my data so I tried to change the model with nn.Linear(n_feat, n_feat).float() but it doesn't work.

This is an example dataset (data) and here I paste my code for the model (which works). After that I paste the code for the SHAP values and LIME, which shows the same error.

import torch, torch.nn as nn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

var = 'target'
data = pd.read_csv('data800.csv', index_col=0)

train_dataset = data.sample(frac=0.8, random_state=1)
test_dataset = data.drop(train_dataset.index)

train_labels = train_dataset.pop(var)
test_labels = test_dataset.pop(var)

#Flatten distribution by replacing each value with its percentile
train_dataset_transformed = train_dataset.copy()
test_dataset_transformed = test_dataset.copy()
for feature in train_dataset.columns:
    #Percentiles estimated from train data
    bin_res = 0.2
    eval_percentiles = np.arange(bin_res, 100, bin_res)
    percentiles = [
        np.percentile(train_dataset[feature], p)
        for p in eval_percentiles
    ]

    #Apply to both train and test data
    train_dataset_transformed[feature] = pd.cut(
        train_dataset[feature],
        bins=[-np.inf] + percentiles + [np.inf],
        labels=False
    ).astype(np.float32)
    
    test_dataset_transformed[feature] = pd.cut(
        test_dataset[feature],
        bins=[-np.inf] + percentiles + [np.inf],
        labels=False
    ).astype(np.float32)

n_feat = train_dataset.shape[1]

model = nn.Sequential(
    nn.Linear(n_feat, n_feat), nn.ReLU(), nn.BatchNorm1d(n_feat),                   
    nn.Linear(n_feat, n_feat // 2), nn.ReLU(), nn.BatchNorm1d(n_feat // 2),                   
    # nn.Linear(n_feat // 2, n_feat // 2), nn.ReLU(),  nn.BatchNorm1d(n_feat // 2),
    nn.Linear(n_feat // 2, n_feat // 4), nn.ReLU(),  nn.BatchNorm1d(n_feat // 4),
    # nn.Linear(n_feat // 4, n_feat // 4), nn.ReLU(),  nn.BatchNorm1d(n_feat // 4),
    nn.Linear(n_feat // 4, 1)
)

optim = torch.optim.Adam(model.parameters(), 0.01)

#Scale
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler().fit(train_dataset_transformed)

X_train = scaler.transform(train_dataset_transformed)
X_test = scaler.transform(test_dataset_transformed)

#Convert to tensors
X_train = torch.tensor(X_train).float()
y_train = torch.tensor(train_labels.values).float()

X_test = torch.tensor(X_test).float()
y_test = torch.tensor(test_labels.values).float()

torch.manual_seed(0)
for epoch in range(1500):
    yhat = model(X_train)

    loss = nn.MSELoss()(yhat.ravel(), y_train)
    optim.zero_grad()
    loss.backward()
    optim.step()

    with torch.no_grad():
        yhatt = model(X_test)
        score = np.corrcoef(y_test, yhatt.ravel())
        if epoch % 100 == 0:
            print('epoch', epoch, '| loss:', loss.item(), '| R:', score[0, 1])

yhat = model(X_test)
yhat = yhat.detach().numpy()
plt.scatter(test_labels, yhat)
ax_lims = plt.gca().axis()
plt.plot([0, 100], [0, 100], 'k:', label='y=x')
plt.gca().axis(ax_lims)
plt.xlabel('True Values')
plt.ylabel('Predictions')
plt.legend()

SHAP:

import shap

def model2(x):
    return model(torch.tensor(x)).detach().numpy()

explainer = shap.Explainer(model2, X_test.detach().numpy())
shap_values = explainer(X_test.detach().numpy(), max_evals=10000)

LIME:

from lime import lime_tabular

features = train_dataset.columns

explainer_lime = lime_tabular.LimeTabularExplainer(X_train.detach().numpy(), feature_names=features, verbose=True, mode='regression')

#test vector
i = 10
#top features
k = 10

def model2(x):
    return model(torch.tensor(x)).detach().numpy()

exp_lime = explainer_lime.explain_instance(X_test[i].detach().numpy(), model2, num_features=k)
 
exp_lime.show_in_notebook()
1

There are 1 answers

0
Anastasiia Goi On

You might consider looking at the captum library, which is designed specifically for PyTorch model interpretability. It offers a range of tools similar to SHAP and LIME but might be more straightforward to integrate with PyTorch models.

If you still face dtype mismatch issues, you might want to be more explicit about setting dtypes throughout your code, not just when creating initial tensors.