diff --git a/docs/exts/drgndoc/ext.py b/docs/exts/drgndoc/ext.py index c5426b39..a19ef8f5 100644 --- a/docs/exts/drgndoc/ext.py +++ b/docs/exts/drgndoc/ext.py @@ -117,7 +117,7 @@ class DrgnDocDirective(sphinx.util.docutils.SphinxDirective): name = ".".join(parts) resolved = self.env.drgndoc_namespace.resolve_global_name(name) if not isinstance(resolved, ResolvedNode): - logger.warning("name %r not found", name, resolved) + logger.warning("name %r not found", resolved) return [] docnode = docutils.nodes.section() diff --git a/docs/exts/drgndoc/parse.py b/docs/exts/drgndoc/parse.py index dd1e5b8c..460d2307 100644 --- a/docs/exts/drgndoc/parse.py +++ b/docs/exts/drgndoc/parse.py @@ -15,9 +15,55 @@ from typing import ( Tuple, Union, cast, + overload, ) -from drgndoc.visitor import NodeVisitor, transform_constant_nodes +from drgndoc.visitor import NodeVisitor + + +class _PreTransformer(ast.NodeTransformer): + # Replace string forward references with the parsed expression. + def _visit_annotation(self, node): + if isinstance(node, ast.Constant) and isinstance(node.value, str): + node = self.visit(ast.parse(node.value, "", "eval")) + return node + + def visit_arg(self, node): + node = self.generic_visit(node) + node.annotation = self._visit_annotation(node.annotation) + return node + + def visit_FunctionDef(self, node): + node = self.generic_visit(node) + node.returns = self._visit_annotation(node.returns) + return node + + def visit_AsyncFunctionDef(self, node): + node = self.generic_visit(node) + node.returns = self._visit_annotation(node.returns) + return node + + def visit_AnnAssign(self, node): + node = self.generic_visit(node) + node.annotation = self._visit_annotation(node.annotation) + return node + + # Replace the old constant nodes produced by ast.parse() before Python 3.8 + # with Constant. + def visit_Num(self, node: ast.Num) -> ast.Constant: + return ast.copy_location(ast.Constant(node.n), node) + + def visit_Str(self, node: ast.Str) -> ast.Constant: + return ast.copy_location(ast.Constant(node.s), node) + + def visit_Bytes(self, node: ast.Bytes) -> ast.Constant: + return ast.copy_location(ast.Constant(node.s), node) + + def visit_Ellipsis(self, node: ast.Ellipsis) -> ast.Constant: + return ast.copy_location(ast.Constant(...), node) + + def visit_NameConstant(self, node: ast.NameConstant) -> ast.Constant: + return ast.copy_location(ast.Constant(node.value), node) # Once we don't care about Python 3.6, we can replace all of this boilerplate @@ -227,8 +273,8 @@ class _ModuleVisitor(NodeVisitor): def parse_source( source: str, filename: str ) -> Tuple[Optional[str], Dict[str, NonModuleNode]]: - node = transform_constant_nodes(ast.parse(source, filename)) - return _ModuleVisitor().visit(node) + node = ast.parse(source, filename) + return _ModuleVisitor().visit(_PreTransformer().visit(node)) def _default_handle_err(e: Exception) -> None: diff --git a/docs/exts/drgndoc/visitor.py b/docs/exts/drgndoc/visitor.py index 26581a31..d2c0cb3e 100644 --- a/docs/exts/drgndoc/visitor.py +++ b/docs/exts/drgndoc/visitor.py @@ -38,32 +38,3 @@ class NodeVisitor: self._visit(prev, node, None) elif isinstance(value, ast.AST): self._visit(value, node, None) - - -class _ConstantNodeTransformer(ast.NodeTransformer): - def visit_Num(self, node: ast.Num) -> ast.Constant: - return ast.copy_location(ast.Constant(node.n), node) - - def visit_Str(self, node: ast.Str) -> ast.Constant: - return ast.copy_location(ast.Constant(node.s), node) - - def visit_Bytes(self, node: ast.Bytes) -> ast.Constant: - return ast.copy_location(ast.Constant(node.s), node) - - def visit_Ellipsis(self, node: ast.Ellipsis) -> ast.Constant: - return ast.copy_location(ast.Constant(...), node) - - def visit_NameConstant(self, node: ast.NameConstant) -> ast.Constant: - return ast.copy_location(ast.Constant(node.value), node) - - -def transform_constant_nodes(node: ast.AST) -> ast.AST: - """ - Since Python 3.8, ast.parse() and friends produce Constant nodes instead of - the more specific constant classes. This replaces occurrences of the old - nodes with Constant to simplify consumers. - """ - if sys.version_info >= (3, 8): - return node - else: - return _ConstantNodeTransformer().visit(node)