#!/usr/bin/env python3 import threading import queue import sys import time import socket import random import yaml import paramiko CONFIG = {"username": "tbastian", "workload": 0.5} class Task: """ A task to be performed, consisting in a shell command, with some file redirected to its standard input and some file its standard output is redirected to. """ class Failure(Exception): """ Raised if a Task failed to complete """ def __init__(self, msg=None): self.msg = msg def __init__(self, command, in_file, out_file, human_label): self.command = command self.in_file = in_file self.out_file = out_file self.human_label = human_label def run(self, client): """ Run the task on a given client """ try: stdin, stdout, stderr = client.exec_command(self.command) if self.in_file: with open(self.in_file, "rb") as in_handle: buf = in_handle.read(1024) while buf: stdin.write(buf) buf = in_handle.read(1024) stdin.close() if self.out_file: with open(self.out_file, "wb") as out_handle: buf = stdout.read(1024) while buf: out_handle.write(buf) buf = stdout.read(1024) except paramiko.ssh_exception.SSHException: raise Task.Failure("SSH Exception") except FileNotFoundError as exn: raise Task.Failure("{}".format(exn)) class WorkingThread(threading.Thread): """ A thread actually getting work done on a given machine """ def __init__(self, host, addr, workqueue, failures): self.host = host self.addr = addr self.client = None self.workqueue = workqueue self.failures = failures super().__init__() def run(self): self.client = paramiko.client.SSHClient() self.client.load_system_host_keys() for n_try in range(3): try: self.client.connect(self.addr, username=CONFIG["username"]) break except Exception as exn: delay = 3 + random.random() * 4 print( ( "[{}] Failed to connect. Retry in {:.02f} seconds." + "Exception:\n{}" ).format(self.host, delay, exn), file=sys.stderr, ) time.sleep(delay) else: print( "[{}] Failed to connect, stopping thread.".format(self.host), file=sys.stderr, ) return try: while True: task = self.workqueue.get_nowait() # Raises `Empty` when done try: task.run(self.client) except Task.Failure as exn: print( "[{}] ERROR: task failed: {} - {}".format( self.host, task.human_label, exn.msg ), file=sys.stderr, ) self.failures.append((task, exn.msg)) except queue.Empty: pass class HostsFile: """ Interface to an hosts file """ def __init__(self, path): self.path = path self.hosts = {} self._parse() def _parse(self): with open(self.path, "r") as handle: parsed = yaml.safe_load(handle) REQUIRED_FIELDS = ["host", "cores"] for entry in parsed: for field in REQUIRED_FIELDS: if field not in entry: if "host" in entry: raise Exception( "Host {} has no {}".format(entry["host"], field) ) raise Exception("Host has no {}".format(field)) self.hosts[entry["host"]] = { "cores": entry["cores"], "addr": socket.gethostbyname(entry["host"]), } class TasksFile: """ Interface to a tasks file """ def __init__(self, path): self.path = path self.queue = queue.Queue() self._parse() self.initial_task_count = self.queue.qsize() def count_completed(self): return self.initial_task_count - self.queue.qsize() def percent_completed(self): return (self.count_completed() / self.initial_task_count) * 100 def _parse(self): with open(self.path, "r") as handle: parsed = yaml.safe_load(handle) for cmd_entry in parsed: self._parse_cmd(cmd_entry) def _parse_cmd(self, cmd): REQUIRED_FIELDS = ["command", "data"] for field in REQUIRED_FIELDS: if field not in cmd: if "command" in cmd: raise Exception( "Command {} has no {}".format(cmd["command"], field) ) raise Exception("Command has no {}".format(field)) for data_entry in cmd["data"]: task = Task( cmd["command"], data_entry.get("in"), data_entry.get("out"), data_entry.get("label", data_entry.get("in")), ) self.queue.put(task) class Orchestrator: """ Combines it all and actually runs stuff """ def __init__(self): self.hosts = HostsFile("hosts.yml") self.tasks = TasksFile("tasks.yml") self.failures = [] self.threads = [] self._spawn_host_threads() def _spawn_host_threads(self): for host in self.hosts.hosts: host_details = self.hosts.hosts[host] for core_id in range(int(host_details["cores"] * CONFIG["workload"])): if len(self.threads) >= self.tasks.initial_task_count: return self.threads.append( WorkingThread( host, host_details["addr"], self.tasks.queue, self.failures ) ) def start(self): for thread in self.threads: thread.start() for thread in self.threads: while thread.is_alive(): thread.join(timeout=1.0) print( "\r[{:05.1f}% - {}/{}]".format( self.tasks.percent_completed(), self.tasks.count_completed(), self.tasks.initial_task_count, ), end="", ) if __name__ == "__main__": orchestrator = Orchestrator() orchestrator.start() print("") if orchestrator.failures: nb_failures = len(orchestrator.failures) print( "{} FAILURES ({:05.1f}%)".format( nb_failures, 100 * nb_failures / orchestrator.tasks.initial_task_count ), file=sys.stderr, ) for task, msg in orchestrator.failures: print("* {}: {}".format(task.human_label, msg), file=sys.stderr)