type: add Type.is_arithmetic() and Type.is_integer()

This will be used for ProgramObject operators as a shortcut for
isinstance(type.real_type(), (ArithmeticType, BitFieldType).
This commit is contained in:
Omar Sandoval 2018-05-06 00:28:07 -07:00
parent da83e0adb3
commit 987f9be6db
2 changed files with 45 additions and 0 deletions

View File

@ -135,6 +135,22 @@ class Type:
""" """
return self return self
def is_arithmetic(self) -> bool:
"""
Return whether this type is an arithmetic type. This is true for
instances of ArithmeticType, BitFieldType, and TypedefType if the
underlying type is one of those.
"""
return False
def is_integer(self) -> bool:
"""
Return whether this type is an integer type. This is true for instances
of IntType, BitFieldType, and TypedefType if the underlying type is one
of those.
"""
return False
class VoidType(Type): class VoidType(Type):
""" """
@ -208,6 +224,9 @@ class ArithmeticType(Type):
def unqualified(self) -> 'ArithmeticType': def unqualified(self) -> 'ArithmeticType':
raise NotImplementedError() raise NotImplementedError()
def is_arithmetic(self) -> bool:
return True
def _int_convert(value: int, bit_size: int, signed: bool) -> int: def _int_convert(value: int, bit_size: int, signed: bool) -> int:
value %= 1 << bit_size value %= 1 << bit_size
@ -260,6 +279,9 @@ class IntType(ArithmeticType):
return self return self
return IntType(self.name, self.size, self.signed) return IntType(self.name, self.size, self.signed)
def is_integer(self) -> bool:
return True
class BoolType(IntType): class BoolType(IntType):
""" """
@ -412,6 +434,12 @@ class BitFieldType(Type):
return BitFieldType(self.type.unqualified(), self.bit_offset, return BitFieldType(self.type.unqualified(), self.bit_offset,
self.bit_size) self.bit_size)
def is_arithmetic(self) -> bool:
return True
def is_integer(self) -> bool:
return True
_TypeThunk = Callable[[], Type] _TypeThunk = Callable[[], Type]
@ -813,6 +841,12 @@ class TypedefType(Type):
type_ = type_.type type_ = type_.type
return type_ return type_
def is_arithmetic(self) -> bool:
return self.type.is_arithmetic()
def is_integer(self) -> bool:
return self.type.is_integer()
class PointerType(Type): class PointerType(Type):
""" """

View File

@ -85,6 +85,8 @@ class TestType(TypeTestCase):
self.assertRaises(ValueError, type_.sizeof) self.assertRaises(ValueError, type_.sizeof)
self.assertRaises(ValueError, type_.read, b'') self.assertRaises(ValueError, type_.read, b'')
self.assertRaises(ValueError, type_.read_pretty, b'') self.assertRaises(ValueError, type_.read_pretty, b'')
self.assertFalse(type_.is_arithmetic())
self.assertFalse(type_.is_integer())
def test_int(self): def test_int(self):
type_ = IntType('int', 4, True) type_ = IntType('int', 4, True)
@ -98,6 +100,8 @@ class TestType(TypeTestCase):
self.assertEqual(type_.read(buffer, 2), -1) self.assertEqual(type_.read(buffer, 2), -1)
self.assertRaises(ValueError, type_.read, buffer, 3) self.assertRaises(ValueError, type_.read, buffer, 3)
self.assertEqual(type_.real_type(), type_) self.assertEqual(type_.real_type(), type_)
self.assertTrue(type_.is_arithmetic())
self.assertTrue(type_.is_integer())
type_ = IntType('unsigned long', 8, False) type_ = IntType('unsigned long', 8, False)
buffer = b'\0' + (99).to_bytes(8, sys.byteorder) buffer = b'\0' + (99).to_bytes(8, sys.byteorder)
@ -113,6 +117,8 @@ class TestType(TypeTestCase):
self.assertEqual(type_.read(buffer), 3.14) self.assertEqual(type_.read(buffer), 3.14)
self.assertEqual(type_.read(b'\0' + buffer, 1), 3.14) self.assertEqual(type_.read(b'\0' + buffer, 1), 3.14)
self.assertRaises(ValueError, type_.read, buffer, 1) self.assertRaises(ValueError, type_.read, buffer, 1)
self.assertTrue(type_.is_arithmetic())
self.assertFalse(type_.is_integer())
type_ = FloatType('float', 4) type_ = FloatType('float', 4)
buffer = struct.pack('f', 1.5) buffer = struct.pack('f', 1.5)
@ -147,10 +153,14 @@ class TestType(TypeTestCase):
self.assertEqual(type_.sizeof(), 4) self.assertEqual(type_.sizeof(), 4)
self.assertEqual(type_.read(b'\0\0\0\0'), 0) self.assertEqual(type_.read(b'\0\0\0\0'), 0)
self.assertEqual(type_.read_pretty(b'\0\0\0\0'), '(INT)0') self.assertEqual(type_.read_pretty(b'\0\0\0\0'), '(INT)0')
self.assertTrue(type_.is_arithmetic())
self.assertTrue(type_.is_integer())
type_ = TypedefType('string', PointerType(pointer_size, IntType('char', 1, True))) type_ = TypedefType('string', PointerType(pointer_size, IntType('char', 1, True)))
self.assertEqual(str(type_), 'typedef char *string') self.assertEqual(str(type_), 'typedef char *string')
self.assertEqual(type_.sizeof(), pointer_size) self.assertEqual(type_.sizeof(), pointer_size)
self.assertFalse(type_.is_arithmetic())
self.assertFalse(type_.is_integer())
type_ = TypedefType('CINT', IntType('int', 4, True, {'const'})) type_ = TypedefType('CINT', IntType('int', 4, True, {'const'}))
self.assertEqual(str(type_), 'typedef const int CINT') self.assertEqual(str(type_), 'typedef const int CINT')
@ -286,6 +296,7 @@ struct {
type_ = BitFieldType(IntType('int', 4, True), 0, 4) type_ = BitFieldType(IntType('int', 4, True), 0, 4)
self.assertEqual(str(type_), 'int : 4') self.assertEqual(str(type_), 'int : 4')
self.assertRaises(ValueError, type_.type_name) self.assertRaises(ValueError, type_.type_name)
self.assertTrue(type_.is_arithmetic())
def test_union(self): def test_union(self):
type_ = UnionType('value', 4, [ type_ = UnionType('value', 4, [