Implement Levenshtein distance, to be used to apply hunks

This commit is contained in:
Théophile Bastian 2020-05-09 19:33:39 +02:00
parent 0a9c03d1e1
commit 50f54b1389
1 changed files with 180 additions and 0 deletions

View File

@ -0,0 +1,180 @@
""" Tools to make smart edition to apply a hunk """
from collections import OrderedDict
class Levenshtein:
cache = {}
moves = {
"D": (-1, 0),
"S": (-1, -1),
"I": (0, -1),
}
def __init__(self, pre, post):
self.pre = pre
self.post = post
self.work_matrix = None
self.op_matrix = None
@classmethod
def get_full_cache(cls):
return cls.cache
def cache_key(self):
if isinstance(self.pre, list):
return (tuple(self.pre), tuple(self.post))
return (self.pre, self.post)
def getcache(self):
return self.get_full_cache().get(self.cache_key(), None)
@classmethod
def clean_cache(cls):
cls.cache = {}
def insert_cost(self, to_insert, pre_idx, post_idx):
return 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 1
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
return int(subst_from != subst_to)
@staticmethod
def _argmin(**kwargs):
""" Returns the first encountered pair (k, v) with lowest v wrt. `<`. """
if not kwargs:
raise Exception("No arguments")
argmin = next(iter(kwargs))
valmin = kwargs[argmin]
for key in kwargs:
val = kwargs[key]
if val < valmin:
argmin = key
valmin = val
return (argmin, valmin)
def compute(self):
cached = self.getcache()
if cached:
return cached
self.op_matrix = [[None] + ["I"] * len(self.post)] + [
["D"] + [None] * (len(self.post)) for _ in range(len(self.pre))
]
self.work_matrix = [
[0] * (len(self.post) + 1) for _ in range(len(self.pre) + 1)
]
for post_idx in range(len(self.post)):
self.work_matrix[0][post_idx + 1] = self.work_matrix[0][
post_idx
] + self.insert_cost(self.post[post_idx], 0, post_idx + 1)
for pre_idx in range(len(self.pre)):
self.work_matrix[pre_idx + 1][0] = self.work_matrix[pre_idx][
0
] + self.del_cost(self.pre[pre_idx], pre_idx + 1, 0)
for pre_idx in range(len(self.pre)):
for post_idx in range(len(self.post)):
c_insert_cost = self.insert_cost(
self.post[post_idx], pre_idx + 1, post_idx + 1
)
c_del_cost = self.del_cost(self.pre[pre_idx], pre_idx + 1, post_idx + 1)
c_subst_cost = self.subst_cost(
self.pre[pre_idx], self.post[post_idx], pre_idx + 1, post_idx + 1
)
opmin, costmin = self._argmin(
**OrderedDict(
[
("D", self.work_matrix[pre_idx][post_idx + 1] + c_del_cost),
(
"I",
self.work_matrix[pre_idx + 1][post_idx] + c_insert_cost,
),
("S", self.work_matrix[pre_idx][post_idx] + c_subst_cost),
]
)
)
self.work_matrix[pre_idx + 1][post_idx + 1] = costmin
self.op_matrix[pre_idx + 1][post_idx + 1] = opmin
self.cur_mode = opmin
ops = []
pre_pos = len(self.pre)
post_pos = len(self.post)
while pre_pos > 0 or post_pos > 0:
cur_op = self.op_matrix[pre_pos][post_pos]
cost_pre = pre_pos
cost_post = post_pos
row_m, col_m = self.moves[cur_op]
pre_pos += row_m
post_pos += col_m
if cur_op == "S":
pre_val = self.pre[pre_pos]
post_val = self.post[post_pos]
ops.append(
(
cur_op,
(pre_val, post_val),
self.subst_cost(pre_val, post_val, cost_pre, cost_post),
)
)
if cur_op == "I":
post_val = self.post[post_pos]
ops.append(
(cur_op, post_val, self.insert_cost(post_val, cost_pre, cost_post))
)
if cur_op == "D":
pre_val = self.pre[pre_pos]
ops.append(
(cur_op, pre_val, self.del_cost(pre_val, cost_pre, cost_post))
)
cached_result = {
"count": self.work_matrix[-1][-1],
"ops": ops[::-1],
}
self.get_full_cache()[self.cache_key()] = cached_result
return cached_result
class InlineLevenshtein(Levenshtein):
""" Levenshtein distance for edition of a single line """
def insert_cost(self, to_insert, pre_idx, post_idx):
return 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 1 # 'x'
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
if subst_from == subst_to:
return 0
return 1
class HunkLevenshtein(Levenshtein):
""" Levenshtein distance for hunk edition, ie. should we delete a whole line,
insert a whole line or substitute one line with another """
def insert_cost(self, to_insert, pre_idx, post_idx):
indent = 0
while indent < len(to_insert) and to_insert[indent] == " ":
indent += 1
indent_cost = 3 if indent else 0
return indent_cost + len(to_insert.strip()) + 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 2
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
res = InlineLevenshtein(subst_from, subst_to).compute()
return res["count"]