116 lines
3 KiB
Python
116 lines
3 KiB
Python
""" Utility functions and classes """
|
|
|
|
import jinja2 as j2
|
|
import re
|
|
|
|
|
|
group_colors: list[str] = [
|
|
"fdffb6",
|
|
"caffbf",
|
|
"9bf6ff",
|
|
"a0c4ff",
|
|
"ffc6ff",
|
|
"ebd8d0",
|
|
"70d6ff",
|
|
"ff70a6",
|
|
"ff9770",
|
|
"ffd670",
|
|
"e9ff70",
|
|
"a5ffd6",
|
|
"d3ab9e",
|
|
"b8e0d2",
|
|
]
|
|
|
|
|
|
def levenshtein_distance(s1, s2):
|
|
"""Compute the Levenshtein distance (edit distance) between two strings
|
|
|
|
Shamelessly stolen from https://stackoverflow.com/a/32558749"""
|
|
if len(s1) > len(s2):
|
|
s1, s2 = s2, s1
|
|
|
|
distances: list[int] = list(range(len(s1) + 1))
|
|
for i2, c2 in enumerate(s2):
|
|
distances_ = [i2 + 1]
|
|
for i1, c1 in enumerate(s1):
|
|
if c1 == c2:
|
|
distances_.append(distances[i1])
|
|
else:
|
|
distances_.append(
|
|
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
|
|
)
|
|
distances = distances_
|
|
return distances[-1]
|
|
|
|
|
|
class UnionFind:
|
|
"""A union-find implementation"""
|
|
|
|
parent_of: list[int]
|
|
_group_size: list[int]
|
|
|
|
def __init__(self, elt_count: int):
|
|
self.parent_of = list(range(elt_count))
|
|
self._group_size = [1] * elt_count
|
|
|
|
def root(self, elt: int) -> int:
|
|
"""Find the element representing :elt: (root of component)
|
|
|
|
Compresses paths along the way"""
|
|
if self.parent_of[elt] == elt:
|
|
return elt
|
|
self.parent_of[elt] = self.root(self.parent_of[elt])
|
|
return self.parent_of[elt]
|
|
|
|
def union(self, elt1: int, elt2: int) -> None:
|
|
"""Unites two components"""
|
|
elt1 = self.root(elt1)
|
|
elt2 = self.root(elt2)
|
|
if elt1 == elt2:
|
|
return
|
|
if self._group_size[elt1] > self._group_size[elt2]:
|
|
self.union(elt2, elt1)
|
|
else:
|
|
self._group_size[elt2] += self._group_size[elt1]
|
|
self._group_size[elt1] = 0
|
|
self.parent_of[self.root(elt1)] = self.root(elt2)
|
|
|
|
def group_size(self, elt: int) -> int:
|
|
"""Get the number of elements in the component of :elt:"""
|
|
return self._group_size[self.root(elt)]
|
|
|
|
|
|
def write_to_file(path: str, content: str) -> None:
|
|
"""Write :content: to the file at :path:, truncating it. Writes to stdout instead
|
|
if :path: is `-`."""
|
|
if path == "-":
|
|
print(content)
|
|
else:
|
|
with open(path, "w") as handle:
|
|
handle.write(content)
|
|
|
|
|
|
_MD_RE_ITAL = re.compile(r"\*(.+?)\*")
|
|
_MD_RE_BOLD = re.compile(r"\*\*(.+?)\*\*")
|
|
|
|
|
|
def md_format(val: str) -> str:
|
|
val = _MD_RE_BOLD.sub(r"\\textbf{\1}", val)
|
|
val = _MD_RE_ITAL.sub(r"\\textit{\1}", val)
|
|
return val
|
|
|
|
|
|
_TEX_NBSP = re.compile(r" ([?!:])")
|
|
|
|
|
|
def escape_latex(val: str) -> str:
|
|
val = _TEX_NBSP.sub(r"~\1", val)
|
|
val = val.replace("·", r"$\cdot$")
|
|
return val.replace("&", r"\&")
|
|
|
|
|
|
def j2_environment() -> j2.Environment:
|
|
env = j2.Environment(loader=j2.PackageLoader("repartir_taches"))
|
|
env.filters["escape_latex"] = escape_latex
|
|
env.filters["md_format"] = md_format
|
|
return env
|