type: add Type.unqualified()

This commit is contained in:
Omar Sandoval 2018-07-11 21:51:48 -07:00
parent 9b5b721838
commit 6fa2d68c0c
2 changed files with 59 additions and 24 deletions

View File

@ -164,12 +164,18 @@ class Type:
""" """
return self return self
def unqualified(self) -> 'Type':
"""
Return this type without qualifiers.
"""
raise NotImplementedError()
def operand_type(self) -> 'Type': def operand_type(self) -> 'Type':
""" """
Return the type that this type is converted to when used in an Return the type that this type is converted to when used in an
expression. expression.
""" """
raise NotImplementedError() return self.unqualified()
def is_arithmetic(self) -> bool: def is_arithmetic(self) -> bool:
""" """
@ -217,7 +223,7 @@ class VoidType(Type):
def convert(self, value: Any) -> None: def convert(self, value: Any) -> None:
return None return None
def operand_type(self) -> 'VoidType': def unqualified(self) -> 'VoidType':
if self.qualifiers: if self.qualifiers:
return VoidType() return VoidType()
return self return self
@ -310,7 +316,7 @@ class IntType(ArithmeticType):
raise TypeError(f'cannot convert to {self}') raise TypeError(f'cannot convert to {self}')
return _int_convert(math.trunc(value), 8 * self.size, self.signed) return _int_convert(math.trunc(value), 8 * self.size, self.signed)
def operand_type(self) -> 'IntType': def unqualified(self) -> 'IntType':
if self.qualifiers: if self.qualifiers:
return IntType(self.name, self.size, self.signed) return IntType(self.name, self.size, self.signed)
return self return self
@ -346,7 +352,7 @@ class BoolType(IntType):
raise TypeError(f'cannot convert to {self}') raise TypeError(f'cannot convert to {self}')
return bool(value) return bool(value)
def operand_type(self) -> 'BoolType': def unqualified(self) -> 'BoolType':
if self.qualifiers: if self.qualifiers:
return BoolType(self.name, self.size) return BoolType(self.name, self.size)
return self return self
@ -375,7 +381,7 @@ class FloatType(ArithmeticType):
else: else:
raise ValueError(f"can't convert to float of size {self.size}") raise ValueError(f"can't convert to float of size {self.size}")
def operand_type(self) -> 'FloatType': def unqualified(self) -> 'FloatType':
if self.qualifiers: if self.qualifiers:
return FloatType(self.name, self.size) return FloatType(self.name, self.size)
return self return self
@ -457,9 +463,9 @@ class BitFieldType(Type):
raise TypeError(f'cannot convert to {self}') raise TypeError(f'cannot convert to {self}')
return _int_convert(math.trunc(value), self.bit_size, self.type.signed) return _int_convert(math.trunc(value), self.bit_size, self.type.signed)
def operand_type(self) -> 'BitFieldType': def unqualified(self) -> 'BitFieldType':
if self.type.qualifiers: if self.type.qualifiers:
return BitFieldType(self.type.operand_type(), self.bit_offset, return BitFieldType(self.type.unqualified(), self.bit_offset,
self.bit_size) self.bit_size)
return self return self
@ -656,7 +662,7 @@ class StructType(CompoundType):
def type_name(self) -> StructTypeName: def type_name(self) -> StructTypeName:
return StructTypeName(self.name, self.qualifiers) return StructTypeName(self.name, self.qualifiers)
def operand_type(self) -> 'StructType': def unqualified(self) -> 'StructType':
if self.qualifiers: if self.qualifiers:
return StructType(self.name, self.size, self._members) return StructType(self.name, self.size, self._members)
return self return self
@ -684,7 +690,7 @@ class UnionType(CompoundType):
def type_name(self) -> UnionTypeName: def type_name(self) -> UnionTypeName:
return UnionTypeName(self.name, self.qualifiers) return UnionTypeName(self.name, self.qualifiers)
def operand_type(self) -> 'UnionType': def unqualified(self) -> 'UnionType':
if self.qualifiers: if self.qualifiers:
return UnionType(self.name, self.size, self._members) return UnionType(self.name, self.size, self._members)
return self return self
@ -806,7 +812,7 @@ class EnumType(Type):
pass pass
return value return value
def operand_type(self) -> 'EnumType': def unqualified(self) -> 'EnumType':
if self.qualifiers: if self.qualifiers:
return EnumType(self.name, self.type, return EnumType(self.name, self.type,
None if self.enum is None else self.enum.__members__) None if self.enum is None else self.enum.__members__)
@ -883,6 +889,11 @@ class TypedefType(Type):
type_ = type_.type type_ = type_.type
return type_ return type_
def unqualified(self) -> 'TypedefType':
if self.qualifiers:
return TypedefType(self.name, self.type)
return self
def operand_type(self) -> Type: def operand_type(self) -> Type:
type_ = self.type type_ = self.type
while isinstance(type_, TypedefType): while isinstance(type_, TypedefType):
@ -957,7 +968,7 @@ class PointerType(Type):
raise TypeError(f'cannot convert to {self}') raise TypeError(f'cannot convert to {self}')
return _int_convert(int(value), 8 * self.size, False) return _int_convert(int(value), 8 * self.size, False)
def operand_type(self) -> 'PointerType': def unqualified(self) -> 'PointerType':
if self.qualifiers: if self.qualifiers:
return PointerType(self.size, self.type) return PointerType(self.size, self.type)
return self return self
@ -1041,6 +1052,9 @@ class ArrayType(Type):
parts.append('}') parts.append('}')
return ''.join(parts) return ''.join(parts)
def unqualified(self) -> 'ArrayType':
return self
def operand_type(self) -> 'PointerType': def operand_type(self) -> 'PointerType':
return PointerType(self.pointer_size, self.type) return PointerType(self.pointer_size, self.type)
@ -1103,6 +1117,9 @@ class FunctionType(Type):
def pretty(self, value: Any, cast: bool = True) -> str: def pretty(self, value: Any, cast: bool = True) -> str:
raise ValueError("can't format function") raise ValueError("can't format function")
def unqualified(self) -> 'FunctionType':
return self
def operand_type(self) -> 'PointerType': def operand_type(self) -> 'PointerType':
return PointerType(self.pointer_size, self) return PointerType(self.pointer_size, self)

View File

@ -545,29 +545,38 @@ class TestConvert(unittest.TestCase):
self.assertEqual(type_.convert(2**64 + 1), 1) self.assertEqual(type_.convert(2**64 + 1), 1)
class TestOperandType(TypeTestCase): class TestUnqualifiedAndOperandType(TypeTestCase):
def assertUnqualifiedType(self, type_, expected):
for i in range(2):
type_ = type_.unqualified()
self.assertEqual(type_, expected)
def assertOperandType(self, type_, expected): def assertOperandType(self, type_, expected):
for i in range(2): for i in range(2):
type_ = type_.operand_type() type_ = type_.operand_type()
self.assertEqual(type_, expected) self.assertEqual(type_, expected)
def assertBoth(self, type_, expected):
self.assertUnqualifiedType(type_, expected)
self.assertOperandType(type_, expected)
def test_void(self): def test_void(self):
self.assertOperandType(VoidType(frozenset({'const'})), VoidType()) self.assertBoth(VoidType(frozenset({'const'})), VoidType())
def test_int(self): def test_int(self):
self.assertOperandType(IntType('int', 4, True, frozenset({'const'})), self.assertBoth(IntType('int', 4, True, frozenset({'const'})),
IntType('int', 4, True)) IntType('int', 4, True))
def test_bool(self): def test_bool(self):
self.assertOperandType(BoolType('_Bool', 1, frozenset({'const'})), self.assertBoth(BoolType('_Bool', 1, frozenset({'const'})),
BoolType('_Bool', 1)) BoolType('_Bool', 1))
def test_float(self): def test_float(self):
self.assertOperandType(FloatType('double', 8, frozenset({'const'})), self.assertBoth(FloatType('double', 8, frozenset({'const'})),
FloatType('double', 8)) FloatType('double', 8))
def test_bit_field(self): def test_bit_field(self):
self.assertOperandType(BitFieldType(IntType('int', 4, True, frozenset({'const'})), 0, 4), self.assertBoth(BitFieldType(IntType('int', 4, True, frozenset({'const'})), 0, 4),
BitFieldType(IntType('int', 4, True), 0, 4)) BitFieldType(IntType('int', 4, True), 0, 4))
def test_struct(self): def test_struct(self):
@ -575,7 +584,7 @@ class TestOperandType(TypeTestCase):
('x', 0, lambda: IntType('int', 4, True)), ('x', 0, lambda: IntType('int', 4, True)),
('y', 4, lambda: IntType('int', 4, True)), ('y', 4, lambda: IntType('int', 4, True)),
], frozenset({'const'})) ], frozenset({'const'}))
self.assertOperandType(const_point_type, point_type) self.assertBoth(const_point_type, point_type)
def test_union(self): def test_union(self):
union_type = UnionType('value', 4, [ union_type = UnionType('value', 4, [
@ -586,7 +595,7 @@ class TestOperandType(TypeTestCase):
('i', 0, lambda: IntType('int', 4, True)), ('i', 0, lambda: IntType('int', 4, True)),
('f', 0, lambda: FloatType('float', 4)), ('f', 0, lambda: FloatType('float', 4)),
], frozenset({'const'})) ], frozenset({'const'}))
self.assertOperandType(const_union_type, union_type) self.assertBoth(const_union_type, union_type)
def test_enum(self): def test_enum(self):
enum_type = EnumType(None, IntType('int', 4, True), [ enum_type = EnumType(None, IntType('int', 4, True), [
@ -599,7 +608,7 @@ class TestOperandType(TypeTestCase):
('GREEN', 11), ('GREEN', 11),
('BLUE', -1) ('BLUE', -1)
], frozenset({'const'})) ], frozenset({'const'}))
self.assertOperandType(const_enum_type, enum_type) self.assertBoth(const_enum_type, enum_type)
def test_typedef(self): def test_typedef(self):
const_typedef_type = TypedefType( const_typedef_type = TypedefType(
@ -610,9 +619,14 @@ class TestOperandType(TypeTestCase):
frozenset({'const'})) frozenset({'const'}))
typedef_type = TypedefType('u32', IntType('unsigned int', 4, False)) typedef_type = TypedefType('u32', IntType('unsigned int', 4, False))
self.assertUnqualifiedType(const_typedef_type, typedef_type)
self.assertOperandType(const_typedef_type, typedef_type) self.assertOperandType(const_typedef_type, typedef_type)
self.assertUnqualifiedType(typedef_const_type, typedef_const_type)
self.assertOperandType(typedef_const_type, self.assertOperandType(typedef_const_type,
IntType('unsigned int', 4, False)) IntType('unsigned int', 4, False))
self.assertUnqualifiedType(const_typedef_const_type, typedef_const_type)
self.assertOperandType(const_typedef_const_type, self.assertOperandType(const_typedef_const_type,
IntType('unsigned int', 4, False)) IntType('unsigned int', 4, False))
@ -630,16 +644,20 @@ class TestOperandType(TypeTestCase):
def test_array(self): def test_array(self):
type_ = ArrayType(IntType('int', 4, True), 2, pointer_size) type_ = ArrayType(IntType('int', 4, True), 2, pointer_size)
self.assertUnqualifiedType(type_, type_)
self.assertOperandType(type_, PointerType(pointer_size, type_.type)) self.assertOperandType(type_, PointerType(pointer_size, type_.type))
typedef_type = TypedefType('pair_t', type_) typedef_type = TypedefType('pair_t', type_)
self.assertUnqualifiedType(typedef_type, typedef_type)
self.assertOperandType(typedef_type, PointerType(pointer_size, type_.type)) self.assertOperandType(typedef_type, PointerType(pointer_size, type_.type))
def test_function(self): def test_function(self):
type_ = FunctionType(pointer_size, VoidType, []) type_ = FunctionType(pointer_size, VoidType, [])
self.assertUnqualifiedType(type_, type_)
self.assertOperandType(type_, PointerType(pointer_size, type_)) self.assertOperandType(type_, PointerType(pointer_size, type_))
typedef_type = TypedefType('callback_t', type_) typedef_type = TypedefType('callback_t', type_)
self.assertUnqualifiedType(typedef_type, typedef_type)
self.assertOperandType(typedef_type, PointerType(pointer_size, type_)) self.assertOperandType(typedef_type, PointerType(pointer_size, type_))