diff --git a/drgn/program.py b/drgn/program.py index e08be749..ea38b0df 100644 --- a/drgn/program.py +++ b/drgn/program.py @@ -18,9 +18,17 @@ from drgn.typeindex import TypeIndex import functools import itertools import math +import operator from typing import Any, Callable, Iterable, Optional, Tuple, Union +def _c_modulo(a: int, b: int) -> int: + if a >= 0: + return a % abs(b) + else: + return -(-a % abs(b)) + + class ProgramObject: """ A ProgramObject either represents an object in the memory of a program (an @@ -41,20 +49,26 @@ class ProgramObject: >>> print(prog['jiffies']) (volatile long unsigned int)4326237045 - ProgramObjects try to behave transparently like the object they represent - in C. E.g., structure members can be accessed with the dot (".") operator - and arrays can be subscripted with "[]". + ProgramObjects support C operators wherever possible. E.g., structure + members can be accessed with the dot (".") operator, arrays can be + subscripted with "[]", arithmetic can be performed, and objects can be + compared. >>> print(prog['init_task'].pid) (pid_t)0 >>> print(prog['init_task'].comm[0]) (char)115 + >>> prog['init_task'].nsproxy.mnt_ns.mounts + 1 + ProgramObject(address=None, type=, value=34) + >>> prog['init_task'].nsproxy.mnt_ns.pending_mounts > 0 + False Note that because the structure dereference operator ("->") is not valid syntax in Python, "." is also used to access members of pointers to structures. Similarly, the indirection operator ("*") is not valid syntax in Python, so pointers can be dereferenced with "[0]" (e.g., write "p[0]" - instead of "*p"). + instead of "*p"). The address-of operator ("&") is available as the + address_of_() method. ProgramObject members and methods are named with a trailing underscore to avoid conflicting with structure or union members. The helper methods @@ -263,70 +277,268 @@ class ProgramObject: self._real_type.qualifiers), address) - def _check_arithmetic_type(self) -> None: - if not isinstance(self._real_type, (ArithmeticType, BitFieldType)): - raise TypeError('not an arithmetic type') + def _unary_operator(self, op: Callable, op_name: str, + integer: bool = False) -> 'ProgramObject': + if ((integer and not self._real_type.is_integer()) or + (not integer and not self._real_type.is_arithmetic())): + raise TypeError(f"invalid operand to unary {op_name} ('{self.type_}')") + type_ = self.program_._type_index.operand_type(self.type_) + type_ = self.program_._type_index.integer_promotions(type_) + return ProgramObject(self.program_, None, type_, op(self.value_())) - def _check_integer_type(self) -> None: - if not isinstance(self._real_type, (IntType, BitFieldType)): - raise TypeError('not an integer type') + def _binary_operands(self, lhs: Any, rhs: Any) -> Tuple[Any, Type, Any, Type]: + if (isinstance(lhs, ProgramObject) and isinstance(rhs, ProgramObject) and + lhs.program_ is not rhs.program_): + raise ValueError('operands are from different programs') + if isinstance(lhs, ProgramObject): + lhs_type = lhs.type_ + if isinstance(lhs._real_type, ArrayType): + lhs = lhs.address_ + else: + lhs = lhs.value_() + else: + lhs_type = self.program_._type_index.literal_type(lhs) + if isinstance(rhs, ProgramObject): + rhs_type = rhs.type_ + if isinstance(rhs._real_type, ArrayType): + rhs = rhs.address_ + else: + rhs = rhs.value_() + else: + rhs_type = self.program_._type_index.literal_type(rhs) + return lhs, lhs_type, rhs, rhs_type - def _unary_type(self, integer: bool = False) -> Type: - return self.program_._type_index.integer_promotions(self.type_.unqualified()) + def _usual_arithmetic_conversions(self, lhs: Any, lhs_type: Type, + rhs: Any, rhs_type: Type) -> Tuple[Type, Any, Any]: + type_ = self.program_._type_index.common_real_type(lhs_type, rhs_type) + return type_, type_.convert(lhs), type_.convert(rhs) + + def _arithmetic_operator(self, op: Callable, op_name: str, + lhs: Any, rhs: Any) -> 'ProgramObject': + lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs) + if not lhs_type.is_arithmetic() and not rhs_type.is_arithmetic(): + raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type, + rhs, rhs_type) + return ProgramObject(self.program_, None, type_, op(lhs, rhs)) + + def _integer_operator(self, op: Callable, op_name: str, + lhs: Any, rhs: Any) -> 'ProgramObject': + lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs) + if not lhs_type.is_integer() or not rhs_type.is_integer(): + raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type, + rhs, rhs_type) + return ProgramObject(self.program_, None, type_, op(lhs, rhs)) + + def _shift_operator(self, op: Callable, op_name: str, + lhs: Any, rhs: Any) -> 'ProgramObject': + lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs) + if not lhs_type.is_integer() or not rhs_type.is_integer(): + raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + lhs_type = self.program_._type_index.integer_promotions(lhs_type) + rhs_type = self.program_._type_index.integer_promotions(rhs_type) + return ProgramObject(self.program_, None, lhs_type, op(lhs, rhs)) + + def _relational_operator(self, op: Callable, op_name: str, + other: Any) -> bool: + lhs_pointer = isinstance(self._real_type, (ArrayType, PointerType)) + rhs_pointer = (isinstance(other, ProgramObject) and + isinstance(other._real_type, (ArrayType, PointerType))) + lhs, lhs_type, rhs, rhs_type = self._binary_operands(self, other) + if ((lhs_pointer != rhs_pointer) or + (not lhs_pointer and + (not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))): + raise TypeError(f"invalid operands to binary {op_name} ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + if not lhs_pointer: + type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type, + rhs, rhs_type) + return op(lhs, rhs) + + def _add(self, lhs: Any, rhs: Any) -> 'ProgramObject': + lhs_pointer = (isinstance(lhs, ProgramObject) and + isinstance(lhs._real_type, (ArrayType, PointerType))) + rhs_pointer = (isinstance(rhs, ProgramObject) and + isinstance(rhs._real_type, (ArrayType, PointerType))) + lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs) + if ((lhs_pointer and rhs_pointer) or + (lhs_pointer and not rhs_type.is_integer()) or + (rhs_pointer and not lhs_type.is_integer()) or + (not lhs_pointer and not rhs_pointer and + (not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))): + raise TypeError(f"invalid operands to binary + ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + if lhs_pointer: + assert isinstance(lhs_type, PointerType) + return ProgramObject(self.program_, None, lhs_type, + lhs + lhs_type.type.sizeof() * rhs) + elif rhs_pointer: + assert isinstance(rhs_type, PointerType) + return ProgramObject(self.program_, None, rhs_type, + rhs + rhs_type.type.sizeof() * lhs) + else: + type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type, + rhs, rhs_type) + return ProgramObject(self.program_, None, type_, lhs + rhs) + + def _sub(self, lhs: Any, rhs: Any) -> 'ProgramObject': + lhs_pointer = (isinstance(lhs, ProgramObject) and + isinstance(lhs._real_type, (ArrayType, PointerType))) + if lhs_pointer: + lhs_sizeof = lhs._real_type.type.sizeof() + rhs_pointer = (isinstance(rhs, ProgramObject) and + isinstance(rhs._real_type, (ArrayType, PointerType))) + if rhs_pointer: + rhs_sizeof = rhs._real_type.type.sizeof() + lhs, lhs_type, rhs, rhs_type = self._binary_operands(lhs, rhs) + if ((lhs_pointer and rhs_pointer and lhs_sizeof != rhs_sizeof) or + (lhs_pointer and not rhs_pointer and not rhs_type.is_integer()) or + (rhs_pointer and not lhs_pointer) or + (not lhs_pointer and not rhs_pointer and + (not lhs_type.is_arithmetic() or not rhs_type.is_arithmetic()))): + raise TypeError(f"invalid operands to binary - ('{lhs_type}' and '{rhs_type}')") + lhs_type = self.program_._type_index.operand_type(lhs_type) + rhs_type = self.program_._type_index.operand_type(rhs_type) + if lhs_pointer and rhs_pointer: + return ProgramObject(self.program_, None, + self.program_._type_index.ptrdiff_t(), + (lhs - rhs) // lhs_sizeof) + elif lhs_pointer: + return ProgramObject(self.program_, None, lhs_type, + lhs - lhs_sizeof * rhs) + else: + type_, lhs, rhs = self._usual_arithmetic_conversions(lhs, lhs_type, + rhs, rhs_type) + return ProgramObject(self.program_, None, type_, lhs - rhs) + + def __add__(self, other: Any) -> 'ProgramObject': + return self._add(self, other) + + def __sub__(self, other: Any) -> 'ProgramObject': + return self._sub(self, other) + + def __mul__(self, other: Any) -> 'ProgramObject': + return self._arithmetic_operator(operator.mul, '*', self, other) + + def __truediv__(self, other: Any) -> 'ProgramObject': + return self._arithmetic_operator(operator.truediv, '/', self, other) + + def __mod__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(_c_modulo, '%', self, other) + + def __lshift__(self, other: Any) -> 'ProgramObject': + return self._shift_operator(operator.lshift, '<<', self, other) + + def __rshift__(self, other: Any) -> 'ProgramObject': + return self._shift_operator(operator.rshift, '>>', self, other) + + def __and__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.and_, '&', self, other) + + def __xor__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.xor, '^', self, other) + + def __or__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.or_, '|', self, other) + + def __radd__(self, other: Any) -> 'ProgramObject': + return self._add(other, self) + + def __rsub__(self, other: Any) -> 'ProgramObject': + return self._sub(other, self) + + def __rmul__(self, other: Any) -> 'ProgramObject': + return self._arithmetic_operator(operator.mul, '*', other, self) + + def __rtruediv__(self, other: Any) -> 'ProgramObject': + return self._arithmetic_operator(operator.truediv, '/', other, self) + + def __rmod__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(_c_modulo, '%', other, self) + + def __rlshift__(self, other: Any) -> 'ProgramObject': + return self._shift_operator(operator.lshift, '<<', other, self) + + def __rrshift__(self, other: Any) -> 'ProgramObject': + return self._shift_operator(operator.rshift, '>>', other, self) + + def __rand__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.and_, '&', other, self) + + def __rxor__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.xor, '^', other, self) + + def __ror__(self, other: Any) -> 'ProgramObject': + return self._integer_operator(operator.or_, '|', other, self) + + def __lt__(self, other: Any) -> bool: + return self._relational_operator(operator.lt, '<', other) + + def __le__(self, other: Any) -> bool: + return self._relational_operator(operator.le, '<=', other) + + def __eq__(self, other: Any) -> bool: + return self._relational_operator(operator.eq, '==', other) + + def __ne__(self, other: Any) -> bool: + return self._relational_operator(operator.ne, '!=', other) + + def __gt__(self, other: Any) -> bool: + return self._relational_operator(operator.gt, '>', other) + + def __ge__(self, other: Any) -> bool: + return self._relational_operator(operator.ge, '>=', other) def __bool__(self) -> bool: if not isinstance(self._real_type, (ArithmeticType, BitFieldType, PointerType)): - raise TypeError('not an arithmetic or pointer type') + raise TypeError(f"invalid operand to bool() ('{self.type_}')") return bool(self.value_()) def __neg__(self) -> 'ProgramObject': - self._check_arithmetic_type() - return ProgramObject(self.program_, None, self._unary_type(), - -self.value_()) + return self._unary_operator(operator.neg, '-') def __pos__(self) -> 'ProgramObject': - self._check_arithmetic_type() - return ProgramObject(self.program_, None, self._unary_type(), - +self.value_()) - - def __abs__(self) -> 'ProgramObject': - self._check_arithmetic_type() - return ProgramObject(self.program_, None, self._unary_type(), - abs(self.value_())) + return self._unary_operator(operator.pos, '+') def __invert__(self) -> 'ProgramObject': - self._check_integer_type() - return ProgramObject(self.program_, None, self._unary_type(), - ~self.value_()) + return self._unary_operator(operator.invert, '~', True) def __int__(self) -> int: - self._check_arithmetic_type() + if not isinstance(self._real_type, (ArithmeticType, BitFieldType)): + raise TypeError(f"can't convert {self.type_} to int") return int(self.value_()) def __float__(self) -> float: - self._check_arithmetic_type() + if not isinstance(self._real_type, (ArithmeticType, BitFieldType)): + raise TypeError(f"can't convert {self.type_} to float") return float(self.value_()) def __index__(self) -> int: - self._check_integer_type() + if not isinstance(self._real_type, (IntType, BitFieldType)): + raise TypeError(f"can't convert {self.type_} to index") return self.value_() def __round__(self, ndigits: Optional[int] = None) -> Union[int, float]: - self._check_arithmetic_type() - return round(self.value_(), ndigits) + return round(self.__float__(), ndigits) def __trunc__(self) -> int: - self._check_arithmetic_type() - return math.trunc(self.value_()) + return math.trunc(self.__float__()) def __floor__(self) -> int: - self._check_arithmetic_type() - return math.floor(self.value_()) + return math.floor(self.__float__()) def __ceil__(self) -> int: - self._check_arithmetic_type() - return math.ceil(self.value_()) + return math.ceil(self.__float__()) class Program: diff --git a/drgn/typeindex.py b/drgn/typeindex.py index f257cb9a..499b11cb 100644 --- a/drgn/typeindex.py +++ b/drgn/typeindex.py @@ -272,7 +272,7 @@ class TypeIndex: if (not isinstance(real_type1, (ArithmeticType, BitFieldType)) or not isinstance(real_type2, (ArithmeticType, BitFieldType))): - raise TypeError('operands must be arithmetic types or bit fields') + raise TypeError('operands must have arithmetic types') # If either operand is long double, then the result is long double. if isinstance(real_type1, FloatType) and real_type1.name == 'long double': diff --git a/tests/test_program.py b/tests/test_program.py index 24ce3de7..760ab4d5 100644 --- a/tests/test_program.py +++ b/tests/test_program.py @@ -1,4 +1,5 @@ import math +import operator import unittest from drgn.program import Program, ProgramObject @@ -70,7 +71,6 @@ class TestProgramObject(TypeIndexTestCase): # _Bool should be the same because of integer promotions. self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['int'], -1)) self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['int'], 1)) - self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['int'], 1)) self.assertEqual(~obj, ProgramObject(self.program, None, TYPES['int'], -2)) self.assertEqual(int(obj), 1) self.assertEqual(float(obj), 1.0) @@ -88,7 +88,6 @@ class TestProgramObject(TypeIndexTestCase): self.assertTrue(bool(obj)) self.assertEqual(-obj, ProgramObject(self.program, None, TYPES['double'], -1.5)) self.assertEqual(+obj, ProgramObject(self.program, None, TYPES['double'], 1.5)) - self.assertEqual(abs(obj), ProgramObject(self.program, None, TYPES['double'], 1.5)) with self.assertRaises(TypeError): ~obj self.assertEqual(int(obj), 1) @@ -202,3 +201,203 @@ class TestProgramObject(TypeIndexTestCase): self.assertEqual(struct_obj.address_, 0xffff0000) self.assertEqual(struct_obj.member_('address_'), ProgramObject(self.program, 0xffff0000, TYPES['unsigned long'])) + + def test_relational(self): + one = self.program.object(None, TYPES['int'], 1) + two = self.program.object(None, TYPES['int'], 2) + three = self.program.object(None, TYPES['int'], 3) + ptr0 = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0000) + ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0004) + + self.assertTrue(one < two) + self.assertFalse(two < two) + self.assertFalse(three < two) + self.assertTrue(ptr0 < ptr1) + + self.assertTrue(one <= two) + self.assertTrue(two <= two) + self.assertFalse(three <= two) + self.assertTrue(ptr0 <= ptr1) + + self.assertTrue(one == one) + self.assertFalse(one == two) + self.assertFalse(ptr0 == ptr1) + + self.assertFalse(one != one) + self.assertTrue(one != two) + self.assertTrue(ptr0 != ptr1) + + self.assertFalse(one > two) + self.assertFalse(two > two) + self.assertTrue(three > two) + self.assertFalse(ptr0 > ptr1) + + self.assertFalse(one >= two) + self.assertTrue(two >= two) + self.assertTrue(three >= two) + self.assertFalse(ptr0 >= ptr1) + + negative_one = self.program.object(None, TYPES['int'], -1) + unsigned_zero = self.program.object(None, TYPES['unsigned int'], 0) + # The usual arithmetic conversions convert -1 to an unsigned int. + self.assertFalse(negative_one < unsigned_zero) + + self.assertTrue(self.program.object(None, TYPES['int'], 1) == + self.program.object(None, TYPES['_Bool'], 1)) + + self.assertRaises(TypeError, operator.lt, ptr0, one) + + def _test_arithmetic(self, op, lhs, rhs, result, integral=True, + floating_point=False): + def INT(value): + return self.program.object(None, TYPES['int'], value) + def LONG(value): + return self.program.object(None, TYPES['long'], value) + def DOUBLE(value): + return self.program.object(None, TYPES['double'], value) + + if integral: + self.assertEqual(op(INT(lhs), INT(rhs)), INT(result)) + self.assertEqual(op(INT(lhs), LONG(rhs)), LONG(result)) + self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result)) + self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result)) + self.assertEqual(op(INT(lhs), rhs), INT(result)) + self.assertEqual(op(LONG(lhs), rhs), LONG(result)) + self.assertEqual(op(lhs, INT(rhs)), INT(result)) + self.assertEqual(op(lhs, LONG(rhs)), LONG(result)) + + if floating_point: + self.assertEqual(op(DOUBLE(lhs), DOUBLE(rhs)), DOUBLE(result)) + self.assertEqual(op(DOUBLE(lhs), INT(rhs)), DOUBLE(result)) + self.assertEqual(op(INT(lhs), DOUBLE(rhs)), DOUBLE(result)) + self.assertEqual(op(DOUBLE(lhs), float(rhs)), DOUBLE(result)) + self.assertEqual(op(float(lhs), DOUBLE(rhs)), DOUBLE(result)) + self.assertEqual(op(float(lhs), INT(rhs)), DOUBLE(result)) + self.assertEqual(op(INT(lhs), float(rhs)), DOUBLE(result)) + + def _test_pointer_type_errors(self, op): + def INT(value): + return self.program.object(None, TYPES['int'], value) + def POINTER(value): + return self.program.object(None, + self.type_index.pointer(TYPES['int']), + value) + + self.assertRaises(TypeError, op, INT(1), POINTER(1)) + self.assertRaises(TypeError, op, POINTER(1), INT(1)) + self.assertRaises(TypeError, op, POINTER(1), POINTER(1)) + + def _test_floating_type_errors(self, op): + def INT(value): + return self.program.object(None, TYPES['int'], value) + def DOUBLE(value): + return self.program.object(None, TYPES['double'], value) + + self.assertRaises(TypeError, op, INT(1), DOUBLE(1)) + self.assertRaises(TypeError, op, DOUBLE(1), INT(1)) + self.assertRaises(TypeError, op, DOUBLE(1), DOUBLE(1)) + + def _test_shift(self, op, lhs, rhs, result): + def BOOL(value): + return self.program.object(None, TYPES['_Bool'], value) + def INT(value): + return self.program.object(None, TYPES['int'], value) + def LONG(value): + return self.program.object(None, TYPES['long'], value) + + self.assertEqual(op(INT(lhs), INT(rhs)), INT(result)) + self.assertEqual(op(INT(lhs), LONG(rhs)), INT(result)) + self.assertEqual(op(LONG(lhs), INT(rhs)), LONG(result)) + self.assertEqual(op(LONG(lhs), LONG(rhs)), LONG(result)) + self.assertEqual(op(INT(lhs), rhs), INT(result)) + self.assertEqual(op(LONG(lhs), rhs), LONG(result)) + self.assertEqual(op(lhs, INT(rhs)), INT(result)) + self.assertEqual(op(lhs, LONG(rhs)), INT(result)) + + self._test_pointer_type_errors(op) + self._test_floating_type_errors(op) + + def test_add(self): + self._test_arithmetic(operator.add, 2, 2, 4, floating_point=True) + + one = self.program.object(None, TYPES['int'], 1) + ptr = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0000) + ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0004) + self.assertEqual(ptr + one, ptr1) + self.assertEqual(one + ptr, ptr1) + self.assertEqual(ptr + 1, ptr1) + self.assertEqual(1 + ptr, ptr1) + self.assertRaises(TypeError, operator.add, ptr, ptr) + self.assertRaises(TypeError, operator.add, ptr, 2.0) + self.assertRaises(TypeError, operator.add, 2.0, ptr) + + def test_sub(self): + self._test_arithmetic(operator.sub, 4, 2, 2, floating_point=True) + + ptr = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0000) + ptr1 = self.program.object(None, self.type_index.pointer(TYPES['int']), + 0xffff0004) + self.assertEqual(ptr1 - ptr, + ProgramObject(self.program, None, TYPES['ptrdiff_t'], 1)) + self.assertEqual(ptr - ptr1, + ProgramObject(self.program, None, TYPES['ptrdiff_t'], -1)) + self.assertEqual(ptr - 0, ptr) + self.assertEqual(ptr1 - 1, ptr) + self.assertRaises(TypeError, operator.sub, 1, ptr) + self.assertRaises(TypeError, operator.sub, ptr, 1.0) + + def test_mul(self): + self._test_arithmetic(operator.mul, 2, 3, 6, floating_point=True) + self._test_pointer_type_errors(operator.mul) + + def test_div(self): + self._test_arithmetic(operator.truediv, 6, 3, 2, floating_point=True) + + # Make sure we do integer division for integer operands. + self._test_arithmetic(operator.truediv, 3, 2, 1) + + # Make sure we truncate towards zero (Python truncates towards negative + # infinity). + self._test_arithmetic(operator.truediv, -1, 2, 0) + self._test_arithmetic(operator.truediv, 1, -2, 0) + + self._test_pointer_type_errors(operator.mul) + + def test_mod(self): + self._test_arithmetic(operator.mod, 4, 2, 0) + + # Make sure the modulo result has the sign of the dividend (Python uses + # the sign of the divisor). + self._test_arithmetic(operator.mod, 1, 26, 1) + self._test_arithmetic(operator.mod, 1, -26, 1) + self._test_arithmetic(operator.mod, -1, 26, -1) + self._test_arithmetic(operator.mod, -1, -26, -1) + + self._test_pointer_type_errors(operator.mod) + self._test_floating_type_errors(operator.mod) + + def test_lshift(self): + self._test_shift(operator.lshift, 2, 3, 16) + + def test_rshift(self): + self._test_shift(operator.rshift, 16, 3, 2) + + def test_and(self): + self._test_arithmetic(operator.and_, 1, 3, 1) + self._test_pointer_type_errors(operator.and_) + self._test_floating_type_errors(operator.and_) + + def test_xor(self): + self._test_arithmetic(operator.xor, 1, 3, 2) + self._test_pointer_type_errors(operator.xor) + self._test_floating_type_errors(operator.xor) + + def test_or(self): + self._test_arithmetic(operator.or_, 1, 3, 3) + self._test_pointer_type_errors(operator.or_) + self._test_floating_type_errors(operator.or_) diff --git a/tests/test_typeindex.py b/tests/test_typeindex.py index 93d71e6b..e5ba201b 100644 --- a/tests/test_typeindex.py +++ b/tests/test_typeindex.py @@ -21,7 +21,7 @@ from drgn.type import ( VoidType, ) from drgn.typeindex import DwarfTypeIndex, TypeIndex -from drgn.typename import BasicTypeName, TypeName +from drgn.typename import BasicTypeName, TypeName, TypedefTypeName from tests.test_type import ( anonymous_point_type, const_anonymous_point_type, @@ -48,7 +48,9 @@ TYPES = { 'float': FloatType('float', 4), 'double': FloatType('double', 8), 'long double': FloatType('long double', 16), + 'ptrdiff_t': FloatType('long double', 16), } +TYPES['ptrdiff_t'] = TypedefType('ptrdiff_t', TYPES['long']) class MockTypeIndex(TypeIndex): @@ -56,7 +58,7 @@ class MockTypeIndex(TypeIndex): super().__init__(8) def _find_type(self, type_name: TypeName) -> Type: - if isinstance(type_name, BasicTypeName): + if isinstance(type_name, (BasicTypeName, TypedefTypeName)): try: return TYPES[type_name.name] except KeyError: