feat(dl): resume download. #86

This commit is contained in:
iyear 2023-01-28 16:51:09 +08:00
parent 0b92e80239
commit 98b81a3133
7 changed files with 179 additions and 35 deletions

View File

@ -2,13 +2,18 @@ package dl
import (
"context"
"encoding/json"
"errors"
"fmt"
"github.com/AlecAivazis/survey/v2"
"github.com/fatih/color"
"github.com/gotd/contrib/middleware/floodwait"
"github.com/iyear/tdl/app/internal/tgc"
"github.com/iyear/tdl/pkg/consts"
"github.com/iyear/tdl/pkg/dcpool"
"github.com/iyear/tdl/pkg/downloader"
"github.com/iyear/tdl/pkg/key"
"github.com/iyear/tdl/pkg/kv"
"github.com/jedib0t/go-pretty/v6/text"
"github.com/spf13/viper"
"go.uber.org/multierr"
@ -62,6 +67,18 @@ func Run(ctx context.Context, opts *Options) error {
return err
}
// resume download and ask user to continue
if err = resume(ctx, kvd, it); err != nil {
return err
}
defer func() {
if rerr != nil { // download is interrupted
multierr.AppendInto(&rerr, saveProgress(kvd, it))
} else { // if finished, we should clear resume key
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(it.fingerprint)))
}
}()
options := &downloader.Options{
Pool: pool,
Dir: opts.Dir,
@ -74,3 +91,46 @@ func Run(ctx context.Context, opts *Options) error {
return downloader.New(options).Download(ctx, viper.GetInt(consts.FlagLimit))
})
}
func resume(ctx context.Context, kvd kv.KV, it *iter) error {
b, err := kvd.Get(key.Resume(it.fingerprint))
if err != nil && !errors.Is(err, kv.ErrNotFound) {
return err
}
if len(b) == 0 { // no progress
return nil
}
finished := make(map[int]struct{})
if err = json.Unmarshal(b, &finished); err != nil {
return err
}
// finished is empty, no need to resume
if len(finished) == 0 {
return nil
}
confirm := false
if err = survey.AskOne(&survey.Confirm{
Message: fmt.Sprintf("Found unfinished download, continue from '%d/%d'?", len(finished), it.Total(ctx)),
}, &confirm); err != nil {
return err
}
if !confirm {
// clear resume key
return kvd.Delete(key.Resume(it.fingerprint))
}
it.setFinished(finished)
return nil
}
func saveProgress(kvd kv.KV, it *iter) error {
b, err := json.Marshal(it.finished)
if err != nil {
return err
}
return kvd.Set(key.Resume(it.fingerprint), b)
}

View File

@ -3,6 +3,8 @@ package dl
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"errors"
"fmt"
"github.com/gotd/td/telegram/peers"
@ -14,6 +16,7 @@ import (
"github.com/iyear/tdl/pkg/storage"
"github.com/iyear/tdl/pkg/utils"
"path/filepath"
"sort"
"sync"
"text/template"
"time"
@ -26,8 +29,10 @@ type iter struct {
mu sync.Mutex
curi int
curj int
finished map[int]struct{}
template *template.Template
manager *peers.Manager
fingerprint string
}
type dialog struct {
@ -74,43 +79,67 @@ func newIter(pool dcpool.Pool, kvd kv.KV, tmpl string, include, exclude []string
excludeMap[utils.FS.AddPrefixDot(v)] = struct{}{}
}
return &iter{
pool: pool,
dialogs: mm,
include: includeMap,
exclude: excludeMap,
curi: 0,
curj: -1,
template: t,
manager: peers.Options{Storage: storage.NewPeers(kvd)}.Build(pool.Client(pool.Default())),
}, nil
// to keep fingerprint stable
sortDialogs(mm)
it := &iter{
pool: pool,
dialogs: mm,
include: includeMap,
exclude: excludeMap,
curi: 0,
curj: -1,
finished: make(map[int]struct{}),
template: t,
manager: peers.Options{Storage: storage.NewPeers(kvd)}.Build(pool.Client(pool.Default())),
fingerprint: fingerprint(mm),
}
return it, nil
}
func (i *iter) Next(ctx context.Context) (*downloader.Item, error) {
func sortDialogs(dialogs []*dialog) {
sort.Slice(dialogs, func(i, j int) bool {
return utils.Telegram.GetInputPeerID(dialogs[i].peer) <
utils.Telegram.GetInputPeerID(dialogs[j].peer) // increasing order
})
for _, m := range dialogs {
sort.Slice(m.msgs, func(i, j int) bool {
return m.msgs[i] > m.msgs[j] // decreasing order
})
}
}
func (iter *iter) Next(ctx context.Context) (*downloader.Item, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
i.mu.Lock()
i.curj++
if i.curj >= len(i.dialogs[i.curi].msgs) {
if i.curi++; i.curi >= len(i.dialogs) {
iter.mu.Lock()
iter.curj++
if iter.curj >= len(iter.dialogs[iter.curi].msgs) {
if iter.curi++; iter.curi >= len(iter.dialogs) {
return nil, errors.New("no more items")
}
i.curj = 0
iter.curj = 0
}
iter.mu.Unlock()
// check if finished
if _, ok := iter.finished[iter.ij2n(iter.curi, iter.curj)]; ok {
return nil, downloader.ErrSkip
}
curi := i.dialogs[i.curi]
cur := curi.msgs[i.curj]
i.mu.Unlock()
return i.item(ctx, curi.peer, cur)
return iter.item(ctx, iter.curi, iter.curj)
}
func (i *iter) item(ctx context.Context, peer tg.InputPeerClass, msg int) (*downloader.Item, error) {
it := query.Messages(i.pool.Client(i.pool.Default())).
func (iter *iter) item(ctx context.Context, i, j int) (*downloader.Item, error) {
peer, msg := iter.dialogs[i].peer, iter.dialogs[i].msgs[j]
it := query.Messages(iter.pool.Client(iter.pool.Default())).
GetHistory(peer).OffsetID(msg + 1).
BatchSize(1).Iter()
id := utils.Telegram.GetInputPeerID(peer)
@ -138,19 +167,19 @@ func (i *iter) item(ctx context.Context, peer tg.InputPeerClass, msg int) (*down
// process include and exclude
ext := filepath.Ext(media.Name)
if len(i.include) > 0 {
if _, ok = i.include[ext]; !ok {
if len(iter.include) > 0 {
if _, ok = iter.include[ext]; !ok {
return nil, downloader.ErrSkip
}
}
if len(i.exclude) > 0 {
if _, ok = i.exclude[ext]; ok {
if len(iter.exclude) > 0 {
if _, ok = iter.exclude[ext]; ok {
return nil, downloader.ErrSkip
}
}
buf := bytes.Buffer{}
err := i.template.Execute(&buf, &fileTemplate{
err := iter.template.Execute(&buf, &fileTemplate{
DialogID: id,
MessageID: message.ID,
MessageDate: int64(message.Date),
@ -163,16 +192,56 @@ func (i *iter) item(ctx context.Context, peer tg.InputPeerClass, msg int) (*down
}
media.Name = buf.String()
media.ID = iter.ij2n(i, j)
return media, nil
}
func (i *iter) Total(_ context.Context) int {
i.mu.Lock()
defer i.mu.Unlock()
func (iter *iter) setFinished(finished map[int]struct{}) {
iter.mu.Lock()
defer iter.mu.Unlock()
iter.finished = finished
}
func (iter *iter) Finish(_ context.Context, id int) error {
iter.mu.Lock()
defer iter.mu.Unlock()
iter.finished[id] = struct{}{}
return nil
}
func (iter *iter) Total(_ context.Context) int {
iter.mu.Lock()
defer iter.mu.Unlock()
total := 0
for _, m := range i.dialogs {
for _, m := range iter.dialogs {
total += len(m.msgs)
}
return total
}
func (iter *iter) ij2n(i, j int) int {
n := 0
for k := 0; k < i; k++ {
n += len(iter.dialogs[k].msgs)
}
return n + j
}
func fingerprint(dialogs []*dialog) string {
endian := binary.BigEndian
buf, b := &bytes.Buffer{}, make([]byte, 8)
for _, m := range dialogs {
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.peer)))
buf.Write(b)
for _, msg := range m.msgs {
endian.PutUint64(b, uint64(msg))
buf.Write(b)
}
}
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
}

View File

@ -68,12 +68,13 @@ func (d *Downloader) Download(ctx context.Context, limit int) error {
wg.SetLimit(limit)
for i := 0; i < total; i++ {
wg.Go(func() error {
wg.Go(func() (rerr error) {
item, err := d.iter.Next(errctx)
if err != nil {
// skip error means we don't need to log error
if !errors.Is(err, ErrSkip) {
if !errors.Is(err, ErrSkip) && !errors.Is(err, context.Canceled) {
d.pw.Log(color.RedString("failed: %v", err))
return err
}
return nil
}
@ -165,5 +166,5 @@ func (d *Downloader) download(ctx context.Context, item *Item) error {
return err
}
return nil
return d.iter.Finish(ctx, item.ID)
}

View File

@ -10,10 +10,12 @@ var ErrSkip = errors.New("skip")
type Iter interface {
Next(ctx context.Context) (*Item, error)
Finish(ctx context.Context, id int) error
Total(ctx context.Context) int
}
type Item struct {
ID int // unique in iter
InputFileLoc tg.InputFileLocationClass
Name string
Size int64

View File

@ -53,3 +53,7 @@ func PeersPhone(phone string) string {
func PeersContactsHash() string {
return New("peers", "contacts", "hash")
}
func Resume(fingerprint string) string {
return New("resume", fingerprint)
}

View File

@ -30,3 +30,10 @@ func (b *Bolt) Set(key string, val []byte) error {
return tx.Bucket(b.ns).Put([]byte(key), val)
})
}
// Delete removes a key from the bucket. If the key does not exist then nothing is done and a nil error is returned
func (b *Bolt) Delete(key string) error {
return b.db.Update(func(tx *bbolt.Tx) error {
return tx.Bucket(b.ns).Delete([]byte(key))
})
}

View File

@ -15,6 +15,7 @@ var (
type KV interface {
Get(key string) ([]byte, error)
Set(key string, value []byte) error
Delete(key string) error
}
type Options struct {