compiler: Flesh Out CST Visitor Framework

As a use case, use it to implement text serialization
of the CST.

Signed-off-by: Will Hawkins <hawkinsw@obs.cr>
This commit is contained in:
Will Hawkins
2026-06-15 23:37:04 -04:00
parent d22776b018
commit aa12974dd6
5 changed files with 277 additions and 31 deletions
+172 -10
View File
@@ -24,23 +24,154 @@ public struct CSTTextSerializer {
public struct CSTTextSerializerContext { public struct CSTTextSerializerContext {
public let serialized: String public let serialized: String
public let indents: Int
public init(_ serialized: String = "") { public init(_ serialized: String = "", _ indents: Int = 0) {
self.serialized = serialized self.serialized = serialized
self.indents = indents
} }
static func produceIndent(_ indent: Int, _ marker: String) -> String {
return repeatElement(marker, count: indent).joined()
}
public func append(_ a: String) -> CSTTextSerializerContext { public func append(_ a: String) -> CSTTextSerializerContext {
return CSTTextSerializerContext(self.serialized + a) return CSTTextSerializerContext(
self.serialized + Self.produceIndent(self.indents, "\t") + a + "\n", self.indents)
}
public func indent() -> CSTTextSerializerContext {
return CSTTextSerializerContext(self.serialized, self.indents + 1)
}
public func unindent() -> CSTTextSerializerContext {
return CSTTextSerializerContext(self.serialized, self.indents - 1)
} }
} }
extension CSTTextSerializer: CSTVisitor<CSTTextSerializerContext> { extension CSTTextSerializer: CSTVisitor<CSTTextSerializerContext> {
public func visit(
node: P4Parser.CST.KeysetExpression, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
var context = context.append("Keyset Expression:").indent()
return .Ok(context.unindent())
if case CST.KeysetExpression.Value(let x) = node {
switch driver.visit(x, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
}
return .Ok(context.unindent())
}
public func visit(
node: P4Parser.CST.SelectCaseExpression, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
var context = context.append("Case Expression:").indent()
switch driver.visit(node.key, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
context = context.append("Next State:").indent()
switch driver.visit(node.next_state_identifier, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
context = context.unindent()
return .Ok(context.unindent())
}
public func visit(
node: P4Parser.CST.SelectExpression, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
var context = context.append("Select Expression:").indent()
context = context.append("Selector:").indent()
switch driver.visit(node.selector, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
context = context.unindent()
context = context.append("Case Expressions:").indent()
for ce in node.case_expressions {
switch driver.visit(ce, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
}
context = context.unindent()
return .Ok(context.unindent())
}
public func visit(
node: P4Parser.CST.Statements, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
var context = context
for s in node.statements {
switch driver.visit(s, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
}
return .Ok(context)
}
public func visit(
node: P4Parser.CST.ExpressionStatement, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
var context = context.append("Expression Statement:").indent()
switch driver.visit(node.expression, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e):
return .Error(e)
}
return .Ok(context.unindent())
}
public func visit(
node: P4Parser.CST.Control, driver: P4Parser.CSTVisitorDriver, context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Control Declaration"))
}
public func visit(
node: P4Parser.CST.ExternDeclaration, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Extern Declaration"))
}
public func visit(
node: P4Parser.CST.FunctionDeclaration, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Function Declaration"))
}
public func visit(
node: P4Parser.CST.StructDeclaration, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Struct Declaration"))
}
public func visit(
node: P4Parser.CST.VariableDeclarationStatement, driver: P4Parser.CSTVisitorDriver,
context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Variable Declaration Statement"))
}
public func visit( public func visit(
node: CST.BinaryOperatorExpression, driver: CSTVisitorDriver, node: CST.BinaryOperatorExpression, driver: CSTVisitorDriver,
context: CSTTextSerializerContext context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> { ) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Binary Operator Expression")) return .Ok(context.append("Binary Operator Expression"))
} }
public func visit( public func visit(
@@ -54,28 +185,44 @@ extension CSTTextSerializer: CSTVisitor<CSTTextSerializerContext> {
node: CST.Identifier, driver: CSTVisitorDriver, node: CST.Identifier, driver: CSTVisitorDriver,
context: CSTTextSerializerContext context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> { ) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("Identifier Expression")) return .Ok(context.append("Identifier: \(node.id)"))
} }
public func visit( public func visit(
node: CST.Parser, driver: CSTVisitorDriver, context: CSTTextSerializerContext node: CST.Parser, driver: CSTVisitorDriver, context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> { ) -> Common.Result<CSTTextSerializerContext> {
var context = context.append("Identifier Expression") var context = context.append("Parser Expression")
context = context.indent()
for s in node.states.states { for s in node.states.states {
switch driver.visit(state: s, visitor: self, context: context) { switch driver.visit(s, visitor: self, context: context) {
case .Ok(let c): context = c case .Ok(let c): context = c
case .Error(let e): return .Error(e) case .Error(let e): return .Error(e)
} }
} }
return .Ok(context.unindent())
return .Ok(context)
} }
public func visit( public func visit(
node: CST.ParserStateDirectTransition, driver: CSTVisitorDriver, node: CST.ParserStateDirectTransition, driver: CSTVisitorDriver,
context: CSTTextSerializerContext context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> { ) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("State: Direct Transition")) var context = context.append("State: Direct Transition").indent()
context = context.append("Statements:")
context = context.indent()
if let statements = node.statements {
switch driver.visit(statements, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
}
context = context.unindent()
context = context.append("Next State:").indent()
switch driver.visit(node.next_state_identifier!, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
context = context.unindent()
return .Ok(context.unindent())
} }
public func visit( public func visit(
@@ -89,6 +236,21 @@ extension CSTTextSerializer: CSTVisitor<CSTTextSerializerContext> {
node: CST.ParserStateSelectTransition, driver: CSTVisitorDriver, node: CST.ParserStateSelectTransition, driver: CSTVisitorDriver,
context: CSTTextSerializerContext context: CSTTextSerializerContext
) -> Common.Result<CSTTextSerializerContext> { ) -> Common.Result<CSTTextSerializerContext> {
return .Ok(context.append("State: Direct Transition"))
var context = context.append("State: Select Transition").indent()
if let statements = node.statements {
context = context.indent()
switch driver.visit(statements, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
context = context.unindent()
}
switch driver.visit(node.te, visitor: self, context: context) {
case .Ok(let c): context = c
case .Error(let e): return .Error(e)
}
return .Ok(context.unindent())
} }
} }
+5 -4
View File
@@ -22,15 +22,16 @@ extension P4Value: CST.Categories.Expression {}
public struct CST { public struct CST {
public struct Categories { public struct Categories {
public protocol Expression {} public protocol LanguageElement {}
public protocol Statement {} public protocol Expression: Categories.LanguageElement {}
public protocol State {} public protocol Statement : Categories.LanguageElement {}
public protocol State : Categories.LanguageElement {}
public protocol Declaration: Categories.Statement {} public protocol Declaration: Categories.Statement {}
} }
struct Expression {} struct Expression {}
public struct Statements { public struct Statements: Categories.Statement {
public let statements: [Categories.Statement] public let statements: [Categories.Statement]
public init(_ s: [Categories.Statement]) { public init(_ s: [Categories.Statement]) {
+18 -1
View File
@@ -58,11 +58,28 @@ public protocol ParsableStatement {
public protocol CSTVisitor<T> { public protocol CSTVisitor<T> {
associatedtype T associatedtype T
// Declarations
func visit(node: CST.Control, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.ExternDeclaration, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.FunctionDeclaration, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.StructDeclaration, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.Parser, driver: CSTVisitorDriver, context: T) -> Result<T>
// Statements
func visit(node: CST.Statements, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.VariableDeclarationStatement, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.ExpressionStatement, driver: CSTVisitorDriver, context: T) -> Result<T>
// Expressions
func visit(node: CST.KeysetExpression, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.SelectCaseExpression, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.SelectExpression, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.BinaryOperatorExpression, driver: CSTVisitorDriver, context: T) -> Result<T> func visit(node: CST.BinaryOperatorExpression, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.Literal, driver: CSTVisitorDriver, context: T) -> Result<T> func visit(node: CST.Literal, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.Identifier, driver: CSTVisitorDriver, context: T) -> Result<T> func visit(node: CST.Identifier, driver: CSTVisitorDriver, context: T) -> Result<T>
func visit(node: CST.Parser, driver: CSTVisitorDriver, context: T) -> Result<T> // Parser
func visit( func visit(
node: CST.ParserStateDirectTransition, driver: CSTVisitorDriver, context: T node: CST.ParserStateDirectTransition, driver: CSTVisitorDriver, context: T
) -> Result<T> ) -> Result<T>
+54 -13
View File
@@ -21,17 +21,56 @@ public struct CSTVisitorDriver {
public init() {} public init() {}
public func visit<T>( public func visit<T>(
expression: any CST.Categories.Expression, visitor: any CSTVisitor<T>, context: T _ elem: any CST.Categories.LanguageElement, visitor: any CSTVisitor<T>, context: T
) -> Result<T> { ) -> Result<T> {
return switch expression { return switch elem {
case let e as CST.BinaryOperatorExpression: case let elem as CST.Categories.Expression:
visitor.visit(node: e, driver: self, context: context) visit(expression: elem, visitor: visitor, context: context)
case let e as CST.Literal: visitor.visit(node: e, driver: self, context: context) case let elem as CST.Categories.Statement:
default: .Error(Error(withMessage: "AST Expression Element Is Not Visitable")) visit(statement: elem, visitor: visitor, context: context)
case let elem as CST.Categories.State: visit(state: elem, visitor: visitor, context: context)
case let elem as CST.Categories.Declaration:
visit(declaration: elem, visitor: visitor, context: context)
default: .Error(Error(withMessage: "AST Language Element (\(elem)) Is Not Visitable"))
} }
} }
public func visit<T>( func visit<T>(
declaration: any CST.Categories.Declaration, visitor: any CSTVisitor<T>, context: T
) -> Result<T> {
return switch declaration {
case let elem as CST.Control: visitor.visit(node: elem, driver: self, context: context)
case let elem as CST.ExternDeclaration:
visitor.visit(node: elem, driver: self, context: context)
case let elem as CST.FunctionDeclaration:
visitor.visit(node: elem, driver: self, context: context)
case let elem as CST.StructDeclaration:
visitor.visit(node: elem, driver: self, context: context)
case let elem as CST.Parser: visitor.visit(node: elem, driver: self, context: context)
default: .Error(Error(withMessage: "AST Declaration Element (\(declaration)) Is Not Visitable"))
}
}
func visit<T>(
expression: any CST.Categories.Expression, visitor: any CSTVisitor<T>, context: T
) -> Result<T> {
return switch expression {
case let s as CST.Identifier:
visitor.visit(node: s, driver: self, context: context)
case let s as CST.KeysetExpression:
visitor.visit(node: s, driver: self, context: context)
case let s as CST.SelectCaseExpression:
visitor.visit(node: s, driver: self, context: context)
case let e as CST.SelectExpression:
visitor.visit(node: e, driver: self, context: context)
case let e as CST.BinaryOperatorExpression:
visitor.visit(node: e, driver: self, context: context)
case let e as CST.Literal: visitor.visit(node: e, driver: self, context: context)
default: .Error(Error(withMessage: "AST Expression Element (\(expression)) Is Not Visitable"))
}
}
func visit<T>(
state: any CST.Categories.State, visitor: any CSTVisitor<T>, context: T state: any CST.Categories.State, visitor: any CSTVisitor<T>, context: T
) -> Result<T> { ) -> Result<T> {
return switch state { return switch state {
@@ -45,26 +84,28 @@ public struct CSTVisitorDriver {
} }
} }
public func visit<T>( func visit<T>(
statement: any CST.Categories.Statement, visitor: any CSTVisitor<T>, context: T statement: any CST.Categories.Statement, visitor: any CSTVisitor<T>, context: T
) -> Result<T> { ) -> Result<T> {
return switch statement { return switch statement {
case let s as CST.Statements: visitor.visit(node: s, driver: self, context: context)
case let s as CST.ExpressionStatement: visitor.visit(node: s, driver: self, context: context)
case let s as CST.VariableDeclarationStatement:
visitor.visit(node: s, driver: self, context: context)
case let s as CST.Parser: visitor.visit(node: s, driver: self, context: context) case let s as CST.Parser: visitor.visit(node: s, driver: self, context: context)
default: .Error(Error(withMessage: "AST Statement Element Is Not Visitable")) default: .Error(Error(withMessage: "AST Statement Element (\(statement)) Is Not Visitable"))
} }
} }
public func visit<T>( public func start<T>(
program: CST.Program, visitor: any CSTVisitor<T>, context: T program: CST.Program, visitor: any CSTVisitor<T>, context: T
) -> Result<T> { ) -> Result<T> {
var context = context var context = context
for s in program.statements.statements { switch visit(statement: program.statements, visitor: visitor, context: context) {
switch visit(statement: s, visitor: visitor, context: context) {
case .Ok(let c): context = c case .Ok(let c): context = c
case .Error(let e): return .Error(e) case .Error(let e): return .Error(e)
} }
}
return .Ok(context) return .Ok(context)
} }
+26 -1
View File
@@ -46,5 +46,30 @@ import TreeSitterP4
let v = CSTTextSerializer() let v = CSTTextSerializer()
let c = CSTTextSerializerContext(); let c = CSTTextSerializerContext();
let vd = CSTVisitorDriver(); let vd = CSTVisitorDriver();
#expect(#RequireOkResult((vd.visit(program: program, visitor: v, context: c)))) let result = try #UseOkResult((vd.start(program: program, visitor: v, context: c)))
let expected = """
Parser Expression
State: Direct Transition
Statements:
Expression Statement:
Literal Expression
Next State:
Identifier: accept
State: Select Transition
Select Expression:
Selector:
Literal Expression
Case Expressions:
Case Expression:
Keyset Expression:
Next State:
Identifier: accept
Case Expression:
Keyset Expression:
Next State:
Identifier: reject
"""
#expect(result.serialized == expected)
} }