diff --git a/hdkeychain/bench_test.go b/hdkeychain/bench_test.go index dde59b190..9c32b5c0e 100644 --- a/hdkeychain/bench_test.go +++ b/hdkeychain/bench_test.go @@ -26,7 +26,7 @@ func BenchmarkDeriveHardened(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - masterKey.Child(hdkeychain.HardenedKeyStart) + masterKey.Child(hdkeychain.HardenedKeyStart, true) } } @@ -41,7 +41,7 @@ func BenchmarkDeriveNormal(b *testing.B) { b.StartTimer() for i := 0; i < b.N; i++ { - masterKey.Child(0) + masterKey.Child(0, true) } } diff --git a/hdkeychain/example_test.go b/hdkeychain/example_test.go index 63d369f31..dc7374f6f 100644 --- a/hdkeychain/example_test.go +++ b/hdkeychain/example_test.go @@ -71,7 +71,7 @@ func Example_defaultWalletLayout() { // Derive the extended key for account 0. This gives the path: // m/0H - acct0, err := masterKey.Child(hdkeychain.HardenedKeyStart + 0) + acct0, err := masterKey.Child(hdkeychain.HardenedKeyStart+0, true) if err != nil { fmt.Println(err) return @@ -80,7 +80,7 @@ func Example_defaultWalletLayout() { // Derive the extended key for the account 0 external chain. This // gives the path: // m/0H/0 - acct0Ext, err := acct0.Child(0) + acct0Ext, err := acct0.Child(0, true) if err != nil { fmt.Println(err) return @@ -89,7 +89,7 @@ func Example_defaultWalletLayout() { // Derive the extended key for the account 0 internal chain. This gives // the path: // m/0H/1 - acct0Int, err := acct0.Child(1) + acct0Int, err := acct0.Child(1, true) if err != nil { fmt.Println(err) return @@ -101,7 +101,7 @@ func Example_defaultWalletLayout() { // Derive the 10th extended key for the account 0 external chain. This // gives the path: // m/0H/0/10 - acct0Ext10, err := acct0Ext.Child(10) + acct0Ext10, err := acct0Ext.Child(10, true) if err != nil { fmt.Println(err) return @@ -110,7 +110,7 @@ func Example_defaultWalletLayout() { // Derive the 1st extended key for the account 0 internal chain. This // gives the path: // m/0H/1/0 - acct0Int0, err := acct0Int.Child(0) + acct0Int0, err := acct0Int.Child(0, true) if err != nil { fmt.Println(err) return diff --git a/hdkeychain/extendedkey.go b/hdkeychain/extendedkey.go index 53486adc1..09b83a53a 100644 --- a/hdkeychain/extendedkey.go +++ b/hdkeychain/extendedkey.go @@ -206,7 +206,11 @@ func (k *ExtendedKey) ParentFingerprint() uint32 { // index does not derive to a usable child. The ErrInvalidChild error will be // returned if this should occur, and the caller is expected to ignore the // invalid child and simply increment to the next index. -func (k *ExtendedKey) Child(i uint32) (*ExtendedKey, error) { +// +// Use the 2nd parameter to fix the bug with leading zeros. This change was made +// as an optional parameter in order to maintain backward compatibility and avoid +// losing wallets created before fixing this bug. +func (k *ExtendedKey) Child(i uint32, fixLeadingZeroBug ...bool) (*ExtendedKey, error) { // Prevent derivation of children beyond the max allowed depth. if k.depth == maxUint8 { return nil, ErrDeriveBeyondMaxDepth @@ -295,6 +299,11 @@ func (k *ExtendedKey) Child(i uint32) (*ExtendedKey, error) { ilNum.Add(ilNum, keyNum) ilNum.Mod(ilNum, btcec.S256().N) childKey = ilNum.Bytes() + // Correction a key-length with leading zero + if len(childKey) < 32 && len(fixLeadingZeroBug) > 0 && fixLeadingZeroBug[0] == true { + extra := make([]byte, 32-len(childKey)) + childKey = append(extra, childKey...) + } isPrivate = true } else { // Case #3. diff --git a/hdkeychain/extendedkey_test.go b/hdkeychain/extendedkey_test.go index 00699d6c2..a58bcbca3 100644 --- a/hdkeychain/extendedkey_test.go +++ b/hdkeychain/extendedkey_test.go @@ -26,6 +26,7 @@ func TestBIP0032Vectors(t *testing.T) { testVec1MasterHex := "000102030405060708090a0b0c0d0e0f" testVec2MasterHex := "fffcf9f6f3f0edeae7e4e1dedbd8d5d2cfccc9c6c3c0bdbab7b4b1aeaba8a5a29f9c999693908d8a8784817e7b7875726f6c696663605d5a5754514e4b484542" testVec3MasterHex := "4b381541583be4423346c643850da4b320e46a87ae3d2a4e6da11eba819cd4acba45d239319ac14f863b8d5ab5a0d0c64d2e8a1e7d1457df2e5a3c51c73235be" + testVec3aMasterHex := "57fb1e450b8afb95c62afbcd49e4100d6790e0822b8905608679180ac34ca0bd45bf7ccc6c5f5218236d0eb93afc78bd117b9f02a6b7df258ea182dfaef5aad7" hkStart := uint32(0x80000000) tests := []struct { @@ -153,6 +154,14 @@ func TestBIP0032Vectors(t *testing.T) { wantPriv: "xprv9uPDJpEQgRQfDcW7BkF7eTya6RPxXeJCqCJGHuCJ4GiRVLzkTXBAJMu2qaMWPrS7AANYqdq6vcBcBUdJCVVFceUvJFjaPdGZ2y9WACViL4L", net: &chaincfg.MainNetParams, }, + { + name: "test vector 3 chain m/44H/60H/0H", + master: testVec3aMasterHex, + path: []uint32{hkStart + 44, hkStart + 60, hkStart}, + wantPub: "xpub6CpsfWjghR6XdCB8yDq7jQRpRKEDP2LT3ZRUgURF9g5xevB7YoTpogkFRqq5nQtVSN8YCMZo2CD8u4zCaxRv85ctCWmzEi9gQ5DBhBFaTNo", + wantPriv: "xprv9yqXG1Cns3YEQi6fsCJ7NGV5sHPiyZcbgLVst61dbLYyn7qy1G9aFtRmaYp481ounqnVf9Go2ymQ4gmxZLEwYSRhU868aDk4ZxzGvqHJVhe", + net: &chaincfg.MainNetParams, + }, // Test vector 1 - Testnet { @@ -224,7 +233,7 @@ tests: for _, childNum := range test.path { var err error - extKey, err = extKey.Child(childNum) + extKey, err = extKey.Child(childNum, true) if err != nil { t.Errorf("err: %v", err) continue tests @@ -381,7 +390,7 @@ tests: for _, childNum := range test.path { var err error - extKey, err = extKey.Child(childNum) + extKey, err = extKey.Child(childNum, true) if err != nil { t.Errorf("err: %v", err) continue tests @@ -500,7 +509,7 @@ tests: for _, childNum := range test.path { var err error - extKey, err = extKey.Child(childNum) + extKey, err = extKey.Child(childNum, true) if err != nil { t.Errorf("err: %v", err) continue tests @@ -830,7 +839,7 @@ func TestErrors(t *testing.T) { } // Deriving a hardened child extended key should fail from a public key. - _, err = pubKey.Child(HardenedKeyStart) + _, err = pubKey.Child(HardenedKeyStart, true) if err != ErrDeriveHardFromPublic { t.Fatalf("Child: mismatched error -- got: %v, want: %v", err, ErrDeriveHardFromPublic) @@ -1052,14 +1061,14 @@ func TestMaximumDepth(t *testing.T) { t.Fatalf("extendedkey depth %d should match expected value %d", extKey.Depth(), i) } - newKey, err := extKey.Child(1) + newKey, err := extKey.Child(1, true) if err != nil { t.Fatalf("Child: unexpected error: %v", err) } extKey = newKey } - noKey, err := extKey.Child(1) + noKey, err := extKey.Child(1, true) if err != ErrDeriveBeyondMaxDepth { t.Fatalf("Child: mismatched error: want %v, got %v", ErrDeriveBeyondMaxDepth, err)