|
4 | 4 | "bytes" |
5 | 5 | "crypto/aes" |
6 | 6 | "io/ioutil" |
| 7 | + "math/rand" |
7 | 8 |
|
8 | 9 | "encoding/base64" |
9 | 10 | "encoding/json" |
@@ -51,38 +52,74 @@ func (c *Client) kmsAuth(url string) error { |
51 | 52 | return nil |
52 | 53 | } |
53 | 54 |
|
54 | | -func kmsSplit(keyProvider string) []string { |
55 | | - return strings.Split(keyProvider, ",") |
56 | | -} |
| 55 | +// parse uri like kms://[email protected];kms02.example.com:9600/kms |
| 56 | +func kmsParseProviderUri(uri string) ([]string, error) { |
| 57 | + original_uri := uri |
57 | 58 |
|
58 | | -// kmsUrl parse KeyProviderUri to list of URL's |
59 | | -func (c *Client) kmsUrl(einfo *hdfs.FileEncryptionInfoProto) ([]string, error) { |
60 | | - defaults, err := c.fetchDefaults() |
61 | | - if err != nil { |
62 | | - return nil, err |
63 | | - } |
64 | | - |
65 | | - uri := defaults.GetKeyProviderUri() |
66 | 59 | if uri == "" { |
67 | | - return nil, errors.New("KeyProviderUri not configured on server") |
| 60 | + return nil, errors.New("KeyProviderUri empty. not configured on server ?") |
68 | 61 | } |
69 | 62 |
|
70 | 63 | var urls []string |
71 | 64 | var proto string |
72 | 65 | if strings.HasPrefix(uri, kmsSchemeHTTPS) { |
73 | 66 | proto = "https://" |
74 | | - urls = kmsSplit(uri[len(kmsSchemeHTTPS):]) |
| 67 | + uri = uri[len(kmsSchemeHTTPS):] |
75 | 68 | } |
76 | 69 | if proto == "" && strings.HasPrefix(uri, kmsSchemeHTTP) { |
77 | 70 | proto = "http://" |
78 | | - urls = kmsSplit(uri[len(kmsSchemeHTTP):]) |
| 71 | + uri = uri[len(kmsSchemeHTTP):] |
79 | 72 | } |
80 | 73 | if proto == "" { |
81 | | - return nil, fmt.Errorf("not supported scheme %v", uri) |
| 74 | + return nil, fmt.Errorf("not supported uri %v", original_uri) |
| 75 | + } |
| 76 | + |
| 77 | + port := ":9600" // default kms port |
| 78 | + path := "" // default path |
| 79 | + |
| 80 | + parts := strings.Split(uri, ";") |
| 81 | + for i, s := range parts { |
| 82 | + path_index := strings.Index(s, "/") |
| 83 | + if path_index > -1 { |
| 84 | + path = s[path_index:] |
| 85 | + s = s[:path_index] |
| 86 | + } |
| 87 | + port_index := strings.Index(s, ":") |
| 88 | + if port_index > -1 { |
| 89 | + port = s[port_index:] |
| 90 | + s = s[:port_index] |
| 91 | + } |
| 92 | + if (path_index > -1 || port_index > -1) && i+1 != len(parts) { |
| 93 | + return nil, fmt.Errorf("bad uri: %v", original_uri) |
| 94 | + } |
| 95 | + urls = append(urls, proto+s) |
82 | 96 | } |
83 | 97 |
|
84 | 98 | for i := range urls { |
85 | | - urls[i] = proto + urls[i] + "/v1/keyversion/" + url.QueryEscape(*einfo.EzKeyVersionName) + "/_eek?eek_op=decrypt" |
| 99 | + urls[i] += port |
| 100 | + urls[i] += path |
| 101 | + } |
| 102 | + |
| 103 | + return urls, nil |
| 104 | +} |
| 105 | + |
| 106 | +// kmsUrl parse KeyProviderUri to list of URL's |
| 107 | +func (c *Client) kmsUrl(einfo *hdfs.FileEncryptionInfoProto) ([]string, error) { |
| 108 | + defaults, err := c.fetchDefaults() |
| 109 | + if err != nil { |
| 110 | + return nil, err |
| 111 | + } |
| 112 | + |
| 113 | + urls, err := kmsParseProviderUri(defaults.GetKeyProviderUri()) |
| 114 | + if err != nil { |
| 115 | + return nil, err |
| 116 | + } |
| 117 | + |
| 118 | + // Reorder urls. Simple method to round robin calls across em. |
| 119 | + rand.Shuffle(len(urls), func(i, j int) { urls[i], urls[j] = urls[j], urls[i] }) |
| 120 | + |
| 121 | + for i := range urls { |
| 122 | + urls[i] = urls[i] + "/v1/keyversion/" + url.QueryEscape(*einfo.EzKeyVersionName) + "/_eek?eek_op=decrypt" |
86 | 123 | } |
87 | 124 |
|
88 | 125 | return urls, nil |
@@ -156,7 +193,7 @@ func (c *Client) kmsGetKey(einfo *hdfs.FileEncryptionInfoProto) (*TransparentEnc |
156 | 193 |
|
157 | 194 | urls, err := c.kmsUrl(einfo) |
158 | 195 | if err != nil { |
159 | | - return nil, err |
| 196 | + return nil, errors.Wrap(err, "fail to get KMS address") |
160 | 197 | } |
161 | 198 |
|
162 | 199 | requestBody, err := json.Marshal(map[string]string{ |
|
0 commit comments