Skip to content

Commit

Permalink
Merge pull request kazewong#106 from kazewong/main
Browse files Browse the repository at this point in the history
Syncing branch
  • Loading branch information
kazewong authored Mar 30, 2023
2 parents 998bc23 + 5471f92 commit 3f37b2e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 10 deletions.
11 changes: 9 additions & 2 deletions src/flowMC/sampler/Gaussian_random_walk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@ class GaussianRandomWalk(LocalSamplerBase):
jit: whether to jit the sampler
params: dictionary of parameters for the sampler
"""
def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable:
def __init__(self, logpdf: Callable, jit: bool, params: dict, verbose: bool = False) -> Callable:
super().__init__(logpdf, jit, params)
self.params = params
self.logpdf = logpdf
self.verbose = verbose

def make_kernel(self, return_aux = False) -> Callable:
"""
Expand Down Expand Up @@ -79,9 +80,15 @@ def rw_sampler(rng_key, n_steps, initial_position):
all_positions = (jnp.zeros((n_chains, n_steps) + initial_position.shape[-1:])) + initial_position[:, None]
all_logp = (jnp.zeros((n_chains, n_steps)) + logp[:, None])
state = (rng_key, all_positions, all_logp, acceptance, self.params)
for i in tqdm(range(1, n_steps)):
if self.verbose:
iterator_loop = tqdm(range(1, n_steps), desc="Sampling Locally", miniters=int(n_steps / 10))
else:
iterator_loop = range(1, n_steps)

for i in iterator_loop:
state = rw_update(i, state)
return state[:-1]


return rw_sampler

13 changes: 9 additions & 4 deletions src/flowMC/sampler/HMC.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class HMC(LocalSamplerBase):
params: dictionary of parameters for the sampler
"""

def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable:
def __init__(self, logpdf: Callable, jit: bool, params: dict, verbose: bool = False) -> Callable:
super().__init__(logpdf, jit, params)

self.potential = lambda x: -self.logpdf(x)
Expand All @@ -38,6 +38,7 @@ def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable:

self.kinetic = lambda p, params: 0.5*(p**2 * params['inverse_metric']).sum()
self.grad_kinetic = jax.grad(self.kinetic)
self.verbose = verbose

def get_initial_hamiltonian(self, rng_key: jax.random.PRNGKey, position: jnp.array,
params: dict):
Expand Down Expand Up @@ -167,9 +168,13 @@ def hmc_sampler(rng_key, n_steps, initial_position):
+ logp[:, None]
)
state = (rng_key, all_positions, all_logp, acceptance, self.params)
for i in tqdm(
range(1, n_steps), desc="Sampling Locally", miniters=int(n_steps / 10)
):

if self.verbose:
iterator_loop = tqdm(range(1, n_steps), desc="Sampling Locally", miniters=int(n_steps / 10))
else:
iterator_loop = range(1, n_steps)

for i in iterator_loop:
state = hmc_update(i, state)

state = (state[0], state[1], -state[2], state[3])
Expand Down
9 changes: 7 additions & 2 deletions src/flowMC/sampler/MALA.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@ class MALA(LocalSamplerBase):
params: dictionary of parameters for the sampler
"""

def __init__(self, logpdf: Callable, jit: bool, params: dict) -> Callable:
def __init__(self, logpdf: Callable, jit: bool, params: dict, verbose: bool = False) -> Callable:
super().__init__(logpdf, jit, params)
self.params = params
self.logpdf = logpdf
self.verbose = verbose

def make_kernel(self, return_aux = False) -> Callable:
"""
Expand Down Expand Up @@ -119,7 +120,11 @@ def mala_sampler(rng_key, n_steps, initial_position):
all_positions = (jnp.zeros((n_chains, n_steps) + initial_position.shape[-1:])) + initial_position[:, None]
all_logp = (jnp.zeros((n_chains, n_steps)) + logp[:, None])
state = (rng_key, all_positions, all_logp, acceptance, self.params)
for i in tqdm(range(1, n_steps)):
if self.verbose:
iterator_loop = tqdm(range(1, n_steps), desc="Sampling Locally", miniters=int(n_steps / 10))
else:
iterator_loop = range(1, n_steps)
for i in iterator_loop:
state = mala_update(i, state)
return state[:-1]

Expand Down
11 changes: 9 additions & 2 deletions src/flowMC/sampler/Sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import flax
import optax
from flowMC.sampler.LocalSampler_Base import LocalSamplerBase
from tqdm import tqdm


class Sampler():
Expand Down Expand Up @@ -297,7 +298,10 @@ def global_sampler_tuning(self, initial_position: jnp.ndarray) -> jnp.array:
"""
print("Training normalizing flow")
last_step = initial_position
for _ in range(self.n_loop_training):
for _ in tqdm(
range(self.n_loop_training),
desc="Tuning global sampler",
):
last_step = self.sampling_loop(last_step, training=True)
return last_step

Expand All @@ -314,7 +318,10 @@ def production_run(self, initial_position: jnp.ndarray) -> jnp.array:
"""
print("Starting Production run")
last_step = initial_position
for _ in range(self.n_loop_production):
for _ in tqdm(
range(self.n_loop_production),
desc="Production run",
):
last_step = self.sampling_loop(last_step)
return last_step

Expand Down

0 comments on commit 3f37b2e

Please sign in to comment.