Make cast() a function instead of a method

This commit is contained in:
Omar Sandoval 2019-02-22 10:34:53 -08:00
parent 3b998be960
commit 9a4262b609
8 changed files with 33 additions and 32 deletions

View File

@ -47,6 +47,16 @@ from drgn.type import Type, PointerType
from drgn.typename import TypeName from drgn.typename import TypeName
def cast(type: Union[str, Type, TypeName], obj: Object) -> Object:
"""
Return a copy of the given object casted to another type. The given
type is usually a string, but it can also be a Type or TypeName object.
"""
if not isinstance(type, Type):
type = obj.prog_.type(type)
return Object(obj.prog_, type, value=obj._value, address=obj.address_)
def container_of(ptr: Object, type: Union[str, Type, TypeName], def container_of(ptr: Object, type: Union[str, Type, TypeName],
member: str) -> Object: member: str) -> Object:
""" """

View File

@ -8,7 +8,7 @@ This module provides helpers for working with Linux devices, including the
kernel encoding of dev_t. kernel encoding of dev_t.
""" """
from drgn import Object from drgn import cast, Object
__all__ = [ __all__ = [
'MAJOR', 'MAJOR',
@ -30,7 +30,7 @@ def MAJOR(dev):
""" """
major = dev >> _MINORBITS major = dev >> _MINORBITS
if isinstance(major, Object): if isinstance(major, Object):
return major.cast_('unsigned int') return cast('unsigned int', major)
return major return major
@ -42,7 +42,7 @@ def MINOR(dev):
""" """
minor = dev & _MINORMASK minor = dev & _MINORMASK
if isinstance(minor, Object): if isinstance(minor, Object):
return minor.cast_('unsigned int') return cast('unsigned int', minor)
return minor return minor
@ -54,5 +54,5 @@ def MKDEV(major, minor):
""" """
dev = (major << _MINORBITS) | minor dev = (major << _MINORBITS) | minor
if isinstance(dev, Object): if isinstance(dev, Object):
return dev.cast_('dev_t') return cast('dev_t', dev)
return dev return dev

View File

@ -8,7 +8,7 @@ This module provides helpers for working with the Linux memory management (mm)
subsystem. Only x86-64 support is currently implemented. subsystem. Only x86-64 support is currently implemented.
""" """
from drgn import Object from drgn import cast, Object
__all__ = [ __all__ = [
@ -25,7 +25,7 @@ __all__ = [
def _vmemmap(prog): def _vmemmap(prog):
try: try:
# KASAN # KASAN
return prog['vmemmap_base'].cast_('struct page *') return cast('struct page *', prog['vmemmap_base'])
except KeyError: except KeyError:
# x86-64 # x86-64
return Object(prog, 'struct page *', value=0xffffea0000000000) return Object(prog, 'struct page *', value=0xffffea0000000000)
@ -57,7 +57,7 @@ def page_to_pfn(page):
Get the page frame number (PFN) of a page. Get the page frame number (PFN) of a page.
""" """
return (page - _vmemmap(page.prog_)).cast_('unsigned long') return cast('unsigned long', page - _vmemmap(page.prog_))
def pfn_to_page(prog_or_pfn, pfn=None): def pfn_to_page(prog_or_pfn, pfn=None):

View File

@ -7,7 +7,7 @@ Linux kernel process ID helpers
This module provides helpers for looking up process IDs. This module provides helpers for looking up process IDs.
""" """
from drgn import container_of, NULL, Program from drgn import cast, container_of, NULL, Program
from drgn.helpers.kernel.idr import idr_find, idr_for_each from drgn.helpers.kernel.idr import idr_find, idr_for_each
from drgn.helpers.kernel.list import hlist_for_each_entry from drgn.helpers.kernel.list import hlist_for_each_entry
@ -34,7 +34,7 @@ def find_pid(prog_or_ns, nr):
prog = prog_or_ns.prog_ prog = prog_or_ns.prog_
ns = prog_or_ns ns = prog_or_ns
if hasattr(ns, 'idr'): if hasattr(ns, 'idr'):
return idr_find(ns.idr, nr).cast_('struct pid *') return cast('struct pid *', idr_find(ns.idr, nr))
else: else:
# We could implement pid_hashfn() and only search that bucket, but it's # We could implement pid_hashfn() and only search that bucket, but it's
# different for 32-bit and 64-bit systems, and it has changed at least # different for 32-bit and 64-bit systems, and it has changed at least
@ -67,7 +67,7 @@ def for_each_pid(prog_or_ns):
ns = prog_or_ns ns = prog_or_ns
if hasattr(ns, 'idr'): if hasattr(ns, 'idr'):
for nr, entry in idr_for_each(ns.idr): for nr, entry in idr_for_each(ns.idr):
yield entry.cast_('struct pid *') yield cast('struct pid *', entry)
else: else:
pid_hash = prog['pid_hash'] pid_hash = prog['pid_hash']
for i in range(1 << prog['pidhash_shift'].value_()): for i in range(1 << prog['pidhash_shift'].value_()):

View File

@ -8,7 +8,7 @@ This module provides helpers for working with radix trees from
"linux/radix-tree.h". "linux/radix-tree.h".
""" """
from drgn import Object from drgn import cast, Object
__all__ = [ __all__ = [
@ -33,7 +33,7 @@ def _radix_tree_root_node(root):
except AttributeError: except AttributeError:
return root.rnode.read_once_(), 1 return root.rnode.read_once_(), 1
else: else:
return node.cast_('struct xa_node *').read_once_(), 2 return cast('struct xa_node *', node).read_once_(), 2
def radix_tree_lookup(root, index): def radix_tree_lookup(root, index):
@ -50,8 +50,8 @@ def radix_tree_lookup(root, index):
break break
parent = _entry_to_node(node, RADIX_TREE_INTERNAL_NODE) parent = _entry_to_node(node, RADIX_TREE_INTERNAL_NODE)
offset = (index >> parent.shift) & RADIX_TREE_MAP_MASK offset = (index >> parent.shift) & RADIX_TREE_MAP_MASK
node = parent.slots[offset].cast_(parent.type_).read_once_() node = cast(parent.type_, parent.slots[offset]).read_once_()
return node.cast_('void *') return cast('void *', node)
def radix_tree_for_each(root): def radix_tree_for_each(root):
@ -66,8 +66,8 @@ def radix_tree_for_each(root):
if _is_internal_node(node, RADIX_TREE_INTERNAL_NODE): if _is_internal_node(node, RADIX_TREE_INTERNAL_NODE):
parent = _entry_to_node(node, RADIX_TREE_INTERNAL_NODE) parent = _entry_to_node(node, RADIX_TREE_INTERNAL_NODE)
for i, slot in enumerate(parent.slots): for i, slot in enumerate(parent.slots):
yield from aux(slot.cast_(parent.type_).read_once_(), yield from aux(cast(parent.type_, slot).read_once_(),
index + (i << parent.shift.value_())) index + (i << parent.shift.value_()))
elif node: elif node:
yield index, node.cast_('void *') yield index, cast('void *', node)
yield from aux(node, 0) yield from aux(node, 0)

View File

@ -80,7 +80,8 @@ def main() -> None:
from drgn.internal.rlcompleter import Completer from drgn.internal.rlcompleter import Completer
init_globals['drgn'] = drgn init_globals['drgn'] = drgn
for attr in ['container_of', 'Object', 'NULL']: drgn_globals = ['cast', 'container_of', 'NULL', 'Object']
for attr in drgn_globals:
init_globals[attr] = getattr(drgn, attr) init_globals[attr] = getattr(drgn, attr)
init_globals['__name__'] = '__main__' init_globals['__name__'] = '__main__'
init_globals['__doc__'] = None init_globals['__doc__'] = None
@ -102,7 +103,7 @@ def main() -> None:
banner = version + """ banner = version + """
For help, type help(drgn). For help, type help(drgn).
>>> import drgn >>> import drgn
>>> from drgn import container_of, Object, NULL""" >>> from drgn import """ + ', '.join(drgn_globals)
if prog._is_kernel(): if prog._is_kernel():
banner += '\n>>> from drgn.helpers.kernel import *' banner += '\n>>> from drgn.helpers.kernel import *'
module = importlib.import_module('drgn.helpers.kernel') module = importlib.import_module('drgn.helpers.kernel')

View File

@ -246,16 +246,6 @@ class Object:
raise ValueError(f'member access must be on a struct or union, not {self.type_.name!r}') from None raise ValueError(f'member access must be on a struct or union, not {self.type_.name!r}') from None
return Object(self.prog_, member_type, address=address + offset) return Object(self.prog_, member_type, address=address + offset)
def cast_(self, type: Union[str, Type, TypeName]) -> 'Object':
"""
Return a copy of this object casted to another type. The given type is
usually a string, but it can also be a Type or TypeName object.
"""
if not isinstance(type, Type):
type = self.prog_.type(type)
return Object(self.prog_, type, value=self._value,
address=self.address_)
def address_of_(self) -> 'Object': def address_of_(self) -> 'Object':
""" """
Return an object pointing to this object. Corresponds to the address-of Return an object pointing to this object. Corresponds to the address-of

View File

@ -2,7 +2,7 @@ import math
import operator import operator
import tempfile import tempfile
from drgn import container_of, NULL, Object, Program from drgn import cast, container_of, NULL, Object, Program
from drgn.internal.corereader import CoreReader from drgn.internal.corereader import CoreReader
from drgn.type import IntType, StructType, TypedefType from drgn.type import IntType, StructType, TypedefType
from tests.test_type import color_type, point_type from tests.test_type import color_type, point_type
@ -52,12 +52,12 @@ class TestObject(TypeIndexTestCase):
def test_cast(self): def test_cast(self):
obj = Object(self.prog, TYPES['int'], value=-1) obj = Object(self.prog, TYPES['int'], value=-1)
cast_obj = obj.cast_('unsigned int') cast_obj = cast('unsigned int', obj)
self.assertEqual(cast_obj, self.assertEqual(cast_obj,
Object(self.prog, TYPES['unsigned int'], value=2**32 - 1)) Object(self.prog, TYPES['unsigned int'], value=2**32 - 1))
obj = Object(self.prog, TYPES['double'], value=1.0) obj = Object(self.prog, TYPES['double'], value=1.0)
self.assertRaises(TypeError, obj.cast_, self.type_index.pointer(TYPES['int'])) self.assertRaises(TypeError, cast, self.type_index.pointer(TYPES['int']), obj)
def test_str(self): def test_str(self):
obj = Object(self.prog, TYPES['int'], value=1) obj = Object(self.prog, TYPES['int'], value=1)
@ -170,7 +170,7 @@ class TestObject(TypeIndexTestCase):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
obj.member_('foo') obj.member_('foo')
cast_obj = obj.cast_('unsigned long') cast_obj = cast('unsigned long', obj)
self.assertEqual(cast_obj, self.assertEqual(cast_obj,
Object(self.prog, TYPES['unsigned long'], value=0)) Object(self.prog, TYPES['unsigned long'], value=0))
self.assertRaises(TypeError, obj.__index__) self.assertRaises(TypeError, obj.__index__)