""" Utility functions and classes """ import jinja2 as j2 import re group_colors: list[str] = [ "ffacab", "acabff", "6fe569", "83e5aa", "dfffbc", "fff6ae", "f4b3ff", ] def levenshtein_distance(s1, s2): """Compute the Levenshtein distance (edit distance) between two strings Shamelessly stolen from https://stackoverflow.com/a/32558749""" if len(s1) > len(s2): s1, s2 = s2, s1 distances = range(len(s1) + 1) for i2, c2 in enumerate(s2): distances_ = [i2 + 1] for i1, c1 in enumerate(s1): if c1 == c2: distances_.append(distances[i1]) else: distances_.append( 1 + min((distances[i1], distances[i1 + 1], distances_[-1])) ) distances = distances_ return distances[-1] class UnionFind: """A union-find implementation""" parent_of: list[int] _group_size: list[int] def __init__(self, elt_count: int): self.parent_of = list(range(elt_count)) self._group_size = [1] * elt_count def root(self, elt: int) -> int: """Find the element representing :elt: (root of component) Compresses paths along the way""" if self.parent_of[elt] == elt: return elt self.parent_of[elt] = self.root(self.parent_of[elt]) return self.parent_of[elt] def union(self, elt1: int, elt2: int) -> None: """Unites two components""" elt1 = self.root(elt1) elt2 = self.root(elt2) if elt1 == elt2: return if self._group_size[elt1] > self._group_size[elt2]: self.union(elt2, elt1) else: self._group_size[elt2] += self._group_size[elt1] self._group_size[elt1] = 0 self.parent_of[self.root(elt1)] = self.root(elt2) def group_size(self, elt: int) -> int: """Get the number of elements in the component of :elt:""" return self._group_size[self.root(elt)] def write_to_file(path: str, content: str) -> None: """Write :content: to the file at :path:, truncating it. Writes to stdout instead if :path: is `-`.""" if path == "-": print(content) else: with open(path, "w") as handle: handle.write(content) _MD_RE_ITAL = re.compile(r"\*(.+?)\*") _MD_RE_BOLD = re.compile(r"\*\*(.+?)\*\*") def md_format(val: str) -> str: val = _MD_RE_BOLD.sub(r"\\textbf{\1}", val) val = _MD_RE_ITAL.sub(r"\\textit{\1}", val) return val _TEX_NBSP = re.compile(r" ([?!])") def escape_latex(val: str) -> str: val = _TEX_NBSP.sub(r"~\1", val) return val.replace("&", r"\&") def j2_environment() -> j2.Environment: env = j2.Environment(loader=j2.PackageLoader("repartir_taches")) env.filters["escape_latex"] = escape_latex env.filters["md_format"] = md_format return env