diff --git a/internal/linkcheck/linkcheck.go b/internal/linkcheck/linkcheck.go index 5933d57..a29efa2 100644 --- a/internal/linkcheck/linkcheck.go +++ b/internal/linkcheck/linkcheck.go @@ -1,57 +1,35 @@ package linkcheck import ( - "context" - "crypto/rand" - _ "embed" - "math/big" - "net/http" - "net/http/cookiejar" - "net/url" "regexp" "strings" - "time" + "sync" "github.com/Luzifer/go_helpers/v2/str" ) -const ( - // DefaultCheckTimeout defines the default time the request to a site - // may take to answer - DefaultCheckTimeout = 10 * time.Second - - maxRedirects = 50 -) - type ( // Checker contains logic to detect and resolve links in a message Checker struct { - checkTimeout time.Duration - userAgents []string - - skipValidation bool // Only for tests, not settable from the outside + res *resolver } ) -var ( - defaultUserAgents = []string{} - linkTest = regexp.MustCompile(`(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-]{0,61}[a-z0-9]`) - numericHost = regexp.MustCompile(`^(?:[0-9]+\.)*[0-9]+(?::[0-9]+)?$`) - - //go:embed user-agents.txt - uaList string -) - -func init() { - defaultUserAgents = strings.Split(strings.TrimSpace(uaList), "\n") -} - // New creates a new Checker instance with default settings -func New() *Checker { - return &Checker{ - checkTimeout: DefaultCheckTimeout, - userAgents: defaultUserAgents, +func New(opts ...func(*Checker)) *Checker { + c := &Checker{ + res: defaultResolver, } + + for _, o := range opts { + o(c) + } + + return c +} + +func withResolver(r *resolver) func(*Checker) { + return func(c *Checker) { c.res = r } } // HeuristicScanForLinks takes a message and tries to find links @@ -74,120 +52,6 @@ func (c Checker) ScanForLinks(message string) (links []string) { return c.scan(message, c.scanPlainNoObfuscate) } -// resolveFinal takes a link and looks up the final destination of -// that link after all redirects were followed -func (c Checker) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack []string, userAgent string) string { - if !linkTest.MatchString(link) && !c.skipValidation { - return "" - } - - if str.StringInSlice(link, callStack) || len(callStack) == maxRedirects { - // We got ourselves a loop: Yay! - return link - } - - client := &http.Client{ - CheckRedirect: func(req *http.Request, via []*http.Request) error { - return http.ErrUseLastResponse - }, - Jar: cookieJar, - } - - ctx, cancel := context.WithTimeout(context.Background(), c.checkTimeout) - defer cancel() - - u, err := url.Parse(link) - if err != nil { - return "" - } - - if u.Scheme == "" { - // We have no scheme and the url is in the path, lets add the - // scheme and re-parse the URL to avoid some confusion - u.Scheme = "http" - u, err = url.Parse(u.String()) - if err != nil { - return "" - } - } - - if numericHost.MatchString(u.Host) && !c.skipValidation { - // Host is fully numeric: We don't support scanning that - return "" - } - - // Sanitize host: Trailing dots are valid but not required - u.Host = strings.TrimRight(u.Host, ".") - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) - if err != nil { - return "" - } - - req.Header.Set("User-Agent", userAgent) - - resp, err := client.Do(req) - if err != nil { - return "" - } - defer resp.Body.Close() - - if resp.StatusCode > 299 && resp.StatusCode < 400 { - // We got a redirect - tu, err := url.Parse(resp.Header.Get("location")) - if err != nil { - return "" - } - target := c.resolveReference(u, tu) - return c.resolveFinal(target, cookieJar, append(callStack, link), userAgent) - } - - // We got a response, it's no redirect, we count this as a success - return u.String() -} - -func (Checker) resolveReference(origin *url.URL, loc *url.URL) string { - // Special Case: vkontakte used as shortener / obfuscation - if loc.Path == "/away.php" && loc.Query().Has("to") { - // VK is doing HTML / JS redirect magic so we take that from them - // and execute the redirect directly here in code - return loc.Query().Get("to") - } - - if loc.Host == "consent.youtube.com" && loc.Query().Has("continue") { - // Youtube links end up in consent page but we want the real - // target so we use the continue parameter where we strip the - // cbrd query parameters as that one causes an infinite loop. - - contTarget, err := url.Parse(loc.Query().Get("continue")) - if err == nil { - v := contTarget.Query() - v.Del("cbrd") - - contTarget.RawQuery = v.Encode() - return contTarget.String() - } - - return loc.Query().Get("continue") - } - - if loc.Host == "www.instagram.com" && loc.Query().Has("next") { - // Instagram likes its login page, we on the other side don't - // care about the sign-in or even the content. Therefore we - // just take their redirect target and use that as the next - // URL - return loc.Query().Get("next") - } - - // Default fallback behavior: Do a normal resolve - return origin.ResolveReference(loc).String() -} - -func (Checker) getJar() *cookiejar.Jar { - jar, _ := cookiejar.New(nil) - return jar -} - func (c Checker) scan(message string, scanFns ...func(string) []string) (links []string) { for _, scanner := range scanFns { if links = scanner(message); links != nil { @@ -219,30 +83,40 @@ func (c Checker) scanObfuscateSpecialCharsAndSpaces(set *regexp.Regexp, connecto } func (c Checker) scanPartsConnected(parts []string, connector string) (links []string) { + wg := new(sync.WaitGroup) + for ptJoin := 2; ptJoin < len(parts); ptJoin++ { for i := 0; i <= len(parts)-ptJoin; i++ { - if link := c.resolveFinal(strings.Join(parts[i:i+ptJoin], connector), c.getJar(), nil, c.userAgent()); link != "" && !str.StringInSlice(link, links) { - links = append(links, link) - } + wg.Add(1) + c.res.Resolve(resolverQueueEntry{ + Link: strings.Join(parts[i:i+ptJoin], connector), + Callback: func(link string) { links = str.AppendIfMissing(links, link) }, + WaitGroup: wg, + }) } } + wg.Wait() + return links } func (c Checker) scanPlainNoObfuscate(message string) (links []string) { - parts := regexp.MustCompile(`\s+`).Split(message, -1) + var ( + parts = regexp.MustCompile(`\s+`).Split(message, -1) + wg = new(sync.WaitGroup) + ) for _, part := range parts { - if link := c.resolveFinal(part, c.getJar(), nil, c.userAgent()); link != "" && !str.StringInSlice(link, links) { - links = append(links, link) - } + wg.Add(1) + c.res.Resolve(resolverQueueEntry{ + Link: part, + Callback: func(link string) { links = str.AppendIfMissing(links, link) }, + WaitGroup: wg, + }) } + wg.Wait() + return links } - -func (c Checker) userAgent() string { - n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(c.userAgents)))) - return c.userAgents[n.Int64()] -} diff --git a/internal/linkcheck/linkcheck_test.go b/internal/linkcheck/linkcheck_test.go index bad558a..68c5b28 100644 --- a/internal/linkcheck/linkcheck_test.go +++ b/internal/linkcheck/linkcheck_test.go @@ -18,13 +18,11 @@ func TestInfiniteRedirect(t *testing.T) { hdl.HandleFunc("/test", func(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, "/", http.StatusFound) }) var ( - c = New() + c = New(withResolver(newResolver(1, withSkipVerify()))) ts = httptest.NewServer(hdl) ) t.Cleanup(ts.Close) - c.skipValidation = true - msg := fmt.Sprintf("Here have a redirect loop: %s", ts.URL) // We expect /test to be the first repeat as the callstack will look like this: @@ -41,13 +39,11 @@ func TestMaxRedirects(t *testing.T) { }) var ( - c = New() + c = New(withResolver(newResolver(1, withSkipVerify()))) ts = httptest.NewServer(hdl) ) t.Cleanup(ts.Close) - c.skipValidation = true - msg := fmt.Sprintf("Here have a redirect loop: %s", ts.URL) // We expect the call to `/N` to have N previous entries and therefore be the break-point @@ -203,13 +199,10 @@ func TestUserAgentListNotEmpty(t *testing.T) { } func TestUserAgentRandomizer(t *testing.T) { - var ( - c = New() - uas = map[string]int{} - ) + uas := map[string]int{} for i := 0; i < 10; i++ { - uas[c.userAgent()]++ + uas[defaultResolver.userAgent()]++ } for _, c := range uas { diff --git a/internal/linkcheck/resolver.go b/internal/linkcheck/resolver.go new file mode 100644 index 0000000..8da0b74 --- /dev/null +++ b/internal/linkcheck/resolver.go @@ -0,0 +1,206 @@ +package linkcheck + +import ( + "context" + "crypto/rand" + _ "embed" + "math/big" + "net/http" + "net/http/cookiejar" + "net/url" + "regexp" + "strings" + "sync" + "time" + + "github.com/Luzifer/go_helpers/v2/str" +) + +const ( + // DefaultCheckTimeout defines the default time the request to a site + // may take to answer + DefaultCheckTimeout = 10 * time.Second + + maxRedirects = 50 + resolverPoolSize = 25 +) + +type ( + resolver struct { + resolverC chan resolverQueueEntry + skipValidation bool + } + + resolverQueueEntry struct { + Link string + Callback func(string) + WaitGroup *sync.WaitGroup + } +) + +var ( + defaultUserAgents = []string{} + linkTest = regexp.MustCompile(`(?:[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?\.)+[a-z0-9][a-z0-9-]{0,61}[a-z0-9]`) + numericHost = regexp.MustCompile(`^(?:[0-9]+\.)*[0-9]+(?::[0-9]+)?$`) + + //go:embed user-agents.txt + uaList string + + defaultResolver = newResolver(resolverPoolSize) +) + +func init() { + defaultUserAgents = strings.Split(strings.TrimSpace(uaList), "\n") +} + +func newResolver(poolSize int, opts ...func(*resolver)) *resolver { + r := &resolver{ + resolverC: make(chan resolverQueueEntry), + } + + for _, o := range opts { + o(r) + } + + for i := 0; i < poolSize; i++ { + go r.runResolver() + } + + return r +} + +func withSkipVerify() func(*resolver) { + return func(r *resolver) { r.skipValidation = true } +} + +func (r resolver) Resolve(qe resolverQueueEntry) { + r.resolverC <- qe +} + +func (resolver) getJar() *cookiejar.Jar { + jar, _ := cookiejar.New(nil) + return jar +} + +// resolveFinal takes a link and looks up the final destination of +// that link after all redirects were followed +func (r resolver) resolveFinal(link string, cookieJar *cookiejar.Jar, callStack []string, userAgent string) string { + if !linkTest.MatchString(link) && !r.skipValidation { + return "" + } + + if str.StringInSlice(link, callStack) || len(callStack) == maxRedirects { + // We got ourselves a loop: Yay! + return link + } + + client := &http.Client{ + CheckRedirect: func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + }, + Jar: cookieJar, + } + + ctx, cancel := context.WithTimeout(context.Background(), DefaultCheckTimeout) + defer cancel() + + u, err := url.Parse(link) + if err != nil { + return "" + } + + if u.Scheme == "" { + // We have no scheme and the url is in the path, lets add the + // scheme and re-parse the URL to avoid some confusion + u.Scheme = "http" + u, err = url.Parse(u.String()) + if err != nil { + return "" + } + } + + if numericHost.MatchString(u.Host) && !r.skipValidation { + // Host is fully numeric: We don't support scanning that + return "" + } + + // Sanitize host: Trailing dots are valid but not required + u.Host = strings.TrimRight(u.Host, ".") + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return "" + } + + req.Header.Set("User-Agent", userAgent) + + resp, err := client.Do(req) + if err != nil { + return "" + } + defer resp.Body.Close() + + if resp.StatusCode > 299 && resp.StatusCode < 400 { + // We got a redirect + tu, err := url.Parse(resp.Header.Get("location")) + if err != nil { + return "" + } + target := r.resolveReference(u, tu) + return r.resolveFinal(target, cookieJar, append(callStack, link), userAgent) + } + + // We got a response, it's no redirect, we count this as a success + return u.String() +} + +func (resolver) resolveReference(origin *url.URL, loc *url.URL) string { + // Special Case: vkontakte used as shortener / obfuscation + if loc.Path == "/away.php" && loc.Query().Has("to") { + // VK is doing HTML / JS redirect magic so we take that from them + // and execute the redirect directly here in code + return loc.Query().Get("to") + } + + if loc.Host == "consent.youtube.com" && loc.Query().Has("continue") { + // Youtube links end up in consent page but we want the real + // target so we use the continue parameter where we strip the + // cbrd query parameters as that one causes an infinite loop. + + contTarget, err := url.Parse(loc.Query().Get("continue")) + if err == nil { + v := contTarget.Query() + v.Del("cbrd") + + contTarget.RawQuery = v.Encode() + return contTarget.String() + } + + return loc.Query().Get("continue") + } + + if loc.Host == "www.instagram.com" && loc.Query().Has("next") { + // Instagram likes its login page, we on the other side don't + // care about the sign-in or even the content. Therefore we + // just take their redirect target and use that as the next + // URL + return loc.Query().Get("next") + } + + // Default fallback behavior: Do a normal resolve + return origin.ResolveReference(loc).String() +} + +func (r resolver) runResolver() { + for qe := range r.resolverC { + if link := r.resolveFinal(qe.Link, r.getJar(), nil, r.userAgent()); link != "" { + qe.Callback(link) + } + qe.WaitGroup.Done() + } +} + +func (resolver) userAgent() string { + n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(defaultUserAgents)))) + return defaultUserAgents[n.Int64()] +}