Mercurial > hg > hg-git
changeset 87:babc85201dc4
merge of upstream work from dulwich project
author | Scott Chacon <schacon@gmail.com> |
---|---|
date | Fri, 08 May 2009 16:12:38 -0700 |
parents | 3ce739f2bd7e |
children | 52b4be85151d |
files | dulwich/__init__.py dulwich/_objects.c dulwich/client.py dulwich/errors.py dulwich/index.py dulwich/lru_cache.py dulwich/misc.py dulwich/object_store.py dulwich/objects.py dulwich/pack.py dulwich/protocol.py dulwich/repo.py dulwich/server.py dulwich/tests/test_lru_cache.py unit-tests/topo-test.py |
diffstat | 15 files changed, 1498 insertions(+), 471 deletions(-) [+] |
line wrap: on
line diff
--- a/dulwich/__init__.py +++ b/dulwich/__init__.py @@ -1,6 +1,6 @@ # __init__.py -- The git module of dulwich # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net> -# Copyright (C) 2008 Jelmer Vernooji <jelmer@samba.org> +# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org> # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -18,9 +18,13 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. + +"""Python implementation of the Git file formats and protocols.""" + + import client import protocol import repo import server -__version__ = (0, 1, 1) +__version__ = (0, 2, 2)
--- a/dulwich/_objects.c +++ b/dulwich/_objects.c @@ -47,12 +47,20 @@ return PyString_FromStringAndSize(sha, 20); } -static PyObject *py_sha_to_hex(PyObject *self, PyObject *py_sha) +static PyObject *sha_to_pyhex(const unsigned char *sha) { char hexsha[41]; - unsigned char *sha; int i; + for (i = 0; i < 20; i++) { + hexsha[i*2] = bytehex((sha[i] & 0xF0) >> 4); + hexsha[i*2+1] = bytehex(sha[i] & 0x0F); + } + + return PyString_FromStringAndSize(hexsha, 40); +} +static PyObject *py_sha_to_hex(PyObject *self, PyObject *py_sha) +{ if (!PyString_Check(py_sha)) { PyErr_SetString(PyExc_TypeError, "sha is not a string"); return NULL; @@ -63,18 +71,65 @@ return NULL; } - sha = (unsigned char *)PyString_AsString(py_sha); - for (i = 0; i < 20; i++) { - hexsha[i*2] = bytehex((sha[i] & 0xF0) >> 4); - hexsha[i*2+1] = bytehex(sha[i] & 0x0F); + return sha_to_pyhex((unsigned char *)PyString_AsString(py_sha)); +} + +static PyObject *py_parse_tree(PyObject *self, PyObject *args) +{ + char *text, *end; + int len, namelen; + PyObject *ret, *item, *name; + + if (!PyArg_ParseTuple(args, "s#", &text, &len)) + return NULL; + + ret = PyDict_New(); + if (ret == NULL) { + return NULL; } - - return PyString_FromStringAndSize(hexsha, 40); + + end = text + len; + + while (text < end) { + long mode; + mode = strtol(text, &text, 8); + + if (*text != ' ') { + PyErr_SetString(PyExc_RuntimeError, "Expected space"); + Py_DECREF(ret); + return NULL; + } + + text++; + + namelen = strlen(text); + + name = PyString_FromStringAndSize(text, namelen); + if (name == NULL) { + Py_DECREF(ret); + return NULL; + } + + item = Py_BuildValue("(lN)", mode, sha_to_pyhex((unsigned char *)text+namelen+1)); + if (item == NULL) { + Py_DECREF(ret); + Py_DECREF(name); + return NULL; + } + PyDict_SetItem(ret, name, item); + Py_DECREF(name); + Py_DECREF(item); + + text += namelen+21; + } + + return ret; } static PyMethodDef py_objects_methods[] = { { "hex_to_sha", (PyCFunction)py_hex_to_sha, METH_O, NULL }, { "sha_to_hex", (PyCFunction)py_sha_to_hex, METH_O, NULL }, + { "parse_tree", (PyCFunction)py_parse_tree, METH_VARARGS, NULL, }, }; void init_objects(void)
--- a/dulwich/client.py +++ b/dulwich/client.py @@ -1,5 +1,5 @@ -# server.py -- Implementation of the server side git protocols -# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org> +# client.py -- Implementation of the server side git protocols +# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> # Copyright (C) 2008 John Carr # # This program is free software; you can redistribute it and/or @@ -25,38 +25,48 @@ import select import socket import subprocess -import copy -import tempfile -from protocol import ( +from dulwich.errors import ( + ChecksumMismatch, + ) +from dulwich.protocol import ( Protocol, TCP_GIT_PORT, extract_capabilities, ) -from pack import ( +from dulwich.pack import ( write_pack_data, ) -from objects import sha_to_hex + def _fileno_can_read(fileno): + """Check if a file descriptor is readable.""" return len(select.select([fileno], [], [], 0)[0]) > 0 class SimpleFetchGraphWalker(object): + """Graph walker that finds out what commits are missing.""" def __init__(self, local_heads, get_parents): + """Create a new SimpleFetchGraphWalker instance. + + :param local_heads: SHA1s that should be retrieved + :param get_parents: Function for finding the parents of a SHA1. + """ self.heads = set(local_heads) self.get_parents = get_parents self.parents = {} - def ack(self, ref): - if ref in self.heads: - self.heads.remove(ref) - if ref in self.parents: - for p in self.parents[ref]: + def ack(self, sha): + """Ack that a particular revision and its ancestors are present in the target.""" + if sha in self.heads: + self.heads.remove(sha) + if sha in self.parents: + for p in self.parents[sha]: self.ack(p) def next(self): + """Iterate over revisions that might be missing in the target.""" if self.heads: ret = self.heads.pop() ps = self.get_parents(ret) @@ -65,8 +75,10 @@ return ret return None + CAPABILITIES = ["multi_ack", "side-band-64k", "ofs-delta"] + class GitClient(object): """Git smart server client. @@ -104,7 +116,7 @@ refs[ref] = sha return refs, server_capabilities - def send_pack(self, path, get_changed_refs, generate_pack_contents): + def send_pack(self, path, determine_wants, generate_pack_contents): """Upload a pack to a remote repository. :param path: Repository path @@ -112,47 +124,37 @@ objects to upload. """ refs, server_capabilities = self.read_refs() - changed_refs = get_changed_refs(refs) + changed_refs = determine_wants(refs) if not changed_refs: - print 'nothing changed' self.proto.write_pkt_line(None) - return None - return_refs = copy.copy(changed_refs) - + return {} want = [] have = [] sent_capabilities = False - for changed_ref in changed_refs: + for changed_ref, new_sha1 in changed_refs.iteritems(): + old_sha1 = refs.get(changed_ref, "0" * 40) if sent_capabilities: - self.proto.write_pkt_line("%s %s %s" % changed_ref) + self.proto.write_pkt_line("%s %s %s" % (old_sha1, new_sha1, changed_ref)) else: - self.proto.write_pkt_line("%s %s %s\0%s" % (changed_ref[0], changed_ref[1], changed_ref[2], self.capabilities())) + self.proto.write_pkt_line("%s %s %s\0%s" % (old_sha1, new_sha1, changed_ref, self.capabilities())) sent_capabilities = True - want.append(changed_ref[1]) - if changed_ref[0] != "0"*40: - have.append(changed_ref[0]) + want.append(new_sha1) + if old_sha1 != "0"*40: + have.append(old_sha1) self.proto.write_pkt_line(None) - shas = generate_pack_contents(want, have) - - # write packfile contents to a temp file - (fd, tmppath) = tempfile.mkstemp(suffix=".pack") - f = os.fdopen(fd, 'w') - (entries, sha) = write_pack_data(f, shas, len(shas)) - - # write that temp file to our filehandle - f = open(tmppath, "r") - self.proto.write_file(f) + objects = generate_pack_contents(want, have) + (entries, sha) = write_pack_data(self.proto.write_file(), objects, len(objects)) self.proto.write(sha) - f.close() # read the final confirmation sha - sha = self.proto.read(20) - if sha: - print "CONFIRM: " + sha_to_hex(sha) + client_sha = self.proto.read(20) + if not client_sha in (None, sha): + raise ChecksumMismatch(sha, client_sha) - return return_refs + return changed_refs - def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress): + def fetch_pack(self, path, determine_wants, graph_walker, pack_data, + progress): """Retrieve a pack from a git smart server. :param determine_wants: Callback that returns list of commits to fetch @@ -161,11 +163,10 @@ :param progress: Callback for progress reports (strings) """ (refs, server_capabilities) = self.read_refs() - refsreturn = copy.deepcopy(refs) wants = determine_wants(refs) if not wants: self.proto.write_pkt_line(None) - return + return refs self.proto.write_pkt_line("want %s %s\n" % (wants[0], self.capabilities())) for want in wants[1:]: self.proto.write_pkt_line("want %s\n" % want) @@ -198,7 +199,7 @@ progress(pkt) else: raise AssertionError("Invalid sideband channel %d" % channel) - return refsreturn + return refs class TCPGitClient(GitClient): @@ -233,17 +234,19 @@ :param progress: Callback for writing progress """ self.proto.send_cmd("git-upload-pack", path, "host=%s" % self.host) - return super(TCPGitClient, self).fetch_pack(path, determine_wants, graph_walker, pack_data, progress) + return super(TCPGitClient, self).fetch_pack(path, determine_wants, + graph_walker, pack_data, progress) class SubprocessGitClient(GitClient): + """Git client that talks to a server using a subprocess.""" def __init__(self, *args, **kwargs): self.proc = None self._args = args self._kwargs = kwargs - def _connect(self, service, *args): + def _connect(self, service, *args, **kwargs): argv = [service] + list(args) self.proc = subprocess.Popen(argv, bufsize=0, stdin=subprocess.PIPE, @@ -256,13 +259,30 @@ return GitClient(lambda: _fileno_can_read(self.proc.stdout.fileno()), read_fn, write_fn, *args, **kwargs) def send_pack(self, path, changed_refs, generate_pack_contents): + """Upload a pack to the server. + + :param path: Path to the git repository on the server + :param changed_refs: Dictionary with new values for the refs + :param generate_pack_contents: Function that returns an iterator over + objects to send + """ client = self._connect("git-receive-pack", path) return client.send_pack(path, changed_refs, generate_pack_contents) def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress): + """Retrieve a pack from the server + + :param path: Path to the git repository on the server + :param determine_wants: Function that receives existing refs + on the server and returns a list of desired shas + :param graph_walker: GraphWalker instance + :param pack_data: Function that can write pack data + :param progress: Function that can write progress texts + """ client = self._connect("git-upload-pack", path) - return client.fetch_pack(path, determine_wants, graph_walker, pack_data, progress) + return client.fetch_pack(path, determine_wants, graph_walker, pack_data, + progress) class SSHSubprocess(object): @@ -310,13 +330,15 @@ self._args = args self._kwargs = kwargs - def send_pack(self, path, changed_refs, generate_pack_contents): - remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack '%s'" % path], port=self.port) + def send_pack(self, path, determine_wants, generate_pack_contents): + remote = get_ssh_vendor().connect_ssh(self.host, ["git-receive-pack %s" % path], port=self.port) client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs) - return client.send_pack(path, changed_refs, generate_pack_contents) + return client.send_pack(path, determine_wants, generate_pack_contents) - def fetch_pack(self, path, determine_wants, graph_walker, pack_data, progress): - remote = get_ssh_vendor().connect_ssh(self.host, ["git-upload-pack '%s'" % path], port=self.port) + def fetch_pack(self, path, determine_wants, graph_walker, pack_data, + progress): + remote = get_ssh_vendor().connect_ssh(self.host, ["git-upload-pack %s" % path], port=self.port) client = GitClient(lambda: _fileno_can_read(remote.proc.stdout.fileno()), remote.recv, remote.send, *self._args, **self._kwargs) - return client.fetch_pack(path, determine_wants, graph_walker, pack_data, progress) + return client.fetch_pack(path, determine_wants, graph_walker, pack_data, + progress)
--- a/dulwich/errors.py +++ b/dulwich/errors.py @@ -1,5 +1,6 @@ # errors.py -- errors for dulwich # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net> +# Copyright (C) 2009 Jelmer Vernooij <jelmer@samba.org> # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License
--- a/dulwich/index.py +++ b/dulwich/index.py @@ -1,5 +1,5 @@ # index.py -- File parser/write for the git index file -# Copryight (C) 2008 Jelmer Vernooij <jelmer@samba.org> +# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -18,13 +18,24 @@ """Parser for the git index file format.""" +import os +import stat import struct +from dulwich.objects import ( + Tree, + hex_to_sha, + sha_to_hex, + ) + + def read_cache_time(f): + """Read a cache time.""" return struct.unpack(">LL", f.read(8)) def write_cache_time(f, t): + """Write a cache time.""" if isinstance(t, int): t = (t, 0) f.write(struct.pack(">LL", *t)) @@ -49,7 +60,8 @@ # Padding: real_size = ((f.tell() - beginoffset + 7) & ~7) f.seek(beginoffset + real_size) - return (name, ctime, mtime, ino, dev, mode, uid, gid, size, sha, flags) + return (name, ctime, mtime, ino, dev, mode, uid, gid, size, + sha_to_hex(sha), flags) def write_cache_entry(f, entry): @@ -63,7 +75,7 @@ (name, ctime, mtime, ino, dev, mode, uid, gid, size, sha, flags) = entry write_cache_time(f, ctime) write_cache_time(f, mtime) - f.write(struct.pack(">LLLLLL20sH", ino, dev, mode, uid, gid, size, sha, flags)) + f.write(struct.pack(">LLLLLL20sH", ino, dev, mode, uid, gid, size, hex_to_sha(sha), flags)) f.write(name) f.write(chr(0)) real_size = ((f.tell() - beginoffset + 7) & ~7) @@ -114,14 +126,29 @@ write_index(f, entries_list) +def cleanup_mode(mode): + if stat.S_ISLNK(fsmode): + mode = stat.S_IFLNK + else: + mode = stat.S_IFREG + mode |= (fsmode & 0111) + return mode + + class Index(object): + """A Git Index file.""" def __init__(self, filename): + """Open an index file. + + :param filename: Path to the index file + """ self._filename = filename self.clear() self.read() def write(self): + """Write current contents of index to disk.""" f = open(self._filename, 'w') try: write_index_dict(f, self._byname) @@ -129,24 +156,37 @@ f.close() def read(self): + """Read current contents of index from disk.""" f = open(self._filename, 'r') try: for x in read_index(f): - self[x[0]] = tuple(x[1:]) finally: f.close() def __len__(self): + """Number of entries in this index file.""" return len(self._byname) def __getitem__(self, name): + """Retrieve entry by relative path.""" return self._byname[name] + def __iter__(self): + """Iterate over the paths in this index.""" + return iter(self._byname) + def get_sha1(self, path): + """Return the (git object) SHA1 for the object at a path.""" return self[path][-2] + def iterblobs(self): + """Iterate over path, sha, mode tuples for use with commit_tree.""" + for path, entry in self: + yield path, entry[-2], cleanup_mode(entry[-6]) + def clear(self): + """Remove all contents from this index.""" self._byname = {} def __setitem__(self, name, x): @@ -161,3 +201,45 @@ def update(self, entries): for name, value in entries.iteritems(): self[name] = value + + +def commit_tree(object_store, blobs): + """Commit a new tree. + + :param object_store: Object store to add trees to + :param blobs: Iterable over blob path, sha, mode entries + :return: SHA1 of the created tree. + """ + trees = {"": {}} + def add_tree(path): + if path in trees: + return trees[path] + dirname, basename = os.path.split(path) + t = add_tree(dirname) + assert isinstance(basename, str) + newtree = {} + t[basename] = newtree + trees[path] = newtree + return newtree + + for path, sha, mode in blobs: + tree_path, basename = os.path.split(path) + tree = add_tree(tree_path) + tree[basename] = (mode, sha) + + def build_tree(path): + tree = Tree() + for basename, entry in trees[path].iteritems(): + if type(entry) == dict: + mode = stat.S_IFDIR + sha = build_tree(os.path.join(path, basename)) + else: + (mode, sha) = entry + tree.add(mode, basename, sha) + object_store.add_object(tree) + return tree.id + return build_tree("") + + +def commit_index(object_store, index): + return commit_tree(object_store, index.blobs())
--- a/dulwich/lru_cache.py +++ b/dulwich/lru_cache.py @@ -12,11 +12,42 @@ # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software -# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA """A simple least-recently-used (LRU) cache.""" -from collections import deque +_null_key = object() + +class _LRUNode(object): + """This maintains the linked-list which is the lru internals.""" + + __slots__ = ('prev', 'next_key', 'key', 'value', 'cleanup', 'size') + + def __init__(self, key, value, cleanup=None): + self.prev = None + self.next_key = _null_key + self.key = key + self.value = value + self.cleanup = cleanup + # TODO: We could compute this 'on-the-fly' like we used to, and remove + # one pointer from this object, we just need to decide if it + # actually costs us much of anything in normal usage + self.size = None + + def __repr__(self): + if self.prev is None: + prev_key = None + else: + prev_key = self.prev.key + return '%s(%r n:%r p:%r)' % (self.__class__.__name__, self.key, + self.next_key, prev_key) + + def run_cleanup(self): + if self.cleanup is not None: + self.cleanup(self.key, self.value) + self.cleanup = None + # Just make sure to break any refcycles, etc + self.value = None class LRUCache(object): @@ -24,48 +55,117 @@ def __init__(self, max_cache=100, after_cleanup_count=None): self._cache = {} - self._cleanup = {} - self._queue = deque() # Track when things are accessed - self._refcount = {} # number of entries in self._queue for each key + # The "HEAD" of the lru linked list + self._most_recently_used = None + # The "TAIL" of the lru linked list + self._least_recently_used = None self._update_max_cache(max_cache, after_cleanup_count) def __contains__(self, key): return key in self._cache def __getitem__(self, key): - val = self._cache[key] - self._record_access(key) - return val + cache = self._cache + node = cache[key] + # Inlined from _record_access to decrease the overhead of __getitem__ + # We also have more knowledge about structure if __getitem__ is + # succeeding, then we know that self._most_recently_used must not be + # None, etc. + mru = self._most_recently_used + if node is mru: + # Nothing to do, this node is already at the head of the queue + return node.value + # Remove this node from the old location + node_prev = node.prev + next_key = node.next_key + # benchmarking shows that the lookup of _null_key in globals is faster + # than the attribute lookup for (node is self._least_recently_used) + if next_key is _null_key: + # 'node' is the _least_recently_used, because it doesn't have a + # 'next' item. So move the current lru to the previous node. + self._least_recently_used = node_prev + else: + node_next = cache[next_key] + node_next.prev = node_prev + node_prev.next_key = next_key + # Insert this node at the front of the list + node.next_key = mru.key + mru.prev = node + self._most_recently_used = node + node.prev = None + return node.value def __len__(self): return len(self._cache) + def _walk_lru(self): + """Walk the LRU list, only meant to be used in tests.""" + node = self._most_recently_used + if node is not None: + if node.prev is not None: + raise AssertionError('the _most_recently_used entry is not' + ' supposed to have a previous entry' + ' %s' % (node,)) + while node is not None: + if node.next_key is _null_key: + if node is not self._least_recently_used: + raise AssertionError('only the last node should have' + ' no next value: %s' % (node,)) + node_next = None + else: + node_next = self._cache[node.next_key] + if node_next.prev is not node: + raise AssertionError('inconsistency found, node.next.prev' + ' != node: %s' % (node,)) + if node.prev is None: + if node is not self._most_recently_used: + raise AssertionError('only the _most_recently_used should' + ' not have a previous node: %s' + % (node,)) + else: + if node.prev.next_key != node.key: + raise AssertionError('inconsistency found, node.prev.next' + ' != node: %s' % (node,)) + yield node + node = node_next + def add(self, key, value, cleanup=None): """Add a new value to the cache. - Also, if the entry is ever removed from the queue, call cleanup. - Passing it the key and value being removed. + Also, if the entry is ever removed from the cache, call + cleanup(key, value). :param key: The key to store it under :param value: The object to store :param cleanup: None or a function taking (key, value) to indicate - 'value' sohuld be cleaned up. + 'value' should be cleaned up. """ + if key is _null_key: + raise ValueError('cannot use _null_key as a key') if key in self._cache: - self._remove(key) - self._cache[key] = value - if cleanup is not None: - self._cleanup[key] = cleanup - self._record_access(key) + node = self._cache[key] + node.run_cleanup() + node.value = value + node.cleanup = cleanup + else: + node = _LRUNode(key, value, cleanup=cleanup) + self._cache[key] = node + self._record_access(node) if len(self._cache) > self._max_cache: # Trigger the cleanup self.cleanup() + def cache_size(self): + """Get the number of entries we will cache.""" + return self._max_cache + def get(self, key, default=None): - if key in self._cache: - return self[key] - return default + node = self._cache.get(key, None) + if node is None: + return default + self._record_access(node) + return node.value def keys(self): """Get the list of keys currently cached. @@ -78,6 +178,10 @@ """ return self._cache.keys() + def items(self): + """Get the key:value pairs as a dict.""" + return dict((k, n.value) for k, n in self._cache.iteritems()) + def cleanup(self): """Clear the cache until it shrinks to the requested size. @@ -87,45 +191,54 @@ # Make sure the cache is shrunk to the correct size while len(self._cache) > self._after_cleanup_count: self._remove_lru() - # No need to compact the queue at this point, because the code that - # calls this would have already triggered it based on queue length def __setitem__(self, key, value): """Add a value to the cache, there will be no cleanup function.""" self.add(key, value, cleanup=None) - def _record_access(self, key): + def _record_access(self, node): """Record that key was accessed.""" - self._queue.append(key) - # Can't use setdefault because you can't += 1 the result - self._refcount[key] = self._refcount.get(key, 0) + 1 - - # If our access queue is too large, clean it up too - if len(self._queue) > self._compact_queue_length: - self._compact_queue() + # Move 'node' to the front of the queue + if self._most_recently_used is None: + self._most_recently_used = node + self._least_recently_used = node + return + elif node is self._most_recently_used: + # Nothing to do, this node is already at the head of the queue + return + # We've taken care of the tail pointer, remove the node, and insert it + # at the front + # REMOVE + if node is self._least_recently_used: + self._least_recently_used = node.prev + if node.prev is not None: + node.prev.next_key = node.next_key + if node.next_key is not _null_key: + node_next = self._cache[node.next_key] + node_next.prev = node.prev + # INSERT + node.next_key = self._most_recently_used.key + self._most_recently_used.prev = node + self._most_recently_used = node + node.prev = None - def _compact_queue(self): - """Compact the queue, leaving things in sorted last appended order.""" - new_queue = deque() - for item in self._queue: - if self._refcount[item] == 1: - new_queue.append(item) - else: - self._refcount[item] -= 1 - self._queue = new_queue - # All entries should be of the same size. There should be one entry in - # queue for each entry in cache, and all refcounts should == 1 - if not (len(self._queue) == len(self._cache) == - len(self._refcount) == sum(self._refcount.itervalues())): - raise AssertionError() - - def _remove(self, key): - """Remove an entry, making sure to maintain the invariants.""" - cleanup = self._cleanup.pop(key, None) - val = self._cache.pop(key) - if cleanup is not None: - cleanup(key, val) - return val + def _remove_node(self, node): + if node is self._least_recently_used: + self._least_recently_used = node.prev + self._cache.pop(node.key) + # If we have removed all entries, remove the head pointer as well + if self._least_recently_used is None: + self._most_recently_used = None + node.run_cleanup() + # Now remove this node from the linked list + if node.prev is not None: + node.prev.next_key = node.next_key + if node.next_key is not _null_key: + node_next = self._cache[node.next_key] + node_next.prev = node.prev + # And remove this node's pointers + node.prev = None + node.next_key = _null_key def _remove_lru(self): """Remove one entry from the lru, and handle consequences. @@ -133,11 +246,7 @@ If there are no more references to the lru, then this entry should be removed from the cache. """ - key = self._queue.popleft() - self._refcount[key] -= 1 - if not self._refcount[key]: - del self._refcount[key] - self._remove(key) + self._remove_node(self._least_recently_used) def clear(self): """Clear out all of the cache.""" @@ -155,11 +264,8 @@ if after_cleanup_count is None: self._after_cleanup_count = self._max_cache * 8 / 10 else: - self._after_cleanup_count = min(after_cleanup_count, self._max_cache) - - self._compact_queue_length = 4*self._max_cache - if len(self._queue) > self._compact_queue_length: - self._compact_queue() + self._after_cleanup_count = min(after_cleanup_count, + self._max_cache) self.cleanup() @@ -169,7 +275,8 @@ This differs in that it doesn't care how many actual items there are, it just restricts the cache to be cleaned up after so much data is stored. - The values that are added must support len(value). + The size of items added will be computed using compute_size(value), which + defaults to len() if not supplied. """ def __init__(self, max_size=1024*1024, after_cleanup_size=None, @@ -191,33 +298,41 @@ self._compute_size = compute_size if compute_size is None: self._compute_size = len - # This approximates that texts are > 0.5k in size. It only really - # effects when we clean up the queue, so we don't want it to be too - # large. self._update_max_size(max_size, after_cleanup_size=after_cleanup_size) LRUCache.__init__(self, max_cache=max(int(max_size/512), 1)) def add(self, key, value, cleanup=None): """Add a new value to the cache. - Also, if the entry is ever removed from the queue, call cleanup. - Passing it the key and value being removed. + Also, if the entry is ever removed from the cache, call + cleanup(key, value). :param key: The key to store it under :param value: The object to store :param cleanup: None or a function taking (key, value) to indicate - 'value' sohuld be cleaned up. + 'value' should be cleaned up. """ - if key in self._cache: - self._remove(key) + if key is _null_key: + raise ValueError('cannot use _null_key as a key') + node = self._cache.get(key, None) value_len = self._compute_size(value) if value_len >= self._after_cleanup_size: + # The new value is 'too big to fit', as it would fill up/overflow + # the cache all by itself + if node is not None: + # We won't be replacing the old node, so just remove it + self._remove_node(node) + if cleanup is not None: + cleanup(key, value) return + if node is None: + node = _LRUNode(key, value, cleanup=cleanup) + self._cache[key] = node + else: + self._value_size -= node.size + node.size = value_len self._value_size += value_len - self._cache[key] = value - if cleanup is not None: - self._cleanup[key] = cleanup - self._record_access(key) + self._record_access(node) if self._value_size > self._max_size: # Time to cleanup @@ -233,10 +348,9 @@ while self._value_size > self._after_cleanup_size: self._remove_lru() - def _remove(self, key): - """Remove an entry, making sure to maintain the invariants.""" - val = LRUCache._remove(self, key) - self._value_size -= self._compute_size(val) + def _remove_node(self, node): + self._value_size -= node.size + LRUCache._remove_node(self, node) def resize(self, max_size, after_cleanup_size=None): """Change the number of bytes that will be cached."""
--- a/dulwich/misc.py +++ b/dulwich/misc.py @@ -20,19 +20,12 @@ These utilities can all be deleted when dulwich decides it wants to stop support for python 2.4. """ - -from mercurial import demandimport -import __builtin__ -orig_import = __builtin__.__import__ -demandimport.disable() - try: import hashlib except ImportError: import sha import struct -__builtin__.__import__ = orig_import class defaultdict(dict): """A python 2.4 equivalent of collections.defaultdict."""
--- a/dulwich/object_store.py +++ b/dulwich/object_store.py @@ -1,5 +1,5 @@ # object_store.py -- Object store for git objects -# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org> +# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -16,7 +16,13 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. + +"""Git object store interfaces and implementation.""" + + +import itertools import os +import stat import tempfile import urllib2 @@ -25,6 +31,7 @@ ) from objects import ( ShaFile, + Tag, Tree, hex_to_sha, sha_to_hex, @@ -42,17 +49,9 @@ PACKDIR = 'pack' -class ObjectStore(object): - """Object store.""" - def __init__(self, path): - """Open an object store. - - :param path: Path of the object store. - """ - self.path = path - self._pack_cache = None - self.pack_dir = os.path.join(self.path, PACKDIR) +class BaseObjectStore(object): + """Object store interface.""" def determine_wants_all(self, refs): return [sha for (ref, sha) in refs.iteritems() if not sha in self and not ref.endswith("^{}")] @@ -65,6 +64,67 @@ return ObjectStoreIterator(self, shas) def __contains__(self, sha): + """Check if a particular object is present by SHA1.""" + raise NotImplementedError(self.__contains__) + + def get_raw(self, name): + """Obtain the raw text for an object. + + :param name: sha for the object. + :return: tuple with object type and object contents. + """ + raise NotImplementedError(self.get_raw) + + def __getitem__(self, sha): + """Obtain an object by SHA1.""" + type, uncomp = self.get_raw(sha) + return ShaFile.from_raw_string(type, uncomp) + + def __iter__(self): + """Iterate over the SHAs that are present in this store.""" + raise NotImplementedError(self.__iter__) + + def add_object(self, obj): + """Add a single object to this object store. + + """ + raise NotImplementedError(self.add_object) + + def add_objects(self, objects): + """Add a set of objects to this object store. + + :param objects: Iterable over a list of objects. + """ + raise NotImplementedError(self.add_objects) + + def find_missing_objects(self, wants, graph_walker, progress=None): + """Find the missing objects required for a set of revisions. + + :param wants: Iterable over SHAs of objects to fetch. + :param graph_walker: Object that can iterate over the list of revisions + to fetch and has an "ack" method that will be called to acknowledge + that a revision is present. + :param progress: Simple progress function that will be called with + updated progress strings. + :return: Iterator over (sha, path) pairs. + """ + return iter(MissingObjectFinder(self, wants, graph_walker, progress).next, None) + + +class DiskObjectStore(BaseObjectStore): + """Git-style object store that exists on disk.""" + + def __init__(self, path): + """Open an object store. + + :param path: Path of the object store. + """ + self.path = path + self._pack_cache = None + self.pack_dir = os.path.join(self.path, PACKDIR) + + def __contains__(self, sha): + """Check if a particular object is present by SHA1.""" for pack in self.packs: if sha in pack: return True @@ -73,6 +133,11 @@ return True return False + def __iter__(self): + """Iterate over the SHAs that are present in this store.""" + iterables = self.packs + [self._iter_shafile_shas()] + return itertools.chain(*iterables) + @property def packs(self): """List with pack objects.""" @@ -93,6 +158,13 @@ # Check from object dir return os.path.join(self.path, dir, file) + def _iter_shafile_shas(self): + for base in os.listdir(self.path): + if len(base) != 2: + continue + for rest in os.listdir(os.path.join(self.path, base)): + yield base+rest + def _get_shafile(self, sha): path = self._get_shafile_path(sha) if os.path.exists(path): @@ -133,13 +205,9 @@ hexsha = sha_to_hex(name) ret = self._get_shafile(hexsha) if ret is not None: - return ret.as_raw_string() + return ret.type, ret.as_raw_string() raise KeyError(hexsha) - def __getitem__(self, sha): - type, uncomp = self.get_raw(sha) - return ShaFile.from_raw_string(type, uncomp) - def move_in_thin_pack(self, path): """Move a specific file containing a pack into the pack directory. @@ -176,7 +244,7 @@ :param path: Path to the pack file. """ p = PackData(path) - entries = p.sorted_entries(self.get_raw) + entries = p.sorted_entries() basename = os.path.join(self.pack_dir, "pack-%s" % iter_sha1(entry[0] for entry in entries)) write_pack_index_v2(basename+".idx", entries, p.get_stored_checksum()) @@ -207,13 +275,16 @@ fd, path = tempfile.mkstemp(dir=self.pack_dir, suffix=".pack") f = os.fdopen(fd, 'w') def commit(): - #os.fsync(fd) - #f.close() + os.fsync(fd) + f.close() if os.path.getsize(path) > 0: self.move_in_pack(path) return f, commit def add_object(self, obj): + """Add a single object to this object store. + + """ self._add_shafile(obj.id, obj) def add_objects(self, objects): @@ -228,6 +299,45 @@ commit() +class MemoryObjectStore(BaseObjectStore): + + def __init__(self): + super(MemoryObjectStore, self).__init__() + self._data = {} + + def __contains__(self, sha): + return sha in self._data + + def __iter__(self): + """Iterate over the SHAs that are present in this store.""" + return self._data.iterkeys() + + def get_raw(self, name): + """Obtain the raw text for an object. + + :param name: sha for the object. + :return: tuple with object type and object contents. + """ + return self[name].as_raw_string() + + def __getitem__(self, name): + return self._data[name] + + def add_object(self, obj): + """Add a single object to this object store. + + """ + self._data[obj.id] = obj + + def add_objects(self, objects): + """Add a set of objects to this object store. + + :param objects: Iterable over a list of objects. + """ + for obj in objects: + self._data[obj.id] = obj + + class ObjectImporter(object): """Interface for importing objects.""" @@ -294,6 +404,12 @@ def tree_lookup_path(lookup_obj, root_sha, path): + """Lookup an object in a Git tree. + + :param lookup_obj: Callback for retrieving object by SHA1 + :param root_sha: SHA1 of the root tree + :param path: Path to lookup + """ parts = path.split("/") sha = root_sha for p in parts: @@ -304,3 +420,58 @@ continue mode, sha = obj[p] return lookup_obj(sha) + + +class MissingObjectFinder(object): + """Find the objects missing from another object store. + + :param object_store: Object store containing at least all objects to be + sent + :param wants: SHA1s of commits to send + :param graph_walker: graph walker object used to see what the remote + repo has and misses + :param progress: Optional function to report progress to. + """ + + def __init__(self, object_store, wants, graph_walker, progress=None): + self.sha_done = set() + self.objects_to_send = set([(w, None, False) for w in wants]) + self.object_store = object_store + if progress is None: + self.progress = lambda x: None + else: + self.progress = progress + ref = graph_walker.next() + while ref: + if ref in self.object_store: + graph_walker.ack(ref) + ref = graph_walker.next() + + def add_todo(self, entries): + self.objects_to_send.update([e for e in entries if not e[0] in self.sha_done]) + + def parse_tree(self, tree): + self.add_todo([(sha, name, not stat.S_ISDIR(mode)) for (mode, name, sha) in tree.entries()]) + + def parse_commit(self, commit): + self.add_todo([(commit.tree, "", False)]) + self.add_todo([(p, None, False) for p in commit.parents]) + + def parse_tag(self, tag): + self.add_todo([(tag.object[1], None, False)]) + + def next(self): + if not self.objects_to_send: + return None + (sha, name, leaf) = self.objects_to_send.pop() + if not leaf: + o = self.object_store[sha] + if isinstance(o, Commit): + self.parse_commit(o) + elif isinstance(o, Tree): + self.parse_tree(o) + elif isinstance(o, Tag): + self.parse_tag(o) + self.sha_done.add(sha) + self.progress("counting objects: %d\r" % len(self.sha_done)) + return (sha, name)
--- a/dulwich/objects.py +++ b/dulwich/objects.py @@ -1,17 +1,17 @@ # objects.py -- Access to base git objects # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net> -# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org> -# +# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> +# # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; version 2 # of the License or (at your option) a later version of the License. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, @@ -23,6 +23,8 @@ import mmap import os +import sha +import stat import zlib from errors import ( @@ -70,9 +72,20 @@ return ''.join([chr(int(hex[i:i+2], 16)) for i in xrange(0, len(hex), 2)]) +def serializable_property(name, docstring=None): + def set(obj, value): + obj._ensure_parsed() + setattr(obj, "_"+name, value) + obj._needs_serialization = True + def get(obj): + obj._ensure_parsed() + return getattr(obj, "_"+name) + return property(get, set, doc=docstring) + + class ShaFile(object): """A git SHA file.""" - + @classmethod def _parse_legacy_object(cls, map): """Parse a legacy object, creating it and setting object._text""" @@ -97,15 +110,31 @@ object._size = size assert text[0] == "\0", "Size not followed by null" text = text[1:] - object._text = text + object.set_raw_string(text) return object def as_legacy_object(self): - return zlib.compress("%s %d\0%s" % (self._type, len(self._text), self._text)) - + text = self.as_raw_string() + return zlib.compress("%s %d\0%s" % (self._type, len(text), text)) + def as_raw_string(self): - return self._num_type, self._text + if self._needs_serialization: + self.serialize() + return self._text + + def as_pretty_string(self): + return self.as_raw_string() + def _ensure_parsed(self): + if self._needs_parsing: + self._parse_text() + + def set_raw_string(self, text): + self._text = text + self._sha = None + self._needs_parsing = True + self._needs_serialization = False + @classmethod def _parse_object(cls, map): """Parse a new style object , creating it and setting object._text""" @@ -121,9 +150,9 @@ byte = ord(map[used]) used += 1 raw = map[used:] - object._text = _decompress(raw) + object.set_raw_string(_decompress(raw)) return object - + @classmethod def _parse_file(cls, map): word = (ord(map[0]) << 8) + ord(map[1]) @@ -131,13 +160,14 @@ return cls._parse_legacy_object(map) else: return cls._parse_object(map) - + def __init__(self): """Don't call this directly""" - + self._sha = None + def _parse_text(self): """For subclasses to do initialisation time parsing""" - + @classmethod def from_file(cls, filename): """Get the contents of a SHA file on disk""" @@ -146,49 +176,52 @@ try: map = mmap.mmap(f.fileno(), size, access=mmap.ACCESS_READ) shafile = cls._parse_file(map) - shafile._parse_text() return shafile finally: f.close() - + @classmethod def from_raw_string(cls, type, string): """Creates an object of the indicated type from the raw string given. - + Type is the numeric type of an object. String is the raw uncompressed contents. """ real_class = num_type_map[type] obj = real_class() - obj._num_type = type - obj._text = string - obj._parse_text() + obj.type = type + obj.set_raw_string(string) return obj - + def _header(self): - return "%s %lu\0" % (self._type, len(self._text)) - + return "%s %lu\0" % (self._type, len(self.as_raw_string())) + def sha(self): """The SHA1 object that is the name of this object.""" - ressha = make_sha() - ressha.update(self._header()) - ressha.update(self._text) - return ressha - + if self._needs_serialization or self._sha is None: + self._sha = make_sha() + self._sha.update(self._header()) + self._sha.update(self.as_raw_string()) + return self._sha + @property def id(self): return self.sha().hexdigest() - - @property - def type(self): + + def get_type(self): return self._num_type + def set_type(self, type): + self._num_type = type + + type = property(get_type, set_type) + def __repr__(self): return "<%s %s>" % (self.__class__.__name__, self.id) - + def __eq__(self, other): """Return true id the sha of the two objects match. - + The __le__ etc methods aren't overriden as they make no sense, certainly at this level. """ @@ -200,12 +233,18 @@ _type = BLOB_ID _num_type = 3 + _needs_serialization = False + _needs_parsing = False - @property - def data(self): - """The text contained within the blob object.""" + def get_data(self): return self._text + def set_data(self, data): + self._text = data + + data = property(get_data, set_data, + "The text contained within the blob object.") + @classmethod def from_file(cls, filename): blob = ShaFile.from_file(filename) @@ -217,7 +256,7 @@ def from_string(cls, string): """Create a blob from a string.""" shafile = cls() - shafile._text = string + shafile.set_raw_string(string) return shafile @@ -238,7 +277,7 @@ def from_string(cls, string): """Create a blob from a string.""" shafile = cls() - shafile._text = string + shafile.set_raw_string(string) return shafile def _parse_text(self): @@ -306,33 +345,48 @@ assert text[count] == '\n', "There must be a new line after the headers" count += 1 self._message = text[count:] + self._needs_parsing = False - @property - def object(self): + def get_object(self): """Returns the object pointed by this tag, represented as a tuple(type, sha)""" + self._ensure_parsed() return (self._object_type, self._object_sha) - @property - def name(self): - """Returns the name of this tag""" - return self._name + object = property(get_object) - @property - def tagger(self): - """Returns the name of the person who created this tag""" - return self._tagger + name = serializable_property("name", "The name of this tag") + tagger = serializable_property("tagger", + "Returns the name of the person who created this tag") + tag_time = serializable_property("tag_time", + "The creation timestamp of the tag. As the number of seconds since the epoch") + message = serializable_property("message", "The message attached to this tag") + - @property - def tag_time(self): - """Returns the creation timestamp of the tag. - - Returns it as the number of seconds since the epoch""" - return self._tag_time - - @property - def message(self): - """Returns the message attached to this tag""" - return self._message +def parse_tree(text): + ret = {} + count = 0 + while count < len(text): + mode = 0 + chr = text[count] + while chr != ' ': + assert chr >= '0' and chr <= '7', "%s is not a valid mode char" % chr + mode = (mode << 3) + (ord(chr) - ord('0')) + count += 1 + chr = text[count] + count += 1 + chr = text[count] + name = '' + while chr != '\0': + name += chr + count += 1 + chr = text[count] + count += 1 + chr = text[count] + sha = text[count:count+20] + hexsha = sha_to_hex(sha) + ret[name] = (mode, hexsha) + count = count + 20 + return ret class Tree(ShaFile): @@ -342,7 +396,10 @@ _num_type = 2 def __init__(self): + super(Tree, self).__init__() self._entries = {} + self._needs_parsing = False + self._needs_serialization = True @classmethod def from_file(cls, filename): @@ -351,63 +408,78 @@ raise NotTreeError(filename) return tree + def __contains__(self, name): + self._ensure_parsed() + return name in self._entries + def __getitem__(self, name): + self._ensure_parsed() return self._entries[name] def __setitem__(self, name, value): assert isinstance(value, tuple) assert len(value) == 2 + self._ensure_parsed() self._entries[name] = value + self._needs_serialization = True def __delitem__(self, name): + self._ensure_parsed() del self._entries[name] + self._needs_serialization = True def add(self, mode, name, hexsha): + assert type(mode) == int + assert type(name) == str + assert type(hexsha) == str + self._ensure_parsed() self._entries[name] = mode, hexsha + self._needs_serialization = True def entries(self): """Return a list of tuples describing the tree entries""" - return [(mode, name, hexsha) for (name, (mode, hexsha)) in self._entries.iteritems()] - - def entry(self, name): - try: - return self._entries[name] - except: - return (None, None) + self._ensure_parsed() + # The order of this is different from iteritems() for historical reasons + return [(mode, name, hexsha) for (name, mode, hexsha) in self.iteritems()] def iteritems(self): + self._ensure_parsed() for name in sorted(self._entries.keys()): - yield name, self_entries[name][0], self._entries[name][1] + yield name, self._entries[name][0], self._entries[name][1] def _parse_text(self): """Grab the entries in the tree""" - count = 0 - while count < len(self._text): - mode = 0 - chr = self._text[count] - while chr != ' ': - assert chr >= '0' and chr <= '7', "%s is not a valid mode char" % chr - mode = (mode << 3) + (ord(chr) - ord('0')) - count += 1 - chr = self._text[count] - count += 1 - chr = self._text[count] - name = '' - while chr != '\0': - name += chr - count += 1 - chr = self._text[count] - count += 1 - chr = self._text[count] - sha = self._text[count:count+20] - hexsha = sha_to_hex(sha) - self.add(mode, name, hexsha) - count = count + 20 + self._entries = parse_tree(self._text) + self._needs_parsing = False def serialize(self): self._text = "" for name, mode, hexsha in self.iteritems(): self._text += "%04o %s\0%s" % (mode, name, hex_to_sha(hexsha)) + self._needs_serialization = False + + def as_pretty_string(self): + text = "" + for name, mode, hexsha in self.iteritems(): + if mode & stat.S_IFDIR: + kind = "tree" + else: + kind = "blob" + text += "%04o %s %s\t%s\n" % (mode, kind, hexsha, name) + return text + + +def parse_timezone(text): + offset = int(text) + hours = int(offset / 100) + minutes = (offset % 100) + return (hours * 3600) + (minutes * 60) + + +def format_timezone(offset): + if offset % 60 != 0: + raise ValueError("Unable to handle non-minute offset.") + return '%+03d%02d' % (offset / 3600, (offset / 60) % 60) class Commit(ShaFile): @@ -417,7 +489,10 @@ _num_type = 1 def __init__(self): + super(Commit, self).__init__() self._parents = [] + self._needs_parsing = False + self._needs_serialization = True @classmethod def from_file(cls, filename): @@ -469,8 +544,9 @@ count += 1 self._author_time = int(text[count:].split(" ", 1)[0]) while text[count] != ' ': + assert text[count] != '\n', "Malformed author information" count += 1 - self._author_timezone = int(text[count:count+6]) + self._author_timezone = parse_timezone(text[count:count+6]) count += 1 while text[count] != '\n': count += 1 @@ -493,8 +569,9 @@ count += 1 self._commit_time = int(text[count:count+10]) while text[count] != ' ': + assert text[count] != '\n', "Malformed committer information" count += 1 - self._commit_timezone = int(text[count:count+6]) + self._commit_timezone = parse_timezone(text[count:count+6]) count += 1 while text[count] != '\n': count += 1 @@ -503,69 +580,54 @@ count += 1 # XXX: There can be an encoding field. self._message = text[count:] + self._needs_parsing = False def serialize(self): self._text = "" self._text += "%s %s\n" % (TREE_ID, self._tree) for p in self._parents: self._text += "%s %s\n" % (PARENT_ID, p) - self._text += "%s %s %s %+05d\n" % (AUTHOR_ID, self._author, str(self._author_time), self._author_timezone) - self._text += "%s %s %s %+05d\n" % (COMMITTER_ID, self._committer, str(self._commit_time), self._commit_timezone) + self._text += "%s %s %s %s\n" % (AUTHOR_ID, self._author, str(self._author_time), format_timezone(self._author_timezone)) + self._text += "%s %s %s %s\n" % (COMMITTER_ID, self._committer, str(self._commit_time), format_timezone(self._commit_timezone)) self._text += "\n" # There must be a new line after the headers self._text += self._message + self._needs_serialization = False - @property - def tree(self): - """Returns the tree that is the state of this commit""" - return self._tree + tree = serializable_property("tree", "Tree that is the state of this commit") - @property - def parents(self): + def get_parents(self): """Return a list of parents of this commit.""" + self._ensure_parsed() return self._parents - @property - def author(self): - """Returns the name of the author of the commit""" - return self._author + def set_parents(self, value): + """Return a list of parents of this commit.""" + self._ensure_parsed() + self._needs_serialization = True + self._parents = value - @property - def committer(self): - """Returns the name of the committer of the commit""" - return self._committer + parents = property(get_parents, set_parents) - @property - def message(self): - """Returns the commit message""" - return self._message + author = serializable_property("author", + "The name of the author of the commit") - @property - def commit_time(self): - """Returns the timestamp of the commit. + committer = serializable_property("committer", + "The name of the committer of the commit") - Returns it as the number of seconds since the epoch. - """ - return self._commit_time + message = serializable_property("message", + "The commit message") - @property - def commit_timezone(self): - """Returns the zone the commit time is in - """ - return self._commit_timezone + commit_time = serializable_property("commit_time", + "The timestamp of the commit. As the number of seconds since the epoch.") - @property - def author_time(self): - """Returns the timestamp the commit was written. + commit_timezone = serializable_property("commit_timezone", + "The zone the commit time is in") - Returns it as the number of seconds since the epoch. - """ - return self._author_time + author_time = serializable_property("author_time", + "The timestamp the commit was written. as the number of seconds since the epoch.") - @property - def author_timezone(self): - """Returns the zone the author time is in - """ - return self._author_timezone + author_timezone = serializable_property("author_timezone", + "Returns the zone the author time is in.") type_map = { @@ -586,6 +648,7 @@ try: # Try to import C versions - from _objects import hex_to_sha, sha_to_hex + from dulwich._objects import hex_to_sha, sha_to_hex, parse_tree except ImportError: pass +
--- a/dulwich/pack.py +++ b/dulwich/pack.py @@ -1,6 +1,6 @@ # pack.py -- For dealing wih packed git objects. # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net> -# Copryight (C) 2008 Jelmer Vernooij <jelmer@samba.org> +# Copryight (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> # # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License @@ -80,7 +80,7 @@ return ret -def read_zlib(data, offset, dec_size): +def read_zlib_chunks(data, offset, dec_size): obj = zlib.decompressobj() ret = [] fed = 0 @@ -91,14 +91,23 @@ add += "Z" fed += len(add) ret.append(obj.decompress(add)) + comp_len = fed-len(obj.unused_data) + return ret, comp_len + + +def read_zlib(data, offset, dec_size): + ret, comp_len = read_zlib_chunks(data, offset, dec_size) x = "".join(ret) assert len(x) == dec_size - comp_len = fed-len(obj.unused_data) return x, comp_len def iter_sha1(iter): - """Return the hexdigest of the SHA1 over a set of names.""" + """Return the hexdigest of the SHA1 over a set of names. + + :param iter: Iterator over string objects + :return: 40-byte hex sha1 digest + """ sha1 = make_sha() for name in iter: sha1.update(name) @@ -114,14 +123,9 @@ :param access: Access mechanism. :return: MMAP'd area. """ - #print f, offset, size - if supports_mmap_offset: - mem = mmap.mmap(f.fileno(), size + offset % mmap.PAGESIZE, access=access, - offset=offset / mmap.PAGESIZE * mmap.PAGESIZE) - return mem, offset % mmap.PAGESIZE - else: - mem = mmap.mmap(f.fileno(), size+offset, access=access) - return mem, offset + mem = mmap.mmap(f.fileno(), size+offset, access=access) + return mem, offset + def load_pack_index(filename): f = open(filename, 'r') @@ -254,14 +258,24 @@ return self.calculate_checksum() == self.get_stored_checksum() def calculate_checksum(self): + """Calculate the SHA1 checksum over this pack index. + + :return: This is a 20-byte binary digest + """ return make_sha(self._contents[:-20]).digest() def get_pack_checksum(self): - """Return the SHA1 checksum stored for the corresponding packfile.""" + """Return the SHA1 checksum stored for the corresponding packfile. + + :return: 20-byte binary digest + """ return str(self._contents[-40:-20]) def get_stored_checksum(self): - """Return the SHA1 checksum stored for this index.""" + """Return the SHA1 checksum stored for this index. + + :return: 20-byte binary digest + """ return str(self._contents[-20:]) def object_index(self, sha): @@ -352,6 +366,10 @@ def read_pack_header(f): + """Read the header of a pack file. + + :param f: File-like object to read from + """ header = f.read(12) assert header[:4] == "PACK" (version,) = unpack_from(">L", header, 4) @@ -365,6 +383,10 @@ def unpack_object(map, offset=0): + """Unpack a Git object. + + :return: tuple with type, uncompressed data and compressed size + """ bytes = take_msb_bytes(map, offset) type = (bytes[0] >> 4) & 0x07 size = bytes[0] & 0x0f @@ -395,6 +417,8 @@ def compute_object_size((num, obj)): + """Compute the size of a unresolved object for use with LRUSizeCache. + """ if num in (6, 7): return len(obj[1]) assert isinstance(obj, str) @@ -460,15 +484,15 @@ return self._num_objects def calculate_checksum(self): - """Calculate the checksum for this pack.""" + """Calculate the checksum for this pack. + + :return: 20-byte binary SHA1 digest + """ map, map_offset = simple_mmap(self._file, 0, self._size - 20) try: - r = make_sha(map[map_offset:self._size-20]).digest() + return make_sha(map[map_offset:self._size-20]).digest() + finally: map.close() - return r - except: - map.close() - raise def resolve_object(self, offset, type, obj, get_ref, get_offset=None): """Resolve an object, possibly resolving deltas when necessary. @@ -504,31 +528,47 @@ return ret def iterobjects(self, progress=None): - offset = self._header_size - num = len(self) - map, _ = simple_mmap(self._file, 0, self._size) - try: - for i in range(num): - (type, obj, total_size) = unpack_object(map, offset) - crc32 = zlib.crc32(map[offset:offset+total_size]) & 0xffffffff - yield offset, type, obj, crc32 - offset += total_size + + class ObjectIterator(object): + + def __init__(self, pack): + self.i = 0 + self.offset = pack._header_size + self.num = len(pack) + self.map, _ = simple_mmap(pack._file, 0, pack._size) + + def __del__(self): + self.map.close() + + def __iter__(self): + return self + + def __len__(self): + return self.num + + def next(self): + if self.i == self.num: + raise StopIteration + (type, obj, total_size) = unpack_object(self.map, self.offset) + crc32 = zlib.crc32(self.map[self.offset:self.offset+total_size]) & 0xffffffff + ret = (self.offset, type, obj, crc32) + self.offset += total_size if progress: - progress(i, num) - map.close() - except: - map.close() - raise + progress(self.i, self.num) + self.i+=1 + return ret + return ObjectIterator(self) def iterentries(self, ext_resolve_ref=None, progress=None): found = {} postponed = defaultdict(list) class Postpone(Exception): """Raised to postpone delta resolving.""" + def get_ref_text(sha): assert len(sha) == 20 if sha in found: - return found[sha] + return self.get_object_at(found[sha]) if ext_resolve_ref: try: return ext_resolve_ref(sha) @@ -548,7 +588,7 @@ else: shafile = ShaFile.from_raw_string(type, obj) sha = shafile.sha().digest() - found[sha] = (type, obj) + found[sha] = offset yield sha, offset, crc32 extra.extend(postponed.get(sha, [])) if postponed: @@ -560,12 +600,41 @@ return ret def create_index_v1(self, filename, resolve_ext_ref=None, progress=None): + """Create a version 1 file for this data file. + + :param filename: Index filename. + :param resolve_ext_ref: Function to use for resolving externally referenced + SHA1s (for thin packs) + :param progress: Progress report function + """ entries = self.sorted_entries(resolve_ext_ref, progress=progress) write_pack_index_v1(filename, entries, self.calculate_checksum()) def create_index_v2(self, filename, resolve_ext_ref=None, progress=None): + """Create a version 2 index file for this data file. + + :param filename: Index filename. + :param resolve_ext_ref: Function to use for resolving externally referenced + SHA1s (for thin packs) + :param progress: Progress report function + """ entries = self.sorted_entries(resolve_ext_ref, progress=progress) write_pack_index_v2(filename, entries, self.calculate_checksum()) + + def create_index(self, filename, resolve_ext_ref=None, progress=None, version=2): + """Create an index file for this data file. + + :param filename: Index filename. + :param resolve_ext_ref: Function to use for resolving externally referenced + SHA1s (for thin packs) + :param progress: Progress report function + """ + if version == 1: + self.create_index_v1(filename, resolve_ext_ref, progress) + elif version == 2: + self.create_index_v2(filename, resolve_ext_ref, progress) + else: + raise ValueError("unknown index format %d" % version) def get_stored_checksum(self): return self._stored_checksum @@ -594,6 +663,8 @@ class SHA1Writer(object): + """Wrapper around a file-like object that remembers the SHA1 of + the data written to it.""" def __init__(self, f): self.f = f @@ -656,6 +727,12 @@ def write_pack(filename, objects, num_objects): + """Write a new pack data file. + + :param filename: Path to the new pack file (without .pack extension) + :param objects: Iterable over (object, path) tuples to write + :param num_objects: Number of objects to write + """ f = open(filename + ".pack", 'w') try: entries, data_sum = write_pack_data(f, objects, num_objects) @@ -679,7 +756,7 @@ # This helps us find good objects to diff against us magic = [] for obj, path in recency: - magic.append( (obj.type, path, 1, -len(obj.as_raw_string()[1]), obj) ) + magic.append( (obj.type, path, 1, -len(obj.as_raw_string()), obj) ) magic.sort() # Build a map of objects and their index in magic - so we can find preceeding objects # to diff against @@ -694,14 +771,15 @@ f.write(struct.pack(">L", num_objects)) # Number of objects in pack for o, path in recency: sha1 = o.sha().digest() - orig_t, raw = o.as_raw_string() + orig_t = o.type + raw = o.as_raw_string() winner = raw t = orig_t #for i in range(offs[o]-window, window): # if i < 0 or i >= len(offs): continue # b = magic[i][4] # if b.type != orig_t: continue - # _, base = b.as_raw_string() + # base = b.as_raw_string() # delta = create_delta(base, raw) # if len(delta) < len(winner): # winner = delta @@ -736,7 +814,11 @@ def create_delta(base_buf, target_buf): - """Use python difflib to work out how to transform base_buf to target_buf""" + """Use python difflib to work out how to transform base_buf to target_buf. + + :param base_buf: Base buffer + :param target_buf: Target buffer + """ assert isinstance(base_buf, str) assert isinstance(target_buf, str) out_buf = "" @@ -888,6 +970,7 @@ class Pack(object): + """A Git pack object.""" def __init__(self, basename): self._basename = basename @@ -898,6 +981,7 @@ @classmethod def from_objects(self, data, idx): + """Create a new pack object from pack data and index objects.""" ret = Pack("") ret._data = data ret._idx = idx @@ -905,14 +989,15 @@ def name(self): """The SHA over the SHAs of the objects in this pack.""" - return self.idx.objects_sha1() + return self.index.objects_sha1() @property def data(self): + """The pack data object being used.""" if self._data is None: self._data = PackData(self._data_path) - assert len(self.idx) == len(self._data) - idx_stored_checksum = self.idx.get_pack_checksum() + assert len(self.index) == len(self._data) + idx_stored_checksum = self.index.get_pack_checksum() data_stored_checksum = self._data.get_stored_checksum() if idx_stored_checksum != data_stored_checksum: raise ChecksumMismatch(sha_to_hex(idx_stored_checksum), @@ -920,7 +1005,11 @@ return self._data @property - def idx(self): + def index(self): + """The index being used. + + :note: This may be an in-memory index + """ if self._idx is None: self._idx = load_pack_index(self._idx_path) return self._idx @@ -928,24 +1017,25 @@ def close(self): if self._data is not None: self._data.close() - self.idx.close() + self.index.close() def __eq__(self, other): - return type(self) == type(other) and self.idx == other.idx + return type(self) == type(other) and self.index == other.index def __len__(self): """Number of entries in this pack.""" - return len(self.idx) + return len(self.index) def __repr__(self): return "%s(%r)" % (self.__class__.__name__, self._basename) def __iter__(self): """Iterate over all the sha1s of the objects in this pack.""" - return iter(self.idx) + return iter(self.index) def check(self): - if not self.idx.check(): + """Check the integrity of this pack.""" + if not self.index.check(): return False if not self.data.check(): return False @@ -957,13 +1047,13 @@ def __contains__(self, sha1): """Check whether this pack contains a particular SHA1.""" try: - self.idx.object_index(sha1) + self.index.object_index(sha1) return True except KeyError: return False def get_raw(self, sha1, resolve_ref=None): - offset = self.idx.object_index(sha1) + offset = self.index.object_index(sha1) obj_type, obj = self.data.get_object_at(offset) if type(offset) is long: offset = int(offset) @@ -977,6 +1067,7 @@ return ShaFile.from_raw_string(type, uncomp) def iterobjects(self, get_raw=None): + """Iterate over the objects in this pack.""" if get_raw is None: get_raw = self.get_raw for offset, type, obj, crc32 in self.data.iterobjects():
--- a/dulwich/protocol.py +++ b/dulwich/protocol.py @@ -79,13 +79,6 @@ yield pkt pkt = self.read_pkt_line() - def write_file(self, f): - try: - for line in f: - self.write(line) - finally: - f.close() - def write_pkt_line(self, line): """ Sends a 'pkt line' to the remote git process @@ -104,6 +97,25 @@ except socket.error, e: raise GitProtocolError(e) + def write_file(self): + class ProtocolFile(object): + + def __init__(self, proto): + self._proto = proto + self._offset = 0 + + def write(self, data): + self._proto.write(data) + self._offset += len(data) + + def tell(self): + return self._offset + + def close(self): + pass + + return ProtocolFile(self) + def write_sideband(self, channel, blob): """ Write data to the sideband (a git multiplexing method)
--- a/dulwich/repo.py +++ b/dulwich/repo.py @@ -1,23 +1,25 @@ # repo.py -- For dealing wih git repositories. # Copyright (C) 2007 James Westby <jw+debian@jameswestby.net> -# Copyright (C) 2008 Jelmer Vernooij <jelmer@samba.org> -# +# Copyright (C) 2008-2009 Jelmer Vernooij <jelmer@samba.org> +# # This program is free software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; version 2 -# of the License or (at your option) any later version of +# of the License or (at your option) any later version of # the License. -# +# # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. -# +# # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. +"""Repository access.""" + import os import stat import zlib @@ -42,9 +44,11 @@ OBJECTDIR = 'objects' SYMREF = 'ref: ' - +REFSDIR = 'refs' +INDEX_FILENAME = "index" class Tags(object): + """Tags container.""" def __init__(self, tagdir, tags): self.tagdir = tagdir @@ -52,7 +56,7 @@ def __getitem__(self, name): return self.tags[name] - + def __setitem__(self, name, ref): self.tags[name] = ref f = open(os.path.join(self.tagdir, name), 'wb') @@ -70,69 +74,33 @@ def read_packed_refs(f): + """Read a packed refs file. + + Yields tuples with ref names and SHA1s. + + :param f: file-like object to read from + """ l = f.readline() - assert l == "# pack-refs with: peeled \n" for l in f.readlines(): + if l[0] == "#": + # Comment + continue if l[0] == "^": # FIXME: Return somehow continue yield tuple(l.rstrip("\n").split(" ", 2)) -class MissingObjectFinder(object): - - def __init__(self, object_store, wants, graph_walker, progress=None): - self.sha_done = set() - self.objects_to_send = set([(w, None) for w in wants]) - self.object_store = object_store - if progress is None: - self.progress = lambda x: None - else: - self.progress = progress - ref = graph_walker.next() - while ref: - if ref in self.object_store: - graph_walker.ack(ref) - ref = graph_walker.next() - - def add_todo(self, entries): - self.objects_to_send.update([e for e in entries if not e in self.sha_done]) - - def parse_tree(self, tree): - self.add_todo([(sha, name) for (mode, name, sha) in tree.entries()]) +class Repo(object): + """A local git repository.""" - def parse_commit(self, commit): - self.add_todo([(commit.tree, "")]) - self.add_todo([(p, None) for p in commit.parents]) - - def parse_tag(self, tag): - self.add_todo([(tag.object[1], None)]) - - def next(self): - if not self.objects_to_send: - return None - (sha, name) = self.objects_to_send.pop() - o = self.object_store[sha] - if isinstance(o, Commit): - self.parse_commit(o) - elif isinstance(o, Tree): - self.parse_tree(o) - elif isinstance(o, Tag): - self.parse_tag(o) - self.sha_done.add((sha, name)) - self.progress("counting objects: %d\r" % len(self.sha_done)) - return (sha, name) - - -class Repo(object): - - ref_locs = ['', 'refs', 'refs/tags', 'refs/heads', 'refs/remotes'] + ref_locs = ['', REFSDIR, 'refs/tags', 'refs/heads', 'refs/remotes'] def __init__(self, root): - if os.path.isdir(os.path.join(root, ".git", "objects")): + if os.path.isdir(os.path.join(root, ".git", OBJECTDIR)): self.bare = False self._controldir = os.path.join(root, ".git") - elif os.path.isdir(os.path.join(root, "objects")): + elif os.path.isdir(os.path.join(root, OBJECTDIR)): self.bare = True self._controldir = root else: @@ -142,32 +110,47 @@ self._object_store = None def controldir(self): + """Return the path of the control directory.""" return self._controldir + def index_path(self): + """Return path to the index file.""" + return os.path.join(self.controldir(), INDEX_FILENAME) + + def open_index(self): + """Open the index for this repository.""" + from dulwich.index import Index + return Index(self.index_path()) + + def has_index(self): + """Check if an index is present.""" + return os.path.exists(self.index_path()) + def find_missing_objects(self, determine_wants, graph_walker, progress): """Find the missing objects required for a set of revisions. - :param determine_wants: Function that takes a dictionary with heads + :param determine_wants: Function that takes a dictionary with heads and returns the list of heads to fetch. - :param graph_walker: Object that can iterate over the list of revisions - to fetch and has an "ack" method that will be called to acknowledge + :param graph_walker: Object that can iterate over the list of revisions + to fetch and has an "ack" method that will be called to acknowledge that a revision is present. - :param progress: Simple progress function that will be called with + :param progress: Simple progress function that will be called with updated progress strings. + :return: Iterator over (sha, path) pairs. """ wants = determine_wants(self.get_refs()) - return iter(MissingObjectFinder(self.object_store, wants, graph_walker, - progress).next, None) + return self.object_store.find_missing_objects(wants, + graph_walker, progress) def fetch_objects(self, determine_wants, graph_walker, progress): """Fetch the missing objects required for a set of revisions. - :param determine_wants: Function that takes a dictionary with heads + :param determine_wants: Function that takes a dictionary with heads and returns the list of heads to fetch. - :param graph_walker: Object that can iterate over the list of revisions - to fetch and has an "ack" method that will be called to acknowledge + :param graph_walker: Object that can iterate over the list of revisions + to fetch and has an "ack" method that will be called to acknowledge that a revision is present. - :param progress: Simple progress function that will be called with + :param progress: Simple progress function that will be called with updated progress strings. :return: tuple with number of objects, iterator over objects """ @@ -175,12 +158,13 @@ self.find_missing_objects(determine_wants, graph_walker, progress)) def object_dir(self): + """Return path of the object directory.""" return os.path.join(self.controldir(), OBJECTDIR) @property def object_store(self): if self._object_store is None: - self._object_store = ObjectStore(self.object_dir()) + self._object_store = DiskObjectStore(self.object_dir()) return self._object_store def pack_dir(self): @@ -201,6 +185,7 @@ f.close() def ref(self, name): + """Return the SHA1 a ref is pointing to.""" for dir in self.ref_locs: file = os.path.join(self.controldir(), dir, name) if os.path.exists(file): @@ -210,6 +195,7 @@ return packed_refs[name] def get_refs(self): + """Get dictionary with all refs.""" ret = {} if self.head(): ret['HEAD'] = self.head() @@ -222,6 +208,13 @@ return ret def get_packed_refs(self): + """Get contents of the packed-refs file. + + :return: Dictionary mapping ref names to SHA1s + + :note: Will return an empty dictionary when no packed-refs file is + present. + """ path = os.path.join(self.controldir(), 'packed-refs') if not os.path.exists(path): return {} @@ -270,12 +263,17 @@ f.close() def remove_ref(self, name): + """Remove a ref. + + :param name: Name of the ref + """ file = os.path.join(self.controldir(), name) if os.path.exists(file): os.remove(file) def tagdir(self): - return os.path.join(self.controldir(), 'refs', 'tags') + """Tag directory.""" + return os.path.join(self.controldir(), REFSDIR, 'tags') def get_tags(self): ret = {} @@ -506,11 +504,11 @@ @classmethod def init_bare(cls, path, mkdir=True): - for d in [["objects"], - ["objects", "info"], - ["objects", "pack"], + for d in [[OBJECTDIR], + [OBJECTDIR, "info"], + [OBJECTDIR, "pack"], ["branches"], - ["refs"], + [REFSDIR], ["refs", "tags"], ["refs", "heads"], ["hooks"], @@ -521,3 +519,4 @@ open(os.path.join(path, 'info', 'excludes'), 'w').write("") create = init_bare +
--- a/dulwich/server.py +++ b/dulwich/server.py @@ -16,6 +16,10 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, # MA 02110-1301, USA. + +"""Git smart network protocol server implementation.""" + + import SocketServer import tempfile @@ -88,6 +92,7 @@ class Handler(object): + """Smart protocol command handler base class.""" def __init__(self, backend, read, write): self.backend = backend @@ -98,6 +103,7 @@ class UploadPackHandler(Handler): + """Protocol handler for uploading a pack to the server.""" def default_capabilities(self): return ("multi_ack", "side-band-64k", "thin-pack", "ofs-delta") @@ -170,6 +176,7 @@ class ReceivePackHandler(Handler): + """Protocol handler for downloading a pack to the client.""" def default_capabilities(self): return ("report-status", "delete-refs") @@ -204,7 +211,8 @@ # backend can now deal with this refs and read a pack using self.read self.backend.apply_pack(client_refs, self.proto.read) - # when we have read all the pack from the client, it assumes everything worked OK + # when we have read all the pack from the client, it assumes + # everything worked OK. # there is NO ack from the server before it reports victory.
new file mode 100644 --- /dev/null +++ b/dulwich/tests/test_lru_cache.py @@ -0,0 +1,447 @@ +# Copyright (C) 2006, 2008 Canonical Ltd +# +# This program is free software; you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation; either version 2 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program; if not, write to the Free Software +# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + +"""Tests for the lru_cache module.""" + +from dulwich import ( + lru_cache, + ) +import unittest + + +class TestLRUCache(unittest.TestCase): + """Test that LRU cache properly keeps track of entries.""" + + def test_cache_size(self): + cache = lru_cache.LRUCache(max_cache=10) + self.assertEqual(10, cache.cache_size()) + + cache = lru_cache.LRUCache(max_cache=256) + self.assertEqual(256, cache.cache_size()) + + cache.resize(512) + self.assertEqual(512, cache.cache_size()) + + def test_missing(self): + cache = lru_cache.LRUCache(max_cache=10) + + self.failIf('foo' in cache) + self.assertRaises(KeyError, cache.__getitem__, 'foo') + + cache['foo'] = 'bar' + self.assertEqual('bar', cache['foo']) + self.failUnless('foo' in cache) + self.failIf('bar' in cache) + + def test_map_None(self): + # Make sure that we can properly map None as a key. + cache = lru_cache.LRUCache(max_cache=10) + self.failIf(None in cache) + cache[None] = 1 + self.assertEqual(1, cache[None]) + cache[None] = 2 + self.assertEqual(2, cache[None]) + # Test the various code paths of __getitem__, to make sure that we can + # handle when None is the key for the LRU and the MRU + cache[1] = 3 + cache[None] = 1 + cache[None] + cache[1] + cache[None] + self.assertEqual([None, 1], [n.key for n in cache._walk_lru()]) + + def test_add__null_key(self): + cache = lru_cache.LRUCache(max_cache=10) + self.assertRaises(ValueError, cache.add, lru_cache._null_key, 1) + + def test_overflow(self): + """Adding extra entries will pop out old ones.""" + cache = lru_cache.LRUCache(max_cache=1, after_cleanup_count=1) + + cache['foo'] = 'bar' + # With a max cache of 1, adding 'baz' should pop out 'foo' + cache['baz'] = 'biz' + + self.failIf('foo' in cache) + self.failUnless('baz' in cache) + + self.assertEqual('biz', cache['baz']) + + def test_by_usage(self): + """Accessing entries bumps them up in priority.""" + cache = lru_cache.LRUCache(max_cache=2) + + cache['baz'] = 'biz' + cache['foo'] = 'bar' + + self.assertEqual('biz', cache['baz']) + + # This must kick out 'foo' because it was the last accessed + cache['nub'] = 'in' + + self.failIf('foo' in cache) + + def test_cleanup(self): + """Test that we can use a cleanup function.""" + cleanup_called = [] + def cleanup_func(key, val): + cleanup_called.append((key, val)) + + cache = lru_cache.LRUCache(max_cache=2) + + cache.add('baz', '1', cleanup=cleanup_func) + cache.add('foo', '2', cleanup=cleanup_func) + cache.add('biz', '3', cleanup=cleanup_func) + + self.assertEqual([('baz', '1')], cleanup_called) + + # 'foo' is now most recent, so final cleanup will call it last + cache['foo'] + cache.clear() + self.assertEqual([('baz', '1'), ('biz', '3'), ('foo', '2')], + cleanup_called) + + def test_cleanup_on_replace(self): + """Replacing an object should cleanup the old value.""" + cleanup_called = [] + def cleanup_func(key, val): + cleanup_called.append((key, val)) + + cache = lru_cache.LRUCache(max_cache=2) + cache.add(1, 10, cleanup=cleanup_func) + cache.add(2, 20, cleanup=cleanup_func) + cache.add(2, 25, cleanup=cleanup_func) + + self.assertEqual([(2, 20)], cleanup_called) + self.assertEqual(25, cache[2]) + + # Even __setitem__ should make sure cleanup() is called + cache[2] = 26 + self.assertEqual([(2, 20), (2, 25)], cleanup_called) + + def test_len(self): + cache = lru_cache.LRUCache(max_cache=10, after_cleanup_count=10) + + cache[1] = 10 + cache[2] = 20 + cache[3] = 30 + cache[4] = 40 + + self.assertEqual(4, len(cache)) + + cache[5] = 50 + cache[6] = 60 + cache[7] = 70 + cache[8] = 80 + + self.assertEqual(8, len(cache)) + + cache[1] = 15 # replacement + + self.assertEqual(8, len(cache)) + + cache[9] = 90 + cache[10] = 100 + cache[11] = 110 + + # We hit the max + self.assertEqual(10, len(cache)) + self.assertEqual([11, 10, 9, 1, 8, 7, 6, 5, 4, 3], + [n.key for n in cache._walk_lru()]) + + def test_cleanup_shrinks_to_after_clean_count(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=3) + + cache.add(1, 10) + cache.add(2, 20) + cache.add(3, 25) + cache.add(4, 30) + cache.add(5, 35) + + self.assertEqual(5, len(cache)) + # This will bump us over the max, which causes us to shrink down to + # after_cleanup_cache size + cache.add(6, 40) + self.assertEqual(3, len(cache)) + + def test_after_cleanup_larger_than_max(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=10) + self.assertEqual(5, cache._after_cleanup_count) + + def test_after_cleanup_none(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=None) + # By default _after_cleanup_size is 80% of the normal size + self.assertEqual(4, cache._after_cleanup_count) + + def test_cleanup(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=2) + + # Add these in order + cache.add(1, 10) + cache.add(2, 20) + cache.add(3, 25) + cache.add(4, 30) + cache.add(5, 35) + + self.assertEqual(5, len(cache)) + # Force a compaction + cache.cleanup() + self.assertEqual(2, len(cache)) + + def test_preserve_last_access_order(self): + cache = lru_cache.LRUCache(max_cache=5) + + # Add these in order + cache.add(1, 10) + cache.add(2, 20) + cache.add(3, 25) + cache.add(4, 30) + cache.add(5, 35) + + self.assertEqual([5, 4, 3, 2, 1], [n.key for n in cache._walk_lru()]) + + # Now access some randomly + cache[2] + cache[5] + cache[3] + cache[2] + self.assertEqual([2, 3, 5, 4, 1], [n.key for n in cache._walk_lru()]) + + def test_get(self): + cache = lru_cache.LRUCache(max_cache=5) + + cache.add(1, 10) + cache.add(2, 20) + self.assertEqual(20, cache.get(2)) + self.assertEquals(None, cache.get(3)) + obj = object() + self.assertTrue(obj is cache.get(3, obj)) + self.assertEqual([2, 1], [n.key for n in cache._walk_lru()]) + self.assertEqual(10, cache.get(1)) + self.assertEqual([1, 2], [n.key for n in cache._walk_lru()]) + + def test_keys(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=5) + + cache[1] = 2 + cache[2] = 3 + cache[3] = 4 + self.assertEqual([1, 2, 3], sorted(cache.keys())) + cache[4] = 5 + cache[5] = 6 + cache[6] = 7 + self.assertEqual([2, 3, 4, 5, 6], sorted(cache.keys())) + + def test_resize_smaller(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=4) + cache[1] = 2 + cache[2] = 3 + cache[3] = 4 + cache[4] = 5 + cache[5] = 6 + self.assertEqual([1, 2, 3, 4, 5], sorted(cache.keys())) + cache[6] = 7 + self.assertEqual([3, 4, 5, 6], sorted(cache.keys())) + # Now resize to something smaller, which triggers a cleanup + cache.resize(max_cache=3, after_cleanup_count=2) + self.assertEqual([5, 6], sorted(cache.keys())) + # Adding something will use the new size + cache[7] = 8 + self.assertEqual([5, 6, 7], sorted(cache.keys())) + cache[8] = 9 + self.assertEqual([7, 8], sorted(cache.keys())) + + def test_resize_larger(self): + cache = lru_cache.LRUCache(max_cache=5, after_cleanup_count=4) + cache[1] = 2 + cache[2] = 3 + cache[3] = 4 + cache[4] = 5 + cache[5] = 6 + self.assertEqual([1, 2, 3, 4, 5], sorted(cache.keys())) + cache[6] = 7 + self.assertEqual([3, 4, 5, 6], sorted(cache.keys())) + cache.resize(max_cache=8, after_cleanup_count=6) + self.assertEqual([3, 4, 5, 6], sorted(cache.keys())) + cache[7] = 8 + cache[8] = 9 + cache[9] = 10 + cache[10] = 11 + self.assertEqual([3, 4, 5, 6, 7, 8, 9, 10], sorted(cache.keys())) + cache[11] = 12 # triggers cleanup back to new after_cleanup_count + self.assertEqual([6, 7, 8, 9, 10, 11], sorted(cache.keys())) + + +class TestLRUSizeCache(unittest.TestCase): + + def test_basic_init(self): + cache = lru_cache.LRUSizeCache() + self.assertEqual(2048, cache._max_cache) + self.assertEqual(int(cache._max_size*0.8), cache._after_cleanup_size) + self.assertEqual(0, cache._value_size) + + def test_add__null_key(self): + cache = lru_cache.LRUSizeCache() + self.assertRaises(ValueError, cache.add, lru_cache._null_key, 1) + + def test_add_tracks_size(self): + cache = lru_cache.LRUSizeCache() + self.assertEqual(0, cache._value_size) + cache.add('my key', 'my value text') + self.assertEqual(13, cache._value_size) + + def test_remove_tracks_size(self): + cache = lru_cache.LRUSizeCache() + self.assertEqual(0, cache._value_size) + cache.add('my key', 'my value text') + self.assertEqual(13, cache._value_size) + node = cache._cache['my key'] + cache._remove_node(node) + self.assertEqual(0, cache._value_size) + + def test_no_add_over_size(self): + """Adding a large value may not be cached at all.""" + cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5) + self.assertEqual(0, cache._value_size) + self.assertEqual({}, cache.items()) + cache.add('test', 'key') + self.assertEqual(3, cache._value_size) + self.assertEqual({'test': 'key'}, cache.items()) + cache.add('test2', 'key that is too big') + self.assertEqual(3, cache._value_size) + self.assertEqual({'test':'key'}, cache.items()) + # If we would add a key, only to cleanup and remove all cached entries, + # then obviously that value should not be stored + cache.add('test3', 'bigkey') + self.assertEqual(3, cache._value_size) + self.assertEqual({'test':'key'}, cache.items()) + + cache.add('test4', 'bikey') + self.assertEqual(3, cache._value_size) + self.assertEqual({'test':'key'}, cache.items()) + + def test_no_add_over_size_cleanup(self): + """If a large value is not cached, we will call cleanup right away.""" + cleanup_calls = [] + def cleanup(key, value): + cleanup_calls.append((key, value)) + + cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=5) + self.assertEqual(0, cache._value_size) + self.assertEqual({}, cache.items()) + cache.add('test', 'key that is too big', cleanup=cleanup) + # key was not added + self.assertEqual(0, cache._value_size) + self.assertEqual({}, cache.items()) + # and cleanup was called + self.assertEqual([('test', 'key that is too big')], cleanup_calls) + + def test_adding_clears_cache_based_on_size(self): + """The cache is cleared in LRU order until small enough""" + cache = lru_cache.LRUSizeCache(max_size=20) + cache.add('key1', 'value') # 5 chars + cache.add('key2', 'value2') # 6 chars + cache.add('key3', 'value23') # 7 chars + self.assertEqual(5+6+7, cache._value_size) + cache['key2'] # reference key2 so it gets a newer reference time + cache.add('key4', 'value234') # 8 chars, over limit + # We have to remove 2 keys to get back under limit + self.assertEqual(6+8, cache._value_size) + self.assertEqual({'key2':'value2', 'key4':'value234'}, + cache.items()) + + def test_adding_clears_to_after_cleanup_size(self): + cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10) + cache.add('key1', 'value') # 5 chars + cache.add('key2', 'value2') # 6 chars + cache.add('key3', 'value23') # 7 chars + self.assertEqual(5+6+7, cache._value_size) + cache['key2'] # reference key2 so it gets a newer reference time + cache.add('key4', 'value234') # 8 chars, over limit + # We have to remove 3 keys to get back under limit + self.assertEqual(8, cache._value_size) + self.assertEqual({'key4':'value234'}, cache.items()) + + def test_custom_sizes(self): + def size_of_list(lst): + return sum(len(x) for x in lst) + cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10, + compute_size=size_of_list) + + cache.add('key1', ['val', 'ue']) # 5 chars + cache.add('key2', ['val', 'ue2']) # 6 chars + cache.add('key3', ['val', 'ue23']) # 7 chars + self.assertEqual(5+6+7, cache._value_size) + cache['key2'] # reference key2 so it gets a newer reference time + cache.add('key4', ['value', '234']) # 8 chars, over limit + # We have to remove 3 keys to get back under limit + self.assertEqual(8, cache._value_size) + self.assertEqual({'key4':['value', '234']}, cache.items()) + + def test_cleanup(self): + cache = lru_cache.LRUSizeCache(max_size=20, after_cleanup_size=10) + + # Add these in order + cache.add('key1', 'value') # 5 chars + cache.add('key2', 'value2') # 6 chars + cache.add('key3', 'value23') # 7 chars + self.assertEqual(5+6+7, cache._value_size) + + cache.cleanup() + # Only the most recent fits after cleaning up + self.assertEqual(7, cache._value_size) + + def test_keys(self): + cache = lru_cache.LRUSizeCache(max_size=10) + + cache[1] = 'a' + cache[2] = 'b' + cache[3] = 'cdef' + self.assertEqual([1, 2, 3], sorted(cache.keys())) + + def test_resize_smaller(self): + cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9) + cache[1] = 'abc' + cache[2] = 'def' + cache[3] = 'ghi' + cache[4] = 'jkl' + # Triggers a cleanup + self.assertEqual([2, 3, 4], sorted(cache.keys())) + # Resize should also cleanup again + cache.resize(max_size=6, after_cleanup_size=4) + self.assertEqual([4], sorted(cache.keys())) + # Adding should use the new max size + cache[5] = 'mno' + self.assertEqual([4, 5], sorted(cache.keys())) + cache[6] = 'pqr' + self.assertEqual([6], sorted(cache.keys())) + + def test_resize_larger(self): + cache = lru_cache.LRUSizeCache(max_size=10, after_cleanup_size=9) + cache[1] = 'abc' + cache[2] = 'def' + cache[3] = 'ghi' + cache[4] = 'jkl' + # Triggers a cleanup + self.assertEqual([2, 3, 4], sorted(cache.keys())) + cache.resize(max_size=15, after_cleanup_size=12) + self.assertEqual([2, 3, 4], sorted(cache.keys())) + cache[5] = 'mno' + cache[6] = 'pqr' + self.assertEqual([2, 3, 4, 5, 6], sorted(cache.keys())) + cache[7] = 'stu' + self.assertEqual([4, 5, 6, 7], sorted(cache.keys())) +
deleted file mode 100644 --- a/unit-tests/topo-test.py +++ /dev/null @@ -1,35 +0,0 @@ -import random, sys -import unittest - -sys.path.append('../') - -import toposort - -class Ob: - def __init__(self, eyedee, parents): - self._id = eyedee - self.parents = parents - - def id(self): - return self._id - - -class TestTopoSorting(unittest.TestCase): - - def testsort(self): - data = { - 'f' : Ob('f', ['d', 'e']), - 'd' : Ob('d', ['b']), - 'e' : Ob('e', ['c', 'g']), - 'g' : Ob('g', ['c']), - 'c' : Ob('c', ['b', 'h']), - 'b' : Ob('b', ['a']), - 'h' : Ob('h', ['a']), - 'a' : Ob('a', []), - } - d = toposort.TopoSort(data).items() - sort = ['a', 'b', 'd', 'h', 'c', 'g', 'e', 'f'] - self.assertEquals(d, sort) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file