New apis for downloading (#2695)

* New apis for downloading

* Changes according to review

* Added minimal test and fixed issues

* Python backward compatibility
This commit is contained in:
Johan Mabille 2023-08-23 09:07:20 +02:00 committed by GitHub
parent 8b073ca2fe
commit 6e4d91e2ec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 1515 additions and 122 deletions

View File

@ -196,6 +196,7 @@ jobs:
shell: bash -l {0}
run: |
pybind11-stubgen libmambapy.bindings
pre-commit run --files stubs/libmambapy/binginds-stubds/__init__.pyi
python compare_stubs.py libmambapy/libmambapy/__init__.pyi stubs/libmambapy/bindings-stubs/__init__.pyi
- name: build cache statistics
run: sccache --show-stats

View File

@ -143,6 +143,7 @@ set(LIBMAMBA_SOURCES
${LIBMAMBA_SOURCE_DIR}/core/activation.cpp
${LIBMAMBA_SOURCE_DIR}/core/channel.cpp
${LIBMAMBA_SOURCE_DIR}/core/context.cpp
${LIBMAMBA_SOURCE_DIR}/core/download.cpp
${LIBMAMBA_SOURCE_DIR}/core/environment.cpp
${LIBMAMBA_SOURCE_DIR}/core/environments_manager.cpp
${LIBMAMBA_SOURCE_DIR}/core/error_handling.cpp
@ -228,6 +229,7 @@ set(LIBMAMBA_PUBLIC_HEADERS
${LIBMAMBA_INCLUDE_DIR}/mamba/core/channel.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/palette.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/context.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/download.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/environment.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/environments_manager.hpp
${LIBMAMBA_INCLUDE_DIR}/mamba/core/error_handling.hpp

View File

@ -77,6 +77,8 @@ namespace mamba
int retry_timeout{ 2 }; // seconds
int retry_backoff{ 3 }; // retry_timeout * retry_backoff
int max_retries{ 3 }; // max number of retries
std::map<std::string, std::string> proxy_servers;
};
struct OutputParams
@ -182,8 +184,6 @@ namespace mamba
ThreadsParams threads_params;
PrefixParams prefix_params;
std::map<std::string, std::string> proxy_servers;
std::size_t lock_timeout = 0;
bool use_lockfiles = true;
@ -212,7 +212,9 @@ namespace mamba
};
std::string channel_alias = "https://conda.anaconda.org";
std::map<std::string, AuthenticationInfo>& authentication_info();
using authentication_info_map_t = std::map<std::string, AuthenticationInfo>;
authentication_info_map_t& authentication_info();
const authentication_info_map_t& authentication_info() const;
std::vector<fs::u8path> token_locations{ "~/.continuum/anaconda-client/tokens" };
bool override_channels_enabled = true;

View File

@ -0,0 +1,110 @@
// Copyright (c) 2019, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.
#ifndef MAMBA_CORE_DOWNLOAD_HPP
#define MAMBA_CORE_DOWNLOAD_HPP
#include <functional>
#include <optional>
#include <string>
#include <variant>
#include <tl/expected.hpp>
#include "mamba/core/context.hpp"
namespace mamba
{
struct TransferData
{
int http_status = 0;
std::string effective_url = "";
std::size_t downloaded_size = 0;
std::size_t average_speed = 0;
};
struct DownloadSuccess
{
std::string filename = "";
TransferData transfer = {};
std::string cache_control = "";
std::string etag = "";
std::string last_modified = "";
std::size_t attempt_number = std::size_t(1);
};
struct DownloadError
{
std::string message = "";
std::optional<std::size_t> retry_wait_seconds = std::nullopt;
std::optional<TransferData> transfer = std::nullopt;
std::size_t attempt_number = std::size_t(1);
};
struct DownloadProgress
{
std::size_t downloaded_size = 0;
std::size_t total_to_download = 0;
};
using DownloadEvent = std::variant<DownloadProgress, DownloadError, DownloadSuccess>;
struct DownloadRequest
{
using progress_callback_t = std::function<void(const DownloadEvent&)>;
// TODO: remove these functions when we plug a library with continuation
using on_success_callback_t = std::function<bool(const DownloadSuccess&)>;
using on_failure_callback_t = std::function<void(const DownloadError&)>;
std::string name;
std::string url;
std::string filename;
bool head_only;
bool ignore_failure;
std::optional<std::size_t> expected_size = std::nullopt;
std::optional<std::string> if_none_match = std::nullopt;
std::optional<std::string> if_modified_since = std::nullopt;
std::optional<progress_callback_t> progress = std::nullopt;
std::optional<on_success_callback_t> on_success = std::nullopt;
std::optional<on_failure_callback_t> on_failure = std::nullopt;
DownloadRequest(
const std::string& lname,
const std::string& lurl,
const std::string& lfilename,
bool lhead_only = false,
bool lignore_failure = false
);
};
using DownloadRequestList = std::vector<DownloadRequest>;
struct MultiDownloadRequest
{
DownloadRequestList requests;
};
using DownloadResult = tl::expected<DownloadSuccess, DownloadError>;
using DownloadResultList = std::vector<DownloadResult>;
struct MultiDownloadResult
{
DownloadResultList results;
};
struct DownloadOptions
{
bool fail_fast = false;
bool sort = true;
};
MultiDownloadResult
download(MultiDownloadRequest requests, const Context& context, DownloadOptions options = {});
}
#endif

View File

@ -14,6 +14,7 @@
#include <sstream>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>
#include <time.h>
@ -338,6 +339,8 @@ namespace mamba
/// NOTE: this does not check if the file exists.
bool is_yaml_file_name(std::string_view filename);
std::optional<std::string>
proxy_match(const std::string& url, const std::map<std::string, std::string>& proxy_servers);
std::optional<std::string> proxy_match(const std::string& url);
std::string hide_secrets(std::string_view str);

View File

@ -256,6 +256,36 @@ namespace mamba::util
next_valid_iterator();
}
~filter_iterator() = default;
filter_iterator(const filter_iterator&) = default;
filter_iterator(filter_iterator&&) = default;
self_type& operator=(const self_type& rhs)
{
m_pred.reset();
if (rhs.m_pred)
{
m_pred.emplace(*(rhs.m_pred));
}
m_iter = rhs.m_iter;
m_begin_limit = rhs.m_begin_limit;
m_end = rhs.m_end;
return *this;
}
self_type& operator=(self_type&& rhs)
{
m_pred.reset();
if (rhs.m_pred)
{
m_pred.emplace(*std::move(rhs.m_pred));
}
m_iter = std::move(rhs.m_iter);
m_begin_limit = std::move(rhs.m_begin_limit);
m_end = std::move(rhs.m_end);
return *this;
}
self_type& operator++()
{
++m_iter;
@ -267,7 +297,7 @@ namespace mamba::util
enable_bidirectional_iterator<It, self_type&> operator--()
{
--m_iter;
while (m_iter != m_begin_limit && !m_pred(*m_iter))
while (m_iter != m_begin_limit && !(m_pred.value()(*m_iter)))
{
--m_iter;
}
@ -340,13 +370,17 @@ namespace mamba::util
void next_valid_iterator()
{
while (m_iter != m_end && !m_pred(*m_iter))
while (m_iter != m_end && !(m_pred.value()(*m_iter)))
{
++m_iter;
}
}
Predicate m_pred;
// Trick to enable move and copy assignment: since lambdas are
// not assignable, we encapsulate them in an std::optional and
// rely on it to implement assignment operators. The optional
// should be replaced with a dedicated wrapper.
std::optional<Predicate> m_pred;
Iterator m_iter;
Iterator m_begin_limit;
Iterator m_end;

View File

@ -1304,7 +1304,7 @@ namespace mamba
[this](auto&... args) { return detail::ssl_verify_hook(*this, args...); }
));
insert(Configurable("proxy_servers", &ctx.proxy_servers)
insert(Configurable("proxy_servers", &ctx.remote_fetch_params.proxy_servers)
.group("Network")
.set_rc_configurable()
.description("Use a proxy server for network connections")

View File

@ -6,10 +6,218 @@
#include <spdlog/spdlog.h>
#include "mamba/util/string.hpp"
#include "compression.hpp"
namespace mamba
{
/*********************
* CompressionStream *
*********************/
CompressionStream::CompressionStream(writer&& func)
: m_writer(std::move(func))
{
}
size_t CompressionStream::write(char* in, size_t size)
{
return write_impl(in, size);
}
size_t CompressionStream::invoke_writer(char* in, size_t size)
{
return m_writer(in, size);
}
/*************************
* ZstdCompressionStream *
*************************/
class ZstdCompressionStream : public CompressionStream
{
public:
using base_type = CompressionStream;
using writer = base_type::writer;
explicit ZstdCompressionStream(writer&& func);
virtual ~ZstdCompressionStream();
private:
size_t write_impl(char* in, size_t size) override;
static constexpr size_t BUFFER_SIZE = 256 * 1024;
ZSTD_DCtx* p_stream;
char m_buffer[BUFFER_SIZE];
};
ZstdCompressionStream::ZstdCompressionStream(writer&& func)
: base_type(std::move(func))
, p_stream(ZSTD_createDCtx())
{
ZSTD_initDStream(p_stream);
}
ZstdCompressionStream::~ZstdCompressionStream()
{
ZSTD_freeDCtx(p_stream);
}
size_t ZstdCompressionStream::write_impl(char* in, size_t size)
{
ZSTD_inBuffer input = { in, size, 0 };
ZSTD_outBuffer output = { m_buffer, BUFFER_SIZE, 0 };
while (input.pos < input.size)
{
auto ret = ZSTD_decompressStream(p_stream, &output, &input);
if (ZSTD_isError(ret))
{
// This is temporary...
// TODO Remove dependency on spdlog after deciding on what to do with logging
spdlog::error("ZSTD decompression error: {}", ZSTD_getErrorName(ret));
return size + 1;
}
if (output.pos > 0)
{
size_t wcb_res = base_type::invoke_writer(m_buffer, output.pos);
if (wcb_res != output.pos)
{
return size + 1;
}
output.pos = 0;
}
}
return size;
}
/**************************
* Bzip2CompressionStream *
**************************/
class Bzip2CompressionStream : public CompressionStream
{
public:
using base_type = CompressionStream;
using writer = base_type::writer;
explicit Bzip2CompressionStream(writer&& func);
virtual ~Bzip2CompressionStream();
private:
size_t write_impl(char* in, size_t size) override;
static constexpr size_t BUFFER_SIZE = 256 * 1024;
bz_stream m_stream;
char m_buffer[BUFFER_SIZE];
};
Bzip2CompressionStream::Bzip2CompressionStream(writer&& func)
: base_type(std::move(func))
{
m_stream.bzalloc = nullptr;
m_stream.bzfree = nullptr;
m_stream.opaque = nullptr;
int error = BZ2_bzDecompressInit(&m_stream, 0, false);
if (error != BZ_OK)
{
throw std::runtime_error("BZ2_bzDecompressInit failed");
}
}
Bzip2CompressionStream::~Bzip2CompressionStream()
{
BZ2_bzDecompressEnd(&m_stream);
}
size_t Bzip2CompressionStream::write_impl(char* in, size_t size)
{
m_stream.next_in = in;
m_stream.avail_in = static_cast<unsigned int>(size);
while (m_stream.avail_in > 0)
{
m_stream.next_out = m_buffer;
m_stream.avail_out = Bzip2CompressionStream::BUFFER_SIZE;
int ret = BZ2_bzDecompress(&m_stream);
if (ret != BZ_OK && ret != BZ_STREAM_END)
{
// This is temporary...
// TODO Remove dependency on spdlog after deciding on what to do with logging
spdlog::error("Bzip2 decompression error: {}", ret);
return size + 1;
}
size_t wcb_res = base_type::invoke_writer(m_buffer, BUFFER_SIZE - m_stream.avail_out);
if (wcb_res != BUFFER_SIZE - m_stream.avail_out)
{
return size + 1;
}
}
return size;
}
/***********************
* NoCompressionStream *
***********************/
class NoCompressionStream : public CompressionStream
{
public:
using base_type = CompressionStream;
using writer = base_type::writer;
explicit NoCompressionStream(writer&& func);
virtual ~NoCompressionStream() = default;
private:
size_t write_impl(char* in, size_t size) override;
};
NoCompressionStream::NoCompressionStream(writer&& func)
: base_type(std::move(func))
{
}
size_t NoCompressionStream::write_impl(char* in, size_t size)
{
return base_type::invoke_writer(in, size);
}
std::unique_ptr<CompressionStream>
make_compression_stream(const std::string& url, CompressionStream::writer&& func)
{
if (util::ends_with(url, ".json.zst"))
{
return std::make_unique<ZstdCompressionStream>(std::move(func));
}
else if (util::ends_with(url, "json.bz2"))
{
return std::make_unique<Bzip2CompressionStream>(std::move(func));
}
else
{
return std::make_unique<NoCompressionStream>(std::move(func));
}
}
// TODO: remove XXXStreams and rename XXXCompressionStream into XXXStream
/*************
* Old stuff *
*************/
size_t ZstdStream::write(char* in, size_t size)
{
ZSTD_inBuffer input = { in, size, 0 };

View File

@ -14,6 +14,40 @@
namespace mamba
{
class CompressionStream
{
public:
using writer = std::function<size_t(char*, size_t)>;
virtual ~CompressionStream() = default;
CompressionStream(const CompressionStream&) = delete;
CompressionStream& operator=(const CompressionStream&) = delete;
CompressionStream(CompressionStream&&) = delete;
CompressionStream& operator=(CompressionStream&&) = delete;
size_t write(char* in, size_t size);
protected:
CompressionStream(writer&& func);
size_t invoke_writer(char* in, size_t size);
private:
virtual size_t write_impl(char* in, size_t size) = 0;
writer m_writer;
};
std::unique_ptr<CompressionStream>
make_compression_stream(const std::string& url, CompressionStream::writer&& func);
// TODO: remove the following when switching to new CompressionStream
struct ZstdStream
{
static constexpr size_t BUFFER_SIZE = 256000;
@ -89,6 +123,8 @@ namespace mamba
{
return ZSTD_DStreamOutSize();
}
} // namespace mamba
#endif // MAMBA_COMPRESSION_HPP

View File

@ -179,6 +179,11 @@ namespace mamba
return m_authentication_info;
}
const std::map<std::string, AuthenticationInfo>& Context::authentication_info() const
{
return const_cast<Context*>(this)->authentication_info();
}
void Context::load_authentication_info()
{
auto& ctx = Context::instance();

View File

@ -4,7 +4,8 @@
//
// The full license is in the file LICENSE, distributed with this software.
// TODO remove all these includes later?
#include <functional>
#include <spdlog/spdlog.h>
#include "mamba/core/environment.hpp" // for NETRC env var
@ -181,6 +182,51 @@ namespace mamba
return m_serious;
}
/**********
* CURLId *
**********/
CURLId::CURLId(CURL* handle)
: p_handle(handle)
{
}
bool CURLId::operator==(const CURLId& rhs) const
{
return p_handle == rhs.p_handle;
}
bool CURLId::operator!=(const CURLId& rhs) const
{
return !(*this == rhs);
}
bool CURLId::operator<(const CURLId& rhs) const
{
return p_handle < rhs.p_handle;
}
bool CURLId::operator<=(const CURLId& rhs) const
{
return !(*this > rhs);
}
bool CURLId::operator>(const CURLId& rhs) const
{
return rhs < *this;
}
bool CURLId::operator>=(const CURLId& rhs) const
{
return rhs <= *this;
}
std::size_t CURLId::hash() const noexcept
{
std::hash<CURL*> h;
return h(p_handle);
}
/**************
* CURLHandle *
**************/
@ -272,7 +318,7 @@ namespace mamba
}
template <class T>
tl::expected<T, CURLcode> CURLHandle::get_info(CURLINFO option)
tl::expected<T, CURLcode> CURLHandle::get_info(CURLINFO option) const
{
T val;
CURLcode result = curl_easy_getinfo(m_handle, option, &val);
@ -294,14 +340,14 @@ namespace mamba
// defining `long long` is needed to handle `curl_off_t` is `long long` case without
// causing duplication.
template tl::expected<long, CURLcode> CURLHandle::get_info(CURLINFO option);
template tl::expected<char*, CURLcode> CURLHandle::get_info(CURLINFO option);
template tl::expected<double, CURLcode> CURLHandle::get_info(CURLINFO option);
template tl::expected<long long, CURLcode> CURLHandle::get_info(CURLINFO option);
template tl::expected<curl_slist*, CURLcode> CURLHandle::get_info(CURLINFO option);
template tl::expected<long, CURLcode> CURLHandle::get_info(CURLINFO option) const;
template tl::expected<char*, CURLcode> CURLHandle::get_info(CURLINFO option) const;
template tl::expected<double, CURLcode> CURLHandle::get_info(CURLINFO option) const;
template tl::expected<long long, CURLcode> CURLHandle::get_info(CURLINFO option) const;
template tl::expected<curl_slist*, CURLcode> CURLHandle::get_info(CURLINFO option) const;
template <>
tl::expected<std::size_t, CURLcode> CURLHandle::get_info(CURLINFO option)
tl::expected<std::size_t, CURLcode> CURLHandle::get_info(CURLINFO option) const
{
auto res = get_info<curl_off_t>(option);
if (res)
@ -315,7 +361,7 @@ namespace mamba
}
template <>
tl::expected<int, CURLcode> CURLHandle::get_info(CURLINFO option)
tl::expected<int, CURLcode> CURLHandle::get_info(CURLINFO option) const
{
auto res = get_info<long>(option);
if (res)
@ -329,7 +375,7 @@ namespace mamba
}
template <>
tl::expected<std::string, CURLcode> CURLHandle::get_info(CURLINFO option)
tl::expected<std::string, CURLcode> CURLHandle::get_info(CURLINFO option) const
{
auto res = get_info<char*>(option);
if (res)
@ -362,6 +408,11 @@ namespace mamba
);
}
void CURLHandle::reset_handle()
{
curl_easy_reset(m_handle);
}
CURLHandle& CURLHandle::add_header(const std::string& header)
{
p_headers = curl_slist_append(p_headers, header.c_str());
@ -399,7 +450,7 @@ namespace mamba
return m_errorbuffer;
}
std::string CURLHandle::get_curl_effective_url()
std::string CURLHandle::get_curl_effective_url() const
{
return get_info<std::string>(CURLINFO_EFFECTIVE_URL).value();
}
@ -411,7 +462,7 @@ namespace mamba
bool CURLHandle::is_curl_res_ok() const
{
return (m_result == CURLE_OK);
return is_curl_res_ok(m_result);
}
void CURLHandle::set_result(CURLcode res)
@ -421,12 +472,37 @@ namespace mamba
std::string CURLHandle::get_res_error() const
{
return static_cast<std::string>(curl_easy_strerror(m_result));
return get_res_error(m_result);
}
bool CURLHandle::can_proceed()
{
switch (m_result)
return can_retry(m_result);
}
void CURLHandle::perform()
{
m_result = curl_easy_perform(m_handle);
}
CURLId CURLHandle::get_id() const
{
return CURLId(m_handle);
}
bool CURLHandle::is_curl_res_ok(CURLcode res)
{
return res == CURLE_OK;
}
std::string CURLHandle::get_res_error(CURLcode res)
{
return static_cast<std::string>(curl_easy_strerror(res));
}
bool CURLHandle::can_retry(CURLcode res)
{
switch (res)
{
case CURLE_ABORTED_BY_CALLBACK:
case CURLE_BAD_FUNCTION_ARGUMENT:
@ -452,11 +528,6 @@ namespace mamba
}
}
void CURLHandle::perform()
{
m_result = curl_easy_perform(m_handle);
}
CURL* unwrap(const CURLHandle& h)
{
return h.m_handle;
@ -472,50 +543,6 @@ namespace mamba
return !(lhs == rhs);
}
/*****************
* CURLReference *
*****************/
CURLReference::CURLReference(CURL* handle)
: p_handle(handle)
{
}
CURL* unwrap(const CURLReference& h)
{
return h.p_handle;
}
bool operator==(const CURLReference& lhs, const CURLReference& rhs)
{
return unwrap(lhs) == unwrap(rhs);
}
bool operator==(const CURLReference& lhs, const CURLHandle& rhs)
{
return unwrap(lhs) == unwrap(rhs);
}
bool operator==(const CURLHandle& lhs, const CURLReference& rhs)
{
return unwrap(lhs) == unwrap(rhs);
}
bool operator!=(const CURLReference& lhs, const CURLReference& rhs)
{
return !(lhs == rhs);
}
bool operator!=(const CURLReference& lhs, const CURLHandle& rhs)
{
return !(lhs == rhs);
}
bool operator!=(const CURLHandle& lhs, const CURLReference& rhs)
{
return !(lhs == rhs);
}
/*******************
* CURLMultiHandle *
*******************/
@ -562,7 +589,8 @@ namespace mamba
void CURLMultiHandle::add_handle(const CURLHandle& h)
{
CURLMcode code = curl_multi_add_handle(p_handle, unwrap(h));
CURL* unw = unwrap(h);
CURLMcode code = curl_multi_add_handle(p_handle, unw);
if (code != CURLM_CALL_MULTI_PERFORM)
{
if (code != CURLM_OK)
@ -594,7 +622,9 @@ namespace mamba
CURLMsg* msg = curl_multi_info_read(p_handle, &msgs_in_queue);
if (msg != nullptr)
{
return CURLMultiResponse{ msg->easy_handle, msg->data.result, msg->msg == CURLMSG_DONE };
return CURLMultiResponse{ CURLId(msg->easy_handle),
msg->data.result,
msg->msg == CURLMSG_DONE };
}
else
{

View File

@ -64,6 +64,41 @@ namespace mamba
bool m_serious;
};
class CURLId
{
public:
bool operator==(const CURLId& rhs) const;
bool operator!=(const CURLId& rhs) const;
bool operator<(const CURLId& rhs) const;
bool operator<=(const CURLId& rhs) const;
bool operator>(const CURLId& rhs) const;
bool operator>=(const CURLId& rhs) const;
std::size_t hash() const noexcept;
private:
explicit CURLId(CURL* handle = nullptr);
CURL* p_handle;
friend class CURLHandle;
friend class CURLMultiHandle;
};
}
template <>
struct std::hash<mamba::CURLId>
{
std::size_t operator()(const mamba::CURLId& arg) const noexcept
{
return arg.hash();
}
};
namespace mamba
{
class CURLHandle
{
public:
@ -76,7 +111,7 @@ namespace mamba
const std::pair<std::string_view, CurlLogLevel> get_ssl_backend_info();
template <class T>
tl::expected<T, CURLcode> get_info(CURLINFO option);
tl::expected<T, CURLcode> get_info(CURLINFO option) const;
void configure_handle(
const std::string& url,
@ -87,6 +122,8 @@ namespace mamba
const std::string& ssl_verify
);
void reset_handle();
CURLHandle& add_header(const std::string& header);
CURLHandle& add_headers(const std::vector<std::string>& headers);
CURLHandle& reset_headers();
@ -97,18 +134,26 @@ namespace mamba
CURLHandle& set_opt_header();
const char* get_error_buffer() const;
std::string get_curl_effective_url();
std::string get_curl_effective_url() const;
std::size_t get_result() const;
bool is_curl_res_ok() const;
[[deprecated]] std::size_t get_result() const;
[[deprecated]] bool is_curl_res_ok() const;
void set_result(CURLcode res);
[[deprecated]] void set_result(CURLcode res);
std::string get_res_error() const;
[[deprecated]] std::string get_res_error() const;
bool can_proceed();
// Side-effect programming, to remove
[[deprecated]] bool can_proceed();
void perform();
CURLId get_id() const;
// New API to avoid storing result
static bool is_curl_res_ok(CURLcode res);
static std::string get_res_error(CURLcode res);
static bool can_retry(CURLcode res);
private:
CURL* m_handle;
@ -122,29 +167,9 @@ namespace mamba
bool operator==(const CURLHandle& lhs, const CURLHandle& rhs);
bool operator!=(const CURLHandle& lhs, const CURLHandle& rhs);
class CURLReference
{
public:
CURLReference(CURL* handle);
private:
CURL* p_handle;
friend CURL* unwrap(const CURLReference&);
};
bool operator==(const CURLHandle& lhs, const CURLReference& rhs);
bool operator==(const CURLReference& lhs, const CURLHandle& rhs);
bool operator==(const CURLReference& lhs, const CURLReference& rhs);
bool operator!=(const CURLHandle& lhs, const CURLReference& rhs);
bool operator!=(const CURLReference& lhs, const CURLHandle& rhs);
bool operator!=(const CURLReference& lhs, const CURLReference& rhs);
struct CURLMultiResponse
{
CURLReference m_handle_ref;
CURLId m_handle_id;
CURLcode m_transfer_result;
bool m_transfer_done;
};

View File

@ -0,0 +1,695 @@
#include "mamba/core/download.hpp"
#include "mamba/core/util.hpp"
#include "mamba/util/iterator.hpp"
#include "mamba/util/string.hpp"
#include "mamba/util/url.hpp"
#include "curl.hpp"
#include "download_impl.hpp"
namespace mamba
{
/**********************************
* DownloadAttempt implementation *
**********************************/
DownloadAttempt::DownloadAttempt(const DownloadRequest& request)
: p_request(&request)
{
p_stream = make_compression_stream(
p_request->url,
[this](char* in, std::size_t size) { return this->write_data(in, size); }
);
m_retry_wait_seconds = std::size_t(0);
}
CURLId DownloadAttempt::prepare_download(
CURLMultiHandle& downloader,
const Context& context,
on_success_callback success,
on_failure_callback error
)
{
m_retry_wait_seconds = static_cast<std::size_t>(context.remote_fetch_params.retry_timeout);
configure_handle(context);
downloader.add_handle(m_handle);
m_success_callback = std::move(success);
m_failure_callback = std::move(error);
return m_handle.get_id();
}
namespace
{
bool is_http_status_ok(int http_status)
{
return http_status / 100 == 2;
}
}
bool DownloadAttempt::finish_download(CURLMultiHandle& downloader, CURLcode code)
{
if (!CURLHandle::is_curl_res_ok(code))
{
DownloadError error = build_download_error(code);
clean_attempt(downloader, true);
invoke_progress_callback(error);
return m_failure_callback(std::move(error));
}
else
{
TransferData data = get_transfer_data();
if (!is_http_status_ok(data.http_status))
{
DownloadError error = build_download_error(std::move(data));
clean_attempt(downloader, true);
invoke_progress_callback(error);
return m_failure_callback(std::move(error));
}
else
{
DownloadSuccess success = build_download_success(std::move(data));
clean_attempt(downloader, false);
invoke_progress_callback(success);
return m_success_callback(std::move(success));
}
}
}
void DownloadAttempt::clean_attempt(CURLMultiHandle& downloader, bool erase_downloaded)
{
downloader.remove_handle(m_handle);
m_handle.reset_handle();
if (m_file.is_open())
{
m_file.close();
}
if (erase_downloaded && fs::exists(p_request->filename))
{
fs::remove(p_request->filename);
}
m_cache_control.clear();
m_etag.clear();
m_last_modified.clear();
}
void DownloadAttempt::invoke_progress_callback(const DownloadEvent& event) const
{
if (p_request->progress.has_value())
{
p_request->progress.value()(event);
}
}
auto DownloadAttempt::create_completion_function() -> completion_function
{
return [this](CURLMultiHandle& handle, CURLcode code)
{ return this->finish_download(handle, code); };
}
namespace
{
int
curl_debug_callback(CURL* /* handle */, curl_infotype type, char* data, size_t size, void* userptr)
{
auto* logger = reinterpret_cast<spdlog::logger*>(userptr);
auto log = Console::hide_secrets(std::string_view(data, size));
switch (type)
{
case CURLINFO_TEXT:
logger->info(fmt::format("* {}", log));
break;
case CURLINFO_HEADER_OUT:
logger->info(fmt::format("> {}", log));
break;
case CURLINFO_HEADER_IN:
logger->info(fmt::format("< {}", log));
break;
default:
break;
}
return 0;
}
std::string
build_transfer_message(int http_status, const std::string& effective_url, std::size_t size)
{
std::stringstream ss;
ss << "Transfer finalized, status: " << http_status << " [" << effective_url << "] "
<< size << " bytes";
return ss.str();
}
}
void DownloadAttempt::configure_handle(const Context& context)
{
// TODO: we should probably store set_low_speed_limit and set_ssl_no_revoke in
// RemoteFetchParams if the request is slower than 30b/s for 60 seconds, cancel.
const std::string no_low_speed_limit = std::getenv("MAMBA_NO_LOW_SPEED_LIMIT")
? std::getenv("MAMBA_NO_LOW_SPEED_LIMIT")
: "0";
const bool set_low_speed_opt = (no_low_speed_limit == "0");
const std::string ssl_no_revoke_env = std::getenv("MAMBA_SSL_NO_REVOKE")
? std::getenv("MAMBA_SSL_NO_REVOKE")
: "0";
const bool set_ssl_no_revoke = context.remote_fetch_params.ssl_no_revoke
|| (ssl_no_revoke_env != "0");
m_handle.configure_handle(
p_request->url,
set_low_speed_opt,
context.remote_fetch_params.connect_timeout_secs,
set_ssl_no_revoke,
proxy_match(p_request->url),
context.remote_fetch_params.ssl_verify
);
m_handle.set_opt(CURLOPT_NOBODY, p_request->head_only);
m_handle.set_opt(CURLOPT_HEADERFUNCTION, &DownloadAttempt::curl_header_callback);
m_handle.set_opt(CURLOPT_HEADERDATA, this);
m_handle.set_opt(CURLOPT_WRITEFUNCTION, &DownloadAttempt::curl_write_callback);
m_handle.set_opt(CURLOPT_WRITEDATA, this);
if (p_request->progress.has_value())
{
m_handle.set_opt(CURLOPT_XFERINFOFUNCTION, &DownloadAttempt::curl_progress_callback);
m_handle.set_opt(CURLOPT_XFERINFODATA, this);
m_handle.set_opt(CURLOPT_NOPROGRESS, 0L);
}
if (util::ends_with(p_request->url, ".json"))
{
// accept all encodings supported by the libcurl build
m_handle.set_opt(CURLOPT_ACCEPT_ENCODING, "");
m_handle.add_header("Content-Type: application/json");
}
m_handle.set_opt(CURLOPT_VERBOSE, context.output_params.verbosity >= 2);
configure_handle_headers(context);
auto logger = spdlog::get("libcurl");
m_handle.set_opt(CURLOPT_DEBUGFUNCTION, curl_debug_callback);
m_handle.set_opt(CURLOPT_DEBUGDATA, logger.get());
}
void DownloadAttempt::configure_handle_headers(const Context& context)
{
m_handle.reset_headers();
std::string user_agent = fmt::format(
"User-Agent: {} {}",
context.remote_fetch_params.user_agent,
curl_version()
);
m_handle.add_header(user_agent);
// get url host
const auto url_handler = util::URL::parse(p_request->url);
auto host = url_handler.host();
const auto port = url_handler.port();
if (port.size())
{
host += ":" + port;
}
if (context.authentication_info().count(host))
{
const auto& auth = context.authentication_info().at(host);
if (auth.type == AuthenticationType::kBearerToken)
{
m_handle.add_header(fmt::format("Authorization: Bearer {}", auth.value));
}
}
if (p_request->if_none_match.has_value())
{
m_handle.add_header("If-None-Match:" + p_request->if_none_match.value());
}
if (p_request->if_modified_since.has_value())
{
m_handle.add_header("If-Modified-Since:" + p_request->if_modified_since.value());
}
m_handle.set_opt_header();
}
size_t DownloadAttempt::write_data(char* buffer, size_t size)
{
if (!m_file.is_open())
{
m_file = open_ofstream(p_request->filename, std::ios::binary);
if (!m_file)
{
LOG_ERROR << "Could not open file for download " << p_request->filename << ": "
<< strerror(errno);
// Return a size _different_ than the expected write size to signal an error
return size + 1;
}
}
m_file.write(buffer, static_cast<std::streamsize>(size));
if (!m_file)
{
LOG_ERROR << "Could not write to file " << p_request->filename << ": " << strerror(errno);
// Return a size _different_ than the expected write size to signal an error
return size + 1;
}
return size;
}
size_t
DownloadAttempt::curl_header_callback(char* buffer, size_t size, size_t nbitems, void* self)
{
auto* s = reinterpret_cast<DownloadAttempt*>(self);
const size_t buffer_size = size * nbitems;
const std::string_view header(buffer, buffer_size);
auto colon_idx = header.find(':');
if (colon_idx != std::string_view::npos)
{
std::string_view key = header.substr(0, colon_idx);
colon_idx++;
// remove spaces
while (std::isspace(header[colon_idx]))
{
++colon_idx;
}
// remove \r\n header ending
const auto header_end = header.find_first_of("\r\n");
std::string_view value = header.substr(
colon_idx,
(header_end > colon_idx) ? header_end - colon_idx : 0
);
// http headers are case insensitive!
const std::string lkey = util::to_lower(key);
if (lkey == "etag")
{
s->m_etag = value;
}
else if (lkey == "cache-control")
{
s->m_cache_control = value;
}
else if (lkey == "last-modified")
{
s->m_last_modified = value;
}
}
return buffer_size;
}
size_t DownloadAttempt::curl_write_callback(char* buffer, size_t size, size_t nbitems, void* self)
{
return reinterpret_cast<DownloadAttempt*>(self)->p_stream->write(buffer, size * nbitems);
}
int DownloadAttempt::curl_progress_callback(
void* f,
curl_off_t total_to_download,
curl_off_t now_downloaded,
curl_off_t,
curl_off_t
)
{
auto* self = reinterpret_cast<DownloadAttempt*>(f);
self->p_request->progress.value()(DownloadProgress{
static_cast<std::size_t>(total_to_download),
static_cast<std::size_t>(now_downloaded) });
return 0;
}
namespace http
{
static constexpr int PAYLOAD_TOO_LARGE = 413;
static constexpr int TOO_MANY_REQUESTS = 429;
static constexpr int INTERNAL_SERVER_ERROR = 500;
static constexpr int ARBITRARY_ERROR = 10000;
}
bool DownloadAttempt::can_retry(CURLcode code) const
{
return m_handle.can_retry(code) && !util::starts_with(p_request->url, "file://");
}
bool DownloadAttempt::can_retry(const TransferData& data) const
{
return (data.http_status == http::PAYLOAD_TOO_LARGE
|| data.http_status == http::TOO_MANY_REQUESTS
|| data.http_status >= http::INTERNAL_SERVER_ERROR)
&& !util::starts_with(p_request->url, "file://");
}
TransferData DownloadAttempt::get_transfer_data() const
{
return {
/* .http_status = */ m_handle.get_info<int>(CURLINFO_RESPONSE_CODE)
.value_or(http::ARBITRARY_ERROR),
/* .effective_url = */ m_handle.get_info<char*>(CURLINFO_EFFECTIVE_URL).value(),
/* .dwonloaded_size = */ m_handle.get_info<std::size_t>(CURLINFO_SIZE_DOWNLOAD_T).value_or(0),
/* .average_speed = */ m_handle.get_info<std::size_t>(CURLINFO_SPEED_DOWNLOAD_T).value_or(0)
};
}
DownloadError DownloadAttempt::build_download_error(CURLcode code) const
{
DownloadError error;
std::stringstream strerr;
strerr << "Download error (" << code << ") " << m_handle.get_res_error(code) << " ["
<< m_handle.get_curl_effective_url() << "]\n"
<< m_handle.get_error_buffer();
error.message = strerr.str();
if (can_retry(code))
{
error.retry_wait_seconds = m_retry_wait_seconds;
}
return error;
}
DownloadError DownloadAttempt::build_download_error(TransferData data) const
{
DownloadError error;
if (can_retry(data))
{
error.retry_wait_seconds = m_handle.get_info<std::size_t>(CURLINFO_RETRY_AFTER)
.value_or(m_retry_wait_seconds);
}
error.message = build_transfer_message(data.http_status, data.effective_url, data.downloaded_size);
error.transfer = std::move(data);
return error;
}
DownloadSuccess DownloadAttempt::build_download_success(TransferData data) const
{
return { /*.filename = */ p_request->filename,
/*.trnasfer = */ std::move(data),
/*.cache_control = */ m_cache_control,
/*.etag = */ m_etag,
/*.last_modified = */ m_last_modified };
}
/**********************************
* DownloadTracker implementation *
**********************************/
DownloadTracker::DownloadTracker(const DownloadRequest& request, DownloadTrackerOptions options)
: p_request(&request)
, m_options(std::move(options))
, m_attempt(request)
, m_attempt_results()
, m_state(DownloadState::WAITING)
, m_next_retry(std::nullopt)
{
}
auto DownloadTracker::prepare_new_attempt(CURLMultiHandle& handle, const Context& context)
-> completion_map_entry
{
m_next_retry = std::nullopt;
CURLId id = m_attempt.prepare_download(
handle,
context,
[this](DownloadSuccess res)
{
bool finalize_res = invoke_on_success(res);
set_state(finalize_res);
throw_if_required(res);
save(std::move(res));
return is_waiting();
},
[this](DownloadError res)
{
invoke_on_failure(res);
set_state(res);
throw_if_required(res);
save(std::move(res));
return is_waiting();
}
);
return { id, m_attempt.create_completion_function() };
}
bool DownloadTracker::can_start_transfer() const
{
return is_waiting()
&& (!m_next_retry.has_value()
|| m_next_retry.value() < std::chrono::steady_clock::now());
}
const DownloadResult& DownloadTracker::get_result() const
{
return m_attempt_results.back();
}
bool DownloadTracker::invoke_on_success(const DownloadSuccess& res) const
{
if (p_request->on_success.has_value())
{
return p_request->on_success.value()(res);
}
return true;
}
void DownloadTracker::invoke_on_failure(const DownloadError& res) const
{
if (p_request->on_failure.has_value())
{
p_request->on_failure.value()(res);
}
}
bool DownloadTracker::is_waiting() const
{
return m_state == DownloadState::WAITING;
}
void DownloadTracker::set_state(bool success)
{
if (success)
{
m_state = DownloadState::FINISHED;
}
else
{
if (m_attempt_results.size() < m_options.max_retries)
{
m_state = DownloadState::WAITING;
}
else
{
m_state = DownloadState::FAILED;
}
}
}
void DownloadTracker::set_state(const DownloadError& res)
{
if (res.retry_wait_seconds.has_value())
{
if (m_attempt_results.size() < m_options.max_retries)
{
m_state = DownloadState::WAITING;
m_next_retry = std::chrono::steady_clock::now()
+ std::chrono::seconds(res.retry_wait_seconds.value());
}
else
{
m_state = DownloadState::FAILED;
}
}
else
{
m_state = DownloadState::FAILED;
}
}
void DownloadTracker::throw_if_required(const DownloadSuccess& res)
{
if (m_state == DownloadState::FAILED && !p_request->ignore_failure && m_options.fail_fast)
{
throw std::runtime_error(
"Multi-download failed. Reason: "
+ build_transfer_message(
res.transfer.http_status,
res.transfer.effective_url,
res.transfer.downloaded_size
)
);
}
}
void DownloadTracker::throw_if_required(const DownloadError& res)
{
if (m_state == DownloadState::FAILED && !p_request->ignore_failure)
{
throw std::runtime_error(res.message);
}
}
void DownloadTracker::save(DownloadSuccess&& res)
{
res.attempt_number = m_attempt_results.size() + std::size_t(1);
m_attempt_results.push_back(DownloadResult(std::move(res)));
}
void DownloadTracker::save(DownloadError&& res)
{
res.attempt_number = m_attempt_results.size() + std::size_t(1);
m_attempt_results.push_back(tl::unexpected(std::move(res)));
}
/*****************************
* DOWNLOADER IMPLEMENTATION *
*****************************/
Downloader::Downloader(MultiDownloadRequest requests, DownloadOptions options, const Context& context)
: m_requests(std::move(requests))
, m_options(std::move(options))
, p_context(&context)
, m_curl_handle(context.threads_params.download_threads)
, m_trackers()
{
if (m_options.sort)
{
std::sort(
m_requests.requests.begin(),
m_requests.requests.end(),
[](const DownloadRequest& a, const DownloadRequest& b) -> bool
{ return a.expected_size.value_or(SIZE_MAX) > b.expected_size.value_or(SIZE_MAX); }
);
}
m_trackers.reserve(m_requests.requests.size());
std::size_t max_retries = static_cast<std::size_t>(context.remote_fetch_params.max_retries);
DownloadTrackerOptions tracker_options{ max_retries, options.fail_fast };
std::transform(
m_requests.requests.begin(),
m_requests.requests.end(),
std::inserter(m_trackers, m_trackers.begin()),
[tracker_options](const DownloadRequest& req)
{ return DownloadTracker(req, tracker_options); }
);
m_waiting_count = m_trackers.size();
}
MultiDownloadResult Downloader::download()
{
while (!download_done())
{
prepare_next_downloads();
update_downloads();
}
return build_result();
}
void Downloader::prepare_next_downloads()
{
size_t running_attempts = m_completion_map.size();
const size_t max_parallel_downloads = p_context->threads_params.download_threads;
auto start_filter = mamba::util::filter(
m_trackers,
[&](DownloadTracker& tracker)
{ return running_attempts < max_parallel_downloads && tracker.can_start_transfer(); }
);
for (auto& tracker : start_filter)
{
auto [iter, success] = m_completion_map.insert(
tracker.prepare_new_attempt(m_curl_handle, *p_context)
);
if (success)
{
++running_attempts;
}
}
}
void Downloader::update_downloads()
{
std::size_t still_running = m_curl_handle.perform();
while (auto resp = m_curl_handle.pop_message())
{
const auto& msg = resp.value();
if (!msg.m_transfer_done)
{
// We are only interested in messages about finished transfers
continue;
}
auto completion_callback = m_completion_map.find(msg.m_handle_id);
if (completion_callback == m_completion_map.end())
{
spdlog::error(
"Received DONE message from unknown target - running transfers left = {}",
still_running
);
}
else
{
bool still_waiting = completion_callback->second(m_curl_handle, msg.m_transfer_result);
m_completion_map.erase(completion_callback);
if (!still_waiting)
{
--m_waiting_count;
}
}
}
}
bool Downloader::download_done() const
{
return m_waiting_count == 0;
}
MultiDownloadResult Downloader::build_result() const
{
DownloadResultList result;
result.reserve(m_trackers.size());
std::transform(
m_trackers.begin(),
m_trackers.end(),
std::inserter(result, result.begin()),
[](const DownloadTracker& tracker) { return tracker.get_result(); }
);
return { result };
}
/*****************************
* Public API implementation *
*****************************/
DownloadRequest::DownloadRequest(
const std::string& lname,
const std::string& lurl,
const std::string& lfilename,
bool lhead_only,
bool lignore_failure
)
: name(lname)
, url(lurl)
, filename(lfilename)
, head_only(lhead_only)
, ignore_failure(lignore_failure)
{
}
MultiDownloadResult
download(MultiDownloadRequest requests, const Context& context, DownloadOptions options)
{
Downloader dl(std::move(requests), std::move(options), context);
return dl.download();
}
}

View File

@ -0,0 +1,163 @@
// Copyright (c) 2019, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.
#ifndef MAMBA_CORE_DOWNLOAD_IMPL_HPP
#define MAMBA_CORE_DOWNLOAD_IMPL_HPP
#include <chrono>
#include <unordered_map>
#include "mamba/core/download.hpp"
#include "compression.hpp"
#include "curl.hpp"
namespace mamba
{
enum class DownloadState
{
WAITING,
PREPARING,
RUNNING,
FINISHED,
FAILED
};
class DownloadAttempt
{
public:
using completion_function = std::function<bool(CURLMultiHandle&, CURLcode)>;
using on_success_callback = std::function<bool(DownloadSuccess)>;
using on_failure_callback = std::function<bool(DownloadError)>;
explicit DownloadAttempt(const DownloadRequest& request);
CURLId prepare_download(
CURLMultiHandle& downloader,
const Context& context,
on_success_callback success,
on_failure_callback error
);
auto create_completion_function() -> completion_function;
private:
bool finish_download(CURLMultiHandle& downloader, CURLcode code);
void clean_attempt(CURLMultiHandle& downloader, bool erase_downloaded);
void invoke_progress_callback(const DownloadEvent&) const;
void configure_handle(const Context& context);
void configure_handle_headers(const Context& context);
size_t write_data(char* buffer, size_t data);
static size_t curl_header_callback(char* buffer, size_t size, size_t nbitems, void* self);
static size_t curl_write_callback(char* buffer, size_t size, size_t nbitems, void* self);
static int curl_progress_callback(
void* f,
curl_off_t total_to_download,
curl_off_t now_downloaded,
curl_off_t,
curl_off_t
);
bool can_retry(CURLcode code) const;
bool can_retry(const TransferData& data) const;
TransferData get_transfer_data() const;
DownloadError build_download_error(CURLcode code) const;
DownloadError build_download_error(TransferData data) const;
DownloadSuccess build_download_success(TransferData data) const;
const DownloadRequest* p_request;
CURLHandle m_handle;
on_success_callback m_success_callback;
on_failure_callback m_failure_callback;
std::size_t m_retry_wait_seconds;
std::unique_ptr<CompressionStream> p_stream;
std::ofstream m_file;
std::string m_cache_control;
std::string m_etag;
std::string m_last_modified;
};
struct DownloadTrackerOptions
{
std::size_t max_retries = 0;
bool fail_fast = false;
};
class DownloadTracker
{
public:
using completion_function = DownloadAttempt::completion_function;
using completion_map_entry = std::pair<CURLId, completion_function>;
DownloadTracker(const DownloadRequest& request, DownloadTrackerOptions options);
auto prepare_new_attempt(CURLMultiHandle& handle, const Context& context)
-> completion_map_entry;
bool can_start_transfer() const;
const DownloadResult& get_result() const;
private:
bool invoke_on_success(const DownloadSuccess&) const;
void invoke_on_failure(const DownloadError&) const;
bool is_waiting() const;
void set_state(bool success);
void set_state(const DownloadError& res);
void throw_if_required(const DownloadSuccess&);
void throw_if_required(const DownloadError&);
void save(DownloadSuccess&&);
void save(DownloadError&&);
const DownloadRequest* p_request;
DownloadTrackerOptions m_options;
DownloadAttempt m_attempt;
std::vector<DownloadResult> m_attempt_results;
DownloadState m_state;
using time_point_t = std::chrono::steady_clock::time_point;
std::optional<time_point_t> m_next_retry;
};
class Downloader
{
public:
explicit Downloader(MultiDownloadRequest requests, DownloadOptions options, const Context& context);
MultiDownloadResult download();
private:
void prepare_next_downloads();
void update_downloads();
bool download_done() const;
MultiDownloadResult build_result() const;
MultiDownloadRequest m_requests;
DownloadOptions m_options;
const Context* p_context;
CURLMultiHandle m_curl_handle;
std::vector<DownloadTracker> m_trackers;
size_t m_waiting_count;
using completion_function = DownloadTracker::completion_function;
std::unordered_map<CURLId, completion_function> m_completion_map;
};
}
#endif

View File

@ -758,7 +758,7 @@ namespace mamba
DownloadTarget* current_target = nullptr;
for (const auto& target : m_targets)
{
if (target->get_curl_handle() == msg.m_handle_ref)
if (target->get_curl_handle().get_id() == msg.m_handle_id)
{
current_target = target;
break;

View File

@ -1544,12 +1544,13 @@ namespace mamba
return std::string(reinterpret_cast<const char*>(output.data()));
}
std::optional<std::string> proxy_match(const std::string& url)
std::optional<std::string>
proxy_match(const std::string& url, const std::map<std::string, std::string>& proxy_servers)
{
/* This is a reimplementation of requests.utils.select_proxy(), of the python requests
library used by conda */
auto& proxies = Context::instance().proxy_servers;
if (proxies.empty())
if (proxy_servers.empty())
{
return std::nullopt;
}
@ -1573,8 +1574,8 @@ namespace mamba
for (auto& option : options)
{
auto proxy = proxies.find(option);
if (proxy != proxies.end())
auto proxy = proxy_servers.find(option);
if (proxy != proxy_servers.end())
{
return proxy->second;
}
@ -1583,6 +1584,11 @@ namespace mamba
return std::nullopt;
}
std::optional<std::string> proxy_match(const std::string& url)
{
return proxy_match(url, Context::instance().remote_fetch_params.proxy_servers);
}
std::string hide_secrets(std::string_view str)
{
std::string copy(str);

View File

@ -42,6 +42,7 @@ set(LIBMAMBA_TEST_SRCS
src/core/test_channel.cpp
src/core/test_configuration.cpp
src/core/test_cpp.cpp
src/core/test_downloader.cpp
src/core/test_env_file_reading.cpp
src/core/test_environments_manager.cpp
src/core/test_history.cpp

View File

@ -36,7 +36,7 @@ namespace mamba
{
m_channel_alias_bu = ctx.channel_alias;
m_ssl_verify = ctx.remote_fetch_params.ssl_verify;
m_proxy_servers = ctx.proxy_servers;
m_proxy_servers = ctx.remote_fetch_params.proxy_servers;
}
~Configuration()
@ -44,7 +44,7 @@ namespace mamba
config.reset_configurables();
ctx.channel_alias = m_channel_alias_bu;
ctx.remote_fetch_params.ssl_verify = m_ssl_verify;
ctx.proxy_servers = m_proxy_servers;
ctx.remote_fetch_params.proxy_servers = m_proxy_servers;
}
protected:
@ -724,7 +724,7 @@ namespace mamba
std::map<std::string, std::string> expected = { { "http", "foo" },
{ "https", "bar" } };
CHECK_EQ(actual, expected);
CHECK_EQ(ctx.proxy_servers, expected);
CHECK_EQ(ctx.remote_fetch_params.proxy_servers, expected);
CHECK_EQ(config.sources().size(), 1);
CHECK_EQ(config.valid_sources().size(), 1);

View File

@ -0,0 +1,48 @@
// Copyright (c) 2022, QuantStack and Mamba Contributors
//
// Distributed under the terms of the BSD 3-Clause License.
//
// The full license is in the file LICENSE, distributed with this software.
#include <doctest/doctest.h>
#include "mamba/core/download.hpp"
namespace mamba
{
TEST_SUITE("downloader")
{
TEST_CASE("file_does_not_exist")
{
#ifdef __linux__
DownloadRequest request(
"test",
"file:///nonexistent/repodata.json",
"test_download_repodata.json",
false,
true
);
MultiDownloadRequest dl_request{ std::vector{ std::move(request) } };
Context::instance().output_params.quiet = true;
MultiDownloadResult res = download(dl_request, Context::instance());
CHECK_EQ(res.results.size(), std::size_t(1));
CHECK(!res.results[0]);
CHECK_EQ(res.results[0].error().attempt_number, std::size_t(1));
#endif
}
TEST_CASE("file_does_not_exist_throw")
{
#ifdef __linux__
DownloadRequest request(
"test",
"file:///nonexistent/repodata.json",
"test_download_repodata.json"
);
MultiDownloadRequest dl_request{ std::vector{ std::move(request) } };
Context::instance().output_params.quiet = true;
CHECK_THROWS_AS(download(dl_request, Context::instance()), std::runtime_error);
#endif
}
}
}

View File

@ -175,11 +175,12 @@ namespace mamba
{
TEST_CASE("proxy_match")
{
Context::instance().proxy_servers = { { "http", "foo" },
{ "https", "bar" },
{ "https://example.net", "foobar" },
{ "all://example.net", "baz" },
{ "all", "other" } };
Context::instance().remote_fetch_params.proxy_servers = { { "http", "foo" },
{ "https", "bar" },
{ "https://example.net",
"foobar" },
{ "all://example.net", "baz" },
{ "all", "other" } };
CHECK_EQ(*proxy_match("http://example.com/channel"), "foo");
CHECK_EQ(*proxy_match("http://example.net/channel"), "foo");
@ -189,14 +190,16 @@ namespace mamba
CHECK_EQ(*proxy_match("ftp://example.net/channel"), "baz");
CHECK_EQ(*proxy_match("ftp://example.org"), "other");
Context::instance().proxy_servers = { { "http", "foo" },
{ "https", "bar" },
{ "https://example.net", "foobar" },
{ "all://example.net", "baz" } };
Context::instance().remote_fetch_params.proxy_servers = {
{ "http", "foo" },
{ "https", "bar" },
{ "https://example.net", "foobar" },
{ "all://example.net", "baz" }
};
CHECK_FALSE(proxy_match("ftp://example.org").has_value());
Context::instance().proxy_servers = {};
Context::instance().remote_fetch_params.proxy_servers = {};
CHECK_FALSE(proxy_match("http://example.com/channel").has_value());
}

View File

@ -457,6 +457,14 @@ class Context:
def max_retries(self, arg0: int) -> None:
pass
@property
def proxy_servers(self) -> typing.Dict[str, str]:
"""
:type: typing.Dict[str, str]
"""
@proxy_servers.setter
def proxy_servers(self, arg0: typing.Dict[str, str]) -> None:
pass
@property
def retry_backoff(self) -> int:
"""
:type: int
@ -710,7 +718,7 @@ class Context:
:type: typing.Dict[str, str]
"""
@proxy_servers.setter
def proxy_servers(self, arg0: typing.Dict[str, str]) -> None:
def proxy_servers(self, arg1: typing.Dict[str, str]) -> None:
pass
@property
def quiet(self) -> bool:

View File

@ -527,7 +527,6 @@ PYBIND11_MODULE(bindings, m)
.def_readwrite("always_yes", &Context::always_yes)
.def_readwrite("dry_run", &Context::dry_run)
.def_readwrite("download_only", &Context::download_only)
.def_readwrite("proxy_servers", &Context::proxy_servers)
.def_readwrite("add_pip_as_python_dependency", &Context::add_pip_as_python_dependency)
.def_readwrite("envs_dirs", &Context::envs_dirs)
.def_readwrite("pkgs_dirs", &Context::pkgs_dirs)
@ -576,6 +575,7 @@ PYBIND11_MODULE(bindings, m)
.def_readwrite("retry_backoff", &Context::RemoteFetchParams::retry_backoff)
.def_readwrite("user_agent", &Context::RemoteFetchParams::user_agent)
// .def_readwrite("read_timeout_secs", &Context::RemoteFetchParams::read_timeout_secs)
.def_readwrite("proxy_servers", &Context::RemoteFetchParams::proxy_servers)
.def_readwrite("connect_timeout_secs", &Context::RemoteFetchParams::connect_timeout_secs);
py::class_<Context::OutputParams>(ctx, "OutputParams")
@ -681,6 +681,19 @@ PYBIND11_MODULE(bindings, m)
deprecated("Use `remote_fetch_params.connect_timeout_secs` instead.");
self.remote_fetch_params.connect_timeout_secs = cts;
}
)
.def_property(
"proxy_servers",
[](const Context& self)
{
deprecated("Use `remote_fetch_params.proxy_servers` instead.");
return self.remote_fetch_params.proxy_servers;
},
[](Context& self, const std::map<std::string, std::string>& proxies)
{
deprecated("Use `remote_fetch_params.proxy_servers` instead.");
self.remote_fetch_params.proxy_servers = proxies;
}
);
// OutputParams

View File

@ -211,7 +211,7 @@ def init_api_context(use_mamba_experimental=False):
api_ctx.channels = context.channels
api_ctx.platform = context.subdir
# Conda uses a frozendict here
api_ctx.proxy_servers = dict(context.proxy_servers)
api_ctx.remote_fetch_params.proxy_servers = dict(context.proxy_servers)
if "MAMBA_EXTRACT_THREADS" in os.environ:
try: