Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions deerlab/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
# __init__.py
from .dd_models import *
from .bg_models import *
from . import dd_models as _dd_models_mod
from . import bg_models as _bg_models_mod

# Define __getattr__ early so submodules that do `from deerlab import bg_*`
# during their own import (e.g. dipolarmodel) can resolve names via this hook.
def __getattr__(name):
if name in _dd_models_mod.__all__:
return _dd_models_mod.__getattr__(name)
if name in _bg_models_mod.__all__:
return _bg_models_mod.__getattr__(name)
raise AttributeError(f"module 'deerlab' has no attribute {name!r}")

from .model import Model, Penalty, Parameter, link, lincombine, merge, relate
from .deerload import deerload
from .selregparam import selregparam
Expand Down
17 changes: 17 additions & 0 deletions deerlab/bg_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import math as m
from numpy import pi
import inspect
from copy import deepcopy as _deepcopy
from deerlab.dipolarkernel import dipolarkernel
from deerlab.utils import formatted_table
from deerlab.model import Model
Expand Down Expand Up @@ -513,3 +514,19 @@ def _poly3(t,p0,p1,p2,p3):
bg_poly3.p3.set(description='3rd order weight', lb=-200, ub=200, par0=-1, unit=r'μs\ :sup:`-3`')
# Add documentation
bg_poly3.__doc__ = _docstring(bg_poly3,notes)


# ---------------------------------------------------------------------------
# Return a fresh deepcopy on every attribute access so that modifications
# to a retrieved model never affect the global template.
# ---------------------------------------------------------------------------
_templates = {name: obj for name, obj in list(globals().items()) if name.startswith('bg_')}
for _name in list(_templates):
del globals()[_name]

__all__ = list(_templates.keys())

def __getattr__(name):
if name in _templates:
return _deepcopy(_templates[name])
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
17 changes: 17 additions & 0 deletions deerlab/dd_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import numpy as np
import scipy.special as spc
from copy import deepcopy as _deepcopy
from deerlab.model import Model
from deerlab.utils import formatted_table

Expand Down Expand Up @@ -1042,3 +1043,19 @@ def _wormgauss(r,contour,persistence,std):
dd_wormgauss.std.set(description='Gaussian standard deviation', lb=0.01, ub=5, par0=0.2, unit='nm')
# Add documentation
dd_wormgauss.__doc__ = _dd_docstring(dd_wormgauss,notes) + docstr_example('dd_wormgauss')


# ---------------------------------------------------------------------------
# Return a fresh deepcopy on every attribute access so that modifications
# to a retrieved model never affect the global template.
# ---------------------------------------------------------------------------
_templates = {name: obj for name, obj in list(globals().items()) if name.startswith('dd_')}
for _name in list(_templates):
del globals()[_name]

__all__ = list(_templates.keys())

def __getattr__(name):
if name in _templates:
return _deepcopy(_templates[name])
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
4 changes: 3 additions & 1 deletion deerlab/dipolarmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def _importparameter(parameter):
'par0' : parameter.par0,
'description' : parameter.description,
'unit' : parameter.unit,
'linear' : parameter.linear
'linear' : parameter.linear,
'frozen' : parameter.frozen,
'value' : parameter.value
}
#------------------------------------------------------------------------

Expand Down
15 changes: 14 additions & 1 deletion test/test_ddmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,4 +104,17 @@ def test_dd_wormchain():
assert_ddmodel(dl.dd_wormchain)

def test_dd_wormgauss():
assert_ddmodel(dl.dd_wormgauss)
assert_ddmodel(dl.dd_wormgauss)


def test_freezing_model():
"Check that freezing parameters of a model works as expected"

# Create model and freeze parameters
model = dl.dd_gauss.copy()
model.mean.freeze(3)
model.std.freeze(0.2)

# Check that the frozen parameters are correctly set
assert model.mean.frozen and model.mean.value == 3
assert model.std.frozen and model.std.value == 0.2
15 changes: 15 additions & 0 deletions test/test_dipolarmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,21 @@ def test_fit_3pathways(V3path):
assert np.allclose(result.model,V3path)
# ======================================================================

# ======================================================================
def test_freeze_fit_linear(V1path):
"Check that the model can be correctly fitted with a frozen linear parameter"

dd_model = dd_gauss
dd_model.std.freeze(0.25)

assert dd_model.std.frozen and dd_model.std.value == 0.25
Vmodel = dipolarmodel(t,r,dd_gauss,bg_hom3d,npathways=1)
assert Vmodel.std.frozen and Vmodel.std.value == 0.25

result = fit(Vmodel,V1path,ftol=1e-4)

assert np.allclose(result.std,0.25)
# ======================================================================
# Fixtures
# ----------------------------------------------------------------------
@fixture(scope='module')
Expand Down
3 changes: 3 additions & 0 deletions test/test_model_penalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def test_weight_freeze(penalty_fcn):
penaltyobj = Penalty(penalty_fcn,'icc')
penaltyobj.weight.freeze(0.5)
assert penaltyobj.weight.frozen==True and penaltyobj.weight.value==0.5
penaltyobj.weight.unfreeze()
# ======================================================================

# ======================================================================
Expand All @@ -74,6 +75,7 @@ def test_fit(penalty_fcn, model, mock_data, selection):
"Check fitting with a penalty with ICC-selected weight"
penaltyobj = Penalty(penalty_fcn, selection)
penaltyobj.weight.set(lb=1e-6,ub=1e1)
assert not penaltyobj.weight.frozen
result = fit(model,mock_data,x,penalties=penaltyobj)
assert ovl(result.model,mock_data)>0.975
# ======================================================================
Expand All @@ -89,6 +91,7 @@ def test_fit_with_penalty_weight(penalty_fcn, model, mock_data, case):
penaltyobj.weight.freeze(0.00001)
result = fit(model,mock_data,x,penalties=penaltyobj)
assert ovl(result.model,mock_data)>0.975
penaltyobj.weight.unfreeze()
# ======================================================================

# ======================================================================
Expand Down
Loading