debian-mirror-gitlab/workhorse-vendor/gocloud.dev/postgres/gcppostgres/gcppostgres.go
2023-01-13 15:02:22 +05:30

194 lines
5.5 KiB
Go

// Copyright 2018 The Go Cloud Development Kit Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Package gcppostgres provides connections to managed PostgreSQL Cloud SQL instances.
// See https://cloud.google.com/sql/docs/postgres/ for more information.
//
// # URLs
//
// For postgres.Open, gcppostgres registers for the scheme "gcppostgres".
// The default URL opener will create a connection using the default
// credentials from the environment, as described in
// https://cloud.google.com/docs/authentication/production.
// To customize the URL opener, or for more details on the URL format,
// see URLOpener.
//
// See https://gocloud.dev/concepts/urls/ for background information.
package gcppostgres // import "gocloud.dev/postgres/gcppostgres"
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"net"
"net/url"
"strings"
"sync"
"time"
"contrib.go.opencensus.io/integrations/ocsql"
"github.com/GoogleCloudPlatform/cloudsql-proxy/proxy/proxy"
"github.com/lib/pq"
"gocloud.dev/gcp"
"gocloud.dev/gcp/cloudsql"
"gocloud.dev/postgres"
)
// Scheme is the URL scheme gcppostgres registers its URLOpener under on
// postgres.DefaultMux.
const Scheme = "gcppostgres"
func init() {
postgres.DefaultURLMux().RegisterPostgres(Scheme, new(lazyCredsOpener))
}
// lazyCredsOpener obtains Application Default Credentials on the first call
// to OpenPostgresURL.
type lazyCredsOpener struct {
init sync.Once
opener *URLOpener
err error
}
func (o *lazyCredsOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
o.init.Do(func() {
creds, err := gcp.DefaultCredentials(ctx)
if err != nil {
o.err = err
return
}
client, err := gcp.NewHTTPClient(gcp.DefaultTransport(), creds.TokenSource)
if err != nil {
o.err = err
return
}
certSource := cloudsql.NewCertSource(client)
o.opener = &URLOpener{CertSource: certSource}
})
if o.err != nil {
return nil, fmt.Errorf("gcppostgres open %v: %v", u, o.err)
}
return o.opener.OpenPostgresURL(ctx, u)
}
// URLOpener opens GCP PostgreSQL URLs
// like "gcppostgres://user:password@myproject/us-central1/instanceId/mydb".
type URLOpener struct {
// CertSource specifies how the opener will obtain authentication information.
// CertSource must not be nil.
CertSource proxy.CertSource
// TraceOpts contains options for OpenCensus.
TraceOpts []ocsql.TraceOption
}
// OpenPostgresURL opens a new GCP database connection wrapped with OpenCensus instrumentation.
func (uo *URLOpener) OpenPostgresURL(ctx context.Context, u *url.URL) (*sql.DB, error) {
if uo.CertSource == nil {
return nil, fmt.Errorf("gcppostgres: URLOpener CertSource is nil")
}
instance, dbName, err := instanceFromURL(u)
if err != nil {
return nil, fmt.Errorf("gcppostgres: open %v: %v", u, err)
}
query := u.Query()
for k := range query {
// Only permit parameters that do not conflict with other behavior.
if k == "sslmode" || k == "sslcert" || k == "sslkey" || k == "sslrootcert" {
return nil, fmt.Errorf("gcppostgres: open: extra parameter %s not allowed", k)
}
}
query.Set("sslmode", "disable")
u2 := new(url.URL)
*u2 = *u
u2.Scheme = "postgres"
u2.Host = "cloudsql"
u2.Path = "/" + dbName
u2.RawQuery = query.Encode()
db := sql.OpenDB(connector{
client: &proxy.Client{
Port: 3307,
Certs: uo.CertSource,
},
instance: instance,
pqConn: u2.String(),
traceOpts: append([]ocsql.TraceOption(nil), uo.TraceOpts...),
})
return db, nil
}
func instanceFromURL(u *url.URL) (instance, db string, _ error) {
path := u.Host + u.Path // everything after scheme but before query or fragment
parts := strings.SplitN(path, "/", 4)
if len(parts) < 4 {
return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
}
for _, part := range parts {
if part == "" {
return "", "", fmt.Errorf("%s is not in the form project/region/instance/dbname", path)
}
}
return parts[0] + ":" + parts[1] + ":" + parts[2], parts[3], nil
}
type pqDriver struct {
client *proxy.Client
instance string
traceOpts []ocsql.TraceOption
}
func (d pqDriver) Open(name string) (driver.Conn, error) {
c, _ := d.OpenConnector(name)
return c.Connect(context.Background())
}
func (d pqDriver) OpenConnector(name string) (driver.Connector, error) {
return connector{d.client, d.instance, name, d.traceOpts}, nil
}
type connector struct {
client *proxy.Client
instance string
pqConn string
traceOpts []ocsql.TraceOption
}
func (c connector) Connect(context.Context) (driver.Conn, error) {
conn, err := pq.DialOpen(dialer{c.client, c.instance}, c.pqConn)
if err != nil {
return nil, err
}
return ocsql.WrapConn(conn, c.traceOpts...), nil
}
func (c connector) Driver() driver.Driver {
return pqDriver{c.client, c.instance, c.traceOpts}
}
type dialer struct {
client *proxy.Client
instance string
}
func (d dialer) Dial(network, address string) (net.Conn, error) {
return d.client.Dial(d.instance)
}
func (d dialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
return nil, errors.New("gcppostgres: DialTimeout not supported")
}