""" Implements Multiway number partitioning greedy algorithm """ import typing as t from sortedcontainers import SortedList __all__ = ["TaskId", "partition"] TaskId = t.NewType("TaskId", int) class PartitionException(Exception): """An exception occurring during partitioning""" class UnsolvableConflict(PartitionException): """Cannot partition set due to unsolvable conflicts""" class Bin: """A bin containing assigned tasks""" elts: list[TaskId] cost: int def __init__(self): self.elts = [] self.cost = 0 def add(self, task: TaskId, cost: int): assert task not in self.elts self.elts.append(task) self.cost += cost def __contains__(self, task: TaskId) -> bool: return task in self.elts def partition( bin_count: int, costs: dict[TaskId, int], multiplicity: dict[TaskId, int] ) -> list[list[TaskId]]: """Partitions the tasks, each with cost `costs[i]`, into `bin_count` bins. Each task has multiplicity `multiplicity[i]`, copies of the same task being mutually exclusive (ie. cannot be in the same bin)""" bins = SortedList([Bin() for _ in range(bin_count)], key=lambda x: x.cost) ordered_tasks: list[TaskId] = [] for t_id, reps in multiplicity.items(): for _ in range(reps): ordered_tasks.append(t_id) ordered_tasks.sort(key=lambda x: costs[x], reverse=True) for task in ordered_tasks: least_full: Bin least_full_pos: int for pos, cur_bin in enumerate(bins): if task not in cur_bin: least_full = cur_bin least_full_pos = pos break else: raise UnsolvableConflict( "Pas assez de groupes pour affecter la tâche " + f"{task} {multiplicity[task]} fois." ) del bins[least_full_pos] least_full.add(task, costs[task]) bins.add(least_full) out: list[list[TaskId]] = [] for cur_bin in bins: out.append(cur_bin.elts) return out