Implement homemade partitioning to handle conflicts

This commit is contained in:
Théophile Bastian 2023-03-04 13:03:29 +01:00
parent 0e7a6c54ae
commit 1869e644e1
4 changed files with 100 additions and 19 deletions

View file

@ -1,6 +1,6 @@
[mypy] [mypy]
check_untyped_defs = True check_untyped_defs = True
[mypy-prtpy.*] [mypy-sortedcontainers.*]
follow_imports = skip follow_imports = skip
ignore_missing_imports = True ignore_missing_imports = True

View file

@ -4,9 +4,9 @@ import random
from pathlib import Path from pathlib import Path
import logging import logging
import jinja2 as j2 import jinja2 as j2
import prtpy
from .config import Task, Category, Config from .config import Task, Category, Config
from .partition import TaskId, partition
from . import util from . import util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -49,9 +49,6 @@ class AssignError(Exception):
def assigner_taches(root_task: Category | Task, group_count: int): def assigner_taches(root_task: Category | Task, group_count: int):
"""Assigne les tâches aux groupes (multiway number partitioning)""" """Assigne les tâches aux groupes (multiway number partitioning)"""
TaskId = t.NewType("TaskId", int)
UniqueTask: t.TypeAlias = tuple[TaskId, int]
def flatten(task: Category | Task) -> list[Task]: def flatten(task: Category | Task) -> list[Task]:
if isinstance(task, Task): if isinstance(task, Task):
return [task] return [task]
@ -62,28 +59,36 @@ def assigner_taches(root_task: Category | Task, group_count: int):
all_tasks = flatten(root_task) all_tasks = flatten(root_task)
def pp_assigned_toughness(repart: list[list[UniqueTask]]) -> str: def pp_assigned_toughness(repart: list[list[TaskId]]) -> str:
"""Pretty-print the assigned toughness for each group""" """Pretty-print the assigned toughness for each group"""
out = [] out = []
for grp_id, grp in enumerate(repart): for grp_id, grp in enumerate(repart):
toughness: int = sum(map(lambda x: all_tasks[x[0]].tough, grp)) toughness: int = sum(map(lambda x: all_tasks[x].tough, grp))
out.append(f"{grp_id:2d}: {toughness:>3d}") out.append(f"{grp_id+1:2d}: {toughness:>3d}")
return "\n".join(out) return "\n".join(out)
opt_input: dict[UniqueTask, int] = {} costs: dict[TaskId, int] = {}
multiplicity: dict[TaskId, int] = {}
for task_id, task in enumerate(all_tasks): for task_id, task in enumerate(all_tasks):
for rep in range(task.nb_groups): t_id: TaskId = TaskId(task_id)
opt_input[(TaskId(task_id), rep)] = task.tough costs[t_id] = task.tough
repart: list[list[UniqueTask]] = prtpy.partition( multiplicity[t_id] = task.nb_groups
algorithm=prtpy.partitioning.greedy, repart: list[list[TaskId]] = partition(
numbins=group_count, bin_count=group_count,
items=opt_input, costs=costs,
multiplicity=multiplicity,
) )
# Sanity-check # Sanity-check
assigned_count = sum(map(len, repart))
task_count = sum(multiplicity.values())
if task_count != assigned_count:
raise AssignError(
f"{assigned_count} tâches ont été attribuées, mais il y en a {task_count} !"
)
for g_id, grp in enumerate(repart): for g_id, grp in enumerate(repart):
taskset: set[TaskId] = set() taskset: set[TaskId] = set()
for (task_id, _) in grp: for task_id in grp:
if task_id in taskset: if task_id in taskset:
raise AssignError( raise AssignError(
f"Le groupe {g_id + 1} a deux fois la tâche {task.qualified_name}" f"Le groupe {g_id + 1} a deux fois la tâche {task.qualified_name}"
@ -92,7 +97,7 @@ def assigner_taches(root_task: Category | Task, group_count: int):
# Actually assign # Actually assign
for g_id, grp in enumerate(repart): for g_id, grp in enumerate(repart):
for (task_id, _) in grp: for task_id in grp:
task = all_tasks[task_id] task = all_tasks[task_id]
if task.assigned is None: if task.assigned is None:
task.assigned = [g_id] task.assigned = [g_id]
@ -111,7 +116,10 @@ def export_short_md(config: Config, groupes: list[list[str]]) -> str:
def export_taskcat(grp: Task | Category) -> str: def export_taskcat(grp: Task | Category) -> str:
if isinstance(grp, Task): if isinstance(grp, Task):
assert grp.assigned is not None assert grp.assigned is not None
return f'* {grp.qualified_name}: {", ".join(map(lambda x: str(x+1), grp.assigned))}' return (
f"* {grp.qualified_name}: "
+ f'{", ".join(map(lambda x: str(x+1), grp.assigned))}'
)
out = "\n" + "#" * (2 + grp.depth) + f" {grp.name}" out = "\n" + "#" * (2 + grp.depth) + f" {grp.name}"
if grp.time: if grp.time:
out += f" ({grp.time})" out += f" ({grp.time})"

View file

@ -0,0 +1,73 @@
""" Implements Multiway number partitioning greedy algorithm """
import typing as t
from sortedcontainers import SortedList
__all__ = ["TaskId", "partition"]
TaskId = t.NewType("TaskId", int)
class PartitionException(Exception):
"""An exception occurring during partitioning"""
class UnsolvableConflict(PartitionException):
"""Cannot partition set due to unsolvable conflicts"""
class Bin:
"""A bin containing assigned tasks"""
elts: list[TaskId]
cost: int
def __init__(self):
self.elts = []
self.cost = 0
def add(self, task: TaskId, cost: int):
assert task not in self.elts
self.elts.append(task)
self.cost += cost
def __contains__(self, task: TaskId) -> bool:
return task in self.elts
def partition(
bin_count: int, costs: dict[TaskId, int], multiplicity: dict[TaskId, int]
) -> list[list[TaskId]]:
"""Partitions the tasks, each with cost `costs[i]`, into `bin_count` bins. Each
task has multiplicity `multiplicity[i]`, copies of the same task being mutually
exclusive (ie. cannot be in the same bin)"""
bins = SortedList([Bin() for _ in range(bin_count)], key=lambda x: x.cost)
ordered_tasks: list[TaskId] = []
for t_id, reps in multiplicity.items():
for _ in range(reps):
ordered_tasks.append(t_id)
ordered_tasks.sort(key=lambda x: costs[x], reverse=True)
for task in ordered_tasks:
least_full: Bin
least_full_pos: int
for pos, cur_bin in enumerate(bins):
if task not in cur_bin:
least_full = cur_bin
least_full_pos = pos
break
else:
raise UnsolvableConflict(
"Pas assez de groupes pour affecter la tâche "
+ f"{task} {multiplicity[task]} fois."
)
del bins[least_full_pos]
least_full.add(task, costs[task])
bins.add(least_full)
out: list[list[TaskId]] = []
for cur_bin in bins:
out.append(cur_bin.elts)
return out

View file

@ -1,3 +1,3 @@
ruamel.yaml ruamel.yaml
Jinja2 Jinja2
prtpy sortedcontainers