Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ def _pre_inference(
def predict_entities(
self, doc: MutableDocument, ents: list[MutableEntity] | None = None
) -> list[MutableEntity]:
if self.cnf_l.train and self.comp_name == "embedding_linker":
if self.cnf_l.train and self.name == "embedding_linker":
logger.warning(
"Attemping to train a static embedding linker. "
"This is not possible / required."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans(
mention_char_spans: list[tuple[int, int]],
device: torch.device,
) -> Tensor:
"""
Convert character-level mention spans into a token-level mask.

Args:
batch_dict: tokenizer output with 'offset_mapping'
mention_char_spans: list of (start_char, end_char) per example
device: torch device

Returns:
mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens,
0.0 otherwise
"""
offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2]
batch_size, seq_len, _ = offset_mapping.shape
mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device)

for i, (mention_start, mention_end) in enumerate(mention_char_spans):
# For each token in the sequence
for j in range(seq_len):
token_start, token_end = offset_mapping[i, j].tolist()
# Skip padding tokens
if token_end == 0 and token_start == 0:
continue
# Check if token overlaps mention span
if token_end > mention_start and token_start < mention_end:
mask[i, j] = 1.0

"""Convert character-level mention spans into a token-level mask."""
offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2]
token_starts = offset_mapping[:, :, 0] # [B, seq_len]
token_ends = offset_mapping[:, :, 1] # [B, seq_len]

spans_tensor = torch.tensor(
mention_char_spans, dtype=torch.long, device=device
) # [B, 2]
mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1]
mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1]

# Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding.
is_special = (token_starts == 0) & (token_ends == 0)
overlaps = (token_ends > mention_starts) & (token_starts < mention_ends)
mask = (overlaps & ~is_special).float()
return mask

def _full_text_spans(self, texts: list[str]) -> list[tuple[int, int]]:
"""Build mention spans that cover each full text entry."""
return [(0, len(text)) for text in texts]

def embed(
self,
to_embed: list[str],
Expand All @@ -301,7 +293,7 @@ def embed(
).to(target_device)

mention_mask = None
if mention_spans is not None:
if self.cnf_l.use_mention_attention and mention_spans is not None:
mention_mask = self._build_mention_mask_from_char_spans(
batch_dict,
mention_spans,
Expand All @@ -322,7 +314,11 @@ def embed_cuis(self) -> None:
"""
self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding

cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
# cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
cui_names = [
max(self.cdb.cui2info[cui].get("names"), key=len)
for cui in self._cui_keys
]
total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size)
all_embeddings = []
for names in tqdm(
Expand All @@ -332,7 +328,14 @@ def embed_cuis(self) -> None:
):
with torch.no_grad():
names_to_embed = [name.replace(self.separator, " ") for name in names]
embeddings = self.embed(names_to_embed, device=self.device)
mention_spans = None
if self.cnf_l.use_mention_attention:
mention_spans = self._full_text_spans(names_to_embed)
embeddings = self.embed(
names_to_embed,
mention_spans=mention_spans,
device=self.device,
)
all_embeddings.append(embeddings.cpu())

all_embeddings_matrix = torch.cat(all_embeddings, dim=0)
Expand All @@ -358,7 +361,14 @@ def embed_names(self) -> None:
names_to_embed = [
name.replace(self.separator, " ") for name in batch_names
]
embeddings = self.embed(names_to_embed, device=self.device)
mention_spans = None
if self.cnf_l.use_mention_attention:
mention_spans = self._full_text_spans(names_to_embed)
embeddings = self.embed(
names_to_embed,
mention_spans=mention_spans,
device=self.device,
)
all_embeddings.append(embeddings.cpu())

all_embeddings_matrix = torch.cat(all_embeddings, dim=0)
Expand Down
Empty file.
117 changes: 117 additions & 0 deletions medcat-plugins/transformer-ner/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
[project]
name = "medcat-transformer-ner"

dynamic = ["version"]

description = "Transformer based NER for MedCAT"

readme = "README.md"

requires-python = ">=3.10"

license = {text = "Apache-2.0"}

keywords = ["ML", "NLP", "NER+L"]

authors = [
{name = "A. Sutton"},
{name = "T. Searle"},
{name = "M. Ratas"},
]

# This should be your name or the names of the organization who currently
# maintains the project, and a valid email address corresponding to the name
# listed.
maintainers = [
{name = "CogStack", email = "contact@cogstack.org" }
]

classifiers = [
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
"Development Status :: 3 - Alpha",

"Intended Audience :: Healthcare Industry",
# "Topic :: Natural Language Processing :: Named Entity Recognition and Linking",

# Specify the Python versions you support here. In particular, ensure
# that you indicate you support Python 3. These classifiers are *not*
# checked by "pip install". See instead "python_requires" below.
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3 :: Only",
"Operating System :: OS Independent",
]

# This field lists other packages that your project depends on to run.
# Any package you put here will be installed by pip when your project is
# installed, so they must be valid existing projects.
#
# For an analysis of this field vs pip's requirements files see:
# https://packaging.python.org/discussions/install-requires-vs-requirements/
dependencies = [
"medcat[spacy]>=2.7",
"transformers>=4.41.0,<5.0", # avoid major bump
"torch>=2.4.0,<3.0",
"tqdm",
]

# List additional groups of dependencies here (e.g. development
# dependencies). Users will be able to install these using the "extras"
# syntax, for example:
#
# $ pip install sampleproject[dev]
#
# Similar to `dependencies` above, these must be valid existing
# projects.
[project.optional-dependencies] # Optional
dev = [
"ruff~=0.1.7",
"mypy",
"types-tqdm",
"types-setuptools",
"types-PyYAML",
]

# entry-points to add onto medcat
[project.entry-points."medcat.plugins"]
medcat_transformer_ner = "medcat_transformer_ner"

[project.urls]
"Homepage" = "https://cogstack.org/"
"Bug Reports" = "https://discourse.cogstack.org/"
"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/transformer-ner"

[build-system]
# These are the assumed default build requirements from pip:
# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support
requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"]
build-backend = "setuptools.build_meta"

[tool.setuptools]
package-dir = {"" = "src"}

[tool.setuptools.packages.find]
where = ["src"]

[tool.setuptools.package-data]
"medcat_ner_transformer" = ["py.typed"]

[tool.setuptools_scm]
# look for .git folder in root of repo
root = "../.."
version_scheme = "post-release"
local_scheme = "no-local-version"
tag_regex = "^medcat-transformer-ner/v(?P<version>\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$"
git_describe_command = "git describe --dirty --tags --long --match 'medcat-transformer-ner/v*'"

[tool.ruff.lint]
# 1. Enable some extra checks for ruff
select = ["E", "F"]
# ignore unused local variables
ignore = ["F841"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .registration import do_registration as __register

__register()
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Optional, Any
from medcat.config import Ner

class TransformerNER(Ner):
"""The config exclusively used for the transformer NER"""
language_model_name: str = "nlpie/distil-clinicalbert"
"""Name/path of the language model. It must be downloadable from
huggingface linked from an appropriate file directory"""
training_batch_size: int = 32
"""The size of the batch to be used for training."""
max_token_length: int = 512
"""Max number of tokens to be passed to the language model.
Longer sequences will be chunked"""
overlap_chunking: float = 0.2
"""Max number of tokens to be passed to the language model.
Longer sequences will be chunked"""
gpu_device: Optional[Any] = None
"""Choose a device for the model to be stored / computed on. If None
then an appropriate GPU device that is available will be chosen"""
require_link_candidates: bool = True
"""Generate ent.link_candidates based on detected names. This requires
checking the CDB.name2info, and is required for vocab based linking.
Set to true becuase even if you don't use it, whats the harm?"""
learning_rate: float = 2e-5
"""The learning rate to be used for training the model"""
weight_decay: float = 0.01
"""The weight decay to be used for training the model"""
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import logging

from medcat.components.types import CoreComponentType
from medcat.components.types import lazy_register_core_component


logger = logging.getLogger(__name__)


def do_registration():
lazy_register_core_component(
CoreComponentType.ner,
"transformer_ner",
"medcat_transformer_ner.transformer_ner",
"NER.create_new_component",
)
Loading
Loading