pkg/flag: add new Base64, Base64List flag.Values
Allows setting of []byte's with base64 encoded strings and [][]bytes with comma-separated base64 encoded strings.
This commit is contained in:
parent
c8feb5c33d
commit
0feb1dd719
2 changed files with 220 additions and 0 deletions
86
pkg/flag/base64.go
Normal file
86
pkg/flag/base64.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package flag
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Base64 implements flag.Value, and is used to populate []byte values from baes64 encoded strings.
|
||||
type Base64 struct {
|
||||
val []byte
|
||||
len int
|
||||
}
|
||||
|
||||
// NewBase64 returns a Base64 which accepts values which decode to len byte strings.
|
||||
func NewBase64(len int) *Base64 {
|
||||
return &Base64{
|
||||
len: len,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Base64) String() string {
|
||||
return base64.StdEncoding.EncodeToString(f.val)
|
||||
}
|
||||
|
||||
// Set will set the []byte value of the Base64 to the base64 decoded values of the string, returning an error if it cannot be decoded or is of the wrong length.
|
||||
func (f *Base64) Set(s string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(s)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(b) != f.len {
|
||||
return fmt.Errorf("expected %d-byte secret", f.len)
|
||||
}
|
||||
|
||||
f.val = b
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes returns the set []byte value.
|
||||
// If no value has been set, a nil []byte is returned.
|
||||
func (f *Base64) Bytes() []byte {
|
||||
return f.val
|
||||
}
|
||||
|
||||
// NewBase64List returns a Base64List which accepts a comma-separated list of strings which must decode to len byte strings.
|
||||
func NewBase64List(len int) *Base64List {
|
||||
return &Base64List{
|
||||
len: len,
|
||||
}
|
||||
}
|
||||
|
||||
// Base64List implements flag.Value and is used to populate [][]byte values from a comma-separated list of base64 encoded strings.
|
||||
type Base64List struct {
|
||||
val [][]byte
|
||||
len int
|
||||
}
|
||||
|
||||
// Set will set the [][]byte value of the Base64List to the base64 decoded values of the comma-separated strings, returning an error on the first error it encounters.
|
||||
func (f *Base64List) Set(ss string) error {
|
||||
if ss == "" {
|
||||
return nil
|
||||
}
|
||||
for i, s := range strings.Split(ss, ",") {
|
||||
b64 := NewBase64(f.len)
|
||||
err := b64.Set(s)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error decoding string %d: %q", i, err)
|
||||
}
|
||||
f.val = append(f.val, b64.Bytes())
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *Base64List) String() string {
|
||||
ss := []string{}
|
||||
for _, b := range f.val {
|
||||
ss = append(ss, base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
return strings.Join(ss, ",")
|
||||
}
|
||||
|
||||
func (f *Base64List) BytesSlice() [][]byte {
|
||||
return f.val
|
||||
}
|
134
pkg/flag/base64_test.go
Normal file
134
pkg/flag/base64_test.go
Normal file
|
@ -0,0 +1,134 @@
|
|||
package flag
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/kylelemons/godebug/pretty"
|
||||
)
|
||||
|
||||
func TestBase64(t *testing.T) {
|
||||
toB64 := func(b []byte) string {
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
s string
|
||||
l int
|
||||
b []byte
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
s: toB64([]byte("123456")),
|
||||
l: 6,
|
||||
b: []byte("123456"),
|
||||
},
|
||||
{
|
||||
s: toB64([]byte("123456")),
|
||||
l: 5,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
s: "not base64",
|
||||
l: 5,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
b64 := NewBase64(tt.l)
|
||||
err := b64.Set(tt.s)
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error %q", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.b, b64.Bytes()); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got) = %v", i,
|
||||
diff)
|
||||
}
|
||||
|
||||
if b64.String() != tt.s {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBase64List(t *testing.T) {
|
||||
// toCSB64 == to comma separated base 64
|
||||
toCSB64 := func(bb ...[]byte) string {
|
||||
ss := []string{}
|
||||
for _, b := range bb {
|
||||
ss = append(ss, base64.StdEncoding.EncodeToString(b))
|
||||
}
|
||||
return strings.Join(ss, ",")
|
||||
}
|
||||
|
||||
b123 := []byte("123456")
|
||||
b567 := []byte("567890")
|
||||
bShort := []byte("1234")
|
||||
|
||||
tests := []struct {
|
||||
s string
|
||||
l int
|
||||
bb [][]byte
|
||||
wantError bool
|
||||
}{
|
||||
{
|
||||
s: toCSB64(b123, b567),
|
||||
l: 6,
|
||||
bb: [][]byte{b123, b567},
|
||||
},
|
||||
{
|
||||
s: toCSB64(b123),
|
||||
l: 6,
|
||||
bb: [][]byte{b123},
|
||||
},
|
||||
{
|
||||
s: "",
|
||||
l: 6,
|
||||
bb: [][]byte{},
|
||||
},
|
||||
{
|
||||
s: toCSB64(b123, bShort),
|
||||
l: 6,
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
s: toCSB64(bShort, b123),
|
||||
l: 6,
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
b64 := NewBase64List(tt.l)
|
||||
err := b64.Set(tt.s)
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("case %d: want err, got nil", i)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("case %d: unexpected error %q", i, err)
|
||||
}
|
||||
|
||||
if diff := pretty.Compare(tt.bb, b64.BytesSlice()); diff != "" {
|
||||
t.Errorf("case %d: Compare(want, got) = %v", i,
|
||||
diff)
|
||||
}
|
||||
|
||||
if b64.String() != tt.s {
|
||||
t.Errorf("case %d: want=%q, got=%q", i, b64.String(), tt.s)
|
||||
}
|
||||
}
|
||||
}
|
Reference in a new issue