From d45a0dc3387e450d5bb7a2fad6b7e0e6550d5b1f Mon Sep 17 00:00:00 2001 From: iyear Date: Sat, 9 Dec 2023 00:21:27 +0800 Subject: [PATCH] feat(forwarder): fix grouped clone, speed up clone --- app/forward/forward.go | 1 + app/forward/progress.go | 5 +- pkg/forwarder/clone.go | 141 ++++++++++++++++++++++++++----------- pkg/forwarder/forwarder.go | 90 ++++++++++++++++++++--- pkg/forwarder/progress.go | 27 ++----- 5 files changed, 189 insertions(+), 75 deletions(-) diff --git a/app/forward/forward.go b/app/forward/forward.go index 86daf9f..d76a84c 100644 --- a/app/forward/forward.go +++ b/app/forward/forward.go @@ -94,6 +94,7 @@ func Run(ctx context.Context, opts Options) error { }), Progress: newProgress(fwProgress), PartSize: viper.GetInt(consts.FlagPartSize), + Threads: viper.GetInt(consts.FlagThreads), }) go fwProgress.Render() diff --git a/app/forward/progress.go b/app/forward/progress.go index 42e7fae..cf40def 100644 --- a/app/forward/progress.go +++ b/app/forward/progress.go @@ -63,7 +63,10 @@ func (p *progress) OnDone(elem forwarder.Elem, err error) { return } - tracker.Increment(1) + if tracker.Total == 1 { + tracker.Increment(1) + } + tracker.MarkAsDone() } func (p *progress) tuple(elem forwarder.Elem) tuple { diff --git a/pkg/forwarder/clone.go b/pkg/forwarder/clone.go index 9437892..5c1da07 100644 --- a/pkg/forwarder/clone.go +++ b/pkg/forwarder/clone.go @@ -3,73 +3,132 @@ package forwarder import ( "context" "io" + "os" "github.com/go-faster/errors" "github.com/gotd/td/telegram/downloader" "github.com/gotd/td/telegram/uploader" "github.com/gotd/td/tg" + "go.uber.org/atomic" "go.uber.org/multierr" - "golang.org/x/sync/errgroup" "github.com/iyear/tdl/pkg/tmedia" ) -type CloneOptions struct { - Media *tmedia.Media - PartSize int - Progress uploader.Progress +type cloneOptions struct { + elem Elem + media *tmedia.Media + progress progressAdd } -func (f *Forwarder) CloneMedia(ctx context.Context, opts CloneOptions, dryRun bool) (tg.InputFileClass, error) { +type progressAdd interface { + add(n int64) +} + +func (f *Forwarder) cloneMedia(ctx context.Context, opts cloneOptions, dryRun bool) (_ tg.InputFileClass, rerr error) { // if dry run, just return empty input file if dryRun { // directly call progress callback - if err := opts.Progress.Chunk(ctx, uploader.ProgressState{ - Uploaded: opts.Media.Size, - Total: opts.Media.Size, - }); err != nil { - return nil, errors.Wrap(err, "dry run chunk") - } + opts.progress.add(opts.media.Size * 2) return &tg.InputFile{}, nil } - r, w := io.Pipe() + temp, err := os.CreateTemp("", "tdl_*") + if err != nil { + return nil, errors.Wrap(err, "create temp file") + } + defer func() { + multierr.AppendInto(&rerr, temp.Close()) + multierr.AppendInto(&rerr, os.Remove(temp.Name())) + }() - wg, errctx := errgroup.WithContext(ctx) + threads := bestThreads(opts.media.Size, f.opts.Threads) - wg.Go(func() (rerr error) { - defer multierr.AppendInvoke(&rerr, multierr.Close(w)) - - _, err := downloader.NewDownloader(). - WithPartSize(opts.PartSize). - Download(f.opts.Pool.Client(ctx, opts.Media.DC), opts.Media.InputFileLoc). - Stream(errctx, w) - if err != nil { - return errors.Wrap(err, "download") - } - return nil - }) + _, err = downloader.NewDownloader(). + WithPartSize(f.opts.PartSize). + Download(f.opts.Pool.Client(ctx, opts.media.DC), opts.media.InputFileLoc). + WithThreads(threads). + Parallel(ctx, writeAt{ + f: temp, + opts: opts, + }) + if err != nil { + return nil, errors.Wrap(err, "download") + } var file tg.InputFileClass - wg.Go(func() (rerr error) { - defer multierr.AppendInvoke(&rerr, multierr.Close(r)) - var err error - upload := uploader.NewUpload(opts.Media.Name, r, opts.Media.Size) - file, err = uploader.NewUploader(f.opts.Pool.Default(ctx)). - WithPartSize(opts.PartSize). - WithProgress(opts.Progress). - Upload(errctx, upload) - if err != nil { - return errors.Wrap(err, "upload") - } - return nil - }) + if _, err = temp.Seek(0, io.SeekStart); err != nil { + return nil, errors.Wrap(err, "seek") + } - if err := wg.Wait(); err != nil { - return nil, errors.Wrap(err, "wait") + upload := uploader.NewUpload(opts.media.Name, temp, opts.media.Size) + file, err = uploader.NewUploader(f.opts.Pool.Default(ctx)). + WithPartSize(f.opts.PartSize). + WithThreads(threads). + WithProgress(uploaded{ + opts: opts, + prev: atomic.NewInt64(0), + }). + Upload(ctx, upload) + if err != nil { + return nil, errors.Wrap(err, "upload") } return file, nil } + +type writeAt struct { + f io.WriterAt + opts cloneOptions +} + +func (w writeAt) WriteAt(p []byte, off int64) (int, error) { + n, err := w.f.WriteAt(p, off) + if err != nil { + return 0, err + } + + w.opts.progress.add(int64(n)) + + return n, nil +} + +type uploaded struct { + opts cloneOptions + prev *atomic.Int64 +} + +func (u uploaded) Chunk(_ context.Context, state uploader.ProgressState) error { + u.opts.progress.add(state.Uploaded - u.prev.Swap(state.Uploaded)) + + return nil +} + +var threadsLevels = []struct { + threads int + size int64 +}{ + {1, 1 << 20}, + {2, 5 << 20}, + {4, 20 << 20}, + {8, 50 << 20}, +} + +// Get best threads num for download, based on file size +func bestThreads(size int64, max int) int { + for _, t := range threadsLevels { + if size < t.size { + return min(t.threads, max) + } + } + return max +} + +func min(a, b int) int { + if a < b { + return a + } + return b +} diff --git a/pkg/forwarder/forwarder.go b/pkg/forwarder/forwarder.go index 1c60844..ce6f4a5 100644 --- a/pkg/forwarder/forwarder.go +++ b/pkg/forwarder/forwarder.go @@ -10,6 +10,7 @@ import ( "github.com/gotd/td/telegram/message" "github.com/gotd/td/telegram/peers" "github.com/gotd/td/tg" + "go.uber.org/atomic" "go.uber.org/zap" "github.com/iyear/tdl/pkg/dcpool" @@ -27,6 +28,7 @@ type Mode int type Options struct { Pool dcpool.Pool PartSize int + Threads int Iter Iter Progress Progress } @@ -100,6 +102,13 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t zap.Int64("to", elem.To().ID()), zap.Int("message", elem.Msg().ID)) + // used for clone progress + totalSize, err := mediaSizeSum(elem.Msg(), grouped...) + if err != nil { + return errors.Wrap(err, "media total size") + } + done := atomic.NewInt64(0) + forwardTextOnly := func(msg *tg.Message) error { if msg.Message == "" { return errors.Errorf("empty message content, skip send: %d", msg.ID) @@ -158,18 +167,21 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t return nil, errors.Errorf("unsupported media %T", msg.Media) } - mediaFile, err := f.CloneMedia(ctx, CloneOptions{ - Media: media, - PartSize: f.opts.PartSize, - Progress: uploadProgress{ + mediaFile, err := f.cloneMedia(ctx, cloneOptions{ + elem: elem, + media: media, + progress: &wrapProgress{ elem: elem, progress: f.opts.Progress, + done: done, + total: totalSize * 2, }, }, elem.AsDryRun()) if err != nil { return nil, errors.Wrap(err, "clone media") } + var inputMedia tg.InputMediaClass // now we only have to process cloned photo or document switch m := msg.Media.(type) { case *tg.MessageMediaPhoto: @@ -179,7 +191,8 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t TTLSeconds: m.TTLSeconds, } photo.SetFlags() - return photo, nil + + inputMedia = photo case *tg.MessageMediaDocument: doc, ok := m.Document.AsNotEmpty() if !ok { @@ -191,10 +204,10 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t return nil, errors.Errorf("empty document thumb %d", msg.ID) } - thumbFile, err := f.CloneMedia(ctx, CloneOptions{ - Media: thumb, - PartSize: f.opts.PartSize, - Progress: nopProgress{}, + thumbFile, err := f.cloneMedia(ctx, cloneOptions{ + elem: elem, + media: thumb, + progress: nopProgress{}, }, elem.AsDryRun()) if err != nil { return nil, errors.Wrap(err, "clone thumb") @@ -213,10 +226,27 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t } document.SetFlags() - return document, nil + inputMedia = document default: return nil, errors.Errorf("unsupported media %T", msg.Media) } + + // note that they must be separately uploaded using messages uploadMedia first, + // using raw inputMediaUploaded* constructors is not supported. + messageMedia, err := f.forwardClient(ctx, elem).MessagesUploadMedia(ctx, &tg.MessagesUploadMediaRequest{ + Peer: elem.To().InputPeer(), + Media: inputMedia, + }) + if err != nil { + return nil, errors.Wrap(err, "upload media") + } + + inputMedia, ok = tmedia.ConvInputMedia(messageMedia) + if !ok && !elem.AsDryRun() { + return nil, errors.Errorf("can't convert uploaded media to input class") + } + + return inputMedia, nil } switch elem.Mode() { @@ -266,6 +296,7 @@ func (f *Forwarder) forwardMessage(ctx context.Context, elem Elem, grouped ...*t Entities: gm.Entities, } single.SetFlags() + media = append(media, single) } @@ -338,6 +369,24 @@ func (n nopInvoker) Invoke(_ context.Context, _ bin.Encoder, _ bin.Decoder) erro return nil } +type nopProgress struct{} + +func (nopProgress) add(_ int64) {} + +type wrapProgress struct { + elem Elem + progress ProgressClone + done *atomic.Int64 + total int64 +} + +func (w *wrapProgress) add(n int64) { + w.progress.OnClone(w.elem, ProgressState{ + Done: w.done.Add(n), + Total: w.total, + }) +} + func (f *Forwarder) forwardClient(ctx context.Context, elem Elem) *tg.Client { if elem.AsDryRun() { return tg.NewClient(nopInvoker{}) @@ -369,3 +418,24 @@ func photoOrDocument(media tg.MessageMediaClass) bool { return false } } + +func mediaSizeSum(msg *tg.Message, grouped ...*tg.Message) (int64, error) { + m, ok := tmedia.GetMedia(msg) + if !ok { + return 0, errors.Errorf("can't get media from message") + } + total := m.Size + + if len(grouped) > 0 { + total = 0 + for _, gm := range grouped { + m, ok := tmedia.GetMedia(gm) + if !ok { + return 0, errors.Errorf("can't get media from message") + } + total += m.Size + } + } + + return total, nil +} diff --git a/pkg/forwarder/progress.go b/pkg/forwarder/progress.go index 4d36938..5e98bd3 100644 --- a/pkg/forwarder/progress.go +++ b/pkg/forwarder/progress.go @@ -1,14 +1,12 @@ package forwarder -import ( - "context" - - "github.com/gotd/td/telegram/uploader" -) +type ProgressClone interface { + OnClone(elem Elem, state ProgressState) +} type Progress interface { OnAdd(elem Elem) - OnClone(elem Elem, state ProgressState) + ProgressClone OnDone(elem Elem, err error) } @@ -16,20 +14,3 @@ type ProgressState struct { Done int64 Total int64 } - -type uploadProgress struct { - elem Elem - progress Progress -} - -func (p uploadProgress) Chunk(_ context.Context, state uploader.ProgressState) error { - p.progress.OnClone(p.elem, ProgressState{ - Done: state.Uploaded, - Total: state.Total, - }) - return nil -} - -type nopProgress struct{} - -func (p nopProgress) Chunk(_ context.Context, _ uploader.ProgressState) error { return nil }