Source code for pylelemmatize.mapper_ds

import re
import sys
from typing import List, Literal, Optional, Tuple, Union
import torch
from .fast_mapper import LemmatizerBMP
import random
from .abstract_mapper import AbstractLemmatizer, fast_str_to_numpy, fast_numpy_to_str
from collections import defaultdict
from glob import glob
import numpy as np


[docs] class Seq2SeqDs: @staticmethod def __labels_to_onehot(labels: torch.Tensor, max_label: int) -> torch.Tensor: """Convert a sequence of labels to one-hot encoded tensor.""" assert labels.ndim == 1, "Labels must be a 1D tensor." onehot = torch.zeros([labels.size(0), max_label + 1], dtype=torch.float32) onehot[torch.arange(len(labels)), labels] = 1.0 return onehot
[docs] @staticmethod def load_icdar2019_parallel_txt_corpus(input_paths: Union[str, List[str]], max_insertions: int, min_length: int, max_length: int) -> List[Tuple[List[str], List[str]]]: for file in input_paths: with open(file, 'r', encoding='utf-8') as f: original_ocr, alighend_ocr, alligned_gt = f.read().strip().split("\n") alighend_ocr = alighend_ocr.split("[OCR_aligned]")[1] alligned_gt = alligned_gt.split("[ GS_aligned]")[1] if len(alligned_gt) != len(alighend_ocr.replace("@","")): print(f"Warning: Original input and OCR aligned string have different lengths after removing '@' characters in file {file}.", file=sys.stderr) continue if alighend_ocr.replace("@","") != original_ocr: print(f"Warning: OCR aligned string and GT aligned string have different lengths in file {file}.", file=sys.stderr) continue pieces_idx = [] last_end = 0 for i in re.finditer(r'@' * (max_insertions - 1) + r'@+', alighend_ocr): start, end = i.span() pieces_idx.append((last_end, start)) last_end = end pieces_idx.append((last_end, len(alighend_ocr)))
[docs] @staticmethod def load_parallel_txt_corpus(input_glob: Union[str, List[str]], output_glob: Union[str, List[str]], check_integrity: Literal["cleanup", "raise", "ignore"] = "cleanup") -> List[Tuple[List[str], List[str]]]: if isinstance(input_glob, str): input_paths = list(glob(input_glob)) elif isinstance(input_glob, list): input_paths = input_glob else: raise ValueError("input_glob must be a string or a list of strings") if isinstance(output_glob, str): output_paths = list(glob(output_glob)) elif isinstance(output_glob, list): output_paths = output_glob else: raise ValueError("output_glob must be a string or a list of strings") data = defaultdict(list) for fname in input_paths + output_paths: k = fname.split("/")[-1].split(".")[0] lines = open(fname,"r").read().split("\n") lines = [line.strip() for line in lines if len(line.strip())] data[k].append(lines) if check_integrity == "ignore": values = list(data.values()) inputs = [v[0] for v in values] outputs = [v[1] for v in values] return inputs, outputs remove_ids = [] non_2 = 0 non_eq = 0 empty = 0 for k, v in data.items(): if len(v) != 2: print(f"No two {k}, found {len(v)}") non_2 += 1 remove_ids.append(k) continue if len(v[0])!=len(v[1]): non_eq += 1 print(f"Not equal lines {k}, found {len(v[0])} and {len(v[1])}") remove_ids.append(k) continue if len(v[0])==0: empty += 1 #print(f"Empty lines {k}") remove_ids.append(k) continue if check_integrity == "raise": if non_2 > 0: raise ValueError(f"Found {non_2} files with not exactly two lines in the corpus") if non_eq > 0: raise ValueError(f"Found {non_eq} files with not equal number of lines in the corpus") if empty > 0: raise ValueError(f"Found {empty} files with empty lines in the corpus") elif check_integrity == "cleanup": print(f"Found {non_2} files with not exactly two lines in the corpus") print(f"Found {non_eq} files with not equal number of lines in the corpus") print(f"Found {empty} files with empty lines in the corpus") erassing = 0 for k in remove_ids: erassing+= sum([len(v) for v in data[k]]) del data[k] kept_lines = sum([len(v[0]) for v in data.values()]) print(f"Kept: {kept_lines} double lines, erased {erassing} lines. Non 2 files: {non_2}, Not matching #: {non_eq}, Empty: {empty}") values = list(data.values()) inputs = [v[0] for v in values] outputs = [v[1] for v in values] return inputs, outputs
[docs] @staticmethod def from_parallel_txt_corpus(input_glob: Union[str, List[str]], output_glob: Union[str, List[str]], **kwargs) -> 'Seq2SeqDs': text_blocks = Seq2SeqDs.load_parallel_txt_corpus(input_glob, output_glob) return Seq2SeqDs(text_blocks, **kwargs)
[docs] @staticmethod def create_selfsupervised_ds(corpus: List[str], mapper: LemmatizerBMP, mapped_is_input: bool = True, add_all_occuring_to_input: bool = True, **kwargs) -> 'Seq2SeqDs': mapped_corpus = [mapper(text) for text in corpus] #mapped_corpus_str = sorted(fast_numpy_to_str(np.unique(fast_str_to_numpy(''.join(mapped_corpus))))) #corpus_str = sorted(fast_numpy_to_str(np.unique(fast_str_to_numpy(''.join(corpus))))) mapsrc_alphabet_str = mapper.src_alphabet_str if add_all_occuring_to_input: # Add all characters that occur in the corpus to the input alphabet corpus_occ = fast_numpy_to_str(np.unique(fast_str_to_numpy(''.join(corpus)))) # Ensure the corpus is processed to extract characters mapsrc_alphabet_str = ''.join(sorted(mapper.src_alphabet_str + corpus_occ)) if mapped_is_input: text_blocks = (mapped_corpus, corpus) out_mapper = LemmatizerBMP.from_alphabet_mapping(mapsrc_alphabet_str, unknown_chr=mapper.unknown_chr) in_mapper = LemmatizerBMP.from_alphabet_mapping(mapper.dst_alphabet_str, unknown_chr=mapper.unknown_chr) else: text_blocks = (corpus, mapped_corpus) out_mapper = LemmatizerBMP.from_alphabet_mapping(mapper.dst_alphabet_str, unknown_chr=mapper.unknown_chr) in_mapper = LemmatizerBMP.from_alphabet_mapping(mapsrc_alphabet_str, unknown_chr=mapper.unknown_chr) return Seq2SeqDs(text_blocks, input_mapper=in_mapper, output_mapper=out_mapper, **kwargs)
[docs] def __init__(self, text_blocks: Tuple[List[str], List[str]], input_mapper: Optional[LemmatizerBMP]=None, output_mapper: Optional[LemmatizerBMP]=None, min_input_seqlen: int = 50, min_output_seqlen: int = 50, one2one_mapping: Optional[bool] = None, crop_to_seqlen: Optional[int] = None, input_is_onehot: bool = False, output_is_onehot: bool = False): self.src_text_blocks = [] self.tgt_text_blocks = [] for n in range(len(text_blocks[0])): if len(text_blocks[0][n]) >= min_input_seqlen and len(text_blocks[1][n]) >= min_output_seqlen: self.src_text_blocks.append(text_blocks[0][n]) self.tgt_text_blocks.append(text_blocks[1][n]) if input_mapper is None: self.input_mapper = LemmatizerBMP.from_alphabet_mapping(AbstractLemmatizer.fast_alphabet_extraction(''.join(self.src_text_blocks))) else: self.input_mapper = input_mapper if output_mapper is None: self.output_mapper = LemmatizerBMP.from_alphabet_mapping(AbstractLemmatizer.fast_alphabet_extraction(''.join(self.tgt_text_blocks))) else: self.output_mapper = output_mapper self.min_input_seqlen = min_input_seqlen self.max_output_seqlen = min_output_seqlen if one2one_mapping is None: self.one2one_mapping = all([len(src) == len(tgt) for src, tgt in zip(self.src_text_blocks, self.tgt_text_blocks)]) elif isinstance(one2one_mapping, bool) and one2one_mapping: assert all([len(src) == len(tgt) for src, tgt in zip(self.src_text_blocks, self.tgt_text_blocks)]) self.one2one_mapping = True else: self.one2one_mapping = False if crop_to_seqlen is not None: assert self.one2one_mapping, "Cannot crop to seqlen if one2one_mapping is False" self.crop_seqlen = crop_to_seqlen else: self.crop_seqlen = None self.input_is_onehot = input_is_onehot self.output_is_onehot = output_is_onehot
def __len__(self) -> int: return len(self.src_text_blocks) def __getitem__(self, n: int, as_string: bool = False) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[str, str]]: src_txt = self.src_text_blocks[n] tgt_txt = self.tgt_text_blocks[n] if as_string: return src_txt, tgt_txt if self.crop_seqlen is not None: start_pos = random.randint(0, len(src_txt) - self.crop_seqlen - 1) end_pos = start_pos + self.crop_seqlen src_txt = src_txt[start_pos:end_pos] tgt_txt = tgt_txt[start_pos:end_pos] src_dense_labels = self.input_mapper.str_to_intlabel_seq(src_txt) tgt_dense_labels = self.output_mapper.str_to_intlabel_seq(tgt_txt) src_dense_labels = torch.tensor(src_dense_labels.astype(np.int64), dtype=torch.int64) tgt_dense_labels = torch.tensor(tgt_dense_labels.astype(np.int64), dtype=torch.int64) if self.input_is_onehot: res_src = self.__labels_to_onehot(src_dense_labels, self.input_mapper.len()-1) else: res_src = src_dense_labels if self.output_is_onehot: res_tgt = self.__labels_to_onehot(tgt_dense_labels, self.output_mapper.len()-1) else: res_tgt = tgt_dense_labels return res_src, res_tgt
[docs] def shuffle(self) -> None: idx = list(range(len(self))) random.shuffle(idx) self.src_text_blocks = [self.src_text_blocks[i] for i in idx] self.tgt_text_blocks = [self.tgt_text_blocks[i] for i in idx]
[docs] def split(self, train_ratio: float = 0.8, shuffle: bool = True) -> Tuple['Seq2SeqDs', 'Seq2SeqDs']: assert 0 < train_ratio < 1, "Ratio must be between 0 and 1" if shuffle: self.shuffle() split_idx = int(len(self) * train_ratio) train_ds = Seq2SeqDs((self.src_text_blocks[:split_idx], self.tgt_text_blocks[:split_idx]), input_mapper=self.input_mapper, output_mapper=self.output_mapper, min_input_seqlen=self.min_input_seqlen, min_output_seqlen=self.max_output_seqlen, one2one_mapping=self.one2one_mapping, crop_to_seqlen=self.crop_seqlen) val_ds = Seq2SeqDs((self.src_text_blocks[split_idx:], self.tgt_text_blocks[split_idx:]), input_mapper=self.input_mapper, output_mapper=self.output_mapper, min_input_seqlen=self.min_input_seqlen, min_output_seqlen=self.max_output_seqlen, one2one_mapping=self.one2one_mapping, crop_to_seqlen=self.crop_seqlen) return train_ds, val_ds
[docs] def compute_ds_CER(self, use_editdistance: bool = False) -> float: """Compute the Character Error Rate (CER) of the dataset.""" total_correct = 0 total_length = 0 if use_editdistance: raise NotImplementedError("Edit distance is not implemented yet.") else: for src_txt, tgt_txt in zip(self.src_text_blocks, self.tgt_text_blocks): total_length += max(len(tgt_txt), len(src_txt)) if len(src_txt) == len(tgt_txt): total_correct += len([1 for s, t in zip(src_txt, tgt_txt) if s == t]) accuracy = total_correct / total_length if total_length > 0 else 0.0 cer = 1 - accuracy return cer
[docs] def render_sample(self, n: int = 0, include_alphabet: bool = False) -> str: src_txt = self.src_text_blocks[n] tgt_txt = self.tgt_text_blocks[n] if include_alphabet: res = f"Input Alphabet: {self.input_mapper.src_alphabet_str}\n" res += f"Output Alphabet: {self.output_mapper.src_alphabet_str}\n" else: res = "" res += f"Sample {n}:\n" res += f"Source : {src_txt}\n" res += f"Source decoded: {self.input_mapper.intlabel_seq_to_str(self[n][0])}\n" res += f"Target : {tgt_txt}\n" res += f"Target decoded: {self.output_mapper.intlabel_seq_to_str(self[n][1])}\n" res += f"Source Tensor: {self[n][0]}\n" res += f"Target Tensor: {self[n][1]}\n" return res