Sklearn Gaussian Mixture predict_proba: difficulties to understand resulting probabilities

276 views Asked by At

We have two two-dimensional, well separable clusters as seen in the figure below:

enter image description here

Running on that dataset sklearns GMM:

import numpy as np
from sklearn.mixture import GaussianMixture

gm = GaussianMixture(n_components=2, random_state=42, n_init= 100, init_params='random_from_data').fit(simpleblobs)

it learns two Gaussians that fit/correspond to the 'centroids' of our dataset:

gm.means_

array([[11.70833308, 13.83333333],
       [29.52666641, 36.19      ]])

So far it seems like it learned what it is supposed to learn.

However things get interesting if we add an object that is roughly between the two clusters:

middle_guy = (25,25)

Naively I would expect that the probability of this middle guy (that is not part of the trained data, it is basically a new sample) is roughly [0.5,0.5], indicating that it has roughly equal probability to be assigned to both clusters, or equally low probability.

However, what we do get after calling the (I think so) intended function, we get:

gm.predict_proba([middle_guy])

array([[1.00000000e+00, 8.36394939e-66]])

This indicates that this 'middle' object is assigned to the bottom-left cluster with a probability of 1 and has a probability near zero assigned for the top-right cluster.

My question at this point is: why do the results indicate a different result than what I would have expected? What have I overseen or misinterpreted?

Thank you for your help!

1

There are 1 answers

0
addi.howe On

A couple things going on here. First, the fitting is not "perfect" in the sense that even if your data came from two isotropic gaussians (identity matrices for the covariances), after fitting the model the predicted means and covariances would differ from the ground truth, especially given a limited number of points.

That's very important for the issue you see, because it means that the covariance structure of each Gaussian component is different. Remember that the density of the Gaussian decays exponentially as you move away from the center. If your point was exactly the midpoint between the centers, and if your gaussians were perfectly isotropic, then the density of each component would be the same, and you would get a 50/50 split as you expected.

In this case, however, a difference between the two covariance matrices of each gaussian may cause one to be many times more likely than the other, especially because your point is so far away from each center. To sacrifice absolute correctness for intuition: at that distance, the density of each component has decayed exponentially, but at different rates (due to the different covariances). Both may be extremely unlikely to have generated the point (in the sense that the score, or log likelihood, is low) but one is many times more likely than the other.

Note that if you mess with the fitted GMM and set the means and covariances to be the "true" centers and the identity matrix, respectively, you will be able to get that 50/50 value you expect. You can see this if you run the script below.

import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

data, labels = make_blobs(
    20, 
    n_features=2, 
    centers=[[0, 0],[20, 0]],
    random_state=12345
)

from sklearn.mixture import GaussianMixture

gm = GaussianMixture(
    n_components=2, 
    random_state=54321, 
    n_init=100, 
    init_params='random_from_data'
).fit(data)


print("GM means:\n", gm.means_)
print("GM covs:\n", gm.covariances_)
midpt = [10, 0]

midpt_proba = gm.predict_proba([midpt])[0]
print(f"proba({midpt}): {midpt_proba}")
midpt_loglike = gm.score_samples([midpt])[0]
print(f"score({midpt}): {midpt_loglike}")

sigma = np.array([[[1, 0],[0, 1]], [[1, 0],[0, 1]]])

print("Modifying the GMM...")
gm.weights_ = np.array([0.5, 0.5])
gm.means_ = np.array([[0, 0], [20, 0]])
gm.covariances_ = sigma
gm.precisions_cholesky_ = np.linalg.cholesky(np.linalg.inv(sigma)).transpose((0, 2, 1))
gm.sample(10)
print("GM NEW means:\n", gm.means_)
print("GM NEW covs:\n", gm.covariances_)

midpt_proba = gm.predict_proba([midpt])[0]
print(f"proba({midpt}) (NEW): {midpt_proba}")
midpt_loglike = gm.score_samples([midpt])[0]
print(f"score({midpt}) (NEW): {midpt_loglike}")


plt.plot(data[:,0], data[:,1], '.')
plt.plot(midpt[0], midpt[1], 'r*')
plt.ylim([-5, 5])
plt.show()