Implement [<ArgumentFlag>] for two-case DUs (#242)

This commit is contained in:
Patrick Stevens
2024-09-04 22:48:36 +01:00
committed by GitHub
parent bdce82fb7a
commit e4cbab3209
15 changed files with 1085 additions and 126 deletions

View File

@@ -12,6 +12,23 @@ type internal ArgParserOutputSpec =
ExtensionMethods : bool
}
type internal FlagDu =
{
Name : Ident
Case1Name : Ident
Case2Name : Ident
/// Hopefully this is simply the const bool True or False, but it might e.g. be a literal
Case1Arg : SynExpr
/// Hopefully this is simply the const bool True or False, but it might e.g. be a literal
Case2Arg : SynExpr
}
static member FromBoolean (flagDu : FlagDu) (value : SynExpr) =
SynExpr.ifThenElse
(SynExpr.equals value flagDu.Case1Arg)
(SynExpr.createLongIdent' [ flagDu.Name ; flagDu.Case2Name ])
(SynExpr.createLongIdent' [ flagDu.Name ; flagDu.Case1Name ])
/// The default value of an argument which admits default values can be pulled from different sources.
/// This defines which source a particular default value comes from.
type private ArgumentDefaultSpec =
@@ -20,6 +37,7 @@ type private ArgumentDefaultSpec =
/// From calling the static member `{typeWeParseInto}.Default{name}()`
/// For example, if `type MyArgs = { Thing : Choice<int, int> }`, then
/// we would use `MyArgs.DefaultThing () : int`.
///
| FunctionCall of name : Ident
type private Accumulation<'choice> =
@@ -33,6 +51,13 @@ type private ParseFunction<'acc> =
FieldName : Ident
TargetVariable : Ident
ArgForm : string
/// If this is a boolean-like field (e.g. a bool or a flag DU), the help text should look a bit different:
/// we should lie to the user about the value of the cases there.
/// Similarly, if we're reading from an environment variable with the laxer parsing rules of accepting e.g.
/// "0" instead of "false", we need to know if we're reading a bool.
/// In that case, `boolCases` is Some, and contains the construction of the flag (or boolean, in which case
/// you get no data).
BoolCases : Choice<FlagDu, unit> option
Help : SynExpr option
/// A function string -> %TargetType%, where TargetVariable is probably a `%TargetType% option`.
/// (Depending on `Accumulation`, we'll remove the `option` at the end of the parse, asserting that the
@@ -236,12 +261,24 @@ module internal ArgParserGenerator =
result.ToString ()
let private identifyAsFlag (flagDus : FlagDu list) (ty : SynType) : FlagDu option =
match ty with
| SynType.LongIdent (SynLongIdent.SynLongIdent (ident, _, _)) ->
flagDus
|> List.tryPick (fun du ->
let duName = du.Name.idText
let ident = List.last(ident).idText
if duName = ident then Some du else None
)
| _ -> None
/// Builds a function or lambda of one string argument, which returns a `ty` (as modified by the `Accumulation`;
/// for example, maybe it returns a `ty option` or a `ty list`).
/// The resulting SynType is the type of the *element* being parsed; so if the Accumulation is List, the SynType
/// is the list element.
let rec private createParseFunction<'choice>
(choice : ArgumentDefaultSpec option -> 'choice)
(flagDus : FlagDu list)
(fieldName : Ident)
(attrs : SynAttribute list)
(ty : SynType)
@@ -334,7 +371,8 @@ module internal ArgParserGenerator =
Accumulation.Required,
ty
| OptionType eltTy ->
let parseElt, acc, childTy = createParseFunction choice fieldName attrs eltTy
let parseElt, acc, childTy =
createParseFunction choice flagDus fieldName attrs eltTy
match acc with
| Accumulation.Optional ->
@@ -353,7 +391,7 @@ module internal ArgParserGenerator =
failwith
$"ArgParser was unable to prove types %O{elt1} and %O{elt2} to be equal in a Choice. We require them to be equal."
let parseElt, acc, childTy = createParseFunction choice fieldName attrs elt1
let parseElt, acc, childTy = createParseFunction choice flagDus fieldName attrs elt1
match acc with
| Accumulation.Optional ->
@@ -391,6 +429,7 @@ module internal ArgParserGenerator =
| [ "Myriad" ; "Plugins" ; "ArgumentDefaultEnvironmentVariableAttribute" ]
| [ "WoofWare" ; "Myriad" ; "Plugins" ; "ArgumentDefaultEnvironmentVariable" ]
| [ "WoofWare" ; "Myriad" ; "Plugins" ; "ArgumentDefaultEnvironmentVariableAttribute" ] ->
ArgumentDefaultSpec.EnvironmentVariable attr.ArgExpr |> Some
| _ -> None
)
@@ -410,13 +449,26 @@ module internal ArgParserGenerator =
failwith
$"ArgParser requires Choice to be of the form Choice<'a, 'a>; that is, two arguments, both the same. For field %s{fieldName.idText}, got: %s{elts}"
| ListType eltTy ->
let parseElt, acc, childTy = createParseFunction choice fieldName attrs eltTy
let parseElt, acc, childTy =
createParseFunction choice flagDus fieldName attrs eltTy
parseElt, Accumulation.List acc, childTy
| _ -> failwith $"Could not decide how to parse arguments for field %s{fieldName.idText} of type %O{ty}"
| ty ->
match identifyAsFlag flagDus ty with
| None -> failwith $"Could not decide how to parse arguments for field %s{fieldName.idText} of type %O{ty}"
| Some flagDu ->
// Parse as a bool, and then do the `if-then` dance.
let parser =
SynExpr.createIdent "x"
|> SynExpr.applyFunction (SynExpr.createLongIdent [ "System" ; "Boolean" ; "Parse" ])
|> FlagDu.FromBoolean flagDu
|> SynExpr.createLambda "x"
parser, Accumulation.Required, ty
let rec private toParseSpec
(counter : int)
(flagDus : FlagDu list)
(ambientRecords : RecordType list)
(finalRecord : RecordType)
: ParseTreeCrate * int
@@ -489,7 +541,7 @@ module internal ArgParserGenerator =
match ambientRecordMatch with
| Some ambient ->
// This field has a type we need to obtain from parsing another record.
let spec, counter = toParseSpec counter ambientRecords ambient
let spec, counter = toParseSpec counter flagDus ambientRecords ambient
counter, (ident, spec) :: acc
| None ->
@@ -503,7 +555,13 @@ module internal ArgParserGenerator =
| None -> ()
let parser, accumulation, parseTy =
createParseFunction<unit> getChoice ident attrs fieldType
createParseFunction<unit> getChoice flagDus ident attrs fieldType
let isBoolLike =
match parseTy with
| PrimitiveType ident when ident |> List.map _.idText = [ "System" ; "Boolean" ] ->
Some (Choice2Of2 ())
| parseTy -> identifyAsFlag flagDus parseTy |> Option.map Choice1Of2
match accumulation with
| Accumulation.List (Accumulation.List _) ->
@@ -519,6 +577,7 @@ module internal ArgParserGenerator =
TargetType = parseTy
ArgForm = argify ident
Help = helpText
BoolCases = isBoolLike
}
|> fun t -> ParseTree.PositionalLeaf (t, Teq.refl)
| Accumulation.List Accumulation.Required ->
@@ -530,6 +589,7 @@ module internal ArgParserGenerator =
TargetType = parseTy
ArgForm = argify ident
Help = helpText
BoolCases = isBoolLike
}
|> fun t -> ParseTree.PositionalLeaf (t, Teq.refl)
| Accumulation.Choice _
@@ -546,7 +606,13 @@ module internal ArgParserGenerator =
| Some spec -> spec
let parser, accumulation, parseTy =
createParseFunction getChoice ident attrs fieldType
createParseFunction getChoice flagDus ident attrs fieldType
let isBoolLike =
match parseTy with
| PrimitiveType ident when ident |> List.map _.idText = [ "System" ; "Boolean" ] ->
Some (Choice2Of2 ())
| parseTy -> identifyAsFlag flagDus parseTy |> Option.map Choice1Of2
{
FieldName = ident
@@ -556,6 +622,7 @@ module internal ArgParserGenerator =
TargetType = parseTy
ArgForm = argify ident
Help = helpText
BoolCases = isBoolLike
}
|> fun t -> ParseTree.NonPositionalLeaf (t, Teq.refl)
|> ParseTreeCrate.make
@@ -581,7 +648,11 @@ module internal ArgParserGenerator =
(args : ParseFunctionNonPositional list)
: SynBinding
=
let describeNonPositional (acc : Accumulation<ArgumentDefaultSpec>) : SynExpr =
let describeNonPositional
(acc : Accumulation<ArgumentDefaultSpec>)
(flagCases : Choice<FlagDu, unit> option)
: SynExpr
=
match acc with
| Accumulation.Required -> SynExpr.CreateConst ""
| Accumulation.Optional -> SynExpr.CreateConst " (optional)"
@@ -596,18 +667,43 @@ module internal ArgParserGenerator =
)
|> SynExpr.paren
| Accumulation.Choice (ArgumentDefaultSpec.FunctionCall var) ->
SynExpr.callMethod var.idText (SynExpr.createIdent' typeName)
match flagCases with
| None -> SynExpr.callMethod var.idText (SynExpr.createIdent' typeName)
| Some (Choice2Of2 ()) -> SynExpr.callMethod var.idText (SynExpr.createIdent' typeName)
| Some (Choice1Of2 flagDu) ->
// Care required here. The return value from the Default call is not a bool,
// but we should display it as such to the user!
[
SynMatchClause.create
(SynPat.identWithArgs [ flagDu.Name ; flagDu.Case1Name ] (SynArgPats.create []))
(SynExpr.ifThenElse
(SynExpr.equals flagDu.Case1Arg (SynExpr.CreateConst true))
(SynExpr.CreateConst "false")
(SynExpr.CreateConst "true"))
SynMatchClause.create
(SynPat.identWithArgs [ flagDu.Name ; flagDu.Case2Name ] (SynArgPats.create []))
(SynExpr.ifThenElse
(SynExpr.equals flagDu.Case2Arg (SynExpr.CreateConst true))
(SynExpr.CreateConst "false")
(SynExpr.CreateConst "true"))
]
|> SynExpr.createMatch (SynExpr.callMethod var.idText (SynExpr.createIdent' typeName))
|> SynExpr.pipeThroughFunction (
SynExpr.applyFunction (SynExpr.createIdent "sprintf") (SynExpr.CreateConst " (default value: %O)")
)
|> SynExpr.paren
| Accumulation.List _ -> SynExpr.CreateConst " (can be repeated)"
let describePositional _ =
let describePositional _ _ =
SynExpr.CreateConst " (positional args) (can be repeated)"
let toPrintable (describe : 'a -> SynExpr) (arg : ParseFunction<'a>) : SynExpr =
let ty = arg.TargetType |> SynType.toHumanReadableString
/// We may sometimes lie about the type name, if e.g. this is a flag DU which we're pretending is a boolean.
/// So the `renderTypeName` takes the Accumulation which tells us whether we're lying.
let toPrintable (describe : 'a -> Choice<FlagDu, unit> option -> SynExpr) (arg : ParseFunction<'a>) : SynExpr =
let ty =
match arg.BoolCases with
| None -> SynType.toHumanReadableString arg.TargetType
| Some _ -> "bool"
let helpText =
match arg.Help with
@@ -617,7 +713,7 @@ module internal ArgParserGenerator =
|> SynExpr.applyTo (SynExpr.paren helpText)
|> SynExpr.paren
let descriptor = describe arg.Accumulation
let descriptor = describe arg.Accumulation arg.BoolCases
let prefix = $"%s{arg.ArgForm} %s{ty}"
@@ -765,9 +861,12 @@ module internal ArgParserGenerator =
)
/// `let setFlagValue (key : string) : bool = ...`
let private setFlagValue (argParseErrors : Ident) (flags : ParseFunction<'a> list) : SynBinding =
/// The second member of the `flags` list tuple is the constant "true" with which we will interpret the
/// arity-0 `--foo`. So in the case of a boolean-typed field, this is `true`; in the case of a Flag-typed field,
/// this is `FlagType.WhicheverCaseHadTrue`.
let private setFlagValue (argParseErrors : Ident) (flags : (ParseFunction<'a> * SynExpr) list) : SynBinding =
(SynExpr.CreateConst false, flags)
||> List.fold (fun finalExpr flag ->
||> List.fold (fun finalExpr (flag, trueCase) ->
let multipleErrorMessage =
SynExpr.createIdent "sprintf"
|> SynExpr.applyTo (SynExpr.CreateConst "Flag '%s' was supplied multiple times")
@@ -789,7 +888,7 @@ module internal ArgParserGenerator =
([
SynExpr.assign
(SynLongIdent.createI flag.TargetVariable)
(SynExpr.applyFunction (SynExpr.createIdent "Some") (SynExpr.CreateConst true))
(SynExpr.pipeThroughFunction (SynExpr.createIdent "Some") trueCase)
SynExpr.CreateConst true
]
|> SynExpr.sequential)
@@ -1059,8 +1158,14 @@ module internal ArgParserGenerator =
|> SynBinding.withRecursion true
/// Takes a single argument, `args : string list`, and returns something of the type indicated by `recordType`.
let createRecordParse (parseState : Ident) (ambientRecords : RecordType list) (recordType : RecordType) : SynExpr =
let spec, _ = toParseSpec 0 ambientRecords recordType
let createRecordParse
(parseState : Ident)
(flagDus : FlagDu list)
(ambientRecords : RecordType list)
(recordType : RecordType)
: SynExpr
=
let spec, _ = toParseSpec 0 flagDus ambientRecords recordType
// For each argument (positional and non-positional), create an accumulator for it.
let nonPos, pos =
{ new ParseTreeEval<_> with
@@ -1143,8 +1248,15 @@ module internal ArgParserGenerator =
/// Assumes access to a non-null variable `x` containing the string value.
let parser =
match pf.TargetType with
| PrimitiveType ident when ident |> List.map _.idText = [ "System" ; "Boolean" ] ->
match pf.BoolCases with
| Some boolLike ->
let trueCase, falseCase =
match boolLike with
| Choice2Of2 () -> SynExpr.CreateConst true, SynExpr.CreateConst false
| Choice1Of2 flag ->
FlagDu.FromBoolean flag (SynExpr.CreateConst true),
FlagDu.FromBoolean flag (SynExpr.CreateConst false)
// We permit environment variables to be populated with 0 and 1 as well.
SynExpr.ifThenElse
(SynExpr.applyFunction
@@ -1167,9 +1279,9 @@ module internal ArgParserGenerator =
[ "System" ; "StringComparison" ; "OrdinalIgnoreCase" ]
]))
(SynExpr.createIdent "x" |> SynExpr.pipeThroughFunction pf.Parser)
(SynExpr.CreateConst false))
(SynExpr.CreateConst true)
| _ -> (SynExpr.createIdent "x" |> SynExpr.pipeThroughFunction pf.Parser)
falseCase)
trueCase
| None -> (SynExpr.createIdent "x" |> SynExpr.pipeThroughFunction pf.Parser)
let errorMessage =
SynExpr.createIdent "sprintf"
@@ -1308,10 +1420,17 @@ module internal ArgParserGenerator =
let flags =
nonPos
|> List.filter (fun pf ->
|> List.choose (fun pf ->
match pf.TargetType with
| PrimitiveType pt -> (pt |> List.map _.idText) = [ "System" ; "Boolean" ]
| _ -> false
| PrimitiveType pt ->
if (pt |> List.map _.idText) = [ "System" ; "Boolean" ] then
Some (pf, SynExpr.CreateConst true)
else
None
| ty ->
match identifyAsFlag flagDus ty with
| Some flag -> (pf, FlagDu.FromBoolean flag (SynExpr.CreateConst true)) |> Some
| _ -> None
)
let leftoverArgAcc =
@@ -1336,18 +1455,81 @@ module internal ArgParserGenerator =
]
)
// The type for which we're generating args may refer to any of the supplied records/unions.
let createModule
(opens : SynOpenDeclTarget list)
(ns : LongIdent)
((taggedType : SynTypeDefn, spec : ArgParserOutputSpec))
(allUnionTypes : SynTypeDefn list)
(allRecordTypes : SynTypeDefn list)
(allUnionTypes : UnionType list)
(allRecordTypes : RecordType list)
: SynModuleOrNamespace
=
// The type for which we're generating args may refer to any of these records/unions.
let allRecordTypes = allRecordTypes |> List.map RecordType.OfRecord
let flagDus =
allUnionTypes
|> List.choose (fun ty ->
match ty.Cases with
| [ c1 ; c2 ] ->
match c1.Fields, c2.Fields with
| [], [] ->
let c1Attr =
c1.Attributes
|> List.tryPick (fun attr ->
match attr.TypeName with
| SynLongIdent.SynLongIdent (id, _, _) ->
match id |> List.last |> _.idText with
| "ArgumentFlagAttribute"
| "ArgumentFlag" -> Some (SynExpr.stripOptionalParen attr.ArgExpr)
| _ -> None
)
let taggedType = RecordType.OfRecord taggedType
let c2Attr =
c2.Attributes
|> List.tryPick (fun attr ->
match attr.TypeName with
| SynLongIdent.SynLongIdent (id, _, _) ->
match id |> List.last |> _.idText with
| "ArgumentFlagAttribute"
| "ArgumentFlag" -> Some (SynExpr.stripOptionalParen attr.ArgExpr)
| _ -> None
)
match c1Attr, c2Attr with
| Some c1Attr, Some c2Attr ->
// Sanity check where possible
match c1Attr, c2Attr with
| SynExpr.Const (SynConst.Bool b1, _), SynExpr.Const (SynConst.Bool b2, _) ->
if b1 = b2 then
failwith
"[<ArgumentFlag>] must have opposite argument values on each case in a two-case discriminated union."
| _, _ -> ()
{
Name = ty.Name
Case1Name = c1.Name
Case1Arg = c1Attr
Case2Name = c2.Name
Case2Arg = c2Attr
}
|> Some
| Some _, None
| None, Some _ ->
failwith
"[<ArgumentFlag>] must be placed on both cases of a two-case discriminated union, with opposite argument values on each case."
| _, _ -> None
| _, _ ->
failwith "[<ArgumentFlag>] may only be placed on discriminated union members with no data."
| _ -> None
)
let taggedType =
match taggedType with
| SynTypeDefn.SynTypeDefn (sci,
SynTypeDefnRepr.Simple (SynTypeDefnSimpleRepr.Record (access, fields, _), _),
smd,
_,
_,
_) -> RecordType.OfRecord sci smd access fields
| _ -> failwith "[<ArgParser>] currently only supports being placed on records."
let modAttrs, modName =
if spec.ExtensionMethods then
@@ -1368,13 +1550,15 @@ module internal ArgParserGenerator =
[
SynUnionCase.create
{
Attrs = []
Attributes = []
Fields = []
Ident = Ident.create "AwaitingKey"
Name = Ident.create "AwaitingKey"
XmlDoc = Some (PreXmlDoc.create "Ready to consume a key or positional arg")
Access = None
}
SynUnionCase.create
{
Attrs = []
Attributes = []
Fields =
[
{
@@ -1383,7 +1567,9 @@ module internal ArgParserGenerator =
Type = SynType.string
}
]
Ident = Ident.create "AwaitingValue"
Name = Ident.create "AwaitingValue"
XmlDoc = Some (PreXmlDoc.create "Waiting to receive a value for the key we've already consumed")
Access = None
}
]
|> SynTypeDefnRepr.union
@@ -1400,7 +1586,7 @@ module internal ArgParserGenerator =
|> SynPat.annotateType (SynType.appPostfix "list" SynType.string)
let parsePrime =
createRecordParse parseStateIdent allRecordTypes taggedType
createRecordParse parseStateIdent flagDus allRecordTypes taggedType
|> SynBinding.basic
[ Ident.create "parse'" ]
[
@@ -1498,12 +1684,12 @@ module internal ArgParserGenerator =
(([], [], []), types)
||> List.fold (fun
(unions, records, others)
(SynTypeDefn.SynTypeDefn (_, repr, _, _, _, _) as ty) ->
(SynTypeDefn.SynTypeDefn (sci, repr, smd, _, _, _) as ty) ->
match repr with
| SynTypeDefnRepr.Simple (SynTypeDefnSimpleRepr.Union _, _) ->
ty :: unions, records, others
| SynTypeDefnRepr.Simple (SynTypeDefnSimpleRepr.Record _, _) ->
unions, ty :: records, others
| SynTypeDefnRepr.Simple (SynTypeDefnSimpleRepr.Union (access, cases, _), _) ->
UnionType.OfUnion sci smd access cases :: unions, records, others
| SynTypeDefnRepr.Simple (SynTypeDefnSimpleRepr.Record (access, fields, _), _) ->
unions, RecordType.OfRecord sci smd access fields :: records, others
| _ -> unions, records, ty :: others
)