You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
235 lines
7.1 KiB
235 lines
7.1 KiB
#!/usr/bin/env python3 |
|
|
|
import threading |
|
import queue |
|
import sys |
|
import time |
|
import socket |
|
import random |
|
import yaml |
|
import paramiko |
|
|
|
|
|
CONFIG = {"username": "tbastian", "workload": 0.5} |
|
|
|
|
|
class Task: |
|
""" A task to be performed, consisting in a shell command, with some file |
|
redirected to its standard input and some file its standard output is redirected |
|
to. """ |
|
|
|
class Failure(Exception): |
|
""" Raised if a Task failed to complete """ |
|
|
|
def __init__(self, msg=None): |
|
self.msg = msg |
|
|
|
def __init__(self, command, in_file, out_file, human_label): |
|
self.command = command |
|
self.in_file = in_file |
|
self.out_file = out_file |
|
self.human_label = human_label |
|
|
|
def run(self, client): |
|
""" Run the task on a given client """ |
|
try: |
|
stdin, stdout, stderr = client.exec_command(self.command) |
|
|
|
if self.in_file: |
|
with open(self.in_file, "rb") as in_handle: |
|
buf = in_handle.read(1024) |
|
while buf: |
|
stdin.write(buf) |
|
buf = in_handle.read(1024) |
|
stdin.close() |
|
|
|
if self.out_file: |
|
with open(self.out_file, "wb") as out_handle: |
|
buf = stdout.read(1024) |
|
while buf: |
|
out_handle.write(buf) |
|
buf = stdout.read(1024) |
|
|
|
except paramiko.ssh_exception.SSHException: |
|
raise Task.Failure("SSH Exception") |
|
except FileNotFoundError as exn: |
|
raise Task.Failure("{}".format(exn)) |
|
|
|
|
|
class WorkingThread(threading.Thread): |
|
""" A thread actually getting work done on a given machine """ |
|
|
|
def __init__(self, host, addr, workqueue, failures): |
|
self.host = host |
|
self.addr = addr |
|
self.client = None |
|
self.workqueue = workqueue |
|
self.failures = failures |
|
super().__init__() |
|
|
|
def run(self): |
|
self.client = paramiko.client.SSHClient() |
|
self.client.load_system_host_keys() |
|
for n_try in range(3): |
|
try: |
|
self.client.connect(self.addr, username=CONFIG["username"]) |
|
break |
|
except Exception as exn: |
|
delay = 3 + random.random() * 4 |
|
print( |
|
( |
|
"[{}] Failed to connect. Retry in {:.02f} seconds." |
|
+ "Exception:\n{}" |
|
).format(self.host, delay, exn), |
|
file=sys.stderr, |
|
) |
|
time.sleep(delay) |
|
else: |
|
print( |
|
"[{}] Failed to connect, stopping thread.".format(self.host), |
|
file=sys.stderr, |
|
) |
|
return |
|
|
|
try: |
|
while True: |
|
task = self.workqueue.get_nowait() # Raises `Empty` when done |
|
try: |
|
task.run(self.client) |
|
except Task.Failure as exn: |
|
print( |
|
"[{}] ERROR: task failed: {} - {}".format( |
|
self.host, task.human_label, exn.msg |
|
), |
|
file=sys.stderr, |
|
) |
|
self.failures.append((task, exn.msg)) |
|
except queue.Empty: |
|
pass |
|
|
|
|
|
class HostsFile: |
|
""" Interface to an hosts file """ |
|
|
|
def __init__(self, path): |
|
self.path = path |
|
self.hosts = {} |
|
self._parse() |
|
|
|
def _parse(self): |
|
with open(self.path, "r") as handle: |
|
parsed = yaml.safe_load(handle) |
|
|
|
REQUIRED_FIELDS = ["host", "cores"] |
|
for entry in parsed: |
|
for field in REQUIRED_FIELDS: |
|
if field not in entry: |
|
if "host" in entry: |
|
raise Exception( |
|
"Host {} has no {}".format(entry["host"], field) |
|
) |
|
raise Exception("Host has no {}".format(field)) |
|
self.hosts[entry["host"]] = { |
|
"cores": entry["cores"], |
|
"addr": socket.gethostbyname(entry["host"]), |
|
} |
|
|
|
|
|
class TasksFile: |
|
""" Interface to a tasks file """ |
|
|
|
def __init__(self, path): |
|
self.path = path |
|
self.queue = queue.Queue() |
|
self._parse() |
|
self.initial_task_count = self.queue.qsize() |
|
|
|
def count_completed(self): |
|
return self.initial_task_count - self.queue.qsize() |
|
|
|
def percent_completed(self): |
|
return (self.count_completed() / self.initial_task_count) * 100 |
|
|
|
def _parse(self): |
|
with open(self.path, "r") as handle: |
|
parsed = yaml.safe_load(handle) |
|
|
|
for cmd_entry in parsed: |
|
self._parse_cmd(cmd_entry) |
|
|
|
def _parse_cmd(self, cmd): |
|
REQUIRED_FIELDS = ["command", "data"] |
|
for field in REQUIRED_FIELDS: |
|
if field not in cmd: |
|
if "command" in cmd: |
|
raise Exception( |
|
"Command {} has no {}".format(cmd["command"], field) |
|
) |
|
raise Exception("Command has no {}".format(field)) |
|
|
|
for data_entry in cmd["data"]: |
|
task = Task( |
|
cmd["command"], |
|
data_entry.get("in"), |
|
data_entry.get("out"), |
|
data_entry.get("label", data_entry.get("in")), |
|
) |
|
self.queue.put(task) |
|
|
|
|
|
class Orchestrator: |
|
""" Combines it all and actually runs stuff """ |
|
|
|
def __init__(self): |
|
self.hosts = HostsFile("hosts.yml") |
|
self.tasks = TasksFile("tasks.yml") |
|
self.failures = [] |
|
self.threads = [] |
|
|
|
self._spawn_host_threads() |
|
|
|
def _spawn_host_threads(self): |
|
for host in self.hosts.hosts: |
|
host_details = self.hosts.hosts[host] |
|
for core_id in range(int(host_details["cores"] * CONFIG["workload"])): |
|
if len(self.threads) >= self.tasks.initial_task_count: |
|
return |
|
self.threads.append( |
|
WorkingThread( |
|
host, host_details["addr"], self.tasks.queue, self.failures |
|
) |
|
) |
|
|
|
def start(self): |
|
for thread in self.threads: |
|
thread.start() |
|
|
|
for thread in self.threads: |
|
while thread.is_alive(): |
|
thread.join(timeout=1.0) |
|
|
|
print( |
|
"\r[{:05.1f}% - {}/{}]".format( |
|
self.tasks.percent_completed(), |
|
self.tasks.count_completed(), |
|
self.tasks.initial_task_count, |
|
), |
|
end="", |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
orchestrator = Orchestrator() |
|
orchestrator.start() |
|
print("") |
|
|
|
if orchestrator.failures: |
|
nb_failures = len(orchestrator.failures) |
|
print( |
|
"{} FAILURES ({:05.1f}%)".format( |
|
nb_failures, 100 * nb_failures / orchestrator.tasks.initial_task_count |
|
), |
|
file=sys.stderr, |
|
) |
|
for task, msg in orchestrator.failures: |
|
print("* {}: {}".format(task.human_label, msg), file=sys.stderr)
|
|
|