Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -249,35 +249,27 @@ def _build_mention_mask_from_char_spans(
mention_char_spans: list[tuple[int, int]],
device: torch.device,
) -> Tensor:
"""
Convert character-level mention spans into a token-level mask.

Args:
batch_dict: tokenizer output with 'offset_mapping'
mention_char_spans: list of (start_char, end_char) per example
device: torch device

Returns:
mask: [batch_size, seq_len] float Tensor, 1.0 for mention tokens,
0.0 otherwise
"""
offset_mapping = batch_dict["offset_mapping"] # [B, max_token_length, 2]
batch_size, seq_len, _ = offset_mapping.shape
mask = torch.zeros((batch_size, seq_len), dtype=torch.float32, device=device)

for i, (mention_start, mention_end) in enumerate(mention_char_spans):
# For each token in the sequence
for j in range(seq_len):
token_start, token_end = offset_mapping[i, j].tolist()
# Skip padding tokens
if token_end == 0 and token_start == 0:
continue
# Check if token overlaps mention span
if token_end > mention_start and token_start < mention_end:
mask[i, j] = 1.0

"""Convert character-level mention spans into a token-level mask."""
offset_mapping = batch_dict["offset_mapping"] # [B, seq_len, 2]
token_starts = offset_mapping[:, :, 0] # [B, seq_len]
token_ends = offset_mapping[:, :, 1] # [B, seq_len]

spans_tensor = torch.tensor(
mention_char_spans, dtype=torch.long, device=device
) # [B, 2]
mention_starts = spans_tensor[:, 0].unsqueeze(1) # [B, 1]
mention_ends = spans_tensor[:, 1].unsqueeze(1) # [B, 1]

# Tokens with offset (0, 0) are special tokens (CLS, SEP) or padding.
is_special = (token_starts == 0) & (token_ends == 0)
overlaps = (token_ends > mention_starts) & (token_starts < mention_ends)
mask = (overlaps & ~is_special).float()
return mask

def _full_text_spans(self, texts: list[str]) -> list[tuple[int, int]]:
"""Build mention spans that cover each full text entry."""
return [(0, len(text)) for text in texts]

def embed(
self,
to_embed: list[str],
Expand All @@ -301,7 +293,7 @@ def embed(
).to(target_device)

mention_mask = None
if mention_spans is not None:
if self.cnf_l.use_mention_attention and mention_spans is not None:
mention_mask = self._build_mention_mask_from_char_spans(
batch_dict,
mention_spans,
Expand All @@ -322,7 +314,11 @@ def embed_cuis(self) -> None:
"""
self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding

cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
# cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys]
cui_names = [
max(self.cdb.cui2info[cui].get("names"), key=len)
for cui in self._cui_keys
]
total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size)
all_embeddings = []
for names in tqdm(
Expand All @@ -332,7 +328,14 @@ def embed_cuis(self) -> None:
):
with torch.no_grad():
names_to_embed = [name.replace(self.separator, " ") for name in names]
embeddings = self.embed(names_to_embed, device=self.device)
mention_spans = None
if self.cnf_l.use_mention_attention:
mention_spans = self._full_text_spans(names_to_embed)
embeddings = self.embed(
names_to_embed,
mention_spans=mention_spans,
device=self.device,
)
all_embeddings.append(embeddings.cpu())

all_embeddings_matrix = torch.cat(all_embeddings, dim=0)
Expand All @@ -358,7 +361,14 @@ def embed_names(self) -> None:
names_to_embed = [
name.replace(self.separator, " ") for name in batch_names
]
embeddings = self.embed(names_to_embed, device=self.device)
mention_spans = None
if self.cnf_l.use_mention_attention:
mention_spans = self._full_text_spans(names_to_embed)
embeddings = self.embed(
names_to_embed,
mention_spans=mention_spans,
device=self.device,
)
all_embeddings.append(embeddings.cpu())

all_embeddings_matrix = torch.cat(all_embeddings, dim=0)
Expand Down
Loading