I'm trying to use TensorFlow Probability to learn the alpha and beta parameters of a beta distribution. I can't get it to work for some reason - the loss is all NaN values. Here's what I have:
from scipy.stats import beta
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_probability as tfp
tfd = tfp.distributions
beta_sample_data = beta1 = beta.rvs(5,5,size=1000)
beta_train = tfd.Beta(concentration1=tf.Variable(1.,name='alpha'),concentration0=tf.Variable(1.,name='beta'),name='beta_train')
def nll(x_train,distribution):
return -tf.reduce_mean(distribution.log_prob(x_train))
# Define a function to compute the loss and gradients
@tf.function
def get_loss_and_grads(x_train,distribution):
with tf.GradientTape() as tape:
tape.watch(distribution.trainable_variables)
loss = nll(x_train, distribution)
grads = tape.gradient(loss,distribution.trainable_variables)
return loss,grads
def beta_dist_optimisation(data, distribution):
# Keep results for plotting
train_loss_results = []
train_rate_results = []
optimizer = tf.keras.optimizers.SGD(learning_rate=0.005)
num_steps = 10
for i in range(num_steps):
loss,grads = get_loss_and_grads(data,distribution)
print(loss,grads)
optimizer.apply_gradients(zip(grads,distribution.trainable_variables))
alpha_value = distribution.concentration1.value()
beta_value = distribution.concentration0.value()
train_loss_results.append(loss)
train_rate_results.append((alpha_value,beta_value))
print("Step {:03d}: Loss: {:.3f}: Alpha: {:.3f} Beta: {:.3f}".format(i,loss,alpha_value,beta_value))
return train_loss_results, train_rate_results
sample_data = tf.cast(beta_sample_data, tf.float32)
train_loss_results, train_rate_results = beta_dist_optimisation(sample_data,beta_train)
I'm trying to use maximum likelihood to learn the alpha and beta parameters of 5,5.
You need to constrain the variables to be positive before using them. Otherwise gradient steps may take them negative and give nans. You can just call softplus on them before passing in. You could also check out tfp.util.TransformedVariable. It should have some examples in the docs.