diff --git a/modules/bibdocfile/lib/bibdocfile.py b/modules/bibdocfile/lib/bibdocfile.py index 133870fe1a..4bd5ed367a 100644 --- a/modules/bibdocfile/lib/bibdocfile.py +++ b/modules/bibdocfile/lib/bibdocfile.py @@ -336,7 +336,8 @@ def _sql_generate_conjunctive_where(to_process): q_str.append(_val_or_null(entry[0], eq_name = entry[1], q_args = q_args)) return (" AND ".join(q_str), q_args) -def file_strip_ext(afile, skip_version=False, only_known_extensions=False, allow_subformat=True): +def file_strip_ext(afile, skip_version=False, + only_known_extensions=False, allow_subformat=True): """ Strip in the best way the extension from a filename. @@ -348,12 +349,14 @@ def file_strip_ext(afile, skip_version=False, only_known_extensions=False, allow 'foo' >>> file_strip_ext("foo.buz", only_known_extensions=True) 'foo.buz' - >>> file_strip_ext("foo.buz;1", skip_version=False, - ... only_known_extensions=True) + >>> file_strip_ext("foo.buz;1", + skip_version=False, + allow_subformat=False, + only_known_extensions=True) 'foo.buz;1' >>> file_strip_ext("foo.gif;icon") 'foo' - >>> file_strip_ext("foo.gif:icon", allow_subformat=False) + >>> file_strip_ext("foo.gif:icon", only_known_extensions=True) 'foo.gif:icon' @param afile: the path/name of a file. @@ -372,11 +375,17 @@ def file_strip_ext(afile, skip_version=False, only_known_extensions=False, allow if skip_version or allow_subformat: afile = afile.split(';')[0] nextfile = _extensions.sub('', afile) + extension = afile[len(nextfile) + 1:] if nextfile == afile and not only_known_extensions: nextfile = os.path.splitext(afile)[0] while nextfile != afile: afile = nextfile - nextfile = _extensions.sub('', afile) + tmp_nextfile = _extensions.sub('', afile) + new_extension = afile[len(tmp_nextfile) + 1:] + if new_extension != extension: + nextfile = tmp_nextfile + else: + break return nextfile def normalize_format(docformat, allow_subformat=True): diff --git a/modules/bibdocfile/lib/bibdocfile_unit_tests.py b/modules/bibdocfile/lib/bibdocfile_unit_tests.py new file mode 100644 index 0000000000..82234af477 --- /dev/null +++ b/modules/bibdocfile/lib/bibdocfile_unit_tests.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- +## +## This file is part of Invenio. +## Copyright (C) 2014 CERN. +## +## Invenio is free software; you can redistribute it and/or +## modify it under the terms of the GNU General Public License as +## published by the Free Software Foundation; either version 2 of the +## License, or (at your option) any later version. +## +## Invenio is distributed in the hope that it will be useful, but +## WITHOUT ANY WARRANTY; without even the implied warranty of +## MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +## General Public License for more details. +## +## You should have received a copy of the GNU General Public License +## along with Invenio; if not, write to the Free Software Foundation, Inc., +## 59 Temple Place, Suite 330, Boston, MA 02111-1307, USA. + +"""BibDocFile Unit Test Suite.""" + +import unittest +from invenio.testutils import make_test_suite, run_test_suite +from invenio.bibdocfile import file_strip_ext + + +class BibDocFileTest(unittest.TestCase): + """Unit tests about""" + + def test_strip_ext(self): + """bibdocfile - test file_strip_ext """ + self.assertEqual(file_strip_ext("foo.tar.gz"), 'foo') + self.assertEqual(file_strip_ext("foo.buz.gz"), 'foo.buz') + self.assertEqual(file_strip_ext("foo.buz"), 'foo') + self.assertEqual(file_strip_ext("foo.buz", + only_known_extensions=True), 'foo.buz') + self.assertEqual(file_strip_ext("foo.buz;1", + skip_version=False, + only_known_extensions=True, + allow_subformat=False), 'foo.buz;1') + self.assertEqual(file_strip_ext("foo.gif;icon"), 'foo') + self.assertEqual(file_strip_ext("foo.gif:icon", + only_known_extensions=True), 'foo.gif:icon') + self.assertEqual(file_strip_ext("foo.pdf.pdf", + only_known_extensions=True), 'foo.pdf') + + +TEST_SUITE = make_test_suite(BibDocFileTest) +if __name__ == "__main__": + run_test_suite(TEST_SUITE, warn_user=False) diff --git a/modules/bibupload/lib/bibupload.py b/modules/bibupload/lib/bibupload.py index a9d368fe0e..47a55e1ba9 100644 --- a/modules/bibupload/lib/bibupload.py +++ b/modules/bibupload/lib/bibupload.py @@ -95,13 +95,12 @@ task_update_progress, task_sleep_now_if_required, fix_argv_paths, \ RecoverableError from invenio.bibdocfile import BibRecDocs, file_strip_ext, normalize_format, \ - get_docname_from_url, check_valid_url, download_url, \ + get_docname_from_url, check_valid_url, \ KEEP_OLD_VALUE, decompose_bibdocfile_url, InvenioBibDocFileError, \ bibdocfile_url_p, CFG_BIBDOCFILE_AVAILABLE_FLAGS, guess_format_from_url, \ BibRelation, MoreInfo, guess_via_magic - -from invenio.search_engine import search_pattern - +from invenio.filedownloadutils import (download_url, + InvenioFileDownloadFormatError) from invenio.bibupload_revisionverifier import RevisionVerifier, \ InvenioBibUploadConflictingRevisionsError, \ InvenioBibUploadInvalidRevisionError, \ @@ -2010,6 +2009,9 @@ def _process_document_moreinfos(more_infos, docname, version, docformat, mode): if guessed_format != docformat: raise RuntimeError("Given URL %s was supposed to refer to format %s but was found to be of format %s. Is this document behind an authentication page?" % (url, docformat, guessed_format)) write_message("%s saved into %s" % (url, downloaded_url), verbose=9) + except InvenioFileDownloadFormatError, err: + write_message("WARNING: format detection problem when downloading '%s' because of: %s" % (url, err), stream=sys.stderr) + raise except Exception, err: write_message("ERROR: in downloading '%s' because of: %s" % (url, err), stream=sys.stderr) raise diff --git a/modules/miscutil/lib/filedownloadutils.py b/modules/miscutil/lib/filedownloadutils.py index 5b697bcc56..0c69de686a 100644 --- a/modules/miscutil/lib/filedownloadutils.py +++ b/modules/miscutil/lib/filedownloadutils.py @@ -34,6 +34,7 @@ import tempfile import shutil import sys +from mimetypes import guess_all_extensions from invenio.urlutils import make_invenio_opener @@ -59,6 +60,11 @@ class InvenioFileCopyError(Exception): pass +class InvenioFileDownloadFormatError(Exception): + """A problem with format detection occurred.""" + pass + + def download_url(url, content_type=None, download_to_file=None, retry_count=10, timeout=10.0): """ @@ -80,8 +86,8 @@ def download_url(url, content_type=None, download_to_file=None, @param url: where the file lives on the interwebs @type url: string - @param content_type: desired content_type to check for in external URLs. - (optional) + @param content_type: desired MIME content_type or extension to check for + in external URLs. (optional) @type content_type: string @param download_to_file: where the file should live after download. @@ -97,7 +103,8 @@ def download_url(url, content_type=None, download_to_file=None, @type timeout: float @return: the path of the downloaded/copied file - @raise InvenioFileDownloadError: raised upon URL/HTTP errors, file errors or wrong format + @raise InvenioFileDownloadError: raised upon URL/HTTP errors, file errors + or wrong format """ if not download_to_file: download_to_file = safe_mkstemp(suffix=".tmp", @@ -133,7 +140,8 @@ def download_external_url(url, download_to_file, content_type=None, @param download_to_file: the path to download the file to @type download_to_file: string - @param content_type: the content_type of the file (optional) + @param content_type: desired MIME content_type or extension to check for + in external URLs. (optional) @type content_type: string @param retry_count: max number of retries for downloading the file @@ -146,6 +154,20 @@ def download_external_url(url, download_to_file, content_type=None, @rtype: string @raise StandardError: if the download failed """ + if content_type and "/" in content_type: + # Probably a MIME type passed + # Try to map it to a list of extensions. + extensions = guess_all_extensions(content_type) + if extensions is None: + msg = 'The content type to check is invalid "%s"' \ + % (content_type,) + raise InvenioFileDownloadFormatError(msg) + elif content_type: + # We assume extension is passed + if content_type[0] != '.': + content_type = '.' + content_type + extensions = [content_type] + error_str = "" error_code = None retry_attempt = 0 @@ -205,7 +227,7 @@ def download_external_url(url, download_to_file, content_type=None, else: # When we get here, it means that the download was a success. try: - finalize_download(url, download_to_file, content_type, request) + finalize_download(url, download_to_file, extensions, request) finally: request.close() return download_to_file @@ -215,15 +237,25 @@ def download_external_url(url, download_to_file, content_type=None, raise InvenioFileDownloadError(msg, code=error_code) -def finalize_download(url, download_to_file, content_type, request): +def finalize_download(url, download_to_file, extensions, request): """ Finalizes the download operation by doing various checks, such as format type, size check etc. """ # If format is given, a format check is performed. - if content_type and content_type not in request.headers['content-type']: - msg = 'The downloaded file is not of the desired format' - raise InvenioFileDownloadError(msg) + if extensions: + downloaded_extensions = guess_all_extensions( + request.headers['content-type'] + ) + + # Ease comparison by making things lowercase. + extensions = map(str.lower, extensions) + downloaded_extensions = map(str.lower, downloaded_extensions) + if not set(extensions) & set(downloaded_extensions): + msg = 'The downloaded file format "%s" is probably' \ + ' not of the desired formats: %r' \ + % (request.headers['content-type'], extensions) + raise InvenioFileDownloadError(msg) # Save the downloaded file to desired or generated location. to_file = open(download_to_file, 'w')