diff --git a/generator.go b/generator.go index af6853d..28573b6 100644 --- a/generator.go +++ b/generator.go @@ -1,6 +1,7 @@ package dhparam import ( + "context" "crypto/rand" "math/big" @@ -42,6 +43,12 @@ func nullCallback(r GeneratorResult) {} // The bit size should be adjusted to be high enough for the current requirements. Also you should keep // in mind the higher the bitsize, the longer the generation might take. func Generate(bits int, generator Generator, cb GeneratorCallback) (*DH, error) { + // Invoke GenerateWithContext with a background context + return GenerateWithContext(context.Background(), bits, generator, cb) +} + +// GenerateWithContext is just like the Generate function, but it accepts a ctx parameter with a context, that can be used to interrupt the generation if needed +func GenerateWithContext(ctx context.Context, bits int, generator Generator, cb GeneratorCallback) (*DH, error) { var ( err error padd, rem int64 @@ -62,35 +69,38 @@ func Generate(bits int, generator Generator, cb GeneratorCallback) (*DH, error) } for { - if prime, err = genPrime(bits, big.NewInt(padd), big.NewInt(rem)); err != nil { - return nil, err - } + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + if prime, err = genPrime(bits, big.NewInt(padd), big.NewInt(rem)); err != nil { + return nil, err + } - if prime.BitLen() > bits { - continue - } + if prime.BitLen() > bits { + continue + } - t := new(big.Int) - t.Rsh(prime, 1) + t := new(big.Int) + t.Rsh(prime, 1) - cb(GeneratorFoundPossiblePrime) + cb(GeneratorFoundPossiblePrime) - if prime.ProbablyPrime(0) { - cb(GeneratorFirstConfirmation) - } else { - continue - } + if prime.ProbablyPrime(0) { + cb(GeneratorFirstConfirmation) + } else { + continue + } - if t.ProbablyPrime(0) { - cb(GeneratorSafePrimeFound) - break + if t.ProbablyPrime(0) { + cb(GeneratorSafePrimeFound) + return &DH{ + P: prime, + G: int(generator), + }, nil + } } } - - return &DH{ - P: prime, - G: int(generator), - }, nil } func genPrime(bits int, padd, rem *big.Int) (*big.Int, error) { diff --git a/generator_test.go b/generator_test.go index 103c2b2..b04d913 100644 --- a/generator_test.go +++ b/generator_test.go @@ -2,12 +2,14 @@ package dhparam import ( "bytes" + "context" "fmt" "io/ioutil" "os" "os/exec" "strings" "testing" + "time" ) func opensslOutput(r GeneratorResult) { @@ -86,6 +88,23 @@ func TestGenerator2048bit(t *testing.T) { execGeneratorIntegration(t, 2048, GeneratorTwo) } +func TestGeneratorInterrupt(t *testing.T) { + start := time.Now() + ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond) + dh, err := GenerateWithContext(ctx, 4096, GeneratorTwo, nil) + cancel() + duration := time.Since(start) + if duration > 1*time.Second { + t.Fatal("Function was not canceled early") + } + if err != context.DeadlineExceeded { + t.Fatal("Expected error to be context.DeadlineExceeded") + } + if dh != nil { + t.Fatal("Expected result to be nil") + } +} + func TestGenerator5(t *testing.T) { execGeneratorIntegration(t, 512, GeneratorFive) }