I am running a regression task on a dataset which is composed of both authentic and augmented samples. The augmented samples are generated by jittering the authentic ones. I would like to select the best performing model with cross-validation with sklearn.
For this I would like to:
- Train the model on a set that is comprised of both the authentic and augmented samples. I do not want the fitting procedure to take the origin of the sample into account (i.e. it should be equivalent to run
estimator.fit(..., sample_weights = [1,1,..., 1]). - Score the models based on their performance on the authentic samples only. For this, I thought of setting the weight of augmented (resp. authentic) samples to 0 (resp. 1).
How to achieve this with sklearn's cross_validate?
I tried the following:
from sklearn import model_selection
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import r2_score, mean_squared_error, make_scorer
import numpy as np
n_smpl, n_feats = 100, 5
arr_source = np.random.random((n_smpl, n_feats))
arr_target = np.random.random((n_smpl, n_feats))
arr_weight = np.random.randint(0, 2, n_smpl) # 0 for augmented, 1 for authentic
model = RandomForestRegressor()
kfold_splitter = model_selection.KFold(n_splits=5, random_state=7, shuffle=True)
my_scorers = {
"r2_weighted": make_scorer(r2_score, sample_weight=arr_weight),
"mse_weighted": make_scorer(mean_squared_error, greater_is_better=False, sample_weight=arr_weight)
}
cv_results = model_selection.cross_validate(model, arr_source, arr_target, scoring = my_scorers, cv=kfold_splitter)
But this returns ValueError: Found input variables with inconsistent numbers of samples: [20, 20, 100]. I understand that this happens because cross_validate is not able to split the sample weights according to the folds.
Is there any way to get this to run through cross-validate? Or any other method?
Found what I was looking for in the Metadata-routing feature which enables to pass different parameters to the scorer and the estimator.
The steps to use it are:
sklearn.set_config(enable_metadata_routing=True)RandomForestRegressor().set_fit_request(sample_weight=False)make_scorer(r2_score).set_score_request(sample_weight=True)model_selection.cross_validate(..., params={"sample_weight": arr_weight})The full code looks like: