diff --git a/patch2vimedit/hunk_changes.py b/patch2vimedit/hunk_changes.py new file mode 100644 index 0000000..c61ea3e --- /dev/null +++ b/patch2vimedit/hunk_changes.py @@ -0,0 +1,180 @@ +""" Tools to make smart edition to apply a hunk """ + +from collections import OrderedDict + + +class Levenshtein: + cache = {} + moves = { + "D": (-1, 0), + "S": (-1, -1), + "I": (0, -1), + } + + def __init__(self, pre, post): + self.pre = pre + self.post = post + self.work_matrix = None + self.op_matrix = None + + @classmethod + def get_full_cache(cls): + return cls.cache + + def cache_key(self): + if isinstance(self.pre, list): + return (tuple(self.pre), tuple(self.post)) + return (self.pre, self.post) + + def getcache(self): + return self.get_full_cache().get(self.cache_key(), None) + + @classmethod + def clean_cache(cls): + cls.cache = {} + + def insert_cost(self, to_insert, pre_idx, post_idx): + return 1 + + def del_cost(self, to_delete, pre_idx, post_idx): + return 1 + + def subst_cost(self, subst_from, subst_to, pre_idx, post_idx): + return int(subst_from != subst_to) + + @staticmethod + def _argmin(**kwargs): + """ Returns the first encountered pair (k, v) with lowest v wrt. `<`. """ + if not kwargs: + raise Exception("No arguments") + argmin = next(iter(kwargs)) + valmin = kwargs[argmin] + + for key in kwargs: + val = kwargs[key] + if val < valmin: + argmin = key + valmin = val + return (argmin, valmin) + + def compute(self): + cached = self.getcache() + if cached: + return cached + + self.op_matrix = [[None] + ["I"] * len(self.post)] + [ + ["D"] + [None] * (len(self.post)) for _ in range(len(self.pre)) + ] + self.work_matrix = [ + [0] * (len(self.post) + 1) for _ in range(len(self.pre) + 1) + ] + + for post_idx in range(len(self.post)): + self.work_matrix[0][post_idx + 1] = self.work_matrix[0][ + post_idx + ] + self.insert_cost(self.post[post_idx], 0, post_idx + 1) + for pre_idx in range(len(self.pre)): + self.work_matrix[pre_idx + 1][0] = self.work_matrix[pre_idx][ + 0 + ] + self.del_cost(self.pre[pre_idx], pre_idx + 1, 0) + + for pre_idx in range(len(self.pre)): + for post_idx in range(len(self.post)): + c_insert_cost = self.insert_cost( + self.post[post_idx], pre_idx + 1, post_idx + 1 + ) + c_del_cost = self.del_cost(self.pre[pre_idx], pre_idx + 1, post_idx + 1) + c_subst_cost = self.subst_cost( + self.pre[pre_idx], self.post[post_idx], pre_idx + 1, post_idx + 1 + ) + + opmin, costmin = self._argmin( + **OrderedDict( + [ + ("D", self.work_matrix[pre_idx][post_idx + 1] + c_del_cost), + ( + "I", + self.work_matrix[pre_idx + 1][post_idx] + c_insert_cost, + ), + ("S", self.work_matrix[pre_idx][post_idx] + c_subst_cost), + ] + ) + ) + + self.work_matrix[pre_idx + 1][post_idx + 1] = costmin + self.op_matrix[pre_idx + 1][post_idx + 1] = opmin + self.cur_mode = opmin + + ops = [] + pre_pos = len(self.pre) + post_pos = len(self.post) + while pre_pos > 0 or post_pos > 0: + cur_op = self.op_matrix[pre_pos][post_pos] + cost_pre = pre_pos + cost_post = post_pos + row_m, col_m = self.moves[cur_op] + pre_pos += row_m + post_pos += col_m + if cur_op == "S": + pre_val = self.pre[pre_pos] + post_val = self.post[post_pos] + ops.append( + ( + cur_op, + (pre_val, post_val), + self.subst_cost(pre_val, post_val, cost_pre, cost_post), + ) + ) + if cur_op == "I": + post_val = self.post[post_pos] + ops.append( + (cur_op, post_val, self.insert_cost(post_val, cost_pre, cost_post)) + ) + if cur_op == "D": + pre_val = self.pre[pre_pos] + ops.append( + (cur_op, pre_val, self.del_cost(pre_val, cost_pre, cost_post)) + ) + + cached_result = { + "count": self.work_matrix[-1][-1], + "ops": ops[::-1], + } + + self.get_full_cache()[self.cache_key()] = cached_result + + return cached_result + + +class InlineLevenshtein(Levenshtein): + """ Levenshtein distance for edition of a single line """ + + def insert_cost(self, to_insert, pre_idx, post_idx): + return 1 + + def del_cost(self, to_delete, pre_idx, post_idx): + return 1 # 'x' + + def subst_cost(self, subst_from, subst_to, pre_idx, post_idx): + if subst_from == subst_to: + return 0 + return 1 + + +class HunkLevenshtein(Levenshtein): + """ Levenshtein distance for hunk edition, ie. should we delete a whole line, + insert a whole line or substitute one line with another """ + + def insert_cost(self, to_insert, pre_idx, post_idx): + indent = 0 + while indent < len(to_insert) and to_insert[indent] == " ": + indent += 1 + indent_cost = 3 if indent else 0 + return indent_cost + len(to_insert.strip()) + 1 + + def del_cost(self, to_delete, pre_idx, post_idx): + return 2 + + def subst_cost(self, subst_from, subst_to, pre_idx, post_idx): + res = InlineLevenshtein(subst_from, subst_to).compute() + return res["count"]