refactor(dl): extract to interface and impl iter in app

This commit is contained in:
iyear 2023-11-27 18:57:32 +08:00
parent b7bbbf60f1
commit 169e47913a
14 changed files with 627 additions and 592 deletions

View File

@ -8,11 +8,11 @@ import (
"github.com/AlecAivazis/survey/v2"
"github.com/fatih/color"
"github.com/go-faster/errors"
"github.com/gotd/td/telegram/peers"
"github.com/spf13/viper"
"go.uber.org/multierr"
"go.uber.org/zap"
"github.com/iyear/tdl/app/internal/dliter"
"github.com/iyear/tdl/app/internal/tgc"
"github.com/iyear/tdl/pkg/consts"
"github.com/iyear/tdl/pkg/dcpool"
@ -20,7 +20,10 @@ import (
"github.com/iyear/tdl/pkg/key"
"github.com/iyear/tdl/pkg/kv"
"github.com/iyear/tdl/pkg/logger"
"github.com/iyear/tdl/pkg/prog"
"github.com/iyear/tdl/pkg/storage"
"github.com/iyear/tdl/pkg/tmessage"
"github.com/iyear/tdl/pkg/utils"
)
type Options struct {
@ -48,7 +51,7 @@ type parser struct {
Parser tmessage.ParseSource
}
func Run(ctx context.Context, opts *Options) error {
func Run(ctx context.Context, opts Options) error {
c, kvd, err := tgc.NoLogin(ctx)
if err != nil {
return err
@ -78,22 +81,16 @@ func Run(ctx context.Context, opts *Options) error {
return serve(ctx, kvd, pool, dialogs, opts.Port, opts.Takeout)
}
iter, err := dliter.New(ctx, &dliter.Options{
Pool: pool,
KV: kvd,
Template: opts.Template,
Include: opts.Include,
Exclude: opts.Exclude,
Desc: opts.Desc,
Dialogs: dialogs,
})
manager := peers.Options{Storage: storage.NewPeers(kvd)}.Build(pool.Default(ctx))
it, err := newIter(pool, manager, dialogs, opts)
if err != nil {
return err
}
if !opts.Restart {
// resume download and ask user to continue
if err = resume(ctx, kvd, iter, !opts.Continue); err != nil {
if err = resume(ctx, kvd, it, !opts.Continue); err != nil {
return err
}
} else {
@ -102,37 +99,39 @@ func Run(ctx context.Context, opts *Options) error {
defer func() { // save progress
if rerr != nil { // download is interrupted
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, iter))
multierr.AppendInto(&rerr, saveProgress(ctx, kvd, it))
} else { // if finished, we should clear resume key
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(iter.Fingerprint())))
multierr.AppendInto(&rerr, kvd.Delete(key.Resume(it.Fingerprint())))
}
}()
dlProgress := prog.New(utils.Byte.FormatBinaryBytes)
dlProgress.SetNumTrackersExpected(it.Total())
prog.EnablePS(ctx, dlProgress)
options := downloader.Options{
Pool: pool,
Dir: opts.Dir,
RewriteExt: opts.RewriteExt,
SkipSame: opts.SkipSame,
PartSize: viper.GetInt(consts.FlagPartSize),
Threads: viper.GetInt(consts.FlagThreads),
Iter: iter,
Takeout: opts.Takeout,
Iter: it,
Progress: newProgress(dlProgress, it, opts),
}
limit := viper.GetInt(consts.FlagLimit)
logger.From(ctx).Info("Start download",
zap.String("dir", options.Dir),
zap.Bool("rewrite_ext", options.RewriteExt),
zap.Bool("skip_same", options.SkipSame),
zap.String("dir", opts.Dir),
zap.Bool("rewrite_ext", opts.RewriteExt),
zap.Bool("skip_same", opts.SkipSame),
zap.Int("part_size", options.PartSize),
zap.Int("threads", options.Threads),
zap.Int("limit", limit))
dl, err := downloader.New(options)
if err != nil {
return errors.Wrap(err, "create downloader")
}
return dl.Download(ctx, limit)
color.Green("All files will be downloaded to '%s' dir", opts.Dir)
go dlProgress.Render()
defer prog.Wait(ctx, dlProgress)
return downloader.New(options).Download(ctx, limit)
})
}
@ -148,7 +147,7 @@ func collectDialogs(parsers []parser) ([][]*tmessage.Dialog, error) {
return dialogs, nil
}
func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
func resume(ctx context.Context, kvd kv.KV, iter *iter, ask bool) error {
logger.From(ctx).Debug("Check resume key",
zap.String("fingerprint", iter.Fingerprint()))
@ -171,7 +170,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
}
confirm := false
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total(ctx))
resumeStr := fmt.Sprintf("Found unfinished download, continue from '%d/%d'", len(finished), iter.Total())
if ask {
if err = survey.AskOne(&survey.Confirm{
Message: color.YellowString(resumeStr + "?"),
@ -195,7 +194,7 @@ func resume(ctx context.Context, kvd kv.KV, iter *dliter.Iter, ask bool) error {
return nil
}
func saveProgress(ctx context.Context, kvd kv.KV, it *dliter.Iter) error {
func saveProgress(ctx context.Context, kvd kv.KV, it *iter) error {
finished := it.Finished()
logger.From(ctx).Debug("Save progress",
zap.Int("finished", len(finished)))

38
app/dl/elem.go Normal file
View File

@ -0,0 +1,38 @@
package dl
import (
"io"
"os"
"github.com/gotd/td/telegram/peers"
"github.com/gotd/td/tg"
"github.com/iyear/tdl/pkg/downloader"
"github.com/iyear/tdl/pkg/tmedia"
)
type iterElem struct {
id int
from peers.Peer
fromMsg *tg.Message
file *tmedia.Media
to *os.File
opts Options
}
func (i *iterElem) File() downloader.File { return i }
func (i *iterElem) To() io.WriterAt { return i.to }
func (i *iterElem) AsTakeout() bool { return i.opts.Takeout }
func (i *iterElem) Location() tg.InputFileLocationClass { return i.file.InputFileLoc }
func (i *iterElem) Name() string { return i.file.Name }
func (i *iterElem) Size() int64 { return i.file.Size }
func (i *iterElem) DC() int { return i.file.DC }

323
app/dl/iter.go Normal file
View File

@ -0,0 +1,323 @@
package dl
import (
"bytes"
"context"
"crypto/sha256"
"encoding/binary"
"fmt"
"os"
"path/filepath"
"sort"
"sync"
"text/template"
"time"
"github.com/go-faster/errors"
"github.com/gotd/td/telegram/peers"
"github.com/iyear/tdl/pkg/dcpool"
"github.com/iyear/tdl/pkg/downloader"
"github.com/iyear/tdl/pkg/tmedia"
"github.com/iyear/tdl/pkg/tmessage"
"github.com/iyear/tdl/pkg/tplfunc"
"github.com/iyear/tdl/pkg/utils"
)
const tempExt = ".tmp"
type fileTemplate struct {
DialogID int64
MessageID int
MessageDate int64
FileName string
FileSize string
DownloadDate int64
}
type iter struct {
pool dcpool.Pool
manager *peers.Manager
dialogs []*tmessage.Dialog
tpl *template.Template
include map[string]struct{}
exclude map[string]struct{}
opts Options
mu *sync.Mutex
finished map[int]struct{}
fingerprint string
preSum []int
i, j int
elem downloader.Elem
err error
}
func newIter(pool dcpool.Pool, manager *peers.Manager, dialog [][]*tmessage.Dialog, opts Options) (*iter, error) {
tpl, err := template.New("dl").
Funcs(tplfunc.FuncMap(tplfunc.All...)).
Parse(opts.Template)
if err != nil {
return nil, errors.Wrap(err, "parse template")
}
dialogs := flatDialogs(dialog)
// if msgs is empty, return error to avoid range out of index
if len(dialogs) == 0 {
return nil, errors.Errorf("you must specify at least one message")
}
// include and exclude
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
// to keep fingerprint stable
sortDialogs(dialogs, opts.Desc)
return &iter{
pool: pool,
manager: manager,
dialogs: dialogs,
opts: opts,
include: includeMap,
exclude: excludeMap,
tpl: tpl,
mu: &sync.Mutex{},
finished: make(map[int]struct{}),
fingerprint: fingerprint(dialogs),
preSum: preSum(dialogs),
i: 0,
j: 0,
elem: nil,
err: nil,
}, nil
}
func (i *iter) Next(ctx context.Context) bool {
select {
case <-ctx.Done():
i.err = ctx.Err()
return false
default:
}
for {
ok, skip := i.process(ctx)
if skip {
continue
}
return ok
}
}
func (i *iter) process(ctx context.Context) (ret bool, skip bool) {
i.mu.Lock()
defer i.mu.Unlock()
defer func() {
if i.j++; i.i < len(i.dialogs) && i.j >= len(i.dialogs[i.i].Messages) {
i.i++
i.j = 0
}
}()
// end of iteration or error occurred
if i.i >= len(i.dialogs) || i.err != nil {
return false, false
}
peer, msg := i.dialogs[i.i].Peer, i.dialogs[i.i].Messages[i.j]
// check if finished
if _, ok := i.finished[i.ij2n(i.i, i.j)]; ok {
return false, true
}
from, err := i.manager.FromInputPeer(ctx, peer)
if err != nil {
i.err = errors.Wrap(err, "resolve from input peer")
return false, false
}
message, err := utils.Telegram.GetSingleMessage(ctx, i.pool.Default(ctx), peer, msg)
if err != nil {
i.err = errors.Wrap(err, "resolve message")
return false, false
}
item, ok := tmedia.GetMedia(message)
if !ok {
i.err = errors.Errorf("can not get media from %d/%d message", from.ID(), message.ID)
return false, false
}
// process include and exclude
ext := filepath.Ext(item.Name)
if _, ok = i.include[ext]; len(i.include) > 0 && !ok {
return false, true
}
if _, ok = i.exclude[ext]; len(i.exclude) > 0 && ok {
return false, true
}
toName := bytes.Buffer{}
err = i.tpl.Execute(&toName, &fileTemplate{
DialogID: from.ID(),
MessageID: message.ID,
MessageDate: int64(message.Date),
FileName: item.Name,
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
DownloadDate: time.Now().Unix(),
})
if err != nil {
i.err = errors.Wrap(err, "execute template")
return false, false
}
if i.opts.SkipSame {
if stat, err := os.Stat(filepath.Join(i.opts.Dir, toName.String())); err == nil {
if utils.FS.GetNameWithoutExt(toName.String()) == utils.FS.GetNameWithoutExt(stat.Name()) &&
stat.Size() == item.Size {
return false, true
}
}
}
filename := fmt.Sprintf("%s%s", toName.String(), tempExt)
path := filepath.Join(i.opts.Dir, filename)
// #113. If path contains dirs, create it. So now we support nested dirs.
if err = os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
i.err = errors.Wrap(err, "create dir")
return false, false
}
to, err := os.Create(path)
if err != nil {
i.err = errors.Wrap(err, "create file")
return false, false
}
i.elem = &iterElem{
id: i.ij2n(i.i, i.j),
from: from,
fromMsg: message,
file: item,
to: to,
opts: i.opts,
}
return true, false
}
func (i *iter) Value() downloader.Elem {
return i.elem
}
func (i *iter) Err() error {
return i.err
}
func (i *iter) SetFinished(finished map[int]struct{}) {
i.mu.Lock()
defer i.mu.Unlock()
i.finished = finished
}
func (i *iter) Finished() map[int]struct{} {
i.mu.Lock()
defer i.mu.Unlock()
return i.finished
}
func (i *iter) Fingerprint() string {
return i.fingerprint
}
func (i *iter) Finish(id int) {
i.mu.Lock()
defer i.mu.Unlock()
i.finished[id] = struct{}{}
}
func (i *iter) Total() int {
i.mu.Lock()
defer i.mu.Unlock()
total := 0
for _, m := range i.dialogs {
total += len(m.Messages)
}
return total
}
func (i *iter) ij2n(ii, jj int) int {
return i.preSum[ii] + jj
}
func flatDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
res := make([]*tmessage.Dialog, 0)
for _, d := range dialogs {
if len(d) == 0 {
continue
}
res = append(res, d...)
}
return res
}
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
m := make(map[string]struct{})
for _, v := range data {
m[keyFn(v)] = struct{}{}
}
return m
}
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
sort.Slice(dialogs, func(i, j int) bool {
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
})
for _, m := range dialogs {
sort.Slice(m.Messages, func(i, j int) bool {
if desc {
return m.Messages[i] > m.Messages[j]
}
return m.Messages[i] < m.Messages[j]
})
}
}
// preSum of dialogs
func preSum(dialogs []*tmessage.Dialog) []int {
sum := make([]int, len(dialogs)+1)
for i, m := range dialogs {
sum[i+1] = sum[i] + len(m.Messages)
}
return sum
}
func fingerprint(dialogs []*tmessage.Dialog) string {
endian := binary.BigEndian
buf, b := &bytes.Buffer{}, make([]byte, 8)
for _, m := range dialogs {
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
buf.Write(b)
for _, msg := range m.Messages {
endian.PutUint64(b, uint64(msg))
buf.Write(b)
}
}
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
}

121
app/dl/progress.go Normal file
View File

@ -0,0 +1,121 @@
package dl
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"github.com/fatih/color"
"github.com/gabriel-vasile/mimetype"
"github.com/go-faster/errors"
pw "github.com/jedib0t/go-pretty/v6/progress"
"github.com/iyear/tdl/pkg/downloader"
"github.com/iyear/tdl/pkg/prog"
"github.com/iyear/tdl/pkg/utils"
)
type progress struct {
pw pw.Writer
trackers *sync.Map // map[ID]*pw.Tracker
opts Options
it *iter
}
func newProgress(p pw.Writer, it *iter, opts Options) *progress {
return &progress{
pw: p,
trackers: &sync.Map{},
opts: opts,
it: it,
}
}
func (p *progress) OnAdd(elem downloader.Elem) {
tracker := prog.AppendTracker(p.pw, utils.Byte.FormatBinaryBytes, p.processMessage(elem), elem.File().Size())
p.trackers.Store(elem.(*iterElem).id, tracker)
}
func (p *progress) OnDownload(elem downloader.Elem, state downloader.ProgressState) {
tracker, ok := p.trackers.Load(elem.(*iterElem).id)
if !ok {
return
}
t := tracker.(*pw.Tracker)
t.UpdateTotal(state.Total)
t.SetValue(state.Downloaded)
}
func (p *progress) OnDone(elem downloader.Elem, err error) {
e := elem.(*iterElem)
tracker, ok := p.trackers.Load(e.id)
if !ok {
return
}
t := tracker.(*pw.Tracker)
if err := e.to.Close(); err != nil {
p.fail(t, elem, errors.Wrap(err, "close file"))
return
}
if err != nil {
if !errors.Is(err, context.Canceled) { // don't report user cancel
p.fail(t, elem, errors.Wrap(err, "progress"))
}
_ = os.Remove(e.to.Name()) // just try to remove temp file, ignore error
return
}
p.it.Finish(e.id)
if err := p.donePost(e); err != nil {
p.fail(t, elem, errors.Wrap(err, "post file"))
return
}
}
func (p *progress) donePost(elem *iterElem) error {
newfile := strings.TrimSuffix(filepath.Base(elem.to.Name()), tempExt)
if p.opts.RewriteExt {
mime, err := mimetype.DetectFile(elem.to.Name())
if err != nil {
return errors.Wrap(err, "detect mime")
}
ext := mime.Extension()
if ext != "" && (filepath.Ext(newfile) != ext) {
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
}
}
if err := os.Rename(elem.to.Name(), filepath.Join(p.opts.Dir, newfile)); err != nil {
return errors.Wrap(err, "rename file")
}
return nil
}
func (p *progress) fail(t *pw.Tracker, elem downloader.Elem, err error) {
p.pw.Log(color.RedString("%s error: %s", p.elemString(elem), err.Error()))
t.MarkAsErrored()
}
func (p *progress) processMessage(elem downloader.Elem) string {
return p.elemString(elem)
}
func (p *progress) elemString(elem downloader.Elem) string {
e := elem.(*iterElem)
return fmt.Sprintf("%s(%d):%d -> %s",
e.from.VisibleName(),
e.from.ID(),
e.fromMsg.ID,
strings.TrimSuffix(e.to.Name(), tempExt))
}

View File

@ -1,174 +0,0 @@
package dliter
import (
"bytes"
"context"
"fmt"
"path/filepath"
"text/template"
"time"
"github.com/go-faster/errors"
"github.com/gotd/td/telegram/peers"
"github.com/iyear/tdl/pkg/downloader"
"github.com/iyear/tdl/pkg/storage"
"github.com/iyear/tdl/pkg/tmedia"
"github.com/iyear/tdl/pkg/tplfunc"
"github.com/iyear/tdl/pkg/utils"
)
func New(ctx context.Context, opts *Options) (*Iter, error) {
tpl, err := template.New("dl").
Funcs(tplfunc.FuncMap(tplfunc.All...)).
Parse(opts.Template)
if err != nil {
return nil, err
}
dialogs := collectDialogs(opts.Dialogs)
// if msgs is empty, return error to avoid range out of index
if len(dialogs) == 0 {
return nil, fmt.Errorf("you must specify at least one message")
}
// include and exclude
includeMap := filterMap(opts.Include, utils.FS.AddPrefixDot)
excludeMap := filterMap(opts.Exclude, utils.FS.AddPrefixDot)
// to keep fingerprint stable
sortDialogs(dialogs, opts.Desc)
manager := peers.Options{Storage: storage.NewPeers(opts.KV)}.Build(opts.Pool.Default(ctx))
it := &Iter{
pool: opts.Pool,
dialogs: dialogs,
include: includeMap,
exclude: excludeMap,
curi: 0,
curj: -1,
preSum: preSum(dialogs),
finished: make(map[int]struct{}),
template: tpl,
manager: manager,
fingerprint: fingerprint(dialogs),
}
return it, nil
}
func (iter *Iter) Next(ctx context.Context) (*downloader.Item, error) {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
iter.mu.Lock()
iter.curj++
if iter.curj >= len(iter.dialogs[iter.curi].Messages) {
if iter.curi++; iter.curi >= len(iter.dialogs) {
return nil, errors.New("no more items")
}
iter.curj = 0
}
i, j := iter.curi, iter.curj
iter.mu.Unlock()
// check if finished
if _, ok := iter.finished[iter.ij2n(i, j)]; ok {
return nil, downloader.ErrSkip
}
return iter.item(ctx, i, j)
}
func (iter *Iter) item(ctx context.Context, i, j int) (*downloader.Item, error) {
peer, msg := iter.dialogs[i].Peer, iter.dialogs[i].Messages[j]
id := utils.Telegram.GetInputPeerID(peer)
message, err := utils.Telegram.GetSingleMessage(ctx, iter.pool.Default(ctx), peer, msg)
if err != nil {
return nil, errors.Wrap(err, "resolve message")
}
item, ok := tmedia.GetMedia(message)
if !ok {
return nil, fmt.Errorf("can not get media from %d/%d message",
id, message.ID)
}
// process include and exclude
ext := filepath.Ext(item.Name)
if len(iter.include) > 0 {
if _, ok = iter.include[ext]; !ok {
return nil, downloader.ErrSkip
}
}
if len(iter.exclude) > 0 {
if _, ok = iter.exclude[ext]; ok {
return nil, downloader.ErrSkip
}
}
buf := bytes.Buffer{}
err = iter.template.Execute(&buf, &fileTemplate{
DialogID: id,
MessageID: message.ID,
MessageDate: int64(message.Date),
FileName: item.Name,
FileSize: utils.Byte.FormatBinaryBytes(item.Size),
DownloadDate: time.Now().Unix(),
})
if err != nil {
return nil, err
}
item.Name = buf.String()
return &downloader.Item{
ID: iter.ij2n(i, j),
Media: item,
}, nil
}
func (iter *Iter) Finish(_ context.Context, id int) error {
iter.mu.Lock()
defer iter.mu.Unlock()
iter.finished[id] = struct{}{}
return nil
}
func (iter *Iter) Total(_ context.Context) int {
iter.mu.Lock()
defer iter.mu.Unlock()
total := 0
for _, m := range iter.dialogs {
total += len(m.Messages)
}
return total
}
func (iter *Iter) ij2n(i, j int) int {
return iter.preSum[i] + j
}
func (iter *Iter) SetFinished(finished map[int]struct{}) {
iter.mu.Lock()
defer iter.mu.Unlock()
iter.finished = finished
}
func (iter *Iter) Finished() map[int]struct{} {
iter.mu.Lock()
defer iter.mu.Unlock()
return iter.finished
}
func (iter *Iter) Fingerprint() string {
return iter.fingerprint
}

View File

@ -1,71 +0,0 @@
package dliter
import (
"reflect"
"testing"
"github.com/iyear/tdl/pkg/tmessage"
)
func TestPreSum(t *testing.T) {
tests := []struct {
dialogs []*tmessage.Dialog
want []int
}{
{
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
want: []int{0, 3, 5},
},
{
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
want: []int{0, 3, 6, 10},
},
{
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}, {Messages: []int{1}}},
want: []int{0, 3, 6, 10, 11},
},
}
for _, tt := range tests {
got := preSum(tt.dialogs)
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("preSum() = %v, want %v", got, tt.want)
}
}
}
func TestIter_ij2n(t *testing.T) {
tests := []struct {
dialogs []*tmessage.Dialog
input []struct {
i, j int
}
want []int
}{
{
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2}}},
input: []struct {
i, j int
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}},
want: []int{0, 1, 2, 3, 4},
},
{
dialogs: []*tmessage.Dialog{{Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3}}, {Messages: []int{1, 2, 3, 4}}},
input: []struct {
i, j int
}{{0, 0}, {0, 1}, {0, 2}, {1, 0}, {1, 1}, {1, 2}, {2, 0}, {2, 1}, {2, 2}, {2, 3}},
want: []int{0, 1, 2, 3, 4, 5, 6, 7, 8, 9},
},
}
for _, tt := range tests {
iter := &Iter{preSum: preSum(tt.dialogs), dialogs: tt.dialogs}
for i, input := range tt.input {
got := iter.ij2n(input.i, input.j)
if got != tt.want[i] {
t.Errorf("ij2n(%v, %v) = %v, want %v", input.i, input.j, got, tt.want[i])
}
}
}
}

View File

@ -1,44 +0,0 @@
package dliter
import (
"sync"
"text/template"
"github.com/gotd/td/telegram/peers"
"github.com/iyear/tdl/pkg/dcpool"
"github.com/iyear/tdl/pkg/kv"
"github.com/iyear/tdl/pkg/tmessage"
)
type Options struct {
Pool dcpool.Pool
KV kv.KV
Template string
Include, Exclude []string
Desc bool
Dialogs [][]*tmessage.Dialog
}
type Iter struct {
pool dcpool.Pool
dialogs []*tmessage.Dialog
include, exclude map[string]struct{}
mu sync.Mutex
curi int
curj int
preSum []int
finished map[int]struct{}
template *template.Template
manager *peers.Manager
fingerprint string
}
type fileTemplate struct {
DialogID int64
MessageID int
MessageDate int64
FileName string
FileSize string
DownloadDate int64
}

View File

@ -1,71 +0,0 @@
package dliter
import (
"bytes"
"crypto/sha256"
"encoding/binary"
"fmt"
"sort"
"github.com/iyear/tdl/pkg/tmessage"
"github.com/iyear/tdl/pkg/utils"
)
func sortDialogs(dialogs []*tmessage.Dialog, desc bool) {
sort.Slice(dialogs, func(i, j int) bool {
return utils.Telegram.GetInputPeerID(dialogs[i].Peer) <
utils.Telegram.GetInputPeerID(dialogs[j].Peer) // increasing order
})
for _, m := range dialogs {
sort.Slice(m.Messages, func(i, j int) bool {
if desc {
return m.Messages[i] > m.Messages[j]
}
return m.Messages[i] < m.Messages[j]
})
}
}
func fingerprint(dialogs []*tmessage.Dialog) string {
endian := binary.BigEndian
buf, b := &bytes.Buffer{}, make([]byte, 8)
for _, m := range dialogs {
endian.PutUint64(b, uint64(utils.Telegram.GetInputPeerID(m.Peer)))
buf.Write(b)
for _, msg := range m.Messages {
endian.PutUint64(b, uint64(msg))
buf.Write(b)
}
}
return fmt.Sprintf("%x", sha256.Sum256(buf.Bytes()))
}
func filterMap(data []string, keyFn func(key string) string) map[string]struct{} {
m := make(map[string]struct{})
for _, v := range data {
m[keyFn(v)] = struct{}{}
}
return m
}
func collectDialogs(dialogs [][]*tmessage.Dialog) []*tmessage.Dialog {
res := make([]*tmessage.Dialog, 0)
for _, d := range dialogs {
if len(d) == 0 {
continue
}
res = append(res, d...)
}
return res
}
// preSum of dialogs
func preSum(dialogs []*tmessage.Dialog) []int {
sum := make([]int, len(dialogs)+1)
for i, m := range dialogs {
sum[i+1] = sum[i] + len(m.Messages)
}
return sum
}

View File

@ -25,7 +25,7 @@ func NewDownload() *cobra.Command {
}
opts.Template = viper.GetString(consts.FlagDlTemplate)
return dl.Run(logger.Named(cmd.Context(), "dl"), &opts)
return dl.Run(logger.Named(cmd.Context(), "dl"), opts)
},
}

View File

@ -2,168 +2,89 @@ package downloader
import (
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/fatih/color"
"github.com/gabriel-vasile/mimetype"
"github.com/go-faster/errors"
"github.com/gotd/td/telegram/downloader"
"github.com/jedib0t/go-pretty/v6/progress"
"go.uber.org/zap"
"golang.org/x/sync/errgroup"
"github.com/iyear/tdl/pkg/dcpool"
"github.com/iyear/tdl/pkg/logger"
"github.com/iyear/tdl/pkg/prog"
"github.com/iyear/tdl/pkg/utils"
)
const TempExt = ".tmp"
var formatter = utils.Byte.FormatBinaryBytes
type Downloader struct {
pw progress.Writer
opts Options
}
type Options struct {
Pool dcpool.Pool
Dir string
RewriteExt bool
SkipSame bool
PartSize int
Threads int
Iter Iter
Takeout bool
Progress Progress
}
func New(opts Options) (*Downloader, error) {
func New(opts Options) *Downloader {
return &Downloader{
pw: prog.New(formatter),
opts: opts,
}, nil
}
}
func (d *Downloader) Download(ctx context.Context, limit int) error {
color.Green("All files will be downloaded to '%s' dir", d.opts.Dir)
total := d.opts.Iter.Total(ctx)
d.pw.SetNumTrackersExpected(total)
go d.renderPinned(ctx, d.pw)
go d.pw.Render()
wg, errctx := errgroup.WithContext(ctx)
wg, wgctx := errgroup.WithContext(ctx)
wg.SetLimit(limit)
for i := 0; i < total; i++ {
item, err := d.opts.Iter.Next(errctx)
if err != nil {
logger.From(errctx).Debug("Iter next failed",
zap.Int("index", i), zap.String("error", err.Error()))
// skip error means we don't need to log error
if !errors.Is(err, ErrSkip) && !errors.Is(err, context.Canceled) {
d.pw.Log(color.RedString("failed: %v", err))
}
continue
for d.opts.Iter.Next(wgctx) {
elem := d.opts.Iter.Value()
wg.Go(func() (rerr error) {
d.opts.Progress.OnAdd(elem)
defer func() { d.opts.Progress.OnDone(elem, rerr) }()
if err := d.download(wgctx, elem); err != nil {
// canceled by user, so we directly return error to stop all
if errors.Is(err, context.Canceled) {
return errors.Wrap(err, "download")
}
wg.Go(func() error {
return d.download(errctx, item)
// don't return error, just log it
}
return nil
})
}
err := wg.Wait()
if err != nil {
d.pw.Stop()
for d.pw.IsRenderInProgress() {
time.Sleep(time.Millisecond * 10)
if err := d.opts.Iter.Err(); err != nil {
return errors.Wrap(err, "iter")
}
// canceled error is ignored by gotd, so we can't detect it in main entry
if errors.Is(err, context.Canceled) {
color.Red("Download aborted by user")
}
return err
}
prog.Wait(ctx, d.pw)
return nil
return wg.Wait()
}
func (d *Downloader) download(ctx context.Context, item *Item) error {
func (d *Downloader) download(ctx context.Context, elem Elem) error {
select {
case <-ctx.Done():
return ctx.Err()
default:
}
logger.From(ctx).Debug("Start download item",
zap.Any("item", item))
logger.From(ctx).Debug("Start download elem",
zap.Any("elem", elem))
client := d.opts.Pool.Client(ctx, elem.File().DC())
if elem.AsTakeout() {
client = d.opts.Pool.Takeout(ctx, elem.File().DC())
}
_, err := downloader.NewDownloader().WithPartSize(d.opts.PartSize).
Download(client, elem.File().Location()).
WithThreads(d.bestThreads(elem.File().Size())).
Parallel(ctx, newWriteAt(elem, d.opts.Progress, d.opts.PartSize))
if err != nil {
return errors.Wrap(err, "download")
}
if d.opts.SkipSame {
if stat, err := os.Stat(filepath.Join(d.opts.Dir, item.Name)); err == nil {
if utils.FS.GetNameWithoutExt(item.Name) == utils.FS.GetNameWithoutExt(stat.Name()) &&
stat.Size() == item.Size {
return nil
}
}
}
tracker := prog.AppendTracker(d.pw, formatter, item.Name, item.Size)
filename := fmt.Sprintf("%s%s", item.Name, TempExt)
path := filepath.Join(d.opts.Dir, filename)
// #113. If path contains dirs, create it. So now we support nested dirs.
if err := os.MkdirAll(filepath.Dir(path), 0o755); err != nil {
return err
}
f, err := os.Create(path)
if err != nil {
return err
}
client := d.opts.Pool.Client(ctx, item.DC)
if d.opts.Takeout {
client = d.opts.Pool.Takeout(ctx, item.DC)
}
_, err = downloader.NewDownloader().WithPartSize(d.opts.PartSize).
Download(client, item.InputFileLoc).
WithThreads(d.bestThreads(item.Size)).
Parallel(ctx, newWriteAt(f, tracker, d.opts.PartSize))
if err := f.Close(); err != nil {
return err
}
if err != nil {
return err
}
// rename file, remove temp extension and add real extension
newfile := strings.TrimSuffix(filename, TempExt)
if d.opts.RewriteExt {
mime, err := mimetype.DetectFile(path)
if err != nil {
return err
}
ext := mime.Extension()
if ext != "" && (filepath.Ext(newfile) != ext) {
newfile = utils.FS.GetNameWithoutExt(newfile) + ext
}
}
if err = os.Rename(path, filepath.Join(d.opts.Dir, newfile)); err != nil {
return err
}
return d.opts.Iter.Finish(ctx, item.ID)
}
// threads level

View File

@ -2,20 +2,26 @@ package downloader
import (
"context"
"errors"
"io"
"github.com/iyear/tdl/pkg/tmedia"
"github.com/gotd/td/tg"
)
var ErrSkip = errors.New("skip")
type Iter interface {
Next(ctx context.Context) (*Item, error)
Finish(ctx context.Context, id int) error
Total(ctx context.Context) int
Next(ctx context.Context) bool
Value() Elem
Err() error
}
type Item struct {
ID int // unique in iter
*tmedia.Media
type Elem interface {
File() File
To() io.WriterAt
AsTakeout() bool
}
type File interface {
Location() tg.InputFileLocationClass
Size() int64
DC() int
}

View File

@ -0,0 +1,57 @@
package downloader
import (
"time"
"go.uber.org/atomic"
)
type Progress interface {
OnAdd(elem Elem)
OnDownload(elem Elem, state ProgressState)
OnDone(elem Elem, err error)
// TODO: OnLog to log something that is not an error but should be sent to the user
}
type ProgressState struct {
Downloaded int64
Total int64
}
// writeAt wrapper for file to use progress bar
//
// do not need mutex because gotd has use syncio.WriteAt
type writeAt struct {
elem Elem
progress Progress
partSize int
downloaded *atomic.Int64
}
func newWriteAt(elem Elem, progress Progress, partSize int) *writeAt {
return &writeAt{
elem: elem,
progress: progress,
partSize: partSize,
downloaded: atomic.NewInt64(0),
}
}
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
at, err := w.elem.To().WriteAt(p, off)
if err != nil {
return 0, err
}
// some small files may finish too fast, terminal history may not be overwritten
// this is just a simple way to avoid the problem
if at < w.partSize { // last part(every file only exec once)
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
}
w.progress.OnDownload(w.elem, ProgressState{
Downloaded: w.downloaded.Add(int64(at)),
Total: w.elem.File().Size(),
})
return at, nil
}

View File

@ -1,29 +0,0 @@
package downloader
import (
"context"
"strings"
"time"
"github.com/jedib0t/go-pretty/v6/progress"
"github.com/iyear/tdl/pkg/ps"
)
func (d *Downloader) renderPinned(ctx context.Context, pw progress.Writer) {
f := func() { pw.SetPinnedMessages(strings.Join(ps.Humanize(ctx), " ")) }
f()
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
pw.SetPinnedMessages()
return
case <-ticker.C:
f()
}
}
}

View File

@ -1,41 +0,0 @@
package downloader
import (
"os"
"time"
"github.com/jedib0t/go-pretty/v6/progress"
)
// writeAt wrapper for file to use progress bar
//
// do not need mutex because gotd has use syncio.WriteAt
type writeAt struct {
f *os.File
tracker *progress.Tracker
partSize int
}
func newWriteAt(f *os.File, tracker *progress.Tracker, partSize int) *writeAt {
return &writeAt{
f: f,
tracker: tracker,
partSize: partSize,
}
}
func (w *writeAt) WriteAt(p []byte, off int64) (int, error) {
at, err := w.f.WriteAt(p, off)
if err != nil {
w.tracker.MarkAsErrored()
return 0, err
}
// some small files may finish too fast, terminal history may not be overwritten
// this is just a simple way to avoid the problem
if at < w.partSize { // last part(every file only exec once)
time.Sleep(time.Millisecond * 200) // to ensure the progress render next time
}
w.tracker.Increment(int64(at))
return at, nil
}