How to stratify split a multi-label, melted dataframe by unique IDs instead of rows

97 views Asked by At

Note: I am not asking how to pivot a dataframe, I'm asking how to split a dataset into train, validation and test sets where the dataset is not only in multi-label and melted format, it also has to be split by unique customer ids, with stratification to ensure each split is proportionate.

The dataset that i have

I have a df with a shape of around (200000, 700). It is a multi-label (not multi-class as the labels are not mutually exclusive) with four labels and is in a melted format.

To illustrate my issue, here's a simplified example of the df:

customer_id month_year label target
customer_1 jan_2022 label_1 1
customer_2 jan_2022 label_1 1
customer_2 jan_2022 label_2 0
customer_2 jan_2022 label_3 1
customer_2 jan_2022 label_4 1
customer_3 jan_2022 label_1 1
customer_3 feb_2022 label_1 0
customer_3 mar_2022 label_1 1
customer_3 apr_2022 label_1 1
customer_3 feb_2022 label_3 0

There are 4 distinct labels in the df: label_1, label_2, label_3, label_4 with the 'target' column indicating 0 or 1 for that label. It essentially represents the same information as a df with four seperate labels, but in a melted format.

The df captures the monthly snapshot of up to four target labels for each customer. This means that not every customer has all four labels, and each label itself might have more than one profile snapshot e.g. label for Jan 2022, Feb 2022, Mar 2022, and so on.

This variability in both labels and frequency per customer means that:

  • a customer like customer_1 might only have jan_2022 snapshot of label_1.
  • another customer like customer_2 might have jan_2022 snapshots of label_1, label_2, label_3 & label_4.
  • other customers like customer_3 might have four months of snapshots for label_1: jan_2022, feb_2022, mar_2022, apr_2022, as well as a feb_2022 snapshot of label_3.

Why the df is in a melted format

If we were to unmelt the df, it would look like this:

customer_id month_year label_1 label_2 label_3 label_4
customer_3 jan_2022 1
customer_3 feb_2022 0 0
customer_3 mar_2022 1
customer_3 apr_2022 1

And the absence of certain labels for certain months (not having 1/0) is exactly why the df has to be in a melted format, otherwise the model cannot be trained on it.

The problem

My problem is two-fold:

  1. I first need to split this dataset into a training, validation & test set in such a way that each customer's records should only either be in the training, validation, & test set with no overlap.

  2. The proportions of the multi-labels across each split should be consistent. I understand that it might not be feasible for each split to be completely proportionate, but it's important to attempt to have stratification.

What i've tried so far:

I've managed to solve the first part of the problem by using sklearn's train_test_split to split the list of unique customer ids into the train, val & test lists and then filtering the df by those lists:

from sklearn.model_selection import train_test_split

train_size = 0.7
val_size = 0.15
test_size = 0.15

# 1. Get the unqiue customer ids from the df as a list
unique_cust_list = df["customer_id"].unique().tolist()
 
 # 2. split the unique ids into train, val & test list 
train_id_list, remainder = train_test_split(unique_cust_list, test_size= val_size + test_size)
val_id_list, test_id_list = train_test_split(remainder, test_size=test_size / (val_size + test_size))

# 3. Filter the df using the 3 lists
train_df = df[df["customer_id"].isin(train_id_list)]
val_df = df[df["customer_id"].isin(val_id_list)]

test_df = df[df["customer_id"].isin(test_id_list)]

I'm stuck with the second problem of ensuring proportionality in each split.

I've looked into sklearn's StratifiedShuffleSplit and StratifiedGroupKFold but none of those natively support splitting a multi-label melted dataset by unique id.

In fact, I cannot seem to find any solutions online that is catered to unique nature of my dataset. I might possibily have to implement a custom stratification solution but I have no idea how to begin with this.

End Result

The ideal result is:

  1. Each customer's records should only be in the train, validation or test sets.

  2. The label distributions among the sets should be fairly consistent. For example, assuming the train dataset has these label distributions:

Labels 0 1
label_1 91% 9%
label_2 88% 12%
label_3 95% 5%
label_4 89% 11%

The validation and test datasets should also follow these similar distributions.

0

There are 0 answers