WE-repartir-taches/repartir_taches/util.py

109 lines
2.8 KiB
Python
Raw Normal View History

""" Utility functions and classes """
2022-10-30 16:59:13 +01:00
import jinja2 as j2
import re
2022-10-30 18:06:02 +01:00
group_colors: list[str] = [
"ffacab",
"acabff",
"6fe569",
"83e5aa",
"dfffbc",
"fff6ae",
"f4b3ff",
]
def levenshtein_distance(s1, s2):
"""Compute the Levenshtein distance (edit distance) between two strings
Shamelessly stolen from https://stackoverflow.com/a/32558749"""
if len(s1) > len(s2):
s1, s2 = s2, s1
distances = range(len(s1) + 1)
for i2, c2 in enumerate(s2):
distances_ = [i2 + 1]
for i1, c1 in enumerate(s1):
if c1 == c2:
distances_.append(distances[i1])
else:
distances_.append(
1 + min((distances[i1], distances[i1 + 1], distances_[-1]))
)
distances = distances_
return distances[-1]
class UnionFind:
"""A union-find implementation"""
parent_of: list[int]
_group_size: list[int]
def __init__(self, elt_count: int):
self.parent_of = list(range(elt_count))
self._group_size = [1] * elt_count
def root(self, elt: int) -> int:
"""Find the element representing :elt: (root of component)
Compresses paths along the way"""
if self.parent_of[elt] == elt:
return elt
self.parent_of[elt] = self.root(self.parent_of[elt])
return self.parent_of[elt]
def union(self, elt1: int, elt2: int) -> None:
"""Unites two components"""
elt1 = self.root(elt1)
elt2 = self.root(elt2)
if elt1 == elt2:
return
if self._group_size[elt1] > self._group_size[elt2]:
self.union(elt2, elt1)
else:
self._group_size[elt2] += self._group_size[elt1]
self._group_size[elt1] = 0
self.parent_of[self.root(elt1)] = self.root(elt2)
def group_size(self, elt: int) -> int:
"""Get the number of elements in the component of :elt:"""
return self._group_size[self.root(elt)]
2022-10-30 16:59:13 +01:00
def write_to_file(path: str, content: str) -> None:
"""Write :content: to the file at :path:, truncating it. Writes to stdout instead
if :path: is `-`."""
if path == "-":
print(content)
else:
with open(path, "w") as handle:
handle.write(content)
_MD_RE_ITAL = re.compile(r"\*(.+?)\*")
_MD_RE_BOLD = re.compile(r"\*\*(.+?)\*\*")
def md_format(val: str) -> str:
val = _MD_RE_BOLD.sub(r"\\textbf{\1}", val)
val = _MD_RE_ITAL.sub(r"\\textit{\1}", val)
return val
_TEX_NBSP = re.compile(r" ([?!])")
def escape_latex(val: str) -> str:
val = _TEX_NBSP.sub(r"~\1", val)
return val.replace("&", r"\&")
def j2_environment() -> j2.Environment:
env = j2.Environment(loader=j2.PackageLoader("repartir_taches"))
env.filters["escape_latex"] = escape_latex
env.filters["md_format"] = md_format
return env