Source code for ablinfer.remote.remote

#!/usr/bin/env python3

from collections import OrderedDict as OD
import logging
import os
from urllib.parse import urljoin
import time
from typing import List

import requests as r

from ..base import DispatchBase, DispatchException
from ..constants import DispatchStage
from .util import ProgressReporter, save_resp, FObjReadWrapper

def urljoin_b(*args):
    """Make urljoin behave like path.join to preserve my sanity."""
    if len(args) > 2:
        return urljoin(urljoin_b(*args[:-1])+'/', args[-1])
    elif len(args) == 2:
        return urljoin(args[0]+'/', args[1])
    return args[0]

[docs]class DispatchRemote(DispatchBase): """Class for dispatching to an ABLInfer server. A required ``base_url`` key is added to ``config``, which must be the server's base URL, which will be passed to :func:`urllib.parse.urljoin` to construct the query URLs. In addition, a ``session`` parameter is added to ``config`` which allows the user to provide a :class:`requests.Session` instance for SSL verification or authentication. """ def __init__(self, config=None): self.base_url = None self.session = None self.remote_session = None self.model_id = None super().__init__(config=config)
[docs] def get_model_list(self) -> List[str]: """Retrieve the list of models available on the site.""" with self._lock: resp = self.session.get(urljoin_b(self.base_url, "models")) resp.raise_for_status() return resp.json()["data"]
[docs] def get_model(self, model_id: str): """Retrieve a model from the site. This function assumes that the model received is normalized. :param model_id: The model's ID. """ with self._lock: resp = self.session.get(urljoin_b(self.base_url, "models", model_id)) resp.raise_for_status() return resp.json(object_pairs_hook=OD)["data"]
def _validate_config(self): super()._validate_config() self.base_url = self.config["base_url"] if not self.base_url.endswith('/'): self.base_url += '/' self.session = self.config["session"] if "session" in self.config else r.Session() ## Check the server logging.info("Trying server at %s..." % self.base_url) resp = self.session.get(self.base_url) resp.raise_for_status() resp = resp.json() if resp["data"]["server"] != "inferserver": raise ValueError("Unknown server %s" % repr(resp["data"]["server"])) def _validate_model_config(self): ## We need to check that the server has the correct version of the model first self.model_id = self.model["id"] try: model = self.get_model(self.model_id) except Exception as e: raise DispatchException("Unable to retrieve model from the server: %s" % repr(e)) if self.model["version"] != model["version"]: raise DispatchException("Version mismatch between server model and local model: server has v%s, we have v%s" % (model["version"], self.model["version"])) super()._validate_model_config() def _make_fmap(self): return {} def _make_flags(self, fmap): return [] def _make_command(self, flags): resp = self.session.post( urljoin_b(self.base_url, "models", self.model_id), json={ "inputs": {n: {"enabled": v["enabled"]} for n, v in self.model_config["inputs"].items()}, "params": self.model_config["params"], "outputs": {n: {"enabled": v["enabled"]} for n, v in self.model_config["outputs"].items()}, }, ) resp.raise_for_status() self.remote_session = resp.json()["data"]["session_id"] self.progress(DispatchStage.Validate, 0.5, 1, "Session ID is %s" % (self.remote_session)) return [] def _get_status(self): return self.session.get(urljoin_b(self.base_url, "sessions", self.remote_session)).json()["data"]["status"] def _save_input(self, fmap): total = len(self.model_config["inputs"]) for n, (name, v) in enumerate(self.model_config["inputs"].items()): if not v["enabled"]: logging.info("Skipping disabled input %s" % name) continue string = "Uploading %s..." % name logging.info(string) with open(v["value"], "rb") as f: header = f.read(1024) f.seek(0) resp = self.session.post(urljoin_b(self.base_url, "models", self.model["id"], "inputs", name, "check"), data=header) resp.raise_for_status() j = resp.json() if not j["data"]["acceptable"]: ft = j["data"]["acceptable"] ft = ft if ft is not None else "unknown" raise DispatchException("Invalid filetype for input %s; expected %s, got %s" % (name, self.model["inputs"][name]["extension"], j["data"]["filetype"])) fwrap = FObjReadWrapper(f, os.path.getsize(v["value"]), string, lambda f, s: self.progress(DispatchStage.Save, n/total + f/total, f, s), period=0.1) resp = self.session.put(urljoin_b(self.base_url, "sessions", self.remote_session, "inputs", name), headers={"Content-Type": "application/octet-stream"}, data=fwrap) resp.raise_for_status() if self._get_status() == "waiting": raise DispatchException("Session ID %s is still waiting for input, but all input has been provided, please report this" % self.remote_session) def _run_command(self, cmd): logging.info("Starting run...") resp = self.session.get(urljoin_b(self.base_url, "sessions", self.remote_session, "logs"), stream=True) for line in resp.iter_lines(5): if line == b'\0': continue self.progress(DispatchStage.Run, 0, 0, line.decode("utf-8")) ## Now the run is over logging.info("Logs ended, waiting for the session to finish...") while True: status = self._get_status() if status in ("complete", "failed"): break time.sleep(1) if status == "failed": raise DispatchException("Session ID %s failed, please report this" % self.remote_session) def _load_output(self, fmap): total = len(self.model_config["outputs"].items()) for n, (name, v) in enumerate(self.model_config["outputs"].items()): if not v["enabled"]: logging.info("Skipping disabled output %s" % name) continue logging.info("Saving output %s" % name) self._output_files.append(v["value"]) resp = self.session.get(urljoin_b(self.base_url, "sessions", self.remote_session, "outputs", name), stream=True) resp.raise_for_status() save_resp(resp, v["value"], "Saving output %s..." % name, lambda f, s: self.progress(DispatchStage.Load, n/total + f/total, f, s), period=0.1) def _cleanup(self, error=None): super()._cleanup(error=error) self.remote_session = None self.session = None self.model_id = None