Upgraded to mcpyrate 3.6.0 (last release)

This commit is contained in:
Salvador E. Tropea 2022-01-29 14:38:00 -03:00
parent fbfee8771d
commit 8c3bb74eb7
10 changed files with 392 additions and 113 deletions

View File

@ -1,3 +1,5 @@
"""mcpyrate: Advanced macro expander and language lab for Python."""
from .astdumper import dump # noqa: F401 from .astdumper import dump # noqa: F401
from .core import MacroExpansionError # noqa: F401 from .core import MacroExpansionError # noqa: F401
from .dialects import Dialect # noqa: F401 from .dialects import Dialect # noqa: F401
@ -5,4 +7,12 @@ from .expander import namemacro, parametricmacro # noqa: F401
from .unparser import unparse # noqa: F401 from .unparser import unparse # noqa: F401
from .utils import gensym # noqa: F401 from .utils import gensym # noqa: F401
__version__ = "3.5.4" # For public API inspection, import modules that wouldn't otherwise get imported.
from . import ansi # noqa: F401
from . import debug # noqa: F401
from . import metatools # noqa: F401
from . import pycachecleaner # noqa: F401
from . import quotes # noqa: F401
from . import splicing # noqa: F401
__version__ = "3.6.0"

View File

@ -2,7 +2,8 @@
"""Utilities related to writing macro expanders and similar meta-metaprogramming tasks.""" """Utilities related to writing macro expanders and similar meta-metaprogramming tasks."""
__all__ = ["resolve_package", "relativize", "match_syspath", __all__ = ["resolve_package", "relativize", "match_syspath",
"ismacroimport", "get_macros"] "ismacroimport", "get_macros",
"isfutureimport", "split_futureimports", "inject_after_futureimports"]
import ast import ast
import importlib import importlib
@ -12,7 +13,7 @@ import pathlib
import sys import sys
from .unparser import unparse_with_fallbacks from .unparser import unparse_with_fallbacks
from .utils import format_location from .utils import format_location, getdocstring
def resolve_package(filename): # TODO: for now, `guess_package`, really. Check the docs again. def resolve_package(filename): # TODO: for now, `guess_package`, really. Check the docs again.
@ -90,6 +91,15 @@ def ismacroimport(statement, magicname='macros'):
def get_macros(macroimport, *, filename, reload=False, allow_asname=True, self_module=None): def get_macros(macroimport, *, filename, reload=False, allow_asname=True, self_module=None):
"""Get absolute module name, macro names and macro functions from a macro-import. """Get absolute module name, macro names and macro functions from a macro-import.
A macro-import is a statement of the form::
from ... import macros, ...
where `macros` is the magic name that your actual macro expander uses to recognize
a macro-import (see `ismacroimport`). This function does not care about what the
actual magic name is, and simply ignores the first name that is imported by the
import statement.
As a side effect, import the macro definition module. As a side effect, import the macro definition module.
Return value is `module_absname, {macroname0: macrofunction0, ...}`. Return value is `module_absname, {macroname0: macrofunction0, ...}`.
@ -166,7 +176,7 @@ def get_macros(macroimport, *, filename, reload=False, allow_asname=True, self_m
module = importlib.reload(module) module = importlib.reload(module)
bindings = {} bindings = {}
for name in macroimport.names[1:]: for name in macroimport.names[1:]: # skip the "macros" in `from ... import macros, ...`
if not allow_asname and name.asname is not None: if not allow_asname and name.asname is not None:
approx_sourcecode = unparse_with_fallbacks(macroimport, debug=True, color=True) approx_sourcecode = unparse_with_fallbacks(macroimport, debug=True, color=True)
loc = format_location(filename, macroimport, approx_sourcecode) loc = format_location(filename, macroimport, approx_sourcecode)
@ -241,3 +251,54 @@ def _mcpyrate_attr(dotted_name, *, force_import=False):
value = ast.Attribute(value=value, attr=name) value = ast.Attribute(value=value, attr=name)
return value return value
# --------------------------------------------------------------------------------
def isfutureimport(tree):
"""Return whether `tree` is a `from __future__ import ...`."""
return isinstance(tree, ast.ImportFrom) and tree.module == "__future__"
def split_futureimports(body):
"""Split `body` into `__future__` imports and everything else.
`body`: list of `ast.stmt`, the suite representing a module top level.
Return value is `[docstring, future_imports, the_rest]`, where each item
is a list of `ast.stmt` (possibly empty).
"""
if getdocstring(body):
docstring, *body = body
docstring = [docstring]
else:
docstring = []
k = -1 # ensure `k` gets defined even if `body` is empty
for k, bstmt in enumerate(body):
if not isfutureimport(bstmt):
break
if k >= 0:
return docstring, body[:k], body[k:]
return docstring, [], body
def inject_after_futureimports(stmts, body):
"""Inject one or more statements into `body` after its `__future__` imports.
`stmts`: `ast.stmt` or list of `ast.stmt`, the statement(s) to inject.
`body`: list of `ast.stmt`, the suite representing a module top level.
Return value is the list `[docstring] + futureimports + stmts + the_rest`.
If `body` has no docstring node at its beginning, the docstring part is
automatically omitted.
If `body` has no `__future__` imports at the beginning just after the
optional docstring, the `futureimports` part is automatically omitted.
"""
if not isinstance(body, list):
raise TypeError(f"`body` must be a `list`, got {type(body)} with value {repr(body)}")
if not isinstance(stmts, list):
if not isinstance(stmts, ast.stmt):
raise TypeError(f"`stmts` must be `ast.stmt` or a `list` of `ast.stmt`, got {type(stmts)} with value {repr(stmts)}")
stmts = [stmts]
docstring, futureimports, body = split_futureimports(body)
return docstring + futureimports + stmts + body

View File

@ -15,10 +15,22 @@ from .utils import format_macrofunction
class Dialect: class Dialect:
"""Base class for dialects.""" """Base class for dialects.
`expander`: the `DialectExpander` instance. The expander provides this automatically.
Stored as `self.expander`.
During dialect expansion, the source location info of the dialect-import statement
that invoked this dialect-import is available as `self.lineno` and `self.col_offset`.
You can pass those to `mcpyrate.splicing.splice_dialect` to automatically mark the
lines from your dialect template as coming from that dialect-import in the user
source code.
"""
def __init__(self, expander): def __init__(self, expander):
"""`expander`: the `DialectExpander` instance. The expander provides this automatically."""
self.expander = expander self.expander = expander
self.lineno = None
self.col_offset = None
def transform_source(self, text): def transform_source(self, text):
"""Override this to add a whole-module source transformer to your dialect. """Override this to add a whole-module source transformer to your dialect.
@ -27,10 +39,10 @@ class Dialect:
tells the expander this dialect does not provide a source transformer. tells the expander this dialect does not provide a source transformer.
Rarely needed. Because we don't (yet) have a generic, extensible Rarely needed. Because we don't (yet) have a generic, extensible
tokenizer for "Python-plus" with extended surface syntax, this is tokenizer for "Python-plus" with extended surface syntax, not to mention that
currently essentially a per-module hook to plug in a transpiler none of the available Python dev tools support any such, this is currently
that compiles source code from some other programming language essentially a per-module hook to plug in a transpiler that compiles
into macro-enabled Python. source code from some other programming language into macro-enabled Python.
The dialect system autodetects the text encoding the same way Python itself The dialect system autodetects the text encoding the same way Python itself
does. That is, it reads the magic comment at the top of the source file does. That is, it reads the magic comment at the top of the source file
@ -80,8 +92,8 @@ class Dialect:
Output should be the transformed AST. Output should be the transformed AST.
To easily splice `tree.body` into your template, see the utility To easily splice `tree.body` into your template, see the utility
`mcpyrate.splicing.splice_dialect` (it automatically handles macro-imports, `mcpyrate.splicing.splice_dialect` (it automatically handles future-imports,
dialect-imports, the magic `__all__`, and the module docstring). macro-imports, dialect-imports, the magic `__all__`, and the module docstring).
As an example, see the `dialects` module in `unpythonic` for example dialects. As an example, see the `dialects` module in `unpythonic` for example dialects.
@ -159,7 +171,7 @@ class DialectExpander:
Due to modularity requirements introduced by `mcpyrate`'s support for Due to modularity requirements introduced by `mcpyrate`'s support for
multi-phase compilation (see the module `mcpyrate.multiphase`), this multi-phase compilation (see the module `mcpyrate.multiphase`), this
class is a bit technical to use. See `mcpyrate.importer`. Roughly, class is a bit technical to use. See `mcpyrate.compiler`. Roughly,
for a single-phase compile:: for a single-phase compile::
dexpander = DialectExpander(filename=...) dexpander = DialectExpander(filename=...)
@ -221,11 +233,11 @@ class DialectExpander:
# state is so last decade. # state is so last decade.
dialect_instances = [] dialect_instances = []
while True: while True:
module_absname, bindings = find_dialectimport(content) theimport = find_dialectimport(content)
if not module_absname: # no more dialects if theimport:
module_absname, bindings, lineno, col_offset = theimport
else: # no more dialects
break break
if not bindings:
continue
for dialectname, cls in bindings.items(): for dialectname, cls in bindings.items():
if not (isinstance(cls, type) and issubclass(cls, Dialect)): if not (isinstance(cls, type) and issubclass(cls, Dialect)):
@ -235,6 +247,9 @@ class DialectExpander:
dialect = cls(expander=self) dialect = cls(expander=self)
except Exception as err: except Exception as err:
raise ImportError(f"Unexpected exception while instantiating dialect `{module_absname}.{dialectname}`") from err raise ImportError(f"Unexpected exception while instantiating dialect `{module_absname}.{dialectname}`") from err
# make the dialect-import source location info available to the transformers
dialect.lineno = lineno
dialect.col_offset = col_offset
try: try:
transformer_method = getattr(dialect, transform) transformer_method = getattr(dialect, transform)
@ -338,9 +353,24 @@ class DialectExpander:
So we can only rely on the literal text "from ... import dialects, ...", So we can only rely on the literal text "from ... import dialects, ...",
similarly to how Racket heavily constrains the format of its `#lang` line. similarly to how Racket heavily constrains the format of its `#lang` line.
Return value is a dict `{dialectname: class, ...}` with all collected bindings Return value is the tuple `(module_absname, bindings, lineno, col_offset)`:
from that one dialect-import. Each binding is a dialect, so usually there is
just one. - `module_absname` is the absolute module name referred to by the import
- `bindings` is a dict `{dialectname: class, ...}`, with all bindings
collected from that one dialect-import statement. Each binding is a
dialect, so usually there is just one.
- `lineno` is the line number of the import statement, determined by
counting the lines of `text`.
- `col_offset` is the corresponding column offset.
Currently not extracted; is always set to 0.
The return value refers to the first not-yet-seen dialect-import (according
to the private cache `self._seen`). Note that this does not transform away
the dialect-imports, because the expander still needs to see them in the
AST transformation step.
If there are no more dialect-imports that have not been seen already,
the return value is `None`.
""" """
matches = _dialectimport.finditer(text) matches = _dialectimport.finditer(text)
try: try:
@ -349,15 +379,17 @@ class DialectExpander:
statement = match.group(0).strip() statement = match.group(0).strip()
if statement not in self._seen: # apply each unique dialect-import once if statement not in self._seen: # apply each unique dialect-import once
self._seen.add(statement) self._seen.add(statement)
lineno = 1 + text[0:match.start()].count("\n") # https://stackoverflow.com/a/48647994
col_offset = 0 # TODO: extract the correct column offset
break break
except StopIteration: except StopIteration:
return "", {} return None
dummy_module = ast.parse(statement, filename=self.filename, mode="exec") dummy_module = ast.parse(statement, filename=self.filename, mode="exec")
dialectimport = dummy_module.body[0] dialectimport = dummy_module.body[0]
module_absname, bindings = get_macros(dialectimport, filename=self.filename, module_absname, bindings = get_macros(dialectimport, filename=self.filename,
reload=False, allow_asname=False) reload=False, allow_asname=False)
return module_absname, bindings return module_absname, bindings, lineno, col_offset
def find_dialectimport_ast(self, tree): def find_dialectimport_ast(self, tree):
"""Find the first dialect-import statement by scanning the AST `tree`. """Find the first dialect-import statement by scanning the AST `tree`.
@ -374,20 +406,38 @@ class DialectExpander:
from ... import dialects, ... from ... import dialects, ...
Return value is a dict `{dialectname: class, ...}` with all collected bindings Return value is the tuple `(module_absname, bindings, lineno)`, where:
from that one dialect-import. Each binding is a dialect, so usually there is
just one. - `module_absname` is the absolute module name referred to by the import
- `bindings` is a dict `{dialectname: class, ...}`, with all bindings
collected from that one dialect-import statement. Each binding is a
dialect, so usually there is just one.
- `lineno` is the line number from the import statement node,
or `None` if the statement had no `lineno` attribute.
- `col_offset` is the corresponding column offset.
It is also taken from the same import statement node.
The return value refers to the first dialect-import that has not yet been
transformed away. If there are no more dialect-imports, the return value
is `None`.
""" """
for index, statement in enumerate(tree.body): for index, statement in enumerate(tree.body):
if ismacroimport(statement, magicname="dialects"): if ismacroimport(statement, magicname="dialects"):
break break
else: else:
return "", {} return None
module_absname, bindings = get_macros(statement, filename=self.filename, module_absname, bindings = get_macros(statement, filename=self.filename,
reload=False, allow_asname=False) reload=False, allow_asname=False)
# Remove all names to prevent dialects being used as regular run-time objects. # Remove all names to prevent dialects being used as regular run-time objects.
# Always use an absolute import, for the unhygienic expose API guarantee. # Always use an absolute import, for the unhygienic expose API guarantee.
tree.body[index] = ast.copy_location(ast.Import(names=[ast.alias(name=module_absname, asname=None)]), thealias = ast.copy_location(ast.alias(name=module_absname, asname=None),
statement)
tree.body[index] = ast.copy_location(ast.Import(names=[thealias]),
statement) statement)
return module_absname, bindings
# Get source location info
lineno = statement.lineno if hasattr(statement, "lineno") else None
col_offset = statement.col_offset if hasattr(statement, "col_offset") else None
return module_absname, bindings, lineno, col_offset

View File

@ -649,12 +649,10 @@ def find_macros(tree, *, filename, reload=False, self_module=None, transform=Tru
else: else:
# Remove all names to prevent macros being used as regular run-time objects. # Remove all names to prevent macros being used as regular run-time objects.
# Always use an absolute import, for the unhygienic expose API guarantee. # Always use an absolute import, for the unhygienic expose API guarantee.
tree.body[index] = copy_location(Import(names=[ thealias = copy_location(alias(name=module_absname, asname=None),
alias(name=module_absname, statement)
asname=None, tree.body[index] = copy_location(Import(names=[thealias]),
lineno=getattr(statement, 'lineno', 0), statement)
col_offset=getattr(statement, 'col_offset', 0))]),
statement)
for index in reversed(stmts_to_delete): for index in reversed(stmts_to_delete):
tree.body.pop(index) tree.body.pop(index)
return bindings return bindings

View File

@ -103,6 +103,7 @@ def path_stats(path, _stats_cache=None):
Beside the source file `path` itself, we look at any macro definition files Beside the source file `path` itself, we look at any macro definition files
the source file imports macros from, recursively, in a `make`-like fashion. the source file imports macros from, recursively, in a `make`-like fashion.
Dialect-imports, if any, are treated the same way.
`_stats_cache` is used internally to speed up the computation, in case the `_stats_cache` is used internally to speed up the computation, in case the
dependency graph hits the same source file multiple times. dependency graph hits the same source file multiple times.

View File

@ -38,7 +38,8 @@ __all__ = ["macro_bindings",
"expand1s", "expands", "expand1s", "expands",
"expand1rq", "expandrq", "expand1rq", "expandrq",
"expand1r", "expandr", "expand1r", "expandr",
"stepr"] "stepr",
"expand_first"]
import ast import ast
@ -48,6 +49,8 @@ from .coreutils import _mcpyrate_attr
from .debug import step_expansion # noqa: F401, used in macro output. from .debug import step_expansion # noqa: F401, used in macro output.
from .expander import MacroExpander, namemacro, parametricmacro from .expander import MacroExpander, namemacro, parametricmacro
from .quotes import astify, capture_value, q, unastify from .quotes import astify, capture_value, q, unastify
from .unparser import unparse_with_fallbacks
from .utils import extract_bindings
def _mcpyrate_metatools_attr(attr): def _mcpyrate_metatools_attr(attr):
@ -447,3 +450,78 @@ def stepr(tree, *, args, syntax, expander, **kw):
[ast.keyword("args", astify(args)), [ast.keyword("args", astify(args)),
ast.keyword("syntax", astify(syntax)), ast.keyword("syntax", astify(syntax)),
ast.keyword("expander", expander_node)]) ast.keyword("expander", expander_node)])
# --------------------------------------------------------------------------------
@parametricmacro
def expand_first(tree, *, args, syntax, expander, **kw):
"""[syntax, block] Force given macros to expand before other macros.
Usage::
with expand_first[macro0, ...]:
...
Each argument can be either a bare macro name, e.g. `macro0`, or a
hygienically unquoted macro, e.g. `q[h[macro0]]`.
As an example, consider::
with your_block_macro:
macro0[expr]
In this case, if `your_block_macro` expands outside-in, it will transform the
`expr` inside the `macro0[expr] before `macro0` even sees the AST. If the test
fails or errors, the error message will contain the expanded version of `expr`,
not the original one. Now, if we change the example to::
with expand_first[macro0]:
with your_block_macro:
macro0[expr]
In this case, `expand_first` arranges things so that `macro0[expr]` expands first
(even if `your_block_macro` expands outside-in), so it will see the original,
unexpanded AST.
This does imply that `your_block_macro` will then receive the expanded form of
`macro0[expr]` as input, but that's macros for you.
There is no particular ordering in which the given set of macros expands;
they are handled by one expander with bindings registered for all of them.
(So the expansion order is determined by the order their use sites are
encountered; the ordering of the names in the argument list does not matter.)
"""
if syntax != "block":
raise SyntaxError("expand_first is a block macro only") # pragma: no cover
if syntax == "block" and kw['optional_vars'] is not None:
raise SyntaxError("expand_first does not take an as-part") # pragma: no cover
if not args:
raise SyntaxError("expected a comma-separated list of `macroname` or `q[h[macroname]]` in `with expand_first[macro0, ...]:`; no macro arguments were given")
# Expand macros in `args` to handle `q[h[somemacro]]`.
#
# We must use the full `bindings`, not only those of the quasiquote operators
# (see `mcpyrate.quotes._expand_quasiquotes`), so that the expander recognizes
# which hygienic captures are macros.
args = MacroExpander(expander.bindings, filename=expander.filename).visit(args)
# In the arguments, we should now have `Name` nodes only.
invalid_args = [node for node in args if type(node) is not ast.Name]
if invalid_args:
invalid_args_str = ", ".join(unparse_with_fallbacks(node, color=True, debug=True)
for node in invalid_args)
raise SyntaxError(f"expected a comma-separated list of `macroname` or `q[h[macroname]]` in `with expand_first[macro0, ...]:`; invalid args: {invalid_args_str}")
# Furthermore, all the specified names must be bound as macros in the current expander.
invalid_args = [name_node for name_node in args if name_node.id not in expander.bindings]
if invalid_args:
invalid_args_str = ", ".join(unparse_with_fallbacks(node, color=True, debug=True)
for node in invalid_args)
raise SyntaxError(f"all macro names in `with expand_first[macro0, ...]:` must be bound in the current expander; the following are not: {invalid_args_str}")
# All ok. First map the names to macro functions:
macros = [expander.bindings[name_node.id] for name_node in args]
# Then map the macro functions to *all* their names in `expander`:
macro_bindings = extract_bindings(expander.bindings, *macros)
return MacroExpander(macro_bindings, filename=expander.filename).visit(tree)

View File

@ -26,10 +26,9 @@ from copy import copy, deepcopy
from . import compiler from . import compiler
from .colorizer import ColorScheme, setcolor from .colorizer import ColorScheme, setcolor
from .coreutils import ismacroimport from .coreutils import ismacroimport, isfutureimport, inject_after_futureimports
from .expander import destructure_candidate, namemacro, parametricmacro from .expander import destructure_candidate, namemacro, parametricmacro
from .unparser import unparse_with_fallbacks from .unparser import unparse_with_fallbacks
from .utils import getdocstring
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Private utilities. # Private utilities.
@ -71,10 +70,6 @@ def iswithphase(stmt, *, filename):
return n return n
def isfutureimport(tree):
"""Return whether `tree` is a `from __future__ import ...`."""
return isinstance(tree, ast.ImportFrom) and tree.module == "__future__"
def extract_phase(tree, *, filename, phase=0): def extract_phase(tree, *, filename, phase=0):
"""Split `tree` into given `phase` and remaining parts. """Split `tree` into given `phase` and remaining parts.
@ -142,35 +137,6 @@ def extract_phase(tree, *, filename, phase=0):
newmodule.body = thisphase newmodule.body = thisphase
return newmodule return newmodule
def split_futureimports(body):
"""Split `body` into `__future__` imports and everything else.
`body`: list of `ast.stmt`, the suite representing a module top level.
Returns `[future_imports, the_rest]`.
"""
k = -1 # ensure `k` gets defined even if `body` is empty
for k, bstmt in enumerate(body):
if not isfutureimport(bstmt):
break
if k >= 0:
return body[:k], body[k:]
return [], body
def inject_after_futureimports(stmt, body):
"""Inject a statement into `body` after `__future__` imports.
`body`: list of `ast.stmt`, the suite representing a module top level.
`stmt`: `ast.stmt`, the statement to inject.
"""
if getdocstring(body):
docstring, *body = body
futureimports, body = split_futureimports(body)
return [docstring] + futureimports + [stmt] + body
else: # no docstring
futureimports, body = split_futureimports(body)
return futureimports + [stmt] + body
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
# Public utilities. # Public utilities.

View File

@ -536,7 +536,7 @@ def is_captured_macro(tree):
the next one thousand years"; see `mcpyrate.gensym`, which links to the next one thousand years"; see `mcpyrate.gensym`, which links to
the UUID spec used by the implementation.) the UUID spec used by the implementation.)
- `frozen_macro` is either `bytes` object that stores a reference to the - `frozen_macro` is a `bytes` object that stores a reference to the
frozen macro function as opaque binary data. frozen macro function as opaque binary data.
The `bytes` object can be decoded by passing the whole return value as `key` The `bytes` object can be decoded by passing the whole return value as `key`

View File

@ -7,7 +7,7 @@ import ast
from copy import deepcopy from copy import deepcopy
from .astfixers import fix_locations from .astfixers import fix_locations
from .coreutils import ismacroimport from .coreutils import ismacroimport, split_futureimports
from .markers import ASTMarker from .markers import ASTMarker
from .utils import getdocstring from .utils import getdocstring
from .walkers import ASTTransformer from .walkers import ASTTransformer
@ -157,29 +157,64 @@ def splice_statements(body, template, tag="__paste_here__"):
return StatementSplicer().visit(template) return StatementSplicer().visit(template)
def splice_dialect(body, template, tag="__paste_here__"): def splice_dialect(body, template, tag="__paste_here__", lineno=None, col_offset=None):
"""In a dialect AST transformer, splice module `body` into `template`. """In a dialect AST transformer, splice module `body` into `template`.
On top of what `splice_statements` does, this function handles macro-imports On top of what `splice_statements` does, this function handles macro-imports
and dialect-imports specially, gathering them all at the top level of the and dialect-imports specially, gathering them all at the top level of the
final module body, so that mcpyrate sees them when the module is sent to final module body, so that mcpyrate sees them when the module is sent to
the macro expander. the macro expander. This is to allow a dialect template to splice the body
into the inside of a `with` block (e.g. to invoke some code-walking macro
that changes the language semantics, such as an auto-TCO or a lazifier),
without breaking macro-imports (and further dialect-imports) introduced
by user code in the body.
Any dialect-imports in the template are placed first (in the order they Any dialect-imports in the template are placed first (in the order they
appear in the template), followed by any dialect-imports in the user code appear in the template), followed by any dialect-imports in the user code
(in the order they appear in the user code), followed by macro-imports in (in the order they appear in the user code), followed by macro-imports in
the template, then macro-imports in the user code. the template, then macro-imports in the user code.
This also handles the module docstring and the magic `__all__` (if any) We also handle the module docstring, future-imports, and the magic `__all__`.
from `body`. The docstring comes first, before dialect-imports. The magic
`__all__` is placed after dialect-imports, before macro-imports. The optional `lineno` and `col_offset` parameters can be used to tell
`splice_dialect` the source location info of the dialect-import (in the
unexpanded source code) that triggered this template. If specified, they
are used to mark all the lines coming from the template as having come
from that dialect-import statement. During dialect expansion, you can
get these from the `lineno` and `col_offset` attributes of your dialect
instance (these attributes are filled in by `DialectExpander`).
If both `body` and `template` have a module docstring, they are concatenated
to produce the module docstring for the result. If only one of them has a
module docstring, that docstring is used as-is. If neither has a module docstring,
the docstring is omitted.
The primary use of a module docstring in a dialect template is to be able to say
that the program was written in dialect X, more information on which can be found at...
Future-imports from `template` and `body` are concatenated.
The magic `__all__` is taken from `body`; if `body` does not define it,
it is omitted.
In the result, the ordering is::
docstring
template future-imports
body future-imports
__all__ (if defined in body)
template dialect-imports
body dialect-imports
template macro-imports
body macro-imports
the rest
Parameters: Parameters:
`body`: `list` of statements `body`: `list` of `ast.stmt`, or a single `ast.stmt`
Original module body from the user code (input). Original module body from the user code (input).
`template`: `list` of statements `template`: `list` of `ast.stmt`, or a single `ast.stmt`
Template for the final module body (output). Template for the final module body (output).
Must contain a paste-here indicator as in `splice_statements`. Must contain a paste-here indicator as in `splice_statements`.
@ -187,16 +222,19 @@ def splice_dialect(body, template, tag="__paste_here__"):
`tag`: `str` `tag`: `str`
The name of the paste-here indicator in `template`. The name of the paste-here indicator in `template`.
Returns `template` with `body` spliced in. Note `template` is **not** copied, `lineno`: optional `int`
and will be mutated in-place. `col_offset`: optional `int`
Source location info of the dialect-import that triggered this template.
Also `body` is mutated, to remove macro-imports, `__all__` and the module Return value is `template` with `body` spliced in.
docstring; these are pasted into the final result.
Note `template` and `body` are **not** copied, and **both** will be mutated
during the splicing process.
""" """
if isinstance(body, ast.AST): if isinstance(body, ast.AST):
body = [body] body = [body]
if isinstance(template, ast.AST): if isinstance(template, ast.AST):
body = [template] template = [template]
if not body: if not body:
raise ValueError("expected at least one statement in `body`") raise ValueError("expected at least one statement in `body`")
if not template: if not template:
@ -207,20 +245,34 @@ def splice_dialect(body, template, tag="__paste_here__"):
# Even if they have location info, it's for a different file compared # Even if they have location info, it's for a different file compared
# to the use site where `body` comes from. # to the use site where `body` comes from.
# #
# Pretend the template code appears at the beginning of the user module. # Pretend the template code appears at the given source location,
# # or if not given, at the beginning of `body`.
# TODO: It would be better to pretend it appears at the line that has the dialect-import. if lineno is not None and col_offset is not None:
# TODO: Requires a `lineno` parameter here, and `DialectExpander` must be modified to supply it. srcloc_dummynode = ast.Constant(value=None)
# TODO: We could extract the `lineno` in `find_dialectimport_ast` and then pass it to the srcloc_dummynode.lineno = lineno
# TODO: user-defined dialect AST transformer, so it could pass it to us if it wants to. srcloc_dummynode.col_offset = col_offset
for stmt in template:
fix_locations(stmt, body[0], mode="overwrite")
if getdocstring(body):
docstring, *body = body
docstring = [docstring]
else: else:
docstring = [] srcloc_dummynode = body[0]
for stmt in template:
fix_locations(stmt, srcloc_dummynode, mode="overwrite")
user_docstring, user_futureimports, body = split_futureimports(body)
template_docstring, template_futureimports, template = split_futureimports(template)
# Combine user and template docstrings if both are defined.
if user_docstring and template_docstring:
# We must extract the bare strings, combine them, and then pack the result into an AST node.
user_doc = getdocstring(user_docstring)
template_doc = getdocstring(template_docstring)
sep = "\n" + ("-" * 79) + "\n"
new_doc = user_doc + sep + template_doc
new_docstring = ast.copy_location(ast.Constant(value=new_doc),
user_docstring[0])
docstring = [new_docstring]
else:
docstring = user_docstring or template_docstring
futureimports = template_futureimports + user_futureimports
def extract_magic_all(tree): def extract_magic_all(tree):
def ismagicall(tree): def ismagicall(tree):
@ -239,6 +291,7 @@ def splice_dialect(body, template, tag="__paste_here__"):
w.visit(tree) w.visit(tree)
return tree, w.collected return tree, w.collected
body, user_magic_all = extract_magic_all(body) body, user_magic_all = extract_magic_all(body)
template, ignored_template_magic_all = extract_magic_all(template)
def extract_macroimports(tree, *, magicname="macros"): def extract_macroimports(tree, *, magicname="macros"):
class MacroImportExtractor(ASTTransformer): class MacroImportExtractor(ASTTransformer):
@ -257,6 +310,7 @@ def splice_dialect(body, template, tag="__paste_here__"):
finalbody = splice_statements(body, template, tag) finalbody = splice_statements(body, template, tag)
return (docstring + return (docstring +
futureimports +
user_magic_all + user_magic_all +
template_dialect_imports + user_dialect_imports + template_dialect_imports + user_dialect_imports +
template_macro_imports + user_macro_imports + template_macro_imports + user_macro_imports +

View File

@ -2,7 +2,7 @@
"""General utilities. Can be useful for writing both macros as well as macro expanders.""" """General utilities. Can be useful for writing both macros as well as macro expanders."""
__all__ = ["gensym", "scrub_uuid", "flatten", "rename", "extract_bindings", "getdocstring", __all__ = ["gensym", "scrub_uuid", "flatten", "rename", "extract_bindings", "getdocstring",
"format_location", "format_macrofunction", "format_context", "get_lineno", "format_location", "format_macrofunction", "format_context",
"NestingLevelTracker"] "NestingLevelTracker"]
import ast import ast
@ -188,34 +188,66 @@ def getdocstring(body):
# -------------------------------------------------------------------------------- # --------------------------------------------------------------------------------
def get_lineno(tree):
"""Extract the source line number from `tree`.
`tree`: AST node, list of AST nodes, or an AST marker.
`tree` is searched recursively (depth first) until a `lineno` attribute is found;
its value is then returned.
If no `lineno` attribute is found anywhere inside `tree`, the return value is `None`.
"""
if hasattr(tree, "lineno"):
return tree.lineno
elif isinstance(tree, markers.ASTMarker) and hasattr(tree, "body"): # look inside AST markers
return get_lineno(tree.body)
elif isinstance(tree, ast.AST): # look inside AST nodes
# Note `iter_fields` ignores attribute fields such as line numbers and column offsets,
# so we don't recurse into those.
for fieldname, node in ast.iter_fields(tree):
lineno = get_lineno(node)
if lineno:
return lineno
elif isinstance(tree, list): # look inside statement suites
for node in tree:
lineno = get_lineno(node)
if lineno:
return lineno
return None
def format_location(filename, tree, sourcecode): def format_location(filename, tree, sourcecode):
"""Format a source code location in a standard way, for error messages. """Format a source code location in a standard way, for error messages.
`filename`: full path to `.py` file. `filename`: full path to `.py` file.
`tree`: AST node to get source line number from. (Looks inside AST markers.) `tree`: AST node to get source line number from. (Looks inside automatically if needed.)
`sourcecode`: source code (typically, to get this, `unparse(tree)` `sourcecode`: source code (typically, to get this, `unparse(tree)`
before expanding it), or `None` to omit it. before expanding it), or `None` to omit it.
"""
lineno = None
if hasattr(tree, "lineno"):
lineno = tree.lineno
elif isinstance(tree, markers.ASTMarker) and hasattr(tree, "body"):
if hasattr(tree.body, "lineno"):
lineno = tree.body.lineno
elif isinstance(tree.body, list) and tree.body and hasattr(tree.body[0], "lineno"): # e.g. `SpliceNodes`
lineno = tree.body[0].lineno
Return value is an `str` containing colored text, suitable for terminal output.
Example outputs for single-line and multiline source code::
/path/to/hello.py:42: print("hello!")
/path/to/hello.py:1337:
if helloing:
print("hello!")
"""
if sourcecode: if sourcecode:
sep = " " if "\n" not in sourcecode else "\n" sep = " " if "\n" not in sourcecode else "\n"
source_with_sep = f"{sep}{sourcecode}" source_with_sep = f"{sep}{sourcecode}"
else: else:
source_with_sep = "" source_with_sep = ""
return f'{colorize(filename, ColorScheme.SOURCEFILENAME)}:{lineno}:{source_with_sep}' return f'{colorize(filename, ColorScheme.SOURCEFILENAME)}:{get_lineno(tree)}:{source_with_sep}'
def format_macrofunction(function): def format_macrofunction(function):
"""Format the fully qualified name of a macro function, for error messages.""" """Format the fully qualified name of a macro function, for error messages.
Return value is an `str` with the fully qualified name.
"""
# Catch broken bindings due to erroneous imports in user code # Catch broken bindings due to erroneous imports in user code
# (e.g. accidentally to a module object instead of to a function object) # (e.g. accidentally to a module object instead of to a function object)
if not (hasattr(function, "__module__") and hasattr(function, "__qualname__")): if not (hasattr(function, "__module__") and hasattr(function, "__qualname__")):
@ -226,7 +258,13 @@ def format_macrofunction(function):
def format_context(tree, *, n=5): def format_context(tree, *, n=5):
"""Format up to the first `n` lines of source code of `tree`.""" """Format up to the first `n` lines of source code of `tree`.
The source code is produced from `tree` by unparsing.
Return value is an `str` containing colored text with syntax highlighting,
suitable for terminal output.
"""
code_lines = unparser.unparse_with_fallbacks(tree, debug=True, color=True).split("\n") code_lines = unparser.unparse_with_fallbacks(tree, debug=True, color=True).split("\n")
code = "\n".join(code_lines[:n]) code = "\n".join(code_lines[:n])
if len(code_lines) > n: if len(code_lines) > n:
@ -255,7 +293,17 @@ class NestingLevelTracker:
value = property(fget=_get_value, doc="The current level. Read-only. Use `set_to` or `change_by` to change.") value = property(fget=_get_value, doc="The current level. Read-only. Use `set_to` or `change_by` to change.")
def set_to(self, value): def set_to(self, value):
"""Context manager. Run a section of code with the level set to `value`.""" """Context manager. Run a section of code with the level set to `value`.
Example::
t = NestingLevelTracker()
assert t.value == 0
with t.set_to(42):
assert t.value == 42
...
assert t.value == 0
"""
if not isinstance(value, int): if not isinstance(value, int):
raise TypeError(f"Expected integer `value`, got {type(value)} with value {repr(value)}") raise TypeError(f"Expected integer `value`, got {type(value)} with value {repr(value)}")
if value < 0: if value < 0:
@ -271,5 +319,18 @@ class NestingLevelTracker:
return _set_to() return _set_to()
def changed_by(self, delta): def changed_by(self, delta):
"""Context manager. Run a section of code with the level incremented by `delta`.""" """Context manager. Run a section of code with the level incremented by `delta`.
Example::
t = NestingLevelTracker()
assert t.value == 0
with t.changed_by(+21):
assert t.value == 21
with t.changed_by(+21):
assert t.value == 42
...
assert t.value == 21
assert t.value == 0
"""
return self.set_to(self.value + delta) return self.set_to(self.value + delta)