program: implement ProgramObject binary operators

This commit is contained in:
Omar Sandoval 2018-05-07 18:48:10 -07:00
parent 9560b913f3
commit 15ea6e8d97
4 changed files with 456 additions and 43 deletions

View File

@ -18,9 +18,17 @@ from drgn.typeindex import TypeIndex
import functools
import itertools
import math
import operator
from typing import Any, Callable, Iterable, Optional, Tuple, Union
def _c_modulo(a: int, b: int) -> int:
if a >= 0:
return a % abs(b)
else:
return -(-a % abs(b))
class ProgramObject:
"""
A ProgramObject either represents an object in the memory of a program (an
@ -41,20 +49,26 @@ class ProgramObject:
>>> print(prog['jiffies'])
(volatile long unsigned int)4326237045
ProgramObjects try to behave transparently like the object they represent
in C. E.g., structure members can be accessed with the dot (".") operator
and arrays can be subscripted with "[]".
ProgramObjects support C operators wherever possible. E.g., structure
members can be accessed with the dot (".") operator, arrays can be
subscripted with "[]", arithmetic can be performed, and objects can be
compared.
>>> print(prog['init_task'].pid)
(pid_t)0
>>> print(prog['init_task'].comm[0])
(char)115
>>> prog['init_task'].nsproxy.mnt_ns.mounts + 1
ProgramObject(address=None, type=<unsigned int>, value=34)
>>> prog['init_task'].nsproxy.mnt_ns.pending_mounts > 0
False
Note that because the structure dereference operator ("->") is not valid
syntax in Python, "." is also used to access members of pointers to
structures. Similarly, the indirection operator ("*") is not valid syntax
in Python, so pointers can be dereferenced with "[0]" (e.g., write "p[0]"
instead of "*p").
instead of "*p"). The address-of operator ("&") is available as the
address_of_() method.
ProgramObject members and methods are named with a trailing underscore to
avoid conflicting with structure or union members. The helper methods
@ -263,70 +277,268 @@ class ProgramObject:
self._real_type.qualifiers),
address)
def _check_arithmetic_type(self) -> None:
if not isinstance(self._real_type, (ArithmeticType, BitFieldType)):
raise TypeError('not an arithmetic type')
def _unary_operator(self, op: Callable, op_name: str,
integer: bool = False) -> 'ProgramObject':
if ((integer and not self._real_type.is_integer()) or
(not integer and not self._real_type.is_arithmetic())):
raise TypeError(f"invalid operand to unary {op_name} ('{self.type_}')")
type_ = self.program_._type_index.operand_type(self.type_)
type_ = self.program_._type_index.integer_promotions(type_)
return ProgramObject(self.program_, None, type_, op(self.value_()))
def _check_integer_type(self) -> None:
if not isinstance(self._real_type, (IntType, BitFieldType)):
raise TypeError('not an integer type')
def _binary_operands(self, lhs: Any, rhs: Any) -> Tuple[Any, Type, Any, Type]:
if (isinstance(lhs, ProgramObject) and isinstance(rhs, ProgramObject) and
lhs.program_ is not rhs.program_):
raise ValueError('operands are from different programs')
if isinstance(lhs, ProgramObject):
lhs_type = lhs.type_
if isinstance(lhs._real_type, ArrayType):
lhs = lhs.address_
else:
lhs = lhs.value_()
else:
lhs_type = self.program_._type_index.literal_type(lhs)
if isinstance(rhs, ProgramObject):
rhs_type = rhs.type_
if isinstance(rhs._real_type, ArrayType):
rhs = rhs.address_
else:
rhs = rhs.value_()
else:
rhs_type = self.program_._type_index.literal_type(rhs)
return lhs, lhs_type, rhs, rhs_type
def _unary_type(self, integer: bool = False) -> Type:
return self.program_._type_index.integer_promotions(self.type_.unqualified())
def _usual_arithmetic_conversions(self, lhs: Any, lhs_type: Type,
rhs: Any, rhs_type: Type) -> Tuple[Type, Any, Any]:
type_ = self.program_._type_index.common_real_type(lhs_type, rhs_type)
return type_, type_.convert(lhs), type_.convert(rhs)
def _arithmetic_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_arithmetic() and not rhs_type.is_arithmetic():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, op(lhs, rhs))
def _integer_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_integer() or not rhs_type.is_integer():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, op(lhs, rhs))
def _shift_operator(self, op: Callable, op_name: str,
lhs: Any, rhs: Any) -> 'ProgramObject':
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if not lhs_type.is_integer() or not rhs_type.is_integer():
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
lhs_type = self.program_._type_index.integer_promotions(lhs_type)
rhs_type = self.program_._type_index.integer_promotions(rhs_type)
return ProgramObject(self.program_, None, lhs_type, op(lhs, rhs))
def _relational_operator(self, op: Callable, op_name: str,
other: Any) -> bool:
lhs_pointer = isinstance(self._real_type, (ArrayType, PointerType))
rhs_pointer = (isinstance(other, ProgramObject) and
isinstance(other._real_type, (ArrayType, PointerType)))
lhs, lhs_type, rhs, rhs_type = self._binary_operands(self, other)
if ((lhs_pointer != rhs_pointer) or
(not lhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if not lhs_pointer:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return op(lhs, rhs)
def _add(self, lhs: Any, rhs: Any) -> 'ProgramObject':
lhs_pointer = (isinstance(lhs, ProgramObject) and
isinstance(lhs._real_type, (ArrayType, PointerType)))
rhs_pointer = (isinstance(rhs, ProgramObject) and
isinstance(rhs._real_type, (ArrayType, PointerType)))
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if ((lhs_pointer and rhs_pointer) or
(lhs_pointer and not rhs_type.is_integer()) or
(rhs_pointer and not lhs_type.is_integer()) or
(not lhs_pointer and not rhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary + ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if lhs_pointer:
assert isinstance(lhs_type, PointerType)
return ProgramObject(self.program_, None, lhs_type,
lhs + lhs_type.type.sizeof() * rhs)
elif rhs_pointer:
assert isinstance(rhs_type, PointerType)
return ProgramObject(self.program_, None, rhs_type,
rhs + rhs_type.type.sizeof() * lhs)
else:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, lhs + rhs)
def _sub(self, lhs: Any, rhs: Any) -> 'ProgramObject':
lhs_pointer = (isinstance(lhs, ProgramObject) and
isinstance(lhs._real_type, (ArrayType, PointerType)))
if lhs_pointer:
lhs_sizeof = lhs._real_type.type.sizeof()
rhs_pointer = (isinstance(rhs, ProgramObject) and
isinstance(rhs._real_type, (ArrayType, PointerType)))
if rhs_pointer:
rhs_sizeof = rhs._real_type.type.sizeof()
lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs)
if ((lhs_pointer and rhs_pointer and lhs_sizeof != rhs_sizeof) or
(lhs_pointer and not rhs_pointer and not rhs_type.is_integer()) or
(rhs_pointer and not lhs_pointer) or
(not lhs_pointer and not rhs_pointer and
(not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))):
raise TypeError(f"invalid operands to binary - ('{lhs_type}' and '{rhs_type}')")
lhs_type = self.program_._type_index.operand_type(lhs_type)
rhs_type = self.program_._type_index.operand_type(rhs_type)
if lhs_pointer and rhs_pointer:
return ProgramObject(self.program_, None,
self.program_._type_index.ptrdiff_t(),
(lhs - rhs) // lhs_sizeof)
elif lhs_pointer:
return ProgramObject(self.program_, None, lhs_type,
lhs - lhs_sizeof * rhs)
else:
type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type,
rhs, rhs_type)
return ProgramObject(self.program_, None, type_, lhs - rhs)
def __add__(self, other: Any) -> 'ProgramObject':
return self._add(self, other)
def __sub__(self, other: Any) -> 'ProgramObject':
return self._sub(self, other)
def __mul__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.mul, '*', self, other)
def __truediv__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.truediv, '/', self, other)
def __mod__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(_c_modulo, '%', self, other)
def __lshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.lshift, '<<', self, other)
def __rshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.rshift, '>>', self, other)
def __and__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.and_, '&', self, other)
def __xor__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.xor, '^', self, other)
def __or__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.or_, '|', self, other)
def __radd__(self, other: Any) -> 'ProgramObject':
return self._add(other, self)
def __rsub__(self, other: Any) -> 'ProgramObject':
return self._sub(other, self)
def __rmul__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.mul, '*', other, self)
def __rtruediv__(self, other: Any) -> 'ProgramObject':
return self._arithmetic_operator(operator.truediv, '/', other, self)
def __rmod__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(_c_modulo, '%', other, self)
def __rlshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.lshift, '<<', other, self)
def __rrshift__(self, other: Any) -> 'ProgramObject':
return self._shift_operator(operator.rshift, '>>', other, self)
def __rand__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.and_, '&', other, self)
def __rxor__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.xor, '^', other, self)
def __ror__(self, other: Any) -> 'ProgramObject':
return self._integer_operator(operator.or_, '|', other, self)
def __lt__(self, other: Any) -> bool:
return self._relational_operator(operator.lt, '<', other)
def __le__(self, other: Any) -> bool:
return self._relational_operator(operator.le, '<=', other)
def __eq__(self, other: Any) -> bool:
return self._relational_operator(operator.eq, '==', other)
def __ne__(self, other: Any) -> bool:
return self._relational_operator(operator.ne, '!=', other)
def __gt__(self, other: Any) -> bool:
return self._relational_operator(operator.gt, '>', other)
def __ge__(self, other: Any) -> bool:
return self._relational_operator(operator.ge, '>=', other)
def __bool__(self) -> bool:
if not isinstance(self._real_type, (ArithmeticType, BitFieldType,
PointerType)):
raise TypeError('not an arithmetic or pointer type')
raise TypeError(f"invalid operand to bool() ('{self.type_}')")
return bool(self.value_())
def __neg__(self) -> 'ProgramObject':
self._check_arithmetic_type()
return ProgramObject(self.program_, None, self._unary_type(),
-self.value_())
return self._unary_operator(operator.neg, '-')
def __pos__(self) -> 'ProgramObject':
self._check_arithmetic_type()
return ProgramObject(self.program_, None, self._unary_type(),
+self.value_())
def __abs__(self) -> 'ProgramObject':
self._check_arithmetic_type()
return ProgramObject(self.program_, None, self._unary_type(),
abs(self.value_()))
return self._unary_operator(operator.pos, '+')
def __invert__(self) -> 'ProgramObject':
self._check_integer_type()
return ProgramObject(self.program_, None, self._unary_type(),
~self.value_())
return self._unary_operator(operator.invert, '~', True)
def __int__(self) -> int:
self._check_arithmetic_type()
if not isinstance(self._real_type, (ArithmeticType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to int")
return int(self.value_())
def __float__(self) -> float:
self._check_arithmetic_type()
if not isinstance(self._real_type, (ArithmeticType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to float")
return float(self.value_())
def __index__(self) -> int:
self._check_integer_type()
if not isinstance(self._real_type, (IntType, BitFieldType)):
raise TypeError(f"can't convert {self.type_} to index")
return self.value_()
def __round__(self, ndigits: Optional[int] = None) -> Union[int, float]:
self._check_arithmetic_type()
return round(self.value_(), ndigits)
return round(self.__float__(), ndigits)
def __trunc__(self) -> int:
self._check_arithmetic_type()
return math.trunc(self.value_())
return math.trunc(self.__float__())
def __floor__(self) -> int:
self._check_arithmetic_type()
return math.floor(self.value_())
return math.floor(self.__float__())
def __ceil__(self) -> int:
self._check_arithmetic_type()
return math.ceil(self.value_())
return math.ceil(self.__float__())
class Program:

View File

@ -272,7 +272,7 @@ class TypeIndex:
if (not isinstance(real_type1, (ArithmeticType, BitFieldType)) or
not isinstance(real_type2, (ArithmeticType, BitFieldType))):
raise TypeError('operands must be arithmetic types or bit fields')
raise TypeError('operands must have arithmetic types')
# If either operand is long double, then the result is long double.
if isinstance(real_type1, FloatType) and real_type1.name == 'long double':

View File

@ -1,4 +1,5 @@
import math
import operator
import unittest
from drgn.program import Program, ProgramObject
@ -70,7 +71,6 @@ class TestProgramObject(TypeIndexTestCase):
# _Bool should be the same because of integer promotions.
self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['int'], -1))
self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['int'], 1))
self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['int'], 1))
self.assertEqual(~obj, ProgramObject(self.program, None, TYPES['int'], -2))
self.assertEqual(int(obj), 1)
self.assertEqual(float(obj), 1.0)
@ -88,7 +88,6 @@ class TestProgramObject(TypeIndexTestCase):
self.assertTrue(bool(obj))
self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['double'], -1.5))
self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['double'], 1.5))
self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['double'], 1.5))
with self.assertRaises(TypeError):
~obj
self.assertEqual(int(obj), 1)
@ -202,3 +201,203 @@ class TestProgramObject(TypeIndexTestCase):
self.assertEqual(struct_obj.address_, 0xffff0000)
self.assertEqual(struct_obj.member_('address_'),
ProgramObject(self.program, 0xffff0000, TYPES['unsigned long']))
def test_relational(self):
one = self.program.object(None, TYPES['int'], 1)
two = self.program.object(None, TYPES['int'], 2)
three = self.program.object(None, TYPES['int'], 3)
ptr0 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertTrue(one < two)
self.assertFalse(two < two)
self.assertFalse(three < two)
self.assertTrue(ptr0 < ptr1)
self.assertTrue(one <= two)
self.assertTrue(two <= two)
self.assertFalse(three <= two)
self.assertTrue(ptr0 <= ptr1)
self.assertTrue(one == one)
self.assertFalse(one == two)
self.assertFalse(ptr0 == ptr1)
self.assertFalse(one != one)
self.assertTrue(one != two)
self.assertTrue(ptr0 != ptr1)
self.assertFalse(one > two)
self.assertFalse(two > two)
self.assertTrue(three > two)
self.assertFalse(ptr0 > ptr1)
self.assertFalse(one >= two)
self.assertTrue(two >= two)
self.assertTrue(three >= two)
self.assertFalse(ptr0 >= ptr1)
negative_one = self.program.object(None, TYPES['int'], -1)
unsigned_zero = self.program.object(None, TYPES['unsigned int'], 0)
# The usual arithmetic conversions convert -1 to an unsigned int.
self.assertFalse(negative_one < unsigned_zero)
self.assertTrue(self.program.object(None, TYPES['int'], 1) ==
self.program.object(None, TYPES['_Bool'], 1))
self.assertRaises(TypeError, operator.lt, ptr0, one)
def _test_arithmetic(self, op, lhs, rhs, result, integral=True,
floating_point=False):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def LONG(value):
return self.program.object(None, TYPES['long'], value)
def DOUBLE(value):
return self.program.object(None, TYPES['double'], value)
if integral:
self.assertEqual(op(INT(lhs), INT(rhs)), INT(result))
self.assertEqual(op(INT(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(INT(lhs), rhs), INT(result))
self.assertEqual(op(LONG(lhs), rhs), LONG(result))
self.assertEqual(op(lhs, INT(rhs)), INT(result))
self.assertEqual(op(lhs, LONG(rhs)), LONG(result))
if floating_point:
self.assertEqual(op(DOUBLE(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(DOUBLE(lhs), INT(rhs)), DOUBLE(result))
self.assertEqual(op(INT(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(DOUBLE(lhs), float(rhs)), DOUBLE(result))
self.assertEqual(op(float(lhs), DOUBLE(rhs)), DOUBLE(result))
self.assertEqual(op(float(lhs), INT(rhs)), DOUBLE(result))
self.assertEqual(op(INT(lhs), float(rhs)), DOUBLE(result))
def _test_pointer_type_errors(self, op):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def POINTER(value):
return self.program.object(None,
self.type_index.pointer(TYPES['int']),
value)
self.assertRaises(TypeError, op, INT(1), POINTER(1))
self.assertRaises(TypeError, op, POINTER(1), INT(1))
self.assertRaises(TypeError, op, POINTER(1), POINTER(1))
def _test_floating_type_errors(self, op):
def INT(value):
return self.program.object(None, TYPES['int'], value)
def DOUBLE(value):
return self.program.object(None, TYPES['double'], value)
self.assertRaises(TypeError, op, INT(1), DOUBLE(1))
self.assertRaises(TypeError, op, DOUBLE(1), INT(1))
self.assertRaises(TypeError, op, DOUBLE(1), DOUBLE(1))
def _test_shift(self, op, lhs, rhs, result):
def BOOL(value):
return self.program.object(None, TYPES['_Bool'], value)
def INT(value):
return self.program.object(None, TYPES['int'], value)
def LONG(value):
return self.program.object(None, TYPES['long'], value)
self.assertEqual(op(INT(lhs), INT(rhs)), INT(result))
self.assertEqual(op(INT(lhs), LONG(rhs)), INT(result))
self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result))
self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result))
self.assertEqual(op(INT(lhs), rhs), INT(result))
self.assertEqual(op(LONG(lhs), rhs), LONG(result))
self.assertEqual(op(lhs, INT(rhs)), INT(result))
self.assertEqual(op(lhs, LONG(rhs)), INT(result))
self._test_pointer_type_errors(op)
self._test_floating_type_errors(op)
def test_add(self):
self._test_arithmetic(operator.add, 2, 2, 4, floating_point=True)
one = self.program.object(None, TYPES['int'], 1)
ptr = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertEqual(ptr + one, ptr1)
self.assertEqual(one + ptr, ptr1)
self.assertEqual(ptr + 1, ptr1)
self.assertEqual(1 + ptr, ptr1)
self.assertRaises(TypeError, operator.add, ptr, ptr)
self.assertRaises(TypeError, operator.add, ptr, 2.0)
self.assertRaises(TypeError, operator.add, 2.0, ptr)
def test_sub(self):
self._test_arithmetic(operator.sub, 4, 2, 2, floating_point=True)
ptr = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0000)
ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']),
0xffff0004)
self.assertEqual(ptr1 - ptr,
ProgramObject(self.program, None, TYPES['ptrdiff_t'], 1))
self.assertEqual(ptr - ptr1,
ProgramObject(self.program, None, TYPES['ptrdiff_t'], -1))
self.assertEqual(ptr - 0, ptr)
self.assertEqual(ptr1 - 1, ptr)
self.assertRaises(TypeError, operator.sub, 1, ptr)
self.assertRaises(TypeError, operator.sub, ptr, 1.0)
def test_mul(self):
self._test_arithmetic(operator.mul, 2, 3, 6, floating_point=True)
self._test_pointer_type_errors(operator.mul)
def test_div(self):
self._test_arithmetic(operator.truediv, 6, 3, 2, floating_point=True)
# Make sure we do integer division for integer operands.
self._test_arithmetic(operator.truediv, 3, 2, 1)
# Make sure we truncate towards zero (Python truncates towards negative
# infinity).
self._test_arithmetic(operator.truediv, -1, 2, 0)
self._test_arithmetic(operator.truediv, 1, -2, 0)
self._test_pointer_type_errors(operator.mul)
def test_mod(self):
self._test_arithmetic(operator.mod, 4, 2, 0)
# Make sure the modulo result has the sign of the dividend (Python uses
# the sign of the divisor).
self._test_arithmetic(operator.mod, 1, 26, 1)
self._test_arithmetic(operator.mod, 1, -26, 1)
self._test_arithmetic(operator.mod, -1, 26, -1)
self._test_arithmetic(operator.mod, -1, -26, -1)
self._test_pointer_type_errors(operator.mod)
self._test_floating_type_errors(operator.mod)
def test_lshift(self):
self._test_shift(operator.lshift, 2, 3, 16)
def test_rshift(self):
self._test_shift(operator.rshift, 16, 3, 2)
def test_and(self):
self._test_arithmetic(operator.and_, 1, 3, 1)
self._test_pointer_type_errors(operator.and_)
self._test_floating_type_errors(operator.and_)
def test_xor(self):
self._test_arithmetic(operator.xor, 1, 3, 2)
self._test_pointer_type_errors(operator.xor)
self._test_floating_type_errors(operator.xor)
def test_or(self):
self._test_arithmetic(operator.or_, 1, 3, 3)
self._test_pointer_type_errors(operator.or_)
self._test_floating_type_errors(operator.or_)

View File

@ -21,7 +21,7 @@ from drgn.type import (
VoidType,
)
from drgn.typeindex import DwarfTypeIndex, TypeIndex
from drgn.typename import BasicTypeName, TypeName
from drgn.typename import BasicTypeName, TypeName, TypedefTypeName
from tests.test_type import (
anonymous_point_type,
const_anonymous_point_type,
@ -48,7 +48,9 @@ TYPES = {
'float': FloatType('float', 4),
'double': FloatType('double', 8),
'long double': FloatType('long double', 16),
'ptrdiff_t': FloatType('long double', 16),
}
TYPES['ptrdiff_t'] = TypedefType('ptrdiff_t', TYPES['long'])
class MockTypeIndex(TypeIndex):
@ -56,7 +58,7 @@ class MockTypeIndex(TypeIndex):
super().__init__(8)
def _find_type(self, type_name: TypeName) -> Type:
if isinstance(type_name, BasicTypeName):
if isinstance(type_name, (BasicTypeName, TypedefTypeName)):
try:
return TYPES[type_name.name]
except KeyError: