mirror of
https://github.com/iyear/tdl
synced 2025-01-09 04:17:35 +08:00
refactor(dl): extract to interface and impl iter in app
This commit is contained in:
parent
b7bbbf60f1
commit
169e47913a
65
app/dl/dl.go
65
app/dl/dl.go
@ -8,11 +8,11 @@ import (
|
|||||||
"github.com/AlecAivazis/survey/v2"
|
"github.com/AlecAivazis/survey/v2"
|
||||||
"github.com/fatih/color"
|
"github.com/fatih/color"
|
||||||
"github.com/go-faster/errors"
|
"github.com/go-faster/errors"
|
||||||
|
"github.com/gotd/td/telegram/peers"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"go.uber.org/multierr"
|
"go.uber.org/multierr"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/iyear/tdl/app/internal/dliter"
|
|
||||||
"github.com/iyear/tdl/app/internal/tgc"
|
"github.com/iyear/tdl/app/internal/tgc"
|
||||||
"github.com/iyear/tdl/pkg/consts"
|
"github.com/iyear/tdl/pkg/consts"
|
||||||
"github.com/iyear/tdl/pkg/dcpool"
|
"github.com/iyear/tdl/pkg/dcpool"
|
||||||
@ -20,7 +20,10 @@ import (
|
|||||||
"github.com/iyear/tdl/pkg/key"
|
"github.com/iyear/tdl/pkg/key"
|
||||||
"github.com/iyear/tdl/pkg/kv"
|
"github.com/iyear/tdl/pkg/kv"
|
||||||
"github.com/iyear/tdl/pkg/logger"
|
"github.com/iyear/tdl/pkg/logger"
|
||||||
|
"github.com/iyear/tdl/pkg/prog"
|
||||||
|
"github.com/iyear/tdl/pkg/storage"
|
||||||
"github.com/iyear/tdl/pkg/tmessage"
|
"github.com/iyear/tdl/pkg/tmessage"
|
||||||
|
"github.com/iyear/tdl/pkg/utils"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
@ -48,7 +51,7 @@ type parser struct {
|
|||||||
Parser tmessage.ParseSource
|
Parser tmessage.ParseSource
|
||||||
}
|
}
|
||||||
|
|
||||||
func Run(ctx context.Context, opts *Options) error {
|
func Run(ctx context.Context, opts Options) error {
|
||||||
c, kvd, err := tgc.NoLogin(ctx)
|
c, kvd, err := tgc.NoLogin(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@ -78,22 +81,16 @@ func Run(ctx context.Context, opts *Options) error {
|
|||||||
return serve(ctx, kvd, pool, dialogs, opts.Port, opts.Takeout)
|
return serve(ctx, kvd, pool, dialogs, opts.Port, opts.Takeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
iter, err := dliter.New(ctx, &dliter.Options{
|
manager := peers.Options{Storage: storage.NewPeers(kvd)}.Build(pool.Default(ctx))
|
||||||
Pool: pool,
|
|
||||||
KV: kvd,
|
it, err := newIter(pool, manager, dialogs, opts)
|
||||||
Template: opts.Template,
|
|
||||||
Include: opts.Include,
|
|
||||||
Exclude: opts.Exclude,
|
|
||||||
Desc: opts.Desc,
|
|
||||||
Dialogs: dialogs,
|
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !opts.Restart {
|
if !opts.Restart {
|
||||||
// resume download and ask user to continue
|
// resume download and ask user to continue
|
||||||
if err = resume(ctx, kvd, iter, !opts.Continue); err != nil {
|
if err = resume(ctx, kvd, it, !opts.Continue); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@ -102,37 +99,39 @@ func Run(ctx context.Context, opts *Options) error {
|
|||||||
|
|
||||||
defer func() { // save progress
|
defer func() { // save progress
|
||||||
if rerr != nil { // download is interrupted
|
if rerr != nil { // download is interrupted
|
||||||
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, iter))
|
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, it))
|
||||||
} else { // if finished, we should clear resume key
|
} else { // if finished, we should clear resume key
|
||||||
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(iter.Fingerprint())))
|
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(it.Fingerprint())))
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
dlProgress := prog.New(utils.Byte.FormatBinaryBytes)
|
||||||
|
dlProgress.SetNumTrackersExpected(it.Total())
|
||||||
|
prog.EnablePS(ctx, dlProgress)
|
||||||
|
|
||||||
options := downloader.Options{
|
options := downloader.Options{
|
||||||
Pool: pool,
|
Pool: pool,
|
||||||
Dir: opts.Dir,
|
PartSize: viper.GetInt(consts.FlagPartSize),
|
||||||
RewriteExt: opts.RewriteExt,
|
Threads: viper.GetInt(consts.FlagThreads),
|
||||||
SkipSame: opts.SkipSame,
|
Iter: it,
|
||||||
PartSize: viper.GetInt(consts.FlagPartSize),
|
Progress: newProgress(dlProgress, it, opts),
|
||||||
Threads: viper.GetInt(consts.FlagThreads),
|
|
||||||
Iter: iter,
|
|
||||||
Takeout: opts.Takeout,
|
|
||||||
}
|
}
|
||||||
limit := viper.GetInt(consts.FlagLimit)
|
limit := viper.GetInt(consts.FlagLimit)
|
||||||
|
|
||||||
logger.From(ctx).Info("Start download",
|
logger.From(ctx).Info("Start download",
|
||||||
zap.String("dir", options.Dir),
|
zap.String("dir", opts.Dir),
|
||||||
zap.Bool("rewrite_ext", options.RewriteExt),
|
zap.Bool("rewrite_ext", opts.RewriteExt),
|
||||||
zap.Bool("skip_same", options.SkipSame),
|
zap.Bool("skip_same", opts.SkipSame),
|
||||||
zap.Int("part_size", options.PartSize),
|
zap.Int("part_size", options.PartSize),
|
||||||
zap.Int("threads", options.Threads),
|
zap.Int("threads", options.Threads),
|
||||||
zap.Int("limit", limit))
|
zap.Int("limit", limit))
|
||||||
|
|
||||||
dl, err := downloader.New(options)
|
color.Green("All files will be downloaded to '%s' dir", opts.Dir)
|
||||||
if err != nil {
|
|
||||||
return errors.Wrap(err, "create downloader")
|
go dlProgress.Render()
|
||||||
}
|
defer prog.Wait(ctx, dlProgress)
|
||||||
return dl.Download(ctx, limit)
|
|
||||||
|
return downloader.New(options).Download(ctx, limit)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -148,7 +147,7 @@ func collectDialogs(parsers []parser) ([][]*tmessage.Dialog, error) {
|
|||||||
return dialogs, nil
|
return dialogs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
func resume(ctx context.Context, kvd kv.KV, iter *iter, ask bool) error {
|
||||||
logger.From(ctx).Debug("Check resume key",
|
logger.From(ctx).Debug("Check resume key",
|
||||||
zap.String("fingerprint", iter.Fingerprint()))
|
zap.String("fingerprint", iter.Fingerprint()))
|
||||||
|
|
||||||
@ -171,7 +170,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
confirm := false
|
confirm := false
|
||||||
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total(ctx))
|
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total())
|
||||||
if ask {
|
if ask {
|
||||||
if err = survey.AskOne(&survey.Confirm{
|
if err = survey.AskOne(&survey.Confirm{
|
||||||
Message: color.YellowString(resumeStr + "?"),
|
Message: color.YellowString(resumeStr + "?"),
|
||||||
@ -195,7 +194,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func saveProgress(ctx context.Context, kvd kv.KV, it *dliter.Iter) error {
|
func saveProgress(ctx context.Context, kvd kv.KV, it *iter) error {
|
||||||
finished := it.Finished()
|
finished := it.Finished()
|
||||||
logger.From(ctx).Debug("Save progress",
|
logger.From(ctx).Debug("Save progress",
|
||||||
zap.Int("finished", len(finished)))
|
zap.Int("finished", len(finished)))
|
||||||
|
38
app/dl/elem.go
Normal file
38
app/dl/elem.go
Normal file
@ -0,0 +1,38 @@
|
|||||||
|
package dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/gotd/td/telegram/peers"
|
||||||
|
"github.com/gotd/td/tg"
|
||||||
|
|
||||||
|
"github.com/iyear/tdl/pkg/downloader"
|
||||||
|
"github.com/iyear/tdl/pkg/tmedia"
|
||||||
|
)
|
||||||
|
|
||||||
|
type iterElem struct {
|
||||||
|
id int
|
||||||
|
|
||||||
|
from peers.Peer
|
||||||
|
fromMsg *tg.Message
|
||||||
|
file *tmedia.Media
|
||||||
|
|
||||||
|
to *os.File
|
||||||
|
|
||||||
|
opts Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iterElem) File() downloader.File { return i }
|
||||||
|
|
||||||
|
func (i *iterElem) To() io.WriterAt { return i.to }
|
||||||
|
|
||||||
|
func (i *iterElem) AsTakeout() bool { return i.opts.Takeout }
|
||||||
|
|
||||||
|
func (i *iterElem) Location() tg.InputFileLocationClass { return i.file.InputFileLoc }
|
||||||
|
|
||||||
|
func (i *iterElem) Name() string { return i.file.Name }
|
||||||
|
|
||||||
|
func (i *iterElem) Size() int64 { return i.file.Size }
|
||||||
|
|
||||||
|
func (i *iterElem) DC() int { return i.file.DC }
|
323
app/dl/iter.go
Normal file
323
app/dl/iter.go
Normal file
@ -0,0 +1,323 @@
|
|||||||
|
package dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/binary"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
"text/template"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-faster/errors"
|
||||||
|
"github.com/gotd/td/telegram/peers"
|
||||||
|
|
||||||
|
"github.com/iyear/tdl/pkg/dcpool"
|
||||||
|
"github.com/iyear/tdl/pkg/downloader"
|
||||||
|
"github.com/iyear/tdl/pkg/tmedia"
|
||||||
|
"github.com/iyear/tdl/pkg/tmessage"
|
||||||
|
"github.com/iyear/tdl/pkg/tplfunc"
|
||||||
|
"github.com/iyear/tdl/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
const tempExt = ".tmp"
|
||||||
|
|
||||||
|
type fileTemplate struct {
|
||||||
|
DialogID int64
|
||||||
|
MessageID int
|
||||||
|
MessageDate int64
|
||||||
|
FileName string
|
||||||
|
FileSize string
|
||||||
|
DownloadDate int64
|
||||||
|
}
|
||||||
|
|
||||||
|
type iter struct {
|
||||||
|
pool dcpool.Pool
|
||||||
|
manager *peers.Manager
|
||||||
|
dialogs []*tmessage.Dialog
|
||||||
|
tpl *template.Template
|
||||||
|
include map[string]struct{}
|
||||||
|
exclude map[string]struct{}
|
||||||
|
opts Options
|
||||||
|
|
||||||
|
mu *sync.Mutex
|
||||||
|
finished map[int]struct{}
|
||||||
|
fingerprint string
|
||||||
|
preSum []int
|
||||||
|
i, j int
|
||||||
|
elem downloader.Elem
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func newIter(pool dcpool.Pool, manager *peers.Manager, dialog [][]*tmessage.Dialog, opts Options) (*iter, error) {
|
||||||
|
tpl, err := template.New("dl").
|
||||||
|
Funcs(tplfunc.FuncMap(tplfunc.All...)).
|
||||||
|
Parse(opts.Template)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.Wrap(err, "parse template")
|
||||||
|
}
|
||||||
|
|
||||||
|
dialogs := flatDialogs(dialog)
|
||||||
|
// if msgs is empty, return error to avoid range out of index
|
||||||
|
if len(dialogs) == 0 {
|
||||||
|
return nil, errors.Errorf("you must specify at least one message")
|
||||||
|
}
|
||||||
|
|
||||||
|
// include and exclude
|
||||||
|
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
|
||||||
|
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
|
||||||
|
|
||||||
|
// to keep fingerprint stable
|
||||||
|
sortDialogs(dialogs, opts.Desc)
|
||||||
|
|
||||||
|
return &iter{
|
||||||
|
pool: pool,
|
||||||
|
manager: manager,
|
||||||
|
dialogs: dialogs,
|
||||||
|
opts: opts,
|
||||||
|
include: includeMap,
|
||||||
|
exclude: excludeMap,
|
||||||
|
tpl: tpl,
|
||||||
|
|
||||||
|
mu: &sync.Mutex{},
|
||||||
|
finished: make(map[int]struct{}),
|
||||||
|
fingerprint: fingerprint(dialogs),
|
||||||
|
preSum: preSum(dialogs),
|
||||||
|
i: 0,
|
||||||
|
j: 0,
|
||||||
|
elem: nil,
|
||||||
|
err: nil,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Next(ctx context.Context) bool {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
i.err = ctx.Err()
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
ok, skip := i.process(ctx)
|
||||||
|
if skip {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) process(ctx context.Context) (ret bool, skip bool) {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if i.j++; i.i < len(i.dialogs) && i.j >= len(i.dialogs[i.i].Messages) {
|
||||||
|
i.i++
|
||||||
|
i.j = 0
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// end of iteration or error occurred
|
||||||
|
if i.i >= len(i.dialogs) || i.err != nil {
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, msg := i.dialogs[i.i].Peer, i.dialogs[i.i].Messages[i.j]
|
||||||
|
|
||||||
|
// check if finished
|
||||||
|
if _, ok := i.finished[i.ij2n(i.i, i.j)]; ok {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
from, err := i.manager.FromInputPeer(ctx, peer)
|
||||||
|
if err != nil {
|
||||||
|
i.err = errors.Wrap(err, "resolve from input peer")
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
message, err := utils.Telegram.GetSingleMessage(ctx, i.pool.Default(ctx), peer, msg)
|
||||||
|
if err != nil {
|
||||||
|
i.err = errors.Wrap(err, "resolve message")
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
item, ok := tmedia.GetMedia(message)
|
||||||
|
if !ok {
|
||||||
|
i.err = errors.Errorf("can not get media from %d/%d message", from.ID(), message.ID)
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// process include and exclude
|
||||||
|
ext := filepath.Ext(item.Name)
|
||||||
|
if _, ok = i.include[ext]; len(i.include) > 0 && !ok {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
if _, ok = i.exclude[ext]; len(i.exclude) > 0 && ok {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
toName := bytes.Buffer{}
|
||||||
|
err = i.tpl.Execute(&toName, &fileTemplate{
|
||||||
|
DialogID: from.ID(),
|
||||||
|
MessageID: message.ID,
|
||||||
|
MessageDate: int64(message.Date),
|
||||||
|
FileName: item.Name,
|
||||||
|
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
|
||||||
|
DownloadDate: time.Now().Unix(),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
i.err = errors.Wrap(err, "execute template")
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if i.opts.SkipSame {
|
||||||
|
if stat, err := os.Stat(filepath.Join(i.opts.Dir, toName.String())); err == nil {
|
||||||
|
if utils.FS.GetNameWithoutExt(toName.String()) == utils.FS.GetNameWithoutExt(stat.Name()) &&
|
||||||
|
stat.Size() == item.Size {
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
filename := fmt.Sprintf("%s%s", toName.String(), tempExt)
|
||||||
|
path := filepath.Join(i.opts.Dir, filename)
|
||||||
|
|
||||||
|
// #113. If path contains dirs, create it. So now we support nested dirs.
|
||||||
|
if err = os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||||
|
i.err = errors.Wrap(err, "create dir")
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
to, err := os.Create(path)
|
||||||
|
if err != nil {
|
||||||
|
i.err = errors.Wrap(err, "create file")
|
||||||
|
return false, false
|
||||||
|
}
|
||||||
|
|
||||||
|
i.elem = &iterElem{
|
||||||
|
id: i.ij2n(i.i, i.j),
|
||||||
|
|
||||||
|
from: from,
|
||||||
|
fromMsg: message,
|
||||||
|
file: item,
|
||||||
|
|
||||||
|
to: to,
|
||||||
|
|
||||||
|
opts: i.opts,
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Value() downloader.Elem {
|
||||||
|
return i.elem
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Err() error {
|
||||||
|
return i.err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) SetFinished(finished map[int]struct{}) {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
i.finished = finished
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Finished() map[int]struct{} {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
return i.finished
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Fingerprint() string {
|
||||||
|
return i.fingerprint
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Finish(id int) {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
i.finished[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) Total() int {
|
||||||
|
i.mu.Lock()
|
||||||
|
defer i.mu.Unlock()
|
||||||
|
|
||||||
|
total := 0
|
||||||
|
for _, m := range i.dialogs {
|
||||||
|
total += len(m.Messages)
|
||||||
|
}
|
||||||
|
return total
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *iter) ij2n(ii, jj int) int {
|
||||||
|
return i.preSum[ii] + jj
|
||||||
|
}
|
||||||
|
|
||||||
|
func flatDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
|
||||||
|
res := make([]*tmessage.Dialog, 0)
|
||||||
|
for _, d := range dialogs {
|
||||||
|
if len(d) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
res = append(res, d...)
|
||||||
|
}
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
|
||||||
|
m := make(map[string]struct{})
|
||||||
|
for _, v := range data {
|
||||||
|
m[keyFn(v)] = struct{}{}
|
||||||
|
}
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
|
||||||
|
sort.Slice(dialogs, func(i, j int) bool {
|
||||||
|
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
|
||||||
|
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
|
||||||
|
})
|
||||||
|
|
||||||
|
for _, m := range dialogs {
|
||||||
|
sort.Slice(m.Messages, func(i, j int) bool {
|
||||||
|
if desc {
|
||||||
|
return m.Messages[i] > m.Messages[j]
|
||||||
|
}
|
||||||
|
return m.Messages[i] < m.Messages[j]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// preSum of dialogs
|
||||||
|
func preSum(dialogs []*tmessage.Dialog) []int {
|
||||||
|
sum := make([]int, len(dialogs)+1)
|
||||||
|
for i, m := range dialogs {
|
||||||
|
sum[i+1] = sum[i] + len(m.Messages)
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
func fingerprint(dialogs []*tmessage.Dialog) string {
|
||||||
|
endian := binary.BigEndian
|
||||||
|
buf, b := &bytes.Buffer{}, make([]byte, 8)
|
||||||
|
for _, m := range dialogs {
|
||||||
|
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
|
||||||
|
buf.Write(b)
|
||||||
|
for _, msg := range m.Messages {
|
||||||
|
endian.PutUint64(b, uint64(msg))
|
||||||
|
buf.Write(b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
|
||||||
|
}
|
121
app/dl/progress.go
Normal file
121
app/dl/progress.go
Normal file
@ -0,0 +1,121 @@
|
|||||||
|
package dl
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fatih/color"
|
||||||
|
"github.com/gabriel-vasile/mimetype"
|
||||||
|
"github.com/go-faster/errors"
|
||||||
|
pw "github.com/jedib0t/go-pretty/v6/progress"
|
||||||
|
|
||||||
|
"github.com/iyear/tdl/pkg/downloader"
|
||||||
|
"github.com/iyear/tdl/pkg/prog"
|
||||||
|
"github.com/iyear/tdl/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
type progress struct {
|
||||||
|
pw pw.Writer
|
||||||
|
trackers *sync.Map // map[ID]*pw.Tracker
|
||||||
|
opts Options
|
||||||
|
|
||||||
|
it *iter
|
||||||
|
}
|
||||||
|
|
||||||
|
func newProgress(p pw.Writer, it *iter, opts Options) *progress {
|
||||||
|
return &progress{
|
||||||
|
pw: p,
|
||||||
|
trackers: &sync.Map{},
|
||||||
|
opts: opts,
|
||||||
|
it: it,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) OnAdd(elem downloader.Elem) {
|
||||||
|
tracker := prog.AppendTracker(p.pw, utils.Byte.FormatBinaryBytes, p.processMessage(elem), elem.File().Size())
|
||||||
|
p.trackers.Store(elem.(*iterElem).id, tracker)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) OnDownload(elem downloader.Elem, state downloader.ProgressState) {
|
||||||
|
tracker, ok := p.trackers.Load(elem.(*iterElem).id)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
t := tracker.(*pw.Tracker)
|
||||||
|
t.UpdateTotal(state.Total)
|
||||||
|
t.SetValue(state.Downloaded)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) OnDone(elem downloader.Elem, err error) {
|
||||||
|
e := elem.(*iterElem)
|
||||||
|
|
||||||
|
tracker, ok := p.trackers.Load(e.id)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
t := tracker.(*pw.Tracker)
|
||||||
|
|
||||||
|
if err := e.to.Close(); err != nil {
|
||||||
|
p.fail(t, elem, errors.Wrap(err, "close file"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, context.Canceled) { // don't report user cancel
|
||||||
|
p.fail(t, elem, errors.Wrap(err, "progress"))
|
||||||
|
}
|
||||||
|
_ = os.Remove(e.to.Name()) // just try to remove temp file, ignore error
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.it.Finish(e.id)
|
||||||
|
|
||||||
|
if err := p.donePost(e); err != nil {
|
||||||
|
p.fail(t, elem, errors.Wrap(err, "post file"))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) donePost(elem *iterElem) error {
|
||||||
|
newfile := strings.TrimSuffix(filepath.Base(elem.to.Name()), tempExt)
|
||||||
|
|
||||||
|
if p.opts.RewriteExt {
|
||||||
|
mime, err := mimetype.DetectFile(elem.to.Name())
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "detect mime")
|
||||||
|
}
|
||||||
|
ext := mime.Extension()
|
||||||
|
if ext != "" && (filepath.Ext(newfile) != ext) {
|
||||||
|
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := os.Rename(elem.to.Name(), filepath.Join(p.opts.Dir, newfile)); err != nil {
|
||||||
|
return errors.Wrap(err, "rename file")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) fail(t *pw.Tracker, elem downloader.Elem, err error) {
|
||||||
|
p.pw.Log(color.RedString("%s error: %s", p.elemString(elem), err.Error()))
|
||||||
|
t.MarkAsErrored()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) processMessage(elem downloader.Elem) string {
|
||||||
|
return p.elemString(elem)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *progress) elemString(elem downloader.Elem) string {
|
||||||
|
e := elem.(*iterElem)
|
||||||
|
return fmt.Sprintf("%s(%d):%d -> %s",
|
||||||
|
e.from.VisibleName(),
|
||||||
|
e.from.ID(),
|
||||||
|
e.fromMsg.ID,
|
||||||
|
strings.TrimSuffix(e.to.Name(), tempExt))
|
||||||
|
}
|
@ -1,174 +0,0 @@
|
|||||||
package dliter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"path/filepath"
|
|
||||||
"text/template"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-faster/errors"
|
|
||||||
"github.com/gotd/td/telegram/peers"
|
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/downloader"
|
|
||||||
"github.com/iyear/tdl/pkg/storage"
|
|
||||||
"github.com/iyear/tdl/pkg/tmedia"
|
|
||||||
"github.com/iyear/tdl/pkg/tplfunc"
|
|
||||||
"github.com/iyear/tdl/pkg/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func New(ctx context.Context, opts *Options) (*Iter, error) {
|
|
||||||
tpl, err := template.New("dl").
|
|
||||||
Funcs(tplfunc.FuncMap(tplfunc.All...)).
|
|
||||||
Parse(opts.Template)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dialogs := collectDialogs(opts.Dialogs)
|
|
||||||
// if msgs is empty, return error to avoid range out of index
|
|
||||||
if len(dialogs) == 0 {
|
|
||||||
return nil, fmt.Errorf("you must specify at least one message")
|
|
||||||
}
|
|
||||||
|
|
||||||
// include and exclude
|
|
||||||
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
|
|
||||||
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
|
|
||||||
|
|
||||||
// to keep fingerprint stable
|
|
||||||
sortDialogs(dialogs, opts.Desc)
|
|
||||||
|
|
||||||
manager := peers.Options{Storage: storage.NewPeers(opts.KV)}.Build(opts.Pool.Default(ctx))
|
|
||||||
it := &Iter{
|
|
||||||
pool: opts.Pool,
|
|
||||||
dialogs: dialogs,
|
|
||||||
include: includeMap,
|
|
||||||
exclude: excludeMap,
|
|
||||||
curi: 0,
|
|
||||||
curj: -1,
|
|
||||||
preSum: preSum(dialogs),
|
|
||||||
finished: make(map[int]struct{}),
|
|
||||||
template: tpl,
|
|
||||||
manager: manager,
|
|
||||||
fingerprint: fingerprint(dialogs),
|
|
||||||
}
|
|
||||||
|
|
||||||
return it, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) Next(ctx context.Context) (*downloader.Item, error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
iter.mu.Lock()
|
|
||||||
iter.curj++
|
|
||||||
if iter.curj >= len(iter.dialogs[iter.curi].Messages) {
|
|
||||||
if iter.curi++; iter.curi >= len(iter.dialogs) {
|
|
||||||
return nil, errors.New("no more items")
|
|
||||||
}
|
|
||||||
iter.curj = 0
|
|
||||||
}
|
|
||||||
i, j := iter.curi, iter.curj
|
|
||||||
iter.mu.Unlock()
|
|
||||||
|
|
||||||
// check if finished
|
|
||||||
if _, ok := iter.finished[iter.ij2n(i, j)]; ok {
|
|
||||||
return nil, downloader.ErrSkip
|
|
||||||
}
|
|
||||||
|
|
||||||
return iter.item(ctx, i, j)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) item(ctx context.Context, i, j int) (*downloader.Item, error) {
|
|
||||||
peer, msg := iter.dialogs[i].Peer, iter.dialogs[i].Messages[j]
|
|
||||||
|
|
||||||
id := utils.Telegram.GetInputPeerID(peer)
|
|
||||||
|
|
||||||
message, err := utils.Telegram.GetSingleMessage(ctx, iter.pool.Default(ctx), peer, msg)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.Wrap(err, "resolve message")
|
|
||||||
}
|
|
||||||
|
|
||||||
item, ok := tmedia.GetMedia(message)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("can not get media from %d/%d message",
|
|
||||||
id, message.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// process include and exclude
|
|
||||||
ext := filepath.Ext(item.Name)
|
|
||||||
if len(iter.include) > 0 {
|
|
||||||
if _, ok = iter.include[ext]; !ok {
|
|
||||||
return nil, downloader.ErrSkip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(iter.exclude) > 0 {
|
|
||||||
if _, ok = iter.exclude[ext]; ok {
|
|
||||||
return nil, downloader.ErrSkip
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := bytes.Buffer{}
|
|
||||||
err = iter.template.Execute(&buf, &fileTemplate{
|
|
||||||
DialogID: id,
|
|
||||||
MessageID: message.ID,
|
|
||||||
MessageDate: int64(message.Date),
|
|
||||||
FileName: item.Name,
|
|
||||||
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
|
|
||||||
DownloadDate: time.Now().Unix(),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
item.Name = buf.String()
|
|
||||||
|
|
||||||
return &downloader.Item{
|
|
||||||
ID: iter.ij2n(i, j),
|
|
||||||
Media: item,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) Finish(_ context.Context, id int) error {
|
|
||||||
iter.mu.Lock()
|
|
||||||
defer iter.mu.Unlock()
|
|
||||||
|
|
||||||
iter.finished[id] = struct{}{}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) Total(_ context.Context) int {
|
|
||||||
iter.mu.Lock()
|
|
||||||
defer iter.mu.Unlock()
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
for _, m := range iter.dialogs {
|
|
||||||
total += len(m.Messages)
|
|
||||||
}
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) ij2n(i, j int) int {
|
|
||||||
return iter.preSum[i] + j
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) SetFinished(finished map[int]struct{}) {
|
|
||||||
iter.mu.Lock()
|
|
||||||
defer iter.mu.Unlock()
|
|
||||||
|
|
||||||
iter.finished = finished
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) Finished() map[int]struct{} {
|
|
||||||
iter.mu.Lock()
|
|
||||||
defer iter.mu.Unlock()
|
|
||||||
|
|
||||||
return iter.finished
|
|
||||||
}
|
|
||||||
|
|
||||||
func (iter *Iter) Fingerprint() string {
|
|
||||||
return iter.fingerprint
|
|
||||||
}
|
|
@ -1,71 +0,0 @@
|
|||||||
package dliter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/tmessage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPreSum(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
dialogs []*tmessage.Dialog
|
|
||||||
want []int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
|
|
||||||
want: []int{0, 3, 5},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
|
|
||||||
want: []int{0, 3, 6, 10},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}, {Messages: []int{1}}},
|
|
||||||
want: []int{0, 3, 6, 10, 11},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
got := preSum(tt.dialogs)
|
|
||||||
if !reflect.DeepEqual(got, tt.want) {
|
|
||||||
t.Errorf("preSum() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIter_ij2n(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
dialogs []*tmessage.Dialog
|
|
||||||
input []struct {
|
|
||||||
i, j int
|
|
||||||
}
|
|
||||||
want []int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
|
|
||||||
input: []struct {
|
|
||||||
i, j int
|
|
||||||
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}},
|
|
||||||
want: []int{0, 1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
|
|
||||||
input: []struct {
|
|
||||||
i, j int
|
|
||||||
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}, {2, 3}},
|
|
||||||
want: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
iter := &Iter{preSum: preSum(tt.dialogs), dialogs: tt.dialogs}
|
|
||||||
|
|
||||||
for i, input := range tt.input {
|
|
||||||
got := iter.ij2n(input.i, input.j)
|
|
||||||
if got != tt.want[i] {
|
|
||||||
t.Errorf("ij2n(%v, %v) = %v, want %v", input.i, input.j, got, tt.want[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,44 +0,0 @@
|
|||||||
package dliter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
"text/template"
|
|
||||||
|
|
||||||
"github.com/gotd/td/telegram/peers"
|
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/dcpool"
|
|
||||||
"github.com/iyear/tdl/pkg/kv"
|
|
||||||
"github.com/iyear/tdl/pkg/tmessage"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Options struct {
|
|
||||||
Pool dcpool.Pool
|
|
||||||
KV kv.KV
|
|
||||||
Template string
|
|
||||||
Include, Exclude []string
|
|
||||||
Desc bool
|
|
||||||
Dialogs [][]*tmessage.Dialog
|
|
||||||
}
|
|
||||||
|
|
||||||
type Iter struct {
|
|
||||||
pool dcpool.Pool
|
|
||||||
dialogs []*tmessage.Dialog
|
|
||||||
include, exclude map[string]struct{}
|
|
||||||
mu sync.Mutex
|
|
||||||
curi int
|
|
||||||
curj int
|
|
||||||
preSum []int
|
|
||||||
finished map[int]struct{}
|
|
||||||
template *template.Template
|
|
||||||
manager *peers.Manager
|
|
||||||
fingerprint string
|
|
||||||
}
|
|
||||||
|
|
||||||
type fileTemplate struct {
|
|
||||||
DialogID int64
|
|
||||||
MessageID int
|
|
||||||
MessageDate int64
|
|
||||||
FileName string
|
|
||||||
FileSize string
|
|
||||||
DownloadDate int64
|
|
||||||
}
|
|
@ -1,71 +0,0 @@
|
|||||||
package dliter
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/tmessage"
|
|
||||||
"github.com/iyear/tdl/pkg/utils"
|
|
||||||
)
|
|
||||||
|
|
||||||
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
|
|
||||||
sort.Slice(dialogs, func(i, j int) bool {
|
|
||||||
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
|
|
||||||
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
|
|
||||||
})
|
|
||||||
|
|
||||||
for _, m := range dialogs {
|
|
||||||
sort.Slice(m.Messages, func(i, j int) bool {
|
|
||||||
if desc {
|
|
||||||
return m.Messages[i] > m.Messages[j]
|
|
||||||
}
|
|
||||||
return m.Messages[i] < m.Messages[j]
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fingerprint(dialogs []*tmessage.Dialog) string {
|
|
||||||
endian := binary.BigEndian
|
|
||||||
buf, b := &bytes.Buffer{}, make([]byte, 8)
|
|
||||||
for _, m := range dialogs {
|
|
||||||
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
|
|
||||||
buf.Write(b)
|
|
||||||
for _, msg := range m.Messages {
|
|
||||||
endian.PutUint64(b, uint64(msg))
|
|
||||||
buf.Write(b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
|
|
||||||
m := make(map[string]struct{})
|
|
||||||
for _, v := range data {
|
|
||||||
m[keyFn(v)] = struct{}{}
|
|
||||||
}
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
func collectDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
|
|
||||||
res := make([]*tmessage.Dialog, 0)
|
|
||||||
for _, d := range dialogs {
|
|
||||||
if len(d) == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
res = append(res, d...)
|
|
||||||
}
|
|
||||||
return res
|
|
||||||
}
|
|
||||||
|
|
||||||
// preSum of dialogs
|
|
||||||
func preSum(dialogs []*tmessage.Dialog) []int {
|
|
||||||
sum := make([]int, len(dialogs)+1)
|
|
||||||
for i, m := range dialogs {
|
|
||||||
sum[i+1] = sum[i] + len(m.Messages)
|
|
||||||
}
|
|
||||||
return sum
|
|
||||||
}
|
|
@ -25,7 +25,7 @@ func NewDownload() *cobra.Command {
|
|||||||
}
|
}
|
||||||
|
|
||||||
opts.Template = viper.GetString(consts.FlagDlTemplate)
|
opts.Template = viper.GetString(consts.FlagDlTemplate)
|
||||||
return dl.Run(logger.Named(cmd.Context(), "dl"), &opts)
|
return dl.Run(logger.Named(cmd.Context(), "dl"), opts)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,168 +2,89 @@ package downloader
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/fatih/color"
|
|
||||||
"github.com/gabriel-vasile/mimetype"
|
|
||||||
"github.com/go-faster/errors"
|
"github.com/go-faster/errors"
|
||||||
"github.com/gotd/td/telegram/downloader"
|
"github.com/gotd/td/telegram/downloader"
|
||||||
"github.com/jedib0t/go-pretty/v6/progress"
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/dcpool"
|
"github.com/iyear/tdl/pkg/dcpool"
|
||||||
"github.com/iyear/tdl/pkg/logger"
|
"github.com/iyear/tdl/pkg/logger"
|
||||||
"github.com/iyear/tdl/pkg/prog"
|
|
||||||
"github.com/iyear/tdl/pkg/utils"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const TempExt = ".tmp"
|
|
||||||
|
|
||||||
var formatter = utils.Byte.FormatBinaryBytes
|
|
||||||
|
|
||||||
type Downloader struct {
|
type Downloader struct {
|
||||||
pw progress.Writer
|
|
||||||
opts Options
|
opts Options
|
||||||
}
|
}
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Pool dcpool.Pool
|
Pool dcpool.Pool
|
||||||
Dir string
|
PartSize int
|
||||||
RewriteExt bool
|
Threads int
|
||||||
SkipSame bool
|
Iter Iter
|
||||||
PartSize int
|
Progress Progress
|
||||||
Threads int
|
|
||||||
Iter Iter
|
|
||||||
Takeout bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(opts Options) (*Downloader, error) {
|
func New(opts Options) *Downloader {
|
||||||
return &Downloader{
|
return &Downloader{
|
||||||
pw: prog.New(formatter),
|
|
||||||
opts: opts,
|
opts: opts,
|
||||||
}, nil
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Downloader) Download(ctx context.Context, limit int) error {
|
func (d *Downloader) Download(ctx context.Context, limit int) error {
|
||||||
color.Green("All files will be downloaded to '%s' dir", d.opts.Dir)
|
wg, wgctx := errgroup.WithContext(ctx)
|
||||||
|
|
||||||
total := d.opts.Iter.Total(ctx)
|
|
||||||
d.pw.SetNumTrackersExpected(total)
|
|
||||||
|
|
||||||
go d.renderPinned(ctx, d.pw)
|
|
||||||
go d.pw.Render()
|
|
||||||
|
|
||||||
wg, errctx := errgroup.WithContext(ctx)
|
|
||||||
wg.SetLimit(limit)
|
wg.SetLimit(limit)
|
||||||
|
|
||||||
for i := 0; i < total; i++ {
|
for d.opts.Iter.Next(wgctx) {
|
||||||
item, err := d.opts.Iter.Next(errctx)
|
elem := d.opts.Iter.Value()
|
||||||
if err != nil {
|
|
||||||
logger.From(errctx).Debug("Iter next failed",
|
|
||||||
zap.Int("index", i), zap.String("error", err.Error()))
|
|
||||||
// skip error means we don't need to log error
|
|
||||||
if !errors.Is(err, ErrSkip) && !errors.Is(err, context.Canceled) {
|
|
||||||
d.pw.Log(color.RedString("failed: %v", err))
|
|
||||||
}
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
wg.Go(func() error {
|
wg.Go(func() (rerr error) {
|
||||||
return d.download(errctx, item)
|
d.opts.Progress.OnAdd(elem)
|
||||||
|
defer func() { d.opts.Progress.OnDone(elem, rerr) }()
|
||||||
|
|
||||||
|
if err := d.download(wgctx, elem); err != nil {
|
||||||
|
// canceled by user, so we directly return error to stop all
|
||||||
|
if errors.Is(err, context.Canceled) {
|
||||||
|
return errors.Wrap(err, "download")
|
||||||
|
}
|
||||||
|
|
||||||
|
// don't return error, just log it
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
err := wg.Wait()
|
if err := d.opts.Iter.Err(); err != nil {
|
||||||
if err != nil {
|
return errors.Wrap(err, "iter")
|
||||||
d.pw.Stop()
|
|
||||||
for d.pw.IsRenderInProgress() {
|
|
||||||
time.Sleep(time.Millisecond * 10)
|
|
||||||
}
|
|
||||||
|
|
||||||
// canceled error is ignored by gotd, so we can't detect it in main entry
|
|
||||||
if errors.Is(err, context.Canceled) {
|
|
||||||
color.Red("Download aborted by user")
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
prog.Wait(ctx, d.pw)
|
return wg.Wait()
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Downloader) download(ctx context.Context, item *Item) error {
|
func (d *Downloader) download(ctx context.Context, elem Elem) error {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return ctx.Err()
|
return ctx.Err()
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.From(ctx).Debug("Start download item",
|
logger.From(ctx).Debug("Start download elem",
|
||||||
zap.Any("item", item))
|
zap.Any("elem", elem))
|
||||||
|
|
||||||
if d.opts.SkipSame {
|
client := d.opts.Pool.Client(ctx, elem.File().DC())
|
||||||
if stat, err := os.Stat(filepath.Join(d.opts.Dir, item.Name)); err == nil {
|
if elem.AsTakeout() {
|
||||||
if utils.FS.GetNameWithoutExt(item.Name) == utils.FS.GetNameWithoutExt(stat.Name()) &&
|
client = d.opts.Pool.Takeout(ctx, elem.File().DC())
|
||||||
stat.Size() == item.Size {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
tracker := prog.AppendTracker(d.pw, formatter, item.Name, item.Size)
|
|
||||||
filename := fmt.Sprintf("%s%s", item.Name, TempExt)
|
|
||||||
path := filepath.Join(d.opts.Dir, filename)
|
|
||||||
|
|
||||||
// #113. If path contains dirs, create it. So now we support nested dirs.
|
|
||||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := os.Create(path)
|
_, err := downloader.NewDownloader().WithPartSize(d.opts.PartSize).
|
||||||
|
Download(client, elem.File().Location()).
|
||||||
|
WithThreads(d.bestThreads(elem.File().Size())).
|
||||||
|
Parallel(ctx, newWriteAt(elem, d.opts.Progress, d.opts.PartSize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return errors.Wrap(err, "download")
|
||||||
}
|
}
|
||||||
|
|
||||||
client := d.opts.Pool.Client(ctx, item.DC)
|
return nil
|
||||||
if d.opts.Takeout {
|
|
||||||
client = d.opts.Pool.Takeout(ctx, item.DC)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = downloader.NewDownloader().WithPartSize(d.opts.PartSize).
|
|
||||||
Download(client, item.InputFileLoc).
|
|
||||||
WithThreads(d.bestThreads(item.Size)).
|
|
||||||
Parallel(ctx, newWriteAt(f, tracker, d.opts.PartSize))
|
|
||||||
if err := f.Close(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// rename file, remove temp extension and add real extension
|
|
||||||
newfile := strings.TrimSuffix(filename, TempExt)
|
|
||||||
|
|
||||||
if d.opts.RewriteExt {
|
|
||||||
mime, err := mimetype.DetectFile(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
ext := mime.Extension()
|
|
||||||
if ext != "" && (filepath.Ext(newfile) != ext) {
|
|
||||||
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = os.Rename(path, filepath.Join(d.opts.Dir, newfile)); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return d.opts.Iter.Finish(ctx, item.ID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// threads level
|
// threads level
|
||||||
|
@ -2,20 +2,26 @@ package downloader
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"io"
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/tmedia"
|
"github.com/gotd/td/tg"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrSkip = errors.New("skip")
|
|
||||||
|
|
||||||
type Iter interface {
|
type Iter interface {
|
||||||
Next(ctx context.Context) (*Item, error)
|
Next(ctx context.Context) bool
|
||||||
Finish(ctx context.Context, id int) error
|
Value() Elem
|
||||||
Total(ctx context.Context) int
|
Err() error
|
||||||
}
|
}
|
||||||
|
|
||||||
type Item struct {
|
type Elem interface {
|
||||||
ID int // unique in iter
|
File() File
|
||||||
*tmedia.Media
|
To() io.WriterAt
|
||||||
|
|
||||||
|
AsTakeout() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type File interface {
|
||||||
|
Location() tg.InputFileLocationClass
|
||||||
|
Size() int64
|
||||||
|
DC() int
|
||||||
}
|
}
|
||||||
|
57
pkg/downloader/progress.go
Normal file
57
pkg/downloader/progress.go
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
package downloader
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"go.uber.org/atomic"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Progress interface {
|
||||||
|
OnAdd(elem Elem)
|
||||||
|
OnDownload(elem Elem, state ProgressState)
|
||||||
|
OnDone(elem Elem, err error)
|
||||||
|
// TODO: OnLog to log something that is not an error but should be sent to the user
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProgressState struct {
|
||||||
|
Downloaded int64
|
||||||
|
Total int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeAt wrapper for file to use progress bar
|
||||||
|
//
|
||||||
|
// do not need mutex because gotd has use syncio.WriteAt
|
||||||
|
type writeAt struct {
|
||||||
|
elem Elem
|
||||||
|
progress Progress
|
||||||
|
partSize int
|
||||||
|
|
||||||
|
downloaded *atomic.Int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func newWriteAt(elem Elem, progress Progress, partSize int) *writeAt {
|
||||||
|
return &writeAt{
|
||||||
|
elem: elem,
|
||||||
|
progress: progress,
|
||||||
|
partSize: partSize,
|
||||||
|
downloaded: atomic.NewInt64(0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
|
||||||
|
at, err := w.elem.To().WriteAt(p, off)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// some small files may finish too fast, terminal history may not be overwritten
|
||||||
|
// this is just a simple way to avoid the problem
|
||||||
|
if at < w.partSize { // last part(every file only exec once)
|
||||||
|
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
|
||||||
|
}
|
||||||
|
w.progress.OnDownload(w.elem, ProgressState{
|
||||||
|
Downloaded: w.downloaded.Add(int64(at)),
|
||||||
|
Total: w.elem.File().Size(),
|
||||||
|
})
|
||||||
|
return at, nil
|
||||||
|
}
|
@ -1,29 +0,0 @@
|
|||||||
package downloader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jedib0t/go-pretty/v6/progress"
|
|
||||||
|
|
||||||
"github.com/iyear/tdl/pkg/ps"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (d *Downloader) renderPinned(ctx context.Context, pw progress.Writer) {
|
|
||||||
f := func() { pw.SetPinnedMessages(strings.Join(ps.Humanize(ctx), " ")) }
|
|
||||||
f()
|
|
||||||
|
|
||||||
ticker := time.NewTicker(time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
pw.SetPinnedMessages()
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
f()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,41 +0,0 @@
|
|||||||
package downloader
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/jedib0t/go-pretty/v6/progress"
|
|
||||||
)
|
|
||||||
|
|
||||||
// writeAt wrapper for file to use progress bar
|
|
||||||
//
|
|
||||||
// do not need mutex because gotd has use syncio.WriteAt
|
|
||||||
type writeAt struct {
|
|
||||||
f *os.File
|
|
||||||
tracker *progress.Tracker
|
|
||||||
partSize int
|
|
||||||
}
|
|
||||||
|
|
||||||
func newWriteAt(f *os.File, tracker *progress.Tracker, partSize int) *writeAt {
|
|
||||||
return &writeAt{
|
|
||||||
f: f,
|
|
||||||
tracker: tracker,
|
|
||||||
partSize: partSize,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
|
|
||||||
at, err := w.f.WriteAt(p, off)
|
|
||||||
if err != nil {
|
|
||||||
w.tracker.MarkAsErrored()
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// some small files may finish too fast, terminal history may not be overwritten
|
|
||||||
// this is just a simple way to avoid the problem
|
|
||||||
if at < w.partSize { // last part(every file only exec once)
|
|
||||||
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
|
|
||||||
}
|
|
||||||
w.tracker.Increment(int64(at))
|
|
||||||
return at, nil
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user