208 lines
6.2 KiB
Python
208 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
|
|
|
import threading
|
|
import queue
|
|
import sys
|
|
import yaml
|
|
import paramiko
|
|
|
|
|
|
CONFIG = {"username": "tbastian", "workload": 0.5}
|
|
|
|
|
|
class Task:
|
|
""" A task to be performed, consisting in a shell command, with some file
|
|
redirected to its standard input and some file its standard output is redirected
|
|
to. """
|
|
|
|
class Failure(Exception):
|
|
""" Raised if a Task failed to complete """
|
|
|
|
def __init__(self, msg=None):
|
|
self.msg = msg
|
|
|
|
def __init__(self, command, in_file, out_file, human_label):
|
|
self.command = command
|
|
self.in_file = in_file
|
|
self.out_file = out_file
|
|
self.human_label = human_label
|
|
|
|
def run(self, client):
|
|
""" Run the task on a given client """
|
|
try:
|
|
stdin, stdout, stderr = client.exec_command(self.command)
|
|
|
|
if self.in_file:
|
|
with open(self.in_file, "rb") as in_handle:
|
|
buf = in_handle.read(1024)
|
|
while buf:
|
|
stdin.write(buf)
|
|
buf = in_handle.read(1024)
|
|
stdin.close()
|
|
|
|
if self.out_file:
|
|
with open(self.out_file, "wb") as out_handle:
|
|
buf = stdout.read(1024)
|
|
while buf:
|
|
out_handle.write(buf)
|
|
buf = stdout.read(1024)
|
|
|
|
except paramiko.ssh_exception.SSHException:
|
|
raise Task.Failure("SSH Exception")
|
|
except FileNotFoundError as exn:
|
|
raise Task.Failure("{}".format(exn))
|
|
|
|
|
|
class WorkingThread(threading.Thread):
|
|
""" A thread actually getting work done on a given machine """
|
|
|
|
def __init__(self, host, workqueue, failures):
|
|
self.host = host
|
|
self.client = None
|
|
self.workqueue = workqueue
|
|
self.failures = failures
|
|
super().__init__()
|
|
|
|
def run(self):
|
|
self.client = paramiko.client.SSHClient()
|
|
self.client.load_system_host_keys()
|
|
self.client.connect(self.host, username=CONFIG["username"])
|
|
|
|
try:
|
|
while True:
|
|
task = self.workqueue.get_nowait() # Raises `Empty` when done
|
|
try:
|
|
task.run(self.client)
|
|
except Task.Failure as exn:
|
|
print(
|
|
"[{}] ERROR: task failed: {} - {}".format(
|
|
self.host, task.human_label, exn.msg
|
|
),
|
|
file=sys.stderr,
|
|
)
|
|
self.failures.append((task, exn.msg))
|
|
except queue.Empty:
|
|
pass
|
|
|
|
|
|
class HostsFile:
|
|
""" Interface to an hosts file """
|
|
|
|
def __init__(self, path):
|
|
self.path = path
|
|
self.hosts = {}
|
|
self._parse()
|
|
|
|
def _parse(self):
|
|
with open(self.path, "r") as handle:
|
|
parsed = yaml.safe_load(handle)
|
|
|
|
REQUIRED_FIELDS = ["host", "cores"]
|
|
for entry in parsed:
|
|
for field in REQUIRED_FIELDS:
|
|
if field not in entry:
|
|
if "host" in entry:
|
|
raise Exception(
|
|
"Host {} has no {}".format(entry["host"], field)
|
|
)
|
|
raise Exception("Host has no {}".format(field))
|
|
self.hosts[entry["host"]] = {"cores": entry["cores"]}
|
|
|
|
|
|
class TasksFile:
|
|
""" Interface to a tasks file """
|
|
|
|
def __init__(self, path):
|
|
self.path = path
|
|
self.queue = queue.Queue()
|
|
self._parse()
|
|
self.initial_task_count = self.queue.qsize()
|
|
|
|
def count_completed(self):
|
|
return self.initial_task_count - self.queue.qsize()
|
|
|
|
def percent_completed(self):
|
|
return (self.count_completed() / self.initial_task_count) * 100
|
|
|
|
def _parse(self):
|
|
with open(self.path, "r") as handle:
|
|
parsed = yaml.safe_load(handle)
|
|
|
|
for cmd_entry in parsed:
|
|
self._parse_cmd(cmd_entry)
|
|
|
|
def _parse_cmd(self, cmd):
|
|
REQUIRED_FIELDS = ["command", "data"]
|
|
for field in REQUIRED_FIELDS:
|
|
if field not in cmd:
|
|
if "command" in cmd:
|
|
raise Exception(
|
|
"Command {} has no {}".format(cmd["command"], field)
|
|
)
|
|
raise Exception("Command has no {}".format(field))
|
|
|
|
for data_entry in cmd["data"]:
|
|
task = Task(
|
|
cmd["command"],
|
|
data_entry.get("in"),
|
|
data_entry.get("out"),
|
|
data_entry.get("label", data_entry.get("in")),
|
|
)
|
|
self.queue.put(task)
|
|
|
|
|
|
class Orchestrator:
|
|
""" Combines it all and actually runs stuff """
|
|
|
|
def __init__(self):
|
|
self.hosts = HostsFile("hosts.yml")
|
|
self.tasks = TasksFile("tasks.yml")
|
|
self.failures = []
|
|
self.threads = []
|
|
|
|
self._spawn_host_threads()
|
|
|
|
def _spawn_host_threads(self):
|
|
for host in self.hosts.hosts:
|
|
host_details = self.hosts.hosts[host]
|
|
for core_id in range(int(host_details["cores"] * CONFIG["workload"])):
|
|
if len(self.threads) >= self.tasks.initial_task_count:
|
|
return
|
|
self.threads.append(
|
|
WorkingThread(host, self.tasks.queue, self.failures)
|
|
)
|
|
|
|
def start(self):
|
|
for thread in self.threads:
|
|
thread.start()
|
|
|
|
for thread in self.threads:
|
|
while thread.is_alive():
|
|
thread.join(timeout=1.0)
|
|
|
|
print(
|
|
"\r[{:05.1f}% - {}/{}]".format(
|
|
self.tasks.percent_completed(),
|
|
self.tasks.count_completed(),
|
|
self.tasks.initial_task_count,
|
|
),
|
|
end="",
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
orchestrator = Orchestrator()
|
|
orchestrator.start()
|
|
print("")
|
|
|
|
if orchestrator.failures:
|
|
nb_failures = len(orchestrator.failures)
|
|
print(
|
|
"{} FAILURES ({:05.1f}%)".format(
|
|
nb_failures, 100 * nb_failures / orchestrator.tasks.initial_task_count
|
|
),
|
|
file=sys.stderr,
|
|
)
|
|
for task, msg in orchestrator.failures:
|
|
print("* {}: {}".format(task.human_label, msg), file=sys.stderr)
|