diff --git a/pkg/flag/base64.go b/pkg/flag/base64.go new file mode 100644 index 00000000..f60267cd --- /dev/null +++ b/pkg/flag/base64.go @@ -0,0 +1,86 @@ +package flag + +import ( + "encoding/base64" + "fmt" + "strings" +) + +// Base64 implements flag.Value, and is used to populate []byte values from baes64 encoded strings. +type Base64 struct { + val []byte + len int +} + +// NewBase64 returns a Base64 which accepts values which decode to len byte strings. +func NewBase64(len int) *Base64 { + return &Base64{ + len: len, + } +} + +func (f *Base64) String() string { + return base64.StdEncoding.EncodeToString(f.val) +} + +// Set will set the []byte value of the Base64 to the base64 decoded values of the string, returning an error if it cannot be decoded or is of the wrong length. +func (f *Base64) Set(s string) error { + b, err := base64.StdEncoding.DecodeString(s) + if err != nil { + return err + } + + if len(b) != f.len { + return fmt.Errorf("expected %d-byte secret", f.len) + } + + f.val = b + return nil +} + +// Bytes returns the set []byte value. +// If no value has been set, a nil []byte is returned. +func (f *Base64) Bytes() []byte { + return f.val +} + +// NewBase64List returns a Base64List which accepts a comma-separated list of strings which must decode to len byte strings. +func NewBase64List(len int) *Base64List { + return &Base64List{ + len: len, + } +} + +// Base64List implements flag.Value and is used to populate [][]byte values from a comma-separated list of base64 encoded strings. +type Base64List struct { + val [][]byte + len int +} + +// Set will set the [][]byte value of the Base64List to the base64 decoded values of the comma-separated strings, returning an error on the first error it encounters. +func (f *Base64List) Set(ss string) error { + if ss == "" { + return nil + } + for i, s := range strings.Split(ss, ",") { + b64 := NewBase64(f.len) + err := b64.Set(s) + if err != nil { + return fmt.Errorf("error decoding string %d: %q", i, err) + } + f.val = append(f.val, b64.Bytes()) + } + return nil +} + +func (f *Base64List) String() string { + ss := []string{} + for _, b := range f.val { + ss = append(ss, base64.StdEncoding.EncodeToString(b)) + } + return strings.Join(ss, ",") +} + +func (f *Base64List) BytesSlice() [][]byte { + return f.val +} diff --git a/pkg/flag/base64_test.go b/pkg/flag/base64_test.go new file mode 100644 index 00000000..f2aa25e8 --- /dev/null +++ b/pkg/flag/base64_test.go @@ -0,0 +1,134 @@ +package flag + +import ( + "encoding/base64" + "strings" + "testing" + + "github.com/kylelemons/godebug/pretty" +) + +func TestBase64(t *testing.T) { + toB64 := func(b []byte) string { + return base64.StdEncoding.EncodeToString(b) + } + + tests := []struct { + s string + l int + b []byte + wantError bool + }{ + { + s: toB64([]byte("123456")), + l: 6, + b: []byte("123456"), + }, + { + s: toB64([]byte("123456")), + l: 5, + wantError: true, + }, + { + s: "not base64", + l: 5, + wantError: true, + }, + } + + for i, tt := range tests { + b64 := NewBase64(tt.l) + err := b64.Set(tt.s) + if tt.wantError { + if err == nil { + t.Errorf("case %d: want err, got nil", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected error %q", i, err) + } + + if diff := pretty.Compare(tt.b, b64.Bytes()); diff != "" { + t.Errorf("case %d: Compare(want, got) = %v", i, + diff) + } + + if b64.String() != tt.s { + t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s) + } + } +} + +func TestBase64List(t *testing.T) { + // toCSB64 == to comma separated base 64 + toCSB64 := func(bb ...[]byte) string { + ss := []string{} + for _, b := range bb { + ss = append(ss, base64.StdEncoding.EncodeToString(b)) + } + return strings.Join(ss, ",") + } + + b123 := []byte("123456") + b567 := []byte("567890") + bShort := []byte("1234") + + tests := []struct { + s string + l int + bb [][]byte + wantError bool + }{ + { + s: toCSB64(b123, b567), + l: 6, + bb: [][]byte{b123, b567}, + }, + { + s: toCSB64(b123), + l: 6, + bb: [][]byte{b123}, + }, + { + s: "", + l: 6, + bb: [][]byte{}, + }, + { + s: toCSB64(b123, bShort), + l: 6, + wantError: true, + }, + { + s: toCSB64(bShort, b123), + l: 6, + wantError: true, + }, + } + + for i, tt := range tests { + b64 := NewBase64List(tt.l) + err := b64.Set(tt.s) + if tt.wantError { + if err == nil { + t.Errorf("case %d: want err, got nil", i) + } + continue + } + + if err != nil { + t.Errorf("case %d: unexpected error %q", i, err) + } + + if diff := pretty.Compare(tt.bb, b64.BytesSlice()); diff != "" { + t.Errorf("case %d: Compare(want, got) = %v", i, + diff) + } + + if b64.String() != tt.s { + t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s) + } + } +}