diff --git a/cli/src/input.rs b/cli/src/input.rs index d9516248d1..1e6c0a31e7 100644 --- a/cli/src/input.rs +++ b/cli/src/input.rs @@ -32,6 +32,10 @@ impl Prepare for InputOptions { program.color_opt = global.color.into(); + if let Ok(nickel_path) = std::env::var("NICKEL_PATH") { + program.add_import_paths(nickel_path.split(':')); + } + #[cfg(debug_assertions)] if self.nostdlib { program.set_skip_stdlib(); diff --git a/core/src/cache.rs b/core/src/cache.rs index e6baac4f7e..6c7d98bc5c 100644 --- a/core/src/cache.rs +++ b/core/src/cache.rs @@ -90,6 +90,7 @@ pub struct Cache { wildcards: HashMap, /// Whether processing should try to continue even in case of errors. Needed by the NLS. error_tolerance: ErrorTolerance, + import_paths: Vec, #[cfg(debug_assertions)] /// Skip loading the stdlib, used for debugging purpose @@ -337,12 +338,20 @@ impl Cache { rev_imports: HashMap::new(), stdlib_ids: None, error_tolerance, + import_paths: Vec::new(), #[cfg(debug_assertions)] skip_stdlib: false, } } + pub fn add_import_paths

(&mut self, paths: impl Iterator) + where + PathBuf: From

, + { + self.import_paths.extend(paths.map(PathBuf::from)); + } + /// Same as [Self::add_file], but assume that the path is already normalized, and take the /// timestamp as a parameter. fn add_file_(&mut self, path: PathBuf, timestamp: SystemTime) -> io::Result { @@ -1314,16 +1323,38 @@ impl ImportResolver for Cache { parent: Option, pos: &TermPos, ) -> Result<(ResolvedTerm, FileId), ImportError> { - let parent_path = parent.and_then(|p| self.get_path(p)).map(PathBuf::from); - let path_buf = with_parent(path, parent_path); + // `parent` is the file that did the import. We first look in its containing directory. + let mut parent_path = parent + .and_then(|p| self.get_path(p)) + .map(PathBuf::from) + .unwrap_or_default(); + parent_path.pop(); + + let possible_parents: Vec = std::iter::once(parent_path) + .chain(self.import_paths.iter().cloned()) + .collect(); + + // Try to import from all possibilities, taking the first one that succeeds. + let (id_op, path_buf) = possible_parents + .iter() + .find_map(|parent| { + let mut path_buf = parent.clone(); + path_buf.push(path); + self.get_or_add_file(&path_buf).ok().map(|x| (x, path_buf)) + }) + .ok_or_else(|| { + let parents = possible_parents + .iter() + .map(|p| p.to_string_lossy()) + .collect::>(); + ImportError::IOError( + path.to_string_lossy().into_owned(), + format!("could not find import (looked in [{}])", parents.join(", ")), + *pos, + ) + })?; + let format = InputFormat::from_path(&path_buf).unwrap_or_default(); - let id_op = self.get_or_add_file(&path_buf).map_err(|err| { - ImportError::IOError( - path_buf.to_string_lossy().into_owned(), - format!("{err}"), - *pos, - ) - })?; let (result, file_id) = match id_op { CacheOp::Cached(id) => (ResolvedTerm::FromCache, id), CacheOp::Done(id) => (ResolvedTerm::FromFile { path: path_buf }, id), @@ -1356,14 +1387,6 @@ impl ImportResolver for Cache { } } -/// Compute the path of a file relatively to a parent. -fn with_parent(path: &OsStr, parent: Option) -> PathBuf { - let mut path_buf = parent.unwrap_or_default(); - path_buf.pop(); - path_buf.push(Path::new(path)); - path_buf -} - /// Normalize the path of a file for unique identification in the cache. /// /// The returned path will be an absolute path. diff --git a/core/src/program.rs b/core/src/program.rs index 195bbbe5be..8bdbb5042c 100644 --- a/core/src/program.rs +++ b/core/src/program.rs @@ -297,6 +297,13 @@ impl Program { self.overrides.extend(overrides); } + pub fn add_import_paths

(&mut self, paths: impl Iterator) + where + PathBuf: From

, + { + self.vm.import_resolver_mut().add_import_paths(paths); + } + /// Only parse the program, don't typecheck or evaluate. returns the [`RichTerm`] AST pub fn parse(&mut self) -> Result { self.vm diff --git a/core/tests/integration/imports/missing-nickel-path.ncl b/core/tests/integration/imports/missing-nickel-path.ncl new file mode 100644 index 0000000000..4f567a0fd3 --- /dev/null +++ b/core/tests/integration/imports/missing-nickel-path.ncl @@ -0,0 +1,5 @@ +# test.type = 'error' +# +# [test.metadata] +# error = 'ImportError::IoError' +2 == (import "two.ncl") diff --git a/core/tests/integration/imports/needs-nickel-path.ncl b/core/tests/integration/imports/needs-nickel-path.ncl new file mode 100644 index 0000000000..dcbc9750f8 --- /dev/null +++ b/core/tests/integration/imports/needs-nickel-path.ncl @@ -0,0 +1,3 @@ +# test.type = 'pass' +# nickel_path = ['tests/integration/imports/imported'] +2 == (import "two.ncl") diff --git a/core/tests/integration/main.rs b/core/tests/integration/main.rs index 7affb080aa..808dc060d3 100644 --- a/core/tests/integration/main.rs +++ b/core/tests/integration/main.rs @@ -63,12 +63,15 @@ fn run_test(test_case: TestCase, path: String) { let test = test_case.annotation.test; for _ in 0..repeat { - let p = TestProgram::new_from_source( + let mut p = TestProgram::new_from_source( Cursor::new(program.clone()), path.as_str(), std::io::stderr(), ) .expect(""); + if let Some(imports) = &test_case.annotation.nickel_path { + p.add_import_paths(imports.iter()); + } match test.clone() { Expectation::Error(expected_err) => { let err = eval_strategy.eval_program_to_err(p); @@ -92,6 +95,7 @@ struct Test { test: Expectation, repeat: Option, eval: Option, + nickel_path: Option>, } #[derive(Clone, Copy, Deserialize)] @@ -196,6 +200,8 @@ enum ErrorExpectation { ParseTypedFieldWithoutDefinition, #[serde(rename = "ImportError::ParseError")] ImportParseError, + #[serde(rename = "ImportError::IoError")] + ImportIoError, #[serde(rename = "ExportError::NumberOutOfRange")] SerializeNumberOutOfRange, } @@ -229,6 +235,7 @@ impl PartialEq for ErrorExpectation { Error::TypecheckError(TypecheckError::FlatTypeInTermPosition { .. }), ) | (ImportParseError, Error::ImportError(ImportError::ParseErrors(..))) + | (ImportIoError, Error::ImportError(ImportError::IOError(..))) | ( SerializeNumberOutOfRange, Error::EvalError(EvalError::SerializationError(ExportError::NumberOutOfRange { @@ -340,6 +347,7 @@ impl std::fmt::Display for ErrorExpectation { "ParseError::TypedFieldWithoutDefinition".to_owned() } ImportParseError => "ImportError::ParseError".to_owned(), + ImportIoError => "ImportError::IoError".to_owned(), EvalBlameError => "EvalError::BlameError".to_owned(), EvalTypeError => "EvalError::TypeError".to_owned(), EvalEqError => "EvalError::EqError".to_owned(),