import importlib
import sys
from .modules import BackendModule, PyTorchBackend, PyGBackend, DGLBackend
__all__ = ["backend",
"set_backend",
"allowed_backends",
"backend_dict",
"file_ext",
"set_file_ext"]
# used to store the models or weights for `PyTorch`
_EXT = ".pth"
_DEFAULT_BACKEND = PyTorchBackend()
_BACKEND = _DEFAULT_BACKEND
_ALL_BACKENDS = {PyTorchBackend, PyGBackend, DGLBackend, }
_BACKEND_DICT = {}
[docs]def allowed_backends():
"""Return the allowed backends."""
return tuple(backend_dict().keys())
[docs]def backend_dict():
return _BACKEND_DICT
def _set_backend_dict():
global _BACKEND_DICT
_BACKEND_DICT = {}
for bkd in _ALL_BACKENDS:
for name in bkd.alias:
_BACKEND_DICT[name] = bkd
[docs]def backend(module_name=None):
"""Publicly accessible method
for determining the current backend.
Parameters:
--------
module_name: String or 'BackendModule', optional.
`'torch'`, `PyTorchBackend`, `'pyg`, etc.
if not specified, return the current default backend module.
Returns:
--------
The backend module.
E.g. `'PyTorch 1.6.0+cpu Backend'`.
Example:
--------
>>> graphgallery.backend()
'PyTorch 1.6.0+cpu Backend'
"""
if module_name is None:
return _BACKEND
elif isinstance(module_name, BackendModule):
return module_name
else:
module_name = str(module_name)
module = _BACKEND_DICT.get(module_name.lower(), None)
if module is None:
raise ValueError(
f"Unsupported backend module name: '{module_name}', expected one of {allowed_backends()}."
)
return module()
def set_to_default_backend():
"""Set the current backend to default"""
global _BACKEND
_BACKEND = _DEFAULT_BACKEND
return _BACKEND
[docs]def set_backend(module_name=None):
"""Set the default backend module.
Parameters:
----------
module_name: String or 'BackendModule', optional.
`'th'`, `'torch'`, `'pytorch'`.
Example:
--------
>>> graphgallery.backend()
'PyTorch 1.6.0+cpu Backend'
Raises:
--------
ValueError: In case of invalid value.
"""
_backend = backend(module_name)
global _BACKEND
if _backend != _BACKEND:
_BACKEND = _backend
try:
# gallery models
from graphgallery.gallery import nodeclas
from graphgallery.gallery import graphclas
from graphgallery.gallery import linkpred
importlib.reload(nodeclas)
importlib.reload(graphclas)
importlib.reload(linkpred)
except Exception as e:
print(
f"Something went wrong when switching to other backend.",
file=sys.stderr)
raise e
return _BACKEND
[docs]def file_ext():
"""Returns the checkpoint filename suffix(extension) for the training model
Returns
-------
str
".pth" by default
"""
return _EXT
[docs]def set_file_ext(ext: str):
"""Set the filename suffix(extension)
"""
global _EXT
_EXT = ext
return _EXT
_set_backend_dict()