From 366457638328d42fd17f1618ac9a70cbbba7f85e Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Sun, 26 Apr 2026 18:19:54 +0100 Subject: [PATCH 1/2] changed cui embedding method and fixed mention_mask generation --- .../transformer_context_model.py | 72 +++++++++++-------- 1 file changed, 41 insertions(+), 31 deletions(-) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index 73924e911..b02083e08 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans( mention_char_spans: list[tuple[int, int]], device: torch.device, ) -> Tensor: - """ - Convert character-level mention spans into a token-level mask. - - Args: - batch_dict: tokenizer output with 'offset_mapping' - mention_char_spans: list of (start_char, end_char) per example - device: torch device - - Returns: - mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens, - 0.0 otherwise - """ - offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2] - batch_size, seq_len, _ = offset_mapping.shape - mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device) - - for i, (mention_start, mention_end) in enumerate(mention_char_spans): - # For each token in the sequence - for j in range(seq_len): - token_start, token_end = offset_mapping[i, j].tolist() - # Skip padding tokens - if token_end == 0 and token_start == 0: - continue - # Check if token overlaps mention span - if token_end > mention_start and token_start < mention_end: - mask[i, j] = 1.0 - + """Convert character-level mention spans into a token-level mask.""" + offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2] + token_starts = offset_mapping[:, :, 0] # [B, seq_len] + token_ends = offset_mapping[:, :, 1] # [B, seq_len] + + spans_tensor = torch.tensor( + mention_char_spans, dtype=torch.long, device=device + ) # [B, 2] + mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1] + mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1] + + # Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding. + is_special = (token_starts == 0) & (token_ends == 0) + overlaps = (token_ends > mention_starts) & (token_starts < mention_ends) + mask = (overlaps & ~is_special).float() return mask + def _full_text_spans(self, texts: list[str]) -> list[tuple[int, int]]: + """Build mention spans that cover each full text entry.""" + return [(0, len(text)) for text in texts] + def embed( self, to_embed: list[str], @@ -301,7 +293,7 @@ def embed( ).to(target_device) mention_mask = None - if mention_spans is not None: + if self.cnf_l.use_mention_attention and mention_spans is not None: mention_mask = self._build_mention_mask_from_char_spans( batch_dict, mention_spans, @@ -322,7 +314,11 @@ def embed_cuis(self) -> None: """ self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding - cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] + # cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] + cui_names = [ + max(self.cdb.cui2info[cui].get("names"), key=len) + for cui in self._cui_keys + ] total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) all_embeddings = [] for names in tqdm( @@ -332,7 +328,14 @@ def embed_cuis(self) -> None: ): with torch.no_grad(): names_to_embed = [name.replace(self.separator, " ") for name in names] - embeddings = self.embed(names_to_embed, device=self.device) + mention_spans = None + if self.cnf_l.use_mention_attention: + mention_spans = self._full_text_spans(names_to_embed) + embeddings = self.embed( + names_to_embed, + mention_spans=mention_spans, + device=self.device, + ) all_embeddings.append(embeddings.cpu()) all_embeddings_matrix = torch.cat(all_embeddings, dim=0) @@ -358,7 +361,14 @@ def embed_names(self) -> None: names_to_embed = [ name.replace(self.separator, " ") for name in batch_names ] - embeddings = self.embed(names_to_embed, device=self.device) + mention_spans = None + if self.cnf_l.use_mention_attention: + mention_spans = self._full_text_spans(names_to_embed) + embeddings = self.embed( + names_to_embed, + mention_spans=mention_spans, + device=self.device, + ) all_embeddings.append(embeddings.cpu()) all_embeddings_matrix = torch.cat(all_embeddings, dim=0) From 65ea764d88d25c284d55419eb12a2b2826dd7eb5 Mon Sep 17 00:00:00 2001 From: Adam Sutton Date: Sun, 26 Apr 2026 20:18:39 +0100 Subject: [PATCH 2/2] fixed spacing --- .../medcat_embedding_linker/transformer_context_model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index b02083e08..fb08af4a7 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -252,17 +252,17 @@ def _build_mention_mask_from_char_spans( """Convert character-level mention spans into a token-level mask.""" offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2] token_starts = offset_mapping[:, :, 0] # [B, seq_len] - token_ends = offset_mapping[:, :, 1] # [B, seq_len] + token_ends = offset_mapping[:, :, 1] # [B, seq_len] spans_tensor = torch.tensor( mention_char_spans, dtype=torch.long, device=device ) # [B, 2] mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1] - mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1] + mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1] # Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding. is_special = (token_starts == 0) & (token_ends == 0) - overlaps = (token_ends > mention_starts) & (token_starts < mention_ends) + overlaps = (token_ends > mention_starts) & (token_starts < mention_ends) mask = (overlaps & ~is_special).float() return mask