Compose records in arg parser (#234)

This commit is contained in:
Patrick Stevens
2024-09-04 00:04:23 +01:00
committed by GitHub
parent 86b938c81e
commit 2220f88053
9 changed files with 1334 additions and 301 deletions

View File

@@ -57,6 +57,150 @@ type private ArgToParse =
| Positional of ParseFunction
| NonPositional of ParseFunction
type private HasPositional = HasPositional
type private HasNoPositional = HasNoPositional
[<AutoOpen>]
module private TeqUtils =
let exFalso<'a> (_ : Teq<HasNoPositional, HasPositional>) : 'a = failwith "LOGIC ERROR!"
let exFalso'<'a> (_ : Teq<HasPositional, HasNoPositional>) : 'a = failwith "LOGIC ERROR!"
[<RequireQualifiedAccess>]
type private ParseTree<'hasPositional> =
| NonPositionalLeaf of ParseFunction * Teq<'hasPositional, HasNoPositional>
| PositionalLeaf of ParseFunction * Teq<'hasPositional, HasPositional>
/// `assemble` takes the SynExpr's (e.g. each record field contents) corresponding to each `Ident` in
/// the branch (e.g. each record field name),
/// and composes them into a `SynExpr` (e.g. the record-typed object).
| Branch of
fields : (Ident * ParseTree<HasNoPositional>) list *
assemble : (Map<string, SynExpr> -> SynExpr) *
Teq<'hasPositional, HasNoPositional>
/// `assemble` takes the SynExpr's (e.g. each record field contents) corresponding to each `Ident` in
/// the branch (e.g. each record field name),
/// and composes them into a `SynExpr` (e.g. the record-typed object).
| BranchPos of
posField : Ident *
fields : ParseTree<HasPositional> *
(Ident * ParseTree<HasNoPositional>) list *
assemble : (Map<string, SynExpr> -> SynExpr) *
Teq<'hasPositional, HasPositional>
type private ParseTreeEval<'ret> =
abstract Eval<'a> : ParseTree<'a> -> 'ret
type private ParseTreeCrate =
abstract Apply<'ret> : ParseTreeEval<'ret> -> 'ret
[<RequireQualifiedAccess>]
module private ParseTreeCrate =
let make<'a> (p : ParseTree<'a>) =
{ new ParseTreeCrate with
member _.Apply a = a.Eval p
}
[<RequireQualifiedAccess>]
module private ParseTree =
[<RequireQualifiedAccess>]
type State =
| Positional of ParseTree<HasPositional> * ParseTree<HasNoPositional> list
| NoPositional of ParseTree<HasNoPositional> list
let private cast (t : Teq<'a, 'b>) : Teq<ParseTree<'a>, ParseTree<'b>> = Teq.Cong.believeMe t
/// The `Ident` here is the field name.
let branch (assemble : Map<string, SynExpr> -> SynExpr) (subs : (Ident * ParseTreeCrate) list) : ParseTreeCrate =
let rec go
(selfIdent : Ident option)
(acc : (Ident * ParseTree<HasNoPositional>) list, pos : (Ident * ParseTree<HasPositional>) option)
(subs : (Ident * ParseTreeCrate) list)
: ParseTreeCrate
=
match subs with
| [] ->
match pos with
| None -> ParseTree.Branch (List.rev acc, assemble, Teq.refl) |> ParseTreeCrate.make
| Some (posField, pos) ->
ParseTree.BranchPos (posField, pos, List.rev acc, assemble, Teq.refl)
|> ParseTreeCrate.make
| (fieldName, sub) :: subs ->
{ new ParseTreeEval<_> with
member _.Eval (t : ParseTree<'a>) =
match t with
| ParseTree.NonPositionalLeaf (_, teq)
| ParseTree.Branch (_, _, teq) ->
go selfIdent (((fieldName, Teq.cast (cast teq) t) :: acc), pos) subs
| ParseTree.PositionalLeaf (_, teq)
| ParseTree.BranchPos (_, _, _, _, teq) ->
match pos with
| None -> go selfIdent (acc, Some (fieldName, Teq.cast (cast teq) t)) subs
| Some (ident, _) ->
failwith
$"Multiple entries tried to claim positional args! %s{ident.idText} and %s{fieldName.idText}"
}
|> sub.Apply
go None ([], None) subs
let rec accumulatorsNonPos (tree : ParseTree<HasNoPositional>) : ParseFunction list =
match tree with
| ParseTree.PositionalLeaf (_, teq) -> exFalso teq
| ParseTree.BranchPos (_, _, _, _, teq) -> exFalso teq
| ParseTree.NonPositionalLeaf (pf, _) -> [ pf ]
| ParseTree.Branch (trees, _, _) -> trees |> List.collect (snd >> accumulatorsNonPos)
/// Returns the positional arg separately.
let rec accumulatorsPos (tree : ParseTree<HasPositional>) : ParseFunction list * ParseFunction =
match tree with
| ParseTree.PositionalLeaf (pf, _) -> [], pf
| ParseTree.NonPositionalLeaf (_, teq) -> exFalso' teq
| ParseTree.Branch (_, _, teq) -> exFalso' teq
| ParseTree.BranchPos (_, tree, trees, _, _) ->
let nonPos = trees |> List.collect (snd >> accumulatorsNonPos)
let nonPos2, pos = accumulatorsPos tree
nonPos @ nonPos2, pos
/// Collect all the ParseFunctions which are necessary to define variables, throwing away
/// all information relevant to composing the resulting variables into records.
/// Returns the list of non-positional parsers, and any positional parser that exists.
let accumulators<'a> (tree : ParseTree<'a>) : ParseFunction list * ParseFunction option =
// Sad duplication of some code here, but it was the easiest way to make it type-safe :(
match tree with
| ParseTree.PositionalLeaf (pf, _) -> [], Some pf
| ParseTree.NonPositionalLeaf (pf, _) -> [ pf ], None
| ParseTree.Branch (trees, _, _) -> trees |> List.collect (snd >> accumulatorsNonPos) |> (fun i -> i, None)
| ParseTree.BranchPos (_, tree, trees, _, _) ->
let nonPos = trees |> List.collect (snd >> accumulatorsNonPos)
let nonPos2, pos = accumulatorsPos tree
nonPos @ nonPos2, Some pos
/// Build the return value.
let rec instantiate<'a> (tree : ParseTree<'a>) : SynExpr =
match tree with
| ParseTree.NonPositionalLeaf (pf, _) -> SynExpr.createIdent' pf.TargetVariable
| ParseTree.PositionalLeaf (pf, _) -> SynExpr.createIdent' pf.TargetVariable
| ParseTree.Branch (trees, assemble, _) ->
trees
|> List.map (fun (fieldName, contents) ->
let instantiated = instantiate contents
fieldName.idText, instantiated
)
|> Map.ofList
|> assemble
| ParseTree.BranchPos (posField, tree, trees, assemble, _) ->
let withPos = instantiate tree
trees
|> List.map (fun (fieldName, contents) ->
let instantiated = instantiate contents
fieldName.idText, instantiated
)
|> Map.ofList
|> Map.add posField.idText withPos
|> assemble
[<RequireQualifiedAccess>]
module internal ArgParserGenerator =
@@ -254,16 +398,21 @@ module internal ArgParserGenerator =
| Accumulation.Required -> parseElt, Accumulation.List, childTy
| _ -> failwith $"Could not decide how to parse arguments for field %s{fieldName.idText} of type %O{ty}"
let private toParseSpec (finalRecord : RecordType) : ParserSpec =
let rec private toParseSpec
(counter : int)
(ambientRecords : RecordType list)
(finalRecord : RecordType)
: ParseTreeCrate * int
=
finalRecord.Fields
|> List.iter (fun (SynField.SynField (isStatic = isStatic)) ->
if isStatic then
failwith "No static record fields allowed in ArgParserGenerator"
)
let args : ArgToParse list =
finalRecord.Fields
|> List.map (fun (SynField.SynField (attrs, _, identOption, fieldType, _, _, _, _, _)) ->
let counter, contents =
((counter, []), finalRecord.Fields)
||> List.fold (fun (counter, acc) (SynField.SynField (attrs, _, identOption, fieldType, _, _, _, _, _)) ->
let attrs = attrs |> List.collect (fun a -> a.Attributes)
let positionalArgAttr =
@@ -313,6 +462,20 @@ module internal ArgParserGenerator =
| None -> failwith "expected args field to have a name, but it did not"
| Some i -> i
let ambientRecordMatch =
match fieldType with
| SynType.LongIdent (SynLongIdent.SynLongIdent (id, _, _)) ->
let target = List.last(id).idText
ambientRecords |> List.tryFind (fun r -> r.Name.idText = target)
| _ -> None
match ambientRecordMatch with
| Some ambient ->
// This field has a type we need to obtain from parsing another record.
let spec, counter = toParseSpec counter ambientRecords ambient
counter, (ident, spec) :: acc
| None ->
let parser, accumulation, parseTy = createParseFunction ident attrs fieldType
match positionalArgAttr with
@@ -322,47 +485,41 @@ module internal ArgParserGenerator =
{
FieldName = ident
Parser = parser
TargetVariable = ident
TargetVariable = Ident.create $"arg_%i{counter}"
Accumulation = accumulation
TargetType = parseTy
ArgForm = argify ident
Help = helpText
}
|> ArgToParse.Positional
|> fun t -> ParseTree.PositionalLeaf (t, Teq.refl)
|> ParseTreeCrate.make
| _ -> failwith $"Expected positional arg accumulation type to be List, but it was %O{fieldType}"
| None ->
{
FieldName = ident
Parser = parser
TargetVariable = ident
TargetVariable = Ident.create $"arg_%i{counter}"
Accumulation = accumulation
TargetType = parseTy
ArgForm = argify ident
Help = helpText
}
|> ArgToParse.NonPositional
|> fun t -> ParseTree.NonPositionalLeaf (t, Teq.refl)
|> ParseTreeCrate.make
|> fun tree -> counter + 1, (ident, tree) :: acc
)
let positional, nonPositionals =
let mutable p = None
let n = ResizeArray ()
let tree =
contents
|> List.rev
|> ParseTree.branch (fun args ->
args
|> Map.toList
|> List.map (fun (ident, expr) -> SynLongIdent.create [ Ident.create ident ], expr)
|> AstHelper.instantiateRecord
)
for arg in args do
match arg with
| ArgToParse.Positional arg ->
match p with
| None -> p <- Some arg
| Some existing ->
failwith
$"Multiple args were tagged with `Positional`: %s{existing.TargetVariable.idText}, %s{arg.TargetVariable.idText}"
| ArgToParse.NonPositional arg -> n.Add arg
p, List.ofSeq n
{
NonPositionals = nonPositionals
Positionals = positional
}
tree, counter
/// let helpText : string = ...
let private helpText
@@ -478,10 +635,10 @@ module internal ArgParserGenerator =
| Accumulation.List ->
[
SynExpr.createIdent "value"
|> SynExpr.pipeThroughFunction arg.Parser
|> SynExpr.pipeThroughFunction (
SynExpr.createLongIdent' [ arg.TargetVariable ; Ident.create "Add" ]
)
|> SynExpr.applyFunction arg.Parser
SynExpr.CreateConst () |> SynExpr.pipeThroughFunction (SynExpr.createIdent "Ok")
]
|> SynExpr.sequential
@@ -508,14 +665,14 @@ module internal ArgParserGenerator =
|> SynBinding.withXmlDoc (
[
" Processes the key-value pair, returning Error if no key was matched."
" If the key is an arg which can arity 1, but throws when consuming that arg, we return Error(<the message>)."
" If the key is an arg which can have arity 1, but throws when consuming that arg, we return Error(<the message>)."
" This can nevertheless be a successful parse, e.g. when the key may have arity 0."
]
|> PreXmlDoc.create'
)
/// `let setFlagValue (key : string) : bool = ...`
let private setFlagValue (parseState : Ident) (argParseErrors : Ident) (flags : ParseFunction list) : SynBinding =
let private setFlagValue (argParseErrors : Ident) (flags : ParseFunction list) : SynBinding =
(SynExpr.CreateConst false, flags)
||> List.fold (fun finalExpr flag ->
let multipleErrorMessage =
@@ -568,7 +725,7 @@ module internal ArgParserGenerator =
(leftoverArgParser : SynExpr)
: SynBinding
=
/// `go (AwaitingValue arg) args
/// `go (AwaitingValue arg) args`
let recurseValue =
SynExpr.createIdent "go"
|> SynExpr.applyTo (
@@ -608,9 +765,9 @@ module internal ArgParserGenerator =
argStartsWithDashes
(SynExpr.sequential
[
(SynExpr.createIdent "arg"
|> SynExpr.pipeThroughFunction leftoverArgParser
|> SynExpr.pipeThroughFunction (SynExpr.createLongIdent' [ leftoverArgs ; Ident.create "Add" ]))
SynExpr.createIdent "arg"
|> SynExpr.pipeThroughFunction leftoverArgParser
|> SynExpr.pipeThroughFunction (SynExpr.createLongIdent' [ leftoverArgs ; Ident.create "Add" ])
recurseKey
])
@@ -786,18 +943,24 @@ module internal ArgParserGenerator =
SynPat.named "state"
|> SynPat.annotateType (SynType.createLongIdent [ parseState ])
SynPat.named "args"
|> SynPat.annotateType (SynType.appPostfix "list" (SynType.string))
|> SynPat.annotateType (SynType.appPostfix "list" SynType.string)
]
SynBinding.basic [ Ident.create "go" ] args body
|> SynBinding.withRecursion true
/// Takes a single argument, `args : string list`, and returns something of the type indicated by `recordType`.
let createRecordParse (parseState : Ident) (recordType : RecordType) : SynExpr =
let spec = toParseSpec recordType
let createRecordParse (parseState : Ident) (ambientRecords : RecordType list) (recordType : RecordType) : SynExpr =
let spec, _ = toParseSpec 0 ambientRecords recordType
// For each argument (positional and non-positional), create an accumulator for it.
let nonPos, pos =
{ new ParseTreeEval<_> with
member _.Eval tree = ParseTree.accumulators tree
}
|> spec.Apply
let bindings =
spec.NonPositionals
nonPos
|> List.map (fun pf ->
match pf.Accumulation with
| Accumulation.Required
@@ -816,7 +979,7 @@ module internal ArgParserGenerator =
let bindings, leftoverArgsName, leftoverArgsParser =
let bindingName, leftoverArgsParser, leftoverArgsType =
match spec.Positionals with
match pos with
| None ->
Ident.create "parser_LeftoverArgs",
(SynExpr.createLambda "x" (SynExpr.createIdent "x")),
@@ -839,7 +1002,7 @@ module internal ArgParserGenerator =
|> SynExpr.applyTo (SynExpr.CreateConst ())
|> SynBinding.basic [ argParseErrors ] []
let helpText = helpText recordType.Name spec.Positionals spec.NonPositionals
let helpText = helpText recordType.Name pos nonPos
let bindings = errorCollection :: helpText :: bindings
@@ -849,7 +1012,7 @@ module internal ArgParserGenerator =
// Determine whether any required arg is missing, and freeze args into immutable form.
let freezeNonPositionalArgs =
spec.NonPositionals
nonPos
|> List.map (fun pf ->
match pf.Accumulation with
| Accumulation.Choice spec ->
@@ -912,7 +1075,7 @@ module internal ArgParserGenerator =
let errorMessage =
SynExpr.createIdent "sprintf"
|> SynExpr.applyTo (SynExpr.CreateConst "Required argument '%s' received no value")
|> SynExpr.applyTo (SynExpr.CreateConst (argify pf.TargetVariable))
|> SynExpr.applyTo (SynExpr.CreateConst pf.ArgForm)
[
SynMatchClause.create
@@ -935,7 +1098,7 @@ module internal ArgParserGenerator =
)
let freezePositional =
match spec.Positionals with
match pos with
| None ->
// Check if there are leftover args. If there are, throw.
let errorMessage =
@@ -969,20 +1132,12 @@ module internal ArgParserGenerator =
let freezeArgs = freezePositional @ freezeNonPositionalArgs
let retPositional =
match spec.Positionals with
| None -> []
| Some pf ->
[
SynLongIdent.createI pf.TargetVariable, SynExpr.createIdent' pf.TargetVariable
]
let retValue =
let happyPath =
spec.NonPositionals
|> List.map (fun pf -> SynLongIdent.createI pf.TargetVariable, SynExpr.createIdent' pf.TargetVariable)
|> fun np -> retPositional @ np
|> AstHelper.instantiateRecord
{ new ParseTreeEval<_> with
member _.Eval tree = ParseTree.instantiate tree
}
|> spec.Apply
let sadPath =
SynExpr.createIdent' argParseErrors
@@ -1001,7 +1156,7 @@ module internal ArgParserGenerator =
SynExpr.ifThenElse areErrors sadPath happyPath
let flags =
spec.NonPositionals
nonPos
|> List.filter (fun pf ->
match pf.TargetType with
| PrimitiveType pt -> (pt |> List.map _.idText) = [ "System" ; "Boolean" ]
@@ -1019,8 +1174,8 @@ module internal ArgParserGenerator =
|> SynExpr.createLet (
bindings
@ [
processKeyValue argParseErrors (Option.toList spec.Positionals @ spec.NonPositionals)
setFlagValue parseState argParseErrors flags
processKeyValue argParseErrors (Option.toList pos @ nonPos)
setFlagValue argParseErrors flags
mainLoop parseState argParseErrors leftoverArgsName leftoverArgsParser
]
)
@@ -1029,10 +1184,13 @@ module internal ArgParserGenerator =
(opens : SynOpenDeclTarget list)
(ns : LongIdent)
((taggedType : SynTypeDefn, spec : ArgParserOutputSpec))
(_allUnionTypesTODO : SynTypeDefn list)
(allUnionTypes : SynTypeDefn list)
(allRecordTypes : SynTypeDefn list)
: SynModuleOrNamespace
=
// The type for which we're generating args may refer to any of these records/unions.
let allRecordTypes = allRecordTypes |> List.map RecordType.OfRecord
let taggedType = RecordType.OfRecord taggedType
let modAttrs, modName =
@@ -1086,7 +1244,7 @@ module internal ArgParserGenerator =
|> SynPat.annotateType (SynType.appPostfix "list" SynType.string)
let parsePrime =
createRecordParse parseStateIdent taggedType
createRecordParse parseStateIdent allRecordTypes taggedType
|> SynBinding.basic
[ Ident.create "parse'" ]
[
@@ -1146,16 +1304,19 @@ module internal ArgParserGenerator =
let ast, _ =
Ast.fromFilename context.InputFilename |> Async.RunSynchronously |> Array.head
let types = Ast.extractTypeDefn ast
let types =
Ast.extractTypeDefn ast
|> List.groupBy (fst >> List.map _.idText >> String.concat ".")
|> List.map (fun (_, v) -> fst (List.head v), List.collect snd v)
let opens = AstHelper.extractOpens ast
let namespaceAndTypes =
types
|> List.choose (fun (ns, types) ->
|> List.collect (fun (ns, types) ->
let typeWithAttr =
types
|> List.tryPick (fun ty ->
|> List.choose (fun ty ->
match Ast.getAttribute<ArgParserAttribute> ty with
| None -> None
| Some attr ->
@@ -1175,8 +1336,8 @@ module internal ArgParserGenerator =
Some (ty, spec)
)
match typeWithAttr with
| Some taggedType ->
typeWithAttr
|> List.map (fun taggedType ->
let unions, records, others =
(([], [], []), types)
||> List.fold (fun
@@ -1194,8 +1355,8 @@ module internal ArgParserGenerator =
failwith
$"Error: all types recursively defined together with an ArgParserGenerator type must be discriminated unions or records. %+A{others}"
Some (ns, taggedType, unions, records)
| _ -> None
(ns, taggedType, unions, records)
)
)
let modules =