mirror of
https://github.com/iyear/tdl
synced 2025-01-08 11:57:55 +08:00
refactor(kv): storage engine support
This commit is contained in:
parent
3fd9197cc2
commit
57e07b66bf
@ -50,15 +50,17 @@ func New(ctx context.Context, login bool, middlewares ...telegram.Middleware) (*
|
||||
)
|
||||
|
||||
if test := viper.GetString(consts.FlagTest); test != "" {
|
||||
kvd, err = kv.NewFile(filepath.Join(os.TempDir(), test)) // persistent storage
|
||||
var stg kv.Storage
|
||||
stg, err = kv.New(kv.DriverFile, map[string]any{"path": filepath.Join(os.TempDir(), test)})
|
||||
if err != nil {
|
||||
return nil, nil, errors.Wrap(err, "create test kv")
|
||||
}
|
||||
kvd, err = stg.Open(test)
|
||||
} else {
|
||||
kvd, err = kv.New(kv.Options{
|
||||
Path: consts.KVPath,
|
||||
NS: viper.GetString(consts.FlagNamespace),
|
||||
})
|
||||
kvd, err = kv.From(ctx).Open(viper.GetString(consts.FlagNamespace))
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return nil, nil, errors.Wrap(err, "open kv")
|
||||
}
|
||||
|
||||
_clock, err := Clock()
|
||||
|
@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/session"
|
||||
tdtdesktop "github.com/gotd/td/session/tdesktop"
|
||||
"github.com/spf13/viper"
|
||||
@ -32,12 +33,9 @@ type Options struct {
|
||||
func Desktop(ctx context.Context, opts *Options) error {
|
||||
ns := viper.GetString(consts.FlagNamespace)
|
||||
|
||||
kvd, err := kv.New(kv.Options{
|
||||
Path: consts.KVPath,
|
||||
NS: ns,
|
||||
})
|
||||
kvd, err := kv.From(ctx).Open(ns)
|
||||
if err != nil {
|
||||
return err
|
||||
return errors.Wrap(err, "open kv")
|
||||
}
|
||||
|
||||
desktop, err := findDesktop(opts.Desktop)
|
||||
|
41
cmd/root.go
41
cmd/root.go
@ -1,12 +1,15 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/iyear/tdl/pkg/consts"
|
||||
@ -15,12 +18,14 @@ import (
|
||||
)
|
||||
|
||||
func New() *cobra.Command {
|
||||
driverTypeKey := "type"
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "tdl",
|
||||
Short: "Telegram Downloader, but more than a downloader",
|
||||
SilenceErrors: true,
|
||||
SilenceUsage: true,
|
||||
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
// init logger
|
||||
debug, level := viper.GetBool(consts.FlagDebug), zap.InfoLevel
|
||||
if debug {
|
||||
@ -34,15 +39,44 @@ func New() *cobra.Command {
|
||||
logger.From(cmd.Context()).Info("Namespace",
|
||||
zap.String("namespace", ns))
|
||||
}
|
||||
|
||||
// check storage flag
|
||||
storageOpts := viper.GetStringMapString(consts.FlagStorage)
|
||||
driver, err := kv.ParseDriver(storageOpts[driverTypeKey])
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "parse driver")
|
||||
}
|
||||
delete(storageOpts, driverTypeKey)
|
||||
|
||||
opts := make(map[string]any)
|
||||
for k, v := range storageOpts {
|
||||
opts[k] = v
|
||||
}
|
||||
storage, err := kv.New(driver, opts)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create kv storage")
|
||||
}
|
||||
|
||||
cmd.SetContext(kv.With(cmd.Context(), storage))
|
||||
return nil
|
||||
},
|
||||
PersistentPostRunE: func(cmd *cobra.Command, args []string) error {
|
||||
return logger.From(cmd.Context()).Sync()
|
||||
return multierr.Combine(
|
||||
kv.From(cmd.Context()).Close(),
|
||||
logger.From(cmd.Context()).Sync(),
|
||||
)
|
||||
},
|
||||
}
|
||||
|
||||
cmd.AddCommand(NewVersion(), NewLogin(), NewDownload(), NewForward(),
|
||||
NewChat(), NewUpload(), NewBackup(), NewRecover(), NewGen())
|
||||
|
||||
cmd.PersistentFlags().StringToString(consts.FlagStorage, map[string]string{
|
||||
driverTypeKey: kv.DriverLegacy.String(),
|
||||
"path": consts.KVPath,
|
||||
}, fmt.Sprintf("storage options, format: type=driver,key1=value1,key2=value2. Available drivers: [%s]",
|
||||
strings.Join(kv.DriverNames(), ",")))
|
||||
|
||||
cmd.PersistentFlags().String(consts.FlagProxy, "", "proxy address, only socks5 is supported, format: protocol://username:password@host:port")
|
||||
cmd.PersistentFlags().StringP(consts.FlagNamespace, "n", "", "namespace for Telegram session")
|
||||
cmd.PersistentFlags().Bool(consts.FlagDebug, false, "enable debug mode")
|
||||
@ -59,7 +93,8 @@ func New() *cobra.Command {
|
||||
|
||||
// completion
|
||||
_ = cmd.RegisterFlagCompletionFunc(consts.FlagNamespace, func(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) {
|
||||
ns, err := kv.Namespaces(consts.KVPath)
|
||||
engine := kv.From(cmd.Context())
|
||||
ns, err := engine.Namespaces()
|
||||
if err != nil {
|
||||
return nil, cobra.ShellCompDirectiveNoFileComp
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package consts
|
||||
|
||||
const (
|
||||
FlagStorage = "storage"
|
||||
FlagProxy = "proxy"
|
||||
FlagNamespace = "ns"
|
||||
FlagDebug = "debug"
|
||||
|
@ -5,6 +5,5 @@ var (
|
||||
DataDir string
|
||||
KVPath string
|
||||
LogPath string
|
||||
DocsPath = "docs"
|
||||
UploadThumbExt = ".thumb"
|
||||
)
|
||||
|
173
pkg/kv/bolt.go
173
pkg/kv/bolt.go
@ -1,39 +1,172 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"go.etcd.io/bbolt"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"github.com/iyear/tdl/pkg/validator"
|
||||
)
|
||||
|
||||
type Bolt struct {
|
||||
ns []byte
|
||||
db *bbolt.DB
|
||||
func init() {
|
||||
register(DriverBolt, func(m map[string]any) (Storage, error) { return newBolt(m) })
|
||||
}
|
||||
|
||||
func (b *Bolt) Get(key string) ([]byte, error) {
|
||||
var val []byte
|
||||
type bolt struct {
|
||||
path string
|
||||
dbs map[string]*bbolt.DB
|
||||
mu *sync.Mutex
|
||||
}
|
||||
|
||||
if err := b.db.View(func(tx *bbolt.Tx) error {
|
||||
val = tx.Bucket(b.ns).Get([]byte(key))
|
||||
func newBolt(opts map[string]any) (*bolt, error) {
|
||||
type options struct {
|
||||
Path string `validate:"required" mapstructure:"path"`
|
||||
}
|
||||
|
||||
var o options
|
||||
if err := mapstructure.WeakDecode(opts, &o); err != nil {
|
||||
return nil, errors.Wrap(err, "decode options")
|
||||
}
|
||||
|
||||
if err := validator.Struct(&o); err != nil {
|
||||
return nil, errors.Wrap(err, "validate options")
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(o.Path, 0o755); err != nil {
|
||||
return nil, errors.Wrap(err, "create dir")
|
||||
}
|
||||
|
||||
return &bolt{
|
||||
path: o.Path,
|
||||
dbs: make(map[string]*bbolt.DB),
|
||||
mu: &sync.Mutex{},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (b *bolt) Name() string {
|
||||
return DriverBolt.String()
|
||||
}
|
||||
|
||||
func (b *bolt) MigrateTo() (Meta, error) {
|
||||
meta := make(Meta)
|
||||
|
||||
if err := b.walk(func(path string) (rerr error) {
|
||||
ns := filepath.Base(path)
|
||||
meta[ns] = make(map[string][]byte)
|
||||
|
||||
db, err := b.open(ns)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open")
|
||||
}
|
||||
|
||||
return db.db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(db.ns).ForEach(func(k, v []byte) error {
|
||||
meta[ns][string(k)] = v
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "walk")
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (b *bolt) MigrateFrom(meta Meta) error {
|
||||
for ns, pairs := range meta {
|
||||
db, err := b.open(ns)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "open")
|
||||
}
|
||||
|
||||
if err = db.db.Update(func(tx *bbolt.Tx) error {
|
||||
bk, err := tx.CreateBucketIfNotExists(db.ns)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create bucket")
|
||||
}
|
||||
for key, value := range pairs {
|
||||
if err = bk.Put([]byte(key), value); err != nil {
|
||||
return errors.Wrap(err, "put")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "update")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *bolt) Namespaces() ([]string, error) {
|
||||
namespaces := make([]string, 0)
|
||||
if err := b.walk(func(path string) error {
|
||||
namespaces = append(namespaces, filepath.Base(path))
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "walk")
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return val, nil
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func (b *Bolt) Set(key string, val []byte) error {
|
||||
return b.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(b.ns).Put([]byte(key), val)
|
||||
func (b *bolt) walk(fn func(path string) error) error {
|
||||
return filepath.Walk(b.path, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "walk")
|
||||
}
|
||||
if info.IsDir() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fn(path)
|
||||
})
|
||||
}
|
||||
|
||||
// Delete removes a key from the bucket. If the key does not exist then nothing is done and a nil error is returned
|
||||
func (b *Bolt) Delete(key string) error {
|
||||
return b.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(b.ns).Delete([]byte(key))
|
||||
})
|
||||
func (b *bolt) Open(ns string) (KV, error) {
|
||||
return b.open(ns)
|
||||
}
|
||||
|
||||
func (b *bolt) open(ns string) (*legacyKV, error) {
|
||||
if ns == "" {
|
||||
return nil, errors.New("namespace is required")
|
||||
}
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
if db, ok := b.dbs[ns]; ok {
|
||||
return &legacyKV{db: db, ns: []byte(ns)}, nil
|
||||
}
|
||||
|
||||
db, err := bbolt.Open(filepath.Join(b.path, ns), os.ModePerm, boltOptions)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open db")
|
||||
}
|
||||
if err = db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(ns))
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "create bucket")
|
||||
}
|
||||
|
||||
b.dbs[ns] = db
|
||||
|
||||
return &legacyKV{db: db, ns: []byte(ns)}, nil
|
||||
}
|
||||
|
||||
func (b *bolt) Close() error {
|
||||
b.mu.Lock()
|
||||
defer b.mu.Unlock()
|
||||
|
||||
var err error
|
||||
for _, db := range b.dbs {
|
||||
err = multierr.Append(err, db.Close())
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
179
pkg/kv/file.go
179
pkg/kv/file.go
@ -4,80 +4,115 @@ import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/iyear/tdl/pkg/validator"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
func init() {
|
||||
register(DriverFile, newFile)
|
||||
}
|
||||
|
||||
type file struct {
|
||||
path string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func NewFile(path string) (*File, error) {
|
||||
_, err := os.Stat(path)
|
||||
func newFile(opts map[string]any) (Storage, error) {
|
||||
type options struct {
|
||||
Path string `validate:"required" mapstructure:"path"`
|
||||
}
|
||||
|
||||
var o options
|
||||
if err := mapstructure.WeakDecode(opts, &o); err != nil {
|
||||
return nil, errors.Wrap(err, "decode options")
|
||||
}
|
||||
|
||||
if err := validator.Struct(&o); err != nil {
|
||||
return nil, errors.Wrap(err, "validate options")
|
||||
}
|
||||
|
||||
_, err := os.Stat(o.Path)
|
||||
if err == nil {
|
||||
return &File{path: path}, nil
|
||||
return &file{path: o.Path}, nil
|
||||
}
|
||||
|
||||
if !os.IsNotExist(err) {
|
||||
return nil, err
|
||||
return nil, errors.Wrap(err, "stat file")
|
||||
}
|
||||
|
||||
if err = os.WriteFile(path, []byte("{}"), 0o644); err != nil {
|
||||
return nil, err
|
||||
if err = os.WriteFile(o.Path, []byte("{}"), 0o644); err != nil {
|
||||
return nil, errors.Wrap(err, "create file")
|
||||
}
|
||||
|
||||
return &File{path: path}, nil
|
||||
return &file{path: o.Path}, nil
|
||||
}
|
||||
|
||||
func (f *File) Get(key string) ([]byte, error) {
|
||||
func (f *file) Name() string {
|
||||
return DriverFile.String()
|
||||
}
|
||||
|
||||
func (f *file) MigrateTo() (Meta, error) {
|
||||
meta, err := f.read()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read")
|
||||
}
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (f *file) MigrateFrom(meta Meta) error {
|
||||
return f.write(meta)
|
||||
}
|
||||
|
||||
func (f *file) Namespaces() ([]string, error) {
|
||||
pairs, err := f.read()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
namespaces := make([]string, 0, len(pairs))
|
||||
for ns := range pairs {
|
||||
namespaces = append(namespaces, ns)
|
||||
}
|
||||
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func (f *file) Open(ns string) (KV, error) {
|
||||
if ns == "" {
|
||||
return nil, errors.New("namespace is required")
|
||||
}
|
||||
|
||||
read, err := f.read()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
if _, ok := read[ns]; !ok {
|
||||
read[ns] = make(map[string][]byte)
|
||||
if err = f.write(read); err != nil {
|
||||
return nil, errors.Wrap(err, "write")
|
||||
}
|
||||
}
|
||||
|
||||
return &fileKV{f: f, ns: ns}, nil
|
||||
}
|
||||
|
||||
func (f *file) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *file) read() (map[string]map[string][]byte, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
m, err := f.read()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if val, ok := m[key]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (f *File) Set(key string, value []byte) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
m, err := f.read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m[key] = value
|
||||
|
||||
return f.write(m)
|
||||
}
|
||||
|
||||
func (f *File) Delete(key string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
m, err := f.read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
delete(m, key)
|
||||
|
||||
return f.write(m)
|
||||
}
|
||||
|
||||
func (f *File) read() (map[string][]byte, error) {
|
||||
bytes, err := os.ReadFile(f.path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m := make(map[string][]byte)
|
||||
m := make(map[string]map[string][]byte)
|
||||
if err = json.Unmarshal(bytes, &m); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -85,7 +120,10 @@ func (f *File) read() (map[string][]byte, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func (f *File) write(m map[string][]byte) error {
|
||||
func (f *file) write(m map[string]map[string][]byte) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
|
||||
bytes, err := json.Marshal(m)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -93,3 +131,42 @@ func (f *File) write(m map[string][]byte) error {
|
||||
|
||||
return os.WriteFile(f.path, bytes, 0o644)
|
||||
}
|
||||
|
||||
type fileKV struct {
|
||||
f *file
|
||||
ns string
|
||||
}
|
||||
|
||||
func (f *fileKV) Get(key string) ([]byte, error) {
|
||||
m, err := f.f.read()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
if v, ok := m[f.ns][key]; ok {
|
||||
return v, nil
|
||||
}
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
|
||||
func (f *fileKV) Set(key string, value []byte) error {
|
||||
m, err := f.f.read()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
m[f.ns][key] = value
|
||||
|
||||
return f.f.write(m)
|
||||
}
|
||||
|
||||
func (f *fileKV) Delete(key string) error {
|
||||
m, err := f.f.read()
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "read")
|
||||
}
|
||||
|
||||
delete(m[f.ns], key)
|
||||
|
||||
return f.f.write(m)
|
||||
}
|
||||
|
87
pkg/kv/kv.go
87
pkg/kv/kv.go
@ -1,74 +1,57 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"time"
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"go.etcd.io/bbolt"
|
||||
"go.uber.org/multierr"
|
||||
|
||||
"github.com/iyear/tdl/pkg/validator"
|
||||
"github.com/go-faster/errors"
|
||||
)
|
||||
|
||||
//go:generate go-enum --values --names --flag --nocase
|
||||
|
||||
// Driver
|
||||
// ENUM(legacy, bolt, file)
|
||||
type Driver string
|
||||
|
||||
var ErrNotFound = errors.New("key not found")
|
||||
|
||||
type Meta map[string]map[string][]byte // namespace, key, value
|
||||
|
||||
type Storage interface {
|
||||
Name() string
|
||||
MigrateTo() (Meta, error)
|
||||
MigrateFrom(Meta) error
|
||||
Namespaces() ([]string, error)
|
||||
Open(ns string) (KV, error)
|
||||
io.Closer
|
||||
}
|
||||
|
||||
type KV interface {
|
||||
Get(key string) ([]byte, error)
|
||||
Set(key string, value []byte) error
|
||||
Delete(key string) error
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
NS string `validate:"required"`
|
||||
Path string `validate:"required"`
|
||||
var drivers = map[Driver]func(map[string]any) (Storage, error){}
|
||||
|
||||
func register(name Driver, fn func(map[string]any) (Storage, error)) {
|
||||
drivers[name] = fn
|
||||
}
|
||||
|
||||
func New(opts Options) (KV, error) {
|
||||
if err := validator.Struct(&opts); err != nil {
|
||||
return nil, err
|
||||
func New(driver Driver, opts map[string]any) (Storage, error) {
|
||||
if fn, ok := drivers[driver]; ok {
|
||||
return fn(opts)
|
||||
}
|
||||
|
||||
db, err := bbolt.Open(opts.Path, os.ModePerm, &bbolt.Options{
|
||||
Timeout: time.Second,
|
||||
NoGrowSync: false,
|
||||
FreelistType: bbolt.FreelistArrayType,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = db.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(opts.NS))
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Bolt{db: db, ns: []byte(opts.NS)}, nil
|
||||
return nil, errors.Errorf("unsupported driver: %s", driver)
|
||||
}
|
||||
|
||||
// Namespaces returns all namespaces in the database
|
||||
func Namespaces(path string) (_ []string, rerr error) {
|
||||
db, err := bbolt.Open(path, os.ModePerm, &bbolt.Options{
|
||||
Timeout: time.Second,
|
||||
ReadOnly: true,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer multierr.AppendInvoke(&rerr, multierr.Close(db))
|
||||
type ctxKey struct{}
|
||||
|
||||
namespaces := make([]string, 0)
|
||||
err = db.View(func(tx *bbolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, _ *bbolt.Bucket) error {
|
||||
namespaces = append(namespaces, string(name))
|
||||
return nil
|
||||
})
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return namespaces, nil
|
||||
func With(ctx context.Context, kv Storage) context.Context {
|
||||
return context.WithValue(ctx, ctxKey{}, kv)
|
||||
}
|
||||
|
||||
func From(ctx context.Context) Storage {
|
||||
return ctx.Value(ctxKey{}).(Storage)
|
||||
}
|
||||
|
92
pkg/kv/kv_enum.go
Normal file
92
pkg/kv/kv_enum.go
Normal file
@ -0,0 +1,92 @@
|
||||
// Code generated by go-enum DO NOT EDIT.
|
||||
// Version: 0.5.8
|
||||
// Revision: 3d844c8ecc59661ed7aa17bfd65727bc06a60ad8
|
||||
// Build Date: 2023-09-18T14:55:21Z
|
||||
// Built By: goreleaser
|
||||
|
||||
package kv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
// DriverLegacy is a Driver of type legacy.
|
||||
DriverLegacy Driver = "legacy"
|
||||
// DriverBolt is a Driver of type bolt.
|
||||
DriverBolt Driver = "bolt"
|
||||
// DriverFile is a Driver of type file.
|
||||
DriverFile Driver = "file"
|
||||
)
|
||||
|
||||
var ErrInvalidDriver = fmt.Errorf("not a valid Driver, try [%s]", strings.Join(_DriverNames, ", "))
|
||||
|
||||
var _DriverNames = []string{
|
||||
string(DriverLegacy),
|
||||
string(DriverBolt),
|
||||
string(DriverFile),
|
||||
}
|
||||
|
||||
// DriverNames returns a list of possible string values of Driver.
|
||||
func DriverNames() []string {
|
||||
tmp := make([]string, len(_DriverNames))
|
||||
copy(tmp, _DriverNames)
|
||||
return tmp
|
||||
}
|
||||
|
||||
// DriverValues returns a list of the values for Driver
|
||||
func DriverValues() []Driver {
|
||||
return []Driver{
|
||||
DriverLegacy,
|
||||
DriverBolt,
|
||||
DriverFile,
|
||||
}
|
||||
}
|
||||
|
||||
// String implements the Stringer interface.
|
||||
func (x Driver) String() string {
|
||||
return string(x)
|
||||
}
|
||||
|
||||
// IsValid provides a quick way to determine if the typed value is
|
||||
// part of the allowed enumerated values
|
||||
func (x Driver) IsValid() bool {
|
||||
_, err := ParseDriver(string(x))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
var _DriverValue = map[string]Driver{
|
||||
"legacy": DriverLegacy,
|
||||
"bolt": DriverBolt,
|
||||
"file": DriverFile,
|
||||
}
|
||||
|
||||
// ParseDriver attempts to convert a string to a Driver.
|
||||
func ParseDriver(name string) (Driver, error) {
|
||||
if x, ok := _DriverValue[name]; ok {
|
||||
return x, nil
|
||||
}
|
||||
// Case insensitive parse, do a separate lookup to prevent unnecessary cost of lowercasing a string if we don't need to.
|
||||
if x, ok := _DriverValue[strings.ToLower(name)]; ok {
|
||||
return x, nil
|
||||
}
|
||||
return Driver(""), fmt.Errorf("%s is %w", name, ErrInvalidDriver)
|
||||
}
|
||||
|
||||
// Set implements the Golang flag.Value interface func.
|
||||
func (x *Driver) Set(val string) error {
|
||||
v, err := ParseDriver(val)
|
||||
*x = v
|
||||
return err
|
||||
}
|
||||
|
||||
// Get implements the Golang flag.Getter interface func.
|
||||
func (x *Driver) Get() interface{} {
|
||||
return *x
|
||||
}
|
||||
|
||||
// Type implements the github.com/spf13/pFlag Value interface.
|
||||
func (x *Driver) Type() string {
|
||||
return "Driver"
|
||||
}
|
155
pkg/kv/kv_test.go
Normal file
155
pkg/kv/kv_test.go
Normal file
@ -0,0 +1,155 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
)
|
||||
|
||||
func forEachStorage(t *testing.T, fn func(e Storage, t *testing.T)) {
|
||||
storages := map[Driver]map[string]any{
|
||||
DriverBolt: {"path": t.TempDir()},
|
||||
DriverLegacy: {"path": filepath.Join(t.TempDir(), "test.db")},
|
||||
DriverFile: {"path": filepath.Join(t.TempDir(), "test.json")},
|
||||
}
|
||||
|
||||
for driver, opts := range storages {
|
||||
storage, err := New(driver, opts)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run(driver.String(), func(t *testing.T) {
|
||||
fn(storage, t)
|
||||
})
|
||||
assert.NoError(t, storage.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestNew(t *testing.T) {
|
||||
tests := map[Driver][]struct {
|
||||
name string
|
||||
opts map[string]any
|
||||
wantErr bool
|
||||
}{
|
||||
DriverBolt: {
|
||||
{name: "valid", opts: map[string]any{"path": t.TempDir()}, wantErr: false},
|
||||
{name: "invalid", opts: map[string]any{"path": ""}, wantErr: true},
|
||||
},
|
||||
DriverLegacy: {
|
||||
{name: "valid", opts: map[string]any{"path": filepath.Join(t.TempDir(), "test.db")}, wantErr: false},
|
||||
{name: "invalid", opts: map[string]any{"path": ""}, wantErr: true},
|
||||
},
|
||||
DriverFile: {
|
||||
{name: "valid", opts: map[string]any{"path": filepath.Join(t.TempDir(), "test.json")}, wantErr: false},
|
||||
},
|
||||
Driver("unknown"): {
|
||||
{name: "unknown", opts: map[string]any{"path": ""}, wantErr: true},
|
||||
},
|
||||
}
|
||||
|
||||
for driver, tests := range tests {
|
||||
for _, tt := range tests {
|
||||
t.Run(fmt.Sprintf("%v/%s", driver, tt.name), func(t *testing.T) {
|
||||
kv, err := New(driver, tt.opts)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, kv)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, kv)
|
||||
assert.NoError(t, kv.Close())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestStorage_Open(t *testing.T) {
|
||||
forEachStorage(t, func(e Storage, t *testing.T) {
|
||||
for _, ns := range []string{"foo", "bar", "foo"} {
|
||||
kv, err := e.Open(ns)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kv)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_Namespaces(t *testing.T) {
|
||||
namespaces := []string{"foo", "bar", "baz"}
|
||||
|
||||
forEachStorage(t, func(e Storage, t *testing.T) {
|
||||
for _, ns := range namespaces {
|
||||
kv, err := e.Open(ns)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kv)
|
||||
}
|
||||
|
||||
ns, err := e.Namespaces()
|
||||
require.NoError(t, err)
|
||||
require.ElementsMatch(t, namespaces, ns)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_MigrateTo(t *testing.T) {
|
||||
meta := Meta{
|
||||
"foo": {
|
||||
"1": []byte("2"),
|
||||
"3": []byte("4"),
|
||||
"5": []byte("6"),
|
||||
},
|
||||
"bar": {
|
||||
"7": []byte("8"),
|
||||
"9": []byte("10"),
|
||||
"11": []byte("12"),
|
||||
},
|
||||
}
|
||||
|
||||
forEachStorage(t, func(e Storage, t *testing.T) {
|
||||
for ns, pairs := range meta {
|
||||
kv, err := e.Open(ns)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kv)
|
||||
|
||||
for key, value := range pairs {
|
||||
require.NoError(t, kv.Set(key, value))
|
||||
}
|
||||
}
|
||||
|
||||
m, err := e.MigrateTo()
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, meta, m)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorage_MigrateFrom(t *testing.T) {
|
||||
meta := Meta{
|
||||
"foo": {
|
||||
"1": []byte("2"),
|
||||
"3": []byte("4"),
|
||||
"5": []byte("6"),
|
||||
},
|
||||
"bar": {
|
||||
"7": []byte("8"),
|
||||
"9": []byte("10"),
|
||||
"11": []byte("12"),
|
||||
},
|
||||
}
|
||||
|
||||
forEachStorage(t, func(e Storage, t *testing.T) {
|
||||
require.NoError(t, e.MigrateFrom(meta))
|
||||
|
||||
for ns, pairs := range meta {
|
||||
kv, err := e.Open(ns)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, kv)
|
||||
|
||||
for key, value := range pairs {
|
||||
v, err := kv.Get(key)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, value, v)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
158
pkg/kv/legacy.go
Normal file
158
pkg/kv/legacy.go
Normal file
@ -0,0 +1,158 @@
|
||||
package kv
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"go.etcd.io/bbolt"
|
||||
|
||||
"github.com/iyear/tdl/pkg/validator"
|
||||
)
|
||||
|
||||
var boltOptions = &bbolt.Options{
|
||||
Timeout: time.Second,
|
||||
NoGrowSync: false,
|
||||
FreelistType: bbolt.FreelistArrayType,
|
||||
}
|
||||
|
||||
func init() {
|
||||
register(DriverLegacy, func(m map[string]any) (Storage, error) {
|
||||
return newLegacy(m)
|
||||
})
|
||||
}
|
||||
|
||||
func newLegacy(opts map[string]any) (*legacy, error) {
|
||||
type options struct {
|
||||
Path string `validate:"required" mapstructure:"path"`
|
||||
}
|
||||
|
||||
var o options
|
||||
if err := mapstructure.WeakDecode(opts, &o); err != nil {
|
||||
return nil, errors.Wrap(err, "decode options")
|
||||
}
|
||||
|
||||
if err := validator.Struct(&o); err != nil {
|
||||
return nil, errors.Wrap(err, "validate options")
|
||||
}
|
||||
|
||||
db, err := bbolt.Open(o.Path, os.ModePerm, boltOptions)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "open db")
|
||||
}
|
||||
|
||||
return &legacy{bolt: db}, nil
|
||||
}
|
||||
|
||||
type legacy struct {
|
||||
bolt *bbolt.DB
|
||||
}
|
||||
|
||||
func (l *legacy) Name() string {
|
||||
return DriverLegacy.String()
|
||||
}
|
||||
|
||||
func (l *legacy) MigrateTo() (Meta, error) {
|
||||
meta := make(Meta)
|
||||
|
||||
if err := l.bolt.View(func(tx *bbolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, b *bbolt.Bucket) error {
|
||||
ns := string(name)
|
||||
meta[ns] = make(map[string][]byte)
|
||||
return b.ForEach(func(k, v []byte) error {
|
||||
meta[ns][string(k)] = v
|
||||
return nil
|
||||
})
|
||||
})
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "iterate buckets")
|
||||
}
|
||||
|
||||
return meta, nil
|
||||
}
|
||||
|
||||
func (l *legacy) MigrateFrom(meta Meta) error {
|
||||
return l.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
for ns, pairs := range meta {
|
||||
b, err := tx.CreateBucketIfNotExists([]byte(ns))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create bucket")
|
||||
}
|
||||
for key, value := range pairs {
|
||||
if err = b.Put([]byte(key), value); err != nil {
|
||||
return errors.Wrap(err, "put")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (l *legacy) Namespaces() ([]string, error) {
|
||||
namespaces := make([]string, 0)
|
||||
if err := l.bolt.View(func(tx *bbolt.Tx) error {
|
||||
return tx.ForEach(func(name []byte, _ *bbolt.Bucket) error {
|
||||
namespaces = append(namespaces, string(name))
|
||||
return nil
|
||||
})
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "iterate namespaces")
|
||||
}
|
||||
return namespaces, nil
|
||||
}
|
||||
|
||||
func (l *legacy) Open(ns string) (KV, error) {
|
||||
return l.open(ns)
|
||||
}
|
||||
|
||||
func (l *legacy) open(ns string) (*legacyKV, error) {
|
||||
if ns == "" {
|
||||
return nil, errors.New("namespace is required")
|
||||
}
|
||||
if err := l.bolt.Update(func(tx *bbolt.Tx) error {
|
||||
_, err := tx.CreateBucketIfNotExists([]byte(ns))
|
||||
return err
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "create bucket")
|
||||
}
|
||||
|
||||
return &legacyKV{db: l.bolt, ns: []byte(ns)}, nil
|
||||
}
|
||||
|
||||
func (l *legacy) Close() error {
|
||||
return l.bolt.Close()
|
||||
}
|
||||
|
||||
type legacyKV struct {
|
||||
db *bbolt.DB
|
||||
ns []byte
|
||||
}
|
||||
|
||||
func (l *legacyKV) Get(key string) ([]byte, error) {
|
||||
var val []byte
|
||||
|
||||
if err := l.db.View(func(tx *bbolt.Tx) error {
|
||||
val = tx.Bucket(l.ns).Get([]byte(key))
|
||||
return nil
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if val == nil {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
func (l *legacyKV) Set(key string, value []byte) error {
|
||||
return l.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(l.ns).Put([]byte(key), value)
|
||||
})
|
||||
}
|
||||
|
||||
func (l *legacyKV) Delete(key string) error {
|
||||
return l.db.Update(func(tx *bbolt.Tx) error {
|
||||
return tx.Bucket(l.ns).Delete([]byte(key))
|
||||
})
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package kv
|
||||
|
||||
import "sync"
|
||||
|
||||
type Memory struct {
|
||||
data map[string][]byte
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemory() *Memory {
|
||||
return &Memory{
|
||||
data: make(map[string][]byte),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Memory) Get(key string) ([]byte, error) {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
data, ok := m.data[key]
|
||||
if !ok {
|
||||
return nil, ErrNotFound
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (m *Memory) Set(key string, value []byte) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
m.data[key] = value
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Memory) Delete(key string) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
delete(m.data, key)
|
||||
return nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user