1
0
Fork 0
mirror of https://github.com/Luzifer/go-dhparam.git synced 2024-11-09 15:50:02 +00:00

Added GenerateWithContext (#2)

This commit is contained in:
Alessandro (Ale) Segala 2020-05-04 08:59:29 -07:00 committed by GitHub
parent fd3eda34cc
commit 107406a558
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 22 deletions

View file

@ -1,6 +1,7 @@
package dhparam package dhparam
import ( import (
"context"
"crypto/rand" "crypto/rand"
"math/big" "math/big"
@ -42,6 +43,12 @@ func nullCallback(r GeneratorResult) {}
// The bit size should be adjusted to be high enough for the current requirements. Also you should keep // The bit size should be adjusted to be high enough for the current requirements. Also you should keep
// in mind the higher the bitsize, the longer the generation might take. // in mind the higher the bitsize, the longer the generation might take.
func Generate(bits int, generator Generator, cb GeneratorCallback) (*DH, error) { func Generate(bits int, generator Generator, cb GeneratorCallback) (*DH, error) {
// Invoke GenerateWithContext with a background context
return GenerateWithContext(context.Background(), bits, generator, cb)
}
// GenerateWithContext is just like the Generate function, but it accepts a ctx parameter with a context, that can be used to interrupt the generation if needed
func GenerateWithContext(ctx context.Context, bits int, generator Generator, cb GeneratorCallback) (*DH, error) {
var ( var (
err error err error
padd, rem int64 padd, rem int64
@ -62,35 +69,38 @@ func Generate(bits int, generator Generator, cb GeneratorCallback) (*DH, error)
} }
for { for {
if prime, err = genPrime(bits, big.NewInt(padd), big.NewInt(rem)); err != nil { select {
return nil, err case <-ctx.Done():
} return nil, ctx.Err()
default:
if prime, err = genPrime(bits, big.NewInt(padd), big.NewInt(rem)); err != nil {
return nil, err
}
if prime.BitLen() > bits { if prime.BitLen() > bits {
continue continue
} }
t := new(big.Int) t := new(big.Int)
t.Rsh(prime, 1) t.Rsh(prime, 1)
cb(GeneratorFoundPossiblePrime) cb(GeneratorFoundPossiblePrime)
if prime.ProbablyPrime(0) { if prime.ProbablyPrime(0) {
cb(GeneratorFirstConfirmation) cb(GeneratorFirstConfirmation)
} else { } else {
continue continue
} }
if t.ProbablyPrime(0) { if t.ProbablyPrime(0) {
cb(GeneratorSafePrimeFound) cb(GeneratorSafePrimeFound)
break return &DH{
P: prime,
G: int(generator),
}, nil
}
} }
} }
return &DH{
P: prime,
G: int(generator),
}, nil
} }
func genPrime(bits int, padd, rem *big.Int) (*big.Int, error) { func genPrime(bits int, padd, rem *big.Int) (*big.Int, error) {

View file

@ -2,12 +2,14 @@ package dhparam
import ( import (
"bytes" "bytes"
"context"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
"testing" "testing"
"time"
) )
func opensslOutput(r GeneratorResult) { func opensslOutput(r GeneratorResult) {
@ -86,6 +88,23 @@ func TestGenerator2048bit(t *testing.T) {
execGeneratorIntegration(t, 2048, GeneratorTwo) execGeneratorIntegration(t, 2048, GeneratorTwo)
} }
func TestGeneratorInterrupt(t *testing.T) {
start := time.Now()
ctx, cancel := context.WithTimeout(context.TODO(), 100*time.Millisecond)
dh, err := GenerateWithContext(ctx, 4096, GeneratorTwo, nil)
cancel()
duration := time.Since(start)
if duration > 1*time.Second {
t.Fatal("Function was not canceled early")
}
if err != context.DeadlineExceeded {
t.Fatal("Expected error to be context.DeadlineExceeded")
}
if dh != nil {
t.Fatal("Expected result to be nil")
}
}
func TestGenerator5(t *testing.T) { func TestGenerator5(t *testing.T) {
execGeneratorIntegration(t, 512, GeneratorFive) execGeneratorIntegration(t, 512, GeneratorFive)
} }