diff --git a/_drgn.pyi b/_drgn.pyi index 22fdb0b4..dceac723 100644 --- a/_drgn.pyi +++ b/_drgn.pyi @@ -259,6 +259,7 @@ class Program: ``struct task_struct *`` object. """ ... + @overload def type(self, name: str, filename: Optional[str] = None) -> Type: """ Get the type with the given name. @@ -273,6 +274,28 @@ class Program: the given file """ ... + @overload + # type is positional-only. + def type(self, type: Type) -> Type: + """ + Return the given type. + + This is mainly useful so that helpers can use ``prog.type()`` to get a + :class:`Type` regardless of whether they were given a :class:`str` or a + :class:`Type`. For example: + + .. code-block:: python3 + + def my_helper(obj: Object, type: Union[str, Type]) -> bool: + # type may be str or Type. + type = obj.prog_.type(type) + # type is now always Type. + return sizeof(obj) > sizeof(type) + + :param type: Type. + :return: The exact same type. + """ + ... def threads(self) -> Iterator[Thread]: """Get an iterator over all of the threads in the program.""" ... diff --git a/drgn/helpers/linux/list.py b/drgn/helpers/linux/list.py index 2357389a..fc504fa0 100644 --- a/drgn/helpers/linux/list.py +++ b/drgn/helpers/linux/list.py @@ -85,9 +85,7 @@ def list_first_entry_or_null( head = head.read_() pos = head.next.read_() if pos == head: - if isinstance(type, str): - type = head.prog_.type(type) - return NULL(head.prog_, head.prog_.pointer_type(type)) + return NULL(head.prog_, head.prog_.pointer_type(head.prog_.type(type))) else: return container_of(pos, type, member) diff --git a/libdrgn/python/program.c b/libdrgn/python/program.c index f263557e..07249d4e 100644 --- a/libdrgn/python/program.c +++ b/libdrgn/python/program.c @@ -589,16 +589,32 @@ static PyObject *Program_find_type(Program *self, PyObject *args, PyObject *kwds { static char *keywords[] = {"name", "filename", NULL}; struct drgn_error *err; - const char *name; + PyObject *name_or_type; struct path_arg filename = {.allow_none = true}; - struct drgn_qualified_type qualified_type; - bool clear; - - if (!PyArg_ParseTupleAndKeywords(args, kwds, "s|O&:type", keywords, - &name, path_converter, &filename)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "O|O&:type", keywords, + &name_or_type, path_converter, + &filename)) return NULL; - clear = set_drgn_in_python(); + if (PyObject_TypeCheck(name_or_type, &DrgnType_type)) { + if (DrgnType_prog((DrgnType *)name_or_type) != self) { + PyErr_SetString(PyExc_ValueError, + "type is from different program"); + return NULL; + } + Py_INCREF(name_or_type); + return name_or_type; + } else if (!PyUnicode_Check(name_or_type)) { + PyErr_SetString(PyExc_TypeError, + "type() argument 1 must be str or Type"); + return NULL; + } + + const char *name = PyUnicode_AsUTF8(name_or_type); + if (!name) + return NULL; + bool clear = set_drgn_in_python(); + struct drgn_qualified_type qualified_type; err = drgn_program_find_type(&self->prog, name, filename.path, &qualified_type); if (clear) diff --git a/tests/test_program.py b/tests/test_program.py index a957cb5c..f68ffd20 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -397,6 +397,15 @@ class TestTypes(MockProgramTestCase): self.prog.add_type_finder(lambda kind, name, filename: None) self.assertRaises(LookupError, self.prog.type, "struct foo") + def test_already_type(self): + self.assertIdentical( + self.prog.type(self.prog.pointer_type(self.prog.void_type())), + self.prog.pointer_type(self.prog.void_type()), + ) + + def test_invalid_argument_type(self): + self.assertRaises(TypeError, self.prog.type, 1) + def test_default_primitive_types(self): def spellings(tokens, num_optional=0): for i in range(len(tokens) - num_optional, len(tokens) + 1):