mirror of
https://github.com/iyear/tdl
synced 2025-01-08 11:57:55 +08:00
feat(forwarder): fix grouped clone, speed up clone
This commit is contained in:
parent
da2206f41c
commit
d45a0dc338
@ -94,6 +94,7 @@ func Run(ctx context.Context, opts Options) error {
|
||||
}),
|
||||
Progress: newProgress(fwProgress),
|
||||
PartSize: viper.GetInt(consts.FlagPartSize),
|
||||
Threads: viper.GetInt(consts.FlagThreads),
|
||||
})
|
||||
|
||||
go fwProgress.Render()
|
||||
|
@ -63,7 +63,10 @@ func (p *progress) OnDone(elem forwarder.Elem, err error) {
|
||||
return
|
||||
}
|
||||
|
||||
tracker.Increment(1)
|
||||
if tracker.Total == 1 {
|
||||
tracker.Increment(1)
|
||||
}
|
||||
tracker.MarkAsDone()
|
||||
}
|
||||
|
||||
func (p *progress) tuple(elem forwarder.Elem) tuple {
|
||||
|
@ -3,73 +3,132 @@ package forwarder
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"os"
|
||||
|
||||
"github.com/go-faster/errors"
|
||||
"github.com/gotd/td/telegram/downloader"
|
||||
"github.com/gotd/td/telegram/uploader"
|
||||
"github.com/gotd/td/tg"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/multierr"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/iyear/tdl/pkg/tmedia"
|
||||
)
|
||||
|
||||
type CloneOptions struct {
|
||||
Media *tmedia.Media
|
||||
PartSize int
|
||||
Progress uploader.Progress
|
||||
type cloneOptions struct {
|
||||
elem Elem
|
||||
media *tmedia.Media
|
||||
progress progressAdd
|
||||
}
|
||||
|
||||
func (f *Forwarder) CloneMedia(ctx context.Context, opts CloneOptions, dryRun bool) (tg.InputFileClass, error) {
|
||||
type progressAdd interface {
|
||||
add(n int64)
|
||||
}
|
||||
|
||||
func (f *Forwarder) cloneMedia(ctx context.Context, opts cloneOptions, dryRun bool) (_ tg.InputFileClass, rerr error) {
|
||||
// if dry run, just return empty input file
|
||||
if dryRun {
|
||||
// directly call progress callback
|
||||
if err := opts.Progress.Chunk(ctx, uploader.ProgressState{
|
||||
Uploaded: opts.Media.Size,
|
||||
Total: opts.Media.Size,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "dry run chunk")
|
||||
}
|
||||
opts.progress.add(opts.media.Size * 2)
|
||||
|
||||
return &tg.InputFile{}, nil
|
||||
}
|
||||
|
||||
r, w := io.Pipe()
|
||||
temp, err := os.CreateTemp("", "tdl_*")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "create temp file")
|
||||
}
|
||||
defer func() {
|
||||
multierr.AppendInto(&rerr, temp.Close())
|
||||
multierr.AppendInto(&rerr, os.Remove(temp.Name()))
|
||||
}()
|
||||
|
||||
wg, errctx := errgroup.WithContext(ctx)
|
||||
threads := bestThreads(opts.media.Size, f.opts.Threads)
|
||||
|
||||
wg.Go(func() (rerr error) {
|
||||
defer multierr.AppendInvoke(&rerr, multierr.Close(w))
|
||||
|
||||
_, err := downloader.NewDownloader().
|
||||
WithPartSize(opts.PartSize).
|
||||
Download(f.opts.Pool.Client(ctx, opts.Media.DC), opts.Media.InputFileLoc).
|
||||
Stream(errctx, w)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "download")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
_, err = downloader.NewDownloader().
|
||||
WithPartSize(f.opts.PartSize).
|
||||
Download(f.opts.Pool.Client(ctx, opts.media.DC), opts.media.InputFileLoc).
|
||||
WithThreads(threads).
|
||||
Parallel(ctx, writeAt{
|
||||
f: temp,
|
||||
opts: opts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "download")
|
||||
}
|
||||
|
||||
var file tg.InputFileClass
|
||||
wg.Go(func() (rerr error) {
|
||||
defer multierr.AppendInvoke(&rerr, multierr.Close(r))
|
||||
|
||||
var err error
|
||||
upload := uploader.NewUpload(opts.Media.Name, r, opts.Media.Size)
|
||||
file, err = uploader.NewUploader(f.opts.Pool.Default(ctx)).
|
||||
WithPartSize(opts.PartSize).
|
||||
WithProgress(opts.Progress).
|
||||
Upload(errctx, upload)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "upload")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if _, err = temp.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, errors.Wrap(err, "seek")
|
||||
}
|
||||
|
||||
if err := wg.Wait(); err != nil {
|
||||
return nil, errors.Wrap(err, "wait")
|
||||
upload := uploader.NewUpload(opts.media.Name, temp, opts.media.Size)
|
||||
file, err = uploader.NewUploader(f.opts.Pool.Default(ctx)).
|
||||
WithPartSize(f.opts.PartSize).
|
||||
WithThreads(threads).
|
||||
WithProgress(uploaded{
|
||||
opts: opts,
|
||||
prev: atomic.NewInt64(0),
|
||||
}).
|
||||
Upload(ctx, upload)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "upload")
|
||||
}
|
||||
|
||||
return file, nil
|
||||
}
|
||||
|
||||
type writeAt struct {
|
||||
f io.WriterAt
|
||||
opts cloneOptions
|
||||
}
|
||||
|
||||
func (w writeAt) WriteAt(p []byte, off int64) (int, error) {
|
||||
n, err := w.f.WriteAt(p, off)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.opts.progress.add(int64(n))
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
type uploaded struct {
|
||||
opts cloneOptions
|
||||
prev *atomic.Int64
|
||||
}
|
||||
|
||||
func (u uploaded) Chunk(_ context.Context, state uploader.ProgressState) error {
|
||||
u.opts.progress.add(state.Uploaded - u.prev.Swap(state.Uploaded))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var threadsLevels = []struct {
|
||||
threads int
|
||||
size int64
|
||||
}{
|
||||
{1, 1 << 20},
|
||||
{2, 5 << 20},
|
||||
{4, 20 << 20},
|
||||
{8, 50 << 20},
|
||||
}
|
||||
|
||||
// Get best threads num for download, based on file size
|
||||
func bestThreads(size int64, max int) int {
|
||||
for _, t := range threadsLevels {
|
||||
if size < t.size {
|
||||
return min(t.threads, max)
|
||||
}
|
||||
}
|
||||
return max
|
||||
}
|
||||
|
||||
func min(a, b int) int {
|
||||
if a < b {
|
||||
return a
|
||||
}
|
||||
return b
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/gotd/td/telegram/message"
|
||||
"github.com/gotd/td/telegram/peers"
|
||||
"github.com/gotd/td/tg"
|
||||
"go.uber.org/atomic"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/iyear/tdl/pkg/dcpool"
|
||||
@ -27,6 +28,7 @@ type Mode int
|
||||
type Options struct {
|
||||
Pool dcpool.Pool
|
||||
PartSize int
|
||||
Threads int
|
||||
Iter Iter
|
||||
Progress Progress
|
||||
}
|
||||
@ -100,6 +102,13 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
zap.Int64("to", elem.To().ID()),
|
||||
zap.Int("message", elem.Msg().ID))
|
||||
|
||||
// used for clone progress
|
||||
totalSize, err := mediaSizeSum(elem.Msg(), grouped...)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "media total size")
|
||||
}
|
||||
done := atomic.NewInt64(0)
|
||||
|
||||
forwardTextOnly := func(msg *tg.Message) error {
|
||||
if msg.Message == "" {
|
||||
return errors.Errorf("empty message content, skip send: %d", msg.ID)
|
||||
@ -158,18 +167,21 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
return nil, errors.Errorf("unsupported media %T", msg.Media)
|
||||
}
|
||||
|
||||
mediaFile, err := f.CloneMedia(ctx, CloneOptions{
|
||||
Media: media,
|
||||
PartSize: f.opts.PartSize,
|
||||
Progress: uploadProgress{
|
||||
mediaFile, err := f.cloneMedia(ctx, cloneOptions{
|
||||
elem: elem,
|
||||
media: media,
|
||||
progress: &wrapProgress{
|
||||
elem: elem,
|
||||
progress: f.opts.Progress,
|
||||
done: done,
|
||||
total: totalSize * 2,
|
||||
},
|
||||
}, elem.AsDryRun())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "clone media")
|
||||
}
|
||||
|
||||
var inputMedia tg.InputMediaClass
|
||||
// now we only have to process cloned photo or document
|
||||
switch m := msg.Media.(type) {
|
||||
case *tg.MessageMediaPhoto:
|
||||
@ -179,7 +191,8 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
TTLSeconds: m.TTLSeconds,
|
||||
}
|
||||
photo.SetFlags()
|
||||
return photo, nil
|
||||
|
||||
inputMedia = photo
|
||||
case *tg.MessageMediaDocument:
|
||||
doc, ok := m.Document.AsNotEmpty()
|
||||
if !ok {
|
||||
@ -191,10 +204,10 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
return nil, errors.Errorf("empty document thumb %d", msg.ID)
|
||||
}
|
||||
|
||||
thumbFile, err := f.CloneMedia(ctx, CloneOptions{
|
||||
Media: thumb,
|
||||
PartSize: f.opts.PartSize,
|
||||
Progress: nopProgress{},
|
||||
thumbFile, err := f.cloneMedia(ctx, cloneOptions{
|
||||
elem: elem,
|
||||
media: thumb,
|
||||
progress: nopProgress{},
|
||||
}, elem.AsDryRun())
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "clone thumb")
|
||||
@ -213,10 +226,27 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
}
|
||||
document.SetFlags()
|
||||
|
||||
return document, nil
|
||||
inputMedia = document
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported media %T", msg.Media)
|
||||
}
|
||||
|
||||
// note that they must be separately uploaded using messages uploadMedia first,
|
||||
// using raw inputMediaUploaded* constructors is not supported.
|
||||
messageMedia, err := f.forwardClient(ctx, elem).MessagesUploadMedia(ctx, &tg.MessagesUploadMediaRequest{
|
||||
Peer: elem.To().InputPeer(),
|
||||
Media: inputMedia,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "upload media")
|
||||
}
|
||||
|
||||
inputMedia, ok = tmedia.ConvInputMedia(messageMedia)
|
||||
if !ok && !elem.AsDryRun() {
|
||||
return nil, errors.Errorf("can't convert uploaded media to input class")
|
||||
}
|
||||
|
||||
return inputMedia, nil
|
||||
}
|
||||
|
||||
switch elem.Mode() {
|
||||
@ -266,6 +296,7 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t
|
||||
Entities: gm.Entities,
|
||||
}
|
||||
single.SetFlags()
|
||||
|
||||
media = append(media, single)
|
||||
}
|
||||
|
||||
@ -338,6 +369,24 @@ func (n nopInvoker) Invoke(_ context.Context, _ bin.Encoder, _ bin.Decoder) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
type nopProgress struct{}
|
||||
|
||||
func (nopProgress) add(_ int64) {}
|
||||
|
||||
type wrapProgress struct {
|
||||
elem Elem
|
||||
progress ProgressClone
|
||||
done *atomic.Int64
|
||||
total int64
|
||||
}
|
||||
|
||||
func (w *wrapProgress) add(n int64) {
|
||||
w.progress.OnClone(w.elem, ProgressState{
|
||||
Done: w.done.Add(n),
|
||||
Total: w.total,
|
||||
})
|
||||
}
|
||||
|
||||
func (f *Forwarder) forwardClient(ctx context.Context, elem Elem) *tg.Client {
|
||||
if elem.AsDryRun() {
|
||||
return tg.NewClient(nopInvoker{})
|
||||
@ -369,3 +418,24 @@ func photoOrDocument(media tg.MessageMediaClass) bool {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func mediaSizeSum(msg *tg.Message, grouped ...*tg.Message) (int64, error) {
|
||||
m, ok := tmedia.GetMedia(msg)
|
||||
if !ok {
|
||||
return 0, errors.Errorf("can't get media from message")
|
||||
}
|
||||
total := m.Size
|
||||
|
||||
if len(grouped) > 0 {
|
||||
total = 0
|
||||
for _, gm := range grouped {
|
||||
m, ok := tmedia.GetMedia(gm)
|
||||
if !ok {
|
||||
return 0, errors.Errorf("can't get media from message")
|
||||
}
|
||||
total += m.Size
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
@ -1,14 +1,12 @@
|
||||
package forwarder
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/gotd/td/telegram/uploader"
|
||||
)
|
||||
type ProgressClone interface {
|
||||
OnClone(elem Elem, state ProgressState)
|
||||
}
|
||||
|
||||
type Progress interface {
|
||||
OnAdd(elem Elem)
|
||||
OnClone(elem Elem, state ProgressState)
|
||||
ProgressClone
|
||||
OnDone(elem Elem, err error)
|
||||
}
|
||||
|
||||
@ -16,20 +14,3 @@ type ProgressState struct {
|
||||
Done int64
|
||||
Total int64
|
||||
}
|
||||
|
||||
type uploadProgress struct {
|
||||
elem Elem
|
||||
progress Progress
|
||||
}
|
||||
|
||||
func (p uploadProgress) Chunk(_ context.Context, state uploader.ProgressState) error {
|
||||
p.progress.OnClone(p.elem, ProgressState{
|
||||
Done: state.Uploaded,
|
||||
Total: state.Total,
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
type nopProgress struct{}
|
||||
|
||||
func (p nopProgress) Chunk(_ context.Context, _ uploader.ProgressState) error { return nil }
|
||||
|
Loading…
Reference in New Issue
Block a user