Source code for ablinfer.slicer.dispatchslicer
"""Module for integrating ABLInfer into Slicer."""
import logging
import os
import slicer
from .. import DispatchBase
from ..docker import DispatchDocker
from ..remote import DispatchRemote
from .processing import __name__ as _
[docs]class SlicerDispatchMixin(DispatchBase):
"""Mixin for dispatching from Slicer.
A ``tmp_path`` key is added to ``config``, which must contain the location to store temporary
files for dispatching.
This class does not implement any actual dispatching; this must be combined with the
appropriate dispatcher to function properly, e.g.::
class SlicerDispatchActual(SlicerDispatchMixin, DispatchActual):
pass
"""
def __init__(self, config):
self.tmp_path = None
self._input_nodes = {}
super().__init__(config)
def _validate_config(self):
super()._validate_config()
self.tmp_path = self.config["tmp_path"]
if not os.path.isdir(self.tmp_path):
os.makedirs(self.tmp_path)
@staticmethod
def _clone(a, b=None):
"""Try to clone a node.
This is copied from qSlicerSubjectHierarchyModuleLogic.cxx, I can't figure out how to call
that directly and it doesn't support cloning into another node anyways.
@param a the original node
@param b the node to clone into; if None, creates one
@returns the created node
"""
## TODO: Move this in to an external library?
if b is None:
b = slicer.mrmlScene.AddNewNodeByClass(a.GetClassName())
name = a.GetName()
else:
name = b.GetName()
## Clone the display node
a_dn = a.GetDisplayNode()
if b.GetDisplayNode():
b_dn = b.GetDisplayNode()
else:
b_dn = slicer.mrmlScene.AddNewNodeByClass(a_dn.GetClassName())
b_dn.Copy(a_dn)
b_dn.SetName(name + "_Display")
b.SetAndObserveDisplayNodeID(b_dn.GetID())
## Clone storage node
a_sn = a.GetStorageNode()
if a_sn:
if b.GetStorageNode():
b_sn = b.GetStorageNode()
else:
b_sn = slicer.mrmlScene.AddNewNodeByClass(a_sn.GetClassName())
b_sn.Copy(a_sn)
if a_sn.GetFileName():
b_sn.SetFileName(a_sn.GetFileName())
b.SetAndObserveStorageNodeID(b_sn.GetID())
## Finally, do the copy
b.Copy(a)
b.SetName(name)
b.SetAndObserveDisplayNodeID(b_dn)
b.SetAndObserveStorageNodeID(b_sn)
## Trigger update
b_ptn = b.GetParentTransformNode()
if b_ptn:
b_ptn.Modified()
return b
def _save_input(self, fmap):
self._input_nodes = {}
for k, v in self.model["inputs"].items():
if not self.model_config["inputs"][k]["enabled"]:
continue
## Don't use os.path.join here: the Docker container might not have the same OS as the
## host machine, which is where we are now. Forward slashes should work on any system,
## so just use them here; we already removed any trailing slash from actual_path
lpath = os.path.join(self.tmp_path, k+v["extension"])
## Write it to the path on the local machine
if k in self._pre_nodes:
actual_node = self._pre_nodes[k]
else:
actual_node = self.model_config["inputs"][k]["value"]
ret = slicer.util.saveNode(actual_node, lpath)
self._created_files.append(lpath)
self._input_nodes[k] = self.model_config["inputs"][k]["value"]
self.model_config["inputs"][k]["value"] = lpath
if not ret:
raise Exception("Unable to save input \"%s\" to file. THIS SHOULD NOT HAPPEN" % v["name"])
## Now let the actual dispatcher put the files into the container
super()._save_input(fmap)
def _load_output(self, fmap):
output_nodes = {}
## Point the dispatcher to the local files
for k, v in self.model_config["outputs"].items():
output_nodes[k] = v.copy()
v["value"] = os.path.join(self.tmp_path, k+self.model["outputs"][k]["extension"])
## Load them to the local disk
super()._load_output(fmap=fmap)
## Restore the input nodes
for k, v in self._input_nodes.items():
if not self.model_config["inputs"][k]["enabled"]:
continue
self.model_config["inputs"][k]["value"] = v
## Now load them into Slicer
for k, member in self.model["outputs"].items():
if not self.model_config["outputs"][k]["enabled"]:
continue
logging.info("Loading \"%s\"..." % member["name"])
of = self.model_config["outputs"][k]["value"]
if member["type"] == "segmentation":
if member["labelmap"]:
logging.info("- Loading as LabelVolume")
_, lvnode = slicer.util.loadLabelVolume(of, returnNode=True)
if lvnode is None:
raise Exception("Missing output %s! Please report this." % member["name"])
node = slicer.mrmlScene.AddNewNodeByClass("vtkMRMLSegmentationNode")
logging.info("- Converting to segmentation")
slicer.vtkSlicerSegmentationsModuleLogic.ImportLabelmapToSegmentationNode(lvnode, node, "")
slicer.mrmlScene.RemoveNode(lvnode)
else:
_, node = slicer.util.loadSegmentation(of, returnNode=True)
## Fill in the colours and names
## First, map the label values to the actual segment objects
segmap = {}
segmentation = node.GetSegmentation()
display_node = node.GetDisplayNode()
if display_node is None:
logging.warning("Can't find display node for segmentation, opacity will be skipped")
for i in range(segmentation.GetNumberOfSegments()):
thisseg, thisid = segmentation.GetNthSegment(i), segmentation.GetNthSegmentID(i)
segmap[thisseg.GetLabelValue()] = (thisseg, thisid)
## Now set the names and colours
for label in set(member["colours"]).union(member["names"]):
try:
ilabel = int(label)
except ValueError:
logging.warning("Invalid label %s, ignoring" % (repr(label)))
continue
if ilabel not in segmap:
logging.warning("Couldn't find segment matching label %d, ignoring" % label)
continue
theseg, theid = segmap[ilabel]
if label in member["colours"]:
thecolour = member["colours"][label]
theseg.SetColor(tuple(thecolour[:3]))
if len(thecolour) == 4: ## Opacity
display_node.SetSegmentOpacity3D(theid, thecolour[3])
if label in member["names"]:
theseg.SetName(member["names"][label])
elif member["type"] == "volume":
if member["labelmap"]:
_, node = slicer.util.loadLabelVolume(of, returnNode=True)
else:
_, node = slicer.util.loadVolume(of, returnNode=True)
else:
raise Exception("Unknown output type %s" % repr(member["type"]))
## Make sure the output is where it should be
if output_nodes[k]["value"] is not None:
## Clone it
self.clone(node, output_nodes[k]["value"])
slicer.mrmlScene.RemoveNode(node)
node = output_nodes[k]["value"]
self.model_config["outputs"][k]["value"] = node
def _cleanup(self, error=None):
## Get rid of output too
self._created_files.extend(self._output_files)
super()._cleanup(error=error)
## Remove the pre-processing clones
for v in self._pre_nodes.values():
slicer.mrmlScene.RemoveNode(v)
self._pre_nodes = {}
self._input_nodes = {}
[docs]class SlicerDispatchDocker(SlicerDispatchMixin, DispatchDocker):
"""Convenience class for dispatching to Docker from Slicer."""
pass
[docs]class SlicerDispatchRemote(SlicerDispatchMixin, DispatchRemote):
"""Convenience class for dispatching to an ABLInfer server from Slicer."""
pass