Compare commits

...

8 commits

4 changed files with 492 additions and 21 deletions

View file

@ -0,0 +1,236 @@
""" Tools to make smart edition to apply a hunk """
from collections import OrderedDict
import logging
logger = logging.getLogger(__name__)
class Levenshtein:
cache = {}
moves = {
"D": (-1, 0),
"S": (-1, -1),
"I": (0, -1),
}
auto_coalesce = False
def __init__(self, pre, post):
self.pre = pre
self.post = post
self.work_matrix = None
self.op_matrix = None
@classmethod
def get_full_cache(cls):
return cls.cache
def cache_key(self):
if isinstance(self.pre, list):
return (tuple(self.pre), tuple(self.post))
return (self.pre, self.post)
def getcache(self):
return self.get_full_cache().get(self.cache_key(), None)
@classmethod
def clean_cache(cls):
cls.cache = {}
def insert_cost(self, to_insert, pre_idx, post_idx):
return 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 1
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
return int(subst_from != subst_to)
@staticmethod
def _argmin(**kwargs):
""" Returns the first encountered pair (k, v) with lowest v wrt. `<`. """
if not kwargs:
raise Exception("No arguments")
argmin = next(iter(kwargs))
valmin = kwargs[argmin]
for key in kwargs:
val = kwargs[key]
if val < valmin:
argmin = key
valmin = val
return (argmin, valmin)
def compute(self):
cached = self.getcache()
if cached:
return cached
self.op_matrix = [[None] + ["I"] * len(self.post)] + [
["D"] + [None] * (len(self.post)) for _ in range(len(self.pre))
]
self.work_matrix = [
[0] * (len(self.post) + 1) for _ in range(len(self.pre) + 1)
]
for post_idx in range(len(self.post)):
self.work_matrix[0][post_idx + 1] = self.work_matrix[0][
post_idx
] + self.insert_cost(self.post[post_idx], 0, post_idx + 1)
for pre_idx in range(len(self.pre)):
self.work_matrix[pre_idx + 1][0] = self.work_matrix[pre_idx][
0
] + self.del_cost(self.pre[pre_idx], pre_idx + 1, 0)
for pre_idx in range(len(self.pre)):
for post_idx in range(len(self.post)):
c_insert_cost = self.insert_cost(
self.post[post_idx], pre_idx + 1, post_idx + 1
)
c_del_cost = self.del_cost(self.pre[pre_idx], pre_idx + 1, post_idx + 1)
c_subst_cost = self.subst_cost(
self.pre[pre_idx], self.post[post_idx], pre_idx + 1, post_idx + 1
)
opmin, costmin = self._argmin(
**OrderedDict(
[
("D", self.work_matrix[pre_idx][post_idx + 1] + c_del_cost),
(
"I",
self.work_matrix[pre_idx + 1][post_idx] + c_insert_cost,
),
("S", self.work_matrix[pre_idx][post_idx] + c_subst_cost),
]
)
)
self.work_matrix[pre_idx + 1][post_idx + 1] = costmin
self.op_matrix[pre_idx + 1][post_idx + 1] = opmin
self.cur_mode = opmin
ops = []
pre_pos = len(self.pre)
post_pos = len(self.post)
while pre_pos > 0 or post_pos > 0:
cur_op = self.op_matrix[pre_pos][post_pos]
cost_pre = pre_pos
cost_post = post_pos
row_m, col_m = self.moves[cur_op]
pre_pos += row_m
post_pos += col_m
if cur_op == "S":
pre_val = self.pre[pre_pos]
post_val = self.post[post_pos]
if pre_val == post_val:
cur_op = "L"
ops.append(
(
cur_op,
(pre_pos, post_pos),
(pre_val, post_val),
self.subst_cost(pre_val, post_val, cost_pre, cost_post),
)
)
if cur_op == "I":
post_val = self.post[post_pos]
ops.append(
(
cur_op,
post_pos,
post_val,
self.insert_cost(post_val, cost_pre, cost_post),
)
)
if cur_op == "D":
pre_val = self.pre[pre_pos]
ops.append(
(
cur_op,
pre_pos,
pre_val,
self.del_cost(pre_val, cost_pre, cost_post),
)
)
ops = ops[::-1]
if self.auto_coalesce:
ops = self.coalesce_ops(ops)
cached_result = {
"count": self.work_matrix[-1][-1],
"ops": ops,
}
self.get_full_cache()[self.cache_key()] = cached_result
return cached_result
@staticmethod
def coalesce_ops(in_ops):
out_ops = []
coal_op = None
coal_pos = None
coal_span = 0
coal_vals = None
coal_cost = 0
for (op, pos, val, cost) in in_ops:
if op == coal_op:
coal_span += 1
if op == "S":
pre, post = val
coal_pre, coal_post = coal_vals
coal_vals = (coal_pre + pre, coal_post + post)
else:
coal_vals += val
coal_cost += cost
else:
out_ops.append((coal_op, (coal_pos, coal_span), coal_vals, coal_cost))
coal_op = op
coal_pos = pos
coal_span = 1
coal_vals = val
coal_cost = cost
out_ops.append((coal_op, (coal_pos, coal_span), coal_vals, coal_cost))
return out_ops
class InlineLevenshtein(Levenshtein):
""" Levenshtein distance for edition of a single line """
auto_coalesce = True
def insert_cost(self, to_insert, pre_idx, post_idx):
return 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 1 # 'x'
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
if subst_from == subst_to:
return 0
return 1
class HunkLevenshtein(Levenshtein):
""" Levenshtein distance for hunk edition, ie. should we delete a whole line,
insert a whole line or substitute one line with another """
def insert_cost(self, to_insert, pre_idx, post_idx):
indent = 0
while indent < len(to_insert) and to_insert[indent] == " ":
indent += 1
indent_cost = 3 if indent else 0
return indent_cost + len(to_insert.strip()) + 1
def del_cost(self, to_delete, pre_idx, post_idx):
return 2
def subst_cost(self, subst_from, subst_to, pre_idx, post_idx):
res = InlineLevenshtein(subst_from, subst_to).compute()
return res["count"] * 1.5

View file

@ -8,11 +8,21 @@ import tempfile
import patch
from tmux import TmuxSession
from vim_session import VimSession
import logging
logger = logging.getLogger(__name__)
def parse_args():
""" Parse command-line arguments """
parser = argparse.ArgumentParser(prog="patch2vimedit")
parser.add_argument(
"-g",
"--debug",
action="store_true",
help="Enable debug logging to /tmp/patch2log",
)
parser.add_argument(
"-C",
"--directory",
@ -46,7 +56,19 @@ def configure_home(home):
tmux_conf.write("set -g status off\n")
with vim_conf_path.open("w") as vim_conf:
vim_conf.write(
"syntax on\nset bg=dark\nset number\nset ts=4\nset sw=4\nset et\n"
"""
syntax on
set bg=dark
set number
set ts=4
set sw=4
set et
set so=5
set noautoindent
set nosmartindent
set nocindent
set indentexpr&
"""
)
return tmux_conf_path, vim_conf_path
@ -73,10 +95,22 @@ def apply_patchset(patchset):
tmux_session.session.kill_session()
def configure_log(args):
if args.debug:
logging.basicConfig(filename="/tmp/patch2log", level=logging.DEBUG)
else:
logging.basicConfig(level=logging.CRITICAL)
logging.getLogger("libtmux").setLevel(logging.INFO)
logging.getLogger("tmux").setLevel(logging.INFO)
def main():
""" Entry-point function """
args = parse_args()
configure_log(args)
patchset = get_patch(args.patch)
if args.directory:

View file

@ -1,9 +1,11 @@
import subprocess
import multiprocessing
import sys
import random
import libtmux
import time
import logging
import libtmux
logger = logging.getLogger(__name__)
class TmuxSession:
@ -15,10 +17,13 @@ class TmuxSession:
def __init__(self, session):
if self.tmux_server is None:
raise Exception("Server not initialized")
self.session = session
self.session_id = session.id
self.keyboard_speed = 0.00001
self.sendkey_log_buffer = []
@classmethod
def initialize_server(cls, socket_name=None, config_file=None):
""" Initialize the tmux server """
@ -67,6 +72,14 @@ class TmuxSession:
# Weirdly, `tmux send-keys 'blah;'` doesn't send the semicolon; and so
# does `tmux send-keys ';'`. We must escape it with a backslash.
arg = arg[:-1] + r"\;"
if arg in ["escape", "enter"]:
logline = "".join(self.sendkey_log_buffer) + " " + arg
logger.debug(logline)
self.sendkey_log_buffer = []
else:
self.sendkey_log_buffer.append(arg)
self.session.attached_pane.send_keys(
arg, suppress_history=False, enter=False
)

View file

@ -1,9 +1,40 @@
import sys
import logging
from hunk_changes import InlineLevenshtein, HunkLevenshtein
logger = logging.getLogger(__name__)
class LineMovement:
""" A movement to a given line, absolute or relative """
def __init__(self, absolute=None, relative=None):
self.absolute = absolute
self.relative = relative
if self.absolute and self.relative:
raise Exception("Cannot move both absolutely and relatively")
def __add__(self, num):
if not isinstance(num, type(0)):
raise Exception("Can only add an integer")
if self.absolute is not None:
return self.__class__(absolute=self.absolute + num)
if self.relative is not None:
return self.__class__(relative=self.relative + num)
def do(self, tmux_session):
if self.relative:
tmux_session.type_keys("escape", "{}j".format(self.relative))
elif self.absolute:
tmux_session.type_keys("escape", "{}G".format(self.absolute))
self.relative = 0
self.absolute = None
class VimSession:
""" A Vim session instrumented through tmux """
debug = False
def __init__(self, tmux_session, file_path=None):
self.tmux_session = tmux_session
self.file_path = file_path
@ -13,10 +44,65 @@ class VimSession:
self.tmux_session.type_keys(" {}".format(self.file_path))
self.tmux_session.type_keys("enter")
self.tmux_session.send_keys(":set paste", "enter")
self.mode = "command"
self.log_open = False
def log(self, msg):
if not self.debug:
return
self.set_mode("command")
if not self.log_open:
self.tmux_session.send_keys(":new", "enter", "C-w", "j")
self.log_open = True
self.tmux_session.send_keys("C-w", "k")
self.tmux_session.send_keys("o", msg.rstrip(), "escape")
self.tmux_session.send_keys("C-w", "j")
def set_mode(self, new_mode, dry_run=False, ofs_balance=True):
""" Sets Vim to mode `new_mode`. Returns the resulting cursor movement, in
column offset. If `dry_run`, do not actually change the mode, just compute the
offset. If `ofs_balance`, balance the induced offset with appropriate cursor
movement. """
if new_mode == self.mode:
return 0
if dry_run and ofs_balance:
return 0
ofs = 0
if self.mode != "command":
ofs -= 1
if not dry_run:
self.tmux_session.type_keys("escape")
self.mode = "command"
if ofs_balance and ofs: # implies `not dry_run`
if ofs < 0:
self.tmux_session.type_keys("{}l".format(-ofs))
else:
self.tmux_session.type_keys("{}h".format(ofs))
ofs = 0
if new_mode == "insert":
if not dry_run:
self.tmux_session.type_keys("i")
self.mode = "insert"
elif new_mode == "replace":
if not dry_run:
self.tmux_session.type_keys("R")
self.mode = "replace"
return ofs
def edit_file(self, file_path):
self.tmux_session.type_keys("escape", ":e ", file_path, "enter")
self.set_mode("command")
self.tmux_session.type_keys(":e ", file_path, "enter")
self.file_path = file_path
def apply_patchset(self, patchset):
@ -27,28 +113,130 @@ class VimSession:
source = patch.source.decode("utf8")
target = patch.target.decode("utf8")
if source != target:
self.tmux_session.type_keys(
"escape", ":!mv ", source, " ", target, "enter",
)
self.set_mode("command")
self.tmux_session.type_keys(":!mv ", source, " ", target, "enter")
if self.file_path != target:
self.edit_file(target)
for hunk in patch:
self.apply_hunk(hunk)
self.tmux_session.type_keys("escape", ":w", "enter")
self.set_mode("command")
self.tmux_session.type_keys(":w", "enter")
@staticmethod
def tabify(text):
""" Substitute groups of four spaces by a tabulation, as much as possible. """
return text.replace(" ", "\t")
def write_line(self, line):
""" Write a line to the vim buffer, assuming everything is set up for it and it
must be insterted above. """
line = line.rstrip()
if line.startswith(" "):
lead_spaces = 0
while lead_spaces < len(line) and line[lead_spaces] == " ":
lead_spaces += 1
line = line.strip()
self.set_mode("command")
self.tmux_session.type_keys(
"O", "escape", "{}a ".format(lead_spaces), "escape", "A", line, "escape"
)
else:
self.tmux_session.type_keys("O", line, "escape")
def subst_line(self, pre, post):
""" Substitute the current line of the vim buffer, assuming everything is set
up for it """
line_levenshtein = InlineLevenshtein(pre, post).compute()
ops = line_levenshtein["ops"]
edit_pos = 1
rel_pos = 0
for op, (_, span), values, _ in ops:
if op == "L":
edit_pos += span
rel_pos += span
else:
if rel_pos > 0:
self.set_mode("command")
self.tmux_session.type_keys("{}|".format(edit_pos))
rel_pos = 0
if op == "I":
self.set_mode("insert")
self.tmux_session.type_keys(self.tabify(values))
edit_pos += span
elif op == "D":
self.set_mode("command")
self.tmux_session.type_keys("x")
elif op == "S":
self.set_mode("replace")
self.tmux_session.type_keys(self.tabify(values[1]))
edit_pos += span
self.set_mode("command")
def apply_hunk(self, hunk):
# So far, very naive.
self.tmux_session.type_keys("escape", "{}G0".format(hunk.starttgt))
logger.debug("Applying hunk @{}/{}".format(hunk.startsrc, hunk.starttgt))
pre_lines = []
post_lines = []
cur_subhunk_line = hunk.starttgt
cur_target_line = hunk.starttgt
for b_line in hunk.text:
line = b_line.decode("utf8")
if line[0] == " ":
self.tmux_session.type_keys("j")
elif line[0] == "-":
self.tmux_session.type_keys("dd")
elif line[0] == "+":
self.tmux_session.type_keys("O", line.strip()[1:], "escape", "j")
u_line = b_line.decode("utf8")
op = u_line[0]
line = u_line[1:]
if op == "+":
post_lines.append(line)
cur_target_line += 1
elif op == "-":
pre_lines.append(line)
elif op == " ":
if pre_lines or post_lines:
logger.debug(
"\tApplying subhunk @{} span {}/{}".format(
cur_subhunk_line, len(pre_lines), len(post_lines)
)
)
self.apply_subhunk(pre_lines, post_lines, cur_subhunk_line)
cur_target_line += 1
cur_subhunk_line = cur_target_line
pre_lines = []
post_lines = []
if pre_lines or post_lines:
self.apply_subhunk(pre_lines, post_lines, cur_subhunk_line)
def apply_subhunk(self, pre_lines, post_lines, startline_target):
hunk_levenshtein = HunkLevenshtein(pre_lines, post_lines).compute()
line_ops = hunk_levenshtein["ops"]
line_mvt = LineMovement(absolute=startline_target)
for op, positions, values, cost in line_ops:
if op == "L":
self.log("LEAVE {} -- L{}".format(values[0].strip(), line_mvt.absolute))
line_mvt += 1
else:
line_mvt.do(self.tmux_session)
if op == "I":
self.log("INSERT {}".format(values.strip()))
self.write_line(values)
line_mvt += 1
elif op == "D":
self.log("DELETE {}".format(values.strip()))
self.tmux_session.type_keys("dd")
elif op == "S":
self.log("SUBST {}/{}".format(values[0].strip(), values[1].strip()))
self.subst_line(values[0].rstrip(), values[1].rstrip())
line_mvt += 1
def quit(self):
self.tmux_session.type_keys("escape", ":q", "enter")
self.set_mode("command")
self.tmux_session.type_keys(":qa!", "enter")