diff --git a/scrapegraphai/utils/sys_dynamic_import.py b/scrapegraphai/utils/sys_dynamic_import.py new file mode 100644 index 00000000..30f75d15 --- /dev/null +++ b/scrapegraphai/utils/sys_dynamic_import.py @@ -0,0 +1,67 @@ +"""high-level module for dynamic importing of python modules at runtime + +source code inspired by https://gist.github.com/DiTo97/46f4b733396b8d7a8f1d4d22db902cfc +""" + +import sys +import typing + + +if typing.TYPE_CHECKING: + import types + + +def srcfile_import(modpath: str, modname: str) -> "types.ModuleType": + """imports a python module from its srcfile + + Args: + modpath: The srcfile absolute path + modname: The module name in the scope + + Returns: + The imported module + + Raises: + ImportError: If the module cannot be imported from the srcfile + """ + import importlib.util # noqa: F401 + + # + spec = importlib.util.spec_from_file_location(modname, modpath) + + if spec is None: + message = f"missing spec for module at {modpath}" + raise ImportError(message) + + if spec.loader is None: + message = f"missing spec loader for module at {modpath}" + raise ImportError(message) + + module = importlib.util.module_from_spec(spec) + + # adds the module to the global scope + sys.modules[modname] = module + + spec.loader.exec_module(module) + + return module + + +def dynamic_import(modname: str, message: str = "") -> None: + """imports a python module at runtime + + Args: + modname: The module name in the scope + message: The display message in case of error + + Raises: + ImportError: If the module cannot be imported at runtime + """ + if modname not in sys.modules: + try: + import importlib # noqa: F401 + + module = importlib.import_module(modname) + sys.modules[modname] = module + except ImportError as x: + raise ImportError(message) from x diff --git a/tests/utils/test_sys_dynamic_import.py b/tests/utils/test_sys_dynamic_import.py new file mode 100644 index 00000000..5f544de2 --- /dev/null +++ b/tests/utils/test_sys_dynamic_import.py @@ -0,0 +1,94 @@ +import os +import sys + +import pytest + +from scrapegraphai.utils.sys_dynamic_import import dynamic_import, srcfile_import + + +def _create_sample_file(filepath: str, content: str): + """creates a sample file at some path with some content""" + with open(filepath, "w", encoding="utf-8") as f: + f.write(content) + + +def _delete_sample_file(filepath: str): + """deletes a sample file at some path""" + if os.path.exists(filepath): + os.remove(filepath) + + +def test_srcfile_import_success(): + modpath = "example1.py" + modname = "example1" + + _create_sample_file(modpath, "def foo(): return 'bar'") + + module = srcfile_import(modpath, modname) + + assert hasattr(module, "foo") + assert module.foo() == "bar" + assert modname in sys.modules + + _delete_sample_file(modpath) + + +def test_srcfile_import_missing_spec(): + modpath = "nonexistent1.py" + modname = "nonexistent1" + + with pytest.raises(FileNotFoundError): + srcfile_import(modpath, modname) + + +def test_srcfile_import_missing_spec_loader(mocker): + modpath = "example2.py" + modname = "example2" + + _create_sample_file(modpath, "") + + mock_spec = mocker.Mock(loader=None) + + mocker.patch("importlib.util.spec_from_file_location", return_value=mock_spec) + + with pytest.raises(ImportError) as error_info: + srcfile_import(modpath, modname) + + assert "missing spec loader for module at" in str(error_info.value) + + _delete_sample_file(modpath) + + +def test_dynamic_import_success(): + print(sys.modules) + modname = "playwright" + assert modname not in sys.modules + + dynamic_import(modname) + + assert modname in sys.modules + + import playwright # noqa: F401 + + +def test_dynamic_import_module_already_imported(): + modname = "json" + + import json # noqa: F401 + + assert modname in sys.modules + + dynamic_import(modname) + + assert modname in sys.modules + + +def test_dynamic_import_import_error_with_custom_message(): + modname = "nonexistent2" + message = "could not import module" + + with pytest.raises(ImportError) as error_info: + dynamic_import(modname, message=message) + + assert str(error_info.value) == message + assert modname not in sys.modules