drgndoc: handle forward references

String annotations (i.e., forward references) need to be parsed into an
ast node. Do it as a transformation step immediately after parsing the
source. We can also squash the constant node transformation into this
one.
This commit is contained in:
Omar Sandoval 2020-02-26 16:04:32 -08:00
parent 376979d25a
commit 8be7ae5299
3 changed files with 50 additions and 33 deletions

View File

@ -117,7 +117,7 @@ class DrgnDocDirective(sphinx.util.docutils.SphinxDirective):
name = ".".join(parts)
resolved = self.env.drgndoc_namespace.resolve_global_name(name)
if not isinstance(resolved, ResolvedNode):
logger.warning("name %r not found", name, resolved)
logger.warning("name %r not found", resolved)
return []
docnode = docutils.nodes.section()

View File

@ -15,9 +15,55 @@ from typing import (
Tuple,
Union,
cast,
overload,
)
from drgndoc.visitor import NodeVisitor, transform_constant_nodes
from drgndoc.visitor import NodeVisitor
class _PreTransformer(ast.NodeTransformer):
# Replace string forward references with the parsed expression.
def _visit_annotation(self, node):
if isinstance(node, ast.Constant) and isinstance(node.value, str):
node = self.visit(ast.parse(node.value, "<string>", "eval"))
return node
def visit_arg(self, node):
node = self.generic_visit(node)
node.annotation = self._visit_annotation(node.annotation)
return node
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
node.returns = self._visit_annotation(node.returns)
return node
def visit_AsyncFunctionDef(self, node):
node = self.generic_visit(node)
node.returns = self._visit_annotation(node.returns)
return node
def visit_AnnAssign(self, node):
node = self.generic_visit(node)
node.annotation = self._visit_annotation(node.annotation)
return node
# Replace the old constant nodes produced by ast.parse() before Python 3.8
# with Constant.
def visit_Num(self, node: ast.Num) -> ast.Constant:
return ast.copy_location(ast.Constant(node.n), node)
def visit_Str(self, node: ast.Str) -> ast.Constant:
return ast.copy_location(ast.Constant(node.s), node)
def visit_Bytes(self, node: ast.Bytes) -> ast.Constant:
return ast.copy_location(ast.Constant(node.s), node)
def visit_Ellipsis(self, node: ast.Ellipsis) -> ast.Constant:
return ast.copy_location(ast.Constant(...), node)
def visit_NameConstant(self, node: ast.NameConstant) -> ast.Constant:
return ast.copy_location(ast.Constant(node.value), node)
# Once we don't care about Python 3.6, we can replace all of this boilerplate
@ -227,8 +273,8 @@ class _ModuleVisitor(NodeVisitor):
def parse_source(
source: str, filename: str
) -> Tuple[Optional[str], Dict[str, NonModuleNode]]:
node = transform_constant_nodes(ast.parse(source, filename))
return _ModuleVisitor().visit(node)
node = ast.parse(source, filename)
return _ModuleVisitor().visit(_PreTransformer().visit(node))
def _default_handle_err(e: Exception) -> None:

View File

@ -38,32 +38,3 @@ class NodeVisitor:
self._visit(prev, node, None)
elif isinstance(value, ast.AST):
self._visit(value, node, None)
class _ConstantNodeTransformer(ast.NodeTransformer):
def visit_Num(self, node: ast.Num) -> ast.Constant:
return ast.copy_location(ast.Constant(node.n), node)
def visit_Str(self, node: ast.Str) -> ast.Constant:
return ast.copy_location(ast.Constant(node.s), node)
def visit_Bytes(self, node: ast.Bytes) -> ast.Constant:
return ast.copy_location(ast.Constant(node.s), node)
def visit_Ellipsis(self, node: ast.Ellipsis) -> ast.Constant:
return ast.copy_location(ast.Constant(...), node)
def visit_NameConstant(self, node: ast.NameConstant) -> ast.Constant:
return ast.copy_location(ast.Constant(node.value), node)
def transform_constant_nodes(node: ast.AST) -> ast.AST:
"""
Since Python 3.8, ast.parse() and friends produce Constant nodes instead of
the more specific constant classes. This replaces occurrences of the old
nodes with Constant to simplify consumers.
"""
if sys.version_info >= (3, 8):
return node
else:
return _ConstantNodeTransformer().visit(node)