diff --git a/posteriors/sgmcmc/sglrw.py b/posteriors/sgmcmc/sglrw.py index 7555bf5..52ce96b 100644 --- a/posteriors/sgmcmc/sglrw.py +++ b/posteriors/sgmcmc/sglrw.py @@ -91,18 +91,13 @@ def update( # Resolve schedules lr_val = lr(state.step) if callable(lr) else lr T_val = temperature(state.step) if callable(temperature) else temperature - lr_val = torch.as_tensor( - lr_val, dtype=state.params.dtype, device=state.params.device - ) - T_val = torch.as_tensor(T_val, dtype=state.params.dtype, device=state.params.device) - # Spatial stepsize to make update binary - diffusion_val = torch.sqrt(2.0 * T_val) - delta_x = torch.sqrt(lr_val) * diffusion_val + diffusion_val = (2.0 * T_val) ** 0.5 + delta_x = lr_val**0.5 * diffusion_val # Per-parameter binary LRW transform def transform_params(p, g): - p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[:, 2] + p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[..., 2] u = torch.rand_like(p_plus) step_sign = torch.where( @@ -139,6 +134,11 @@ def ternary_probs( Returns: Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus]. """ + diffusion_val = torch.as_tensor( + diffusion_val, dtype=drift_val.dtype, device=drift_val.device + ) + stepsize = torch.as_tensor(stepsize, dtype=drift_val.dtype, device=drift_val.device) + delta_x = torch.as_tensor(delta_x, dtype=drift_val.dtype, device=drift_val.device) desired_mean = stepsize * drift_val desired_var = stepsize * diffusion_val**2 scaled_mean = desired_mean / delta_x diff --git a/tests/sgmcmc/test_baoa.py b/tests/sgmcmc/test_baoa.py index 468f4f9..131c0c8 100644 --- a/tests/sgmcmc/test_baoa.py +++ b/tests/sgmcmc/test_baoa.py @@ -41,7 +41,7 @@ def lr(step): transform = baoa.build(log_prob, lr) # Initialise - params = torch.randn(dim) + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sghmc.py b/tests/sgmcmc/test_sghmc.py index f0bdc24..94573e1 100644 --- a/tests/sgmcmc/test_sghmc.py +++ b/tests/sgmcmc/test_sghmc.py @@ -49,7 +49,7 @@ def lr(step): transform = sghmc.build(log_prob, lr) # Initialise - params = torch.randn(dim) + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sgld.py b/tests/sgmcmc/test_sgld.py index 348853e..4dac258 100644 --- a/tests/sgmcmc/test_sgld.py +++ b/tests/sgmcmc/test_sgld.py @@ -34,7 +34,7 @@ def lr(step): transform = sgld.build(log_prob, lr) # Initialise - params = torch.randn(dim) + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sglrw.py b/tests/sgmcmc/test_sglrw.py index 93c3a0c..458012f 100644 --- a/tests/sgmcmc/test_sglrw.py +++ b/tests/sgmcmc/test_sglrw.py @@ -33,7 +33,7 @@ def lr(step): transform = sglrw.build(log_prob, lr) # Initialise - params = torch.randn(dim) + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None) diff --git a/tests/sgmcmc/test_sgnht.py b/tests/sgmcmc/test_sgnht.py index 74b0d9a..0707cb3 100644 --- a/tests/sgmcmc/test_sgnht.py +++ b/tests/sgmcmc/test_sgnht.py @@ -49,7 +49,7 @@ def lr(step): transform = sgnht.build(log_prob, lr) # Initialise - params = torch.randn(dim) + params = {"w": torch.randn(2, 2), "b": torch.randn(1)} # Verify inplace update verify_inplace_update(transform, params, None)