I am experimenting with a machine learning method. My goal is to define a loss function that can evaluate the convergence of the model towards a uniform data distribution.
I generate K random hypothetical observations and compare how many of them are smaller than the real data in the training set. In other words, the model generates K simulations, and we calculate the number of simulations where the generated data is smaller than the actual data. If the model is well trained, this distribution should converge to a uniform distribution.
I have the following code written in Julia.
#Learning with custom loss
μ = 0; stddev = 1
η = 0.1; num_epochs = 2000; n_samples = 100; K = 2
optim = Flux.setup(Flux.Adam(η), model)
losses = []
@showprogress for epoch in 1:num_epochs
loss, grads = Flux.withgradient(model) do m
aₖ = zeros(K+1)
for _ in 1:n_samples
x = rand(Normal(μ, stddev), K)
yₖ = m(x')
y = realModel(rand(Float64))
aₖ += generate_aₖ(yₖ, y)
end
scalar_diff(aₖ ./ sum(aₖ))
end
Flux.update!(optim, model, grads[1])
push!(losses, loss)
end;
generated_a_k is construted the following way. However, to simplify, I will just say that a execution of generated_a_k id independent.
scalar_diff(aₖ) = sum((aₖ .- (1 ./ length(aₖ))) .^2)
jensen_shannon_∇(aₖ) = jensen_shannon_divergence(aₖ, fill(1 / length(aₖ), 1, length(aₖ)))
function jensen_shannon_divergence(p, q)
ϵ = 1e-3 # to avoid log(0)
return 0.5 * (kldivergence(p.+ϵ, q.+ϵ) + kldivergence(q.+ϵ, p.+ϵ))
end;
"""
sigmoid(ŷ, y)
Sigmoid function centered at y.
"""
function sigmoid(ŷ, y)
return sigmoid_fast.((ŷ-y)*10)
end;
"""
ψₘ(y, m)
Bump function centered at m. Implemented as a gaussian function.
"""
function ψₘ(y, m)
stddev = 0.1
return exp.((-0.5 .* ((y .- m) ./ stddev) .^ 2))
end
"""
ϕ(yₖ, yₙ)
Sum of the sigmoid function centered at yₙ applied to the vector yₖ.
"""
function ϕ(yₖ, yₙ)
return sum(sigmoid.(yₙ, yₖ))
end;
"""
γ(yₖ, yₙ, m)
Calculate the contribution of ψₘ ∘ ϕ(yₖ, yₙ) to the m bin of the histogram (Vector{Float}).
"""
function γ(yₖ, yₙ::Float64, m::Int64)
eₘ(m) = [j == m ? 1.0 : 0.0 for j in 0:length(yₖ)]
return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)
end;
"""
γ_fast(yₖ, yₙ, m)
Apply the γ function to the given parameters.
This function is faster than the original γ function because it uses StaticArrays.
However because Zygote does not support StaticArrays, this function can not be used in the training process.
"""
function γ_fast(yₖ, yₙ::Float64, m::Int64)
eₘ(m) = SVector{length(yₖ)+1, Float64}(j == m ? 1.0 : 0.0 for j in 0:length(yₖ))
return eₘ(m) * ψₘ(ϕ(yₖ, yₙ), m)
end;
"""
generate_aₖ(ŷ, y)
Generate a one step histogram (Vector{Float}) of the given vector ŷ of K simulted observations and the real data y.
generate_aₖ(ŷ, y) = ∑ₖ γ(ŷ, y, k)
"""
generate_aₖ(ŷ, y::Float64) = sum([γ(ŷ, y, k) for k in 0:length(ŷ)])
As I mentioned, I need to transform the concept of counting (histogram) into a differentiable operation. To do this, I have done the following. First, I have used a sigmoid operation to check if a fictitious observation (out of the K generated) is smaller than the real observation. Obviously, this process can be repeated for the K observations by simply summing them (using the ϕ function). After this, I want to generate differentiable histogram bins. To achieve this, it is sufficient to use a bump function (I tried modifiers, but in the end, when running this example in PyTorch, I found that a normalized Gaussian worked better) that sums to nearly 1 when the real observation is greater than the K fictitious observations, and it is zero otherwise. We will have K+1 bump functions, each centered at “1, 2, …”. With the gamma (γ) function, I simply try to associate the previous result with the vector component i , which will represent my histogram bins. Finally, generated_ak is nothing more than an interaction over all the vector components to generate each of the bins.
This code basically aims to learn a bimodal distribution using the method I described earlier. As you can see, the inner loop is 'morally parallelizable' since the order of simulations or their summation doesn't matter. Therefore, I would like to execute this in parallel, but when I use @distributed, threads, etc., I get the following Zigote error:
Can't differentiate foreigncall expression $(Expr(:foreigncall, :(:jl_cpu_wake), Nothing, svec(), 0, :(:ccall))). You might want to check the Zygote limitations documentation.
Is there a way to parallelize this while respecting the limitations of Zigote's automatic differentiation?
Any other advice to speed up this process?
Zygote does not support using
@distributeddirectly as you have discovered. The solution would be to move the call toFlux.withgradientinside the loop and manually sum the loss and gradients.Now the outer loop over
partitionsshould be amenable to@distributed.