diff --git a/app/internal/tgc/tgc.go b/app/internal/tgc/tgc.go index 0f2e4d9..9d0bf69 100644 --- a/app/internal/tgc/tgc.go +++ b/app/internal/tgc/tgc.go @@ -50,15 +50,17 @@ func New(ctx context.Context, login bool, middlewares ...telegram.Middleware) (* ) if test := viper.GetString(consts.FlagTest); test != "" { - kvd, err = kv.NewFile(filepath.Join(os.TempDir(), test)) // persistent storage + var stg kv.Storage + stg, err = kv.New(kv.DriverFile, map[string]any{"path": filepath.Join(os.TempDir(), test)}) + if err != nil { + return nil, nil, errors.Wrap(err, "create test kv") + } + kvd, err = stg.Open(test) } else { - kvd, err = kv.New(kv.Options{ - Path: consts.KVPath, - NS: viper.GetString(consts.FlagNamespace), - }) + kvd, err = kv.From(ctx).Open(viper.GetString(consts.FlagNamespace)) } if err != nil { - return nil, nil, err + return nil, nil, errors.Wrap(err, "open kv") } _clock, err := Clock() diff --git a/app/login/desktop.go b/app/login/desktop.go index 589067f..60ed0a4 100644 --- a/app/login/desktop.go +++ b/app/login/desktop.go @@ -9,6 +9,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/fatih/color" + "github.com/go-faster/errors" "github.com/gotd/td/session" tdtdesktop "github.com/gotd/td/session/tdesktop" "github.com/spf13/viper" @@ -32,12 +33,9 @@ type Options struct { func Desktop(ctx context.Context, opts *Options) error { ns := viper.GetString(consts.FlagNamespace) - kvd, err := kv.New(kv.Options{ - Path: consts.KVPath, - NS: ns, - }) + kvd, err := kv.From(ctx).Open(ns) if err != nil { - return err + return errors.Wrap(err, "open kv") } desktop, err := findDesktop(opts.Desktop) diff --git a/cmd/root.go b/cmd/root.go index 11ec1c6..4e79bb0 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -1,12 +1,15 @@ package cmd import ( + "fmt" "path/filepath" "strings" "time" + "github.com/go-faster/errors" "github.com/spf13/cobra" "github.com/spf13/viper" + "go.uber.org/multierr" "go.uber.org/zap" "github.com/iyear/tdl/pkg/consts" @@ -15,12 +18,14 @@ import ( ) func New() *cobra.Command { + driverTypeKey := "type" + cmd := &cobra.Command{ Use: "tdl", Short: "Telegram Downloader, but more than a downloader", SilenceErrors: true, SilenceUsage: true, - PersistentPreRun: func(cmd *cobra.Command, args []string) { + PersistentPreRunE: func(cmd *cobra.Command, args []string) error { // init logger debug, level := viper.GetBool(consts.FlagDebug), zap.InfoLevel if debug { @@ -34,15 +39,44 @@ func New() *cobra.Command { logger.From(cmd.Context()).Info("Namespace", zap.String("namespace", ns)) } + + // check storage flag + storageOpts := viper.GetStringMapString(consts.FlagStorage) + driver, err := kv.ParseDriver(storageOpts[driverTypeKey]) + if err != nil { + return errors.Wrap(err, "parse driver") + } + delete(storageOpts, driverTypeKey) + + opts := make(map[string]any) + for k, v := range storageOpts { + opts[k] = v + } + storage, err := kv.New(driver, opts) + if err != nil { + return errors.Wrap(err, "create kv storage") + } + + cmd.SetContext(kv.With(cmd.Context(), storage)) + return nil }, PersistentPostRunE: func(cmd *cobra.Command, args []string) error { - return logger.From(cmd.Context()).Sync() + return multierr.Combine( + kv.From(cmd.Context()).Close(), + logger.From(cmd.Context()).Sync(), + ) }, } cmd.AddCommand(NewVersion(), NewLogin(), NewDownload(), NewForward(), NewChat(), NewUpload(), NewBackup(), NewRecover(), NewGen()) + cmd.PersistentFlags().StringToString(consts.FlagStorage, map[string]string{ + driverTypeKey: kv.DriverLegacy.String(), + "path": consts.KVPath, + }, fmt.Sprintf("storage options, format: type=driver,key1=value1,key2=value2. Available drivers: [%s]", + strings.Join(kv.DriverNames(), ","))) + cmd.PersistentFlags().String(consts.FlagProxy, "", "proxy address, only socks5 is supported, format: protocol://username:password@host:port") cmd.PersistentFlags().StringP(consts.FlagNamespace, "n", "", "namespace for Telegram session") cmd.PersistentFlags().Bool(consts.FlagDebug, false, "enable debug mode") @@ -59,7 +93,8 @@ func New() *cobra.Command { // completion _ = cmd.RegisterFlagCompletionFunc(consts.FlagNamespace, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { - ns, err := kv.Namespaces(consts.KVPath) + engine := kv.From(cmd.Context()) + ns, err := engine.Namespaces() if err != nil { return nil, cobra.ShellCompDirectiveNoFileComp } diff --git a/pkg/consts/flag.go b/pkg/consts/flag.go index 6c4e5ec..7a4505d 100644 --- a/pkg/consts/flag.go +++ b/pkg/consts/flag.go @@ -1,6 +1,7 @@ package consts const ( + FlagStorage = "storage" FlagProxy = "proxy" FlagNamespace = "ns" FlagDebug = "debug" diff --git a/pkg/consts/path.go b/pkg/consts/path.go index efb7233..aac51e0 100644 --- a/pkg/consts/path.go +++ b/pkg/consts/path.go @@ -5,6 +5,5 @@ var ( DataDir string KVPath string LogPath string - DocsPath = "docs" UploadThumbExt = ".thumb" ) diff --git a/pkg/kv/bolt.go b/pkg/kv/bolt.go index b58c070..d248cf8 100644 --- a/pkg/kv/bolt.go +++ b/pkg/kv/bolt.go @@ -1,39 +1,172 @@ package kv import ( + "os" + "path/filepath" + "sync" + + "github.com/go-faster/errors" + "github.com/mitchellh/mapstructure" "go.etcd.io/bbolt" + "go.uber.org/multierr" + + "github.com/iyear/tdl/pkg/validator" ) -type Bolt struct { - ns []byte - db *bbolt.DB +func init() { + register(DriverBolt, func(m map[string]any) (Storage, error) { return newBolt(m) }) } -func (b *Bolt) Get(key string) ([]byte, error) { - var val []byte +type bolt struct { + path string + dbs map[string]*bbolt.DB + mu *sync.Mutex +} - if err := b.db.View(func(tx *bbolt.Tx) error { - val = tx.Bucket(b.ns).Get([]byte(key)) +func newBolt(opts map[string]any) (*bolt, error) { + type options struct { + Path string `validate:"required" mapstructure:"path"` + } + + var o options + if err := mapstructure.WeakDecode(opts, &o); err != nil { + return nil, errors.Wrap(err, "decode options") + } + + if err := validator.Struct(&o); err != nil { + return nil, errors.Wrap(err, "validate options") + } + + if err := os.MkdirAll(o.Path, 0o755); err != nil { + return nil, errors.Wrap(err, "create dir") + } + + return &bolt{ + path: o.Path, + dbs: make(map[string]*bbolt.DB), + mu: &sync.Mutex{}, + }, nil +} + +func (b *bolt) Name() string { + return DriverBolt.String() +} + +func (b *bolt) MigrateTo() (Meta, error) { + meta := make(Meta) + + if err := b.walk(func(path string) (rerr error) { + ns := filepath.Base(path) + meta[ns] = make(map[string][]byte) + + db, err := b.open(ns) + if err != nil { + return errors.Wrap(err, "open") + } + + return db.db.View(func(tx *bbolt.Tx) error { + return tx.Bucket(db.ns).ForEach(func(k, v []byte) error { + meta[ns][string(k)] = v + return nil + }) + }) + }); err != nil { + return nil, errors.Wrap(err, "walk") + } + + return meta, nil +} + +func (b *bolt) MigrateFrom(meta Meta) error { + for ns, pairs := range meta { + db, err := b.open(ns) + if err != nil { + return errors.Wrap(err, "open") + } + + if err = db.db.Update(func(tx *bbolt.Tx) error { + bk, err := tx.CreateBucketIfNotExists(db.ns) + if err != nil { + return errors.Wrap(err, "create bucket") + } + for key, value := range pairs { + if err = bk.Put([]byte(key), value); err != nil { + return errors.Wrap(err, "put") + } + } + return nil + }); err != nil { + return errors.Wrap(err, "update") + } + } + + return nil +} + +func (b *bolt) Namespaces() ([]string, error) { + namespaces := make([]string, 0) + if err := b.walk(func(path string) error { + namespaces = append(namespaces, filepath.Base(path)) return nil }); err != nil { - return nil, err + return nil, errors.Wrap(err, "walk") } - if val == nil { - return nil, ErrNotFound - } - return val, nil + return namespaces, nil } -func (b *Bolt) Set(key string, val []byte) error { - return b.db.Update(func(tx *bbolt.Tx) error { - return tx.Bucket(b.ns).Put([]byte(key), val) +func (b *bolt) walk(fn func(path string) error) error { + return filepath.Walk(b.path, func(path string, info os.FileInfo, err error) error { + if err != nil { + return errors.Wrap(err, "walk") + } + if info.IsDir() { + return nil + } + + return fn(path) }) } -// Delete removes a key from the bucket. If the key does not exist then nothing is done and a nil error is returned -func (b *Bolt) Delete(key string) error { - return b.db.Update(func(tx *bbolt.Tx) error { - return tx.Bucket(b.ns).Delete([]byte(key)) - }) +func (b *bolt) Open(ns string) (KV, error) { + return b.open(ns) +} + +func (b *bolt) open(ns string) (*legacyKV, error) { + if ns == "" { + return nil, errors.New("namespace is required") + } + b.mu.Lock() + defer b.mu.Unlock() + + if db, ok := b.dbs[ns]; ok { + return &legacyKV{db: db, ns: []byte(ns)}, nil + } + + db, err := bbolt.Open(filepath.Join(b.path, ns), os.ModePerm, boltOptions) + if err != nil { + return nil, errors.Wrap(err, "open db") + } + if err = db.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(ns)) + return err + }); err != nil { + return nil, errors.Wrap(err, "create bucket") + } + + b.dbs[ns] = db + + return &legacyKV{db: db, ns: []byte(ns)}, nil +} + +func (b *bolt) Close() error { + b.mu.Lock() + defer b.mu.Unlock() + + var err error + for _, db := range b.dbs { + err = multierr.Append(err, db.Close()) + } + + return err } diff --git a/pkg/kv/file.go b/pkg/kv/file.go index c36dfb9..4716a7d 100644 --- a/pkg/kv/file.go +++ b/pkg/kv/file.go @@ -4,80 +4,115 @@ import ( "encoding/json" "os" "sync" + + "github.com/go-faster/errors" + "github.com/iyear/tdl/pkg/validator" + "github.com/mitchellh/mapstructure" ) -type File struct { +func init() { + register(DriverFile, newFile) +} + +type file struct { path string mu sync.Mutex } -func NewFile(path string) (*File, error) { - _, err := os.Stat(path) +func newFile(opts map[string]any) (Storage, error) { + type options struct { + Path string `validate:"required" mapstructure:"path"` + } + + var o options + if err := mapstructure.WeakDecode(opts, &o); err != nil { + return nil, errors.Wrap(err, "decode options") + } + + if err := validator.Struct(&o); err != nil { + return nil, errors.Wrap(err, "validate options") + } + + _, err := os.Stat(o.Path) if err == nil { - return &File{path: path}, nil + return &file{path: o.Path}, nil } if !os.IsNotExist(err) { - return nil, err + return nil, errors.Wrap(err, "stat file") } - if err = os.WriteFile(path, []byte("{}"), 0o644); err != nil { - return nil, err + if err = os.WriteFile(o.Path, []byte("{}"), 0o644); err != nil { + return nil, errors.Wrap(err, "create file") } - return &File{path: path}, nil + return &file{path: o.Path}, nil } -func (f *File) Get(key string) ([]byte, error) { +func (f *file) Name() string { + return DriverFile.String() +} + +func (f *file) MigrateTo() (Meta, error) { + meta, err := f.read() + if err != nil { + return nil, errors.Wrap(err, "read") + } + return meta, nil +} + +func (f *file) MigrateFrom(meta Meta) error { + return f.write(meta) +} + +func (f *file) Namespaces() ([]string, error) { + pairs, err := f.read() + if err != nil { + return nil, errors.Wrap(err, "read") + } + + namespaces := make([]string, 0, len(pairs)) + for ns := range pairs { + namespaces = append(namespaces, ns) + } + + return namespaces, nil +} + +func (f *file) Open(ns string) (KV, error) { + if ns == "" { + return nil, errors.New("namespace is required") + } + + read, err := f.read() + if err != nil { + return nil, errors.Wrap(err, "read") + } + + if _, ok := read[ns]; !ok { + read[ns] = make(map[string][]byte) + if err = f.write(read); err != nil { + return nil, errors.Wrap(err, "write") + } + } + + return &fileKV{f: f, ns: ns}, nil +} + +func (f *file) Close() error { + return nil +} + +func (f *file) read() (map[string]map[string][]byte, error) { f.mu.Lock() defer f.mu.Unlock() - m, err := f.read() - if err != nil { - return nil, err - } - - if val, ok := m[key]; ok { - return val, nil - } - return nil, ErrNotFound -} - -func (f *File) Set(key string, value []byte) error { - f.mu.Lock() - defer f.mu.Unlock() - - m, err := f.read() - if err != nil { - return err - } - - m[key] = value - - return f.write(m) -} - -func (f *File) Delete(key string) error { - f.mu.Lock() - defer f.mu.Unlock() - - m, err := f.read() - if err != nil { - return err - } - - delete(m, key) - - return f.write(m) -} - -func (f *File) read() (map[string][]byte, error) { bytes, err := os.ReadFile(f.path) if err != nil { return nil, err } - m := make(map[string][]byte) + m := make(map[string]map[string][]byte) if err = json.Unmarshal(bytes, &m); err != nil { return nil, err } @@ -85,7 +120,10 @@ func (f *File) read() (map[string][]byte, error) { return m, nil } -func (f *File) write(m map[string][]byte) error { +func (f *file) write(m map[string]map[string][]byte) error { + f.mu.Lock() + defer f.mu.Unlock() + bytes, err := json.Marshal(m) if err != nil { return err @@ -93,3 +131,42 @@ func (f *File) write(m map[string][]byte) error { return os.WriteFile(f.path, bytes, 0o644) } + +type fileKV struct { + f *file + ns string +} + +func (f *fileKV) Get(key string) ([]byte, error) { + m, err := f.f.read() + if err != nil { + return nil, errors.Wrap(err, "read") + } + + if v, ok := m[f.ns][key]; ok { + return v, nil + } + return nil, ErrNotFound +} + +func (f *fileKV) Set(key string, value []byte) error { + m, err := f.f.read() + if err != nil { + return errors.Wrap(err, "read") + } + + m[f.ns][key] = value + + return f.f.write(m) +} + +func (f *fileKV) Delete(key string) error { + m, err := f.f.read() + if err != nil { + return errors.Wrap(err, "read") + } + + delete(m[f.ns], key) + + return f.f.write(m) +} diff --git a/pkg/kv/kv.go b/pkg/kv/kv.go index 529da2b..0509ebc 100644 --- a/pkg/kv/kv.go +++ b/pkg/kv/kv.go @@ -1,74 +1,57 @@ package kv import ( - "errors" - "os" - "time" + "context" + "io" - "go.etcd.io/bbolt" - "go.uber.org/multierr" - - "github.com/iyear/tdl/pkg/validator" + "github.com/go-faster/errors" ) +//go:generate go-enum --values --names --flag --nocase + +// Driver +// ENUM(legacy, bolt, file) +type Driver string + var ErrNotFound = errors.New("key not found") +type Meta map[string]map[string][]byte // namespace, key, value + +type Storage interface { + Name() string + MigrateTo() (Meta, error) + MigrateFrom(Meta) error + Namespaces() ([]string, error) + Open(ns string) (KV, error) + io.Closer +} + type KV interface { Get(key string) ([]byte, error) Set(key string, value []byte) error Delete(key string) error } -type Options struct { - NS string `validate:"required"` - Path string `validate:"required"` +var drivers = map[Driver]func(map[string]any) (Storage, error){} + +func register(name Driver, fn func(map[string]any) (Storage, error)) { + drivers[name] = fn } -func New(opts Options) (KV, error) { - if err := validator.Struct(&opts); err != nil { - return nil, err +func New(driver Driver, opts map[string]any) (Storage, error) { + if fn, ok := drivers[driver]; ok { + return fn(opts) } - db, err := bbolt.Open(opts.Path, os.ModePerm, &bbolt.Options{ - Timeout: time.Second, - NoGrowSync: false, - FreelistType: bbolt.FreelistArrayType, - }) - if err != nil { - return nil, err - } - - if err = db.Update(func(tx *bbolt.Tx) error { - _, err := tx.CreateBucketIfNotExists([]byte(opts.NS)) - return err - }); err != nil { - return nil, err - } - - return &Bolt{db: db, ns: []byte(opts.NS)}, nil + return nil, errors.Errorf("unsupported driver: %s", driver) } -// Namespaces returns all namespaces in the database -func Namespaces(path string) (_ []string, rerr error) { - db, err := bbolt.Open(path, os.ModePerm, &bbolt.Options{ - Timeout: time.Second, - ReadOnly: true, - }) - if err != nil { - return nil, err - } - defer multierr.AppendInvoke(&rerr, multierr.Close(db)) +type ctxKey struct{} - namespaces := make([]string, 0) - err = db.View(func(tx *bbolt.Tx) error { - return tx.ForEach(func(name []byte, _ *bbolt.Bucket) error { - namespaces = append(namespaces, string(name)) - return nil - }) - }) - - if err != nil { - return nil, err - } - return namespaces, nil +func With(ctx context.Context, kv Storage) context.Context { + return context.WithValue(ctx, ctxKey{}, kv) +} + +func From(ctx context.Context) Storage { + return ctx.Value(ctxKey{}).(Storage) } diff --git a/pkg/kv/kv_enum.go b/pkg/kv/kv_enum.go new file mode 100644 index 0000000..904b42d --- /dev/null +++ b/pkg/kv/kv_enum.go @@ -0,0 +1,92 @@ +// Code generated by go-enum DO NOT EDIT. +// Version: 0.5.8 +// Revision: 3d844c8ecc59661ed7aa17bfd65727bc06a60ad8 +// Build Date: 2023-09-18T14:55:21Z +// Built By: goreleaser + +package kv + +import ( + "fmt" + "strings" +) + +const ( + // DriverLegacy is a Driver of type legacy. + DriverLegacy Driver = "legacy" + // DriverBolt is a Driver of type bolt. + DriverBolt Driver = "bolt" + // DriverFile is a Driver of type file. + DriverFile Driver = "file" +) + +var ErrInvalidDriver = fmt.Errorf("not a valid Driver, try [%s]", strings.Join(_DriverNames, ", ")) + +var _DriverNames = []string{ + string(DriverLegacy), + string(DriverBolt), + string(DriverFile), +} + +// DriverNames returns a list of possible string values of Driver. +func DriverNames() []string { + tmp := make([]string, len(_DriverNames)) + copy(tmp, _DriverNames) + return tmp +} + +// DriverValues returns a list of the values for Driver +func DriverValues() []Driver { + return []Driver{ + DriverLegacy, + DriverBolt, + DriverFile, + } +} + +// String implements the Stringer interface. +func (x Driver) String() string { + return string(x) +} + +// IsValid provides a quick way to determine if the typed value is +// part of the allowed enumerated values +func (x Driver) IsValid() bool { + _, err := ParseDriver(string(x)) + return err == nil +} + +var _DriverValue = map[string]Driver{ + "legacy": DriverLegacy, + "bolt": DriverBolt, + "file": DriverFile, +} + +// ParseDriver attempts to convert a string to a Driver. +func ParseDriver(name string) (Driver, error) { + if x, ok := _DriverValue[name]; ok { + return x, nil + } + // Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to. + if x, ok := _DriverValue[strings.ToLower(name)]; ok { + return x, nil + } + return Driver(""), fmt.Errorf("%s is %w", name, ErrInvalidDriver) +} + +// Set implements the Golang flag.Value interface func. +func (x *Driver) Set(val string) error { + v, err := ParseDriver(val) + *x = v + return err +} + +// Get implements the Golang flag.Getter interface func. +func (x *Driver) Get() interface{} { + return *x +} + +// Type implements the github.com/spf13/pFlag Value interface. +func (x *Driver) Type() string { + return "Driver" +} diff --git a/pkg/kv/kv_test.go b/pkg/kv/kv_test.go new file mode 100644 index 0000000..be2bd0a --- /dev/null +++ b/pkg/kv/kv_test.go @@ -0,0 +1,155 @@ +package kv + +import ( + "fmt" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "path/filepath" + "testing" + +) + +func forEachStorage(t *testing.T, fn func(e Storage, t *testing.T)) { + storages := map[Driver]map[string]any{ + DriverBolt: {"path": t.TempDir()}, + DriverLegacy: {"path": filepath.Join(t.TempDir(), "test.db")}, + DriverFile: {"path": filepath.Join(t.TempDir(), "test.json")}, + } + + for driver, opts := range storages { + storage, err := New(driver, opts) + require.NoError(t, err) + + t.Run(driver.String(), func(t *testing.T) { + fn(storage, t) + }) + assert.NoError(t, storage.Close()) + } +} + +func TestNew(t *testing.T) { + tests := map[Driver][]struct { + name string + opts map[string]any + wantErr bool + }{ + DriverBolt: { + {name: "valid", opts: map[string]any{"path": t.TempDir()}, wantErr: false}, + {name: "invalid", opts: map[string]any{"path": ""}, wantErr: true}, + }, + DriverLegacy: { + {name: "valid", opts: map[string]any{"path": filepath.Join(t.TempDir(), "test.db")}, wantErr: false}, + {name: "invalid", opts: map[string]any{"path": ""}, wantErr: true}, + }, + DriverFile: { + {name: "valid", opts: map[string]any{"path": filepath.Join(t.TempDir(), "test.json")}, wantErr: false}, + }, + Driver("unknown"): { + {name: "unknown", opts: map[string]any{"path": ""}, wantErr: true}, + }, + } + + for driver, tests := range tests { + for _, tt := range tests { + t.Run(fmt.Sprintf("%v/%s", driver, tt.name), func(t *testing.T) { + kv, err := New(driver, tt.opts) + if tt.wantErr { + assert.Error(t, err) + assert.Nil(t, kv) + } else { + assert.NoError(t, err) + assert.NotNil(t, kv) + assert.NoError(t, kv.Close()) + } + }) + } + } +} + +func TestStorage_Open(t *testing.T) { + forEachStorage(t, func(e Storage, t *testing.T) { + for _, ns := range []string{"foo", "bar", "foo"} { + kv, err := e.Open(ns) + require.NoError(t, err) + require.NotNil(t, kv) + } + }) +} + +func TestStorage_Namespaces(t *testing.T) { + namespaces := []string{"foo", "bar", "baz"} + + forEachStorage(t, func(e Storage, t *testing.T) { + for _, ns := range namespaces { + kv, err := e.Open(ns) + require.NoError(t, err) + require.NotNil(t, kv) + } + + ns, err := e.Namespaces() + require.NoError(t, err) + require.ElementsMatch(t, namespaces, ns) + }) +} + +func TestStorage_MigrateTo(t *testing.T) { + meta := Meta{ + "foo": { + "1": []byte("2"), + "3": []byte("4"), + "5": []byte("6"), + }, + "bar": { + "7": []byte("8"), + "9": []byte("10"), + "11": []byte("12"), + }, + } + + forEachStorage(t, func(e Storage, t *testing.T) { + for ns, pairs := range meta { + kv, err := e.Open(ns) + require.NoError(t, err) + require.NotNil(t, kv) + + for key, value := range pairs { + require.NoError(t, kv.Set(key, value)) + } + } + + m, err := e.MigrateTo() + assert.NoError(t, err) + assert.Equal(t, meta, m) + }) +} + +func TestStorage_MigrateFrom(t *testing.T) { + meta := Meta{ + "foo": { + "1": []byte("2"), + "3": []byte("4"), + "5": []byte("6"), + }, + "bar": { + "7": []byte("8"), + "9": []byte("10"), + "11": []byte("12"), + }, + } + + forEachStorage(t, func(e Storage, t *testing.T) { + require.NoError(t, e.MigrateFrom(meta)) + + for ns, pairs := range meta { + kv, err := e.Open(ns) + require.NoError(t, err) + require.NotNil(t, kv) + + for key, value := range pairs { + v, err := kv.Get(key) + require.NoError(t, err) + require.Equal(t, value, v) + } + } + }) +} diff --git a/pkg/kv/legacy.go b/pkg/kv/legacy.go new file mode 100644 index 0000000..537071d --- /dev/null +++ b/pkg/kv/legacy.go @@ -0,0 +1,158 @@ +package kv + +import ( + "os" + "time" + + "github.com/go-faster/errors" + "github.com/mitchellh/mapstructure" + "go.etcd.io/bbolt" + + "github.com/iyear/tdl/pkg/validator" +) + +var boltOptions = &bbolt.Options{ + Timeout: time.Second, + NoGrowSync: false, + FreelistType: bbolt.FreelistArrayType, +} + +func init() { + register(DriverLegacy, func(m map[string]any) (Storage, error) { + return newLegacy(m) + }) +} + +func newLegacy(opts map[string]any) (*legacy, error) { + type options struct { + Path string `validate:"required" mapstructure:"path"` + } + + var o options + if err := mapstructure.WeakDecode(opts, &o); err != nil { + return nil, errors.Wrap(err, "decode options") + } + + if err := validator.Struct(&o); err != nil { + return nil, errors.Wrap(err, "validate options") + } + + db, err := bbolt.Open(o.Path, os.ModePerm, boltOptions) + if err != nil { + return nil, errors.Wrap(err, "open db") + } + + return &legacy{bolt: db}, nil +} + +type legacy struct { + bolt *bbolt.DB +} + +func (l *legacy) Name() string { + return DriverLegacy.String() +} + +func (l *legacy) MigrateTo() (Meta, error) { + meta := make(Meta) + + if err := l.bolt.View(func(tx *bbolt.Tx) error { + return tx.ForEach(func(name []byte, b *bbolt.Bucket) error { + ns := string(name) + meta[ns] = make(map[string][]byte) + return b.ForEach(func(k, v []byte) error { + meta[ns][string(k)] = v + return nil + }) + }) + }); err != nil { + return nil, errors.Wrap(err, "iterate buckets") + } + + return meta, nil +} + +func (l *legacy) MigrateFrom(meta Meta) error { + return l.bolt.Update(func(tx *bbolt.Tx) error { + for ns, pairs := range meta { + b, err := tx.CreateBucketIfNotExists([]byte(ns)) + if err != nil { + return errors.Wrap(err, "create bucket") + } + for key, value := range pairs { + if err = b.Put([]byte(key), value); err != nil { + return errors.Wrap(err, "put") + } + } + } + return nil + }) +} + +func (l *legacy) Namespaces() ([]string, error) { + namespaces := make([]string, 0) + if err := l.bolt.View(func(tx *bbolt.Tx) error { + return tx.ForEach(func(name []byte, _ *bbolt.Bucket) error { + namespaces = append(namespaces, string(name)) + return nil + }) + }); err != nil { + return nil, errors.Wrap(err, "iterate namespaces") + } + return namespaces, nil +} + +func (l *legacy) Open(ns string) (KV, error) { + return l.open(ns) +} + +func (l *legacy) open(ns string) (*legacyKV, error) { + if ns == "" { + return nil, errors.New("namespace is required") + } + if err := l.bolt.Update(func(tx *bbolt.Tx) error { + _, err := tx.CreateBucketIfNotExists([]byte(ns)) + return err + }); err != nil { + return nil, errors.Wrap(err, "create bucket") + } + + return &legacyKV{db: l.bolt, ns: []byte(ns)}, nil +} + +func (l *legacy) Close() error { + return l.bolt.Close() +} + +type legacyKV struct { + db *bbolt.DB + ns []byte +} + +func (l *legacyKV) Get(key string) ([]byte, error) { + var val []byte + + if err := l.db.View(func(tx *bbolt.Tx) error { + val = tx.Bucket(l.ns).Get([]byte(key)) + return nil + }); err != nil { + return nil, err + } + + if val == nil { + return nil, ErrNotFound + } + return val, nil +} + +func (l *legacyKV) Set(key string, value []byte) error { + return l.db.Update(func(tx *bbolt.Tx) error { + return tx.Bucket(l.ns).Put([]byte(key), value) + }) +} + +func (l *legacyKV) Delete(key string) error { + return l.db.Update(func(tx *bbolt.Tx) error { + return tx.Bucket(l.ns).Delete([]byte(key)) + }) +} diff --git a/pkg/kv/memory.go b/pkg/kv/memory.go deleted file mode 100644 index fa930ef..0000000 --- a/pkg/kv/memory.go +++ /dev/null @@ -1,41 +0,0 @@ -package kv - -import "sync" - -type Memory struct { - data map[string][]byte - mu sync.RWMutex -} - -func NewMemory() *Memory { - return &Memory{ - data: make(map[string][]byte), - } -} - -func (m *Memory) Get(key string) ([]byte, error) { - m.mu.RLock() - defer m.mu.RUnlock() - - data, ok := m.data[key] - if !ok { - return nil, ErrNotFound - } - return data, nil -} - -func (m *Memory) Set(key string, value []byte) error { - m.mu.Lock() - defer m.mu.Unlock() - - m.data[key] = value - return nil -} - -func (m *Memory) Delete(key string) error { - m.mu.Lock() - defer m.mu.Unlock() - - delete(m.data, key) - return nil -}