From 2861a82777994369bf87253b95872f58441e8aa9 Mon Sep 17 00:00:00 2001 From: Will Hawkins Date: Mon, 23 Mar 2026 07:55:11 -0400 Subject: [PATCH] Implement Binary Operators and Grouping (in Expressions) Signed-off-by: Will Hawkins --- Sources/Common/ProgramTypes.swift | 240 +++++++++++++++++++++++++++- Sources/Common/Protocols.swift | 4 + Sources/Common/Support.swift | 4 + Sources/Macros/Macros.swift | 3 +- Sources/P4Compiler/Expression.swift | 71 +++++++- Sources/P4Lang/Parser.swift | 58 ++++++- Sources/P4Runtime/Expressions.swift | 55 ++++++- 7 files changed, 420 insertions(+), 15 deletions(-) diff --git a/Sources/Common/ProgramTypes.swift b/Sources/Common/ProgramTypes.swift index 626fc27..801dee3 100644 --- a/Sources/Common/ProgramTypes.swift +++ b/Sources/Common/ProgramTypes.swift @@ -16,7 +16,7 @@ // along with this program. If not, see . /// A P4 identifier -public class Identifier: CustomStringConvertible, Equatable, Hashable { +public class Identifier: CustomStringConvertible, Comparable, Hashable { public func hash(into hasher: inout Hasher) { hasher.combine(self.name) } @@ -38,6 +38,10 @@ public class Identifier: CustomStringConvertible, Equatable, Hashable { public static func == (lhs: Identifier, rhs: Identifier) -> Bool { return lhs.name == rhs.name } + + public static func < (lhs: Identifier, rhs: Identifier) -> Bool { + return lhs.name < rhs.name + } } /// A P4 identifier @@ -171,10 +175,121 @@ public class P4StructValue: P4Value { return self.stype } - public func eq(rhs: any P4Value) -> Bool { + func bin_op_impl(lhs: P4StructValue, rhs: P4StructValue, op: (P4Value?, P4Value?) -> Bool) -> Bool { + if lhs.stype.fields.count() != rhs.stype.fields.count() { + // If there are a different number of fields, then we cannot + // possibly be equal. + return false + } + + // Note: Because the number of values _always_ matches the number of fields, there + // is no need to check there! + + for xx in zip(zip(lhs.stype.fields, lhs.values), zip(rhs.stype.fields, rhs.values)) { + let left = xx.0 + let right = xx.1 + + let left_field = left.0 + let left_value = left.1 + + let right_field = right.0 + let right_value = right.1 + + // If the field names do not match, then there is a problem. + if left_field != right_field { + return false + } + + // Now that we know that the field names match, do the values match? + if !op(left_value, right_value) { + return false + } + } return true } + public func eq(rhs: any P4Value) -> Bool { + guard let rrhs = rhs as? P4StructValue else { + return false + } + return bin_op_impl(lhs: self, rhs: rrhs) { ilhs, irhs in + if ilhs == nil && irhs == nil { + return true + } + guard let llhs = ilhs, + let rrhs = irhs + else { + return false + } + return llhs.eq(rhs: rrhs) + } + } + public func lt(rhs: any P4Value) -> Bool { + guard let rrhs = rhs as? P4StructValue else { + return false + } + return bin_op_impl(lhs: self, rhs: rrhs) { ilhs, irhs in + if ilhs == nil && irhs == nil { + return true + } + guard let llhs = ilhs, + let rrhs = irhs + else { + return false + } + return llhs.lt(rhs: rrhs) + } + } + public func lte(rhs: any P4Value) -> Bool { + guard let rrhs = rhs as? P4StructValue else { + return false + } + return bin_op_impl(lhs: self, rhs: rrhs) { ilhs, irhs in + if ilhs == nil && irhs == nil { + return true + } + guard let llhs = ilhs, + let rrhs = irhs + else { + return false + } + return llhs.lte(rhs: rrhs) + } + } + public func gt(rhs: any P4Value) -> Bool { + guard let rrhs = rhs as? P4StructValue else { + return false + } + return bin_op_impl(lhs: self, rhs: rrhs) { ilhs, irhs in + if ilhs == nil && irhs == nil { + return true + } + guard let llhs = ilhs, + let rrhs = irhs + else { + return false + } + return llhs.gt(rhs: rrhs) + } + } + + public func gte(rhs: any P4Value) -> Bool { + guard let rrhs = rhs as? P4StructValue else { + return false + } + return bin_op_impl(lhs: self, rhs: rrhs) { ilhs, irhs in + if ilhs == nil && irhs == nil { + return true + } + guard let llhs = ilhs, + let rrhs = irhs + else { + return false + } + return llhs.gte(rhs: rrhs) + } + } + public var description: String { return "Struct: \(self.stype.fields.describe_with_values(values: self.values))" } @@ -253,6 +368,10 @@ public class P4BooleanValue: P4Value { let value: Bool + public func access() -> Bool { + return self.value + } + public init(withValue value: Bool) { self.value = value } @@ -263,6 +382,34 @@ public class P4BooleanValue: P4Value { return self.value == bool_rhs.value } + public func lt(rhs: P4Value) -> Bool { + guard let bool_rhs = rhs as? P4BooleanValue else { + return false + } + return (self.value ? 1 : 0 ) < (bool_rhs.value ? 1 : 0) + } + + public func lte(rhs: P4Value) -> Bool { + guard let bool_rhs = rhs as? P4BooleanValue else { + return false + } + return (self.value ? 1 : 0 ) <= (bool_rhs.value ? 1 : 0) + } + + public func gt(rhs: P4Value) -> Bool { + guard let bool_rhs = rhs as? P4BooleanValue else { + return false + } + return (self.value ? 1 : 0 ) > (bool_rhs.value ? 1 : 0) + } + + public func gte(rhs: P4Value) -> Bool { + guard let bool_rhs = rhs as? P4BooleanValue else { + return false + } + return (self.value ? 1 : 0 ) >= (bool_rhs.value ? 1 : 0) + } + public var description: String { "\(self.value ? "true" : "false") of \(self.type()) type" } @@ -307,6 +454,35 @@ public class P4IntValue: P4Value { } return self.value == int_rhs.value } + + public func lt(rhs: P4Value) -> Bool { + guard let int_rhs = rhs as? P4IntValue else { + return false + } + return self.value < int_rhs.value + } + + public func lte(rhs: P4Value) -> Bool { + guard let int_rhs = rhs as? P4IntValue else { + return false + } + return self.value <= int_rhs.value + } + + public func gt(rhs: P4Value) -> Bool { + guard let int_rhs = rhs as? P4IntValue else { + return false + } + return self.value > int_rhs.value + } + + public func gte(rhs: P4Value) -> Bool { + guard let int_rhs = rhs as? P4IntValue else { + return false + } + return self.value >= int_rhs.value + } + public var description: String { "\(self.value) of \(self.type()) type" } @@ -345,6 +521,34 @@ public class P4StringValue: P4Value { return self.value == string_rhs.value } + public func lt(rhs: P4Value) -> Bool { + guard let string_rhs = rhs as? P4StringValue else { + return false + } + return self.value < string_rhs.value + } + + public func lte(rhs: P4Value) -> Bool { + guard let string_rhs = rhs as? P4StringValue else { + return false + } + return self.value <= string_rhs.value + } + + public func gt(rhs: P4Value) -> Bool { + guard let string_rhs = rhs as? P4StringValue else { + return false + } + return self.value > string_rhs.value + } + + public func gte(rhs: P4Value) -> Bool { + guard let string_rhs = rhs as? P4StringValue else { + return false + } + return self.value >= string_rhs.value + } + public var description: String { "\(self.value) of \(self.type()) type" } @@ -415,6 +619,38 @@ public class P4ArrayValue: P4Value { return true } + public func lt(rhs: P4Value) -> Bool { + guard rhs as? P4ArrayValue != nil else { + return false + } + // TODO!! + return true + } + + public func lte(rhs: P4Value) -> Bool { + guard rhs as? P4ArrayValue != nil else { + return false + } + // TODO!! + return true + } + + public func gt(rhs: P4Value) -> Bool { + guard rhs as? P4ArrayValue != nil else { + return false + } + // TODO!! + return true + } + + public func gte(rhs: P4Value) -> Bool { + guard rhs as? P4ArrayValue != nil else { + return false + } + // TODO!! + return true + } + public var description: String { "\(self.value) of \(self.type()) type" } diff --git a/Sources/Common/Protocols.swift b/Sources/Common/Protocols.swift index 2f5780f..cc802f5 100644 --- a/Sources/Common/Protocols.swift +++ b/Sources/Common/Protocols.swift @@ -40,6 +40,10 @@ public protocol P4Type: CustomStringConvertible { public protocol P4Value: EvaluatableExpression, CustomStringConvertible { func type() -> any P4Type func eq(rhs: P4Value) -> Bool + func lt(rhs: P4Value) -> Bool + func lte(rhs: P4Value) -> Bool + func gt(rhs: P4Value) -> Bool + func gte(rhs: P4Value) -> Bool } extension P4Value { diff --git a/Sources/Common/Support.swift b/Sources/Common/Support.swift index 446705c..fb86088 100644 --- a/Sources/Common/Support.swift +++ b/Sources/Common/Support.swift @@ -133,6 +133,10 @@ extension Result: CustomStringConvertible { } } +public func Map(input: T, block: (T) -> U) -> U { + return block(input) +} + @freestanding(expression) public macro RequireOkResult(_: Result) -> Bool = #externalMacro(module: "Macros", type: "RequireResult") @freestanding(expression) public macro RequireErrorResult(_: Error, _: Result) -> Bool = diff --git a/Sources/Macros/Macros.swift b/Sources/Macros/Macros.swift index 6ef008d..75e3c27 100644 --- a/Sources/Macros/Macros.swift +++ b/Sources/Macros/Macros.swift @@ -20,7 +20,8 @@ import SwiftSyntax @_spi(ExperimentalLanguageFeature) import SwiftSyntaxMacros public func remove_embedded_quotes(_ from: String) -> String { - return from.replacing("\"", with: []) + let result = from.replacing("\"", with: []) + return result } struct MacroError: Error, CustomStringConvertible { diff --git a/Sources/P4Compiler/Expression.swift b/Sources/P4Compiler/Expression.swift index 5f2ddf8..ec42ba4 100644 --- a/Sources/P4Compiler/Expression.swift +++ b/Sources/P4Compiler/Expression.swift @@ -121,19 +121,30 @@ struct Expression { public static func Compile( node: Node, withContext: CompilerContext ) -> Result { - #RequireNodesType( nodes: node, type: ["expression", "keysetExpression"], nice_type_names: ["expression", "keyset expression"]) // If the node is a keyset expression, then dig out the expression: - let node = + var expression_node = if node.nodeType == "keysetExpression" { node.child(at: 0)! } else { node } + #RequireNodeType(node: expression_node, type: "expression", nice_type_name: "expression") + + expression_node = expression_node.child(at: 0)! + #RequireNodesType( + nodes: expression_node, type: ["grouped_expression", "simple_expression"], + nice_type_names: ["grouped expression", "simple expression"]) + + // If this is a grouped expression, recurse! + if expression_node.nodeType == "grouped_expression" { + return Expression.Compile(node: expression_node.child(at: 1)!, withContext: withContext) + } + let localElementsParsers: [CompilableExpression.Type] = [ P4BooleanValue.self, P4StringValue.self, P4IntValue.self, TypedIdentifier.self, BinaryOperatorExpression.self, ArrayAccessExpression.self, FieldAccessExpression.self, @@ -141,7 +152,7 @@ struct Expression { for le_parser in localElementsParsers { switch le_parser.compile( - node: node, withContext: withContext) + node: expression_node, withContext: withContext) { case .Ok(.some(let parsed)): return .Ok(parsed) case .Error(let e): return .Error(e) @@ -157,8 +168,29 @@ struct LValue { public static func Compile( node: Node, withContext: CompilerContext ) -> Result { + #RequireNodesType( + nodes: node, type: ["expression", "keysetExpression"], + nice_type_names: ["expression", "keyset expression"]) - // Try to compile all the expressions that are LValuable! + // If the node is a keyset expression, then dig out the expression: + var expression_node = + if node.nodeType == "keysetExpression" { + node.child(at: 0)! + } else { + node + } + + #RequireNodeType(node: expression_node, type: "expression", nice_type_name: "expression") + + expression_node = expression_node.child(at: 0)! + #RequireNodesType( + nodes: expression_node, type: ["grouped_expression", "simple_expression"], + nice_type_names: ["grouped expression", "simple expression"]) + + // If this is a grouped expression, recurse! + if expression_node.nodeType == "grouped_expression" { + return LValue.Compile(node: expression_node.child(at: 1)!, withContext: withContext) + } let lvalueParsers: [CompilableLValueExpression.Type] = [ TypedIdentifier.self, FieldAccessExpression.self, ArrayAccessExpression.self, @@ -166,7 +198,7 @@ struct LValue { for lvalue_parser in lvalueParsers { switch lvalue_parser.compile_as_lvalue( - node: node, withContext: withContext) + node: expression_node, withContext: withContext) { case .Ok(.some(let parsed)): return .Ok(parsed) case .Error(let e): return .Error(e) @@ -304,9 +336,13 @@ extension BinaryOperatorExpression: CompilableExpression { currentChild = expression.child(at: currentChildIdx) let binary_operator_expression_node = currentChild! + + // TODO: This macro cannot handle new lines in the arrays + // swift-format-ignore #RequireNodesType( - nodes: binary_operator_expression_node, type: ["binaryEqualOperatorExpression"], - nice_type_names: ["binary equal operator"]) + nodes: binary_operator_expression_node, + type: ["binaryEqualOperatorExpression", "binaryLessThanOperatorExpression", "binaryLessThanEqualOperatorExpression", "binaryGreaterThanOperatorExpression", "binaryGreaterThanEqualOperatorExpression", "binaryAndOperatorExpression", "binaryOrOperatorExpression"], + nice_type_names: [ "binary equal operator", "binary less than operator", "binary less than or equal to operator", "binary greater than operator", "binary greater than or equal to operator", "binary and operator", "binary or operator"]) if binary_operator_expression_node.childCount < currentChildIdxSafe { return Result.Error( @@ -344,9 +380,27 @@ extension BinaryOperatorExpression: CompilableExpression { return Result.Error(maybe_right_hand_side.error()!) } + let evaluators = [ + "binaryEqualOperatorExpression": ("Binary Equal", P4Boolean(), Optional.none, binary_equal_operator_evaluator), + "binaryLessThanOperatorExpression": ("Binary Less Than", P4Boolean(), Optional.none, binary_lt_operator_evaluator), + "binaryLessThanEqualOperatorExpression": ("Binary Less Than Or Equal", P4Boolean(), Optional.none, binary_lte_operator_evaluator), + "binaryGreaterThanOperatorExpression": ("Binary Greater Than", P4Boolean(), Optional.none, binary_gt_operator_evaluator), + "binaryGreaterThanEqualOperatorExpression": ("Binary Greater Than Or Equal", P4Boolean(), Optional.none, binary_gte_operator_evaluator), + "binaryAndOperatorExpression": ("Binary Or", P4Boolean(), binary_and_or_operator_checker, binary_and_operator_evaluator), + "binaryOrOperatorExpression": ("Binary And", P4Boolean(), binary_and_or_operator_checker, binary_or_operator_evaluator), + ] + + guard let selected_evaluator = evaluators[binary_operator_expression_node.nodeType!] else { + return Result.Error(Error(withMessage: "No evaluator for \(binary_operator_expression_node.nodeType!)")) + } + + if let checker = selected_evaluator.2, case .Error(let e) = checker(left_hand_side, right_hand_side) { + return Result.Error(e) + } + return .Ok( BinaryOperatorExpression( - withEvaluator: ("Binary Equal", P4Boolean(), binary_equal_operator_evaluator), + withEvaluator: (selected_evaluator.0, selected_evaluator.1, selected_evaluator.3), withLhs: left_hand_side, withRhs: right_hand_side)) } } @@ -522,6 +576,7 @@ extension FieldAccessExpression: CompilableLValueExpression { node: SwiftTreeSitter.Node, withContext context: CompilerContext ) -> Result { let expression = node.child(at: 0)! + print("expression: \(expression)") #SkipUnlessNodeType( node: expression, type: "fieldAccessExpression") diff --git a/Sources/P4Lang/Parser.swift b/Sources/P4Lang/Parser.swift index a4f41d4..1f7bf4e 100644 --- a/Sources/P4Lang/Parser.swift +++ b/Sources/P4Lang/Parser.swift @@ -63,6 +63,34 @@ public class ParserState: P4Type, P4Value, Equatable, CustomStringConvertible { } } + public func lt(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let other as ParserState: self.state < other.state + default: false + } + } + + public func lte(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let other as ParserState: self.state <= other.state + default: false + } + } + + public func gt(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let other as ParserState: self.state > other.state + default: false + } + } + + public func gte(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let other as ParserState: self.state >= other.state + default: false + } + } + public private(set) var state: Identifier public private(set) var statements: [EvaluatableStatement] @@ -184,7 +212,35 @@ public struct Parser: P4Type, P4Value { public func eq(rhs: any Common.P4Type) -> Bool { return switch rhs { - case is Parser: true + case let parser_rhs as Parser: self.name == parser_rhs.name + default: false + } + } + + public func lt(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let parser_rhs as Parser: self.name < parser_rhs.name + default: false + } + } + + public func lte(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let parser_rhs as Parser: self.name <= parser_rhs.name + default: false + } + } + + public func gt(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let parser_rhs as Parser: self.name > parser_rhs.name + default: false + } + } + + public func gte(rhs: any Common.P4Value) -> Bool { + return switch rhs { + case let parser_rhs as Parser: self.name >= parser_rhs.name default: false } } diff --git a/Sources/P4Runtime/Expressions.swift b/Sources/P4Runtime/Expressions.swift index 656ed0c..01c9e60 100644 --- a/Sources/P4Runtime/Expressions.swift +++ b/Sources/P4Runtime/Expressions.swift @@ -94,10 +94,59 @@ extension TypedIdentifier: EvaluatableLValueExpression { } public func binary_equal_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { - if left.eq(rhs: right) { - return P4BooleanValue(withValue: true) + return Map(input: left.eq(rhs: right)) { input in + P4BooleanValue(withValue: input) + } +} + +public func binary_lt_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + return Map(input: left.lt(rhs: right)) { input in + P4BooleanValue(withValue: input) + } +} + +public func binary_lte_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + return Map(input: left.lte(rhs: right)) { input in + P4BooleanValue(withValue: input) + } +} + +public func binary_gt_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + return Map(input: left.gt(rhs: right)) { input in + P4BooleanValue(withValue: input) + } +} + +public func binary_gte_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + return Map(input: left.gte(rhs: right)) { input in + P4BooleanValue(withValue: input) + } +} + +public typealias BinaryOperatorChecker = (EvaluatableExpression, EvaluatableExpression) -> Result<()> + +public func binary_and_or_operator_checker(left: EvaluatableExpression, right: EvaluatableExpression) -> Result<()> { + // Check that both are Boolean-typed things! + if !(left.type().eq(rhs: P4Boolean()) && right.type().eq(rhs: P4Boolean())) { + return .Error(Error(withMessage: "And/Or on operands with non-bool type is not allowed")) + } + return .Ok(()) +} + +public func binary_and_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + let bleft = left as! P4BooleanValue + let bright = right as! P4BooleanValue + return Map(input: bleft.access() && bright.access()) { input in + P4BooleanValue(withValue: input) + } +} + +public func binary_or_operator_evaluator(left: P4Value, right: P4Value) -> P4Value { + let bleft = left as! P4BooleanValue + let bright = right as! P4BooleanValue + return Map(input: bleft.access() || bright.access()) { input in + P4BooleanValue(withValue: input) } - return P4BooleanValue(withValue: false) } extension BinaryOperatorExpression: EvaluatableExpression {