program: implement ProgramObject binary operators

This commit is contained in:
Omar Sandoval 2018-05-07 18:48:10 -07:00
parent 9560b913f3
commit 15ea6e8d97
4 changed files with 456 additions and 43 deletions

View File

@ -18,9 +18,17 @@ from drgn.typeindex import TypeIndex
import functools import functools
import itertools import itertools
import math import math
import operator
from typing import Any, Callable, Iterable, Optional, Tuple, Union from typing import Any, Callable, Iterable, Optional, Tuple, Union
def _c_modulo(a: int, b: int) -> int:
if a >= 0:
return a % abs(b)
else:
return -(-a % abs(b))
class ProgramObject: class ProgramObject:
""" """
A ProgramObject either represents an object in the memory of a program (an A ProgramObject either represents an object in the memory of a program (an
@ -41,20 +49,26 @@ class ProgramObject:
>>> print(prog['jiffies']) >>> print(prog['jiffies'])
(volatile long unsigned int)4326237045 (volatile long unsigned int)4326237045
ProgramObjects try to behave transparently like the object they represent ProgramObjects support C operators wherever possible. E.g., structure
in C. E.g., structure members can be accessed with the dot (".") operator members can be accessed with the dot (".") operator, arrays can be
and arrays can be subscripted with "[]". subscripted with "[]", arithmetic can be performed, and objects can be
compared.
>>> print(prog['init_task'].pid) >>> print(prog['init_task'].pid)
(pid_t)0 (pid_t)0
>>> print(prog['init_task'].comm[0]) >>> print(prog['init_task'].comm[0])
(char)115 (char)115
>>> prog['init_task'].nsproxy.mnt_ns.mounts + 1
ProgramObject(address=None, type=<unsigned int>, value=34)
>>> prog['init_task'].nsproxy.mnt_ns.pending_mounts > 0
False
Note that because the structure dereference operator ("->") is not valid Note that because the structure dereference operator ("->") is not valid
syntax in Python, "." is also used to access members of pointers to syntax in Python, "." is also used to access members of pointers to
structures. Similarly, the indirection operator ("*") is not valid syntax structures. Similarly, the indirection operator ("*") is not valid syntax
in Python, so pointers can be dereferenced with "[0]" (e.g., write "p[0]" in Python, so pointers can be dereferenced with "[0]" (e.g., write "p[0]"
instead of "*p"). instead of "*p"). The address-of operator ("&") is available as the
address_of_() method.
ProgramObject members and methods are named with a trailing underscore to ProgramObject members and methods are named with a trailing underscore to
avoid conflicting with structure or union members. The helper methods avoid conflicting with structure or union members. The helper methods
@ -263,70 +277,268 @@ class ProgramObject:
self._real_type.qualifiers), self._real_type.qualifiers),
address) address)
def _check_arithmetic_type(self) -> None: def _unary_operator(self, op: Callable, op_name: str,
if not isinstance(self._real_type, (ArithmeticType, BitFieldType)): integer: bool = False) -> 'ProgramObject':
raise TypeError('not an arithmetic type') if ((integer and not self._real_type.is_integer()) or
(not integer and not self._real_type.is_arithmetic())):
raise TypeError(f"invalid operand to unary {op_name} ('{self.type_}')")
type_ = self.program_._type_index.operand_type(self.type_)
type_ = self.program_._type_index.integer_promotions(type_)
return ProgramObject(self.program_, None, type_, op(self.value_()))
def _check_integer_type(self) -> None: def _binary_operands(self, lhs: Any, rhs: Any) -> Tuple[Any, Type, Any, Type]:
if not isinstance(self._real_type, (IntType, BitFieldType)): if (isinstance(lhs, ProgramObject) and isinstance(rhs, ProgramObject) and
raise TypeError('not an integer type') lhs.program_ is not rhs.program_):
raise ValueError('operands are from different programs')
if isinstance(lhs, ProgramObject):
lhs_type = lhs.type_
if isinstance(lhs._real_type, ArrayType):
lhs = lhs.address_
else:
lhs = lhs.value_()
else:
lhs_type = self.program_._type_index.literal_type(lhs)
if isinstance(rhs, ProgramObject):
rhs_type = rhs.type_
if isinstance(rhs._real_type, ArrayType):
rhs = rhs.address_
else:
rhs = rhs.value_()
else:
rhs_type = self.program_._type_index.literal_type(rhs)
return lhs, lhs_type, rhs, rhs_type
def _unary_type(self, integer: bool = False) -> Type: def _usual_arithmetic_conversions(self, lhs: Any, lhs_type: Type,
return self.program_._type_index.integer_promotions(self.type_.unqualified()) rhs: Any, rhs_type: Type) -> Tuple[Type, Any, Any]:
type_ = self.program_._type_index.common_real_type(lhs_type, rhs_type)
return type_, type_.convert(lhs), type_.convert(rhs)
def _arithmetic_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_arithmetic() and not rhs_type.is_arithmetic():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, op(lhs, rhs))
def _integer_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_integer() or not rhs_type.is_integer():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, op(lhs, rhs))
def _shift_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_integer() or not rhs_type.is_integer():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
lhs_type = self.program_._type_index.integer_promotions(lhs_type)
rhs_type = self.program_._type_index.integer_promotions(rhs_type)
return ProgramObject(self.program_, None, lhs_type, op(lhs, rhs))
def _relational_operator(self, op: Callable, op_name: str,
other: Any) -> bool:
lhs_pointer = isinstance(self._real_type, (ArrayType, PointerType))
rhs_pointer = (isinstance(other, ProgramObject) and
isinstance(other._real_type, (ArrayType, PointerType)))
lhs, lhs_type, rhs, rhs_type = self._binary_operands(self, other)
if ((lhs_pointer != rhs_pointer) or
(not lhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if not lhs_pointer:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return op(lhs, rhs)
def _add(self, lhs: Any, rhs: Any) -> 'ProgramObject':
lhs_pointer = (isinstance(lhs, ProgramObject) and
isinstance(lhs._real_type, (ArrayType, PointerType)))
rhs_pointer = (isinstance(rhs, ProgramObject) and
isinstance(rhs._real_type, (ArrayType, PointerType)))
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if ((lhs_pointer and rhs_pointer) or
(lhs_pointer and not rhs_type.is_integer()) or
(rhs_pointer and not lhs_type.is_integer()) or
(not lhs_pointer and not rhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary + ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if lhs_pointer:
assert isinstance(lhs_type, PointerType)
return ProgramObject(self.program_, None, lhs_type,
lhs + lhs_type.type.sizeof() * rhs)
elif rhs_pointer:
assert isinstance(rhs_type, PointerType)
return ProgramObject(self.program_, None, rhs_type,
rhs + rhs_type.type.sizeof() * lhs)
else:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, lhs + rhs)
def _sub(self, lhs: Any, rhs: Any) -> 'ProgramObject':
lhs_pointer = (isinstance(lhs, ProgramObject) and
isinstance(lhs._real_type, (ArrayType, PointerType)))
if lhs_pointer:
lhs_sizeof = lhs._real_type.type.sizeof()
rhs_pointer = (isinstance(rhs, ProgramObject) and
isinstance(rhs._real_type, (ArrayType, PointerType)))
if rhs_pointer:
rhs_sizeof = rhs._real_type.type.sizeof()
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if ((lhs_pointer and rhs_pointer and lhs_sizeof != rhs_sizeof) or
(lhs_pointer and not rhs_pointer and not rhs_type.is_integer()) or
(rhs_pointer and not lhs_pointer) or
(not lhs_pointer and not rhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary - ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if lhs_pointer and rhs_pointer:
return ProgramObject(self.program_, None,
self.program_._type_index.ptrdiff_t(),
(lhs - rhs) // lhs_sizeof)
elif lhs_pointer:
return ProgramObject(self.program_, None, lhs_type,
lhs - lhs_sizeof * rhs)
else:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, lhs - rhs)
def __add__(self, other: Any) -> 'ProgramObject':
return self._add(self, other)
def __sub__(self, other: Any) -> 'ProgramObject':
return self._sub(self, other)
def __mul__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.mul, '*', self, other)
def __truediv__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.truediv, '/', self, other)
def __mod__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(_c_modulo, '%', self, other)
def __lshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.lshift, '<<', self, other)
def __rshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.rshift, '>>', self, other)
def __and__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.and_, '&', self, other)
def __xor__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.xor, '^', self, other)
def __or__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.or_, '|', self, other)
def __radd__(self, other: Any) -> 'ProgramObject':
return self._add(other, self)
def __rsub__(self, other: Any) -> 'ProgramObject':
return self._sub(other, self)
def __rmul__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.mul, '*', other, self)
def __rtruediv__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.truediv, '/', other, self)
def __rmod__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(_c_modulo, '%', other, self)
def __rlshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.lshift, '<<', other, self)
def __rrshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.rshift, '>>', other, self)
def __rand__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.and_, '&', other, self)
def __rxor__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.xor, '^', other, self)
def __ror__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.or_, '|', other, self)
def __lt__(self, other: Any) -> bool:
return self._relational_operator(operator.lt, '<', other)
def __le__(self, other: Any) -> bool:
return self._relational_operator(operator.le, '<=', other)
def __eq__(self, other: Any) -> bool:
return self._relational_operator(operator.eq, '==', other)
def __ne__(self, other: Any) -> bool:
return self._relational_operator(operator.ne, '!=', other)
def __gt__(self, other: Any) -> bool:
return self._relational_operator(operator.gt, '>', other)
def __ge__(self, other: Any) -> bool:
return self._relational_operator(operator.ge, '>=', other)
def __bool__(self) -> bool: def __bool__(self) -> bool:
if not isinstance(self._real_type, (ArithmeticType, BitFieldType, if not isinstance(self._real_type, (ArithmeticType, BitFieldType,
PointerType)): PointerType)):
raise TypeError('not an arithmetic or pointer type') raise TypeError(f"invalid operand to bool() ('{self.type_}')")
return bool(self.value_()) return bool(self.value_())
def __neg__(self) -> 'ProgramObject': def __neg__(self) -> 'ProgramObject':
self._check_arithmetic_type() return self._unary_operator(operator.neg, '-')
return ProgramObject(self.program_, None, self._unary_type(),
-self.value_())
def __pos__(self) -> 'ProgramObject': def __pos__(self) -> 'ProgramObject':
self._check_arithmetic_type() return self._unary_operator(operator.pos, '+')
return ProgramObject(self.program_, None, self._unary_type(),
+self.value_())
def __abs__(self) -> 'ProgramObject':
self._check_arithmetic_type()
return ProgramObject(self.program_, None, self._unary_type(),
abs(self.value_()))
def __invert__(self) -> 'ProgramObject': def __invert__(self) -> 'ProgramObject':
self._check_integer_type() return self._unary_operator(operator.invert, '~', True)
return ProgramObject(self.program_, None, self._unary_type(),
~self.value_())
def __int__(self) -> int: def __int__(self) -> int:
self._check_arithmetic_type() if not isinstance(self._real_type, (ArithmeticType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to int")
return int(self.value_()) return int(self.value_())
def __float__(self) -> float: def __float__(self) -> float:
self._check_arithmetic_type() if not isinstance(self._real_type, (ArithmeticType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to float")
return float(self.value_()) return float(self.value_())
def __index__(self) -> int: def __index__(self) -> int:
self._check_integer_type() if not isinstance(self._real_type, (IntType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to index")
return self.value_() return self.value_()
def __round__(self, ndigits: Optional[int] = None) -> Union[int, float]: def __round__(self, ndigits: Optional[int] = None) -> Union[int, float]:
self._check_arithmetic_type() return round(self.__float__(), ndigits)
return round(self.value_(), ndigits)
def __trunc__(self) -> int: def __trunc__(self) -> int:
self._check_arithmetic_type() return math.trunc(self.__float__())
return math.trunc(self.value_())
def __floor__(self) -> int: def __floor__(self) -> int:
self._check_arithmetic_type() return math.floor(self.__float__())
return math.floor(self.value_())
def __ceil__(self) -> int: def __ceil__(self) -> int:
self._check_arithmetic_type() return math.ceil(self.__float__())
return math.ceil(self.value_())
class Program: class Program:

View File

@ -272,7 +272,7 @@ class TypeIndex:
if (not isinstance(real_type1, (ArithmeticType, BitFieldType)) or if (not isinstance(real_type1, (ArithmeticType, BitFieldType)) or
not isinstance(real_type2, (ArithmeticType, BitFieldType))): not isinstance(real_type2, (ArithmeticType, BitFieldType))):
raise TypeError('operands must be arithmetic types or bit fields') raise TypeError('operands must have arithmetic types')
# If either operand is long double, then the result is long double. # If either operand is long double, then the result is long double.
if isinstance(real_type1, FloatType) and real_type1.name == 'long double': if isinstance(real_type1, FloatType) and real_type1.name == 'long double':

View File

@ -1,4 +1,5 @@
import math import math
import operator
import unittest import unittest
from drgn.program import Program, ProgramObject from drgn.program import Program, ProgramObject
@ -70,7 +71,6 @@ class TestProgramObject(TypeIndexTestCase):
# _Bool should be the same because of integer promotions. # _Bool should be the same because of integer promotions.
self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['int'], -1)) self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['int'], -1))
self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['int'], 1)) self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['int'], 1))
self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['int'], 1))
self.assertEqual(~obj, ProgramObject(self.program, None, TYPES['int'], -2)) self.assertEqual(~obj, ProgramObject(self.program, None, TYPES['int'], -2))
self.assertEqual(int(obj), 1) self.assertEqual(int(obj), 1)
self.assertEqual(float(obj), 1.0) self.assertEqual(float(obj), 1.0)
@ -88,7 +88,6 @@ class TestProgramObject(TypeIndexTestCase):
self.assertTrue(bool(obj)) self.assertTrue(bool(obj))
self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['double'], -1.5)) self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['double'], -1.5))
self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['double'], 1.5)) self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['double'], 1.5))
self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['double'], 1.5))
with self.assertRaises(TypeError): with self.assertRaises(TypeError):
~obj ~obj
self.assertEqual(int(obj), 1) self.assertEqual(int(obj), 1)
@ -202,3 +201,203 @@ class TestProgramObject(TypeIndexTestCase):
self.assertEqual(struct_obj.address_, 0xffff0000) self.assertEqual(struct_obj.address_, 0xffff0000)
self.assertEqual(struct_obj.member_('address_'), self.assertEqual(struct_obj.member_('address_'),
ProgramObject(self.program, 0xffff0000, TYPES['unsigned long'])) ProgramObject(self.program, 0xffff0000, TYPES['unsigned long']))
def test_relational(self):
one = self.program.object(None, TYPES['int'], 1)
two = self.program.object(None, TYPES['int'], 2)
three = self.program.object(None, TYPES['int'], 3)
ptr0 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertTrue(one < two)
self.assertFalse(two < two)
self.assertFalse(three < two)
self.assertTrue(ptr0 < ptr1)
self.assertTrue(one <= two)
self.assertTrue(two <= two)
self.assertFalse(three <= two)
self.assertTrue(ptr0 <= ptr1)
self.assertTrue(one == one)
self.assertFalse(one == two)
self.assertFalse(ptr0 == ptr1)
self.assertFalse(one != one)
self.assertTrue(one != two)
self.assertTrue(ptr0 != ptr1)
self.assertFalse(one > two)
self.assertFalse(two > two)
self.assertTrue(three > two)
self.assertFalse(ptr0 > ptr1)
self.assertFalse(one >= two)
self.assertTrue(two >= two)
self.assertTrue(three >= two)
self.assertFalse(ptr0 >= ptr1)
negative_one = self.program.object(None, TYPES['int'], -1)
unsigned_zero = self.program.object(None, TYPES['unsigned int'], 0)
# The usual arithmetic conversions convert -1 to an unsigned int.
self.assertFalse(negative_one < unsigned_zero)
self.assertTrue(self.program.object(None, TYPES['int'], 1) ==
self.program.object(None, TYPES['_Bool'], 1))
self.assertRaises(TypeError, operator.lt, ptr0, one)
def _test_arithmetic(self, op, lhs, rhs, result, integral=True,
floating_point=False):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def LONG(value):
return self.program.object(None, TYPES['long'], value)
def DOUBLE(value):
return self.program.object(None, TYPES['double'], value)
if integral:
self.assertEqual(op(INT(lhs), INT(rhs)), INT(result))
self.assertEqual(op(INT(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(INT(lhs), rhs), INT(result))
self.assertEqual(op(LONG(lhs), rhs), LONG(result))
self.assertEqual(op(lhs, INT(rhs)), INT(result))
self.assertEqual(op(lhs, LONG(rhs)), LONG(result))
if floating_point:
self.assertEqual(op(DOUBLE(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(DOUBLE(lhs), INT(rhs)), DOUBLE(result))
self.assertEqual(op(INT(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(DOUBLE(lhs), float(rhs)), DOUBLE(result))
self.assertEqual(op(float(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(float(lhs), INT(rhs)), DOUBLE(result))
self.assertEqual(op(INT(lhs), float(rhs)), DOUBLE(result))
def _test_pointer_type_errors(self, op):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def POINTER(value):
return self.program.object(None,
self.type_index.pointer(TYPES['int']),
value)
self.assertRaises(TypeError, op, INT(1), POINTER(1))
self.assertRaises(TypeError, op, POINTER(1), INT(1))
self.assertRaises(TypeError, op, POINTER(1), POINTER(1))
def _test_floating_type_errors(self, op):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def DOUBLE(value):
return self.program.object(None, TYPES['double'], value)
self.assertRaises(TypeError, op, INT(1), DOUBLE(1))
self.assertRaises(TypeError, op, DOUBLE(1), INT(1))
self.assertRaises(TypeError, op, DOUBLE(1), DOUBLE(1))
def _test_shift(self, op, lhs, rhs, result):
def BOOL(value):
return self.program.object(None, TYPES['_Bool'], value)
def INT(value):
return self.program.object(None, TYPES['int'], value)
def LONG(value):
return self.program.object(None, TYPES['long'], value)
self.assertEqual(op(INT(lhs), INT(rhs)), INT(result))
self.assertEqual(op(INT(lhs), LONG(rhs)), INT(result))
self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(INT(lhs), rhs), INT(result))
self.assertEqual(op(LONG(lhs), rhs), LONG(result))
self.assertEqual(op(lhs, INT(rhs)), INT(result))
self.assertEqual(op(lhs, LONG(rhs)), INT(result))
self._test_pointer_type_errors(op)
self._test_floating_type_errors(op)
def test_add(self):
self._test_arithmetic(operator.add, 2, 2, 4, floating_point=True)
one = self.program.object(None, TYPES['int'], 1)
ptr = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertEqual(ptr + one, ptr1)
self.assertEqual(one + ptr, ptr1)
self.assertEqual(ptr + 1, ptr1)
self.assertEqual(1 + ptr, ptr1)
self.assertRaises(TypeError, operator.add, ptr, ptr)
self.assertRaises(TypeError, operator.add, ptr, 2.0)
self.assertRaises(TypeError, operator.add, 2.0, ptr)
def test_sub(self):
self._test_arithmetic(operator.sub, 4, 2, 2, floating_point=True)
ptr = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertEqual(ptr1 - ptr,
ProgramObject(self.program, None, TYPES['ptrdiff_t'], 1))
self.assertEqual(ptr - ptr1,
ProgramObject(self.program, None, TYPES['ptrdiff_t'], -1))
self.assertEqual(ptr - 0, ptr)
self.assertEqual(ptr1 - 1, ptr)
self.assertRaises(TypeError, operator.sub, 1, ptr)
self.assertRaises(TypeError, operator.sub, ptr, 1.0)
def test_mul(self):
self._test_arithmetic(operator.mul, 2, 3, 6, floating_point=True)
self._test_pointer_type_errors(operator.mul)
def test_div(self):
self._test_arithmetic(operator.truediv, 6, 3, 2, floating_point=True)
# Make sure we do integer division for integer operands.
self._test_arithmetic(operator.truediv, 3, 2, 1)
# Make sure we truncate towards zero (Python truncates towards negative
# infinity).
self._test_arithmetic(operator.truediv, -1, 2, 0)
self._test_arithmetic(operator.truediv, 1, -2, 0)
self._test_pointer_type_errors(operator.mul)
def test_mod(self):
self._test_arithmetic(operator.mod, 4, 2, 0)
# Make sure the modulo result has the sign of the dividend (Python uses
# the sign of the divisor).
self._test_arithmetic(operator.mod, 1, 26, 1)
self._test_arithmetic(operator.mod, 1, -26, 1)
self._test_arithmetic(operator.mod, -1, 26, -1)
self._test_arithmetic(operator.mod, -1, -26, -1)
self._test_pointer_type_errors(operator.mod)
self._test_floating_type_errors(operator.mod)
def test_lshift(self):
self._test_shift(operator.lshift, 2, 3, 16)
def test_rshift(self):
self._test_shift(operator.rshift, 16, 3, 2)
def test_and(self):
self._test_arithmetic(operator.and_, 1, 3, 1)
self._test_pointer_type_errors(operator.and_)
self._test_floating_type_errors(operator.and_)
def test_xor(self):
self._test_arithmetic(operator.xor, 1, 3, 2)
self._test_pointer_type_errors(operator.xor)
self._test_floating_type_errors(operator.xor)
def test_or(self):
self._test_arithmetic(operator.or_, 1, 3, 3)
self._test_pointer_type_errors(operator.or_)
self._test_floating_type_errors(operator.or_)

View File

@ -21,7 +21,7 @@ from drgn.type import (
VoidType, VoidType,
) )
from drgn.typeindex import DwarfTypeIndex, TypeIndex from drgn.typeindex import DwarfTypeIndex, TypeIndex
from drgn.typename import BasicTypeName, TypeName from drgn.typename import BasicTypeName, TypeName, TypedefTypeName
from tests.test_type import ( from tests.test_type import (
anonymous_point_type, anonymous_point_type,
const_anonymous_point_type, const_anonymous_point_type,
@ -48,7 +48,9 @@ TYPES = {
'float': FloatType('float', 4), 'float': FloatType('float', 4),
'double': FloatType('double', 8), 'double': FloatType('double', 8),
'long double': FloatType('long double', 16), 'long double': FloatType('long double', 16),
'ptrdiff_t': FloatType('long double', 16),
} }
TYPES['ptrdiff_t'] = TypedefType('ptrdiff_t', TYPES['long'])
class MockTypeIndex(TypeIndex): class MockTypeIndex(TypeIndex):
@ -56,7 +58,7 @@ class MockTypeIndex(TypeIndex):
super().__init__(8) super().__init__(8)
def _find_type(self, type_name: TypeName) -> Type: def _find_type(self, type_name: TypeName) -> Type:
if isinstance(type_name, BasicTypeName): if isinstance(type_name, (BasicTypeName, TypedefTypeName)):
try: try:
return TYPES[type_name.name] return TYPES[type_name.name]
except KeyError: except KeyError: