diff --git a/README.md b/README.md index 1754108..f50131f 100644 --- a/README.md +++ b/README.md @@ -187,15 +187,16 @@ See `test/main.go` for an example. - [x] Strict protocol validation - [x] OID validation - [ ] DN parsing support -- [ ] Full concurrency ability +- [x] Full concurrency ability - [ ] Comprehensive message parsing tests - [x] Abandon request - [x] Add request - [x] Bind request - [x] Compare request (concurrent) +- [x] Delete request (concurrent) - [x] Extended requests -- [x] Modify request -- [x] ModifyDN request +- [x] Modify request (concurrent) +- [x] ModifyDN request (concurrent) - [x] Search request (concurrent) - [x] StartTLS request - [x] Unbind request diff --git a/handler.go b/handler.go index e379d2e..78b9da9 100644 --- a/handler.go +++ b/handler.go @@ -16,6 +16,8 @@ type Handler interface { Bind(*Conn, *Message, *BindRequest) // Perform a Compare request Compare(*Conn, *Message, *CompareRequest) + // Perform a Delete request + Delete(*Conn, *Message, string) // Perform an Extended request Extended(*Conn, *Message, *ExtendedRequest) // Perform a Modify request @@ -50,6 +52,10 @@ func (*BaseHandler) Compare(conn *Conn, msg *Message, req *CompareRequest) { conn.SendResult(msg.MessageID, nil, TypeCompareResponseOp, UnsupportedOperation) } +func (*BaseHandler) Delete(conn *Conn, msg *Message, dn string) { + conn.SendResult(msg.MessageID, nil, TypeDeleteResponseOp, UnsupportedOperation) +} + func (*BaseHandler) Modify(conn *Conn, msg *Message, req *ModifyRequest) { conn.SendResult(msg.MessageID, nil, TypeModifyResponseOp, UnsupportedOperation) } diff --git a/server.go b/server.go index 9fa6272..48fd061 100644 --- a/server.go +++ b/server.go @@ -192,7 +192,12 @@ func (s *LDAPServer) handleMessage(conn *Conn, msg *Message) { }() case TypeDeleteRequestOp: log.Println("Delete request") - conn.SendResult(msg.MessageID, nil, TypeDeleteResponseOp, UnsupportedOperation) + dn := BerGetOctetString(msg.ProtocolOp.Data) + conn.asyncOperations.Add(1) + go func() { + defer conn.asyncOperations.Done() + s.Handler.Delete(conn, msg, dn) + }() case TypeExtendedRequestOp: log.Println("Extended request") req, err := GetExtendedRequest(msg.ProtocolOp.Data) diff --git a/test/main.go b/test/main.go index 97bec39..f75a294 100644 --- a/test/main.go +++ b/test/main.go @@ -47,6 +47,24 @@ func (t *TestHandler) Abandon(conn *ldapserver.Conn, msg *ldapserver.Message, me t.abandonmentLock.Unlock() } +func (t *TestHandler) Add(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.AddRequest) { + auth := getAuth(conn) + if auth != "uid=authorizeduser,ou=users,dc=example,dc=com" { + log.Println("Not an authorized connection!", auth) + conn.SendResult(msg.MessageID, nil, ldapserver.TypeAddResponseOp, ldapserver.PermissionDenied) + return + } + log.Println("Add DN:", req.Entry) + for _, attr := range req.Attributes { + log.Println(" Attribute:", attr.Description) + log.Println(" Values:", attr.Values) + } + res := &ldapserver.Result{ + ResultCode: ldapserver.ResultSuccess, + } + conn.SendResult(msg.MessageID, nil, ldapserver.TypeAddResponseOp, res) +} + func (t *TestHandler) Bind(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.BindRequest) { res := &ldapserver.BindResponse{} if req.Version != 3 { @@ -115,6 +133,20 @@ func (t *TestHandler) Compare(conn *ldapserver.Conn, msg *ldapserver.Message, re conn.SendResult(msg.MessageID, nil, ldapserver.TypeCompareResponseOp, res) } +func (t *TestHandler) Delete(conn *ldapserver.Conn, msg *ldapserver.Message, dn string) { + auth := getAuth(conn) + if auth != "uid=authorizeduser,ou=users,dc=example,dc=com" { + log.Println("Not an authorized connection!", auth) + conn.SendResult(msg.MessageID, nil, ldapserver.TypeDeleteResponseOp, ldapserver.PermissionDenied) + return + } + log.Println("Delete DN:", dn) + res := &ldapserver.Result{ + ResultCode: ldapserver.ResultSuccess, + } + conn.SendResult(msg.MessageID, nil, ldapserver.TypeDeleteResponseOp, res) +} + func (t *TestHandler) Modify(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.ModifyRequest) { auth := getAuth(conn) if auth != "uid=authorizeduser,ou=users,dc=example,dc=com" { @@ -218,3 +250,16 @@ func (t *TestHandler) Search(conn *ldapserver.Conn, msg *ldapserver.Message, req } conn.SendResult(msg.MessageID, nil, ldapserver.TypeSearchResultDoneOp, res) } + +func (t *TestHandler) Extended(conn *ldapserver.Conn, msg *ldapserver.Message, req *ldapserver.ExtendedRequest) { + switch req.Name { + case ldapserver.OIDPasswordModify: + log.Println("Password modify") + // Pretend to handle it + res := &ldapserver.ExtendedResult{} + res.ResultCode = ldapserver.ResultSuccess + conn.SendResult(msg.MessageID, nil, ldapserver.TypeExtendedResponseOp, res) + default: + t.BaseHandler.Extended(conn, msg, req) + } +}