/* * * Copyright 2016, Google Inc. * All rights reserved. * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are * met: * * * Redistributions of source code must retain the above copyright * notice, this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above * copyright notice, this list of conditions and the following disclaimer * in the documentation and/or other materials provided with the * distribution. * * Neither the name of Google Inc. nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * */ // Package grpclb implements the load balancing protocol defined at // https://github.com/grpc/grpc/blob/master/doc/load-balancing.md. // The implementation is currently EXPERIMENTAL. package grpclb import ( "errors" "fmt" "sync" "golang.org/x/net/context" "google.golang.org/grpc" lbpb "google.golang.org/grpc/grpclb/grpc_lb_v1" "google.golang.org/grpc/grpclog" "google.golang.org/grpc/naming" ) // Balancer creates a grpclb load balancer. func Balancer(r naming.Resolver) grpc.Balancer { return &balancer{ r: r, } } type remoteBalancerInfo struct { addr grpc.Address name string } // addrInfo consists of the information of a backend server. type addrInfo struct { addr grpc.Address connected bool dropRequest bool } type balancer struct { r naming.Resolver mu sync.Mutex seq int // a sequence number to make sure addrCh does not get stale addresses. w naming.Watcher addrCh chan []grpc.Address rbs []remoteBalancerInfo addrs []addrInfo next int waitCh chan struct{} done bool } func (b *balancer) watchAddrUpdates(w naming.Watcher, ch chan remoteBalancerInfo) error { updates, err := w.Next() if err != nil { return err } b.mu.Lock() defer b.mu.Unlock() if b.done { return grpc.ErrClientConnClosing } var bAddr remoteBalancerInfo if len(b.rbs) > 0 { bAddr = b.rbs[0] } for _, update := range updates { addr := grpc.Address{ Addr: update.Addr, Metadata: update.Metadata, } switch update.Op { case naming.Add: var exist bool for _, v := range b.rbs { // TODO: Is the same addr with different server name a different balancer? if addr == v.addr { exist = true break } } if exist { continue } b.rbs = append(b.rbs, remoteBalancerInfo{addr: addr}) case naming.Delete: for i, v := range b.rbs { if addr == v.addr { copy(b.rbs[i:], b.rbs[i+1:]) b.rbs = b.rbs[:len(b.rbs)-1] break } } default: grpclog.Println("Unknown update.Op ", update.Op) } } // TODO: Fall back to the basic round-robin load balancing if the resulting address is // not a load balancer. if len(b.rbs) > 0 { // For simplicity, always use the first one now. May revisit this decision later. if b.rbs[0] != bAddr { select { case <-ch: default: } ch <- b.rbs[0] } } return nil } func (b *balancer) processServerList(l *lbpb.ServerList, seq int) { servers := l.GetServers() var ( sl []addrInfo addrs []grpc.Address ) for _, s := range servers { // TODO: Support ExpirationInterval addr := grpc.Address{ Addr: fmt.Sprintf("%s:%d", s.IpAddress, s.Port), // TODO: include LoadBalanceToken in the Metadata } sl = append(sl, addrInfo{ addr: addr, // TODO: Support dropRequest feature. }) addrs = append(addrs, addr) } b.mu.Lock() defer b.mu.Unlock() if b.done || seq < b.seq { return } if len(sl) > 0 { // reset b.next to 0 when replacing the server list. b.next = 0 b.addrs = sl b.addrCh <- addrs } return } func (b *balancer) callRemoteBalancer(lbc lbpb.LoadBalancerClient) (retry bool) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() stream, err := lbc.BalanceLoad(ctx, grpc.FailFast(false)) if err != nil { grpclog.Printf("Failed to perform RPC to the remote balancer %v", err) return } b.mu.Lock() if b.done { b.mu.Unlock() return } b.seq++ seq := b.seq b.mu.Unlock() initReq := &lbpb.LoadBalanceRequest{ LoadBalanceRequestType: &lbpb.LoadBalanceRequest_InitialRequest{ InitialRequest: new(lbpb.InitialLoadBalanceRequest), }, } if err := stream.Send(initReq); err != nil { // TODO: backoff on retry? return true } reply, err := stream.Recv() if err != nil { // TODO: backoff on retry? return true } initResp := reply.GetInitialResponse() if initResp == nil { grpclog.Println("Failed to receive the initial response from the remote balancer.") return } // TODO: Support delegation. if initResp.LoadBalancerDelegate != "" { // delegation grpclog.Println("TODO: Delegation is not supported yet.") return } // Retrieve the server list. for { reply, err := stream.Recv() if err != nil { break } if serverList := reply.GetServerList(); serverList != nil { b.processServerList(serverList, seq) } } return true } func (b *balancer) Start(target string, config grpc.BalancerConfig) error { // TODO: Fall back to the basic direct connection if there is no name resolver. if b.r == nil { return errors.New("there is no name resolver installed") } b.mu.Lock() if b.done { b.mu.Unlock() return grpc.ErrClientConnClosing } b.addrCh = make(chan []grpc.Address) w, err := b.r.Resolve(target) if err != nil { b.mu.Unlock() return err } b.w = w b.mu.Unlock() balancerAddrCh := make(chan remoteBalancerInfo, 1) // Spawn a goroutine to monitor the name resolution of remote load balancer. go func() { for { if err := b.watchAddrUpdates(w, balancerAddrCh); err != nil { grpclog.Printf("grpc: the naming watcher stops working due to %v.\n", err) close(balancerAddrCh) return } } }() // Spawn a goroutine to talk to the remote load balancer. go func() { var cc *grpc.ClientConn for { rb, ok := <-balancerAddrCh if cc != nil { cc.Close() } if !ok { // b is closing. return } // Talk to the remote load balancer to get the server list. // // TODO: override the server name in creds using Metadata in addr. var err error creds := config.DialCreds if creds == nil { cc, err = grpc.Dial(rb.addr.Addr, grpc.WithInsecure()) } else { cc, err = grpc.Dial(rb.addr.Addr, grpc.WithTransportCredentials(creds)) } if err != nil { grpclog.Printf("Failed to setup a connection to the remote balancer %v: %v", rb.addr, err) return } go func(cc *grpc.ClientConn) { lbc := lbpb.NewLoadBalancerClient(cc) for { if retry := b.callRemoteBalancer(lbc); !retry { cc.Close() return } } }(cc) } }() return nil } func (b *balancer) down(addr grpc.Address, err error) { b.mu.Lock() defer b.mu.Unlock() for _, a := range b.addrs { if addr == a.addr { a.connected = false break } } } func (b *balancer) Up(addr grpc.Address) func(error) { b.mu.Lock() defer b.mu.Unlock() if b.done { return nil } var cnt int for _, a := range b.addrs { if a.addr == addr { if a.connected { return nil } a.connected = true } if a.connected { cnt++ } } // addr is the only one which is connected. Notify the Get() callers who are blocking. if cnt == 1 && b.waitCh != nil { close(b.waitCh) b.waitCh = nil } return func(err error) { b.down(addr, err) } } func (b *balancer) Get(ctx context.Context, opts grpc.BalancerGetOptions) (addr grpc.Address, put func(), err error) { var ch chan struct{} b.mu.Lock() if b.done { b.mu.Unlock() err = grpc.ErrClientConnClosing return } if len(b.addrs) > 0 { if b.next >= len(b.addrs) { b.next = 0 } next := b.next for { a := b.addrs[next] next = (next + 1) % len(b.addrs) if a.connected { addr = a.addr b.next = next b.mu.Unlock() return } if next == b.next { // Has iterated all the possible address but none is connected. break } } } if !opts.BlockingWait { if len(b.addrs) == 0 { b.mu.Unlock() err = fmt.Errorf("there is no address available") return } // Returns the next addr on b.addrs for a failfast RPC. addr = b.addrs[b.next].addr b.next++ b.mu.Unlock() return } // Wait on b.waitCh for non-failfast RPCs. if b.waitCh == nil { ch = make(chan struct{}) b.waitCh = ch } else { ch = b.waitCh } b.mu.Unlock() for { select { case <-ctx.Done(): err = ctx.Err() return case <-ch: b.mu.Lock() if b.done { b.mu.Unlock() err = grpc.ErrClientConnClosing return } if len(b.addrs) > 0 { if b.next >= len(b.addrs) { b.next = 0 } next := b.next for { a := b.addrs[next] next = (next + 1) % len(b.addrs) if a.connected { addr = a.addr b.next = next b.mu.Unlock() return } if next == b.next { // Has iterated all the possible address but none is connected. break } } } // The newly added addr got removed by Down() again. if b.waitCh == nil { ch = make(chan struct{}) b.waitCh = ch } else { ch = b.waitCh } b.mu.Unlock() } } } func (b *balancer) Notify() <-chan []grpc.Address { return b.addrCh } func (b *balancer) Close() error { b.mu.Lock() defer b.mu.Unlock() b.done = true if b.waitCh != nil { close(b.waitCh) } if b.addrCh != nil { close(b.addrCh) } if b.w != nil { b.w.Close() } return nil }