"""Base class for dispatch helpers."""
from __future__ import division
from abc import ABCMeta, abstractmethod
import logging
import os
from threading import RLock
from typing import List, Mapping, Any, Callable, Dict
from .processing import dispatch_processing
from .model import normalize_model_config, update_normalize_model
from .constants import DispatchStage
[docs]class DispatchException(Exception):
"""To be raised if a process runs into a problem.
:param stderr: The stderr output of the process, if any.
"""
def __init__(self, msg, stderr: str = None):
super().__init__(msg)
self.stderr = stderr
[docs]class DispatchBase(metaclass=ABCMeta):
"""Base class for dispatching to models.
This class cannot actually be instantiated; users are intended to instantiate a subclass that
dispatches using the appropriate method (e.g. :class:`ablinfer.docker.DispatchDocker`).
The general procedure when running is:
1. :meth:`_validate_config`
2. :meth:`_validate_model_config`
3. :meth:`_make_fmap`
4. :meth:`_make_flags`
5. :meth:`_make_command`
6. :meth:`_run_processing` (dispatches required nodes to :meth:`_clone` to copy them)
7. :meth:`_save_input`
8. :meth:`_run_command`
9. :meth:`_load_output`
10. :meth:`_run_processing`
11. :meth:`_cleanup`
Functions are expected to raise exceptions if they encounter any unfixable errors. If an
exception is raised, :meth:`_cleanup_all` will be called.
"""
def __init__(self, config: Mapping = None) -> None:
if config is None:
config = {}
self.config = config
self._created_files: List[str] = []
self._output_files: List[str] = []
self._pre_nodes: Dict[str, Any] = {}
self._lock = RLock()
self.model: Mapping[str, Any] = None
self.model_config: Mapping[str, Any] = None
self.progress: Callable[[DispatchStage, float, float, str], None] = None
self._validate_config()
[docs] def _validate_config(self) -> None:
"""Validate the non-model configuration."""
[docs] def _validate_model_config(self) -> None:
"""Validate the model configuration.
This first normalizes the model config, which also validates it. It then ensures that no
inputs are null, no two outputs are the same, and no outputs are the same as inputs.
Note that when overriding this method without delegating here, the new method **must**
normalize the model config by calling :func:``ablinfer.model.normalize_model_config``.
"""
self.model_config = normalize_model_config(self.model, self.model_config)
in_nodes = {i["value"] for i in self.model_config["inputs"].values() if i["enabled"]}
if None in in_nodes:
raise ValueError("Make sure that all inputs are selected!")
out_nodes = {i["value"] for i in self.model_config["outputs"].values() if i["enabled"] and i["value"] is not None}
if in_nodes.intersection(out_nodes).difference((None,)) or len(out_nodes) != sum((1 for i in self.model_config["outputs"].values() if i["enabled"] and i["value"] is not None)):
raise ValueError("Inputs and outputs must be different")
[docs] def _make_fmap_helper(self, actual_path: str) -> Mapping[str, str]:
## Strip the trailing slash of actual_path, if it has one
if len(actual_path) > 1 and actual_path[-1] in ('/', '\\'):
actual_path = actual_path[:-1]
fmap = {}
for s in ("inputs", "outputs",):
for k, v in self.model[s].items():
fmap[k] = actual_path + '/' + k + v["extension"]
return fmap
[docs] @abstractmethod
def _make_fmap(self) -> Mapping[str, str]:
"""Make the file map.
:returns: The created filemap, mapping the input/output name to its path as seen by
whatever the command will be dispatched to (e.g. for Docker, this would be the
path inside the container).
"""
def _clone(self, node: Any) -> Any:
"""Clone the given node for pre-processing.
This should be overwritten for implementations based on the format your nodes are in. Any
cloned node distinct from the input node is added to `self._pre_nodes`, which maps the
input name to the result of this function. These nodes should be cleaned up at in the
:meth:`_cleanup` function.
"""
return node
[docs] def _run_processing(self, inp: bool) -> None:
"""Run pre/post processing.
:param inp: Whether to do the input or output section.
"""
section = "inputs" if inp else "outputs"
process = "pre" if inp else "post"
stage = DispatchStage.Preprocess if inp else DispatchStage.Postprocess
if inp:
self._pre_nodes = {}
total = len(self.model[section])
for count, (o, member) in enumerate(self.model[section].items()):
node = self.model_config[section][o]["value"]
if process not in member or not self.model_config[section][o]["enabled"]:
continue
string = "Running %sprocessing for %s..." % (process, member["name"])
logging.info(string)
self.progress(stage, count/total, 0, string)
## Clone nodes if we're pre-processing
if inp and True in (i["enabled"] for i in self.model_config[section][o][process]):
## Make sure we have at least one pre-processing to run
cnode = self._clone(node)
if cnode != node:
self._pre_nodes[o] = cnode
node = cnode
total_ops = len(member[process])
for n, op in enumerate(member[process], 1):
logging.info("%d. %s" % (n, op["name"]))
self.progress(stage, count/total+(n-1)/(total*total_ops), (n-1)/total_ops, string+" %d. %s" % (n, op["name"]))
if not self.model_config[section][o][process][n-1]["enabled"]:
logging.info("- Not enabled: skipping")
continue
## TODO: Possibly keep going if a processing step fails?
dispatch_processing(op, self.model_config[section][o][process][n-1], node, member, self.model, self.model_config, inp=inp)
[docs] def _make_flags(self, fmap: Mapping[str, str]) -> List[str]:
"""Construct the flags to pass to the command.
:param fmap: The generated fmap.
"""
cmd = {}
def format_flag(f, v):
if not f:
return (v,)
elif f[-1] == '=':
return (f+v,)
return (f, v,)
for n in ("inputs", "outputs"):
for k, v in self.model_config[n].items():
if not v["enabled"]:
continue
cmd[k] = format_flag(self.model[n][k]["flag"], fmap[k])
for k in self.model_config["params"]:
if self.model["params"][k]["type"] == "bool": ## Handle boolean switches
if self.model_config["params"][k]["value"]:
cmd[k] = (self.model["params"][k]["flag"],)
else:
cmd[k] = format_flag(self.model["params"][k]["flag"], str(self.model_config["params"][k]["value"]))
## Set the flags in the proper order
if "order" in self.model:
order = self.model["order"]
else:
order = sorted(list(cmd.keys()))
acmd = []
for i in order:
if i not in cmd:
continue
acmd.extend(cmd[i])
return acmd
[docs] def _make_command(self, flags: List[str]) -> List[str]:
"""Make the final command.
:param flags: A list of the flags to be included.
:returns: The final command, to be passed to :meth:`_run_command_`.
"""
return flags
[docs] @abstractmethod
def _run_command(self, cmd: List[str]) -> None:
"""Run the actual command.
Must raise :class:`DispatchException` if the called process runs into a problem.
:param cmd: The command to execute.
"""
[docs] @abstractmethod
def _load_output(self, fmap: Mapping[str, str]) -> None:
"""Put the output files into the appropriate locations.
It is this function's responsibility to set self._output_files as the files are created, so
that they can be removed properly if the command fails.
:param fmap: The file map.
"""
[docs] def _cleanup(self, error: Exception = None) -> None:
"""Run successful cleanup.
This function should conduct all of the cleanup necessary for a successful run. This will
also be called if the run fails. This means that this function should run checks of its
own to see what actually needs to be cleaned up.
:param error: The exception, if one occurred.
"""
for f in self._created_files:
try:
os.remove(f)
except:
pass
self._created_files = []
[docs] def _cleanup_all(self, error: Exception) -> None:
"""Cleanup everything.
This will only be called if an error is raised.
:param error: The exception that occurred.
"""
self._cleanup(error=error)
for f in self._output_files:
try:
os.remove(f)
except:
pass
self._output_files = []
def _run(self) -> None:
"""Actual run method.
This is just to make sure the run method can't be called twice at once.
"""
logging.info("Processing started")
try:
logging.info("Validating model configuration...")
self.progress(DispatchStage.Validate, 0, 0, "Validating model configuration...")
self._validate_model_config()
self.progress(DispatchStage.Validate, 0.5, 0, "Assembling command...")
fmap = self._make_fmap()
cmd = self._make_command(self._make_flags(fmap))
self.progress(DispatchStage.Validate, 1, 0, "Done validation")
self.progress(DispatchStage.Preprocess, 0, 0, "Starting pre-processing...")
logging.info("Running pre-processing...")
self._run_processing(inp=True)
self.progress(DispatchStage.Preprocess, 1, 0, "Done pre-processing")
self.progress(DispatchStage.Save, 0, 0, "Saving inputs...")
self._save_input(fmap)
self.progress(DispatchStage.Save, 1, 0, "Done saving inputs")
logging.info("Command: "+' '.join(cmd))
logging.info("Running model...")
self.progress(DispatchStage.Run, 0, 0, "Running model...")
self._run_command(cmd)
self.progress(DispatchStage.Run, 1, 0, "Model finished")
self.progress(DispatchStage.Load, 0, 0, "Loading outputs...")
logging.info("Loading output...")
self._load_output(fmap)
self.progress(DispatchStage.Load, 1, 0, "Done loading outputs")
self.progress(DispatchStage.Postprocess, 0, 0, "Starting post-processing...")
logging.info("Running post-processing...")
self._run_processing(inp=False)
self.progress(DispatchStage.Postprocess, 1, 0, "Done post-processing")
logging.info("Cleaning up inputs...")
self._cleanup(error=None)
logging.info("Processing finished")
except Exception as e:
self._cleanup_all(e)
raise e
[docs] def run(self, model: Mapping, model_config: Mapping, progress: Callable[[DispatchStage, float, float, str], None] = print) -> None:
"""Run the model.
This is the entry point
:param model: The model specification.
:param model_config: The model configuration.
:param progress: an optional function accepting a ``DispatchStage``, a float on [0,1]
representing progress in the current stage, a float on [0,1] representing
progress in the current operation, and a string with detailed info.
"""
## Ensure only one run is underway
with self._lock:
self.model = update_normalize_model(model)[0]
self.model_config = normalize_model_config(self.model, model_config)
self.progress = progress
self._run()