// p4rse, Copyright 2026, Will Hawkins // // This file is part of p4rse. // // This file is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // This program is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with this program. If not, see . import SwiftCompilerPlugin import SwiftSyntax @_spi(ExperimentalLanguageFeature) import SwiftSyntaxMacros public func remove_embedded_quotes(_ from: String) -> String { let result = from.replacing("\"", with: []) return result } struct MacroError: Error, CustomStringConvertible { var message: String var description: String { return message } public init(withMessage _message: String) { message = _message } } public struct UseOkResult: ExpressionMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { guard let argument = node.arguments.first?.expression else { throw Require.Error.SyntaxError } return """ { switch \(argument) { case Result.Ok(let __good): return __good case Result.Error(let __error): print("Unexpected result: \\(__error)") throw Require.Error.UnexpectedResult } }() """ } } public struct UseErrorResult: ExpressionMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { guard let argument = node.arguments.first?.expression else { throw Require.Error.SyntaxError } return """ { switch \(argument) { case Result.Error(let __error): return __error case Result.Ok(let __good): print("Unexpected result: \\(__good)") throw Require.Error.UnexpectedResult } }() """ } } public struct Require { public enum Error: Swift.Error { case UnexpectedResult case SyntaxError } } public struct RequireResult: ExpressionMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { guard let argument = node.arguments.first?.expression else { throw Require.Error.SyntaxError } return """ { switch \(argument) { case Result.Ok(_): return true case Result.Error(let __error): print("Unexpected result: \\(__error)") return false } }() """ } } public struct RequireErrorResult: ExpressionMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> ExprSyntax { let arguments = node.arguments.indices let expected_error = node.arguments[arguments.startIndex].expression let error_producer = node.arguments[arguments.index(after: arguments.startIndex)].expression return ExprSyntax( """ { let __expected_error = \(expected_error) let __actual_error = \(error_producer) if case Result.Error(let __found_error) = __actual_error { if !__expected_error.eq(__found_error) { print("Expected Error: \\(__expected_error) but got Error: \\(__found_error)") return false } return true } else { print("Expected error, but got Ok") return false } }() """) } } public struct RequireNodeType: CodeItemMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [CodeBlockItemSyntax] { let arguments = node.arguments.indices var arg_index = arguments.startIndex let node_to_check = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) let expected_type = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) let expected_type_nice_name = node.arguments[arg_index].expression let error_message = "Did not find " + remove_embedded_quotes(expected_type_nice_name.description) return [ CodeBlockItemSyntax( """ if \(node_to_check).nodeType != \(expected_type) { return Result.Error( ErrorWithLocation(sourceLocation: \(node_to_check).toSourceLocation(), withError: "\(raw: error_message)")) } """) ] } } public struct RequireNodesType: CodeItemMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [CodeBlockItemSyntax] { let arguments = node.arguments.indices var arg_index = arguments.startIndex let node_to_check = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) guard let expected_types = node.arguments[arg_index].expression.as(ArrayExprSyntax.self) else { throw MacroError(withMessage: "Node(s) to check must be in an array") } arg_index = arguments.index(after: arg_index) guard let expected_type_nice_names = node.arguments[arg_index].expression.as(ArrayExprSyntax.self) else { throw MacroError(withMessage: "Node nice names must be in an array") } let error_message = "Did not find one of the expected types: " + expected_type_nice_names.elements.map { l in remove_embedded_quotes("\(l.expression)") }.joined(separator: ",") let ifs = expected_types.elements.map { l in "\(node_to_check).nodeType != \(l.expression)" }.joined(separator: " && ") return [ CodeBlockItemSyntax( """ if \(raw: ifs) { return Result.Error( ErrorWithLocation(sourceLocation: \(node_to_check).toSourceLocation(), withError: "\(raw: error_message)")) } """) ] } } public struct SkipUnlessNodeType: CodeItemMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [CodeBlockItemSyntax] { let arguments = node.arguments.indices var arg_index = arguments.startIndex let node_to_check = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) let expected_type = node.arguments[arg_index].expression return [ CodeBlockItemSyntax( """ if \(node_to_check).nodeType != \(expected_type) { return Result.Ok(.none) } """) ] } } public struct SkipUnlessNodesTypes: CodeItemMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [CodeBlockItemSyntax] { let arguments = node.arguments.indices var arg_index = arguments.startIndex let node_to_check = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) guard let expected_types = node.arguments[arg_index].expression.as(ArrayExprSyntax.self) else { throw MacroError(withMessage: "Node(s) to check must be in an array") } arg_index = arguments.index(after: arg_index) let ifs = expected_types.elements.map { l in "\(node_to_check).nodeType != \(l.expression)" }.joined(separator: " && ") return [ CodeBlockItemSyntax( """ if \(raw: ifs) { return Result.Ok(.none) } """) ] } } public struct MustOr: CodeItemMacro { public static func expansion( of node: some FreestandingMacroExpansionSyntax, in context: some MacroExpansionContext ) throws -> [CodeBlockItemSyntax] { let arguments = node.arguments.indices var arg_index = arguments.startIndex let result = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) let thing = node.arguments[arg_index].expression arg_index = arguments.index(after: arg_index) let or = node.arguments[arg_index].expression return [ CodeBlockItemSyntax( """ if let __thing = \(thing) { \(result) = __thing } else { return \(or) } """) ] } } public struct CliTestDeclarationMacro: PeerMacro, Sendable { // NOTE: Taken from swift-testing. /// Get an expression initializing an instance of ``SourceLocation`` from an /// arbitrary syntax node. /// /// - Parameters: /// - node: The syntax node for which an instance of ``SourceLocation`` is /// needed. /// - context: The macro context in which the expression is being parsed. /// /// - Returns: An expression value that initializes an instance of /// ``SourceLocation`` for `node`. static func createSourceLocationExpr( of node: some SyntaxProtocol, context: some MacroExpansionContext ) -> ExprSyntax { if node.isProtocol((any FreestandingMacroExpansionSyntax).self) { // Freestanding macro expressions can just use __here() // directly and do not need to talk to the macro context to get source // location info. return "Testing.SourceLocation.__here()" } // Get the equivalent source location in both `#fileID` and `#filePath` modes. guard let fileIDSourceLoc: AbstractSourceLocation = context.location(of: node), let filePathSourceLoc: AbstractSourceLocation = context.location( of: node, at: .afterLeadingTrivia, filePathMode: .filePath) else { return "Testing.SourceLocation.__here()" } return "Testing.SourceLocation(__uncheckedFileID: \(fileIDSourceLoc.file), filePath: \(filePathSourceLoc.file), line: \(fileIDSourceLoc.line), column: \(fileIDSourceLoc.column))" } /// Get an expression initializing an instance of `__SourceBounds` from two /// arbitrary syntax nodesvalues. /// /// - Parameters: /// - lowerBoundNode: The syntax node representing the lower bound. The start /// of this node (after leading trivia) is used. /// - upperBoundNode: The syntax node representing the upper bound. The end of /// this node (before trailing trivia) is used. /// - context: The macro context in which the expression is being parsed. /// /// - Returns: An expression value that initializes an instance of /// `__SourceBounds`. /// /// The resulting source bounds instance represents (approximately): /// /// ```swift /// lowerBoundNode.positionAfterSkippingLeadingTrivia ..< upperBoundNode.endPositionBeforeTrailingTrivia /// ``` static func createSourceBoundsExpr( from lowerBoundNode: some SyntaxProtocol, to upperBoundNode: some SyntaxProtocol, in context: some MacroExpansionContext ) -> ExprSyntax { let lowerBoundExpr = createSourceLocationExpr(of: lowerBoundNode, context: context) let upperBoundExpr: ExprSyntax = if let upperBoundSourceLoc = context.location( of: upperBoundNode, at: .beforeTrailingTrivia, filePathMode: .fileID) { "(\(upperBoundSourceLoc.line), \(upperBoundSourceLoc.column))" } else { "(.max, .max)" } return "Testing.__SourceBounds(__uncheckedLowerBound: \(lowerBoundExpr), upperBound: \(upperBoundExpr))" } // NOTE: End of what was taken from swift-testing private static func doc_shrink(_ from: String) -> String { return from.replacing(Regex(#/^.*\/\/\/[\s]+/#), with: "") } public static func expansion( of node: AttributeSyntax, providingPeersOf declaration: some DeclSyntaxProtocol, in context: some MacroExpansionContext ) throws -> [DeclSyntax] { let test_name = declaration.cast(FunctionDeclSyntax.self).name let cli_test_expected_output = node.leadingTrivia.filter({ $0.isComment }).map({ doc_shrink("\($0)") }).joined(separator: "\\n") let cli_test_driver_thunk_name = context.makeUniqueName("_thunk_") let (expected_decl, expected_label) = if cli_test_expected_output.isEmpty { ("let expected = \"\(test_name)\"", "withExpectedPath:") } else { ("let expected = \"\(cli_test_expected_output)\"", "withExpected:") } let cli_test_driver_thunk: DeclSyntax = """ @Sendable private func \(cli_test_driver_thunk_name)() async throws { \(raw: expected_decl) _ = unsafe try await Testing.__requiringUnsafe( Testing.__requiringTry( Testing.__requiringAwait(swiftCliTestRunner(\(test_name), \(raw: expected_label) expected)))) } """ let source_bounds = createSourceBoundsExpr(from: node, to: declaration, in: context) let cli_test_driver_generator_name = context.makeUniqueName("_generator_") let cli_test_driver_generator: DeclSyntax = """ @Sendable private func \(cli_test_driver_generator_name)() async -> Testing.Test { return .__function( named: "\(test_name)", in: nil as Swift.Never.Type?, xcTestCompatibleSelector: Testing.__xcTestCompatibleSelector("\(test_name)"), traits: [], sourceBounds: \(source_bounds), parameters: [], testFunction: \(cli_test_driver_thunk_name) ) } """ #if os(macOS) let section = "__DATA_CONST,__swift5_tests" #else let section = "swift5_tests" #endif let cli_test_driver_content_record_name = context.makeUniqueName("testContentRecord") let cli_test_driver_cr: DeclSyntax = """ @section("\(raw: section)") @used private nonisolated let \(cli_test_driver_content_record_name): Testing.__TestContentRecord = ( 0x7465_7374, /* indicate a test */ 0, { outValue, type, _, _ in Testing.Test.__store(\(cli_test_driver_generator_name), into: outValue, asTypeAt: type) }, 0, 0 ) """ return [ cli_test_driver_thunk, cli_test_driver_generator, cli_test_driver_cr, ] } } @main struct P4Macros: CompilerPlugin { var providingMacros: [Macro.Type] = [ RequireResult.self, RequireErrorResult.self, UseOkResult.self, UseErrorResult.self, RequireNodeType.self, SkipUnlessNodeType.self, SkipUnlessNodesTypes.self, RequireNodesType.self, MustOr.self, CliTestDeclarationMacro.self, ] }