Simple python script to distribute shell tasks over machines
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

235 lines
7.1 KiB

#!/usr/bin/env python3
import threading
import queue
import sys
import time
import socket
import random
import yaml
import paramiko
CONFIG = {"username": "tbastian", "workload": 0.5}
class Task:
""" A task to be performed, consisting in a shell command, with some file
redirected to its standard input and some file its standard output is redirected
to. """
class Failure(Exception):
""" Raised if a Task failed to complete """
def __init__(self, msg=None):
self.msg = msg
def __init__(self, command, in_file, out_file, human_label):
self.command = command
self.in_file = in_file
self.out_file = out_file
self.human_label = human_label
def run(self, client):
""" Run the task on a given client """
try:
stdin, stdout, stderr = client.exec_command(self.command)
if self.in_file:
with open(self.in_file, "rb") as in_handle:
buf = in_handle.read(1024)
while buf:
stdin.write(buf)
buf = in_handle.read(1024)
stdin.close()
if self.out_file:
with open(self.out_file, "wb") as out_handle:
buf = stdout.read(1024)
while buf:
out_handle.write(buf)
buf = stdout.read(1024)
except paramiko.ssh_exception.SSHException:
raise Task.Failure("SSH Exception")
except FileNotFoundError as exn:
raise Task.Failure("{}".format(exn))
class WorkingThread(threading.Thread):
""" A thread actually getting work done on a given machine """
def __init__(self, host, addr, workqueue, failures):
self.host = host
self.addr = addr
self.client = None
self.workqueue = workqueue
self.failures = failures
super().__init__()
def run(self):
self.client = paramiko.client.SSHClient()
self.client.load_system_host_keys()
for n_try in range(3):
try:
self.client.connect(self.addr, username=CONFIG["username"])
break
except Exception as exn:
delay = 3 + random.random() * 4
print(
(
"[{}] Failed to connect. Retry in {:.02f} seconds."
+ "Exception:\n{}"
).format(self.host, delay, exn),
file=sys.stderr,
)
time.sleep(delay)
else:
print(
"[{}] Failed to connect, stopping thread.".format(self.host),
file=sys.stderr,
)
return
try:
while True:
task = self.workqueue.get_nowait() # Raises `Empty` when done
try:
task.run(self.client)
except Task.Failure as exn:
print(
"[{}] ERROR: task failed: {} - {}".format(
self.host, task.human_label, exn.msg
),
file=sys.stderr,
)
self.failures.append((task, exn.msg))
except queue.Empty:
pass
class HostsFile:
""" Interface to an hosts file """
def __init__(self, path):
self.path = path
self.hosts = {}
self._parse()
def _parse(self):
with open(self.path, "r") as handle:
parsed = yaml.safe_load(handle)
REQUIRED_FIELDS = ["host", "cores"]
for entry in parsed:
for field in REQUIRED_FIELDS:
if field not in entry:
if "host" in entry:
raise Exception(
"Host {} has no {}".format(entry["host"], field)
)
raise Exception("Host has no {}".format(field))
self.hosts[entry["host"]] = {
"cores": entry["cores"],
"addr": socket.gethostbyname(entry["host"]),
}
class TasksFile:
""" Interface to a tasks file """
def __init__(self, path):
self.path = path
self.queue = queue.Queue()
self._parse()
self.initial_task_count = self.queue.qsize()
def count_completed(self):
return self.initial_task_count - self.queue.qsize()
def percent_completed(self):
return (self.count_completed() / self.initial_task_count) * 100
def _parse(self):
with open(self.path, "r") as handle:
parsed = yaml.safe_load(handle)
for cmd_entry in parsed:
self._parse_cmd(cmd_entry)
def _parse_cmd(self, cmd):
REQUIRED_FIELDS = ["command", "data"]
for field in REQUIRED_FIELDS:
if field not in cmd:
if "command" in cmd:
raise Exception(
"Command {} has no {}".format(cmd["command"], field)
)
raise Exception("Command has no {}".format(field))
for data_entry in cmd["data"]:
task = Task(
cmd["command"],
data_entry.get("in"),
data_entry.get("out"),
data_entry.get("label", data_entry.get("in")),
)
self.queue.put(task)
class Orchestrator:
""" Combines it all and actually runs stuff """
def __init__(self):
self.hosts = HostsFile("hosts.yml")
self.tasks = TasksFile("tasks.yml")
self.failures = []
self.threads = []
self._spawn_host_threads()
def _spawn_host_threads(self):
for host in self.hosts.hosts:
host_details = self.hosts.hosts[host]
for core_id in range(int(host_details["cores"] * CONFIG["workload"])):
if len(self.threads) >= self.tasks.initial_task_count:
return
self.threads.append(
WorkingThread(
host, host_details["addr"], self.tasks.queue, self.failures
)
)
def start(self):
for thread in self.threads:
thread.start()
for thread in self.threads:
while thread.is_alive():
thread.join(timeout=1.0)
print(
"\r[{:05.1f}% - {}/{}]".format(
self.tasks.percent_completed(),
self.tasks.count_completed(),
self.tasks.initial_task_count,
),
end="",
)
if __name__ == "__main__":
orchestrator = Orchestrator()
orchestrator.start()
print("")
if orchestrator.failures:
nb_failures = len(orchestrator.failures)
print(
"{} FAILURES ({:05.1f}%)".format(
nb_failures, 100 * nb_failures / orchestrator.tasks.initial_task_count
),
file=sys.stderr,
)
for task, msg in orchestrator.failures:
print("* {}: {}".format(task.human_label, msg), file=sys.stderr)