diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index a056a4b51..60ac8e6fd 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -585,7 +585,7 @@ def _pre_inference( def predict_entities( self, doc: MutableDocument, ents: list[MutableEntity] | None = None ) -> list[MutableEntity]: - if self.cnf_l.train and self.comp_name == "embedding_linker": + if self.cnf_l.train and self.name == "embedding_linker": logger.warning( "Attemping to train a static embedding linker. " "This is not possible / required." diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index 73924e911..b02083e08 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans( mention_char_spans: list[tuple[int, int]], device: torch.device, ) -> Tensor: - """ - Convert character-level mention spans into a token-level mask. - - Args: - batch_dict: tokenizer output with 'offset_mapping' - mention_char_spans: list of (start_char, end_char) per example - device: torch device - - Returns: - mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens, - 0.0 otherwise - """ - offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2] - batch_size, seq_len, _ = offset_mapping.shape - mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device) - - for i, (mention_start, mention_end) in enumerate(mention_char_spans): - # For each token in the sequence - for j in range(seq_len): - token_start, token_end = offset_mapping[i, j].tolist() - # Skip padding tokens - if token_end == 0 and token_start == 0: - continue - # Check if token overlaps mention span - if token_end > mention_start and token_start < mention_end: - mask[i, j] = 1.0 - + """Convert character-level mention spans into a token-level mask.""" + offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2] + token_starts = offset_mapping[:, :, 0] # [B, seq_len] + token_ends = offset_mapping[:, :, 1] # [B, seq_len] + + spans_tensor = torch.tensor( + mention_char_spans, dtype=torch.long, device=device + ) # [B, 2] + mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1] + mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1] + + # Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding. + is_special = (token_starts == 0) & (token_ends == 0) + overlaps = (token_ends > mention_starts) & (token_starts < mention_ends) + mask = (overlaps & ~is_special).float() return mask + def _full_text_spans(self, texts: list[str]) -> list[tuple[int, int]]: + """Build mention spans that cover each full text entry.""" + return [(0, len(text)) for text in texts] + def embed( self, to_embed: list[str], @@ -301,7 +293,7 @@ def embed( ).to(target_device) mention_mask = None - if mention_spans is not None: + if self.cnf_l.use_mention_attention and mention_spans is not None: mention_mask = self._build_mention_mask_from_char_spans( batch_dict, mention_spans, @@ -322,7 +314,11 @@ def embed_cuis(self) -> None: """ self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding - cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] + # cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] + cui_names = [ + max(self.cdb.cui2info[cui].get("names"), key=len) + for cui in self._cui_keys + ] total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) all_embeddings = [] for names in tqdm( @@ -332,7 +328,14 @@ def embed_cuis(self) -> None: ): with torch.no_grad(): names_to_embed = [name.replace(self.separator, " ") for name in names] - embeddings = self.embed(names_to_embed, device=self.device) + mention_spans = None + if self.cnf_l.use_mention_attention: + mention_spans = self._full_text_spans(names_to_embed) + embeddings = self.embed( + names_to_embed, + mention_spans=mention_spans, + device=self.device, + ) all_embeddings.append(embeddings.cpu()) all_embeddings_matrix = torch.cat(all_embeddings, dim=0) @@ -358,7 +361,14 @@ def embed_names(self) -> None: names_to_embed = [ name.replace(self.separator, " ") for name in batch_names ] - embeddings = self.embed(names_to_embed, device=self.device) + mention_spans = None + if self.cnf_l.use_mention_attention: + mention_spans = self._full_text_spans(names_to_embed) + embeddings = self.embed( + names_to_embed, + mention_spans=mention_spans, + device=self.device, + ) all_embeddings.append(embeddings.cpu()) all_embeddings_matrix = torch.cat(all_embeddings, dim=0) diff --git a/medcat-plugins/transformer-ner/README.md b/medcat-plugins/transformer-ner/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/medcat-plugins/transformer-ner/pyproject.toml b/medcat-plugins/transformer-ner/pyproject.toml new file mode 100644 index 000000000..9c0ab6e1f --- /dev/null +++ b/medcat-plugins/transformer-ner/pyproject.toml @@ -0,0 +1,117 @@ +[project] +name = "medcat-transformer-ner" + +dynamic = ["version"] + +description = "Transformer based NER for MedCAT" + +readme = "README.md" + +requires-python = ">=3.10" + +license = {text = "Apache-2.0"} + +keywords = ["ML", "NLP", "NER+L"] + +authors = [ + {name = "A. Sutton"}, + {name = "T. Searle"}, + {name = "M. Ratas"}, +] + +# This should be your name or the names of the organization who currently +# maintains the project, and a valid email address corresponding to the name +# listed. +maintainers = [ + {name = "CogStack", email = "contact@cogstack.org" } +] + +classifiers = [ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + "Development Status :: 3 - Alpha", + + "Intended Audience :: Healthcare Industry", + # "Topic :: Natural Language Processing :: Named Entity Recognition and Linking", + + # Specify the Python versions you support here. In particular, ensure + # that you indicate you support Python 3. These classifiers are *not* + # checked by "pip install". See instead "python_requires" below. + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Programming Language :: Python :: 3 :: Only", + "Operating System :: OS Independent", +] + +# This field lists other packages that your project depends on to run. +# Any package you put here will be installed by pip when your project is +# installed, so they must be valid existing projects. +# +# For an analysis of this field vs pip's requirements files see: +# https://packaging.python.org/discussions/install-requires-vs-requirements/ +dependencies = [ + "medcat[spacy]>=2.7", + "transformers>=4.41.0,<5.0", # avoid major bump + "torch>=2.4.0,<3.0", + "tqdm", +] + +# List additional groups of dependencies here (e.g. development +# dependencies). Users will be able to install these using the "extras" +# syntax, for example: +# +# $ pip install sampleproject[dev] +# +# Similar to `dependencies` above, these must be valid existing +# projects. +[project.optional-dependencies] # Optional +dev = [ + "ruff~=0.1.7", + "mypy", + "types-tqdm", + "types-setuptools", + "types-PyYAML", +] + +# entry-points to add onto medcat +[project.entry-points."medcat.plugins"] +medcat_transformer_ner = "medcat_transformer_ner" + +[project.urls] +"Homepage" = "https://cogstack.org/" +"Bug Reports" = "https://discourse.cogstack.org/" +"Source" = "https://github.com/CogStack/cogstack-nlp/tree/main/medcat-plugins/transformer-ner" + +[build-system] +# These are the assumed default build requirements from pip: +# https://pip.pypa.io/en/stable/reference/pip/#pep-517-and-518-support +requires = ["setuptools>=43.0.0", "setuptools_scm>=8", "wheel"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +package-dir = {"" = "src"} + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.setuptools.package-data] +"medcat_ner_transformer" = ["py.typed"] + +[tool.setuptools_scm] +# look for .git folder in root of repo +root = "../.." +version_scheme = "post-release" +local_scheme = "no-local-version" +tag_regex = "^medcat-transformer-ner/v(?P\\d+(?:\\.\\d+)*)(?:[ab]\\d+|rc\\d+)?$" +git_describe_command = "git describe --dirty --tags --long --match 'medcat-transformer-ner/v*'" + +[tool.ruff.lint] +# 1. Enable some extra checks for ruff +select = ["E", "F"] +# ignore unused local variables +ignore = ["F841"] diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py new file mode 100644 index 000000000..1f1d2b174 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/__init__.py @@ -0,0 +1,3 @@ +from .registration import do_registration as __register + +__register() diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py new file mode 100644 index 000000000..ec90eedf3 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/config.py @@ -0,0 +1,27 @@ +from typing import Optional, Any +from medcat.config import Ner + +class TransformerNER(Ner): + """The config exclusively used for the transformer NER""" + language_model_name: str = "nlpie/distil-clinicalbert" + """Name/path of the language model. It must be downloadable from + huggingface linked from an appropriate file directory""" + training_batch_size: int = 32 + """The size of the batch to be used for training.""" + max_token_length: int = 512 + """Max number of tokens to be passed to the language model. + Longer sequences will be chunked""" + overlap_chunking: float = 0.2 + """Max number of tokens to be passed to the language model. + Longer sequences will be chunked""" + gpu_device: Optional[Any] = None + """Choose a device for the model to be stored / computed on. If None + then an appropriate GPU device that is available will be chosen""" + require_link_candidates: bool = True + """Generate ent.link_candidates based on detected names. This requires + checking the CDB.name2info, and is required for vocab based linking. + Set to true becuase even if you don't use it, whats the harm?""" + learning_rate: float = 2e-5 + """The learning rate to be used for training the model""" + weight_decay: float = 0.01 + """The weight decay to be used for training the model""" \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py new file mode 100644 index 000000000..71b42af43 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/registration.py @@ -0,0 +1,16 @@ +import logging + +from medcat.components.types import CoreComponentType +from medcat.components.types import lazy_register_core_component + + +logger = logging.getLogger(__name__) + + +def do_registration(): + lazy_register_core_component( + CoreComponentType.ner, + "transformer_ner", + "medcat_transformer_ner.transformer_ner", + "NER.create_new_component", + ) diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py new file mode 100644 index 000000000..ec1ce5988 --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner.py @@ -0,0 +1,533 @@ +from pathlib import Path +from typing import Any, Optional, Union +from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken +from medcat.components.types import CoreComponentType, TrainableComponent +from medcat.components.types import AbstractEntityProvidingComponent +from medcat.components.ner.vocab_based_annotator import annotate_name +from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.vocab import Vocab +from medcat.cdb import CDB +from medcat.config.config import ComponentConfig +from medcat.storage.serialisables import AbstractManualSerialisable +from transformers import AutoTokenizer, AutoModelForTokenClassification +from medcat_transformer_ner.config import TransformerNER +import logging +import os +import torch + + +import numpy as np +from collections import Counter + +logger = logging.getLogger(__name__) + + +class NER(AbstractEntityProvidingComponent, TrainableComponent, AbstractManualSerialisable): + name = 'transformer_ner' + + comp_name = "transformer_ner" + _MODEL_FOLDER_NAME = "trainable_embedding_model" + _MODEL_STATE_FILE_NAME = "model_state.pt" + + def __init__(self, tokenizer: BaseTokenizer, + cdb: CDB) -> None: + super().__init__() + self.tokenizer = tokenizer + self.cdb = cdb + self.config = self.cdb.config + + # NER model stuff! + self.cnf_ner: TransformerNER = self.config.components.ner + self.label2id = { + "O": 0, + "B-ENT": 1, + "I-ENT": 2 + } + self.id2label = {v: k for k, v in self.label2id.items()} + self._model_init_kwargs = dict() + self.load_transformers(self.cnf_ner.language_model_name) + self.max_token_length = self.cnf_ner.max_token_length + self.overlap_chunking = self.cnf_ner.overlap_chunking + # class_weights = torch.tensor([ + # 0.2, # O + # 1.0, # B-ENT + # 1.0 # I-ENT + # ], device=self.device) + # self.loss_fct = torch.nn.CrossEntropyLoss( + # weight=class_weights, + # ignore_index=-100 + # ) + + @staticmethod + def _resolve_model_source(path_or_model_name: Union[str, Path]) -> str: + """Return local absolute path if it exists, otherwise keep HF model id.""" + candidate = Path(path_or_model_name).expanduser() + if candidate.exists(): + return str(candidate.resolve()) + return str(path_or_model_name) + + def _get_model_init_kwargs(self) -> dict[str, Any]: + """Build kwargs passed to ModelForEmbeddingLinking.from_pretrained.""" + return dict(self._model_init_kwargs) + + def load_transformers(self, language_model_name: Union[str, Path]) -> None: + """Load tokenizer/model from local path or Hugging Face model id.""" + model_source = self._resolve_model_source(language_model_name) + model_init_kwargs = self._get_model_init_kwargs() + + if ( + not hasattr(self, "model") + or not hasattr(self, "transformer_tokenizer") + or model_source != self._loaded_model_source + or model_init_kwargs != self._loaded_model_init_kwargs + ): + self.cnf_ner.language_model_name = str(language_model_name) + + self.transformer_tokenizer = AutoTokenizer.from_pretrained( + model_source, + clean_up_tokenization_spaces=False # might be an issue + ) + self.model = AutoModelForTokenClassification.from_pretrained( + model_source, + num_labels=3, + id2label=self.id2label, + label2id=self.label2id, + ) + self.model.eval() + self.device = torch.device( + self.cnf_ner.gpu_device + or ("cuda" if torch.cuda.is_available() else "cpu") + ) + self.model.to(self.device) + self._loaded_model_source = model_source + self._loaded_model_init_kwargs = model_init_kwargs + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=2e-5, weight_decay=0.01) + logger.debug( + "Loaded embedding model: %s (resolved source: %s) with kwargs=%s " \ + "on device: %s", + language_model_name, + model_source, + model_init_kwargs, + self.device, + ) + + def get_type(self) -> CoreComponentType: + return CoreComponentType.ner + + def _chunk_and_encode(self, + text: str, + entities: Optional[list[MutableEntity]] = None + ) -> tuple[list, list, list, list, Optional[list]]: + labels_enabled = entities is not None + # First pass: tokenize full text to get offsets for chunking and label alignment + base_encoding = self.transformer_tokenizer( + text, + return_offsets_mapping=True, + add_special_tokens=False + ) + + offsets = base_encoding["offset_mapping"] + + stride = self.max_token_length - int(self.max_token_length * self.overlap_chunking) + + n_tokens = len(base_encoding["input_ids"]) + start_idx = 0 + + input_ids = [] + attention_masks = [] + all_labels = [] if labels_enabled else None + offset_mappings = [] + chunk_char_starts = [] + while start_idx < n_tokens: + end_idx = min(start_idx + self.max_token_length, n_tokens) + + chunk_offsets = offsets[start_idx:end_idx] + + char_start = chunk_offsets[0][0] + char_end = chunk_offsets[-1][1] + chunk_text = text[char_start:char_end] + + # Rebase entities to chunk + # iff this is a training example + if labels_enabled: + chunk_entities = [] + for ent in entities: + ent_start = ent.base.start_char_index + ent_end = ent.base.end_char_index + + if ent_end > char_start and ent_start < char_end: + chunk_entities.append({ + "start": ent_start - char_start, + "end": ent_end - char_start + }) + + # Tokenize chunk + encoding = self.transformer_tokenizer( + chunk_text, + return_offsets_mapping=True, + truncation=True, + padding="max_length", + max_length=self.max_token_length + ) + + offsets_chunk = encoding["offset_mapping"] + + # Label alignment to relevant chunks + if labels_enabled: + labels = [ + -100 if (start == end) else self.label2id["O"] + for start, end in offsets_chunk + ] + + + for ent in chunk_entities: + started = False + for i, (token_start, token_end) in enumerate(offsets_chunk): + if token_start < ent["end"] and token_end > ent["start"]: + if not started: + labels[i] = self.label2id["B-ENT"] + started = True + else: + labels[i] = self.label2id["I-ENT"] + + all_labels.append(labels) + + input_ids.append(encoding["input_ids"]) + attention_masks.append(encoding["attention_mask"]) + offset_mappings.append(offsets_chunk) + chunk_char_starts.append(char_start) + + if end_idx == n_tokens: + break + + start_idx += stride + input_ids = torch.tensor(input_ids, dtype=torch.long).to(self.device) + attention_masks = torch.tensor(attention_masks, dtype=torch.long).to(self.device) + if labels_enabled: + all_labels = torch.tensor(all_labels, dtype=torch.long).to(self.device) + return input_ids, attention_masks, offset_mappings, chunk_char_starts, all_labels + + def _focal_loss(self, logits, labels, gamma=2.0, ignore_index=-100): + # flatten + logits = logits.view(-1, logits.size(-1)) + labels = labels.view(-1) + + # mask ignored + valid_mask = labels != ignore_index + logits = logits[valid_mask] + labels = labels[valid_mask] + + # standard CE + ce_loss = torch.nn.functional.cross_entropy(logits, labels, reduction='none') + + # pt = probability of correct class + pt = torch.exp(-ce_loss) + + # focal scaling + loss = ((1 - pt) ** gamma) * ce_loss + + return loss.mean() + + def train(self, cui: str, + entity: MutableEntity, + doc: MutableDocument, + negative: bool = False, + names: Union[list[str], dict] = []) -> None: + """Train the NER component on a given document. This is used in the + supervised training loop of the MedCAT trainer. + """ + # if this is the last entity, we'll train + # kind of a hacky work around, but it's minimal impact on the CAT trainer + if entity is doc.ner_ents[-1]: + text = doc.base.text + entities = doc.ner_ents + input_ids, attention_masks, _, _, labels = self._chunk_and_encode(text, entities) + self.optimizer.zero_grad() + self.model.train() + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks, + labels=labels + ) + + loss = outputs.loss + loss.backward() + + logger.debug("NER training step - loss: ", + loss.item()) + + self.optimizer.step() + + def _decode_chunk(self, preds, offsets_chunk, chunk_char_start): + """For inference only. Decode a single chunk of predictions into entity + spans, then merge them across chunks.""" + spans = [] + current = None + + for pred_id, (tok_start, tok_end) in zip(preds, offsets_chunk): + + # skip padding / special tokens + if (tok_start, tok_end) == (0, 0): + continue + + label = self.id2label[pred_id] + + # if label is "O", we close any open entity span and move on + if label == "O": + if current is not None: + spans.append(current) + current = None + continue + + # This is a bit too general for a binary ENT/ Non Ent + # But it's extendable... maybe! + prefix, ent_type = label.split("-", 1) + + abs_start = chunk_char_start + tok_start + abs_end = chunk_char_start + tok_end + + # if prefix is "B", we start a new entity span, closing any + # open one first. If prefix is "I", we continue the current + # span if it's the same entity type, otherwise we treat it + # as a new "B" span (this handles broken BIO sequences). + if prefix == "B": + if current is not None: + spans.append(current) + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + # if prefix is "I", we continue the current span if it's + # the same entity type, otherwise we treat it as a new "B" + # span (this handles broken BIO sequences). + # TODO: other methods of handling broken BIO? + elif prefix == "I": + if current is not None and current["label"] == ent_type: + current["end"] = abs_end + else: + # broken BIO -> treat as B + current = { + "start": abs_start, + "end": abs_end, + "label": ent_type + } + + if current is not None: + spans.append(current) + + return spans + + def _merge_spans(self, spans): + """Merge spans across chunk boundaries. This is required before creating + entities in the doc, otherwise we might have duplicates for the same + entity that got split across chunks. Used in inference only.""" + if not spans: + return [] + + spans = sorted(spans, key=lambda x: (x["start"], x["end"])) + merged = [spans[0]] + + for span in spans[1:]: + last = merged[-1] + + if span["label"] == last["label"] and span["start"] <= last["end"]: + last["end"] = max(last["end"], span["end"]) + else: + merged.append(span) + + return merged + + def _char_span_to_token_span(self, + doc: MutableDocument, + start_char: int, + end_char: int) -> Optional[tuple[int, int]]: + """Compatibility with SpaCy tokenization - convert character span to token span. + Used in inference only.""" + spacy_doc = doc._delegate + # Prefer strict/inner alignment first + span = spacy_doc.char_span(start_char, end_char, alignment_mode="contract") + # This very rarely fails + # If it does, we've got expand then some manual token offset checking as a final fallback. + if span is None: + span = spacy_doc.char_span(start_char, end_char, alignment_mode="expand") + if span is not None: + return span.start, span.end + + # derive token indices from token character offsets. + token_start = None + token_end = None + for tok in spacy_doc: + tok_start = tok.idx + tok_end = tok.idx + len(tok) + + if tok_end <= start_char: + continue + if tok_start >= end_char and token_end is not None: + break + + if token_start is None and tok_end > start_char: + token_start = tok.i + if tok_start < end_char: + token_end = tok.i + 1 + + if token_start is None or token_end is None or token_start >= token_end: + return None + return token_start, token_end + + def _preprocess_tokens(self, tokens: list[MutableToken]) -> str: + tokens_raw = ' '.join(tkn.text.lower() for tkn in tokens).strip() + return tokens_raw.replace(' ', self.config.general.separator) + + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + """Detect candidates for concepts - linker will then be able + to do the rest. It adds `entities` to the doc.ner_ents and each + entity can have the entity.link_candidates - that the linker + will resolve. + + Args: + doc (MutableDocument): + Spacy document to be annotated with named entities. + ents (list[MutableEntity] | None): + The entities given. This should be None. + + Returns: + list[MutableEntity]: + The NER'ed entities. + """ + # Keep offset generation in the same coordinate space as spaCy char_span. + text = doc._delegate.text + input_ids, attention_masks, offset_mappings, chunk_char_starts, _ = self._chunk_and_encode(text) + + self.model.eval() + with torch.no_grad(): + input_ids = input_ids.to(self.device) + attention_masks = attention_masks.to(self.device) + + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_masks + ) + predictions = outputs.logits.argmax(dim=-1).cpu().tolist() + + all_spans = [] + for preds, offsets_chunk, char_start in zip( + predictions, + offset_mappings, + chunk_char_starts + ): + spans = self._decode_chunk(preds, offsets_chunk, char_start) + all_spans.extend(spans) + final_spans = self._merge_spans(all_spans) + + ner_ents = [] + seen_token_spans = set() + for span in final_spans: + token_char_end = max(span["start"], span["end"] - 1) + tokens = doc.get_tokens(span["start"], token_char_end) + if not tokens: + continue + + # I'm not sure if this is required or beneficial. + # Essentially in the case where you don't require link candidates + # We only need the detected name, no candidates. So the span that is detected + # by the model can potentially be linked + if not self.cnf_ner.require_link_candidates: + token_start = tokens[0].base.index + token_end = tokens[-1].base.index + 1 + span_key = (token_start, token_end) + if span_key not in seen_token_spans: + ent = self.tokenizer.create_entity( + doc, + token_start, + token_end, + text[span["start"]:span["end"]] + ) + if ent: + ner_ents.append(ent) + seen_token_spans.add(span_key) + + for i in range(len(tokens)): + for j in range(i + 1, len(tokens) + 1): + sub_tokens = tokens[i:j] + preprocessed_sub_name = self._preprocess_tokens(sub_tokens) + if preprocessed_sub_name not in self.cdb.name2info: + continue + + token_start = sub_tokens[0].base.index + token_end = sub_tokens[-1].base.index + 1 + span_key = (token_start, token_end) + if span_key in seen_token_spans: + continue + + ent = None + if not self.cnf_ner.require_link_candidates: + detected_name = text[ + sub_tokens[0].base.char_index: + sub_tokens[-1].base.char_index + len(sub_tokens[-1].text) + ] + ent = self.tokenizer.create_entity( + doc, + token_start, + token_end, + detected_name + ) + else: + ent = annotate_name( + self.tokenizer, + preprocessed_sub_name, + sub_tokens, + doc, + self.cdb, + len(ner_ents), + 'concept' + ) + + if ent: + detected_name = text[ + sub_tokens[0].base.char_index: + sub_tokens[-1].base.char_index + len(sub_tokens[-1].text) + ] + ner_ents.append(ent) + seen_token_spans.add(span_key) + return ner_ents + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> 'TransformerNER': + return cls(tokenizer, cdb) + + def serialise_to(self, folder_path: str) -> None: + os.makedirs(folder_path, exist_ok=True) + model_folder = os.path.join(folder_path, self._MODEL_FOLDER_NAME) + os.makedirs(model_folder, exist_ok=True) + + torch.save( + # TODO: save gracefully when NER model done + self.model.state_dict(), + os.path.join(model_folder, self._MODEL_STATE_FILE_NAME), + ) + + @classmethod + def deserialise_from( + cls, folder_path: str, **init_kwargs + ) -> "NER": + cdb = init_kwargs["cdb"] + tokenizer = init_kwargs["tokenizer"] + ner = cls(tokenizer, cdb) + + model_state_path = os.path.join( + folder_path, cls._MODEL_FOLDER_NAME, cls._MODEL_STATE_FILE_NAME + ) + + # TODO: handle this gracefully when NER model done + if os.path.exists(model_state_path): + state_dict = torch.load(model_state_path, map_location=ner.device) + ner.model.load_state_dict(state_dict) + + return ner \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py new file mode 100644 index 000000000..ef46c035d --- /dev/null +++ b/medcat-plugins/transformer-ner/src/medcat_transformer_ner/transformer_ner_model.py @@ -0,0 +1,158 @@ +from pathlib import Path +from typing import Any, Iterator, Optional, Union +from medcat.storage.serialisables import AbstractSerialisable +from torch import Tensor, nn +from transformers import AutoModel, AutoTokenizer +from tqdm import tqdm +import json +import logging +import math +import torch +import torch.nn.functional as F + +logger = logging.getLogger(__name__) + +class ModelForBinaryNER(nn.Module): + """Wrapper around a Hugging Face transformer for transformer-based NER. + + + """ + + def __init__( + self, + embedding_model_name: str, + top_n_layers_to_unfreeze: int = -1, + device: Optional[Union[str, torch.device]] = None, + ) -> None: + super().__init__() + self.language_model = AutoModel.from_pretrained(embedding_model_name) + self.base_model_name = self.language_model.name_or_path + + # TODO: rest of logic here + + self.top_n_layers_to_unfreeze = top_n_layers_to_unfreeze + + hidden_size = self.language_model.config.hidden_size + + self._freeze_all_parameters() + self.unfreeze_top_n_lm_layers(self.top_n_layers_to_unfreeze) + + target_device = self._resolve_device(device) + self.to(target_device) + + @staticmethod + def _resolve_device(device: Optional[Union[str, torch.device]]) -> torch.device: + if device is None: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + return torch.device(device) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @staticmethod + def masked_mean_pooling(token_embeddings: Tensor, mask: Tensor) -> Tensor: + mask = mask.unsqueeze(-1).float() + summed = torch.sum(token_embeddings * mask, dim=1) + counts = torch.clamp(mask.sum(dim=1), min=1e-9) + return summed / counts + + def forward(self, **inputs) -> Tensor: + # Don't pass the mention_mask to the language model if it does exist + mention_mask = inputs.pop("mention_mask", None) + model_output = self.language_model(**inputs) + + pooling_mask = ( + mention_mask if mention_mask is not None else inputs["attention_mask"] + ) + sentence_embeddings = self.masked_mean_pooling( + model_output.last_hidden_state, pooling_mask + ) + + # TODO: logic required + + pass + + def _freeze_all_parameters(self) -> None: + for param in self.language_model.parameters(): + param.requires_grad = False + + if self.use_projection_layer: + for param in self.projection_layer.parameters(): + param.requires_grad = True + + def unfreeze_top_n_lm_layers(self, n: int) -> None: + # train all LM layers - each layer requires more data + if n == -1: + for param in self.language_model.parameters(): + param.requires_grad = True + return + + # keep LM fully frozen - better with less data + if n == 0: + return + + # BERT-likes + if hasattr(self.language_model, "encoder") and hasattr( + self.language_model.encoder, "layer" + ): + layers = self.language_model.encoder.layer + # DistilBERT-likes + elif hasattr(self.language_model, "transformer") and hasattr( + self.language_model.transformer, "layer" + ): + layers = self.language_model.transformer.layer + else: + raise ValueError("Unsupported LM architecture for layer unfreezing.") + + total_layers = len(layers) + n = min(n, total_layers) + for layer in layers[-n:]: + for param in layer.parameters(): + param.requires_grad = True + + def save_pretrained(self, save_directory: Union[str, Path]) -> None: + save_path = Path(save_directory) + save_path.mkdir(parents=True, exist_ok=True) + + torch.save(self.state_dict(), save_path / "pytorch_model.bin") + + config = { + "embedding_model_name": self.base_model_name, + "use_projection_layer": self.use_projection_layer, + "top_n_layers_to_unfreeze": self.top_n_layers_to_unfreeze, + } + with open(save_path / "config.json", "w", encoding="utf-8") as f: + json.dump(config, f, indent=2) + + @classmethod + def from_pretrained( + cls, + path_or_model_name: Union[str, Path], + device: Optional[Union[str, torch.device]] = None, + **kwargs, + ) -> "ModelForBinaryNER": + path = Path(path_or_model_name) + config_path = path / "config.json" + weights_path = path / "pytorch_model.bin" + target_device = cls._resolve_device(device) + + # Local saved wrapper model. + if config_path.exists() and weights_path.exists(): + with open(config_path, encoding="utf-8") as f: + config = json.load(f) + + config.update(kwargs) + model = cls(**config) + state_dict = torch.load(weights_path, map_location="cpu") + model.load_state_dict(state_dict) + model.to(target_device) + return model + + # Hugging Face model id/path. + model = cls( + embedding_model_name=str(path_or_model_name), + device=target_device, + **kwargs, + ) + return model \ No newline at end of file diff --git a/medcat-plugins/transformer-ner/tests/__init__.py b/medcat-plugins/transformer-ner/tests/__init__.py new file mode 100644 index 000000000..b40364e1c --- /dev/null +++ b/medcat-plugins/transformer-ner/tests/__init__.py @@ -0,0 +1,26 @@ +# NOTE: mostly copied from medcat tests +import atexit +import os +import shutil + + +RESOURCES_PATH = os.path.join(os.path.dirname(__file__), "resources") +EXAMPLE_MODEL_PACK_ZIP = os.path.join(RESOURCES_PATH, "mct2_model_pack.zip") +UNPACKED_EXAMPLE_MODEL_PACK_PATH = os.path.join( + RESOURCES_PATH, "mct2_model_pack") + + +# unpack model pack at start so we can access stuff like Vocab +print("Unpacking included test model pack") +shutil.unpack_archive(EXAMPLE_MODEL_PACK_ZIP, UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +def _del_unpacked_model(): + print( + "Cleaning up! Removing unpacked exmaple model pack:", + UNPACKED_EXAMPLE_MODEL_PACK_PATH, + ) + shutil.rmtree(UNPACKED_EXAMPLE_MODEL_PACK_PATH) + + +atexit.register(_del_unpacked_model) diff --git a/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip new file mode 100644 index 000000000..b6bc74e49 Binary files /dev/null and b/medcat-plugins/transformer-ner/tests/resources/mct2_model_pack.zip differ