-
Notifications
You must be signed in to change notification settings - Fork 1
/
download_queue.go
161 lines (145 loc) · 4.45 KB
/
download_queue.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package podownloader
import (
"PoDownloader/logger"
"context"
"errors"
"github.com/vbauerster/mpb/v8"
"net/http"
"os"
"os/signal"
"sync"
"syscall"
)
// DownloadQueue is the queue of download tasks
// include two types of download tasks: URLDownloadTask and TextSaveTask
type DownloadQueue struct {
tasks []interface{}
lock *sync.Mutex
}
// NewDownloadQueueFromDownloadTasks converts []*PodcastDownloadTask to *DownloadQueue
// and returns the converted *DownloadQueue
// *DownloadQueue will contain 5 types of download tasks:
// 1. Podcast cover download task
// 2. Podcast RSS download task
// 3. Episode cover download task
// 4. Episode shownotes download task
// 5. Episodes enclosures download task
// All nil tasks will be filtered out
func NewDownloadQueueFromDownloadTasks(podcastDownloadTasks []*PodcastDownloadTask) *DownloadQueue {
var tasks []interface{}
for _, podcastDownloadTask := range podcastDownloadTasks {
tasks = append(tasks, podcastDownloadTask.RSSDownloadTask)
if podcastDownloadTask.CoverDownloadTask != nil {
tasks = append(tasks, podcastDownloadTask.CoverDownloadTask)
}
for _, episodeDownloadTask := range podcastDownloadTask.EpisodeDownloadTasks {
if episodeDownloadTask.ShownotesDownloadTask != nil {
tasks = append(tasks, episodeDownloadTask.ShownotesDownloadTask)
}
if episodeDownloadTask.CoverDownloadTask != nil {
tasks = append(tasks, episodeDownloadTask.CoverDownloadTask)
}
for _, enclosureDownloadTask := range episodeDownloadTask.EnclosureDownloadTasks {
if enclosureDownloadTask != nil {
tasks = append(tasks, enclosureDownloadTask)
}
}
}
}
return &DownloadQueue{
tasks: tasks,
lock: &sync.Mutex{},
}
}
// EnQueue adds an element to the rear of the queue
func (dq *DownloadQueue) EnQueue(podcastDownloadTasks *PodcastDownloadTask) {
dq.lock.Lock()
dq.tasks = append(dq.tasks, podcastDownloadTasks)
dq.lock.Unlock()
}
// DeQueue removes an element from the front of the queue
func (dq *DownloadQueue) DeQueue() (interface{}, error) {
dq.lock.Lock()
defer dq.lock.Unlock()
if len(dq.tasks) > 0 {
frontDownloadTask := dq.tasks[0]
dq.tasks = dq.tasks[1:]
return frontDownloadTask, nil
}
return nil, errors.New("queue is empty")
}
// Front returns queue front
func (dq *DownloadQueue) Front() (interface{}, error) {
dq.lock.Lock()
defer dq.lock.Unlock()
if len(dq.tasks) > 0 {
return dq.tasks[0], nil
}
return nil, errors.New("queue is empty")
}
// Length returns queue length
func (dq *DownloadQueue) Length() int {
dq.lock.Lock()
defer dq.lock.Unlock()
return len(dq.tasks)
}
// IsEmpty returns whether the queue is empty
func (dq *DownloadQueue) IsEmpty() bool {
dq.lock.Lock()
defer dq.lock.Unlock()
return len(dq.tasks) == 0
}
// StartDownload will start threadCount download goroutines to download podcasts
// and returns the destination download paths of the failed tasks
func (dq *DownloadQueue) StartDownload(threadCount int, httpClient *http.Client, logger *logger.Logger) []string {
realThreadCount := threadCount
// When specified download threads is greater than the number of download tasks,
// using the number of download tasks as download threads
if threadCount > dq.Length() {
realThreadCount = dq.Length()
}
// Using doneWg to wait for all download workers done
doneWg := new(sync.WaitGroup)
doneWg.Add(realThreadCount)
progressBar := mpb.New(
mpb.WithWaitGroup(doneWg),
)
ctx, cancelFunc := context.WithCancel(context.Background())
downloadWorker := NewDownloadWorker(doneWg, httpClient, progressBar, logger, realThreadCount)
// Start all download workers
for i := 0; i < realThreadCount; i++ {
go downloadWorker.WorkerFunc()
}
// Listen to the SIGINT and SIGTERM signal
go func() {
termChan := make(chan os.Signal)
signal.Notify(termChan, syscall.SIGINT, syscall.SIGTERM)
<-termChan
logger.Println("Received cancellation signal, waiting for all download tasks done")
cancelFunc()
}()
// Producer: feed download tasks to downloadWorker.IngestChan
go func() {
for {
select {
case <-ctx.Done():
// Remove all unstarted download tasks
for len(downloadWorker.TasksChan) > 0 {
<-downloadWorker.TasksChan
}
close(downloadWorker.TasksChan)
return
default:
task, err := dq.DeQueue()
if err != nil {
close(downloadWorker.TasksChan)
return
}
downloadWorker.TasksChan <- task
}
}
}()
// Wait for all download workers done
progressBar.Wait()
return downloadWorker.FailedTaskDestPaths
}