Files
gp4/Sources/Macros/Macros.swift
T
Will Hawkins d22776b018 compiler: Refactor Language Element Tags
Signed-off-by: Will Hawkins <hawkinsw@obs.cr>
2026-06-15 21:16:52 -04:00

471 lines
15 KiB
Swift

// p4rse, Copyright 2026, Will Hawkins
//
// This file is part of p4rse.
//
// This file is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.
import SwiftCompilerPlugin
import SwiftSyntax
@_spi(ExperimentalLanguageFeature) import SwiftSyntaxMacros
public func remove_embedded_quotes(_ from: String) -> String {
let result = from.replacing("\"", with: [])
return result
}
struct MacroError: Error, CustomStringConvertible {
var message: String
var description: String {
return message
}
public init(withMessage _message: String) {
message = _message
}
}
public struct UseOkResult: ExpressionMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
guard let argument = node.arguments.first?.expression else {
throw Require.Error.SyntaxError
}
return """
{
switch \(argument) {
case Result.Ok(let __good): return __good
case Result.Error(let __error):
print("Unexpected result: \\(__error)")
throw Require.Error.UnexpectedResult
}
}()
"""
}
}
public struct UseErrorResult: ExpressionMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
guard let argument = node.arguments.first?.expression else {
throw Require.Error.SyntaxError
}
return """
{
switch \(argument) {
case Result.Error(let __error): return __error
case Result.Ok(let __good):
print("Unexpected result: \\(__good)")
throw Require.Error.UnexpectedResult
}
}()
"""
}
}
public struct Require {
public enum Error: Swift.Error {
case UnexpectedResult
case SyntaxError
}
}
public struct RequireResult: ExpressionMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
guard let argument = node.arguments.first?.expression else {
throw Require.Error.SyntaxError
}
return """
{
switch \(argument) {
case Result.Ok(_): return true
case Result.Error(let __error):
print("Unexpected result: \\(__error)")
return false
}
}()
"""
}
}
public struct RequireErrorResult: ExpressionMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax,
in context: some MacroExpansionContext
) throws -> ExprSyntax {
let arguments = node.arguments.indices
let expected_error = node.arguments[arguments.startIndex].expression
let error_producer = node.arguments[arguments.index(after: arguments.startIndex)].expression
return ExprSyntax(
"""
{
let __expected_error = \(expected_error)
let __actual_error = \(error_producer)
if case Result.Error(let __found_error) = __actual_error {
if !__expected_error.eq(__found_error) {
print("Expected Error: \\(__expected_error) but got Error: \\(__found_error)")
return false
}
return true
} else {
print("Expected error, but got Ok")
return false
}
}()
""")
}
}
public struct RequireNodeType: CodeItemMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext
) throws -> [CodeBlockItemSyntax] {
let arguments = node.arguments.indices
var arg_index = arguments.startIndex
let node_to_check = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
let expected_type = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
let expected_type_nice_name = node.arguments[arg_index].expression
let error_message =
"Did not find " + remove_embedded_quotes(expected_type_nice_name.description)
return [
CodeBlockItemSyntax(
"""
if \(node_to_check).nodeType != \(expected_type) {
return Result.Error(
ErrorWithLocation(sourceLocation: \(node_to_check).toSourceLocation(), withError: "\(raw: error_message)"))
}
""")
]
}
}
public struct RequireNodesType: CodeItemMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext
) throws -> [CodeBlockItemSyntax] {
let arguments = node.arguments.indices
var arg_index = arguments.startIndex
let node_to_check = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
guard let expected_types = node.arguments[arg_index].expression.as(ArrayExprSyntax.self) else {
throw MacroError(withMessage: "Node(s) to check must be in an array")
}
arg_index = arguments.index(after: arg_index)
guard
let expected_type_nice_names = node.arguments[arg_index].expression.as(ArrayExprSyntax.self)
else {
throw MacroError(withMessage: "Node nice names must be in an array")
}
let error_message =
"Did not find one of the expected types: "
+ expected_type_nice_names.elements.map { l in
remove_embedded_quotes("\(l.expression)")
}.joined(separator: ",")
let ifs = expected_types.elements.map { l in
"\(node_to_check).nodeType != \(l.expression)"
}.joined(separator: " && ")
return [
CodeBlockItemSyntax(
"""
if \(raw: ifs) {
return Result.Error(
ErrorWithLocation(sourceLocation: \(node_to_check).toSourceLocation(), withError: "\(raw: error_message)"))
}
""")
]
}
}
public struct SkipUnlessNodeType: CodeItemMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext
) throws -> [CodeBlockItemSyntax] {
let arguments = node.arguments.indices
var arg_index = arguments.startIndex
let node_to_check = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
let expected_type = node.arguments[arg_index].expression
return [
CodeBlockItemSyntax(
"""
if \(node_to_check).nodeType != \(expected_type) {
return Result.Ok(.none)
}
""")
]
}
}
public struct SkipUnlessNodesTypes: CodeItemMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext
) throws -> [CodeBlockItemSyntax] {
let arguments = node.arguments.indices
var arg_index = arguments.startIndex
let node_to_check = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
guard let expected_types = node.arguments[arg_index].expression.as(ArrayExprSyntax.self) else {
throw MacroError(withMessage: "Node(s) to check must be in an array")
}
arg_index = arguments.index(after: arg_index)
let ifs = expected_types.elements.map { l in
"\(node_to_check).nodeType != \(l.expression)"
}.joined(separator: " && ")
return [
CodeBlockItemSyntax(
"""
if \(raw: ifs) {
return Result.Ok(.none)
}
""")
]
}
}
public struct MustOr: CodeItemMacro {
public static func expansion(
of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext
) throws -> [CodeBlockItemSyntax] {
let arguments = node.arguments.indices
var arg_index = arguments.startIndex
let result = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
let thing = node.arguments[arg_index].expression
arg_index = arguments.index(after: arg_index)
let or = node.arguments[arg_index].expression
return [
CodeBlockItemSyntax(
"""
if let __thing = \(thing) {
\(result) = __thing
} else {
return \(or)
}
""")
]
}
}
public struct CliTestDeclarationMacro: PeerMacro, Sendable {
// NOTE: Taken from swift-testing.
/// Get an expression initializing an instance of ``SourceLocation`` from an
/// arbitrary syntax node.
///
/// - Parameters:
/// - node: The syntax node for which an instance of ``SourceLocation`` is
/// needed.
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: An expression value that initializes an instance of
/// ``SourceLocation`` for `node`.
static func createSourceLocationExpr(
of node: some SyntaxProtocol, context: some MacroExpansionContext
) -> ExprSyntax {
if node.isProtocol((any FreestandingMacroExpansionSyntax).self) {
// Freestanding macro expressions can just use __here()
// directly and do not need to talk to the macro context to get source
// location info.
return "Testing.SourceLocation.__here()"
}
// Get the equivalent source location in both `#fileID` and `#filePath` modes.
guard let fileIDSourceLoc: AbstractSourceLocation = context.location(of: node),
let filePathSourceLoc: AbstractSourceLocation = context.location(
of: node, at: .afterLeadingTrivia, filePathMode: .filePath)
else {
return "Testing.SourceLocation.__here()"
}
return
"Testing.SourceLocation(__uncheckedFileID: \(fileIDSourceLoc.file), filePath: \(filePathSourceLoc.file), line: \(fileIDSourceLoc.line), column: \(fileIDSourceLoc.column))"
}
/// Get an expression initializing an instance of `__SourceBounds` from two
/// arbitrary syntax nodesvalues.
///
/// - Parameters:
/// - lowerBoundNode: The syntax node representing the lower bound. The start
/// of this node (after leading trivia) is used.
/// - upperBoundNode: The syntax node representing the upper bound. The end of
/// this node (before trailing trivia) is used.
/// - context: The macro context in which the expression is being parsed.
///
/// - Returns: An expression value that initializes an instance of
/// `__SourceBounds`.
///
/// The resulting source bounds instance represents (approximately):
///
/// ```swift
/// lowerBoundNode.positionAfterSkippingLeadingTrivia ..< upperBoundNode.endPositionBeforeTrailingTrivia
/// ```
static func createSourceBoundsExpr(
from lowerBoundNode: some SyntaxProtocol, to upperBoundNode: some SyntaxProtocol,
in context: some MacroExpansionContext
) -> ExprSyntax {
let lowerBoundExpr = createSourceLocationExpr(of: lowerBoundNode, context: context)
let upperBoundExpr: ExprSyntax =
if let upperBoundSourceLoc = context.location(
of: upperBoundNode, at: .beforeTrailingTrivia, filePathMode: .fileID)
{
"(\(upperBoundSourceLoc.line), \(upperBoundSourceLoc.column))"
} else {
"(.max, .max)"
}
return
"Testing.__SourceBounds(__uncheckedLowerBound: \(lowerBoundExpr), upperBound: \(upperBoundExpr))"
}
// NOTE: End of what was taken from swift-testing
private static func doc_shrink(_ from: String) -> String {
return from.replacing(Regex(#/^.*\/\/\/[\s]+/#), with: "")
}
public static func expansion(
of node: AttributeSyntax,
providingPeersOf declaration: some DeclSyntaxProtocol,
in context: some MacroExpansionContext
) throws -> [DeclSyntax] {
let test_name = declaration.cast(FunctionDeclSyntax.self).name
let cli_test_expected_output = node.leadingTrivia.filter({ $0.isComment }).map({
doc_shrink("\($0)")
}).joined(separator: "\\n")
let cli_test_driver_thunk_name = context.makeUniqueName("_thunk_")
let (expected_decl, expected_label) =
if cli_test_expected_output.isEmpty {
("let expected = \"\(test_name)\"", "withExpectedPath:")
} else {
("let expected = \"\(cli_test_expected_output)\"", "withExpected:")
}
let cli_test_driver_thunk: DeclSyntax = """
@Sendable private func \(cli_test_driver_thunk_name)() async throws {
\(raw: expected_decl)
_ = unsafe try await Testing.__requiringUnsafe(
Testing.__requiringTry(
Testing.__requiringAwait(swiftCliTestRunner(\(test_name), \(raw: expected_label) expected))))
}
"""
let source_bounds = createSourceBoundsExpr(from: node, to: declaration, in: context)
let cli_test_driver_generator_name = context.makeUniqueName("_generator_")
let cli_test_driver_generator: DeclSyntax = """
@Sendable private func \(cli_test_driver_generator_name)() async -> Testing.Test {
return .__function(
named: "\(test_name)",
in: nil as Swift.Never.Type?,
xcTestCompatibleSelector: Testing.__xcTestCompatibleSelector("\(test_name)"),
traits: [],
sourceBounds: \(source_bounds),
parameters: [],
testFunction: \(cli_test_driver_thunk_name)
)
}
"""
#if os(macOS)
let section = "__DATA_CONST,__swift5_tests"
#else
let section = "swift5_tests"
#endif
let cli_test_driver_content_record_name = context.makeUniqueName("testContentRecord")
let cli_test_driver_cr: DeclSyntax = """
@section("\(raw: section)")
@used
private nonisolated let \(cli_test_driver_content_record_name): Testing.__TestContentRecord = (
0x7465_7374, /* indicate a test */
0,
{ outValue, type, _, _ in
Testing.Test.__store(\(cli_test_driver_generator_name), into: outValue, asTypeAt: type)
},
0,
0
)
"""
return [
cli_test_driver_thunk, cli_test_driver_generator, cli_test_driver_cr,
]
}
}
public enum DeriveParsableStatement: MemberMacro {
public static func expansion(
of: AttributeSyntax, providingMembersOf type: some DeclGroupSyntax, conformingTo: [TypeSyntax],
in: some MacroExpansionContext
) throws -> [DeclSyntax] {
let implementation = DeclSyntax(
"""
public static func ParseStatement(
node: Node, withContext context: CSTCompilerContext
) -> Result<CST.Categories.Statement> {
return switch Parse(node: node, withContext: context) {
case .Ok(let res): .Ok(res)
case .Error(let e): .Error(e)
}
}
""")
return [implementation]
}
}
@main
struct P4Macros: CompilerPlugin {
var providingMacros: [Macro.Type] = [
RequireResult.self, RequireErrorResult.self, UseOkResult.self, UseErrorResult.self,
RequireNodeType.self, SkipUnlessNodeType.self, SkipUnlessNodesTypes.self, RequireNodesType.self,
MustOr.self, CliTestDeclarationMacro.self, DeriveParsableStatement.self,
]
}