180 lines
5.7 KiB
Python
180 lines
5.7 KiB
Python
""" Tools to make smart edition to apply a hunk """
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
|
class Levenshtein:
|
|
cache = {}
|
|
moves = {
|
|
"D": (-1, 0),
|
|
"S": (-1, -1),
|
|
"I": (0, -1),
|
|
}
|
|
|
|
def __init__(self, pre, post):
|
|
self.pre = pre
|
|
self.post = post
|
|
self.work_matrix = None
|
|
self.op_matrix = None
|
|
|
|
@classmethod
|
|
def get_full_cache(cls):
|
|
return cls.cache
|
|
|
|
def cache_key(self):
|
|
if isinstance(self.pre, list):
|
|
return (tuple(self.pre), tuple(self.post))
|
|
return (self.pre, self.post)
|
|
|
|
def getcache(self):
|
|
return self.get_full_cache().get(self.cache_key(), None)
|
|
|
|
@classmethod
|
|
def clean_cache(cls):
|
|
cls.cache = {}
|
|
|
|
def insert_cost(self, to_insert, pre_idx, post_idx):
|
|
return 1
|
|
|
|
def del_cost(self, to_delete, pre_idx, post_idx):
|
|
return 1
|
|
|
|
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
|
|
return int(subst_from != subst_to)
|
|
|
|
@staticmethod
|
|
def _argmin(**kwargs):
|
|
""" Returns the first encountered pair (k, v) with lowest v wrt. `<`. """
|
|
if not kwargs:
|
|
raise Exception("No arguments")
|
|
argmin = next(iter(kwargs))
|
|
valmin = kwargs[argmin]
|
|
|
|
for key in kwargs:
|
|
val = kwargs[key]
|
|
if val < valmin:
|
|
argmin = key
|
|
valmin = val
|
|
return (argmin, valmin)
|
|
|
|
def compute(self):
|
|
cached = self.getcache()
|
|
if cached:
|
|
return cached
|
|
|
|
self.op_matrix = [[None] + ["I"] * len(self.post)] + [
|
|
["D"] + [None] * (len(self.post)) for _ in range(len(self.pre))
|
|
]
|
|
self.work_matrix = [
|
|
[0] * (len(self.post) + 1) for _ in range(len(self.pre) + 1)
|
|
]
|
|
|
|
for post_idx in range(len(self.post)):
|
|
self.work_matrix[0][post_idx + 1] = self.work_matrix[0][
|
|
post_idx
|
|
] + self.insert_cost(self.post[post_idx], 0, post_idx + 1)
|
|
for pre_idx in range(len(self.pre)):
|
|
self.work_matrix[pre_idx + 1][0] = self.work_matrix[pre_idx][
|
|
0
|
|
] + self.del_cost(self.pre[pre_idx], pre_idx + 1, 0)
|
|
|
|
for pre_idx in range(len(self.pre)):
|
|
for post_idx in range(len(self.post)):
|
|
c_insert_cost = self.insert_cost(
|
|
self.post[post_idx], pre_idx + 1, post_idx + 1
|
|
)
|
|
c_del_cost = self.del_cost(self.pre[pre_idx], pre_idx + 1, post_idx + 1)
|
|
c_subst_cost = self.subst_cost(
|
|
self.pre[pre_idx], self.post[post_idx], pre_idx + 1, post_idx + 1
|
|
)
|
|
|
|
opmin, costmin = self._argmin(
|
|
**OrderedDict(
|
|
[
|
|
("D", self.work_matrix[pre_idx][post_idx + 1] + c_del_cost),
|
|
(
|
|
"I",
|
|
self.work_matrix[pre_idx + 1][post_idx] + c_insert_cost,
|
|
),
|
|
("S", self.work_matrix[pre_idx][post_idx] + c_subst_cost),
|
|
]
|
|
)
|
|
)
|
|
|
|
self.work_matrix[pre_idx + 1][post_idx + 1] = costmin
|
|
self.op_matrix[pre_idx + 1][post_idx + 1] = opmin
|
|
self.cur_mode = opmin
|
|
|
|
ops = []
|
|
pre_pos = len(self.pre)
|
|
post_pos = len(self.post)
|
|
while pre_pos > 0 or post_pos > 0:
|
|
cur_op = self.op_matrix[pre_pos][post_pos]
|
|
cost_pre = pre_pos
|
|
cost_post = post_pos
|
|
row_m, col_m = self.moves[cur_op]
|
|
pre_pos += row_m
|
|
post_pos += col_m
|
|
if cur_op == "S":
|
|
pre_val = self.pre[pre_pos]
|
|
post_val = self.post[post_pos]
|
|
ops.append(
|
|
(
|
|
cur_op,
|
|
(pre_val, post_val),
|
|
self.subst_cost(pre_val, post_val, cost_pre, cost_post),
|
|
)
|
|
)
|
|
if cur_op == "I":
|
|
post_val = self.post[post_pos]
|
|
ops.append(
|
|
(cur_op, post_val, self.insert_cost(post_val, cost_pre, cost_post))
|
|
)
|
|
if cur_op == "D":
|
|
pre_val = self.pre[pre_pos]
|
|
ops.append(
|
|
(cur_op, pre_val, self.del_cost(pre_val, cost_pre, cost_post))
|
|
)
|
|
|
|
cached_result = {
|
|
"count": self.work_matrix[-1][-1],
|
|
"ops": ops[::-1],
|
|
}
|
|
|
|
self.get_full_cache()[self.cache_key()] = cached_result
|
|
|
|
return cached_result
|
|
|
|
|
|
class InlineLevenshtein(Levenshtein):
|
|
""" Levenshtein distance for edition of a single line """
|
|
|
|
def insert_cost(self, to_insert, pre_idx, post_idx):
|
|
return 1
|
|
|
|
def del_cost(self, to_delete, pre_idx, post_idx):
|
|
return 1 # 'x'
|
|
|
|
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
|
|
if subst_from == subst_to:
|
|
return 0
|
|
return 1
|
|
|
|
|
|
class HunkLevenshtein(Levenshtein):
|
|
""" Levenshtein distance for hunk edition, ie. should we delete a whole line,
|
|
insert a whole line or substitute one line with another """
|
|
|
|
def insert_cost(self, to_insert, pre_idx, post_idx):
|
|
indent = 0
|
|
while indent < len(to_insert) and to_insert[indent] == " ":
|
|
indent += 1
|
|
indent_cost = 3 if indent else 0
|
|
return indent_cost + len(to_insert.strip()) + 1
|
|
|
|
def del_cost(self, to_delete, pre_idx, post_idx):
|
|
return 2
|
|
|
|
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
|
|
res = InlineLevenshtein(subst_from, subst_to).compute()
|
|
return res["count"]
|