Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Turing.jl support #160

Open
torfjelde opened this issue Oct 27, 2023 · 12 comments
Open

Turing.jl support #160

torfjelde opened this issue Oct 27, 2023 · 12 comments

Comments

@torfjelde
Copy link

Hola amigos!

I came across the AutoMALA paper (really neat stuff) and wanted to have a go at it with some Turing.jl models. I came across this:

# At the moment, AutoMALA assumes a :singleton_variable structure
# so use the SliceSampler.
default_explorer(::TuringLogPotential) = SliceSampler()

What does this :singleton_variable structure refer to?

Thanks!

@alexandrebouchard
Copy link
Member

alexandrebouchard commented Oct 27, 2023

Thank you!

Here is an explanation for that temporary situation: one thing we really like with Turing is the ability to support models containing both discrete and continuous variables. To support this, part of the functionality in Pigeons allows the state to be tuple-like, where each state's key mapping to a potentially different data type. Let's call this situation the "type heterogeneous case".

But for the first pass of writing gradient based samplers, it was convenient to begin by assuming type homogeneity. We encode that situation with the tuple-like state having a single key called :singleton_variable.

Clearly it should be possible to use gradient-based sampling in the type heterogeneous case. We might just need some guidance on some details of the Turing API to achieve this. I forgot exactly what was the tumbling block but it did not seem serious. But for now since the slice sampler handles the type heterogeneous case, we left it as default for now (on a related note, we would also be interested in getting Turing's samplers, in particular the SMC based samplers, to work automatically in the tempering case, again we might just need a bit of guidance on the API to achieve this).

PS: to give a fuller picture, the state interface can handle different levels of abstraction, but I am focussing here on the level of abstraction relevant to autoMALA, HMC and similar samplers.

@torfjelde
Copy link
Author

Ah, I see that makes sense! 👍

I also noticed:

- The LogDensityProblems interface seems to force us to keep two representations of the states,
one for the VariableInfo and one vector based. This is only partly implemented at the moment,
as a result we have the following limitations: (1) AutoMALA+Turing cannot use the
diagonal pre-conditioning at the moment. (2) AutoMALA+Turing only works if all variables are
continuous at the moment. Both could be addressed but since the autodiff is pretty slow at the
moment it seems low priority; the user can just rely on SliceSampler() at the moment.

and

- On some Turing models, gradient computation is non-deterministic,
see 4433584a044510bf9360e1e7191e59478496dc0b and associated CIs at
https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5550424683/jobs/10135522013
vs
https://github.com/Julia-Tempering/Pigeons.jl/actions/runs/5550424683/jobs/10135521940
(look for output of test_turing.jl)

Could you elaborate a bit on the former?

For the latter, that sounds like a strange error tbh, but unfortunately the logs are no longer available so I can't really look into it 😕

we would also be interested in getting Turing's samplers, in particular the SMC based samplers, to work automatically in the tempering case, again we might just need a bit of guidance on the API to achieve this

Uncertain how useful this will be to tbh. SMC samplers in Turing.jl are computationally very inefficient due to the nature of their implementation.

Nonetheless, I'd be very happy to help:) Would be very nice to have this easily accessible and working in Turing.jl.

@alexandrebouchard
Copy link
Member

Certainly, happy to elaborate on these!

For the first blurb, here is the difficulty that I had encountered: I was using DynamicPPL.getall, setall! and ADgradient however when the model is mixed discrete-continuous, this includes the discrete variables. I was wondering if there an equivalent to getall/setall!/ADgradient but based on a view on only the continuous variables?

For the second one, I will rerun the test, hopefully we can reproduce the non-reproducibility! I will keep you updated...

@alexandrebouchard
Copy link
Member

I managed to replicate the non-reproducibility issue I mentioned in #165! This time I am attaching the logs for posterity since the CI seems to erase them after some time period.

Here is the background:

The test in question is this one.

I am attaching the logs below for two sister CI runs. They only differ in that one use a different MPI library, and since this specific test does not use MPI, consider them as two independent runs. If you search the logs for "Starting test_turing.jl" you will find a table under, look for the column min(αₑ) (minimum MALA accept pr across chains), you will find last column is 0.504 vs 0.518.

The problem does not arise if the slice sampler is used with a Turing model (table immediately after), or if autoMALA is used on pure Julia or Stan model. So I suspect the non-determinism is related to gradient computation on Turing models.

Some additional info:

1.8-mac-mpich-8_Run julia-actionsjulia-runtest@v1.txt
1.8-mac-openmpi-8_Run julia-actionsjulia-runtest@v1.txt

@torfjelde
Copy link
Author

torfjelde commented Nov 19, 2023

Sorry for the late reply; was awol for one one week and then sick the next..

But this is great; thank you!

So I suspect the non-determinism is related to gradient computation on Turing models.

Hmm, if this is the issue then there must be something with how it's set up in Pigeons.jl or something, because the model is fully reproducible on current Turing.jl, i.e. if I run NUTS with the same random seed multiple times on the exact model you pointed to 😕

I'll have a look.

@torfjelde
Copy link
Author

So I suspect the non-determinism is related to gradient computation on Turing models.

Have you observed this phenomenon concretely btw? Non-determinacy of the gradient computation I mean? Or is this just a suspicion?

@torfjelde
Copy link
Author

So when I run the exact tests from #165 locally, the results are perfectly reproducible (just running that testset twice results in exactly the same values everywhere).

julia> using Test, Pigeons, Turing

julia> @testset "Turing-gradient" begin
           target = Pigeons.toy_turing_unid_target()

           @show Threads.nthreads()

           logz_mala = Pigeons.stepping_stone_pair(pigeons(; target, explorer = AutoMALA(preconditioner = Pigeons.IdentityPreconditioner())))
           logz_slicer = Pigeons.stepping_stone_pair(pigeons(; target, explorer = SliceSampler()))

           @test abs(logz_mala[1] - logz_slicer[1]) < 0.1
       end
Threads.nthreads() = 1
┌ Info: Neither traces, disk, nor online recorders included. 
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        3.4     0.0083   7.09e+06  -3.32e+03          0      0.622          0      0.539 
        4       2.22    0.00967   7.69e+06  -1.48e+03          0      0.753      0.668      0.716 
        8       2.62     0.0212   1.72e+07      -42.8   8.09e-30      0.709      0.465       0.62 
       16       2.91     0.0662   3.79e+07      -10.8      0.077      0.677      0.439      0.606 
       32       3.29     0.0994   7.82e+07      -11.8      0.128      0.635      0.528      0.628 
       64       3.27      0.224   1.59e+08      -11.1      0.209      0.637      0.529      0.624 
      128       3.51      0.474   3.36e+08      -11.4      0.508       0.61       0.53      0.621 
      256       3.57      0.922   6.84e+08      -11.9      0.475      0.604      0.519      0.624 
      512       3.46       1.86   1.37e+09      -11.5      0.582      0.615      0.517      0.604 
 1.02e+03       3.52       3.72   2.77e+09      -11.9      0.571      0.609      0.517      0.629 
──────────────────────────────────────────────────────────────────────────────────────────────────
┌ Info: Neither traces, disk, nor online recorders included. 
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.04    0.00155   1.04e+06  -4.24e+03          0      0.885          1          1 
        4       4.06    0.00233   1.78e+06      -16.3   4.63e-06      0.549          1          1 
        8       3.49    0.00428   3.52e+06      -12.1      0.215      0.612          1          1 
       16       2.68    0.00919    7.4e+06      -10.2      0.518      0.703          1          1 
       32       4.29     0.0165   1.36e+07      -11.8      0.222      0.524          1          1 
       64       3.17     0.0366   2.84e+07      -11.5      0.529      0.648          1          1 
      128       3.56     0.0863   5.49e+07      -11.5      0.523      0.605          1          1 
      256       3.38      0.154    1.1e+08      -11.6      0.526      0.625          1          1 
      512       3.48      0.292   2.21e+08        -12      0.527      0.614          1          1 
 1.02e+03       3.55      0.611   4.43e+08      -11.8      0.571      0.605          1          1 
──────────────────────────────────────────────────────────────────────────────────────────────────
Test Summary:   | Pass  Total  Time
Turing-gradient |    1      1  8.7s
Test.DefaultTestSet("Turing-gradient", Any[], 1, false, false, true, 1.700434095336861e9, 1.700434103988216e9, false)

julia> @testset "Turing-gradient" begin
           target = Pigeons.toy_turing_unid_target()

           @show Threads.nthreads()

           logz_mala = Pigeons.stepping_stone_pair(pigeons(; target, explorer = AutoMALA(preconditioner = Pigeons.IdentityPreconditioner())))
           logz_slicer = Pigeons.stepping_stone_pair(pigeons(; target, explorer = SliceSampler()))

           @test abs(logz_mala[1] - logz_slicer[1]) < 0.1
       end
Threads.nthreads() = 1
┌ Info: Neither traces, disk, nor online recorders included. 
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2        3.4    0.00789   7.09e+06  -3.32e+03          0      0.622          0      0.539 
        4       2.22     0.0101   7.69e+06  -1.48e+03          0      0.753      0.668      0.716 
        8       2.62     0.0234   1.72e+07      -42.8   8.09e-30      0.709      0.465       0.62 
       16       2.91     0.0502   3.79e+07      -10.8      0.077      0.677      0.439      0.606 
       32       3.29      0.117   7.82e+07      -11.8      0.128      0.635      0.528      0.628 
       64       3.27      0.222   1.59e+08      -11.1      0.209      0.637      0.529      0.624 
      128       3.51       0.45   3.36e+08      -11.4      0.508       0.61       0.53      0.621 
      256       3.57      0.933   6.84e+08      -11.9      0.475      0.604      0.519      0.624 
      512       3.46       1.89   1.37e+09      -11.5      0.582      0.615      0.517      0.604 
 1.02e+03       3.52       3.77   2.77e+09      -11.9      0.571      0.609      0.517      0.629 
──────────────────────────────────────────────────────────────────────────────────────────────────
┌ Info: Neither traces, disk, nor online recorders included. 
│    You may not have access to your samples (unless you are using a custom recorder, or maybe you just want log(Z)).
└    To add recorders, use e.g. pigeons(target = ..., record = [traces; record_default()])
──────────────────────────────────────────────────────────────────────────────────────────────────
  scans        Λ        time(s)    allc(B)  log(Z₁/Z₀)   min(α)     mean(α)    min(αₑ)   mean(αₑ) 
────────── ────────── ────────── ────────── ────────── ────────── ────────── ────────── ──────────
        2       1.04    0.00121   1.04e+06  -4.24e+03          0      0.885          1          1 
        4       4.06    0.00233   1.78e+06      -16.3   4.63e-06      0.549          1          1 
        8       3.49    0.00465   3.52e+06      -12.1      0.215      0.612          1          1 
       16       2.68    0.00969    7.4e+06      -10.2      0.518      0.703          1          1 
       32       4.29     0.0171   1.36e+07      -11.8      0.222      0.524          1          1 
       64       3.17     0.0368   2.84e+07      -11.5      0.529      0.648          1          1 
      128       3.56     0.0714   5.49e+07      -11.5      0.523      0.605          1          1 
      256       3.38      0.156    1.1e+08      -11.6      0.526      0.625          1          1 
      512       3.48      0.309   2.21e+08        -12      0.527      0.614          1          1 
 1.02e+03       3.55      0.598   4.43e+08      -11.8      0.571      0.605          1          1 
──────────────────────────────────────────────────────────────────────────────────────────────────
Test Summary:   | Pass  Total   Time
Turing-gradient |    1      1  10.0s
Test.DefaultTestSet("Turing-gradient", Any[], 1, false, false, true, 1.700434107621442e9, 1.70043411758343e9, false)
Manifest.toml
(jl_Vc09RO) pkg> st --manifest
Status `/tmp/jl_Vc09RO/Manifest.toml`
  [47edcb42] ADTypes v0.2.5
  [621f4979] AbstractFFTs v1.5.0
⌅ [80f14c24] AbstractMCMC v4.4.2
⌅ [7a57a42e] AbstractPPL v0.6.2
  [1520ce14] AbstractTrees v0.4.4
  [7d9f7c33] Accessors v0.1.33
  [79e6a3ab] Adapt v3.7.1
⌅ [0bf59076] AdvancedHMC v0.5.5
⌅ [5b7e9947] AdvancedMH v0.7.5
⌅ [576499cb] AdvancedPS v0.4.3
  [b5ca4192] AdvancedVI v0.2.4
  [dce04be8] ArgCheck v2.3.0
  [ec485272] ArnoldiMethod v0.2.0
  [4fba245c] ArrayInterface v7.5.1
  [a9b6321e] Atomix v0.1.0
  [13072b0f] AxisAlgorithms v1.0.1
  [39de3d68] AxisArrays v0.4.7
  [198e06fe] BangBang v0.3.39
  [9718e550] Baselet v0.1.1
  [76274a88] Bijectors v0.13.7
  [c88b6f0a] BridgeStan v2.2.2
⌅ [fa961155] CEnum v0.4.2
  [49dc2e85] Calculus v0.5.1
  [082447d4] ChainRules v1.58.0
  [d360d2e6] ChainRulesCore v1.18.0
  [9e997f8a] ChangesOfVariables v0.1.8
  [861a8166] Combinatorics v1.0.2
  [38540f10] CommonSolve v0.2.4
  [bbf7d656] CommonSubexpressions v0.3.0
  [34da2185] Compat v4.10.0
  [a33af91c] CompositionsBase v0.1.2
  [2569d6c7] ConcreteStructs v0.2.3
  [88cd18e8] ConsoleProgressMonitor v0.1.2
  [187b0558] ConstructionBase v1.5.4
  [a8cc5b0e] Crayons v4.1.1
  [9a962f9c] DataAPI v1.15.0
  [a93c6f00] DataFrames v1.6.1
  [864edb3b] DataStructures v0.18.15
  [e2d170a0] DataValueInterfaces v1.0.0
  [244e2a9f] DefineSingletons v0.1.2
  [8bb1440f] DelimitedFiles v1.9.1
  [b429d917] DensityInterface v0.4.0
  [163ba53b] DiffResults v1.1.0
  [b552c78f] DiffRules v1.15.1
  [31c24e10] Distributions v0.25.103
  [ced4e74d] DistributionsAD v0.6.53
  [ffbed154] DocStringExtensions v0.9.3
  [fa6b7ba4] DualNumbers v0.6.8
⌅ [366bfd00] DynamicPPL v0.23.21
⌅ [cad2338a] EllipticalSliceSampling v1.1.0
  [4e289a0a] EnumX v1.0.4
  [6a31a4e8] Expect v0.3.1
  [e2ba6199] ExprTools v0.1.10
  [7a1cc6ca] FFTW v1.7.1
  [1a297f60] FillArrays v1.7.0
  [59287772] Formatting v0.4.2
  [f6369f11] ForwardDiff v0.10.36
  [069b7b12] FunctionWrappers v1.1.3
  [77dc65aa] FunctionWrappersWrappers v0.1.3
  [d9f16b24] Functors v0.4.5
  [46192b85] GPUArraysCore v0.1.5
  [86223c79] Graphs v1.9.0
  [34004b35] HypergeometricFunctions v0.3.23
  [d25df0c9] Inflate v0.1.4
  [22cec73e] InitialValues v0.3.1
  [842dd82b] InlineStrings v1.4.0
  [505f98c9] InplaceOps v0.3.0
  [18e54dd8] IntegerMathUtils v0.1.2
  [a98d9a8b] Interpolations v0.14.7
  [8197267c] IntervalSets v0.7.8
  [3587e190] InverseFunctions v0.1.12
  [41ab1584] InvertedIndices v1.3.0
  [92d709cd] IrrationalConstants v0.2.2
  [c8e1da08] IterTools v1.8.0
  [82899510] IteratorInterfaceExtensions v1.0.0
  [692b3bcd] JLLWrappers v1.5.0
  [63c18a36] KernelAbstractions v0.9.13
  [5ab0869b] KernelDensity v0.6.7
  [929cbde3] LLVM v6.4.0
  [8ac3fa9e] LRUCache v1.5.0
  [b964fa9f] LaTeXStrings v1.3.1
  [73f95e8e] LatticeRules v0.0.1
  [50d2b5c4] Lazy v0.15.1
  [1d6d02ad] LeftChildRightSiblingTrees v0.2.0
  [6f1fad26] Libtask v0.8.6
  [6fdf6af0] LogDensityProblems v2.1.1
  [996a588d] LogDensityProblemsAD v1.7.0
  [2ab3a3ac] LogExpFunctions v0.3.26
  [e6f89c97] LoggingExtras v1.0.3
  [c7f686f2] MCMCChains v6.0.4
  [be115224] MCMCDiagnosticTools v0.3.8
  [e80e1ace] MLJModelInterface v1.9.3
⌃ [da04e1cc] MPI v0.20.16
  [3da0fdf6] MPIPreferences v0.1.10
  [1914dd2f] MacroTools v0.5.11
  [dbb5928d] MappedArrays v0.4.2
  [128add7d] MicroCollections v0.1.4
  [e1d29d7a] Missings v1.1.0
  [872c559c] NNlib v0.9.7
  [77ba4419] NaNMath v1.0.2
  [86f7a689] NamedArrays v0.10.0
  [c020b1a1] NaturalSort v1.0.0
  [6fe1bfb0] OffsetArrays v1.12.10
  [a15396b6] OnlineStats v1.6.3
  [925886fa] OnlineStatsBase v1.6.1
  [3bd65402] Optimisers v0.3.1
  [bac558e1] OrderedCollections v1.6.2
  [90014a1f] PDMats v0.11.29
  [69de0a69] Parsers v2.8.0
  [0eb8d820] Pigeons v0.2.8
  [eebad327] PkgVersion v0.3.3
  [2dfb63ee] PooledArrays v1.4.3
  [aea7be01] PrecompileTools v1.2.0
  [21216c6a] Preferences v1.4.1
  [08abe8d2] PrettyTables v2.3.0
  [27ebfcd6] Primes v0.5.5
  [33c8b6b6] ProgressLogging v0.1.4
  [92933f4c] ProgressMeter v1.9.0
  [1fd47b50] QuadGK v2.9.1
  [8a4e6c94] QuasiMonteCarlo v0.3.3
  [74087812] Random123 v1.6.1
  [e6cf234a] RandomNumbers v1.5.3
  [b3c3ace0] RangeArrays v0.3.2
  [c84ed2f1] Ratios v0.4.5
  [c1ae055f] RealDot v0.1.0
  [3cdcf5f2] RecipesBase v1.3.4
  [731186ca] RecursiveArrayTools v2.38.10
  [189a3867] Reexport v1.2.2
  [ae029012] Requires v1.3.0
  [79098fc4] Rmath v0.7.1
  [f2b01f46] Roots v2.0.22
  [7e49a35a] RuntimeGeneratedFunctions v0.5.12
  [0bca4576] SciMLBase v2.8.2
  [c0aeaf25] SciMLOperators v0.3.7
  [30f210dd] ScientificTypesBase v3.0.0
  [91c51154] SentinelArrays v1.4.1
  [efcf1570] Setfield v1.1.1
  [699a6c99] SimpleTraits v0.9.4
  [ce78b400] SimpleUnPack v1.1.0
  [ed01d8cd] Sobol v1.5.0
  [a2af1166] SortingAlgorithms v1.2.0
  [dc90abb0] SparseInverseSubset v0.1.1
  [276daf66] SpecialFunctions v2.3.1
  [8efc31e9] SplittableRandoms v0.1.2
  [171d559e] SplittablesBase v0.1.15
  [90137ffa] StaticArrays v1.7.0
  [1e83bf80] StaticArraysCore v1.4.2
  [64bff920] StatisticalTraits v3.2.0
  [82ae8749] StatsAPI v1.7.0
  [2913bbd2] StatsBase v0.34.2
  [4c63d2b9] StatsFuns v1.3.0
  [892a3eda] StringManipulation v0.3.4
  [09ab397b] StructArrays v0.6.16
  [2efcf032] SymbolicIndexingInterface v0.2.2
  [3783bdb8] TableTraits v1.0.1
  [bd369af6] Tables v1.11.1
  [5d786b92] TerminalLoggers v0.1.7
  [9f7883ad] Tracker v0.2.29
  [28d57a85] Transducers v0.4.79
  [410a4b4d] Tricks v0.1.8
  [781d530d] TruncatedStacktraces v1.4.0
  [fce5fe82] Turing v0.29.3
  [013be700] UnsafeAtomics v0.2.1
  [d80eeb9a] UnsafeAtomicsLLVM v0.1.3
  [efce3f68] WoodburyMatrices v0.5.6
  [a5390f91] ZipFile v0.10.1
  [700de1a5] ZygoteRules v0.2.4
  [f5851436] FFTW_jll v3.3.10+0
  [e33a78d0] Hwloc_jll v2.9.3+0
  [1d5cc7b8] IntelOpenMP_jll v2023.2.0+0
  [dad2f222] LLVMExtra_jll v0.0.27+1
  [856f044c] MKL_jll v2023.2.0+0
  [7cb0a576] MPICH_jll v4.1.2+0
  [f1f71cc9] MPItrampoline_jll v5.3.1+0
  [9237b28f] MicrosoftMPI_jll v10.1.4+1
  [fe0851c0] OpenMPI_jll v5.0.0+0
  [458c3c95] OpenSSL_jll v3.0.12+0
  [efe28fd5] OpenSpecFun_jll v0.5.5+0
  [32165bc3] PMIx_jll v4.2.7+0
  [f50d1b31] Rmath_jll v0.4.0+0
  [1080aeaf] libevent_jll v2.1.13+1
  [eb928a42] prrte_jll v3.0.2+0
  [0dad84c5] ArgTools v1.1.1
  [56f22d72] Artifacts
  [2a0f44e3] Base64
  [ade2ca70] Dates
  [8ba89e20] Distributed
  [f43a241f] Downloads v1.6.0
  [7b1f6079] FileWatching
  [9fa8497b] Future
  [b77e0a4c] InteractiveUtils
  [4af54fe1] LazyArtifacts
  [b27032c2] LibCURL v0.6.3
  [76f85450] LibGit2
  [8f399da3] Libdl
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [d6f4376e] Markdown
  [a63ad114] Mmap
  [ca575930] NetworkOptions v1.2.0
  [44cfe95a] Pkg v1.9.2
  [de0858da] Printf
  [3fa0cd96] REPL
  [9a3f8284] Random
  [ea8e919c] SHA v0.7.0
  [9e88b42a] Serialization
  [1a1011a3] SharedArrays
  [6462fe0b] Sockets
  [2f01184e] SparseArrays
  [10745b16] Statistics v1.9.0
  [4607b0f0] SuiteSparse
  [fa267f1f] TOML v1.0.3
  [a4e569a6] Tar v1.10.0
  [8dfed614] Test
  [cf7118a7] UUIDs
  [4ec0a83e] Unicode
  [e66e0078] CompilerSupportLibraries_jll v1.0.5+0
  [deac9b47] LibCURL_jll v7.84.0+0
  [29816b5a] LibSSH2_jll v1.10.2+0
  [c8ffd9c3] MbedTLS_jll v2.28.2+0
  [14a3606d] MozillaCACerts_jll v2022.10.11
  [4536629a] OpenBLAS_jll v0.3.21+4
  [05823500] OpenLibm_jll v0.8.1+0
  [bea87d4a] SuiteSparse_jll v5.10.1+6
  [83775a58] Zlib_jll v1.2.13+0
  [8e850b90] libblastrampoline_jll v5.8.0+0
  [8e850ede] nghttp2_jll v1.48.0+0
  [3f19e933] p7zip_jll v17.4.0+0
Info Packages marked with ⌃ and ⌅ have new versions available, but those with ⌅ are restricted by compatibility constraints from upgrading. To see why use `status --outdated -m`

@torfjelde
Copy link
Author

Is there a possibility that usage of a different MPI version affect the rng somehow? Seems quite strange, but I don't have too much experience with MPI.

@alexandrebouchard
Copy link
Member

Thanks for checking! Yes, we did observed the non-reproducibility, this first occurred when CI checks were non-deterministically failing.

Regarding the MPI hypothesis, I would be quite surprised if the MPI implementation would affect the RNGs. MPI should not be aware of rngs.

What we often observed in non-reproducibility issues is that things might appear reproducible in one computing setup but not in another. E.g. if a race condition depends on timing of events it might only trigger in certain setups. Here I agree it seems to only show up on the CI instances (see logs saved above). We like to have reproducibility in the CI instances because we rely on it to check a property we call "parallelism invariance" on our distributed algorithms (https://pigeons.run/dev/distributed/#distributed). But we can always rely on other PPLs (or non gradient Turing) to check that property on the core distributed algorithms so this may not be necessary to narrow down this tricky quirk.

@torfjelde
Copy link
Author

Sorry for the very late reply here. Conference + Christmas holidays happened + I've been working on a convenient way to represent mixing of variable types as you mentioned, and wanted to have that done before I replied but that's been taking much longer time than originally intended so will have to defer that for now.

Regarding the MPI hypothesis, I would be quite surprised if the MPI implementation would affect the RNGs. MPI should not be aware of rngs.

Very much agree; that would seem very surprising.

So I suspect the non-determinism is related to gradient computation on Turing models.

Are you constructing a separate model for each process? As in, is

function LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential::TuringLogPotential, buffers::Pigeons.Augmentation)

called for each worker?

@alexandrebouchard
Copy link
Member

No worries, I have been slow in everything lately too!! :)

Good question! For this specific test, it is single threaded. But if it would have been multi-threaded, then the way it is setup at the moment is to have each replica having a distinct VarInfo, but the model is shared by several threads. I assumed the mutability happens in VarInfo's and not in models. I guess it's orthogonal to this issue, but I am curious if this is the right mental model?

@torfjelde
Copy link
Author

Whoops, completely missed the reply! Just came across this now because I was just trying out Pigeons.jl for a problem I had of my own and figured I'd check back no this issue.

But if it would have been multi-threaded, then the way it is setup at the moment is to have each replica having a distinct VarInfo, but the model is shared by several threads. I assumed the mutability happens in VarInfo's and not in models.

Mutation shouldn't happen in the model unless arguments passed to the model itself are mutating, e.g. passing in missing in an array will lead to it being sampled rather than "observed". So your understanding is indeed correct:)

Hmm, any ideas of how to best go about debugging this? We on the Turing side are pretty keen to help out with this but (at least I) am lacking in knowledge when it comes to how all of the moving parts here interact 😕 Maybe @devmotion or @yebai have any thoughts / ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants