diff --git a/kibot/mcpyrate/__init__.py b/kibot/mcpyrate/__init__.py index f677b11d..0f20a494 100644 --- a/kibot/mcpyrate/__init__.py +++ b/kibot/mcpyrate/__init__.py @@ -1,3 +1,5 @@ +"""mcpyrate: Advanced macro expander and language lab for Python.""" + from .astdumper import dump # noqa: F401 from .core import MacroExpansionError # noqa: F401 from .dialects import Dialect # noqa: F401 @@ -5,4 +7,12 @@ from .expander import namemacro, parametricmacro # noqa: F401 from .unparser import unparse # noqa: F401 from .utils import gensym # noqa: F401 -__version__ = "3.5.4" +# For public API inspection, import modules that wouldn't otherwise get imported. +from . import ansi # noqa: F401 +from . import debug # noqa: F401 +from . import metatools # noqa: F401 +from . import pycachecleaner # noqa: F401 +from . import quotes # noqa: F401 +from . import splicing # noqa: F401 + +__version__ = "3.6.0" diff --git a/kibot/mcpyrate/coreutils.py b/kibot/mcpyrate/coreutils.py index 824f71be..248abf88 100644 --- a/kibot/mcpyrate/coreutils.py +++ b/kibot/mcpyrate/coreutils.py @@ -2,7 +2,8 @@ """Utilities related to writing macro expanders and similar meta-metaprogramming tasks.""" __all__ = ["resolve_package", "relativize", "match_syspath", - "ismacroimport", "get_macros"] + "ismacroimport", "get_macros", + "isfutureimport", "split_futureimports", "inject_after_futureimports"] import ast import importlib @@ -12,7 +13,7 @@ import pathlib import sys from .unparser import unparse_with_fallbacks -from .utils import format_location +from .utils import format_location, getdocstring def resolve_package(filename): # TODO: for now, `guess_package`, really. Check the docs again. @@ -90,6 +91,15 @@ def ismacroimport(statement, magicname='macros'): def get_macros(macroimport, *, filename, reload=False, allow_asname=True, self_module=None): """Get absolute module name, macro names and macro functions from a macro-import. + A macro-import is a statement of the form:: + + from ... import macros, ... + + where `macros` is the magic name that your actual macro expander uses to recognize + a macro-import (see `ismacroimport`). This function does not care about what the + actual magic name is, and simply ignores the first name that is imported by the + import statement. + As a side effect, import the macro definition module. Return value is `module_absname, {macroname0: macrofunction0, ...}`. @@ -166,7 +176,7 @@ def get_macros(macroimport, *, filename, reload=False, allow_asname=True, self_m module = importlib.reload(module) bindings = {} - for name in macroimport.names[1:]: + for name in macroimport.names[1:]: # skip the "macros" in `from ... import macros, ...` if not allow_asname and name.asname is not None: approx_sourcecode = unparse_with_fallbacks(macroimport, debug=True, color=True) loc = format_location(filename, macroimport, approx_sourcecode) @@ -241,3 +251,54 @@ def _mcpyrate_attr(dotted_name, *, force_import=False): value = ast.Attribute(value=value, attr=name) return value + +# -------------------------------------------------------------------------------- + +def isfutureimport(tree): + """Return whether `tree` is a `from __future__ import ...`.""" + return isinstance(tree, ast.ImportFrom) and tree.module == "__future__" + +def split_futureimports(body): + """Split `body` into `__future__` imports and everything else. + + `body`: list of `ast.stmt`, the suite representing a module top level. + + Return value is `[docstring, future_imports, the_rest]`, where each item + is a list of `ast.stmt` (possibly empty). + """ + if getdocstring(body): + docstring, *body = body + docstring = [docstring] + else: + docstring = [] + + k = -1 # ensure `k` gets defined even if `body` is empty + for k, bstmt in enumerate(body): + if not isfutureimport(bstmt): + break + if k >= 0: + return docstring, body[:k], body[k:] + return docstring, [], body + +def inject_after_futureimports(stmts, body): + """Inject one or more statements into `body` after its `__future__` imports. + + `stmts`: `ast.stmt` or list of `ast.stmt`, the statement(s) to inject. + `body`: list of `ast.stmt`, the suite representing a module top level. + + Return value is the list `[docstring] + futureimports + stmts + the_rest`. + + If `body` has no docstring node at its beginning, the docstring part is + automatically omitted. + + If `body` has no `__future__` imports at the beginning just after the + optional docstring, the `futureimports` part is automatically omitted. + """ + if not isinstance(body, list): + raise TypeError(f"`body` must be a `list`, got {type(body)} with value {repr(body)}") + if not isinstance(stmts, list): + if not isinstance(stmts, ast.stmt): + raise TypeError(f"`stmts` must be `ast.stmt` or a `list` of `ast.stmt`, got {type(stmts)} with value {repr(stmts)}") + stmts = [stmts] + docstring, futureimports, body = split_futureimports(body) + return docstring + futureimports + stmts + body diff --git a/kibot/mcpyrate/dialects.py b/kibot/mcpyrate/dialects.py index 8dd69da5..0f53c46a 100644 --- a/kibot/mcpyrate/dialects.py +++ b/kibot/mcpyrate/dialects.py @@ -15,10 +15,22 @@ from .utils import format_macrofunction class Dialect: - """Base class for dialects.""" + """Base class for dialects. + + `expander`: the `DialectExpander` instance. The expander provides this automatically. + Stored as `self.expander`. + + During dialect expansion, the source location info of the dialect-import statement + that invoked this dialect-import is available as `self.lineno` and `self.col_offset`. + + You can pass those to `mcpyrate.splicing.splice_dialect` to automatically mark the + lines from your dialect template as coming from that dialect-import in the user + source code. + """ def __init__(self, expander): - """`expander`: the `DialectExpander` instance. The expander provides this automatically.""" self.expander = expander + self.lineno = None + self.col_offset = None def transform_source(self, text): """Override this to add a whole-module source transformer to your dialect. @@ -27,10 +39,10 @@ class Dialect: tells the expander this dialect does not provide a source transformer. Rarely needed. Because we don't (yet) have a generic, extensible - tokenizer for "Python-plus" with extended surface syntax, this is - currently essentially a per-module hook to plug in a transpiler - that compiles source code from some other programming language - into macro-enabled Python. + tokenizer for "Python-plus" with extended surface syntax, not to mention that + none of the available Python dev tools support any such, this is currently + essentially a per-module hook to plug in a transpiler that compiles + source code from some other programming language into macro-enabled Python. The dialect system autodetects the text encoding the same way Python itself does. That is, it reads the magic comment at the top of the source file @@ -80,8 +92,8 @@ class Dialect: Output should be the transformed AST. To easily splice `tree.body` into your template, see the utility - `mcpyrate.splicing.splice_dialect` (it automatically handles macro-imports, - dialect-imports, the magic `__all__`, and the module docstring). + `mcpyrate.splicing.splice_dialect` (it automatically handles future-imports, + macro-imports, dialect-imports, the magic `__all__`, and the module docstring). As an example, see the `dialects` module in `unpythonic` for example dialects. @@ -159,7 +171,7 @@ class DialectExpander: Due to modularity requirements introduced by `mcpyrate`'s support for multi-phase compilation (see the module `mcpyrate.multiphase`), this - class is a bit technical to use. See `mcpyrate.importer`. Roughly, + class is a bit technical to use. See `mcpyrate.compiler`. Roughly, for a single-phase compile:: dexpander = DialectExpander(filename=...) @@ -221,11 +233,11 @@ class DialectExpander: # state is so last decade. dialect_instances = [] while True: - module_absname, bindings = find_dialectimport(content) - if not module_absname: # no more dialects + theimport = find_dialectimport(content) + if theimport: + module_absname, bindings, lineno, col_offset = theimport + else: # no more dialects break - if not bindings: - continue for dialectname, cls in bindings.items(): if not (isinstance(cls, type) and issubclass(cls, Dialect)): @@ -235,6 +247,9 @@ class DialectExpander: dialect = cls(expander=self) except Exception as err: raise ImportError(f"Unexpected exception while instantiating dialect `{module_absname}.{dialectname}`") from err + # make the dialect-import source location info available to the transformers + dialect.lineno = lineno + dialect.col_offset = col_offset try: transformer_method = getattr(dialect, transform) @@ -338,9 +353,24 @@ class DialectExpander: So we can only rely on the literal text "from ... import dialects, ...", similarly to how Racket heavily constrains the format of its `#lang` line. - Return value is a dict `{dialectname: class, ...}` with all collected bindings - from that one dialect-import. Each binding is a dialect, so usually there is - just one. + Return value is the tuple `(module_absname, bindings, lineno, col_offset)`: + + - `module_absname` is the absolute module name referred to by the import + - `bindings` is a dict `{dialectname: class, ...}`, with all bindings + collected from that one dialect-import statement. Each binding is a + dialect, so usually there is just one. + - `lineno` is the line number of the import statement, determined by + counting the lines of `text`. + - `col_offset` is the corresponding column offset. + Currently not extracted; is always set to 0. + + The return value refers to the first not-yet-seen dialect-import (according + to the private cache `self._seen`). Note that this does not transform away + the dialect-imports, because the expander still needs to see them in the + AST transformation step. + + If there are no more dialect-imports that have not been seen already, + the return value is `None`. """ matches = _dialectimport.finditer(text) try: @@ -349,15 +379,17 @@ class DialectExpander: statement = match.group(0).strip() if statement not in self._seen: # apply each unique dialect-import once self._seen.add(statement) + lineno = 1 + text[0:match.start()].count("\n") # https://stackoverflow.com/a/48647994 + col_offset = 0 # TODO: extract the correct column offset break except StopIteration: - return "", {} + return None dummy_module = ast.parse(statement, filename=self.filename, mode="exec") dialectimport = dummy_module.body[0] module_absname, bindings = get_macros(dialectimport, filename=self.filename, reload=False, allow_asname=False) - return module_absname, bindings + return module_absname, bindings, lineno, col_offset def find_dialectimport_ast(self, tree): """Find the first dialect-import statement by scanning the AST `tree`. @@ -374,20 +406,38 @@ class DialectExpander: from ... import dialects, ... - Return value is a dict `{dialectname: class, ...}` with all collected bindings - from that one dialect-import. Each binding is a dialect, so usually there is - just one. + Return value is the tuple `(module_absname, bindings, lineno)`, where: + + - `module_absname` is the absolute module name referred to by the import + - `bindings` is a dict `{dialectname: class, ...}`, with all bindings + collected from that one dialect-import statement. Each binding is a + dialect, so usually there is just one. + - `lineno` is the line number from the import statement node, + or `None` if the statement had no `lineno` attribute. + - `col_offset` is the corresponding column offset. + It is also taken from the same import statement node. + + The return value refers to the first dialect-import that has not yet been + transformed away. If there are no more dialect-imports, the return value + is `None`. """ for index, statement in enumerate(tree.body): if ismacroimport(statement, magicname="dialects"): break else: - return "", {} + return None module_absname, bindings = get_macros(statement, filename=self.filename, reload=False, allow_asname=False) # Remove all names to prevent dialects being used as regular run-time objects. # Always use an absolute import, for the unhygienic expose API guarantee. - tree.body[index] = ast.copy_location(ast.Import(names=[ast.alias(name=module_absname, asname=None)]), + thealias = ast.copy_location(ast.alias(name=module_absname, asname=None), + statement) + tree.body[index] = ast.copy_location(ast.Import(names=[thealias]), statement) - return module_absname, bindings + + # Get source location info + lineno = statement.lineno if hasattr(statement, "lineno") else None + col_offset = statement.col_offset if hasattr(statement, "col_offset") else None + + return module_absname, bindings, lineno, col_offset diff --git a/kibot/mcpyrate/expander.py b/kibot/mcpyrate/expander.py index d4c081db..ebd6727e 100644 --- a/kibot/mcpyrate/expander.py +++ b/kibot/mcpyrate/expander.py @@ -649,12 +649,10 @@ def find_macros(tree, *, filename, reload=False, self_module=None, transform=Tru else: # Remove all names to prevent macros being used as regular run-time objects. # Always use an absolute import, for the unhygienic expose API guarantee. - tree.body[index] = copy_location(Import(names=[ - alias(name=module_absname, - asname=None, - lineno=getattr(statement, 'lineno', 0), - col_offset=getattr(statement, 'col_offset', 0))]), - statement) + thealias = copy_location(alias(name=module_absname, asname=None), + statement) + tree.body[index] = copy_location(Import(names=[thealias]), + statement) for index in reversed(stmts_to_delete): tree.body.pop(index) return bindings diff --git a/kibot/mcpyrate/importer.py b/kibot/mcpyrate/importer.py index 230381be..4e58e963 100644 --- a/kibot/mcpyrate/importer.py +++ b/kibot/mcpyrate/importer.py @@ -103,6 +103,7 @@ def path_stats(path, _stats_cache=None): Beside the source file `path` itself, we look at any macro definition files the source file imports macros from, recursively, in a `make`-like fashion. + Dialect-imports, if any, are treated the same way. `_stats_cache` is used internally to speed up the computation, in case the dependency graph hits the same source file multiple times. diff --git a/kibot/mcpyrate/metatools.py b/kibot/mcpyrate/metatools.py index 9ae4711f..1eddfce3 100644 --- a/kibot/mcpyrate/metatools.py +++ b/kibot/mcpyrate/metatools.py @@ -38,7 +38,8 @@ __all__ = ["macro_bindings", "expand1s", "expands", "expand1rq", "expandrq", "expand1r", "expandr", - "stepr"] + "stepr", + "expand_first"] import ast @@ -48,6 +49,8 @@ from .coreutils import _mcpyrate_attr from .debug import step_expansion # noqa: F401, used in macro output. from .expander import MacroExpander, namemacro, parametricmacro from .quotes import astify, capture_value, q, unastify +from .unparser import unparse_with_fallbacks +from .utils import extract_bindings def _mcpyrate_metatools_attr(attr): @@ -447,3 +450,78 @@ def stepr(tree, *, args, syntax, expander, **kw): [ast.keyword("args", astify(args)), ast.keyword("syntax", astify(syntax)), ast.keyword("expander", expander_node)]) + +# -------------------------------------------------------------------------------- + +@parametricmacro +def expand_first(tree, *, args, syntax, expander, **kw): + """[syntax, block] Force given macros to expand before other macros. + + Usage:: + + with expand_first[macro0, ...]: + ... + + Each argument can be either a bare macro name, e.g. `macro0`, or a + hygienically unquoted macro, e.g. `q[h[macro0]]`. + + As an example, consider:: + + with your_block_macro: + macro0[expr] + + In this case, if `your_block_macro` expands outside-in, it will transform the + `expr` inside the `macro0[expr] before `macro0` even sees the AST. If the test + fails or errors, the error message will contain the expanded version of `expr`, + not the original one. Now, if we change the example to:: + + with expand_first[macro0]: + with your_block_macro: + macro0[expr] + + In this case, `expand_first` arranges things so that `macro0[expr]` expands first + (even if `your_block_macro` expands outside-in), so it will see the original, + unexpanded AST. + + This does imply that `your_block_macro` will then receive the expanded form of + `macro0[expr]` as input, but that's macros for you. + + There is no particular ordering in which the given set of macros expands; + they are handled by one expander with bindings registered for all of them. + (So the expansion order is determined by the order their use sites are + encountered; the ordering of the names in the argument list does not matter.) + """ + if syntax != "block": + raise SyntaxError("expand_first is a block macro only") # pragma: no cover + if syntax == "block" and kw['optional_vars'] is not None: + raise SyntaxError("expand_first does not take an as-part") # pragma: no cover + if not args: + raise SyntaxError("expected a comma-separated list of `macroname` or `q[h[macroname]]` in `with expand_first[macro0, ...]:`; no macro arguments were given") + + # Expand macros in `args` to handle `q[h[somemacro]]`. + # + # We must use the full `bindings`, not only those of the quasiquote operators + # (see `mcpyrate.quotes._expand_quasiquotes`), so that the expander recognizes + # which hygienic captures are macros. + args = MacroExpander(expander.bindings, filename=expander.filename).visit(args) + + # In the arguments, we should now have `Name` nodes only. + invalid_args = [node for node in args if type(node) is not ast.Name] + if invalid_args: + invalid_args_str = ", ".join(unparse_with_fallbacks(node, color=True, debug=True) + for node in invalid_args) + raise SyntaxError(f"expected a comma-separated list of `macroname` or `q[h[macroname]]` in `with expand_first[macro0, ...]:`; invalid args: {invalid_args_str}") + + # Furthermore, all the specified names must be bound as macros in the current expander. + invalid_args = [name_node for name_node in args if name_node.id not in expander.bindings] + if invalid_args: + invalid_args_str = ", ".join(unparse_with_fallbacks(node, color=True, debug=True) + for node in invalid_args) + raise SyntaxError(f"all macro names in `with expand_first[macro0, ...]:` must be bound in the current expander; the following are not: {invalid_args_str}") + + # All ok. First map the names to macro functions: + macros = [expander.bindings[name_node.id] for name_node in args] + + # Then map the macro functions to *all* their names in `expander`: + macro_bindings = extract_bindings(expander.bindings, *macros) + return MacroExpander(macro_bindings, filename=expander.filename).visit(tree) diff --git a/kibot/mcpyrate/multiphase.py b/kibot/mcpyrate/multiphase.py index fbb73d61..a68ab5e4 100644 --- a/kibot/mcpyrate/multiphase.py +++ b/kibot/mcpyrate/multiphase.py @@ -26,10 +26,9 @@ from copy import copy, deepcopy from . import compiler from .colorizer import ColorScheme, setcolor -from .coreutils import ismacroimport +from .coreutils import ismacroimport, isfutureimport, inject_after_futureimports from .expander import destructure_candidate, namemacro, parametricmacro from .unparser import unparse_with_fallbacks -from .utils import getdocstring # -------------------------------------------------------------------------------- # Private utilities. @@ -71,10 +70,6 @@ def iswithphase(stmt, *, filename): return n -def isfutureimport(tree): - """Return whether `tree` is a `from __future__ import ...`.""" - return isinstance(tree, ast.ImportFrom) and tree.module == "__future__" - def extract_phase(tree, *, filename, phase=0): """Split `tree` into given `phase` and remaining parts. @@ -142,35 +137,6 @@ def extract_phase(tree, *, filename, phase=0): newmodule.body = thisphase return newmodule -def split_futureimports(body): - """Split `body` into `__future__` imports and everything else. - - `body`: list of `ast.stmt`, the suite representing a module top level. - - Returns `[future_imports, the_rest]`. - """ - k = -1 # ensure `k` gets defined even if `body` is empty - for k, bstmt in enumerate(body): - if not isfutureimport(bstmt): - break - if k >= 0: - return body[:k], body[k:] - return [], body - -def inject_after_futureimports(stmt, body): - """Inject a statement into `body` after `__future__` imports. - - `body`: list of `ast.stmt`, the suite representing a module top level. - `stmt`: `ast.stmt`, the statement to inject. - """ - if getdocstring(body): - docstring, *body = body - futureimports, body = split_futureimports(body) - return [docstring] + futureimports + [stmt] + body - else: # no docstring - futureimports, body = split_futureimports(body) - return futureimports + [stmt] + body - # -------------------------------------------------------------------------------- # Public utilities. diff --git a/kibot/mcpyrate/quotes.py b/kibot/mcpyrate/quotes.py index 8e1278e8..504b9285 100644 --- a/kibot/mcpyrate/quotes.py +++ b/kibot/mcpyrate/quotes.py @@ -536,7 +536,7 @@ def is_captured_macro(tree): the next one thousand years"; see `mcpyrate.gensym`, which links to the UUID spec used by the implementation.) - - `frozen_macro` is either `bytes` object that stores a reference to the + - `frozen_macro` is a `bytes` object that stores a reference to the frozen macro function as opaque binary data. The `bytes` object can be decoded by passing the whole return value as `key` diff --git a/kibot/mcpyrate/splicing.py b/kibot/mcpyrate/splicing.py index 06c28fe5..7ff3e197 100644 --- a/kibot/mcpyrate/splicing.py +++ b/kibot/mcpyrate/splicing.py @@ -7,7 +7,7 @@ import ast from copy import deepcopy from .astfixers import fix_locations -from .coreutils import ismacroimport +from .coreutils import ismacroimport, split_futureimports from .markers import ASTMarker from .utils import getdocstring from .walkers import ASTTransformer @@ -157,29 +157,64 @@ def splice_statements(body, template, tag="__paste_here__"): return StatementSplicer().visit(template) -def splice_dialect(body, template, tag="__paste_here__"): +def splice_dialect(body, template, tag="__paste_here__", lineno=None, col_offset=None): """In a dialect AST transformer, splice module `body` into `template`. On top of what `splice_statements` does, this function handles macro-imports and dialect-imports specially, gathering them all at the top level of the final module body, so that mcpyrate sees them when the module is sent to - the macro expander. + the macro expander. This is to allow a dialect template to splice the body + into the inside of a `with` block (e.g. to invoke some code-walking macro + that changes the language semantics, such as an auto-TCO or a lazifier), + without breaking macro-imports (and further dialect-imports) introduced + by user code in the body. Any dialect-imports in the template are placed first (in the order they appear in the template), followed by any dialect-imports in the user code (in the order they appear in the user code), followed by macro-imports in the template, then macro-imports in the user code. - This also handles the module docstring and the magic `__all__` (if any) - from `body`. The docstring comes first, before dialect-imports. The magic - `__all__` is placed after dialect-imports, before macro-imports. + We also handle the module docstring, future-imports, and the magic `__all__`. + + The optional `lineno` and `col_offset` parameters can be used to tell + `splice_dialect` the source location info of the dialect-import (in the + unexpanded source code) that triggered this template. If specified, they + are used to mark all the lines coming from the template as having come + from that dialect-import statement. During dialect expansion, you can + get these from the `lineno` and `col_offset` attributes of your dialect + instance (these attributes are filled in by `DialectExpander`). + + If both `body` and `template` have a module docstring, they are concatenated + to produce the module docstring for the result. If only one of them has a + module docstring, that docstring is used as-is. If neither has a module docstring, + the docstring is omitted. + + The primary use of a module docstring in a dialect template is to be able to say + that the program was written in dialect X, more information on which can be found at... + + Future-imports from `template` and `body` are concatenated. + + The magic `__all__` is taken from `body`; if `body` does not define it, + it is omitted. + + In the result, the ordering is:: + + docstring + template future-imports + body future-imports + __all__ (if defined in body) + template dialect-imports + body dialect-imports + template macro-imports + body macro-imports + the rest Parameters: - `body`: `list` of statements + `body`: `list` of `ast.stmt`, or a single `ast.stmt` Original module body from the user code (input). - `template`: `list` of statements + `template`: `list` of `ast.stmt`, or a single `ast.stmt` Template for the final module body (output). Must contain a paste-here indicator as in `splice_statements`. @@ -187,16 +222,19 @@ def splice_dialect(body, template, tag="__paste_here__"): `tag`: `str` The name of the paste-here indicator in `template`. - Returns `template` with `body` spliced in. Note `template` is **not** copied, - and will be mutated in-place. + `lineno`: optional `int` + `col_offset`: optional `int` + Source location info of the dialect-import that triggered this template. - Also `body` is mutated, to remove macro-imports, `__all__` and the module - docstring; these are pasted into the final result. + Return value is `template` with `body` spliced in. + + Note `template` and `body` are **not** copied, and **both** will be mutated + during the splicing process. """ if isinstance(body, ast.AST): body = [body] if isinstance(template, ast.AST): - body = [template] + template = [template] if not body: raise ValueError("expected at least one statement in `body`") if not template: @@ -207,20 +245,34 @@ def splice_dialect(body, template, tag="__paste_here__"): # Even if they have location info, it's for a different file compared # to the use site where `body` comes from. # - # Pretend the template code appears at the beginning of the user module. - # - # TODO: It would be better to pretend it appears at the line that has the dialect-import. - # TODO: Requires a `lineno` parameter here, and `DialectExpander` must be modified to supply it. - # TODO: We could extract the `lineno` in `find_dialectimport_ast` and then pass it to the - # TODO: user-defined dialect AST transformer, so it could pass it to us if it wants to. - for stmt in template: - fix_locations(stmt, body[0], mode="overwrite") - - if getdocstring(body): - docstring, *body = body - docstring = [docstring] + # Pretend the template code appears at the given source location, + # or if not given, at the beginning of `body`. + if lineno is not None and col_offset is not None: + srcloc_dummynode = ast.Constant(value=None) + srcloc_dummynode.lineno = lineno + srcloc_dummynode.col_offset = col_offset else: - docstring = [] + srcloc_dummynode = body[0] + for stmt in template: + fix_locations(stmt, srcloc_dummynode, mode="overwrite") + + user_docstring, user_futureimports, body = split_futureimports(body) + template_docstring, template_futureimports, template = split_futureimports(template) + + # Combine user and template docstrings if both are defined. + if user_docstring and template_docstring: + # We must extract the bare strings, combine them, and then pack the result into an AST node. + user_doc = getdocstring(user_docstring) + template_doc = getdocstring(template_docstring) + sep = "\n" + ("-" * 79) + "\n" + new_doc = user_doc + sep + template_doc + new_docstring = ast.copy_location(ast.Constant(value=new_doc), + user_docstring[0]) + docstring = [new_docstring] + else: + docstring = user_docstring or template_docstring + + futureimports = template_futureimports + user_futureimports def extract_magic_all(tree): def ismagicall(tree): @@ -239,6 +291,7 @@ def splice_dialect(body, template, tag="__paste_here__"): w.visit(tree) return tree, w.collected body, user_magic_all = extract_magic_all(body) + template, ignored_template_magic_all = extract_magic_all(template) def extract_macroimports(tree, *, magicname="macros"): class MacroImportExtractor(ASTTransformer): @@ -257,6 +310,7 @@ def splice_dialect(body, template, tag="__paste_here__"): finalbody = splice_statements(body, template, tag) return (docstring + + futureimports + user_magic_all + template_dialect_imports + user_dialect_imports + template_macro_imports + user_macro_imports + diff --git a/kibot/mcpyrate/utils.py b/kibot/mcpyrate/utils.py index 688cf91d..504feb9d 100644 --- a/kibot/mcpyrate/utils.py +++ b/kibot/mcpyrate/utils.py @@ -2,7 +2,7 @@ """General utilities. Can be useful for writing both macros as well as macro expanders.""" __all__ = ["gensym", "scrub_uuid", "flatten", "rename", "extract_bindings", "getdocstring", - "format_location", "format_macrofunction", "format_context", + "get_lineno", "format_location", "format_macrofunction", "format_context", "NestingLevelTracker"] import ast @@ -188,34 +188,66 @@ def getdocstring(body): # -------------------------------------------------------------------------------- +def get_lineno(tree): + """Extract the source line number from `tree`. + + `tree`: AST node, list of AST nodes, or an AST marker. + + `tree` is searched recursively (depth first) until a `lineno` attribute is found; + its value is then returned. + + If no `lineno` attribute is found anywhere inside `tree`, the return value is `None`. + """ + if hasattr(tree, "lineno"): + return tree.lineno + elif isinstance(tree, markers.ASTMarker) and hasattr(tree, "body"): # look inside AST markers + return get_lineno(tree.body) + elif isinstance(tree, ast.AST): # look inside AST nodes + # Note `iter_fields` ignores attribute fields such as line numbers and column offsets, + # so we don't recurse into those. + for fieldname, node in ast.iter_fields(tree): + lineno = get_lineno(node) + if lineno: + return lineno + elif isinstance(tree, list): # look inside statement suites + for node in tree: + lineno = get_lineno(node) + if lineno: + return lineno + return None + + def format_location(filename, tree, sourcecode): """Format a source code location in a standard way, for error messages. `filename`: full path to `.py` file. - `tree`: AST node to get source line number from. (Looks inside AST markers.) + `tree`: AST node to get source line number from. (Looks inside automatically if needed.) `sourcecode`: source code (typically, to get this, `unparse(tree)` before expanding it), or `None` to omit it. - """ - lineno = None - if hasattr(tree, "lineno"): - lineno = tree.lineno - elif isinstance(tree, markers.ASTMarker) and hasattr(tree, "body"): - if hasattr(tree.body, "lineno"): - lineno = tree.body.lineno - elif isinstance(tree.body, list) and tree.body and hasattr(tree.body[0], "lineno"): # e.g. `SpliceNodes` - lineno = tree.body[0].lineno + Return value is an `str` containing colored text, suitable for terminal output. + Example outputs for single-line and multiline source code:: + + /path/to/hello.py:42: print("hello!") + + /path/to/hello.py:1337: + if helloing: + print("hello!") + """ if sourcecode: sep = " " if "\n" not in sourcecode else "\n" source_with_sep = f"{sep}{sourcecode}" else: source_with_sep = "" - return f'{colorize(filename, ColorScheme.SOURCEFILENAME)}:{lineno}:{source_with_sep}' + return f'{colorize(filename, ColorScheme.SOURCEFILENAME)}:{get_lineno(tree)}:{source_with_sep}' def format_macrofunction(function): - """Format the fully qualified name of a macro function, for error messages.""" + """Format the fully qualified name of a macro function, for error messages. + + Return value is an `str` with the fully qualified name. + """ # Catch broken bindings due to erroneous imports in user code # (e.g. accidentally to a module object instead of to a function object) if not (hasattr(function, "__module__") and hasattr(function, "__qualname__")): @@ -226,7 +258,13 @@ def format_macrofunction(function): def format_context(tree, *, n=5): - """Format up to the first `n` lines of source code of `tree`.""" + """Format up to the first `n` lines of source code of `tree`. + + The source code is produced from `tree` by unparsing. + + Return value is an `str` containing colored text with syntax highlighting, + suitable for terminal output. + """ code_lines = unparser.unparse_with_fallbacks(tree, debug=True, color=True).split("\n") code = "\n".join(code_lines[:n]) if len(code_lines) > n: @@ -255,7 +293,17 @@ class NestingLevelTracker: value = property(fget=_get_value, doc="The current level. Read-only. Use `set_to` or `change_by` to change.") def set_to(self, value): - """Context manager. Run a section of code with the level set to `value`.""" + """Context manager. Run a section of code with the level set to `value`. + + Example:: + + t = NestingLevelTracker() + assert t.value == 0 + with t.set_to(42): + assert t.value == 42 + ... + assert t.value == 0 + """ if not isinstance(value, int): raise TypeError(f"Expected integer `value`, got {type(value)} with value {repr(value)}") if value < 0: @@ -271,5 +319,18 @@ class NestingLevelTracker: return _set_to() def changed_by(self, delta): - """Context manager. Run a section of code with the level incremented by `delta`.""" + """Context manager. Run a section of code with the level incremented by `delta`. + + Example:: + + t = NestingLevelTracker() + assert t.value == 0 + with t.changed_by(+21): + assert t.value == 21 + with t.changed_by(+21): + assert t.value == 42 + ... + assert t.value == 21 + assert t.value == 0 + """ return self.set_to(self.value + delta)