diff --git a/.isort.cfg b/.isort.cfg index 443b418db..30ef4b9dd 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -5,3 +5,4 @@ multi_line_output=3 include_trailing_comma=True force_grid_wrap=0 use_parentheses=True +profile=black diff --git a/libmamba/src/core/channel.cpp b/libmamba/src/core/channel.cpp index 33ae3440e..ad884775f 100644 --- a/libmamba/src/core/channel.cpp +++ b/libmamba/src/core/channel.cpp @@ -491,11 +491,11 @@ namespace mamba { const Channel& channel = ca.second; std::string test_url = join_url(channel.location(), channel.name()); - // original code splits with '/' and compares tokens if (starts_with(url, test_url)) { auto subname = std::string(strip(url.replace(0u, test_url.size(), ""), "/")); + return channel_configuration(channel.location(), join_url(channel.name(), subname), scheme, diff --git a/libmamba/src/core/solver.cpp b/libmamba/src/core/solver.cpp index ccad6d866..b4c7bcdd9 100644 --- a/libmamba/src/core/solver.cpp +++ b/libmamba/src/core/solver.cpp @@ -63,11 +63,30 @@ namespace mamba } } - inline bool channel_match(Solvable* s, const std::string& channel) + inline bool channel_match(Solvable* s, const Channel& needle) { MRepo* mrepo = reinterpret_cast(s->repo->appdata); const Channel* chan = mrepo->channel(); - return chan && chan->name() == channel; + + if (!chan) + return false; + + if ((*chan) == needle) + return true; + + auto& custom_multichannels = Context::instance().custom_multichannels; + auto x = custom_multichannels.find(needle.name()); + if (x != custom_multichannels.end()) + { + for (auto el : (x->second)) + { + const Channel& inner = make_channel(el); + if ((*chan) == inner) + return true; + } + } + + return false; } void MSolver::add_global_job(int job_flag) @@ -84,9 +103,10 @@ namespace mamba // conda_build_form does **NOT** contain the channel info Id match = pool_conda_matchspec(pool, ms.conda_build_form().c_str()); + const Channel& c = make_channel(ms.channel); for (Id* wp = pool_whatprovides_ptr(pool, match); *wp; wp++) { - if (channel_match(pool_id2solvable(pool, *wp), ms.channel)) + if (channel_match(pool_id2solvable(pool, *wp), c)) { queue_push(&selected_pkgs, *wp); } @@ -264,11 +284,13 @@ namespace mamba Id match = pool_conda_matchspec(pool, ms.conda_build_form().c_str()); std::set matching_solvables; + const Channel& c = make_channel(ms.channel); + for (Id* wp = pool_whatprovides_ptr(pool, match); *wp; wp++) { if (!ms.channel.empty()) { - if (!channel_match(pool_id2solvable(pool, *wp), ms.channel)) + if (!channel_match(pool_id2solvable(pool, *wp), c)) { continue; } diff --git a/libmamba/tests/test_channel.cpp b/libmamba/tests/test_channel.cpp index a7f28bff0..ab43f3fa9 100644 --- a/libmamba/tests/test_channel.cpp +++ b/libmamba/tests/test_channel.cpp @@ -339,6 +339,15 @@ namespace mamba ChannelContext::instance().reset(); } + TEST(Channel, channel_name) + { + std::string value = "https://repo.mamba.pm/conda-forge"; + const Channel& c = make_channel(value); + EXPECT_EQ(c.scheme(), "https"); + EXPECT_EQ(c.location(), "repo.mamba.pm"); + EXPECT_EQ(c.name(), "conda-forge"); + EXPECT_EQ(c.platforms(), std::vector({ platform, "noarch" })); + } TEST(Channel, make_channel) { diff --git a/mamba/mamba/mamba.py b/mamba/mamba/mamba.py index dfe2a3d35..f2907c2f7 100644 --- a/mamba/mamba/mamba.py +++ b/mamba/mamba/mamba.py @@ -43,6 +43,7 @@ from conda.gateways.disk.create import mkdir_p from conda.gateways.disk.delete import delete_trash, path_is_clean, rm_rf from conda.gateways.disk.test import is_conda_environment from conda.misc import explicit, touch_nonadmin +from conda.models.channel import MultiChannel from conda.models.match_spec import MatchSpec import libmambapy as api @@ -377,8 +378,11 @@ def install(args, parser, command="install"): for spec in specs: # CONDA TODO: correct handling for subdir isn't yet done spec_channel = spec.get_exact_value("channel") - if spec_channel and spec_channel.base_url not in channels: - channels.append(spec_channel.base_url) + if spec_channel: + if isinstance(spec_channel, MultiChannel): + channels.append(spec_channel.name) + elif spec_channel.base_url not in channels: + channels.append(spec_channel.base_url) index_args["channel_urls"] = channels diff --git a/mamba/tests/test_all.py b/mamba/tests/test_all.py index e40866b28..a0fccc765 100644 --- a/mamba/tests/test_all.py +++ b/mamba/tests/test_all.py @@ -9,6 +9,7 @@ import pytest from utils import ( Environment, add_glibc_virtual_package, + config_file, copy_channels_osx, platform_shells, run_mamba_conda, @@ -178,6 +179,33 @@ def test_empty_create(): ) +multichannel_config = { + "channels": ["conda-forge"], + "custom_multichannels": {"conda-forge2": ["conda-forge"]}, +} + + +@pytest.mark.parametrize("config_file", [multichannel_config], indirect=["config_file"]) +def test_multi_channels(config_file): + # we need to create a config file first + output = subprocess.check_output( + [ + "mamba", + "create", + "-n", + "multichannels", + "conda-forge2::xtensor", + "--dry-run", + "--json", + ] + ) + res = json.loads(output.decode()) + for pkg in res["actions"]["FETCH"]: + assert pkg["channel"].startswith("https://conda.anaconda.org/conda-forge") + for pkg in res["actions"]["LINK"]: + assert pkg["base_url"] == "https://conda.anaconda.org/conda-forge" + + def test_update_py(): # check updating a package when a newer version if platform.system() == "Windows": diff --git a/mamba/tests/utils.py b/mamba/tests/utils.py index 4fc3441b9..3942beb23 100644 --- a/mamba/tests/utils.py +++ b/mamba/tests/utils.py @@ -4,6 +4,10 @@ import shutil import subprocess import time import uuid +from pathlib import Path + +import pytest +import yaml def get_lines(std_pipe): @@ -133,6 +137,23 @@ def run_mamba_conda(channels, package): run("mamba", channels, package) +@pytest.fixture +def config_file(request): + file_loc = Path.home() / ".condarc" + old_config_file = None + if file_loc.exists(): + old_config_file = file_loc.rename(Path.home() / ".condarc.bkup") + + with open(file_loc, "w") as fo: + yaml.dump(request.param, fo) + + yield file_loc + + if old_config_file: + file_loc.unlink() + old_config_file.rename(file_loc) + + def add_glibc_virtual_package(): version = get_glibc_version() here = os.path.dirname(os.path.abspath(__file__))