From 017688d836e942e06b1dfabf1ece5dc9c13b986e Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Thu, 26 Apr 2012 20:33:21 +0200 Subject: [PATCH 1/6] minor clean-up: remove duplicate dyn_file_func setting `self.dyn_file_func` is currently set twice: Once in the base class and once in the server child class. As it's only used in the non-server case, remove it from the base class. --- tftpy/TftpContexts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index c3a1bd4..78bf34e 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -67,7 +67,7 @@ def add_dup(self, pkt): class TftpContext(object): """The base class of the contexts.""" - def __init__(self, host, port, timeout, dyn_file_func=None): + def __init__(self, host, port, timeout): """Constructor for the base context, setting shared instance variables.""" self.file_to_transfer = None @@ -94,7 +94,6 @@ def __init__(self, host, port, timeout, dyn_file_func=None): self.last_update = 0 # The last packet we sent, if applicable, to make resending easy. self.last_pkt = None - self.dyn_file_func = dyn_file_func # Count the number of retry attempts. self.retry_count = 0 @@ -199,11 +198,11 @@ def __init__(self, host, port, timeout, root, dyn_file_func=None): host, port, timeout, - dyn_file_func ) # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. self.state = TftpStateServerStart(self) + self.root = root self.dyn_file_func = dyn_file_func From a973612bd8a4962855f06aa58318c492c83abcf5 Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Thu, 26 Apr 2012 20:45:22 +0200 Subject: [PATCH 2/6] refactor TftpState: move server-specific stuff to TftpServerState --- tftpy/TftpStates.py | 134 +++++++++++++++++++++++--------------------- 1 file changed, 71 insertions(+), 63 deletions(-) diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index c9f20b5..3098e45 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -24,9 +24,6 @@ def __init__(self, context): file object is required, since in tftp there's always a file involved.""" self.context = context - # This variable is used to store the absolute path to the file being - # managed. Currently only used by the server. - self.full_path = None def handle(self, pkt, raddress, rport): """An abstract method for handling a packet. It is expected to return @@ -76,64 +73,6 @@ def returnSupportedOptions(self, options): log.debug("Returning these accepted options: %s" % accepted_options) return accepted_options - def serverInitial(self, pkt, raddress, rport): - """This method performs initial setup for a server context transfer, - put here to refactor code out of the TftpStateServerRecvRRQ and - TftpStateServerRecvWRQ classes, since their initial setup is - identical. The method returns a boolean, sendoack, to indicate whether - it is required to send an OACK to the client.""" - options = pkt.options - sendoack = False - if not self.context.tidport: - self.context.tidport = rport - log.info("Setting tidport to %s" % rport) - - log.debug("Setting default options, blksize") - self.context.options = { 'blksize': DEF_BLKSIZE } - - if options: - log.debug("Options requested: %s" % options) - supported_options = self.returnSupportedOptions(options) - self.context.options.update(supported_options) - sendoack = True - - # FIXME - only octet mode is supported at this time. - if pkt.mode != 'octet': - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, \ - "Only octet transfers are supported at this time." - - # test host/port of client end - if self.context.host != raddress or self.context.port != rport: - self.sendError(TftpErrors.UnknownTID) - log.error("Expected traffic from %s:%s but received it " - "from %s:%s instead." - % (self.context.host, - self.context.port, - raddress, - rport)) - # FIXME: increment an error count? - # Return same state, we're still waiting for valid traffic. - return self - - log.debug("Requested filename is %s" % pkt.filename) - - # Make sure that the path to the file is contained in the server's - # root directory. - full_path = os.path.join(self.context.root, pkt.filename) - self.full_path = os.path.abspath(full_path) - log.debug("full_path is %s" % full_path) - if self.context.root == full_path[:len(self.context.root)]: - log.info("requested file is in the server root - good") - else: - log.warn("requested file is not within the server root - bad") - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "bad file path" - - self.context.file_to_transfer = pkt.filename - - return sendoack - def sendDAT(self): """This method sends the next DAT packet based on the data in the context. It returns a boolean indicating whether the transfer is @@ -261,7 +200,76 @@ def handleDat(self, pkt): # Default is to ack return TftpStateExpectDAT(self.context) -class TftpStateServerRecvRRQ(TftpState): +class TftpServerState(TftpState): + """The base class for server states.""" + + def __init__(self, context): + TftpState.__init__(self, context) + + # This variable is used to store the absolute path to the file being + # managed. + self.full_path = None + + def serverInitial(self, pkt, raddress, rport): + """This method performs initial setup for a server context transfer, + put here to refactor code out of the TftpStateServerRecvRRQ and + TftpStateServerRecvWRQ classes, since their initial setup is + identical. The method returns a boolean, sendoack, to indicate whether + it is required to send an OACK to the client.""" + options = pkt.options + sendoack = False + if not self.context.tidport: + self.context.tidport = rport + log.info("Setting tidport to %s" % rport) + + log.debug("Setting default options, blksize") + self.context.options = { 'blksize': DEF_BLKSIZE } + + if options: + log.debug("Options requested: %s" % options) + supported_options = self.returnSupportedOptions(options) + self.context.options.update(supported_options) + sendoack = True + + # FIXME - only octet mode is supported at this time. + if pkt.mode != 'octet': + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, \ + "Only octet transfers are supported at this time." + + # test host/port of client end + if self.context.host != raddress or self.context.port != rport: + self.sendError(TftpErrors.UnknownTID) + log.error("Expected traffic from %s:%s but received it " + "from %s:%s instead." + % (self.context.host, + self.context.port, + raddress, + rport)) + # FIXME: increment an error count? + # Return same state, we're still waiting for valid traffic. + return self + + log.debug("Requested filename is %s" % pkt.filename) + + # Make sure that the path to the file is contained in the server's + # root directory. + full_path = os.path.join(self.context.root, pkt.filename) + self.full_path = os.path.abspath(full_path) + log.debug("full_path is %s" % full_path) + if self.context.root == full_path[:len(self.context.root)]: + log.info("requested file is in the server root - good") + else: + log.warn("requested file is not within the server root - bad") + self.sendError(TftpErrors.IllegalTftpOp) + raise TftpException, "bad file path" + + self.context.file_to_transfer = pkt.filename + + return sendoack + + +class TftpStateServerRecvRRQ(TftpServerState): """This class represents the state of the TFTP server when it has just received an RRQ packet.""" def handle(self, pkt, raddress, rport): @@ -306,7 +314,7 @@ def handle(self, pkt, raddress, rport): # Note, we don't have to check any other states in this method, that's # up to the caller. -class TftpStateServerRecvWRQ(TftpState): +class TftpStateServerRecvWRQ(TftpServerState): """This class represents the state of the TFTP server when it has just received a WRQ packet.""" def make_subdirs(self): From c52ea929bb97dde7a70f2b2bae3de170311a59f6 Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Thu, 26 Apr 2012 21:16:03 +0200 Subject: [PATCH 3/6] t/test.py: add unit test for insecure path access The test currently fails, because the request path is improperly checked / sanitised. --- t/test.py | 38 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/t/test.py b/t/test.py index 066d90a..e8d0c78 100644 --- a/t/test.py +++ b/t/test.py @@ -286,5 +286,43 @@ def testServerNoOptionsSubdir(self): finalstate = serverstate.state.handle(ack, raddress, rport) self.assertTrue( finalstate is None ) + def testServerInsecurePath(self): + raddress = '127.0.0.2' + rport = 10000 + timeout = 5 + root = os.path.dirname(os.path.abspath(__file__)) + serverstate = tftpy.TftpContextServer(raddress, + rport, + timeout, + root) + rrq = tftpy.TftpPacketRRQ() + rrq.filename = '../setup.py' + rrq.mode = 'octet' + rrq.options = {} + + # Start the download. + self.assertRaisesRegexp(tftpy.TftpException, "bad file path", + serverstate.start, rrq.encode().buffer) + + def testServerSecurePath(self): + raddress = '127.0.0.2' + rport = 10000 + timeout = 5 + root = os.path.dirname(os.path.abspath(__file__)) + serverstate = tftpy.TftpContextServer(raddress, + rport, + timeout, + root) + rrq = tftpy.TftpPacketRRQ() + rrq.filename = '100KBFILE' + rrq.mode = 'octet' + rrq.options = {} + + # Start the download. + serverstate.start(rrq.encode().buffer) + # Should be in expectack state. + self.assertTrue(isinstance(serverstate.state, + tftpy.TftpStateExpectACK)) + if __name__ == '__main__': unittest.main() From 83507c60ff28ab0bd7e87217bdd53c8b56ca4715 Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Thu, 26 Apr 2012 21:18:31 +0200 Subject: [PATCH 4/6] tftpy/TftpStates.py: fix security problem regarding path check This patch fixes the request path check. It makes sure that requested paths are _below_ the specified root directory. --- tftpy/TftpStates.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 3098e45..8e4ad1e 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -257,7 +257,7 @@ def serverInitial(self, pkt, raddress, rport): full_path = os.path.join(self.context.root, pkt.filename) self.full_path = os.path.abspath(full_path) log.debug("full_path is %s" % full_path) - if self.context.root == full_path[:len(self.context.root)]: + if self.full_path.startswith(self.context.root): log.info("requested file is in the server root - good") else: log.warn("requested file is not within the server root - bad") From c3089fb212fd66432cf7866bc491792022d06a6d Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Fri, 27 Apr 2012 17:45:53 +0200 Subject: [PATCH 5/6] Introduce server-side virtual filesystem layer This patch modifies the TFTP server code to depend on a virtual file system abstraction for all file-system access. The previous file system access code is moved to the TftpVfsNative class. The patch also provides a TftpVfsStack class, which allows multiple VFS classes to be combined into one name space. Includes unit tests. --- COPYING | 1 + t/test.py | 173 +++++++++++++++++++++++++++++++++++-- tftpy/TftpContexts.py | 19 ++++- tftpy/TftpServer.py | 47 ++++------ tftpy/TftpStates.py | 84 +++++------------- tftpy/TftpVfs.py | 193 ++++++++++++++++++++++++++++++++++++++++++ tftpy/__init__.py | 1 + 7 files changed, 412 insertions(+), 106 deletions(-) create mode 100644 tftpy/TftpVfs.py diff --git a/COPYING b/COPYING index c9f2c9c..b81aa90 100644 --- a/COPYING +++ b/COPYING @@ -1,6 +1,7 @@ The MIT License Copyright (c) 2009 Michael P. Soulier +Copyright (c) 2012 Fabian Knittel Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/t/test.py b/t/test.py index e8d0c78..2d351e1 100644 --- a/t/test.py +++ b/t/test.py @@ -5,6 +5,8 @@ import tftpy import os import time +import tempfile +import shutil log = tftpy.log @@ -216,12 +218,12 @@ def testServerNoOptions(self): raddress = '127.0.0.2' rport = 10000 timeout = 5 - root = os.path.dirname(os.path.abspath(__file__)) + vfs = tftpy.TftpVfsCompat(os.path.dirname(os.path.abspath(__file__))) # Testing without the dyn_func_file set. serverstate = tftpy.TftpContextServer(raddress, rport, timeout, - root) + vfs) self.assertTrue( isinstance(serverstate, tftpy.TftpContextServer) ) @@ -253,12 +255,12 @@ def testServerNoOptionsSubdir(self): raddress = '127.0.0.2' rport = 10000 timeout = 5 - root = os.path.dirname(os.path.abspath(__file__)) + vfs = tftpy.TftpVfsCompat(os.path.dirname(os.path.abspath(__file__))) # Testing without the dyn_func_file set. serverstate = tftpy.TftpContextServer(raddress, rport, timeout, - root) + vfs) self.assertTrue( isinstance(serverstate, tftpy.TftpContextServer) ) @@ -290,11 +292,11 @@ def testServerInsecurePath(self): raddress = '127.0.0.2' rport = 10000 timeout = 5 - root = os.path.dirname(os.path.abspath(__file__)) + vfs = tftpy.TftpVfsCompat(os.path.dirname(os.path.abspath(__file__))) serverstate = tftpy.TftpContextServer(raddress, rport, timeout, - root) + vfs) rrq = tftpy.TftpPacketRRQ() rrq.filename = '../setup.py' rrq.mode = 'octet' @@ -308,11 +310,11 @@ def testServerSecurePath(self): raddress = '127.0.0.2' rport = 10000 timeout = 5 - root = os.path.dirname(os.path.abspath(__file__)) + vfs = tftpy.TftpVfsCompat(os.path.dirname(os.path.abspath(__file__))) serverstate = tftpy.TftpContextServer(raddress, rport, timeout, - root) + vfs) rrq = tftpy.TftpPacketRRQ() rrq.filename = '100KBFILE' rrq.mode = 'octet' @@ -324,5 +326,160 @@ def testServerSecurePath(self): self.assertTrue(isinstance(serverstate.state, tftpy.TftpStateExpectACK)) +class TestTftpyVfsReadOnlyDynFileFunc(unittest.TestCase): + def testRead(self): + state = {'called':False} + the_path = 'a path' + def dyn_func(path): + state['called'] = True + return path + vfs = tftpy.TftpVfsReadOnlyDynFileFunc(dyn_func) + ret = vfs.open_read(the_path) + self.assertEqual(the_path, ret) + self.assert_(state['called']) + + def testWrite(self): + state = {'called':False} + the_path = 'a-path' + def dyn_func(path): + state['called'] = True + return path + vfs = tftpy.TftpVfsReadOnlyDynFileFunc(dyn_func) + ret = vfs.open_write(the_path) + self.assertEqual(None, ret) + self.assert_(not state['called']) + +class TestTftpyVfsNative(unittest.TestCase): + def setUp(self): + self.write_root = tempfile.mkdtemp() + def tearDown(self): + shutil.rmtree(self.write_root, ignore_errors=True) + + def testReadExisting(self): + # Copy file to the temporary tftp root + root = os.path.dirname(os.path.abspath(__file__)) + the_path = '100KBFILE' + shutil.copy(os.path.join(root, the_path), self.write_root) + + vfs = tftpy.TftpVfsNative(self.write_root) + fp = vfs.open_read(the_path) + self.assert_(fp is not None) + try: + orig_fp = open(os.path.join(self.write_root, the_path), 'rb') + try: + self.assertEqual(orig_fp.read(), fp.read()) + finally: + orig_fp.close() + finally: + fp.close() + + def testReadNonExistent(self): + # The temporary tftp root is empty + the_path = '100KBFILE' + + vfs = tftpy.TftpVfsNative(self.write_root) + fp = vfs.open_read(the_path) + self.assert_(fp is None) + + def testNonExistentRoot(self): + non_existent_root = os.path.join(self.write_root, 'non-existent') + self.assertRaisesRegexp(tftpy.TftpException, 'tftproot does not exist', + tftpy.TftpVfsNative, non_existent_root) + + def testWriteSubdir(self): + """Write a test string and read it back.""" + the_dir = 'a-path' + the_fn = os.path.join(the_dir, 'a-file') + data = 'test string' + vfs = tftpy.TftpVfsNative(self.write_root) + fp = vfs.open_write(the_fn) + self.assert_(fp is not None) + fp.write(data) + fp.close() + self.assert_(os.path.exists(os.path.join(self.write_root, the_dir))) + self.assert_(os.path.isdir(os.path.join(self.write_root, the_dir))) + self.assert_(os.path.exists(os.path.join(self.write_root, the_fn))) + + def testUnsafeRead(self): + the_path = os.path.join(os.path.pardir, '100KBFILE') + vfs = tftpy.TftpVfsNative(self.write_root) + self.assertRaisesRegexp(tftpy.TftpException, "bad file path", + vfs.open_read, the_path) + +class TestTftpyVfsStack(unittest.TestCase): + def setUp(self): + self.vfs = tftpy.TftpVfsStack() + + def testReadEmpty(self): + self.assert_(self.vfs.open_read('path') is None) + + def testWriteEmpty(self): + self.assert_(self.vfs.open_write('path') is None) + + class MockVfsAccept(object): + def __init__(self): + self.read_fp = object() + self.write_fp = object() + self.path = None + def open_read(self, path): + self.path = path + return self.read_fp + def open_write(self, path): + self.path = path + return self.write_fp + + class MockVfsReject(object): + def __init__(self): + self.path = None + def open_read(self, path): + self.path = path + return None + def open_write(self, path): + self.path = path + return None + + def testReadRoot(self): + fs1 = self.MockVfsAccept() + self.vfs.mount(fs1, '/') + ret = self.vfs.open_read('path') + self.assert_(ret is fs1.read_fp) + self.assertEqual('/path', fs1.path) + + def testWriteRoot(self): + fs1 = self.MockVfsAccept() + self.vfs.mount(fs1, '/') + ret = self.vfs.open_write('path') + self.assert_(ret is fs1.write_fp) + self.assertEqual('/path', fs1.path) + + def testFirstRoot(self): + """Return first valid match""" + fs1 = self.MockVfsReject() + fs2 = self.MockVfsAccept() + fs3 = self.MockVfsAccept() + self.vfs.mount(fs1, '/') + self.vfs.mount(fs2, '/') + self.vfs.mount(fs3, '/') + ret = self.vfs.open_read('path') + self.assert_(ret is fs2.read_fp) + self.assert_(ret is not fs3.read_fp) + self.assertEqual('/path', fs1.path) + self.assertEqual('/path', fs2.path) + self.assertEqual(None, fs3.path) + + def testIterateSubPaths(self): + """Visit all providers that have a matching base path.""" + fs1 = self.MockVfsReject() + fs2 = self.MockVfsReject() + fs3 = self.MockVfsReject() + self.vfs.mount(fs1, '/base') + self.vfs.mount(fs2, '/base/somewhere') + self.vfs.mount(fs3, '/not-relevant') + ret = self.vfs.open_read('/base/somewhere/path') + self.assert_(ret is None) + self.assertEqual('/somewhere/path', fs1.path) + self.assertEqual('/path', fs2.path) + self.assertEqual(None, fs3.path) + if __name__ == '__main__': unittest.main() diff --git a/tftpy/TftpContexts.py b/tftpy/TftpContexts.py index 78bf34e..7231439 100644 --- a/tftpy/TftpContexts.py +++ b/tftpy/TftpContexts.py @@ -193,19 +193,18 @@ def cycle(self): class TftpContextServer(TftpContext): """The context for the server.""" - def __init__(self, host, port, timeout, root, dyn_file_func=None): + def __init__(self, host, port, timeout, vfs): TftpContext.__init__(self, host, port, timeout, ) + self._vfs = vfs + # At this point we have no idea if this is a download or an upload. We # need to let the start state determine that. self.state = TftpStateServerStart(self) - self.root = root - self.dyn_file_func = dyn_file_func - def __str__(self): return "%s:%s %s" % (self.host, self.port, self.state) @@ -236,6 +235,18 @@ def end(self): log.debug("Set metrics.end_time to %s" % self.metrics.end_time) self.metrics.compute() + def open_read(self, path): + """Return a file-like object for the virtual path `path` or ``None`` if + the path does not exist. Throws :class:`TftpException` in case the path + is invalid.""" + return self._vfs.open_read(path) + + def open_write(self, path): + """Return a file-like object for the virtual path `path` or ``None`` if + the path is not writable. Throws :class:`TftpException` in case the + path is invalid.""" + return self._vfs.open_write(path) + class TftpContextClientUpload(TftpContext): """The upload context for the client during an upload. Note: If input is a hyphen, then we will use stdin.""" diff --git a/tftpy/TftpServer.py b/tftpy/TftpServer.py index 364227c..8027bd7 100644 --- a/tftpy/TftpServer.py +++ b/tftpy/TftpServer.py @@ -9,43 +9,22 @@ from TftpPacketTypes import * from TftpPacketFactory import TftpPacketFactory from TftpContexts import TftpContextServer +from TftpVfs import TftpVfsCompat -class TftpServer(TftpSession): +class TftpServerVfs(TftpSession): """This class implements a tftp server object. Run the listen() method to - listen for client requests. It takes two optional arguments. tftproot is - the path to the tftproot directory to serve files from and/or write them - to. dyn_file_func is a callable that must return a file-like object to - read from during downloads. This permits the serving of dynamic - content.""" + listen for client requests. It expects the `vfs` object to provide access + to the filesystem.""" - def __init__(self, tftproot='/tftpboot', dyn_file_func=None): + def __init__(self, vfs): self.listenip = None self.listenport = None self.sock = None - # FIXME: What about multiple roots? - self.root = os.path.abspath(tftproot) - self.dyn_file_func = dyn_file_func + self.vfs = vfs # A dict of sessions, where each session is keyed by a string like # ip:tid for the remote end. self.sessions = {} - if os.path.exists(self.root): - log.debug("tftproot %s does exist" % self.root) - if not os.path.isdir(self.root): - raise TftpException, "The tftproot must be a directory." - else: - log.debug("tftproot %s is a directory" % self.root) - if os.access(self.root, os.R_OK): - log.debug("tftproot %s is readable" % self.root) - else: - raise TftpException, "The tftproot must be readable" - if os.access(self.root, os.W_OK): - log.debug("tftproot %s is writable" % self.root) - else: - log.warning("The tftproot %s is not writable" % self.root) - else: - raise TftpException, "The tftproot does not exist." - def listen(self, listenip="", listenport=DEF_TFTP_PORT, @@ -104,8 +83,7 @@ def listen(self, self.sessions[key] = TftpContextServer(raddress, rport, timeout, - self.root, - self.dyn_file_func) + self.vfs) try: self.sessions[key].start(buffer) except TftpException, err: @@ -182,3 +160,14 @@ def listen(self, else: log.warn("Strange, session %s is not on the deletion list" % key) + +class TftpServer(TftpServerVfs): + """This class implements a tftp server object. Run the listen() method to + listen for client requests. It takes two optional arguments. tftproot is + the path to the tftproot directory to serve files from and/or write them + to. dyn_file_func is a callable that must return a file-like object to + read from during downloads. This permits the serving of dynamic + content.""" + + def __init__(self, tftproot='/tftpboot', dyn_file_func=None): + TftpServerVfs.__init__(self, TftpVfsCompat(tftproot, dyn_file_func)) diff --git a/tftpy/TftpStates.py b/tftpy/TftpStates.py index 8e4ad1e..e2a9ac6 100644 --- a/tftpy/TftpStates.py +++ b/tftpy/TftpStates.py @@ -203,13 +203,6 @@ def handleDat(self, pkt): class TftpServerState(TftpState): """The base class for server states.""" - def __init__(self, context): - TftpState.__init__(self, context) - - # This variable is used to store the absolute path to the file being - # managed. - self.full_path = None - def serverInitial(self, pkt, raddress, rport): """This method performs initial setup for a server context transfer, put here to refactor code out of the TftpStateServerRecvRRQ and @@ -252,18 +245,6 @@ def serverInitial(self, pkt, raddress, rport): log.debug("Requested filename is %s" % pkt.filename) - # Make sure that the path to the file is contained in the server's - # root directory. - full_path = os.path.join(self.context.root, pkt.filename) - self.full_path = os.path.abspath(full_path) - log.debug("full_path is %s" % full_path) - if self.full_path.startswith(self.context.root): - log.info("requested file is in the server root - good") - else: - log.warn("requested file is not within the server root - bad") - self.sendError(TftpErrors.IllegalTftpOp) - raise TftpException, "bad file path" - self.context.file_to_transfer = pkt.filename return sendoack @@ -276,25 +257,16 @@ def handle(self, pkt, raddress, rport): "Handle an initial RRQ packet as a server." log.debug("In TftpStateServerRecvRRQ.handle") sendoack = self.serverInitial(pkt, raddress, rport) - path = self.full_path - log.info("Opening file %s for reading" % path) - if os.path.exists(path): - # Note: Open in binary mode for win32 portability, since win32 - # blows. - self.context.fileobj = open(path, "rb") - elif self.context.dyn_file_func: - log.debug("No such file %s but using dyn_file_func" % path) - self.context.fileobj = \ - self.context.dyn_file_func(self.context.file_to_transfer) - - if self.context.fileobj is None: - log.debug("dyn_file_func returned 'None', treating as " - "FileNotFound") - self.sendError(TftpErrors.FileNotFound) - raise TftpException, "File not found: %s" % path - else: + try: + self.context.fileobj = self.context.open_read( + self.context.file_to_transfer) + except TftpException: + self.sendError(TftpErrors.IllegalTftpOp) + raise + if self.context.fileobj is None: self.sendError(TftpErrors.FileNotFound) - raise TftpException, "File not found: %s" % path + raise TftpException, "File not found: %s" % ( + self.context.file_to_transfer) # Options negotiation. if sendoack: @@ -317,38 +289,20 @@ def handle(self, pkt, raddress, rport): class TftpStateServerRecvWRQ(TftpServerState): """This class represents the state of the TFTP server when it has just received a WRQ packet.""" - def make_subdirs(self): - """The purpose of this method is to, if necessary, create all of the - subdirectories leading up to the file to the written.""" - # Pull off everything below the root. - subpath = self.full_path[len(self.context.root):] - log.debug("make_subdirs: subpath is %s" % subpath) - # Split on directory separators, but drop the last one, as it should - # be the filename. - dirs = subpath.split(os.sep)[:-1] - log.debug("dirs is %s" % dirs) - current = self.context.root - for dir in dirs: - if dir: - current = os.path.join(current, dir) - if os.path.isdir(current): - log.debug("%s is already an existing directory" % current) - else: - os.mkdir(current, 0700) - def handle(self, pkt, raddress, rport): "Handle an initial WRQ packet as a server." log.debug("In TftpStateServerRecvWRQ.handle") sendoack = self.serverInitial(pkt, raddress, rport) - path = self.full_path - log.info("Opening file %s for writing" % path) - if os.path.exists(path): - # FIXME: correct behavior? - log.warn("File %s exists already, overwriting..." % self.context.file_to_transfer) - # FIXME: I think we should upload to a temp file and not overwrite the - # existing file until the file is successfully uploaded. - self.make_subdirs() - self.context.fileobj = open(path, "wb") + try: + self.context.fileobj = self.context.open_write( + self.context.file_to_transfer) + except TftpException: + self.sendError(TftpErrors.IllegalTftpOp) + raise + if self.context.fileobj is None: + self.sendError(TftpErrors.FileNotFound) + raise TftpException, "File not found: %s" % ( + self.context.file_to_transfer) # Options negotiation. if sendoack: diff --git a/tftpy/TftpVfs.py b/tftpy/TftpVfs.py new file mode 100644 index 0000000..385a094 --- /dev/null +++ b/tftpy/TftpVfs.py @@ -0,0 +1,193 @@ +"""This module implements virtual file system functionality. It contains +classes that provide or simulate file system access for TftpServer. + +Each VFS class is expected to provide two methods: + +`open_read` is called when a TFTP path should be read. If the path exists, a +file-like object should be returned, otherwise ``None``. + +`open_write` is called when a TFTP path should be written to. If the path can +be written to, a file-like object should be returned, otherwise ``None``. +""" + +import os +import os.path +from tftpy.TftpShared import TftpException +from tftpy.TftpShared import log + + +class TftpVfsStack(object): + """Allows delegation of VFS requests to other VFS providers, based on the + request path. If a request path matches more than one provider, the first + provider that was registered is tried first, then the second and so on. The + search is aborted as soon a provider succeeds.""" + + def __init__(self): + self._file_systems = [] + + def mount(self, vfs, base_path='/'): + """Registers the VFS provider `vfs` for requests prefixed with + `base_path`.""" + base_path = os.path.normpath(base_path) + # The base_path should end with a slash. + if not base_path.endswith('/'): + base_path += '/' + self._file_systems.append((vfs, base_path)) + + def _matching_vfs_paths(self, path): + """Generator that searches through the list of VFS providers. For each + base path that matches `path`, a tuple with the matching VFS provider + and the sub path (relative to the base path) is yielded.""" + path = os.path.normpath(path) + if not path.startswith('/'): + path = '/' + path + for vfs, base_path in self._file_systems: + if not path.startswith(base_path): + continue + sub_path = path[len(base_path) - 1:] + yield (vfs, sub_path) + + def open_read(self, path): + """Returns the file-like object from the first matching file system that + provides one. Returns ``None`` if no file was found. + """ + for vfs, sub_path in self._matching_vfs_paths(path): + log.debug('attempting path %s, sub path %s' % (path, sub_path)) + fp = vfs.open_read(sub_path) + if fp is not None: + return fp + return None + + def open_write(self, path): + """Returns the file-like object from the first matching file system that + provides one. Returns ``None`` if no file was found or writing isn't + supported. + """ + for vfs, sub_path in self._matching_vfs_paths(path): + log.debug('attempting path %s, sub path %s' % (path, sub_path)) + fp = vfs.open_write(sub_path) + if fp is not None: + return fp + return None + +class TftpVfsNative(object): + """Provides access to the operating system's file system. Access is + provided relative to the sub-tree at `root`. + + Allows read and write access. During write access, any missing + intermediate directories of the target path are automatically created. If + the target file already exists, it is replaced. + """ + def __init__(self, root): + self.root = os.path.abspath(root) + + if os.path.exists(self.root): + log.debug("tftproot %s does exist" % self.root) + if not os.path.isdir(self.root): + raise TftpException, "The tftproot must be a directory." + else: + log.debug("tftproot %s is a directory" % self.root) + if os.access(self.root, os.R_OK): + log.debug("tftproot %s is readable" % self.root) + else: + raise TftpException, "The tftproot must be readable" + if os.access(self.root, os.W_OK): + log.debug("tftproot %s is writable" % self.root) + else: + log.warning("The tftproot %s is not writable" % self.root) + else: + raise TftpException, "The tftproot does not exist." + + def _full_path(self, path): + """Translates the relative path `path` to the absolute path within the + native file system and returns that path. Also makes sure, that the + resulting path resides within the designated sub-tree. + """ + # Remove any starting slash + if path.startswith('/'): + path = path[1:] + # Make sure that the path to the file is contained in the server's + # root directory. + full_path = os.path.abspath(os.path.join(self.root, path)) + log.debug("full_path is %s" % full_path) + if full_path.startswith(self.root): + log.info("requested file is in the server root - good") + else: + log.warn("requested file is not within the server root - bad") + raise TftpException, "bad file path" + return full_path + + def open_read(self, path): + """If `path` exists relative to `root`, a file-like object for read + access is provided. Otherwise ``None`` is returned. + Throws :class:``TftpException`` if the requested path tries to exit the + sub-tree.""" + full_path = self._full_path(path) + if not os.path.exists(full_path): + return None + log.info("Opening file %s for reading" % full_path) + # Note: Open in binary mode for win32 portability, since win32 blows. + return open(full_path, "rb") + + def _make_subdirs(self, full_path): + """The purpose of this method is to, if necessary, create all of the + subdirectories leading up to the file to the written. + """ + # Pull off everything below the root. + subpath = full_path[len(self.root):] + log.debug("make_subdirs: subpath is %s" % subpath) + # Split on directory separators, but drop the last one, as it should + # be the filename. + dirs = subpath.split(os.sep)[:-1] + log.debug("dirs is %s" % dirs) + current = self.root + for directory in dirs: + if directory: + current = os.path.join(current, directory) + if os.path.isdir(current): + log.debug("%s is already an existing directory" % current) + else: + os.mkdir(current, 0700) + + def open_write(self, path): + """A file-like object for write access is provided. Any missing + directories are created automatically. Throws :class:``TftpException`` + if the requested path tries to exit the sub-tree. + """ + full_path = self._full_path(path) + log.info("Opening file %s for writing" % full_path) + if os.path.exists(full_path): + # FIXME: correct behavior? + log.warn("File %s exists already, overwriting..." % full_path) + # FIXME: I think we should upload to a temp file and not overwrite the + # existing file until the file is successfully uploaded. + self._make_subdirs(full_path) + return open(full_path, "wb") + +class TftpVfsReadOnlyDynFileFunc(object): + """Allows read access to potentially dynamic content provided by the + call-back `dyn_file_func`. + """ + def __init__(self, dyn_file_func): + self.dyn_file_func = dyn_file_func + + def open_read(self, path): + """Returns the file-like object provided by the call-back, based on + `path`.""" + return self.dyn_file_func(path) + + def open_write(self, path): + """Always returns ``None``, because the `dyn_file_func` does not support + write access.""" + return None + +class TftpVfsCompat(TftpVfsStack): + """Provides a layered VFS which combines a native file-system and, as a + fall-back in case the file isn't found in the native file-system, an + optional call-back function for potentially dynamic content. + """ + def __init__(self, tftproot='/tftpboot', dyn_file_func=None): + TftpVfsStack.__init__(self) + self.mount(TftpVfsNative(tftproot), '/') + if dyn_file_func is not None: + self.mount(TftpVfsReadOnlyDynFileFunc(dyn_file_func), '/') diff --git a/tftpy/__init__.py b/tftpy/__init__.py index e8ef87f..ad0bf44 100644 --- a/tftpy/__init__.py +++ b/tftpy/__init__.py @@ -22,3 +22,4 @@ from TftpServer import * from TftpContexts import * from TftpStates import * +from TftpVfs import * From e73099ee8394230249290b36c33cdf4de397a9e3 Mon Sep 17 00:00:00 2001 From: Fabian Knittel Date: Fri, 27 Apr 2012 17:50:26 +0200 Subject: [PATCH 6/6] bin/tftpy_server.py: support for multiple roots Allows the ``--root`` option to be specified more than once. Any write requests are sent to the first root. --- bin/tftpy_server.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/bin/tftpy_server.py b/bin/tftpy_server.py index 6a1d3a3..c3ebc1b 100755 --- a/bin/tftpy_server.py +++ b/bin/tftpy_server.py @@ -21,7 +21,8 @@ def main(): '--root', type='string', help='path to serve from', - default=None) + default=None, + action="append") parser.add_option('-d', '--debug', action='store_true', @@ -38,7 +39,11 @@ def main(): parser.print_help() sys.exit(1) - server = tftpy.TftpServer(options.root) + vfs = tftpy.TftpVfsStack() + for root in options.root: + vfs.mount(tftpy.TftpVfsNative(root), '/') + + server = tftpy.TftpServerVfs(vfs) try: server.listen(options.ip, options.port) except tftpy.TftpException, err: