from pathlib import Path
import sys
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import torch
from .fast_mapper import LemmatizerBMP
import pickle
[docs]
class CharConfusionMatrix:
[docs]
@staticmethod
def edit_distance(s1: np.ndarray, s2: np.ndarray) -> Tuple[int, np.ndarray]:
"""Compute the Levenshtein edit distance between two sequences.
This function calculates the minimum number of single-character edits
(insertions, deletions, or substitutions) required to change one sequence
into the other. It also returns the dynamic programming (DP) matrix used
to compute the distance.
Parameters
----------
s1 : np.ndarray
The first sequence as a NumPy array.
s2 : np.ndarray
The second sequence as a NumPy array.
The Levenshtein edit distance between `s1` and `s2`.
The DP matrix used to compute the distance, where `dp[i, j]` represents
the edit distance between the first `i` characters of `s1` and the first
`j` characters of `s2`.
Examples
--------
>>> import numpy as np
>>> s1 = np.array(['a', 'b', 'c'])
>>> s2 = np.array(['a', 'c', 'd'])
>>> distance, dp = edit_distance(s1, s2)
>>> distance
2
>>> dp
array([[0, 1, 2, 3],
[1, 0, 1, 2],
[2, 1, 1, 2],
[3, 2, 2, 2]])
"""
n, m = len(s1), len(s2)
dp = np.zeros((n + 1, m + 1), dtype=int)
for i in range(n + 1):
dp[i, 0] = i
for j in range(m + 1):
dp[0, j] = j
for i in range(1, n + 1):
for j in range(1, m + 1):
cost_sub = 0 if s1[i - 1] == s2[j - 1] else 1
diag = dp[i - 1, j - 1] + cost_sub
up = dp[i - 1, j] + 1
left = dp[i, j - 1] + 1
dp[i, j] = min(diag, up, left)
distance = dp[n, m]
return distance, dp
[docs]
def backtrace_ed_matrix(self, input_seq: np.ndarray, gt_seq: np.ndarray, dp: np.ndarray) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""
Backtraces the edit distance matrix to compute the alignment path, operation types,
ground truth substitutions, and confusion matrix.
Parameters
----------
input_seq : np.ndarray
The input sequence represented as an array of indices.
gt_seq : np.ndarray
The ground truth sequence represented as an array of indices.
dp : np.ndarray
The dynamic programming matrix containing the edit distances.
Returns
-------
path : np.ndarray
The alignment path as an array of (input_index, gt_index) pairs.
operation_type : np.ndarray
The sequence of operation types:
- 0: Match
- 1: Substitution
- 2: Deletion
- 3: Insertion
gt_sub_input : np.ndarray
The ground truth sequence with substitutions applied.
cm : np.ndarray
The confusion matrix representing the counts of matches, substitutions,
insertions, and deletions. The matrix has dimensions
(len(alphabet), len(alphabet)), where the first row/column represents
insertions/deletions.
"""
cm = np.zeros((len(self.alphabet), len(self.alphabet)), dtype=np.int32)
inp_idx = len(input_seq)
gt_idx = len(gt_seq)
path = []
operation_type = []
gt_sub_input = []
while inp_idx > 0 and gt_idx > 0:
choice = ((dp[inp_idx - 1, gt_idx - 1]-.00001, (-1, -1), 0),
(dp[inp_idx - 1, gt_idx], (-1, 0), 2),
(dp[inp_idx, gt_idx - 1], (0, -1), 3))
_, (di, dj), op_type = min(choice, key=lambda x: x[0])
inp_idx += di
gt_idx += dj
path.append((inp_idx, gt_idx))
if op_type == 0:
gt_sub_input.append(input_seq[inp_idx])
op_type = 0 if input_seq[inp_idx] == gt_seq[gt_idx] else 1
cm[input_seq[inp_idx], gt_seq[gt_idx]] += 1
elif op_type == 3: # Insertion
gt_sub_input.append(gt_seq[gt_idx])
cm[0, gt_seq[gt_idx]] += 1
elif op_type == 2: # Deletion
cm[input_seq[inp_idx], 0] += 1
operation_type.append(op_type)
while gt_idx > 0:
gt_idx -= 1
path.append((inp_idx, gt_idx))
operation_type.append(3)
gt_sub_input.append(gt_seq[gt_idx])
cm[0, gt_seq[gt_idx]] += 1
while inp_idx > 0:
inp_idx -= 1
path.append((inp_idx, gt_idx))
operation_type.append(2)
cm[input_seq[inp_idx], 0] += 1
return np.array(path)[::-1, :], np.array(operation_type)[::-1], np.array(gt_sub_input[::-1]), cm
[docs]
def ingest_textline_observation(self, pred_line: str, gt_line: str) -> Tuple[str, int]:
"""
Processes a pair of predicted and ground truth text lines, computes the edit distance,
and updates the confusion matrix.
This method performs the following steps:
1. Converts the predicted and ground truth text lines into dense integer label sequences.
2. Computes the edit distance and dynamic programming matrix between the two sequences.
3. Performs a backtrace on the edit distance matrix to generate the ground truth substitution input and updates the confusion matrix.
4. Returns the ground truth substitution input and the computed edit distance.
Parameters
----------
pred_line : str
The predicted text line as a string.
gt_line : str
The ground truth text line as a string.
Returns
-------
Tuple[str, int]
A tuple containing:
- The ground truth substitution input as a string.
- The edit distance between the predicted and ground truth text lines.
"""
dense_pred = self.alphabet.str_to_intlabel_seq(pred_line)
dense_gt = self.alphabet.str_to_intlabel_seq(gt_line)
distance, dp = self.edit_distance(dense_pred, dense_gt)
_, _, gt_sub_input, cm = self.backtrace_ed_matrix(dense_pred, dense_gt, dp)
self.cm += cm
return self.alphabet.intlabel_seq_to_str(gt_sub_input), distance
[docs]
def generate_random_substitution_sequences(self, seq) -> np.ndarray:
"""
Generate random substitution sequences based on a conditional probability matrix.
This method generates a sequence of random substitutions for the input
sequence ``seq`` using the confusion matrix as conditional probability. Each
output symbol is sampled from the conditional probabilities of the
corresponding input symbol.
Parameters
----------
seq : np.ndarray
Input sequence represented as a NumPy array of integers. Each integer
corresponds to a symbol in the vocabulary.
Returns
-------
np.ndarray
A NumPy array of the same shape as ``seq``, where each element is a
randomly substituted symbol based on the conditional probability matrix.
Examples
--------
>>> import numpy as np
>>> cm = np.array([[0, 0.5, 0.5],
... [0, 0.7, 0.3],
... [0, 0.4, 0.6]])
>>> seq = np.array([1, 2, 1])
>>> augmenter = SubstitutionAugmenter(cm)
>>> out = augmenter.generate_random_substitution_sequences(seq)
>>> out.shape == seq.shape
True
"""
e = 1e-8
cm = self.cm.copy()
cm[:, 0] = 0 # Zero out deletions
cm[0, :] = 0 # Zero out insertions
cm = cm / (cm.sum(axis=1, keepdims=True) + e) # Normalize to probabilities
all_cond_probabilities_cdf = np.cumsum(cm, axis=1)
rnd_vals = np.random.rand(seq.shape[0])
res = np.zeros_like(seq)
for n in range(seq.shape[0]):
input_symbol = seq[n]
res[n] = np.searchsorted(all_cond_probabilities_cdf[input_symbol, :], rnd_vals[n])
return res
[docs]
def get_self_supervision_textline(self, input_line: str) -> str:
dense_input = self.alphabet.str_to_intlabel_seq(input_line)
mutated = self.generate_random_substitution_sequences(dense_input)
return self.alphabet.intlabel_seq_to_str(mutated)
[docs]
def save(self, file_path: Union[str, Path]):
pickle.dump([self.alphabet.dst_alphabet_str, self.cm], open(file_path, "wb"))
[docs]
@staticmethod
def load(file_path: Union[str, Path]) -> "CharConfusionMatrix":
dst_alphabet_str, cm = pickle.load(open(file_path, "rb"))
lemmatizer = LemmatizerBMP.from_alphabet_mapping(dst_alphabet_str, dst_alphabet_str)
char_cm = CharConfusionMatrix(lemmatizer)
char_cm.cm = cm
return char_cm
[docs]
def __init__(self, alphabet: Union[LemmatizerBMP, str]):
if isinstance(alphabet, str):
alphabet = LemmatizerBMP.from_alphabet_mapping(alphabet, alphabet)
self.alphabet = alphabet
self.cm = np.zeros((len(self.alphabet), len(self.alphabet)), dtype=int)
[docs]
def get_matrix(self) -> np.ndarray:
return self.cm
[docs]
def distort_np_sequence(self, input_seq: np.ndarray) -> np.ndarray:
return self.generate_random_substitution_sequences(input_seq)
[docs]
def distort_pt_sequence(self, input_seq: torch.Tensor) -> torch.Tensor:
input_np = input_seq.cpu().numpy()
distorted_np = self.generate_random_substitution_sequences(input_np)
return torch.from_numpy(distorted_np).to(input_seq.device)
[docs]
def distort_string(self, input_str: str) -> str:
dense_input = self.alphabet.str_to_intlabel_seq(input_str)
mutated = self.generate_random_substitution_sequences(dense_input)
return self.alphabet.intlabel_seq_to_str(mutated)
[docs]
def __call__(self, seq: Union[np.ndarray, torch.Tensor, str]) -> Union[np.ndarray, torch.Tensor, str]:
if isinstance(seq, np.ndarray):
return self.distort_np_sequence(seq)
elif isinstance(seq, torch.Tensor):
return self.distort_pt_sequence(seq)
elif isinstance(seq, str):
return self.distort_string(seq)
else:
raise ValueError(f"Unsupported input type: {type(seq)}")
def main_get_augmented_substitutiononly_parallel_corpus():
import fargv
p = {
"gt_txt": "",
"src_txt": "",
"alphabet_str": "",
"out_txt": "",
}
args, _ = fargv.fargv(p)
if args.src_txt == "" and args.gt_txt == "":
#all_lines = sys.stdin.readlines()
#return
pass
class AugmentedSeq2SeqDsSubOnly():
def __init__(self, cm: Union[LemmatizerBMP, str, CharConfusionMatrix]):
if isinstance(cm, CharConfusionMatrix):
self.conf_mat = cm
elif isinstance(cm, (str, LemmatizerBMP)):
self.cm = CharConfusionMatrix(alphabet= cm)
else:
raise ValueError(f"Wrong type for cm")
def ingest_textline_observation(self, pred_line: str, gt_line: str) -> Tuple[str, int]:
if self.input_is_output:
self.textlines.append((gt_line, pred_line))
else:
self.textlines.append((pred_line, gt_line))
return super().ingest_textline_observation(pred_line, gt_line)
def __getitem__(self, index: int, as_string: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[str, str]]:
src_str, tgt_str = self.textlines[index]
src_np = self.alphabet.str_to_intlabel_seq(src_str)
tgt_np = self.alphabet.str_to_intlabel_seq(tgt_str)
augmented_src_np = self.generate_random_substitution_sequences(src_np)
src_dense_labels = torch.from_numpy(augmented_src_np).long()
tgt_dense_labels = torch.from_numpy(tgt_np).long()
src_dense_labels = torch.tensor(src_dense_labels.astype(np.int64), dtype=torch.int64)
tgt_dense_labels = torch.tensor(tgt_dense_labels.astype(np.int64), dtype=torch.int64)
if self.input_is_onehot:
res_src = self.__labels_to_onehot(src_dense_labels, self.input_mapper.len()-1)
else:
res_src = src_dense_labels
if self.output_is_onehot:
res_tgt = self.__labels_to_onehot(tgt_dense_labels, self.output_mapper.len()-1)
else:
res_tgt = tgt_dense_labels
return res_src, res_tgt
def __len__(self) -> int:
return len(self.textlines)
def main_textline_full_cer():
"""
Compute the full CER (including substitutions) between two textline files.
"""
import fargv
import sys
import tqdm
from pylelemmatize.demapper_lstm import DemapperLSTM
p = {
"src_tsv": "",
"src1_txt": "",
"src2_txt": "",
"ignore_lines_with_cer_above": 1.,
"verbose": False,
}
args, _ = fargv.fargv(p)
if args.src_tsv != "":
assert args.src1_txt == "" and args.src2_txt == "", "If src_tsv is provided, src1_txt and src2_txt must be empty"
textlines1, textlines2 = zip(*[line.split("\t")[:2] for line in open(args.src_tsv,"r").readlines()])
elif args.src1_txt != "" and args.src2_txt != "":
assert args.src_tsv == "", "If src1_txt and src2_txt are provided, src_tsv must be empty"
textlines1 = open(args.src1_txt,"r").readlines()
textlines2 = open(args.src2_txt,"r").readlines()
elif args.src1_txt == "" and args.src2_txt == "" and args.src_tsv == "":
textlines1 = []
textlines2 = []
for line in sys.stdin.readlines():
p1, p2 = line.split("\t")[:2]
textlines1.append(p1)
textlines2.append(p2)
assert len(textlines1) == len(textlines2), "Input files must have the same number of lines."
alphabet_str ="".join(sorted(set("".join(textlines1 + textlines2))))
alphabet = LemmatizerBMP.from_alphabet_mapping(alphabet_str, alphabet_str)
all_dist = 0
conf_acc = np.zeros([len(alphabet), len(alphabet)])
dropped_lines = 0
dropped_chars = 0
total_chars = 0
if args.verbose:
progress = tqdm.tqdm(total=len(textlines1), desc="Processing lines")
cm = CharConfusionMatrix(alphabet)
for l1, l2 in zip(textlines1, textlines2):
np_l1 = alphabet.str_to_intlabel_seq(l1)
np_l2 = alphabet.str_to_intlabel_seq(l2)
dist, dp = cm.edit_distance(np_l1, np_l2)
_, _, _, conf = cm.backtrace_ed_matrix(np_l1, np_l2, dp)
if (dist / len(l2.strip())) > args.ignore_lines_with_cer_above:
dropped_lines += 1
dropped_chars += len(l2)
continue
total_chars += len(l2)
all_dist += dist
conf_acc += conf
if args.verbose:
progress.update(1)
if args.verbose:
progress.close()
cer = all_dist / total_chars
insertions = conf_acc[:, 0].sum()
deletions = conf_acc[0, :].sum()
invalids = conf_acc[0, 0]
substitutions = conf_acc.sum() - np.trace(conf_acc) - (insertions + deletions + invalids)
print(f"Dropped lines: {dropped_lines}, Dropped characters: {dropped_chars}")
print(f"Total characters (after dropping): {total_chars}")
print(f"Total edit distance: {all_dist}")
print(f"CER: {cer:.4f} (Insertions: {insertions}, Deletions: {deletions}, Substitutions: {substitutions})")