diff --git a/go.mod b/go.mod index 527552a..7e722c9 100644 --- a/go.mod +++ b/go.mod @@ -63,5 +63,6 @@ require ( github.com/projectdiscovery/subfinder/v2 v2.5.6 github.com/stretchr/testify v1.8.2 github.com/valyala/bytebufferpool v1.0.0 // indirect + golang.org/x/sync v0.1.0 golang.org/x/sys v0.4.0 // indirect ) diff --git a/go.sum b/go.sum index 1d0a6a0..684c8e2 100644 --- a/go.sum +++ b/go.sum @@ -144,6 +144,8 @@ golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o= +golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/domain/enumeration.go b/pkg/domain/enumeration.go index 68c1757..9ccb31f 100644 --- a/pkg/domain/enumeration.go +++ b/pkg/domain/enumeration.go @@ -7,6 +7,7 @@ import ( "github.com/Escape-Technologies/goctopus/pkg/config" "github.com/projectdiscovery/subfinder/v2/pkg/resolve" "github.com/projectdiscovery/subfinder/v2/pkg/runner" + log "github.com/sirupsen/logrus" ) func makeCallback(domain *address.Addr, subDomains chan *address.Addr) func(s *resolve.HostEntry) { @@ -15,7 +16,12 @@ func makeCallback(domain *address.Addr, subDomains chan *address.Addr) func(s *r } } -func EnumerateSubdomains(domain *address.Addr, subDomains chan *address.Addr) (err error) { +var ( + runnerInstance *runner.Runner +) + +func EnumerateSubdomains(domain *address.Addr, subDomains chan *address.Addr, threads int) (err error) { + log.Errorf("Enumerating subdomains for %s with %d threads", domain, threads) subDomains <- domain c := config.Get() @@ -23,14 +29,16 @@ func EnumerateSubdomains(domain *address.Addr, subDomains chan *address.Addr) (e return nil } - runnerInstance, _ := runner.NewRunner(&runner.Options{ - Threads: c.MaxWorkers, // Thread controls the number of threads to use for active enumerations - Timeout: c.Timeout, // Timeout is the seconds to wait for sources to respond - MaxEnumerationTime: 5, // MaxEnumerationTime is the maximum amount of time in mins to wait for enumeration - Resolvers: resolve.DefaultResolvers, // Use the default list of resolvers by marshaling it to the config - ResultCallback: makeCallback(domain, subDomains), - Silent: true, - }) + if runnerInstance == nil { + runnerInstance, _ = runner.NewRunner(&runner.Options{ + Threads: threads, // Thread controls the number of threads to use for active enumerations + Timeout: c.Timeout, // Timeout is the seconds to wait for sources to respond + MaxEnumerationTime: 5, // MaxEnumerationTime is the maximum amount of time in mins to wait for enumeration + Resolvers: resolve.DefaultResolvers, // Use the default list of resolvers by marshaling it to the config + ResultCallback: makeCallback(domain, subDomains), + Silent: true, + }) + } err = runnerInstance.EnumerateSingleDomain(domain.Address, []io.Writer{}) return err diff --git a/pkg/goctopus/fingerprint.go b/pkg/goctopus/fingerprint.go index 1f60d56..6f872a8 100644 --- a/pkg/goctopus/fingerprint.go +++ b/pkg/goctopus/fingerprint.go @@ -1,6 +1,7 @@ package goctopus import ( + "context" "sync" "github.com/Escape-Technologies/goctopus/internal/utils" @@ -10,18 +11,20 @@ import ( "github.com/Escape-Technologies/goctopus/pkg/endpoint" "github.com/Escape-Technologies/goctopus/pkg/output" log "github.com/sirupsen/logrus" + "golang.org/x/sync/semaphore" ) -func worker(addresses chan *address.Addr, output chan *output.FingerprintOutput, workerId int, wg *sync.WaitGroup) { - defer wg.Done() +func worker(addresses chan *address.Addr, output chan *output.FingerprintOutput, workerId int, sem *semaphore.Weighted) { log.Debugf("Worker %d instantiated", workerId) for address := range addresses { + sem.Acquire(context.Background(), 1) log.Debugf("Worker %d started on: %v", workerId, address) res, err := FingerprintAddress(address) if err == nil { log.Debugf("Worker %d found endpoint: %v", workerId, res) output <- res } + sem.Release(1) } log.Debugf("Worker %d finished", workerId) } @@ -35,17 +38,26 @@ func FingerprintAddress(address *address.Addr) (*output.FingerprintOutput, error } } +func asyncEnumeration(address *address.Addr, enumeratedAddresses chan *address.Addr, threads int, sem *semaphore.Weighted, wg *sync.WaitGroup) { + defer wg.Done() + defer sem.Release(int64(threads)) + if err := domain.EnumerateSubdomains(address, enumeratedAddresses, threads); err != nil { + log.Errorf("Error enumerating subdomains for %v: %v", address, err) + } +} + // An addresses can be a domain or an url func FingerprintAddresses(addresses chan *address.Addr, output chan *output.FingerprintOutput) { maxWorkers := config.Get().MaxWorkers enumeratedAddresses := make(chan *address.Addr, config.Get().MaxWorkers) - workersWg := sync.WaitGroup{} - workersWg.Add(maxWorkers) + sem := semaphore.NewWeighted(int64(maxWorkers)) + enumerationWg := sync.WaitGroup{} + enumerationThreads := utils.MinInt(maxWorkers, 10) for i := 0; i < maxWorkers; i++ { - go worker(enumeratedAddresses, output, i, &workersWg) + go worker(enumeratedAddresses, output, i, sem) } i := 1 @@ -53,18 +65,22 @@ func FingerprintAddresses(addresses chan *address.Addr, output chan *output.Fing log.Debugf("(%d) Adding %v to the queue", i, address) // If the domain is a url, we don't need to crawl it if utils.IsUrl(address.Address) { + sem.Acquire(context.Background(), 1) enumeratedAddresses <- address } else { - if err := domain.EnumerateSubdomains(address, enumeratedAddresses); err != nil { - log.Errorf("Error enumerating subdomains for %v: %v", address, err) - } + // 10 threads for subdomain enumeration, unless maxWorkers is less than 10 + enumerationWg.Add(1) + sem.Acquire(context.Background(), int64(enumerationThreads)) + log.Errorf("%v", address) + go asyncEnumeration(address, enumeratedAddresses, enumerationThreads, sem, &enumerationWg) } i++ } + enumerationWg.Wait() close(enumeratedAddresses) log.Debugf("Waiting for workers to finish...") - workersWg.Wait() + sem.Acquire(context.Background(), int64(maxWorkers)) close(output) log.Debugf("All workers finished") }