changeset 36851:bf676267f64f

wireprotoserver: split ssh protocol handler and server We want to formalize the interface for protocol handlers. Today, server functionality (which is domain specific) is interleaved with protocol handling functionality (which conforms to a generic interface) in the sshserver class. This commit splits the ssh protocol handling code out of the sshserver class. Differential Revision: https://phab.mercurial-scm.org/D2080
author Gregory Szorc <gregory.szorc@gmail.com>
date Wed, 07 Feb 2018 20:17:05 -0800
parents 5767664d39a5
children 2ad145fbde54
files mercurial/wireprotoserver.py tests/sshprotoext.py tests/test-sshserver.py
diffstat 3 files changed, 30 insertions(+), 21 deletions(-) [+]
line wrap: on
line diff
--- a/mercurial/wireprotoserver.py
+++ b/mercurial/wireprotoserver.py
@@ -354,19 +354,12 @@
     fout.write(b'\n')
     fout.flush()
 
-class sshserver(baseprotocolhandler):
-    def __init__(self, ui, repo):
+class sshv1protocolhandler(baseprotocolhandler):
+    """Handler for requests services via version 1 of SSH protocol."""
+    def __init__(self, ui, fin, fout):
         self._ui = ui
-        self._repo = repo
-        self._fin = ui.fin
-        self._fout = ui.fout
-
-        hook.redirect(True)
-        ui.fout = repo.ui.fout = ui.ferr
-
-        # Prevent insertion/deletion of CRs
-        util.setbinary(self._fin)
-        util.setbinary(self._fout)
+        self._fin = fin
+        self._fout = fout
 
     @property
     def name(self):
@@ -403,6 +396,26 @@
     def redirect(self):
         pass
 
+    def _client(self):
+        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
+        return 'remote:ssh:' + client
+
+class sshserver(object):
+    def __init__(self, ui, repo):
+        self._ui = ui
+        self._repo = repo
+        self._fin = ui.fin
+        self._fout = ui.fout
+
+        hook.redirect(True)
+        ui.fout = repo.ui.fout = ui.ferr
+
+        # Prevent insertion/deletion of CRs
+        util.setbinary(self._fin)
+        util.setbinary(self._fout)
+
+        self._proto = sshv1protocolhandler(self._ui, self._fin, self._fout)
+
     def serve_forever(self):
         while self.serve_one():
             pass
@@ -410,8 +423,8 @@
 
     def serve_one(self):
         cmd = self._fin.readline()[:-1]
-        if cmd and wireproto.commands.commandavailable(cmd, self):
-            rsp = wireproto.dispatch(self._repo, self, cmd)
+        if cmd and wireproto.commands.commandavailable(cmd, self._proto):
+            rsp = wireproto.dispatch(self._repo, self._proto, cmd)
 
             if isinstance(rsp, bytes):
                 _sshv1respondbytes(self._fout, rsp)
@@ -432,7 +445,3 @@
         elif cmd:
             _sshv1respondbytes(self._fout, b'')
         return cmd != ''
-
-    def _client(self):
-        client = encoding.environ.get('SSH_CLIENT', '').split(' ', 1)[0]
-        return 'remote:ssh:' + client
--- a/tests/sshprotoext.py
+++ b/tests/sshprotoext.py
@@ -48,7 +48,7 @@
         wireprotoserver._sshv1respondbytes(self._fout, b'')
         l = self._fin.readline()
         assert l == b'between\n'
-        rsp = wireproto.dispatch(self._repo, self, b'between')
+        rsp = wireproto.dispatch(self._repo, self._proto, b'between')
         wireprotoserver._sshv1respondbytes(self._fout, rsp)
 
         super(prehelloserver, self).serve_forever()
@@ -73,7 +73,7 @@
 
         # Send the upgrade response.
         self._fout.write(b'upgraded %s %s\n' % (token, name))
-        servercaps = wireproto.capabilities(self._repo, self)
+        servercaps = wireproto.capabilities(self._repo, self._proto)
         rsp = b'capabilities: %s' % servercaps
         self._fout.write(b'%d\n' % len(rsp))
         self._fout.write(rsp)
--- a/tests/test-sshserver.py
+++ b/tests/test-sshserver.py
@@ -24,7 +24,7 @@
     def assertparse(self, cmd, input, expected):
         server = mockserver(input)
         _func, spec = wireproto.commands[cmd]
-        self.assertEqual(server.getargs(spec), expected)
+        self.assertEqual(server._proto.getargs(spec), expected)
 
 def mockserver(inbytes):
     ui = mockui(inbytes)