Skip to content

Commit bae76b8

Browse files
author
Christian Muirhead
committed
Extracted the combined recorder from rpc.Conn
rpc.Conn now accepts a RecorderFactory rather than an ObserverFactory. The Recorder interface has the same methods as Observer, but they can return errors to stop the handling of the request, which is needed for auditing. I've kept the Observer interface, since the lack of errors makes multiplexing simpler. The observer is now embedded in a combined recorder that forwards messages to it but also passes them on to the auditlog recorder, which has the opportunity to interrupt the request.
1 parent c388ea9 commit bae76b8

File tree

12 files changed

+157
-129
lines changed

12 files changed

+157
-129
lines changed

api/apiclient.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import (
3131
"gopkg.in/retry.v1"
3232

3333
"github.com/juju/juju/api/base"
34-
"github.com/juju/juju/apiserver/observer"
3534
"github.com/juju/juju/apiserver/params"
3635
"github.com/juju/juju/network"
3736
"github.com/juju/juju/rpc"
@@ -196,7 +195,7 @@ func Open(info *Info, opts DialOpts) (Connection, error) {
196195
return nil, errors.Trace(err)
197196
}
198197

199-
client := rpc.NewConn(jsoncodec.New(dialResult.conn), observer.None())
198+
client := rpc.NewConn(jsoncodec.New(dialResult.conn), nil)
200199
client.Start()
201200

202201
bakeryClient := opts.BakeryClient

api/testing/fakeserver.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func FakeAPIServer(root interface{}) net.Conn {
1717
c0, c1 := net.Pipe()
1818
serverCodec := jsoncodec.NewNet(c1)
1919
serverRPC := rpc.NewConn(serverCodec, nil)
20-
serverRPC.Serve(root, nil)
20+
serverRPC.Serve(root, nil, nil)
2121
serverRPC.Start()
2222
go func() {
2323
<-serverRPC.Dead()

apiserver/admin.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ func (a *admin) login(req params.LoginRequest, loginVersion int) (params.LoginRe
136136
modelTag = a.root.model.Tag().String()
137137
}
138138

139-
var recorder *auditlog.Recorder
139+
var auditRecorder *auditlog.Recorder
140140
if authResult.userLogin {
141141
// We only audit connections from humans.
142-
recorder, err = auditlog.NewRecorder(
142+
auditRecorder, err = auditlog.NewRecorder(
143143
a.srv.auditLogger,
144144
auditlog.ConversationArgs{
145145
Who: req.AuthTag,
@@ -155,7 +155,10 @@ func (a *admin) login(req params.LoginRequest, loginVersion int) (params.LoginRe
155155
}
156156
}
157157

158-
a.root.rpcConn.ServeRoot(apiRoot, recorder, serverError)
158+
recorderFactory := observer.NewRecorderFactory(
159+
a.apiObserver, auditRecorder)
160+
161+
a.root.rpcConn.ServeRoot(apiRoot, recorderFactory, serverError)
159162
return params.LoginResult{
160163
Servers: params.FromNetworkHostsPorts(hostPorts),
161164
ControllerTag: a.root.model.ControllerTag().String(),

apiserver/apiserver.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -967,7 +967,7 @@ func (srv *Server) serveConn(
967967
host string,
968968
) error {
969969
codec := jsoncodec.NewWebsocket(wsConn.Conn)
970-
conn := rpc.NewConn(codec, apiObserver)
970+
conn := rpc.NewConn(codec, observer.NewRecorderFactory(apiObserver, nil))
971971

972972
// Note that we don't overwrite modelUUID here because
973973
// newAPIHandler treats an empty modelUUID as signifying

apiserver/export_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ func TestingAPIHandler(c *gc.C, pool *state.StatePool, st *state.State) (*apiHan
111111
statePool: pool,
112112
tag: names.NewMachineTag("0"),
113113
}
114-
h, err := newAPIHandler(srv, st, nil, st.ModelUUID(), "testing.invalid:1234")
114+
h, err := newAPIHandler(srv, st, nil, st.ModelUUID(), 6543, "testing.invalid:1234")
115115
c.Assert(err, jc.ErrorIsNil)
116116
return h, h.getResources()
117117
}

apiserver/observer/recorder.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
// Copyright 2017 Canonical Ltd.
2+
// Licensed under the AGPLv3, see LICENCE file for details.
3+
4+
package observer
5+
6+
import (
7+
"encoding/json"
8+
9+
"github.com/juju/errors"
10+
11+
"github.com/juju/juju/core/auditlog"
12+
"github.com/juju/juju/rpc"
13+
)
14+
15+
// NewRecorderFactory makes a new rpc.RecorderFactory to make
16+
// recorders that that will update the observer and the auditlog
17+
// recorder when it records a request or reply. The auditlog recorder
18+
// can be nil.
19+
func NewRecorderFactory(observerFactory rpc.ObserverFactory, recorder *auditlog.Recorder) rpc.RecorderFactory {
20+
return func() rpc.Recorder {
21+
return &combinedRecorder{
22+
observer: observerFactory.RPCObserver(),
23+
recorder: recorder,
24+
}
25+
}
26+
}
27+
28+
// combinedRecorder wraps an observer (which might be a multiplexer)
29+
// up with an auditlog recorder into an rpc.Recorder.
30+
type combinedRecorder struct {
31+
observer rpc.Observer
32+
recorder *auditlog.Recorder
33+
}
34+
35+
// ServerRequest implements rpc.Recorder.
36+
func (cr *combinedRecorder) ServerRequest(hdr *rpc.Header, body interface{}) error {
37+
cr.observer.ServerRequest(hdr, body)
38+
if cr.recorder == nil {
39+
return nil
40+
}
41+
// TODO(babbageclunk): make this configurable.
42+
jsonArgs, err := json.Marshal(body)
43+
if err != nil {
44+
return errors.Trace(err)
45+
}
46+
return errors.Trace(cr.recorder.AddRequest(auditlog.RequestArgs{
47+
RequestID: hdr.RequestId,
48+
Facade: hdr.Request.Type,
49+
Method: hdr.Request.Action,
50+
Version: hdr.Request.Version,
51+
Args: string(jsonArgs),
52+
}))
53+
}
54+
55+
// ServerReply implements rpc.Recorder.
56+
func (cr *combinedRecorder) ServerReply(req rpc.Request, replyHdr *rpc.Header, body interface{}) error {
57+
cr.observer.ServerReply(req, replyHdr, body)
58+
if cr.recorder == nil {
59+
return nil
60+
}
61+
var responseErrors []*auditlog.Error
62+
if replyHdr.Error == "" {
63+
var err error
64+
responseErrors, err = extractErrors(body)
65+
if err != nil {
66+
return errors.Trace(err)
67+
}
68+
} else {
69+
responseErrors = []*auditlog.Error{{
70+
Message: replyHdr.Error,
71+
Code: replyHdr.ErrorCode,
72+
}}
73+
}
74+
return errors.Trace(cr.recorder.AddResponse(auditlog.ResponseErrorsArgs{
75+
RequestID: replyHdr.RequestId,
76+
Errors: responseErrors,
77+
}))
78+
}
79+
80+
func extractErrors(body interface{}) ([]*auditlog.Error, error) {
81+
// TODO(babbageclunk): use reflection to find errors in the response body.
82+
return nil, nil
83+
}

apiserver/testing/fakeapi.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/gorilla/websocket"
1212
"github.com/juju/utils"
1313

14+
"github.com/juju/juju/apiserver/observer"
1415
"github.com/juju/juju/apiserver/observer/fakeobserver"
1516
"github.com/juju/juju/rpc"
1617
"github.com/juju/juju/rpc/jsoncodec"
@@ -77,7 +78,7 @@ func (srv *Server) serveAPI(w http.ResponseWriter, req *http.Request) {
7778

7879
func (srv *Server) serveConn(wsConn *websocket.Conn, modelUUID string) {
7980
codec := jsoncodec.NewWebsocket(wsConn)
80-
conn := rpc.NewConn(codec, &fakeobserver.Instance{})
81+
conn := rpc.NewConn(codec, observer.NewRecorderFactory(&fakeobserver.Instance{}, nil))
8182

8283
root := allVersions{
8384
rpcreflect.ValueOf(reflect.ValueOf(srv.newRoot(modelUUID))),

core/auditlog/auditlog.go

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,6 @@ func NewRecorder(log AuditLog, c ConversationArgs) (*Recorder, error) {
138138

139139
// AddRequest records a method call to the API.
140140
func (r *Recorder) AddRequest(m RequestArgs) error {
141-
if r == nil {
142-
return nil
143-
}
144141
return errors.Trace(r.log.AddRequest(Request{
145142
ConversationID: r.callID,
146143
ConnectionID: r.connectionID,
@@ -154,9 +151,6 @@ func (r *Recorder) AddRequest(m RequestArgs) error {
154151

155152
// AddResponse records the result of a method call to the API.
156153
func (r *Recorder) AddResponse(m ResponseErrorsArgs) error {
157-
if r == nil {
158-
return nil
159-
}
160154
return errors.Trace(r.log.AddResponse(ResponseErrors{
161155
ConversationID: r.callID,
162156
ConnectionID: r.connectionID,

rpc/dispatch_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ func (s *dispatchSuite) SetUpSuite(c *gc.C) {
3232
s.BaseSuite.SetUpSuite(c)
3333
rpcServer := func(ws *websocket.Conn) {
3434
codec := jsoncodec.NewWebsocket(ws)
35-
conn := rpc.NewConn(codec, &notifier{})
35+
conn := rpc.NewConn(codec, nil)
3636

37-
conn.Serve(&DispatchRoot{}, nil)
37+
conn.Serve(&DispatchRoot{}, nil, nil)
3838
conn.Start()
3939

4040
<-conn.Dead()

rpc/observers.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,14 @@ type Observer interface {
2828
ServerReply(req Request, hdr *Header, body interface{})
2929
}
3030

31+
// ObserverFactory is a type which can construct a new Observer.
32+
type ObserverFactory interface {
33+
// RPCObserver will return a new Observer usually constructed
34+
// from the state previously built up in the Observer. The
35+
// returned instance will be utilized per RPC request.
36+
RPCObserver() Observer
37+
}
38+
3139
// NewObserverMultiplexer returns a new ObserverMultiplexer
3240
// with the provided RequestNotifiers.
3341
func NewObserverMultiplexer(rpcObservers ...Observer) *ObserverMultiplexer {

0 commit comments

Comments
 (0)