Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions posteriors/sgmcmc/sglrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,13 @@ def update(
# Resolve schedules
lr_val = lr(state.step) if callable(lr) else lr
T_val = temperature(state.step) if callable(temperature) else temperature
lr_val = torch.as_tensor(
lr_val, dtype=state.params.dtype, device=state.params.device
)
T_val = torch.as_tensor(T_val, dtype=state.params.dtype, device=state.params.device)

# Spatial stepsize to make update binary
diffusion_val = torch.sqrt(2.0 * T_val)
delta_x = torch.sqrt(lr_val) * diffusion_val
diffusion_val = (2.0 * T_val) ** 0.5
delta_x = lr_val**0.5 * diffusion_val

# Per-parameter binary LRW transform
def transform_params(p, g):
p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[:, 2]
p_plus = ternary_probs(g, diffusion_val, lr_val, delta_x)[..., 2]

u = torch.rand_like(p_plus)
step_sign = torch.where(
Expand Down Expand Up @@ -139,6 +134,11 @@ def ternary_probs(
Returns:
Update probabilities as a tensor, with last axis being [p_minus, p_zero, p_plus].
"""
diffusion_val = torch.as_tensor(
diffusion_val, dtype=drift_val.dtype, device=drift_val.device
)
stepsize = torch.as_tensor(stepsize, dtype=drift_val.dtype, device=drift_val.device)
delta_x = torch.as_tensor(delta_x, dtype=drift_val.dtype, device=drift_val.device)
desired_mean = stepsize * drift_val
desired_var = stepsize * diffusion_val**2
scaled_mean = desired_mean / delta_x
Expand Down
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_baoa.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def lr(step):
transform = baoa.build(log_prob, lr)

# Initialise
params = torch.randn(dim)
params = {"w": torch.randn(2, 2), "b": torch.randn(1)}

# Verify inplace update
verify_inplace_update(transform, params, None)
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sghmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def lr(step):
transform = sghmc.build(log_prob, lr)

# Initialise
params = torch.randn(dim)
params = {"w": torch.randn(2, 2), "b": torch.randn(1)}

# Verify inplace update
verify_inplace_update(transform, params, None)
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sgld.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def lr(step):
transform = sgld.build(log_prob, lr)

# Initialise
params = torch.randn(dim)
params = {"w": torch.randn(2, 2), "b": torch.randn(1)}

# Verify inplace update
verify_inplace_update(transform, params, None)
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sglrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def lr(step):
transform = sglrw.build(log_prob, lr)

# Initialise
params = torch.randn(dim)
params = {"w": torch.randn(2, 2), "b": torch.randn(1)}

# Verify inplace update
verify_inplace_update(transform, params, None)
2 changes: 1 addition & 1 deletion tests/sgmcmc/test_sgnht.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def lr(step):
transform = sgnht.build(log_prob, lr)

# Initialise
params = torch.randn(dim)
params = {"w": torch.randn(2, 2), "b": torch.randn(1)}

# Verify inplace update
verify_inplace_update(transform, params, None)