diff --git a/drgn/cli.py b/drgn/cli.py index 22bc54c2..9389f236 100644 --- a/drgn/cli.py +++ b/drgn/cli.py @@ -179,7 +179,7 @@ def main() -> None: else: max_address = 2**32 - 1 segments = [(0, 0, 0, max_address, max_address)] - core_reader = CoreReader(core_file.fileno(), segments) + core_reader = CoreReader(core_file, segments) if args.pid is None: if os.path.abspath(args.core) == '/proc/kcore': diff --git a/drgn/corereader.c b/drgn/corereader.c index c5a85f48..77df2171 100644 --- a/drgn/corereader.c +++ b/drgn/corereader.c @@ -24,6 +24,7 @@ struct segment { typedef struct { PyObject_HEAD + PyObject *file; int fd; int num_segments; struct segment *segments; @@ -55,22 +56,49 @@ static int pread_all(int fd, void *buf, size_t count, off_t offset) static void CoreReader_dealloc(CoreReader *self) { free(self->segments); + Py_XDECREF(self->file); Py_TYPE(self)->tp_free((PyObject *)self); } +static int CoreReader_traverse(CoreReader *self, visitproc visit, void *arg) +{ + Py_VISIT(self->file); + return 0; +} + +static int CoreReader_clear(CoreReader *self) +{ + Py_CLEAR(self->file); + return 0; +} + static int CoreReader_init(CoreReader *self, PyObject *args, PyObject *kwds) { static const char *errmsg = "segment must be (offset, vaddr, paddr, filesz, memsz)"; - static char *keywords[] = {"fd", "segments", NULL}; - int fd; + static char *keywords[] = {"file", "segments", NULL}; + PyObject *file, *fd_obj; + long fd; PyObject *segments_list; struct segment *segments; int num_segments, i; - if (!PyArg_ParseTupleAndKeywords(args, kwds, "iO!:CoreReader", keywords, - &fd, &PyList_Type, &segments_list)) + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OO!:CoreReader", keywords, + &file, &PyList_Type, &segments_list)) return -1; + fd_obj = PyObject_CallMethod(file, "fileno", "()"); + if (!fd_obj) + return -1; + + fd = PyLong_AsLong(fd_obj); + Py_DECREF(fd_obj); + if (fd == -1 && PyErr_Occurred()) + return -1; + if (fd < 0 || fd > INT_MAX) { + PyErr_SetString(PyExc_ValueError, "invalid file descriptor"); + return -1; + } + if (PyList_GET_SIZE(segments_list) > INT_MAX) { PyErr_SetString(PyExc_OverflowError, "too many segments"); return -1; @@ -121,6 +149,9 @@ static int CoreReader_init(CoreReader *self, PyObject *args, PyObject *kwds) } free(self->segments); + Py_XDECREF(self->file); + Py_INCREF(file); + self->file = file; self->fd = fd; self->segments = segments; self->num_segments = num_segments; @@ -143,11 +174,51 @@ static CoreReader *CoreReader_new(PyTypeObject *subtype, PyObject *args, return reader; } +static PyObject *CoreReader_close(CoreReader *self) +{ + PyObject *ret; + + if (!self->file) + Py_RETURN_NONE; + + ret = PyObject_CallMethod(self->file, "close", "()"); + if (ret) { + Py_DECREF(self->file); + self->file = NULL; + self->fd = -1; + } + return ret; +} + +static PyObject *CoreReader_enter(PyObject *self) +{ + Py_INCREF(self); + return self; +} + +static PyObject *CoreReader_exit(CoreReader *self, PyObject *args, + PyObject *kwds) +{ + static char *keywords[] = {"exc_type", "exc_value", "traceback", NULL}; + PyObject *exc_type, *exc_value, *traceback; + + if (!PyArg_ParseTupleAndKeywords(args, kwds, "OOO:__exit__", keywords, + &exc_type, &exc_value, &traceback)) + return NULL; + + return CoreReader_close(self); +} + static int read_core(CoreReader *self, void *buf, uint64_t address, uint64_t count, int physical) { char *p = buf; + if (self->fd == -1) { + PyErr_SetString(PyExc_ValueError, "read on closed CoreReader"); + return -1; + } + while (count) { struct segment *segment; uint64_t segment_address; @@ -289,9 +360,17 @@ CoreReader_READ(long_double, long double, PyFloat_FromDouble) "physical -- whether address is a physical memory address"} #define CoreReader_DOC \ - "CoreReader(fd, segments) -> new core file reader" + "CoreReader(file, segments) -> new core file reader" static PyMethodDef CoreReader_methods[] = { + {"close", (PyCFunction)CoreReader_close, + METH_NOARGS, + "close()\n\n" + "Close the file underlying this reader."}, + {"__enter__", (PyCFunction)CoreReader_enter, + METH_NOARGS}, + {"__exit__", (PyCFunction)CoreReader_exit, + METH_VARARGS | METH_KEYWORDS}, {"read", (PyCFunction)CoreReader_read, METH_VARARGS | METH_KEYWORDS, "read(address, size, physical=False)\n\n" @@ -338,10 +417,10 @@ static PyTypeObject CoreReader_type = { NULL, /* tp_getattro */ NULL, /* tp_setattro */ NULL, /* tp_as_buffer */ - Py_TPFLAGS_DEFAULT, /* tp_flags */ + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_GC,/* tp_flags */ CoreReader_DOC, /* tp_doc */ - NULL, /* tp_traverse */ - NULL, /* tp_clear */ + (traverseproc)CoreReader_traverse,/* tp_traverse */ + (inquiry)CoreReader_clear, /* tp_clear */ NULL, /* tp_richcompare */ 0, /* tp_weaklistoffset */ NULL, /* tp_iter */ diff --git a/drgn/corereader.pyi b/drgn/corereader.pyi index e9dc2f00..ce34c640 100644 --- a/drgn/corereader.pyi +++ b/drgn/corereader.pyi @@ -1,9 +1,9 @@ from os import PathLike -from typing import Any, List, Sequence, Tuple, Union +from typing import Any, BinaryIO, List, Sequence, Tuple, Union class CoreReader: - def __init__(self, fd: int, + def __init__(self, file: BinaryIO, segments: Sequence[Tuple[int, int, int, int, int]]) -> None: ... def close(self) -> None: ... def __enter__(self) -> CoreReader: ... diff --git a/drgn/program.py b/drgn/program.py index b8811e43..811e5ed7 100644 --- a/drgn/program.py +++ b/drgn/program.py @@ -591,6 +591,25 @@ class Program: self._reader = reader self._type_index = type_index self._variable_index = variable_index + # Ugly hack for KernelVariableIndex. + try: + set_program = variable_index.set_program # type: ignore + except AttributeError: + pass + else: + set_program(self) + + def __enter__(self) -> 'Program': + return self + + def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: + self.close() + + def close(self) -> None: + """ + Close resources associated with this Program. + """ + self._reader.close() def __getitem__(self, name: str) -> ProgramObject: """ diff --git a/tests/test_corereader.py b/tests/test_corereader.py index acfc15ed..318d1677 100644 --- a/tests/test_corereader.py +++ b/tests/test_corereader.py @@ -18,23 +18,25 @@ def tmpfile(data): class TestCoreReader(unittest.TestCase): def test_bad_segments(self): - self.assertRaises(TypeError, CoreReader, 0, 0) - self.assertRaises(TypeError, CoreReader, 0, [0]) - self.assertRaises(ValueError, CoreReader, 0, [()]) - self.assertRaises(OverflowError, CoreReader, 0, [(2**64, 0, 0, 0, 0)]) + with tmpfile(b'') as file: + self.assertRaises(TypeError, CoreReader, file, 0) + self.assertRaises(TypeError, CoreReader, file, [0]) + self.assertRaises(ValueError, CoreReader, file, [()]) + self.assertRaises(OverflowError, CoreReader, file, + [(2**64, 0, 0, 0, 0)]) def test_simple_read(self): data = b'hello, world!' segments = [(0, 0xffff0000, 0x0, len(data), len(data))] with tmpfile(data) as file: - core_reader = CoreReader(file.fileno(), segments) + core_reader = CoreReader(file, segments) self.assertEqual(core_reader.read(0xffff0000, len(data)), data) def test_bad_address(self): data = b'hello, world!' segments = [(0, 0xffff0000, 0x0, len(data), len(data))] with tmpfile(data) as file: - core_reader = CoreReader(file.fileno(), segments) + core_reader = CoreReader(file, segments) self.assertRaisesRegex(ValueError, 'could not find memory segment', core_reader.read, 0xdeadbeef, 4) @@ -42,7 +44,7 @@ class TestCoreReader(unittest.TestCase): data = b'hello, world!' segments = [(0, 0xffff0000, 0x0, len(data), len(data))] with tmpfile(data) as file: - core_reader = CoreReader(file.fileno(), segments) + core_reader = CoreReader(file, segments) self.assertRaisesRegex(ValueError, 'could not find memory segment', core_reader.read, 0xffff0000, len(data) + 1) @@ -54,7 +56,7 @@ class TestCoreReader(unittest.TestCase): (4, 0xffff0004, 0x0, 10, 10), ] with tmpfile(data) as file: - core_reader = CoreReader(file.fileno(), segments) + core_reader = CoreReader(file, segments) self.assertEqual(core_reader.read(0xffff0000, 14), data[:14]) def test_zero_filled_segment(self): @@ -63,7 +65,7 @@ class TestCoreReader(unittest.TestCase): (0, 0xffff0000, 0x0, 13, 17), ] with tmpfile(data) as file: - core_reader = CoreReader(file.fileno(), segments) + core_reader = CoreReader(file, segments) self.assertEqual(core_reader.read(0xffff0000, len(data) + 4), data + bytes(4)) self.assertEqual(core_reader.read(0xffff0000 + len(data), 4), diff --git a/tests/test_program.py b/tests/test_program.py index 045357ab..452c4027 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -27,14 +27,19 @@ class TestProgramObject(TypeIndexTestCase): self.addTypeEqualityFunc(ProgramObject, program_object_equality_func) buffer = b'\x01\x00\x00\x00\x02\x00\x00\x00hello\x00\x00\x00' segments = [(0, 0xffff0000, 0x0, len(buffer), len(buffer))] - self.tmpfile = tempfile.TemporaryFile() - self.tmpfile.write(buffer) - self.tmpfile.flush() - core_reader = CoreReader(self.tmpfile.fileno(), segments) - self.program = Program(reader=core_reader, type_index=self.type_index, - variable_index=None) + tmpfile = tempfile.TemporaryFile() + try: + tmpfile.write(buffer) + tmpfile.flush() + core_reader = CoreReader(tmpfile, segments) + self.program = Program(reader=core_reader, type_index=self.type_index, + variable_index=None) + except: + tmpfile.close() + raise def tearDown(self): - self.tmpfile.close() + if hasattr(self, 'program'): + self.program.close() super().tearDown() def test_constructor(self): diff --git a/tests/test_type.py b/tests/test_type.py index 316d5e4c..32f95234 100644 --- a/tests/test_type.py +++ b/tests/test_type.py @@ -638,7 +638,7 @@ class TestTypeRead(unittest.TestCase): def test_void(self): type_ = VoidType() with tmpfile(b'') as file: - reader = CoreReader(file.fileno(), []) + reader = CoreReader(file, []) self.assertRaises(ValueError, type_.read, reader, 0x0) self.assertRaises(ValueError, type_.read_pretty, reader, 0x0) @@ -646,7 +646,7 @@ class TestTypeRead(unittest.TestCase): expected_pretty_cast, expected_pretty_nocast): segments = [(0, 0xffff0000, 0x0, len(buffer), len(buffer))] with tmpfile(buffer) as file: - reader = CoreReader(file.fileno(), segments) + reader = CoreReader(file, segments) self.assertEqual(type_.read(reader, 0xffff0000), expected_value) self.assertEqual(type_.read_pretty(reader, 0xffff0000), expected_pretty_cast)