Skip to content

Commit

Permalink
Merge pull request #68 from Vaibhavdixit02/master
Browse files Browse the repository at this point in the history
Update turing_inference #46
  • Loading branch information
ChrisRackauckas committed Feb 22, 2019
2 parents 46ba3b5 + 36e63a3 commit 3f67e7c
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 53 deletions.
2 changes: 1 addition & 1 deletion REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ DiffEqBase 5.0.0
#Mamba
Stan
Distributions
Turing 0.5.0
Turing
MacroTools
Optim
RecursiveArrayTools
Expand Down
54 changes: 16 additions & 38 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
@@ -1,54 +1,32 @@
function turing_inference(prob::DiffEqBase.DEProblem,alg,t,data,priors = nothing;
num_samples=1000, delta=0.65, kwargs...)

function bif(vi, sampler, x=data)
_lp = 0.0
N = length(priors)
_theta = Vector(undef,N)
num_samples=1000, delta=0.65, kwargs...)

N = length(priors)
_theta = Vector{Real}(undef, N)
@model bif(x) = begin
for i in 1:length(priors)
_theta[i], __lp = Turing.assume(sampler,
priors[i],
Turing.VarName(vi, [:bif, Symbol("theta$i")], ""),
vi)
_lp += __lp
_theta[i] ~ priors[i]
end
σ ~ InverseGamma(2, 3)

theta = convert(Array{typeof(first(_theta))},_theta)

σ, __lp = Turing.assume(sampler,
InverseGamma(2, 3),
Turing.VarName(vi, [:bif, ], ""),
vi)
_lp += __lp
theta = convert(Array{typeof(first(_theta))}, _theta)
p_tmp = remake(prob, u0 = convert.(eltype(theta), (prob.u0)), p = theta)
sol_tmp = solve(p_tmp, alg; saveat = t, kwargs...)
fill_length = length(t) - length(sol_tmp.u)

p_tmp = remake(prob, u0=convert.(eltype(theta),(prob.u0)),p=theta)
sol_tmp = solve(p_tmp,alg;saveat=t,kwargs...)
fill_length = length(t)-length(sol_tmp.u)
for i in 1:fill_length
if eltype(sol_tmp.u) <: Number
push!(sol_tmp.u,Inf)
push!(sol_tmp.u, Inf)
else
push!(sol_tmp.u,fill(Inf,size(sol_tmp[1])))
push!(sol_tmp.u, fill(Inf, size(sol_tmp[1])))
end
end
for i = 1:length(t)
res = sol_tmp.u[i]
# x[:,i] ~ MvNormal(res, σ*ones(2))
__lp = Turing.observe(
sampler,
MvNormal(res, σ*ones(length(prob.u0))), # Distribution
x[:,i], # Data point
vi
)
_lp += __lp
x[:,i] ~ MvNormal(res, σ*ones(length(prob.u0)))
end

vi.logp = _lp
vi
end

bif() = bif(Turing.VarInfo(), nothing)

chn = sample(bif, Turing.NUTS(num_samples, delta))
end
model = bif(data)
chn = sample(model, Turing.IS(num_samples))
end
25 changes: 11 additions & 14 deletions test/turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,31 @@ priors = [Normal(1.5,0.01)]

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500)

@show mean(bayesian_result[:theta1][50:end])
@show mean(bayesian_result[:_theta][50:end])

@test mean(bayesian_result[:theta1][50:end]) 1.5 atol=0.1
@test mean(bayesian_result[:_theta][50:end])[1] 1.5 atol=0.1

println("Four parameter case")
f1 = @ode_def begin
f2 = @ode_def begin
dx = a*x - b*x*y
dy = -c*y + d*x*y
end a b c d
u0 = [1.0,1.0]
tspan = (0.0,10.0)
p = [1.5,1.0,3.0,1.0]
prob1 = ODEProblem(f1,u0,tspan,p)
sol = solve(prob1,Tsit5())
prob2 = ODEProblem(f2,u0,tspan,p)
sol = solve(prob2,Tsit5())
t = collect(range(1,stop=10,length=10))
randomized = VectorOfArray([(sol(t[i]) + .01randn(2)) for i in 1:length(t)])
data = convert(Array,randomized)
priors = [Truncated(Normal(1.5,0.01),0,2),Truncated(Normal(1.0,0.01),0,1.5),
Truncated(Normal(3.0,0.01),0,4),Truncated(Normal(1.0,0.01),0,2)]

bayesian_result = turing_inference(prob1,Tsit5(),t,data,priors;num_samples=500)
bayesian_result = turing_inference(prob2,Tsit5(),t,data,priors;num_samples=500)

@show mean(bayesian_result[:theta1][50:end])
@show mean(bayesian_result[:theta2][50:end])
@show mean(bayesian_result[:theta3][50:end])
@show mean(bayesian_result[:theta4][50:end])
@show mean(bayesian_result[:_theta][:,:][50:end])

@test mean(bayesian_result[:theta1][50:end]) 1.5 atol=3e-1
@test mean(bayesian_result[:theta2][50:end]) 1.0 atol=3e-1
@test mean(bayesian_result[:theta3][50:end]) 3.0 atol=3e-1
@test mean(bayesian_result[:theta4][50:end]) 1.0 atol=3e-1
@test mean(bayesian_result[:_theta][:,1][50:end]) 1.5 atol=3e-1
@test mean(bayesian_result[:_theta][:,2][50:end]) 1.0 atol=3e-1
@test mean(bayesian_result[:_theta][:,3][50:end]) 3.0 atol=3e-1
@test mean(bayesian_result[:_theta][:,4][50:end]) 1.0 atol=3e-1

0 comments on commit 3f67e7c

Please sign in to comment.