changeset 698:a58ae693ab72

determine_wants: factor ref filtering code out into a separate function This will be used in another context in an upcoming patch.
author Siddharth Agarwal <sid0@fb.com>
date Tue, 04 Mar 2014 12:57:37 -0800
parents a7383625c891
children 6dc550f2fa78
files hggit/git_handler.py
diffstat 1 files changed, 31 insertions(+), 22 deletions(-) [+]
line wrap: on
line diff
--- a/hggit/git_handler.py
+++ b/hggit/git_handler.py
@@ -973,33 +973,14 @@
 
         return new_refs
 
-
     def fetch_pack(self, remote_name, heads=None):
         client, path = self.get_transport_and_path(remote_name)
         graphwalker = self.git.get_graph_walker()
+
         def determine_wants(refs):
-            if heads:
-                want = []
-                # contains pairs of ('refs/(heads|tags|...)/foo', 'foo')
-                # if ref is just '<foo>', then we get ('foo', 'foo')
-                stripped_refs = [
-                    (r, r[r.find('/', r.find('/')+1)+1:])
-                        for r in refs]
-                for h in heads:
-                    r = [pair[0] for pair in stripped_refs if pair[1] == h]
-                    if not r:
-                        raise hgutil.Abort("ref %s not found on remote server" % h)
-                    elif len(r) == 1:
-                        want.append(refs[r[0]])
-                    else:
-                        raise hgutil.Abort("ambiguous reference %s: %r" % (h, r))
-            else:
-                want = [sha for ref, sha in refs.iteritems()
-                        if not ref.endswith('^{}')
-                        and ( ref.startswith('refs/heads/') or ref.startswith('refs/tags/') ) ]
-            want = [x for x in want if x not in self.git]
+            filteredrefs = self.filter_refs(refs, heads)
+            return [x for x in filteredrefs.itervalues() if x not in self.git]
 
-            return want
         try:
             progress = GitProgress(self.ui)
             f = StringIO.StringIO()
@@ -1018,6 +999,34 @@
 
     ## REFERENCES HANDLING
 
+    def filter_refs(self, refs, heads):
+        '''For a dictionary of refs: shas, if heads has any elements then return refs
+        that match the heads. Otherwise, return refs that are heads or tags.
+
+        '''
+        filteredrefs = {}
+        if heads:
+            # contains pairs of ('refs/(heads|tags|...)/foo', 'foo')
+            # if ref is just '<foo>', then we get ('foo', 'foo')
+            stripped_refs = [
+                (r, r[r.find('/', r.find('/')+1)+1:])
+                    for r in refs]
+            for h in heads:
+                r = [pair[0] for pair in stripped_refs if pair[1] == h]
+                if not r:
+                    raise hgutil.Abort("ref %s not found on remote server" % h)
+                elif len(r) == 1:
+                    filteredrefs[r[0]] = refs[r[0]]
+                else:
+                    raise hgutil.Abort("ambiguous reference %s: %r" % (h, r))
+        else:
+            for ref, sha in refs.iteritems():
+                if (not ref.endswith('^{}')
+                    and (ref.startswith('refs/heads/')
+                         or ref.startswith('refs/tags/'))):
+                    filteredrefs[ref] = sha
+        return filteredrefs
+
     def update_references(self):
         heads = self.local_heads()