From 99804e30532e7f2dea6eb20a99f87f7a4e5b6219 Mon Sep 17 00:00:00 2001 From: Will Hawkins Date: Fri, 27 Mar 2026 01:36:38 -0400 Subject: [PATCH] Better Support For Keysets Signed-off-by: Will Hawkins --- Sources/Common/ProgramTypes.swift | 116 +++++++++++++++ Sources/P4Compiler/Expression.swift | 80 +++++----- Sources/P4Lang/Expressions.swift | 124 ++++++++++++++-- Sources/P4Runtime/Expressions.swift | 14 +- Tests/p4rseTests/ExpressionTests/Keyset.swift | 139 ++++++++++++++++++ tree-sitter-p4/grammar.js | 3 +- tree-sitter-p4/test/corpus/transitions.txt | 51 +++++++ 7 files changed, 476 insertions(+), 51 deletions(-) create mode 100644 Tests/p4rseTests/ExpressionTests/Keyset.swift diff --git a/Sources/Common/ProgramTypes.swift b/Sources/Common/ProgramTypes.swift index 2684a55..03ce4f1 100644 --- a/Sources/Common/ProgramTypes.swift +++ b/Sources/Common/ProgramTypes.swift @@ -656,3 +656,119 @@ public class P4ArrayValue: P4Value { "\(self.value) of \(self.type()) type" } } + +/// A P4 set type +public struct P4Set: P4Type { + public init(withSetType stype: P4Type) { + self.stype = stype + } + + let stype: P4Type + + public func set_type() -> P4Type { + return self.stype + } + + public var description: String { + return "P4Set" + } + + public func eq(rhs: any P4Type) -> Bool { + return switch rhs { + // If rhs is a set type, then they are the same if the types in the set are the same. + case let srhs as P4Set: srhs.eq(rhs: self.stype) + default: false + } + } + + public func def() -> P4Value { + return P4ArrayValue(withType: self, withValue: []) + } +} + +/// An instance of a P4 set +public class P4SetValue: P4Value { + public func type() -> any P4Type { + return P4Set(withSetType: self.stype) + } + + let value: P4Value + let stype: P4Type + + public init(withType type: P4Type, withValue value: P4Value) { + self.stype = type + self.value = value + } + + public func access() -> P4Value { + return self.value + } + + public func eq(rhs: P4Value) -> Bool { + guard let rrhs = rhs as? P4SetValue else { + return false + } + return rrhs.access().eq(rhs: self.value) + } + public func lt(rhs: P4Value) -> Bool { + guard let rrhs = rhs as? P4SetValue else { + return false + } + return rrhs.access().lt(rhs: self.value) + } + public func lte(rhs: P4Value) -> Bool { + guard let rrhs = rhs as? P4SetValue else { + return false + } + return rrhs.access().lte(rhs: self.value) + } + public func gt(rhs: P4Value) -> Bool { + guard let rrhs = rhs as? P4SetValue else { + return false + } + return rrhs.access().gt(rhs: self.value) + } + public func gte(rhs: P4Value) -> Bool { + guard let rrhs = rhs as? P4SetValue else { + return false + } + return rrhs.access().gte(rhs: self.value) + } + + public var description: String { + "P4Set with \(self.value) of \(self.type()) type" + } +} + +public class P4SetDefaultValue: P4Value { + public func type() -> any P4Type { + return P4Set(withSetType: self.stype) + } + + let stype: P4Type + + public init(withType type: P4Type) { + self.stype = type + } + + // Snarf up everything! + public func eq(rhs: P4Value) -> Bool { + return true + } + public func lt(rhs: P4Value) -> Bool { + return true + } + public func lte(rhs: P4Value) -> Bool { + return true + } + public func gt(rhs: P4Value) -> Bool { + return true + } + public func gte(rhs: P4Value) -> Bool { + return true + } + + public var description: String { + "Default of P4Set of \(self.type()) type" + } +} diff --git a/Sources/P4Compiler/Expression.swift b/Sources/P4Compiler/Expression.swift index 8f3a3fa..523c872 100644 --- a/Sources/P4Compiler/Expression.swift +++ b/Sources/P4Compiler/Expression.swift @@ -117,26 +117,40 @@ extension P4StringValue: CompilableExpression { } } +extension KeysetExpression: CompilableExpression { + static func compile( + node: SwiftTreeSitter.Node, withContext context: CompilerContext + ) -> Common.Result<(any Common.EvaluatableExpression)?> { + let keyset_expression_node = node.child(at: 0)! + + #RequireNodesType( + nodes: keyset_expression_node, type: ["expression", "default_keyset"], + nice_type_names: ["expression", "default keyset"]) + + // If there is a default keyset, that's easy! + if keyset_expression_node.nodeType == "default_keyset" { + return .Ok(PlaceholderDefaultKeysetExpression()) + } + + // Compile the expression: + let maybe_compiled_set_expression = Expression.Compile( + node: keyset_expression_node, withContext: context) + guard case .Ok(let compiled_expression) = maybe_compiled_set_expression else { + return .Error(maybe_compiled_set_expression.error()!) + } + + return .Ok(NonDefaultKeysetExpression(compiled_expression)) + } +} + struct Expression { public static func Compile( node: Node, withContext: CompilerContext ) -> Result { - #RequireNodesType( - nodes: node, type: ["expression", "keysetExpression"], - nice_type_names: ["expression", "keyset expression"]) - - // If the node is a keyset expression, then dig out the expression: - var expression_node = - if node.nodeType == "keysetExpression" { - node.child(at: 0)! - } else { - node - } - #RequireNodeType( - node: expression_node, type: "expression", nice_type_name: "expression") + node: node, type: "expression", nice_type_name: "expression") - expression_node = expression_node.child(at: 0)! + let expression_node = node.child(at: 0)! #RequireNodesType( nodes: expression_node, type: ["grouped_expression", "simple_expression"], nice_type_names: ["grouped expression", "simple expression"]) @@ -169,22 +183,10 @@ struct LValue { public static func Compile( node: Node, withContext: CompilerContext ) -> Result { - #RequireNodesType( - nodes: node, type: ["expression", "keysetExpression"], - nice_type_names: ["expression", "keyset expression"]) - - // If the node is a keyset expression, then dig out the expression: - var expression_node = - if node.nodeType == "keysetExpression" { - node.child(at: 0)! - } else { - node - } - #RequireNodeType( - node: expression_node, type: "expression", nice_type_name: "expression") + node: node, type: "expression", nice_type_name: "expression") - expression_node = expression_node.child(at: 0)! + let expression_node = node.child(at: 0)! #RequireNodesType( nodes: expression_node, type: ["grouped_expression", "simple_expression"], nice_type_names: ["grouped expression", "simple expression"]) @@ -253,14 +255,18 @@ extension SelectExpression: CompilableExpression { )) } - var kses: [KeysetExpression] = Array() + var kses: [SelectCaseExpression] = Array() var kses_errors: [Error] = Array() select_body_node.enumerateNamedChildren { current_node in - let maybe_parsed_kse = KeysetExpression.compile( + let maybe_parsed_kse = SelectCaseExpression.compile( node: current_node, withContext: context) if case .Ok(let parsed_kse) = maybe_parsed_kse { - kses.append(parsed_kse as! KeysetExpression) + let parsed_cse = parsed_kse as! SelectCaseExpression + switch parsed_cse.update_type(to: selector.type()) { + case .Ok(let updated_cse): kses.append(updated_cse) + case .Error(let e): kses_errors.append(ErrorOnNode(node: current_node, withError: e.msg)) + } } else { kses_errors.append(Error(withMessage: "\(maybe_parsed_kse.error()!)")) } @@ -272,15 +278,15 @@ extension SelectExpression: CompilableExpression { withMessage: "Error(s) parsing select cases: " + (kses_errors.map { error in return "\(error.msg)" - }.joined(separator: ";\n")))) + }.joined(separator: ";")))) } return .Ok( - SelectExpression(withSelector: selector, withKeysetExpressions: kses), + SelectExpression(withSelector: selector, withSelectCaseExpressions: kses), ) } } -extension KeysetExpression: CompilableExpression { +extension SelectCaseExpression: CompilableExpression { static func compile( node: Node, withContext context: CompilerContext ) -> Result { @@ -300,7 +306,7 @@ extension KeysetExpression: CompilableExpression { return Result.Error(Error(withMessage: "Missing target state in select case")) } - let maybe_parsed_keysetexpression = Expression.Compile( + let maybe_parsed_keysetexpression = KeysetExpression.compile( node: keysetexpression_node, withContext: context) guard case Result.Ok(let keysetexpression) = maybe_parsed_keysetexpression else { return Result.Error(maybe_parsed_keysetexpression.error()!) @@ -313,8 +319,8 @@ extension KeysetExpression: CompilableExpression { } return .Ok( - KeysetExpression( - withKey: keysetexpression, withNextState: targetstate) + SelectCaseExpression( + withKey: keysetexpression as! KeysetExpression, withNextState: targetstate) ) } } diff --git a/Sources/P4Lang/Expressions.swift b/Sources/P4Lang/Expressions.swift index f130ea1..9e7daa8 100644 --- a/Sources/P4Lang/Expressions.swift +++ b/Sources/P4Lang/Expressions.swift @@ -17,43 +17,145 @@ import Common -public struct KeysetExpression { +public class KeysetExpression { + public func update_type(to: P4Type) -> Result { + return .Ok(self) + } + + public func kse_evaluate(execution: Common.ProgramExecution) -> Result { + return .Error(Error(withMessage: "Missing key in keyset expression")) + } + + public func kse_type() -> P4Type { + return P4Boolean() + } +} + +public class NonDefaultKeysetExpression: KeysetExpression { public let key: EvaluatableExpression + + public init(_ key: EvaluatableExpression) { + self.key = key + } + + // Some keyset expressions need additional + // context about their types -- e.g., default. + // Override to update and return true if the + // update is safe. + public override func update_type(to: P4Type) -> Result { + // In the default case, if the current key type + // does not match the updated type, that's an + // error. + return Map(input: key.type().eq(rhs: to)) { input in + input + ? .Ok(self) + : .Error( + Error(withMessage: "Keyset expression type does not match selector expression type")) + } + } + + public override func kse_evaluate(execution: Common.ProgramExecution) -> Result { + return self.key.evaluate(execution: execution) + } + + public override func kse_type() -> P4Type { + return self.key.type() + } + +} + +public class DefaultKeysetExpression: KeysetExpression { + let type: P4Type + + public init(withType type: P4Type) { + self.type = type + } + + public override func update_type(to: P4Type) -> Result { + return Map(input: type.eq(rhs: to)) { input in + input + ? .Ok(DefaultKeysetExpression(withType: to)) + : .Error( + Error(withMessage: "Keyset expression type does not match selector expression type")) + } + } + + public override func kse_evaluate(execution: Common.ProgramExecution) -> Result { + return .Ok(P4SetDefaultValue(withType: self.type)) + } + + public override func kse_type() -> P4Type { + return P4Set(withSetType: self.type) + } +} + +public class PlaceholderDefaultKeysetExpression: KeysetExpression { + public override init() {} + + public override func update_type(to: P4Type) -> Result { + .Ok(DefaultKeysetExpression(withType: to)) + } + + public override func kse_evaluate(execution: Common.ProgramExecution) -> Result { + return .Error(Error(withMessage: "Cannot evaluate a placeholder default keyset expression")) + } + + public override func kse_type() -> P4Type { + return P4Set(withSetType: P4Boolean()) + } +} + +public struct SelectCaseExpression { + public let key: KeysetExpression public let next_state_identifier: Identifier public let next_state: ParserState? - public init(withKey key: EvaluatableExpression, withNextState next_state_id: Identifier) { + public init(withKey key: KeysetExpression, withNextState next_state_id: Identifier) { self.key = key self.next_state_identifier = next_state_id self.next_state = .none } public init( - withKey key: EvaluatableExpression, withNextState next_state_id: Identifier, - withNextState next_state: ParserState + withKey key: KeysetExpression, withNextState next_state_id: Identifier, + withNextState next_state: ParserState? ) { self.key = key self.next_state_identifier = next_state_id self.next_state = next_state } + // Some keyset expressions need additional + // context about their types -- e.g., default. + // Override to update and return true if the + public func update_type(to: P4Type) -> Result { + switch key.update_type(to: to) { + case .Ok(let new_kse): + .Ok( + SelectCaseExpression( + withKey: new_kse, withNextState: self.next_state_identifier, + withNextState: self.next_state)) + case .Error(let e): .Error(e) + } + } } public struct SelectExpression { public let selector: EvaluatableExpression - public let keyset_expressions: [KeysetExpression] + public let select_expressions: [SelectCaseExpression] public init( - withSelector selector: EvaluatableExpression, withKeysetExpressions kses: [KeysetExpression] + withSelector selector: EvaluatableExpression, + withSelectCaseExpressions kses: [SelectCaseExpression] ) { self.selector = selector - self.keyset_expressions = kses + self.select_expressions = kses } - public func append_checked_kse(kse: KeysetExpression) -> SelectExpression { - var new_kse = self.keyset_expressions - new_kse.append(kse) + public func append_checked_sce(sce: SelectCaseExpression) -> SelectExpression { + var new_cses = self.select_expressions + new_cses.append(sce) return SelectExpression( - withSelector: self.selector, withKeysetExpressions: new_kse) + withSelector: self.selector, withSelectCaseExpressions: new_cses) } } diff --git a/Sources/P4Runtime/Expressions.swift b/Sources/P4Runtime/Expressions.swift index cdde769..27aff88 100644 --- a/Sources/P4Runtime/Expressions.swift +++ b/Sources/P4Runtime/Expressions.swift @@ -18,7 +18,7 @@ import Common import P4Lang -extension KeysetExpression: EvaluatableExpression { +extension SelectCaseExpression: EvaluatableExpression { public func evaluate(execution: Common.ProgramExecution) -> Common.Result { return execution.scopes.lookup(identifier: next_state_identifier) } @@ -33,7 +33,7 @@ extension SelectExpression: EvaluatableExpression { public func evaluate(execution: Common.ProgramExecution) -> Common.Result { switch self.selector.evaluate(execution: execution) { case .Ok(let selector_value): - for kse in self.keyset_expressions { + for kse in self.select_expressions { if case .Ok(let kse_key) = kse.key.evaluate(execution: execution), kse_key.eq(rhs: selector_value) { @@ -374,3 +374,13 @@ extension FieldAccessExpression: EvaluatableLValueExpression { return .Ok(()) } } + +extension KeysetExpression: EvaluatableExpression { + public func evaluate(execution: Common.ProgramExecution) -> Common.Result { + return self.kse_evaluate(execution: execution) + } + + public func type() -> any Common.P4Type { + return self.kse_type() + } +} diff --git a/Tests/p4rseTests/ExpressionTests/Keyset.swift b/Tests/p4rseTests/ExpressionTests/Keyset.swift new file mode 100644 index 0000000..dbc750b --- /dev/null +++ b/Tests/p4rseTests/ExpressionTests/Keyset.swift @@ -0,0 +1,139 @@ +// p4rse, Copyright 2026, Will Hawkins +// +// This file is part of p4rse. +// +// This file is free software: you can redistribute it and/or modify +// it under the terms of the GNU General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU General Public License for more details. +// +// You should have received a copy of the GNU General Public License +// along with this program. If not, see . + +import Common +import Foundation +import Macros +import P4Lang +import P4Runtime +import SwiftTreeSitter +import Testing +import TreeSitter +import TreeSitterP4 + +@testable import P4Compiler + +@Test func test_simple_parser_with_transition_select_case_nondefault_expressions() async throws { + let simple_parser_declaration = """ + parser main_parser() { + state start { + transition select (true) { + false: reject; + true: accept; + }; + } + }; + """ + + let program = try #UseOkResult(Program.Compile(simple_parser_declaration)) + let parser = try #UseOkResult(program.find_parser(withName: Identifier(name: "main_parser"))) + let runtime = try #UseOkResult(P4Runtime.ParserRuntime.create(program: program)) + let (state_result, _) = try! #UseOkResult(runtime.run()) + + #expect(parser.states.count() == 1) + + #expect(state_result == P4Lang.accept) +} + +@Test func test_simple_parser_with_transition_select_case_default_expression() async throws { + let simple_parser_declaration = """ + parser main_parser() { + state start { + transition select (5) { + 5: reject; + _: accept; + }; + } + }; + """ + + let program = try #UseOkResult(Program.Compile(simple_parser_declaration)) + let parser = try #UseOkResult(program.find_parser(withName: Identifier(name: "main_parser"))) + let runtime = try #UseOkResult(P4Runtime.ParserRuntime.create(program: program)) + let (state_result, _) = try! #UseOkResult(runtime.run()) + + #expect(parser.states.count() == 1) + + #expect(state_result == P4Lang.reject) +} + +@Test func test_simple_parser_with_transition_select_case_default_expression2() async throws { + let simple_parser_declaration = """ + parser main_parser() { + state start { + transition select (1) { + 5: reject; + _: accept; + }; + } + }; + """ + + let program = try #UseOkResult(Program.Compile(simple_parser_declaration)) + let parser = try #UseOkResult(program.find_parser(withName: Identifier(name: "main_parser"))) + let runtime = try #UseOkResult(P4Runtime.ParserRuntime.create(program: program)) + let (state_result, _) = try! #UseOkResult(runtime.run()) + + #expect(parser.states.count() == 1) + + #expect(state_result == P4Lang.accept) +} + +@Test func test_simple_parser_with_transition_select_case_default_expression3() async throws { + let simple_parser_declaration = """ + parser main_parser() { + state start { + transition select (6) { + 5: reject; + 6: reject; + _: accept; + }; + } + }; + """ + + let program = try #UseOkResult(Program.Compile(simple_parser_declaration)) + let parser = try #UseOkResult(program.find_parser(withName: Identifier(name: "main_parser"))) + let runtime = try #UseOkResult(P4Runtime.ParserRuntime.create(program: program)) + let (state_result, _) = try! #UseOkResult(runtime.run()) + + #expect(parser.states.count() == 1) + + #expect(state_result == P4Lang.reject) +} + +@Test func test_simple_parser_with_transition_select_case_invalid_type() async throws { + let simple_parser_declaration = """ + parser main_parser() { + state start { + transition select (6) { + true: reject; + 6: reject; + _: accept; + }; + } + }; + """ + + #expect( + #RequireErrorResult( + Error( + withMessage: + "Error(s) parsing select cases: {81, 12}: Keyset expression type does not match selector expression type" + ), + Program.Compile(simple_parser_declaration))) +} diff --git a/tree-sitter-p4/grammar.js b/tree-sitter-p4/grammar.js index d459694..3c6b363 100644 --- a/tree-sitter-p4/grammar.js +++ b/tree-sitter-p4/grammar.js @@ -94,7 +94,7 @@ export default grammar({ booleanLiteralExpression: $ => choice($.true, $.false), selectExpression: $ => seq($.select, '(', $.expression, ')', '{', $.selectBody, '}'), // TODO: Should be expression list and not just a single expression transitionSelectionExpression: $ => choice($.identifier, $.selectExpression), - keysetExpression: $ => $.expression, + keysetExpression: $ => choice($.expression, $.default_keyset), binaryOperatorExpression: $ => choice($.binaryEqualOperatorExpression, $.binaryLessThanOperatorExpression, $.binaryLessThanEqualOperatorExpression, @@ -177,6 +177,7 @@ export default grammar({ string_literal: $ => /"[^"]*"/, integer: $ => /[0-9]+/, annotation_literal: $ => /@[A-Za-z_]+/, + default_keyset: $=> '_', double_equal: $=> '==', less_than: $=> '<', diff --git a/tree-sitter-p4/test/corpus/transitions.txt b/tree-sitter-p4/test/corpus/transitions.txt index 47042d2..d755667 100644 --- a/tree-sitter-p4/test/corpus/transitions.txt +++ b/tree-sitter-p4/test/corpus/transitions.txt @@ -87,3 +87,54 @@ parser simple() { ) ) ) + +========================= +Simple Transition Statement (To Select Expression With Default) +========================= +parser simple() { + state start { + transition select (se) { + _: next_state; + }; + } +}; + +--- +(p4program + (declaration + (parserDeclaration + (parserType + (parser) + (identifier) + ) + (parserStates + (parserState + (state) + (identifier) + (parserTransitionStatement + (transition) + (transitionSelectionExpression + (selectExpression + (select) + (expression + (simple_expression + (identifier) + ) + ) + (selectBody + (selectCase + (keysetExpression + (default_keyset) + ) + (colon) + (identifier) + ) + ) + ) + ) + ) + ) + ) + ) + ) +)