diff --git a/args.zig b/args.zig index 41c3461..5e8cb00 100644 --- a/args.zig +++ b/args.zig @@ -85,10 +85,10 @@ pub fn parseWithVerb(comptime Generic: type, comptime Verb: type, args_iterator: fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterator: anytype, allocator: std.mem.Allocator, error_handling: ErrorHandling) !ParseArgsResult(Generic, MaybeVerb) { var result = ParseArgsResult(Generic, MaybeVerb){ .arena = std.heap.ArenaAllocator.init(allocator), - .options = Generic{}, .verb = if (MaybeVerb != null) null else {}, // no verb by default .positionals = undefined, .executable_name = null, + .options = undefined, }; errdefer result.arena.deinit(); var result_arena_allocator = result.arena.allocator(); @@ -98,7 +98,22 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato var last_error: ?anyerror = null; - while (args_iterator.next()) |item| { + // Create map for required arguments + var required_map = std.StringHashMap(bool).init(allocator); + defer required_map.deinit(); + + // Add the generic arguments + // and init defaults + inline for (std.meta.fields(Generic)) |field| { + if (field.default_value) |default_value_ptr| { + const default_value = @ptrCast(*const field.field_type, default_value_ptr).*; + @field(result.options, field.name) = default_value; + } else { + try required_map.put(field.name, true); + } + } + + args_loop: while (args_iterator.next()) |item| { if (std.mem.startsWith(u8, item, "--")) { if (std.mem.eql(u8, item, "--")) { // double hyphen is considered 'everything from here now is positional' @@ -122,34 +137,33 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato .value = null, }; - var found = false; inline for (std.meta.fields(Generic)) |fld| { if (std.mem.eql(u8, pair.name, fld.name)) { try parseOption(Generic, result_arena_allocator, &result.options, args_iterator, error_handling, &last_error, fld.name, pair.value); - found = true; + _ = required_map.remove(fld.name); + continue :args_loop; } } if (MaybeVerb) |Verb| { if (result.verb) |*verb| { - if (!found) { - const Tag = std.meta.Tag(Verb); - inline for (std.meta.fields(Verb)) |verb_info| { - if (verb.* == @field(Tag, verb_info.name)) { - inline for (std.meta.fields(verb_info.field_type)) |fld| { - if (std.mem.eql(u8, pair.name, fld.name)) { - try parseOption( - verb_info.field_type, - result_arena_allocator, - &@field(verb.*, verb_info.name), - args_iterator, - error_handling, - &last_error, - fld.name, - pair.value, - ); - found = true; - } + const Tag = std.meta.Tag(Verb); + inline for (std.meta.fields(Verb)) |verb_info| { + if (verb.* == @field(Tag, verb_info.name)) { + inline for (std.meta.fields(verb_info.field_type)) |fld| { + if (std.mem.eql(u8, pair.name, fld.name)) { + try parseOption( + verb_info.field_type, + result_arena_allocator, + &@field(verb.*, verb_info.name), + args_iterator, + error_handling, + &last_error, + fld.name, + pair.value, + ); + _ = required_map.remove(fld.name); + continue :args_loop; } } } @@ -157,13 +171,11 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato } } - if (!found) { - last_error = error.EncounteredUnknownArgument; - try error_handling.process(error.EncounteredUnknownArgument, Error{ - .option = pair.name, - .kind = .unknown, - }); - } + last_error = error.EncounteredUnknownArgument; + try error_handling.process(error.EncounteredUnknownArgument, Error{ + .option = pair.name, + .kind = .unknown, + }); } else if (std.mem.startsWith(u8, item, "-")) { if (std.mem.eql(u8, item, "-")) { // single hyphen is considered a positional argument @@ -172,7 +184,6 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato var any_shorthands = false; for (item[1..]) |char, index| { var option_name = [2]u8{ '-', char }; - var found = false; if (@hasDecl(Generic, "shorthands")) { any_shorthands = true; inline for (std.meta.fields(@TypeOf(Generic.shorthands))) |fld| { @@ -191,43 +202,43 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato }); } else { try parseOption(Generic, result_arena_allocator, &result.options, args_iterator, error_handling, &last_error, real_name, null); + _ = required_map.remove(real_name); } - found = true; + continue :args_loop; } } } if (MaybeVerb) |Verb| { if (result.verb) |*verb| { - if (!found) { - const Tag = std.meta.Tag(Verb); - inline for (std.meta.fields(Verb)) |verb_info| { - const VerbType = verb_info.field_type; - if (verb.* == @field(Tag, verb_info.name)) { - const target_value = &@field(verb.*, verb_info.name); - if (@hasDecl(VerbType, "shorthands")) { - any_shorthands = true; - inline for (std.meta.fields(@TypeOf(VerbType.shorthands))) |fld| { - if (fld.name.len != 1) - @compileError("All shorthand fields must be exactly one character long!"); - if (fld.name[0] == char) { - const real_name = @field(VerbType.shorthands, fld.name); - const real_fld_type = @TypeOf(@field(target_value.*, real_name)); - - // -2 because we stripped of the "-" at the beginning - if (requiresArg(real_fld_type) and index != item.len - 2) { - last_error = error.EncounteredUnexpectedArgument; - try error_handling.process(error.EncounteredUnexpectedArgument, Error{ - .option = &option_name, - .kind = .invalid_placement, - }); - } else { - try parseOption(VerbType, result_arena_allocator, target_value, args_iterator, error_handling, &last_error, real_name, null); - } - last_error = null; // we need to reset that error here, as it was set previously - found = true; + const Tag = std.meta.Tag(Verb); + inline for (std.meta.fields(Verb)) |verb_info| { + const VerbType = verb_info.field_type; + if (verb.* == @field(Tag, verb_info.name)) { + const target_value = &@field(verb.*, verb_info.name); + if (@hasDecl(VerbType, "shorthands")) { + any_shorthands = true; + inline for (std.meta.fields(@TypeOf(VerbType.shorthands))) |fld| { + if (fld.name.len != 1) + @compileError("All shorthand fields must be exactly one character long!"); + if (fld.name[0] == char) { + const real_name = @field(VerbType.shorthands, fld.name); + const real_fld_type = @TypeOf(@field(target_value.*, real_name)); + + // -2 because we stripped of the "-" at the beginning + if (requiresArg(real_fld_type) and index != item.len - 2) { + last_error = error.EncounteredUnexpectedArgument; + try error_handling.process(error.EncounteredUnexpectedArgument, Error{ + .option = &option_name, + .kind = .invalid_placement, + }); + } else { + try parseOption(VerbType, result_arena_allocator, target_value, args_iterator, error_handling, &last_error, real_name, null); + _ = required_map.remove(real_name); } + last_error = null; // we need to reset that error here, as it was set previously + continue :args_loop; } } } @@ -235,13 +246,11 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato } } } - if (!found) { - last_error = error.EncounteredUnknownArgument; - try error_handling.process(error.EncounteredUnknownArgument, Error{ - .option = &option_name, - .kind = .unknown, - }); - } + last_error = error.EncounteredUnknownArgument; + try error_handling.process(error.EncounteredUnknownArgument, Error{ + .option = &option_name, + .kind = .unknown, + }); } if (!any_shorthands) { try error_handling.process(error.EncounteredUnsupportedArgument, Error{ @@ -256,25 +265,50 @@ fn parseInternal(comptime Generic: type, comptime MaybeVerb: ?type, args_iterato inline for (std.meta.fields(Verb)) |fld| { if (std.mem.eql(u8, item, fld.name)) { // found active verb, default-initialize it - result.verb = @unionInit(Verb, fld.name, fld.field_type{}); + result.verb = @unionInit(Verb, fld.name, undefined); + const target_value = &@field(result.verb.?, fld.name); + + const VerbType = fld.field_type; + inline for (std.meta.fields(VerbType)) |field| { + if (field.default_value) |default_value_ptr| { + const default_value = @ptrCast(*const field.field_type, default_value_ptr).*; + @field(target_value, field.name) = default_value; + } else { + try required_map.put(field.name, true); + } + } + + continue :args_loop; } } - if (result.verb == null) { - try error_handling.process(error.EncounteredUnknownVerb, Error{ - .option = "verb", - .kind = .unsupported, - }); - } + try error_handling.process(error.EncounteredUnknownVerb, Error{ + .option = "verb", + .kind = .unsupported, + }); continue; } } + // Argument doesn't match anything, so should be a argument to the program itself try arglist.append(try result_arena_allocator.dupeZ(u8, item)); } } + if (required_map.count() > 0) { + last_error = error.MissingRequiredArgument; + + var it = required_map.keyIterator(); + while (it.next()) |option| { + try error_handling.process(error.MissingRequiredArgument, Error{ + .option = option.*, + .kind = .missing_argument, + }); + } + return error.MissingRequiredArgument; + } + if (last_error != null) return error.InvalidArguments; @@ -727,6 +761,22 @@ const TestVerb = union(enum) { }; }; +const TestVerbRequired = union(enum) { + magic: MagicOptions, + booze: BoozeOptions, + + const MagicOptions = struct { invoke: bool = false }; + const BoozeOptions = struct { + cocktail: bool = false, + longdrink: bool, + + pub const shorthands = .{ + .c = "cocktail", + .l = "longdrink", + }; + }; +}; + test "basic parsing (no verbs)" { var titerator = TestIterator.init(&[_][:0]const u8{ "--output", @@ -801,7 +851,7 @@ test "shorthand parsing (no verbs)" { test "basic parsing (with verbs)" { var titerator = TestIterator.init(&[_][:0]const u8{ - "--output", // non-verb options can come before or after verb + "--output", // non-verb options can come before or after verb "foobar", "booze", // verb "--with-offset", @@ -889,6 +939,61 @@ test "shorthand parsing (with verbs)" { try std.testing.expectEqual(false, booze.longdrink); } +test "shorthand parsing (with verbs and required)" { + var allocator = std.heap.ArenaAllocator.init(std.testing.allocator); + defer allocator.deinit(); + + var titerator_magic = TestIterator.init(&[_][:0]const u8{ + "magic", // verb + }); + + var titerator_short = TestIterator.init(&[_][:0]const u8{ + "booze", // verb + "-c", // --cocktail + "-l", + }); + var titerator_long = TestIterator.init(&[_][:0]const u8{ + "booze", // verb + "-l", // --longdring + }); + + var titerator_missing = TestIterator.init(&[_][:0]const u8{ + "booze", // verb + "-c", // --cocktail + }); + + { + var args = try parseInternal(TestGenericOptions, TestVerbRequired, &titerator_magic, allocator.allocator(), .print); + defer args.deinit(); + try std.testing.expect(?TestVerbRequired == @TypeOf(args.verb)); + try std.testing.expect(args.verb.? == .magic); + } + + { + var args = try parseInternal(TestGenericOptions, TestVerbRequired, &titerator_short, allocator.allocator(), .print); + defer args.deinit(); + try std.testing.expect(?TestVerbRequired == @TypeOf(args.verb)); + try std.testing.expect(args.verb.? == .booze); + const booze = args.verb.?.booze; + try std.testing.expectEqual(true, booze.cocktail); + try std.testing.expectEqual(true, booze.longdrink); + } + + { + var args = try parseInternal(TestGenericOptions, TestVerbRequired, &titerator_long, allocator.allocator(), .print); + defer args.deinit(); + try std.testing.expect(?TestVerbRequired == @TypeOf(args.verb)); + try std.testing.expect(args.verb.? == .booze); + const booze = args.verb.?.booze; + try std.testing.expectEqual(false, booze.cocktail); + try std.testing.expectEqual(true, booze.longdrink); + } + + { + try std.testing.expectError(error.MissingRequiredArgument, parseInternal(TestGenericOptions, TestVerbRequired, &titerator_missing, allocator.allocator(), .print)); + } +} + test "strings with sentinel" { var titerator = TestIterator.init(&[_][:0]const u8{ "--output", @@ -944,3 +1049,40 @@ test "index of raw indicator --" { try std.testing.expectEqual(args.raw_start_index, 2); try std.testing.expectEqual(args.positionals.len, 5); } + +test "required argument --" { + var titerator = TestIterator.init(&[_][:0]const u8{ + "--output", + "foobar", + }); + { + var args = try parseInternal( + struct { + output: ?[:0]const u8 = "test", + }, + null, + &titerator, + std.testing.allocator, + .print, + ); + defer args.deinit(); + + try std.testing.expectEqualStrings("foobar", args.options.output.?); + } + + { + var titerator_none = TestIterator.init(&[_][:0]const u8{}); + var args = try parseInternal( + struct { + output: ?[:0]const u8 = "test", + }, + null, + &titerator_none, + std.testing.allocator, + .print, + ); + defer args.deinit(); + + try std.testing.expectEqualStrings("test", args.options.output.?); + } +}