mirror of
https://github.com/iyear/tdl
synced 2025-01-09 04:17:35 +08:00
refactor(dl): extract to interface and impl iter in app
This commit is contained in:
parent
b7bbbf60f1
commit
169e47913a
59
app/dl/dl.go
59
app/dl/dl.go
@ -8,11 +8,11 @@ import (
|
||||
"github.com/AlecAivazis/survey/v2"
|
||||
"github.com/fatih/color"
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
"github.com/spf13/viper"
|
||||
"go.uber.org/multierr"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/iyear/tdl/app/internal/dliter"
|
||||
"github.com/iyear/tdl/app/internal/tgc"
|
||||
"github.com/iyear/tdl/pkg/consts"
|
||||
"github.com/iyear/tdl/pkg/dcpool"
|
||||
@ -20,7 +20,10 @@ import (
|
||||
"github.com/iyear/tdl/pkg/key"
|
||||
"github.com/iyear/tdl/pkg/kv"
|
||||
"github.com/iyear/tdl/pkg/logger"
|
||||
"github.com/iyear/tdl/pkg/prog"
|
||||
"github.com/iyear/tdl/pkg/storage"
|
||||
"github.com/iyear/tdl/pkg/tmessage"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
@ -48,7 +51,7 @@ type parser struct {
|
||||
Parser tmessage.ParseSource
|
||||
}
|
||||
|
||||
func Run(ctx context.Context, opts *Options) error {
|
||||
func Run(ctx context.Context, opts Options) error {
|
||||
c, kvd, err := tgc.NoLogin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -78,22 +81,16 @@ func Run(ctx context.Context, opts *Options) error {
|
||||
return serve(ctx, kvd, pool, dialogs, opts.Port, opts.Takeout)
|
||||
}
|
||||
|
||||
iter, err := dliter.New(ctx, &dliter.Options{
|
||||
Pool: pool,
|
||||
KV: kvd,
|
||||
Template: opts.Template,
|
||||
Include: opts.Include,
|
||||
Exclude: opts.Exclude,
|
||||
Desc: opts.Desc,
|
||||
Dialogs: dialogs,
|
||||
})
|
||||
manager := peers.Options{Storage: storage.NewPeers(kvd)}.Build(pool.Default(ctx))
|
||||
|
||||
it, err := newIter(pool, manager, dialogs, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !opts.Restart {
|
||||
// resume download and ask user to continue
|
||||
if err = resume(ctx, kvd, iter, !opts.Continue); err != nil {
|
||||
if err = resume(ctx, kvd, it, !opts.Continue); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
@ -102,37 +99,39 @@ func Run(ctx context.Context, opts *Options) error {
|
||||
|
||||
defer func() { // save progress
|
||||
if rerr != nil { // download is interrupted
|
||||
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, iter))
|
||||
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, it))
|
||||
} else { // if finished, we should clear resume key
|
||||
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(iter.Fingerprint())))
|
||||
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(it.Fingerprint())))
|
||||
}
|
||||
}()
|
||||
|
||||
dlProgress := prog.New(utils.Byte.FormatBinaryBytes)
|
||||
dlProgress.SetNumTrackersExpected(it.Total())
|
||||
prog.EnablePS(ctx, dlProgress)
|
||||
|
||||
options := downloader.Options{
|
||||
Pool: pool,
|
||||
Dir: opts.Dir,
|
||||
RewriteExt: opts.RewriteExt,
|
||||
SkipSame: opts.SkipSame,
|
||||
PartSize: viper.GetInt(consts.FlagPartSize),
|
||||
Threads: viper.GetInt(consts.FlagThreads),
|
||||
Iter: iter,
|
||||
Takeout: opts.Takeout,
|
||||
Iter: it,
|
||||
Progress: newProgress(dlProgress, it, opts),
|
||||
}
|
||||
limit := viper.GetInt(consts.FlagLimit)
|
||||
|
||||
logger.From(ctx).Info("Start download",
|
||||
zap.String("dir", options.Dir),
|
||||
zap.Bool("rewrite_ext", options.RewriteExt),
|
||||
zap.Bool("skip_same", options.SkipSame),
|
||||
zap.String("dir", opts.Dir),
|
||||
zap.Bool("rewrite_ext", opts.RewriteExt),
|
||||
zap.Bool("skip_same", opts.SkipSame),
|
||||
zap.Int("part_size", options.PartSize),
|
||||
zap.Int("threads", options.Threads),
|
||||
zap.Int("limit", limit))
|
||||
|
||||
dl, err := downloader.New(options)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create downloader")
|
||||
}
|
||||
return dl.Download(ctx, limit)
|
||||
color.Green("All files will be downloaded to '%s' dir", opts.Dir)
|
||||
|
||||
go dlProgress.Render()
|
||||
defer prog.Wait(ctx, dlProgress)
|
||||
|
||||
return downloader.New(options).Download(ctx, limit)
|
||||
})
|
||||
}
|
||||
|
||||
@ -148,7 +147,7 @@ func collectDialogs(parsers []parser) ([][]*tmessage.Dialog, error) {
|
||||
return dialogs, nil
|
||||
}
|
||||
|
||||
func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
||||
func resume(ctx context.Context, kvd kv.KV, iter *iter, ask bool) error {
|
||||
logger.From(ctx).Debug("Check resume key",
|
||||
zap.String("fingerprint", iter.Fingerprint()))
|
||||
|
||||
@ -171,7 +170,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
||||
}
|
||||
|
||||
confirm := false
|
||||
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total(ctx))
|
||||
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total())
|
||||
if ask {
|
||||
if err = survey.AskOne(&survey.Confirm{
|
||||
Message: color.YellowString(resumeStr + "?"),
|
||||
@ -195,7 +194,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func saveProgress(ctx context.Context, kvd kv.KV, it *dliter.Iter) error {
|
||||
func saveProgress(ctx context.Context, kvd kv.KV, it *iter) error {
|
||||
finished := it.Finished()
|
||||
logger.From(ctx).Debug("Save progress",
|
||||
zap.Int("finished", len(finished)))
|
||||
|
38
app/dl/elem.go
Normal file
38
app/dl/elem.go
Normal file
@ -0,0 +1,38 @@
|
||||
package dl
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
"github.com/gotd/td/tg"
|
||||
|
||||
"github.com/iyear/tdl/pkg/downloader"
|
||||
"github.com/iyear/tdl/pkg/tmedia"
|
||||
)
|
||||
|
||||
type iterElem struct {
|
||||
id int
|
||||
|
||||
from peers.Peer
|
||||
fromMsg *tg.Message
|
||||
file *tmedia.Media
|
||||
|
||||
to *os.File
|
||||
|
||||
opts Options
|
||||
}
|
||||
|
||||
func (i *iterElem) File() downloader.File { return i }
|
||||
|
||||
func (i *iterElem) To() io.WriterAt { return i.to }
|
||||
|
||||
func (i *iterElem) AsTakeout() bool { return i.opts.Takeout }
|
||||
|
||||
func (i *iterElem) Location() tg.InputFileLocationClass { return i.file.InputFileLoc }
|
||||
|
||||
func (i *iterElem) Name() string { return i.file.Name }
|
||||
|
||||
func (i *iterElem) Size() int64 { return i.file.Size }
|
||||
|
||||
func (i *iterElem) DC() int { return i.file.DC }
|
323
app/dl/iter.go
Normal file
323
app/dl/iter.go
Normal file
@ -0,0 +1,323 @@
|
||||
package dl
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"sync"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
|
||||
"github.com/iyear/tdl/pkg/dcpool"
|
||||
"github.com/iyear/tdl/pkg/downloader"
|
||||
"github.com/iyear/tdl/pkg/tmedia"
|
||||
"github.com/iyear/tdl/pkg/tmessage"
|
||||
"github.com/iyear/tdl/pkg/tplfunc"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
const tempExt = ".tmp"
|
||||
|
||||
type fileTemplate struct {
|
||||
DialogID int64
|
||||
MessageID int
|
||||
MessageDate int64
|
||||
FileName string
|
||||
FileSize string
|
||||
DownloadDate int64
|
||||
}
|
||||
|
||||
type iter struct {
|
||||
pool dcpool.Pool
|
||||
manager *peers.Manager
|
||||
dialogs []*tmessage.Dialog
|
||||
tpl *template.Template
|
||||
include map[string]struct{}
|
||||
exclude map[string]struct{}
|
||||
opts Options
|
||||
|
||||
mu *sync.Mutex
|
||||
finished map[int]struct{}
|
||||
fingerprint string
|
||||
preSum []int
|
||||
i, j int
|
||||
elem downloader.Elem
|
||||
err error
|
||||
}
|
||||
|
||||
func newIter(pool dcpool.Pool, manager *peers.Manager, dialog [][]*tmessage.Dialog, opts Options) (*iter, error) {
|
||||
tpl, err := template.New("dl").
|
||||
Funcs(tplfunc.FuncMap(tplfunc.All...)).
|
||||
Parse(opts.Template)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parse template")
|
||||
}
|
||||
|
||||
dialogs := flatDialogs(dialog)
|
||||
// if msgs is empty, return error to avoid range out of index
|
||||
if len(dialogs) == 0 {
|
||||
return nil, errors.Errorf("you must specify at least one message")
|
||||
}
|
||||
|
||||
// include and exclude
|
||||
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
|
||||
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
|
||||
|
||||
// to keep fingerprint stable
|
||||
sortDialogs(dialogs, opts.Desc)
|
||||
|
||||
return &iter{
|
||||
pool: pool,
|
||||
manager: manager,
|
||||
dialogs: dialogs,
|
||||
opts: opts,
|
||||
include: includeMap,
|
||||
exclude: excludeMap,
|
||||
tpl: tpl,
|
||||
|
||||
mu: &sync.Mutex{},
|
||||
finished: make(map[int]struct{}),
|
||||
fingerprint: fingerprint(dialogs),
|
||||
preSum: preSum(dialogs),
|
||||
i: 0,
|
||||
j: 0,
|
||||
elem: nil,
|
||||
err: nil,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *iter) Next(ctx context.Context) bool {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
i.err = ctx.Err()
|
||||
return false
|
||||
default:
|
||||
}
|
||||
|
||||
for {
|
||||
ok, skip := i.process(ctx)
|
||||
if skip {
|
||||
continue
|
||||
}
|
||||
|
||||
return ok
|
||||
}
|
||||
}
|
||||
|
||||
func (i *iter) process(ctx context.Context) (ret bool, skip bool) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
defer func() {
|
||||
if i.j++; i.i < len(i.dialogs) && i.j >= len(i.dialogs[i.i].Messages) {
|
||||
i.i++
|
||||
i.j = 0
|
||||
}
|
||||
}()
|
||||
|
||||
// end of iteration or error occurred
|
||||
if i.i >= len(i.dialogs) || i.err != nil {
|
||||
return false, false
|
||||
}
|
||||
|
||||
peer, msg := i.dialogs[i.i].Peer, i.dialogs[i.i].Messages[i.j]
|
||||
|
||||
// check if finished
|
||||
if _, ok := i.finished[i.ij2n(i.i, i.j)]; ok {
|
||||
return false, true
|
||||
}
|
||||
|
||||
from, err := i.manager.FromInputPeer(ctx, peer)
|
||||
if err != nil {
|
||||
i.err = errors.Wrap(err, "resolve from input peer")
|
||||
return false, false
|
||||
}
|
||||
|
||||
message, err := utils.Telegram.GetSingleMessage(ctx, i.pool.Default(ctx), peer, msg)
|
||||
if err != nil {
|
||||
i.err = errors.Wrap(err, "resolve message")
|
||||
return false, false
|
||||
}
|
||||
|
||||
item, ok := tmedia.GetMedia(message)
|
||||
if !ok {
|
||||
i.err = errors.Errorf("can not get media from %d/%d message", from.ID(), message.ID)
|
||||
return false, false
|
||||
}
|
||||
|
||||
// process include and exclude
|
||||
ext := filepath.Ext(item.Name)
|
||||
if _, ok = i.include[ext]; len(i.include) > 0 && !ok {
|
||||
return false, true
|
||||
}
|
||||
if _, ok = i.exclude[ext]; len(i.exclude) > 0 && ok {
|
||||
return false, true
|
||||
}
|
||||
|
||||
toName := bytes.Buffer{}
|
||||
err = i.tpl.Execute(&toName, &fileTemplate{
|
||||
DialogID: from.ID(),
|
||||
MessageID: message.ID,
|
||||
MessageDate: int64(message.Date),
|
||||
FileName: item.Name,
|
||||
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
|
||||
DownloadDate: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
i.err = errors.Wrap(err, "execute template")
|
||||
return false, false
|
||||
}
|
||||
|
||||
if i.opts.SkipSame {
|
||||
if stat, err := os.Stat(filepath.Join(i.opts.Dir, toName.String())); err == nil {
|
||||
if utils.FS.GetNameWithoutExt(toName.String()) == utils.FS.GetNameWithoutExt(stat.Name()) &&
|
||||
stat.Size() == item.Size {
|
||||
return false, true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
filename := fmt.Sprintf("%s%s", toName.String(), tempExt)
|
||||
path := filepath.Join(i.opts.Dir, filename)
|
||||
|
||||
// #113. If path contains dirs, create it. So now we support nested dirs.
|
||||
if err = os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
i.err = errors.Wrap(err, "create dir")
|
||||
return false, false
|
||||
}
|
||||
|
||||
to, err := os.Create(path)
|
||||
if err != nil {
|
||||
i.err = errors.Wrap(err, "create file")
|
||||
return false, false
|
||||
}
|
||||
|
||||
i.elem = &iterElem{
|
||||
id: i.ij2n(i.i, i.j),
|
||||
|
||||
from: from,
|
||||
fromMsg: message,
|
||||
file: item,
|
||||
|
||||
to: to,
|
||||
|
||||
opts: i.opts,
|
||||
}
|
||||
|
||||
return true, false
|
||||
}
|
||||
|
||||
func (i *iter) Value() downloader.Elem {
|
||||
return i.elem
|
||||
}
|
||||
|
||||
func (i *iter) Err() error {
|
||||
return i.err
|
||||
}
|
||||
|
||||
func (i *iter) SetFinished(finished map[int]struct{}) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
i.finished = finished
|
||||
}
|
||||
|
||||
func (i *iter) Finished() map[int]struct{} {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
return i.finished
|
||||
}
|
||||
|
||||
func (i *iter) Fingerprint() string {
|
||||
return i.fingerprint
|
||||
}
|
||||
|
||||
func (i *iter) Finish(id int) {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
i.finished[id] = struct{}{}
|
||||
}
|
||||
|
||||
func (i *iter) Total() int {
|
||||
i.mu.Lock()
|
||||
defer i.mu.Unlock()
|
||||
|
||||
total := 0
|
||||
for _, m := range i.dialogs {
|
||||
total += len(m.Messages)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (i *iter) ij2n(ii, jj int) int {
|
||||
return i.preSum[ii] + jj
|
||||
}
|
||||
|
||||
func flatDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
|
||||
res := make([]*tmessage.Dialog, 0)
|
||||
for _, d := range dialogs {
|
||||
if len(d) == 0 {
|
||||
continue
|
||||
}
|
||||
res = append(res, d...)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, v := range data {
|
||||
m[keyFn(v)] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
|
||||
sort.Slice(dialogs, func(i, j int) bool {
|
||||
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
|
||||
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
|
||||
})
|
||||
|
||||
for _, m := range dialogs {
|
||||
sort.Slice(m.Messages, func(i, j int) bool {
|
||||
if desc {
|
||||
return m.Messages[i] > m.Messages[j]
|
||||
}
|
||||
return m.Messages[i] < m.Messages[j]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// preSum of dialogs
|
||||
func preSum(dialogs []*tmessage.Dialog) []int {
|
||||
sum := make([]int, len(dialogs)+1)
|
||||
for i, m := range dialogs {
|
||||
sum[i+1] = sum[i] + len(m.Messages)
|
||||
}
|
||||
return sum
|
||||
}
|
||||
|
||||
func fingerprint(dialogs []*tmessage.Dialog) string {
|
||||
endian := binary.BigEndian
|
||||
buf, b := &bytes.Buffer{}, make([]byte, 8)
|
||||
for _, m := range dialogs {
|
||||
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
|
||||
buf.Write(b)
|
||||
for _, msg := range m.Messages {
|
||||
endian.PutUint64(b, uint64(msg))
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
|
||||
}
|
121
app/dl/progress.go
Normal file
121
app/dl/progress.go
Normal file
@ -0,0 +1,121 @@
|
||||
package dl
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/go-faster/errors"
|
||||
pw "github.com/jedib0t/go-pretty/v6/progress"
|
||||
|
||||
"github.com/iyear/tdl/pkg/downloader"
|
||||
"github.com/iyear/tdl/pkg/prog"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
type progress struct {
|
||||
pw pw.Writer
|
||||
trackers *sync.Map // map[ID]*pw.Tracker
|
||||
opts Options
|
||||
|
||||
it *iter
|
||||
}
|
||||
|
||||
func newProgress(p pw.Writer, it *iter, opts Options) *progress {
|
||||
return &progress{
|
||||
pw: p,
|
||||
trackers: &sync.Map{},
|
||||
opts: opts,
|
||||
it: it,
|
||||
}
|
||||
}
|
||||
|
||||
func (p *progress) OnAdd(elem downloader.Elem) {
|
||||
tracker := prog.AppendTracker(p.pw, utils.Byte.FormatBinaryBytes, p.processMessage(elem), elem.File().Size())
|
||||
p.trackers.Store(elem.(*iterElem).id, tracker)
|
||||
}
|
||||
|
||||
func (p *progress) OnDownload(elem downloader.Elem, state downloader.ProgressState) {
|
||||
tracker, ok := p.trackers.Load(elem.(*iterElem).id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
t := tracker.(*pw.Tracker)
|
||||
t.UpdateTotal(state.Total)
|
||||
t.SetValue(state.Downloaded)
|
||||
}
|
||||
|
||||
func (p *progress) OnDone(elem downloader.Elem, err error) {
|
||||
e := elem.(*iterElem)
|
||||
|
||||
tracker, ok := p.trackers.Load(e.id)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
t := tracker.(*pw.Tracker)
|
||||
|
||||
if err := e.to.Close(); err != nil {
|
||||
p.fail(t, elem, errors.Wrap(err, "close file"))
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) { // don't report user cancel
|
||||
p.fail(t, elem, errors.Wrap(err, "progress"))
|
||||
}
|
||||
_ = os.Remove(e.to.Name()) // just try to remove temp file, ignore error
|
||||
return
|
||||
}
|
||||
|
||||
p.it.Finish(e.id)
|
||||
|
||||
if err := p.donePost(e); err != nil {
|
||||
p.fail(t, elem, errors.Wrap(err, "post file"))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (p *progress) donePost(elem *iterElem) error {
|
||||
newfile := strings.TrimSuffix(filepath.Base(elem.to.Name()), tempExt)
|
||||
|
||||
if p.opts.RewriteExt {
|
||||
mime, err := mimetype.DetectFile(elem.to.Name())
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "detect mime")
|
||||
}
|
||||
ext := mime.Extension()
|
||||
if ext != "" && (filepath.Ext(newfile) != ext) {
|
||||
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
|
||||
}
|
||||
}
|
||||
|
||||
if err := os.Rename(elem.to.Name(), filepath.Join(p.opts.Dir, newfile)); err != nil {
|
||||
return errors.Wrap(err, "rename file")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *progress) fail(t *pw.Tracker, elem downloader.Elem, err error) {
|
||||
p.pw.Log(color.RedString("%s error: %s", p.elemString(elem), err.Error()))
|
||||
t.MarkAsErrored()
|
||||
}
|
||||
|
||||
func (p *progress) processMessage(elem downloader.Elem) string {
|
||||
return p.elemString(elem)
|
||||
}
|
||||
|
||||
func (p *progress) elemString(elem downloader.Elem) string {
|
||||
e := elem.(*iterElem)
|
||||
return fmt.Sprintf("%s(%d):%d -> %s",
|
||||
e.from.VisibleName(),
|
||||
e.from.ID(),
|
||||
e.fromMsg.ID,
|
||||
strings.TrimSuffix(e.to.Name(), tempExt))
|
||||
}
|
@ -1,174 +0,0 @@
|
||||
package dliter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"text/template"
|
||||
"time"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
|
||||
"github.com/iyear/tdl/pkg/downloader"
|
||||
"github.com/iyear/tdl/pkg/storage"
|
||||
"github.com/iyear/tdl/pkg/tmedia"
|
||||
"github.com/iyear/tdl/pkg/tplfunc"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
func New(ctx context.Context, opts *Options) (*Iter, error) {
|
||||
tpl, err := template.New("dl").
|
||||
Funcs(tplfunc.FuncMap(tplfunc.All...)).
|
||||
Parse(opts.Template)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dialogs := collectDialogs(opts.Dialogs)
|
||||
// if msgs is empty, return error to avoid range out of index
|
||||
if len(dialogs) == 0 {
|
||||
return nil, fmt.Errorf("you must specify at least one message")
|
||||
}
|
||||
|
||||
// include and exclude
|
||||
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
|
||||
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
|
||||
|
||||
// to keep fingerprint stable
|
||||
sortDialogs(dialogs, opts.Desc)
|
||||
|
||||
manager := peers.Options{Storage: storage.NewPeers(opts.KV)}.Build(opts.Pool.Default(ctx))
|
||||
it := &Iter{
|
||||
pool: opts.Pool,
|
||||
dialogs: dialogs,
|
||||
include: includeMap,
|
||||
exclude: excludeMap,
|
||||
curi: 0,
|
||||
curj: -1,
|
||||
preSum: preSum(dialogs),
|
||||
finished: make(map[int]struct{}),
|
||||
template: tpl,
|
||||
manager: manager,
|
||||
fingerprint: fingerprint(dialogs),
|
||||
}
|
||||
|
||||
return it, nil
|
||||
}
|
||||
|
||||
func (iter *Iter) Next(ctx context.Context) (*downloader.Item, error) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
iter.mu.Lock()
|
||||
iter.curj++
|
||||
if iter.curj >= len(iter.dialogs[iter.curi].Messages) {
|
||||
if iter.curi++; iter.curi >= len(iter.dialogs) {
|
||||
return nil, errors.New("no more items")
|
||||
}
|
||||
iter.curj = 0
|
||||
}
|
||||
i, j := iter.curi, iter.curj
|
||||
iter.mu.Unlock()
|
||||
|
||||
// check if finished
|
||||
if _, ok := iter.finished[iter.ij2n(i, j)]; ok {
|
||||
return nil, downloader.ErrSkip
|
||||
}
|
||||
|
||||
return iter.item(ctx, i, j)
|
||||
}
|
||||
|
||||
func (iter *Iter) item(ctx context.Context, i, j int) (*downloader.Item, error) {
|
||||
peer, msg := iter.dialogs[i].Peer, iter.dialogs[i].Messages[j]
|
||||
|
||||
id := utils.Telegram.GetInputPeerID(peer)
|
||||
|
||||
message, err := utils.Telegram.GetSingleMessage(ctx, iter.pool.Default(ctx), peer, msg)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "resolve message")
|
||||
}
|
||||
|
||||
item, ok := tmedia.GetMedia(message)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("can not get media from %d/%d message",
|
||||
id, message.ID)
|
||||
}
|
||||
|
||||
// process include and exclude
|
||||
ext := filepath.Ext(item.Name)
|
||||
if len(iter.include) > 0 {
|
||||
if _, ok = iter.include[ext]; !ok {
|
||||
return nil, downloader.ErrSkip
|
||||
}
|
||||
}
|
||||
if len(iter.exclude) > 0 {
|
||||
if _, ok = iter.exclude[ext]; ok {
|
||||
return nil, downloader.ErrSkip
|
||||
}
|
||||
}
|
||||
|
||||
buf := bytes.Buffer{}
|
||||
err = iter.template.Execute(&buf, &fileTemplate{
|
||||
DialogID: id,
|
||||
MessageID: message.ID,
|
||||
MessageDate: int64(message.Date),
|
||||
FileName: item.Name,
|
||||
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
|
||||
DownloadDate: time.Now().Unix(),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
item.Name = buf.String()
|
||||
|
||||
return &downloader.Item{
|
||||
ID: iter.ij2n(i, j),
|
||||
Media: item,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (iter *Iter) Finish(_ context.Context, id int) error {
|
||||
iter.mu.Lock()
|
||||
defer iter.mu.Unlock()
|
||||
|
||||
iter.finished[id] = struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (iter *Iter) Total(_ context.Context) int {
|
||||
iter.mu.Lock()
|
||||
defer iter.mu.Unlock()
|
||||
|
||||
total := 0
|
||||
for _, m := range iter.dialogs {
|
||||
total += len(m.Messages)
|
||||
}
|
||||
return total
|
||||
}
|
||||
|
||||
func (iter *Iter) ij2n(i, j int) int {
|
||||
return iter.preSum[i] + j
|
||||
}
|
||||
|
||||
func (iter *Iter) SetFinished(finished map[int]struct{}) {
|
||||
iter.mu.Lock()
|
||||
defer iter.mu.Unlock()
|
||||
|
||||
iter.finished = finished
|
||||
}
|
||||
|
||||
func (iter *Iter) Finished() map[int]struct{} {
|
||||
iter.mu.Lock()
|
||||
defer iter.mu.Unlock()
|
||||
|
||||
return iter.finished
|
||||
}
|
||||
|
||||
func (iter *Iter) Fingerprint() string {
|
||||
return iter.fingerprint
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package dliter
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/iyear/tdl/pkg/tmessage"
|
||||
)
|
||||
|
||||
func TestPreSum(t *testing.T) {
|
||||
tests := []struct {
|
||||
dialogs []*tmessage.Dialog
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
|
||||
want: []int{0, 3, 5},
|
||||
},
|
||||
{
|
||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
|
||||
want: []int{0, 3, 6, 10},
|
||||
},
|
||||
{
|
||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}, {Messages: []int{1}}},
|
||||
want: []int{0, 3, 6, 10, 11},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
got := preSum(tt.dialogs)
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("preSum() = %v, want %v", got, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestIter_ij2n(t *testing.T) {
|
||||
tests := []struct {
|
||||
dialogs []*tmessage.Dialog
|
||||
input []struct {
|
||||
i, j int
|
||||
}
|
||||
want []int
|
||||
}{
|
||||
{
|
||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
|
||||
input: []struct {
|
||||
i, j int
|
||||
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}},
|
||||
want: []int{0, 1, 2, 3, 4},
|
||||
},
|
||||
{
|
||||
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
|
||||
input: []struct {
|
||||
i, j int
|
||||
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}, {2, 3}},
|
||||
want: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
iter := &Iter{preSum: preSum(tt.dialogs), dialogs: tt.dialogs}
|
||||
|
||||
for i, input := range tt.input {
|
||||
got := iter.ij2n(input.i, input.j)
|
||||
if got != tt.want[i] {
|
||||
t.Errorf("ij2n(%v, %v) = %v, want %v", input.i, input.j, got, tt.want[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
@ -1,44 +0,0 @@
|
||||
package dliter
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"text/template"
|
||||
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
|
||||
"github.com/iyear/tdl/pkg/dcpool"
|
||||
"github.com/iyear/tdl/pkg/kv"
|
||||
"github.com/iyear/tdl/pkg/tmessage"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Pool dcpool.Pool
|
||||
KV kv.KV
|
||||
Template string
|
||||
Include, Exclude []string
|
||||
Desc bool
|
||||
Dialogs [][]*tmessage.Dialog
|
||||
}
|
||||
|
||||
type Iter struct {
|
||||
pool dcpool.Pool
|
||||
dialogs []*tmessage.Dialog
|
||||
include, exclude map[string]struct{}
|
||||
mu sync.Mutex
|
||||
curi int
|
||||
curj int
|
||||
preSum []int
|
||||
finished map[int]struct{}
|
||||
template *template.Template
|
||||
manager *peers.Manager
|
||||
fingerprint string
|
||||
}
|
||||
|
||||
type fileTemplate struct {
|
||||
DialogID int64
|
||||
MessageID int
|
||||
MessageDate int64
|
||||
FileName string
|
||||
FileSize string
|
||||
DownloadDate int64
|
||||
}
|
@ -1,71 +0,0 @@
|
||||
package dliter
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/iyear/tdl/pkg/tmessage"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
|
||||
sort.Slice(dialogs, func(i, j int) bool {
|
||||
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
|
||||
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
|
||||
})
|
||||
|
||||
for _, m := range dialogs {
|
||||
sort.Slice(m.Messages, func(i, j int) bool {
|
||||
if desc {
|
||||
return m.Messages[i] > m.Messages[j]
|
||||
}
|
||||
return m.Messages[i] < m.Messages[j]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func fingerprint(dialogs []*tmessage.Dialog) string {
|
||||
endian := binary.BigEndian
|
||||
buf, b := &bytes.Buffer{}, make([]byte, 8)
|
||||
for _, m := range dialogs {
|
||||
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
|
||||
buf.Write(b)
|
||||
for _, msg := range m.Messages {
|
||||
endian.PutUint64(b, uint64(msg))
|
||||
buf.Write(b)
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
|
||||
}
|
||||
|
||||
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
|
||||
m := make(map[string]struct{})
|
||||
for _, v := range data {
|
||||
m[keyFn(v)] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func collectDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
|
||||
res := make([]*tmessage.Dialog, 0)
|
||||
for _, d := range dialogs {
|
||||
if len(d) == 0 {
|
||||
continue
|
||||
}
|
||||
res = append(res, d...)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
// preSum of dialogs
|
||||
func preSum(dialogs []*tmessage.Dialog) []int {
|
||||
sum := make([]int, len(dialogs)+1)
|
||||
for i, m := range dialogs {
|
||||
sum[i+1] = sum[i] + len(m.Messages)
|
||||
}
|
||||
return sum
|
||||
}
|
@ -25,7 +25,7 @@ func NewDownload() *cobra.Command {
|
||||
}
|
||||
|
||||
opts.Template = viper.GetString(consts.FlagDlTemplate)
|
||||
return dl.Run(logger.Named(cmd.Context(), "dl"), &opts)
|
||||
return dl.Run(logger.Named(cmd.Context(), "dl"), opts)
|
||||
},
|
||||
}
|
||||
|
||||
|
@ -2,168 +2,89 @@ package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fatih/color"
|
||||
"github.com/gabriel-vasile/mimetype"
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram/downloader"
|
||||
"github.com/jedib0t/go-pretty/v6/progress"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/iyear/tdl/pkg/dcpool"
|
||||
"github.com/iyear/tdl/pkg/logger"
|
||||
"github.com/iyear/tdl/pkg/prog"
|
||||
"github.com/iyear/tdl/pkg/utils"
|
||||
)
|
||||
|
||||
const TempExt = ".tmp"
|
||||
|
||||
var formatter = utils.Byte.FormatBinaryBytes
|
||||
|
||||
type Downloader struct {
|
||||
pw progress.Writer
|
||||
opts Options
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
Pool dcpool.Pool
|
||||
Dir string
|
||||
RewriteExt bool
|
||||
SkipSame bool
|
||||
PartSize int
|
||||
Threads int
|
||||
Iter Iter
|
||||
Takeout bool
|
||||
Progress Progress
|
||||
}
|
||||
|
||||
func New(opts Options) (*Downloader, error) {
|
||||
func New(opts Options) *Downloader {
|
||||
return &Downloader{
|
||||
pw: prog.New(formatter),
|
||||
opts: opts,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Downloader) Download(ctx context.Context, limit int) error {
|
||||
color.Green("All files will be downloaded to '%s' dir", d.opts.Dir)
|
||||
|
||||
total := d.opts.Iter.Total(ctx)
|
||||
d.pw.SetNumTrackersExpected(total)
|
||||
|
||||
go d.renderPinned(ctx, d.pw)
|
||||
go d.pw.Render()
|
||||
|
||||
wg, errctx := errgroup.WithContext(ctx)
|
||||
wg, wgctx := errgroup.WithContext(ctx)
|
||||
wg.SetLimit(limit)
|
||||
|
||||
for i := 0; i < total; i++ {
|
||||
item, err := d.opts.Iter.Next(errctx)
|
||||
if err != nil {
|
||||
logger.From(errctx).Debug("Iter next failed",
|
||||
zap.Int("index", i), zap.String("error", err.Error()))
|
||||
// skip error means we don't need to log error
|
||||
if !errors.Is(err, ErrSkip) && !errors.Is(err, context.Canceled) {
|
||||
d.pw.Log(color.RedString("failed: %v", err))
|
||||
}
|
||||
continue
|
||||
for d.opts.Iter.Next(wgctx) {
|
||||
elem := d.opts.Iter.Value()
|
||||
|
||||
wg.Go(func() (rerr error) {
|
||||
d.opts.Progress.OnAdd(elem)
|
||||
defer func() { d.opts.Progress.OnDone(elem, rerr) }()
|
||||
|
||||
if err := d.download(wgctx, elem); err != nil {
|
||||
// canceled by user, so we directly return error to stop all
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return errors.Wrap(err, "download")
|
||||
}
|
||||
|
||||
wg.Go(func() error {
|
||||
return d.download(errctx, item)
|
||||
// don't return error, just log it
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err := wg.Wait()
|
||||
if err != nil {
|
||||
d.pw.Stop()
|
||||
for d.pw.IsRenderInProgress() {
|
||||
time.Sleep(time.Millisecond * 10)
|
||||
if err := d.opts.Iter.Err(); err != nil {
|
||||
return errors.Wrap(err, "iter")
|
||||
}
|
||||
|
||||
// canceled error is ignored by gotd, so we can't detect it in main entry
|
||||
if errors.Is(err, context.Canceled) {
|
||||
color.Red("Download aborted by user")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
prog.Wait(ctx, d.pw)
|
||||
|
||||
return nil
|
||||
return wg.Wait()
|
||||
}
|
||||
|
||||
func (d *Downloader) download(ctx context.Context, item *Item) error {
|
||||
func (d *Downloader) download(ctx context.Context, elem Elem) error {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
}
|
||||
|
||||
logger.From(ctx).Debug("Start download item",
|
||||
zap.Any("item", item))
|
||||
logger.From(ctx).Debug("Start download elem",
|
||||
zap.Any("elem", elem))
|
||||
|
||||
client := d.opts.Pool.Client(ctx, elem.File().DC())
|
||||
if elem.AsTakeout() {
|
||||
client = d.opts.Pool.Takeout(ctx, elem.File().DC())
|
||||
}
|
||||
|
||||
_, err := downloader.NewDownloader().WithPartSize(d.opts.PartSize).
|
||||
Download(client, elem.File().Location()).
|
||||
WithThreads(d.bestThreads(elem.File().Size())).
|
||||
Parallel(ctx, newWriteAt(elem, d.opts.Progress, d.opts.PartSize))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download")
|
||||
}
|
||||
|
||||
if d.opts.SkipSame {
|
||||
if stat, err := os.Stat(filepath.Join(d.opts.Dir, item.Name)); err == nil {
|
||||
if utils.FS.GetNameWithoutExt(item.Name) == utils.FS.GetNameWithoutExt(stat.Name()) &&
|
||||
stat.Size() == item.Size {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
tracker := prog.AppendTracker(d.pw, formatter, item.Name, item.Size)
|
||||
filename := fmt.Sprintf("%s%s", item.Name, TempExt)
|
||||
path := filepath.Join(d.opts.Dir, filename)
|
||||
|
||||
// #113. If path contains dirs, create it. So now we support nested dirs.
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := d.opts.Pool.Client(ctx, item.DC)
|
||||
if d.opts.Takeout {
|
||||
client = d.opts.Pool.Takeout(ctx, item.DC)
|
||||
}
|
||||
|
||||
_, err = downloader.NewDownloader().WithPartSize(d.opts.PartSize).
|
||||
Download(client, item.InputFileLoc).
|
||||
WithThreads(d.bestThreads(item.Size)).
|
||||
Parallel(ctx, newWriteAt(f, tracker, d.opts.PartSize))
|
||||
if err := f.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// rename file, remove temp extension and add real extension
|
||||
newfile := strings.TrimSuffix(filename, TempExt)
|
||||
|
||||
if d.opts.RewriteExt {
|
||||
mime, err := mimetype.DetectFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ext := mime.Extension()
|
||||
if ext != "" && (filepath.Ext(newfile) != ext) {
|
||||
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
|
||||
}
|
||||
}
|
||||
|
||||
if err = os.Rename(path, filepath.Join(d.opts.Dir, newfile)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return d.opts.Iter.Finish(ctx, item.ID)
|
||||
}
|
||||
|
||||
// threads level
|
||||
|
@ -2,20 +2,26 @@ package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/iyear/tdl/pkg/tmedia"
|
||||
"github.com/gotd/td/tg"
|
||||
)
|
||||
|
||||
var ErrSkip = errors.New("skip")
|
||||
|
||||
type Iter interface {
|
||||
Next(ctx context.Context) (*Item, error)
|
||||
Finish(ctx context.Context, id int) error
|
||||
Total(ctx context.Context) int
|
||||
Next(ctx context.Context) bool
|
||||
Value() Elem
|
||||
Err() error
|
||||
}
|
||||
|
||||
type Item struct {
|
||||
ID int // unique in iter
|
||||
*tmedia.Media
|
||||
type Elem interface {
|
||||
File() File
|
||||
To() io.WriterAt
|
||||
|
||||
AsTakeout() bool
|
||||
}
|
||||
|
||||
type File interface {
|
||||
Location() tg.InputFileLocationClass
|
||||
Size() int64
|
||||
DC() int
|
||||
}
|
||||
|
57
pkg/downloader/progress.go
Normal file
57
pkg/downloader/progress.go
Normal file
@ -0,0 +1,57 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"go.uber.org/atomic"
|
||||
)
|
||||
|
||||
type Progress interface {
|
||||
OnAdd(elem Elem)
|
||||
OnDownload(elem Elem, state ProgressState)
|
||||
OnDone(elem Elem, err error)
|
||||
// TODO: OnLog to log something that is not an error but should be sent to the user
|
||||
}
|
||||
|
||||
type ProgressState struct {
|
||||
Downloaded int64
|
||||
Total int64
|
||||
}
|
||||
|
||||
// writeAt wrapper for file to use progress bar
|
||||
//
|
||||
// do not need mutex because gotd has use syncio.WriteAt
|
||||
type writeAt struct {
|
||||
elem Elem
|
||||
progress Progress
|
||||
partSize int
|
||||
|
||||
downloaded *atomic.Int64
|
||||
}
|
||||
|
||||
func newWriteAt(elem Elem, progress Progress, partSize int) *writeAt {
|
||||
return &writeAt{
|
||||
elem: elem,
|
||||
progress: progress,
|
||||
partSize: partSize,
|
||||
downloaded: atomic.NewInt64(0),
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
at, err := w.elem.To().WriteAt(p, off)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// some small files may finish too fast, terminal history may not be overwritten
|
||||
// this is just a simple way to avoid the problem
|
||||
if at < w.partSize { // last part(every file only exec once)
|
||||
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
|
||||
}
|
||||
w.progress.OnDownload(w.elem, ProgressState{
|
||||
Downloaded: w.downloaded.Add(int64(at)),
|
||||
Total: w.elem.File().Size(),
|
||||
})
|
||||
return at, nil
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jedib0t/go-pretty/v6/progress"
|
||||
|
||||
"github.com/iyear/tdl/pkg/ps"
|
||||
)
|
||||
|
||||
func (d *Downloader) renderPinned(ctx context.Context, pw progress.Writer) {
|
||||
f := func() { pw.SetPinnedMessages(strings.Join(ps.Humanize(ctx), " ")) }
|
||||
f()
|
||||
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
pw.SetPinnedMessages()
|
||||
return
|
||||
case <-ticker.C:
|
||||
f()
|
||||
}
|
||||
}
|
||||
}
|
@ -1,41 +0,0 @@
|
||||
package downloader
|
||||
|
||||
import (
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/jedib0t/go-pretty/v6/progress"
|
||||
)
|
||||
|
||||
// writeAt wrapper for file to use progress bar
|
||||
//
|
||||
// do not need mutex because gotd has use syncio.WriteAt
|
||||
type writeAt struct {
|
||||
f *os.File
|
||||
tracker *progress.Tracker
|
||||
partSize int
|
||||
}
|
||||
|
||||
func newWriteAt(f *os.File, tracker *progress.Tracker, partSize int) *writeAt {
|
||||
return &writeAt{
|
||||
f: f,
|
||||
tracker: tracker,
|
||||
partSize: partSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
at, err := w.f.WriteAt(p, off)
|
||||
if err != nil {
|
||||
w.tracker.MarkAsErrored()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// some small files may finish too fast, terminal history may not be overwritten
|
||||
// this is just a simple way to avoid the problem
|
||||
if at < w.partSize { // last part(every file only exec once)
|
||||
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
|
||||
}
|
||||
w.tracker.Increment(int64(at))
|
||||
return at, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user