diff --git a/pyblock2/driver/core.py b/pyblock2/driver/core.py index 595fc9f4..9e7c6fea 100644 --- a/pyblock2/driver/core.py +++ b/pyblock2/driver/core.py @@ -4288,7 +4288,10 @@ def dmrg( assert len(proj_weights) == len(proj_mpss) dmrg.projection_weights = bw.VectorFP(proj_weights) dmrg.ext_mpss = bw.bs.VectorMPS(proj_mpss) - impo = self.get_identity_mpo() + if metric_mpo is None: + impo = self.get_identity_mpo() + else: + impo = metric_mpo for ext_mps in dmrg.ext_mpss: if ext_mps.info.tag == ket.info.tag: raise RuntimeError("Same tag for proj_mps and ket!!") diff --git a/src/dmrg/moving_environment.hpp b/src/dmrg/moving_environment.hpp index 9aace3a2..52d2012a 100644 --- a/src/dmrg/moving_environment.hpp +++ b/src/dmrg/moving_environment.hpp @@ -2040,7 +2040,7 @@ template struct MovingEnvironment { shared_ptr> cket = nullptr) { return symm_context_convert_impl( i, mps->info, cmps->info, dot, fuse_left, mask, forward, - is_wfn, infer_info, + is_wfn, infer_info, false, ket == nullptr && !(!forward && infer_info) ? mps->tensors[i] : ket, cket == nullptr && !(forward && infer_info) @@ -2049,14 +2049,15 @@ template struct MovingEnvironment { nullptr, nullptr) .first; } - static shared_ptr> symm_context_convert_group( + static shared_ptr> + symm_context_convert_perturbative( int i, const shared_ptr> &mps, const shared_ptr> &cmps, int dot, bool fuse_left, bool mask, bool forward, bool is_wfn, bool infer_info, const shared_ptr> &pket) { return symm_context_convert_impl(i, mps->info, cmps->info, dot, fuse_left, mask, forward, is_wfn, - infer_info, mps->tensors[i], + infer_info, true, mps->tensors[i], cmps->tensors[i], pket, nullptr) .second; } @@ -2066,7 +2067,7 @@ template struct MovingEnvironment { symm_context_convert_impl(int i, const shared_ptr> &info, const shared_ptr> &cinfo, int dot, bool fuse_left, bool mask, bool forward, - bool is_wfn, bool infer_info, + bool is_wfn, bool infer_info, bool is_pert, shared_ptr> ket, shared_ptr> cket, shared_ptr> pket, @@ -2139,7 +2140,8 @@ template struct MovingEnvironment { shared_ptr> gr_wfn = is_group ? make_shared>(d_alloc) : nullptr; - if (is_group && infer_info) { + if (is_pert) { + assert(is_group && infer_info); // FIXME: multi will have problem vector pket_dqs; for (int iw = 0; iw < pket->n; iw++) { diff --git a/src/dmrg/mps_unfused.hpp b/src/dmrg/mps_unfused.hpp index daa738fb..c5a10743 100644 --- a/src/dmrg/mps_unfused.hpp +++ b/src/dmrg/mps_unfused.hpp @@ -23,6 +23,7 @@ #include "../core/matrix.hpp" #include "../core/sparse_matrix.hpp" #include "mps.hpp" +#include "state_averaged.hpp" #include #include #include @@ -319,6 +320,7 @@ struct TransSparseTensorfind_state(lqz); int ir = tr_right_dim->find_state(rqz); + assert(il != -1 && ir != -1); int kmst = conn_basis->n_states[imqz]; int klst = conn_left_dim->n_states[il]; int krst = conn_right_dim->n_states[ir]; @@ -372,22 +374,26 @@ struct TransSparseTensor struct UnfusedMPS { shared_ptr> info; vector>> tensors; + vector>>> wfns; string canonical_form; int center, n_sites, dot; + bool is_multi; + int nroots; + vector::FP> weights; UnfusedMPS() {} UnfusedMPS(const shared_ptr> &mps) { this->initialize(mps); } static shared_ptr> - forward_left_fused(int ii, const shared_ptr> &mps, bool wfn) { + forward_left_fused(int ii, shared_ptr> info, + shared_ptr> mat, bool wfn) { shared_ptr> ts = make_shared>(); - StateInfo m = *mps->info->basis[ii]; + StateInfo m = *info->basis[ii]; ts->data.resize(m.n); - mps->info->load_left_dims(ii); - StateInfo l = *mps->info->left_dims[ii]; - StateInfo lm = StateInfo::tensor_product( - l, m, *mps->info->left_dims_fci[ii + 1]); + info->load_left_dims(ii); + StateInfo l = *info->left_dims[ii]; + StateInfo lm = + StateInfo::tensor_product(l, m, *info->left_dims_fci[ii + 1]); shared_ptr::ConnectionInfo> clm = StateInfo::get_connection_info(l, m, lm); - shared_ptr> mat = mps->tensors[ii]; assert(wfn == mat->info->is_wavefunction); for (int i = 0; i < mat->info->n; i++) { S bra = mat->info->quanta[i].get_bra(mat->info->delta_quantum); @@ -403,8 +409,7 @@ template struct UnfusedMPS { uint32_t lp = (uint32_t)l.n_states[ibba] * m.n_states[ibbb] * mat->info->n_states_ket[i]; ts->data[ibbb].push_back(make_pair( - make_pair(l.quanta[ibba], - wfn ? mps->info->target - ket : ket), + make_pair(l.quanta[ibba], wfn ? info->target - ket : ket), make_shared>(l.n_states[ibba], m.n_states[ibbb], mat->info->n_states_ket[i]))); memcpy(ts->data[ibbb].back().second->data->data(), @@ -420,17 +425,17 @@ template struct UnfusedMPS { return ts; } static shared_ptr> - forward_right_fused(int ii, const shared_ptr> &mps, bool wfn) { + forward_right_fused(int ii, shared_ptr> info, + shared_ptr> mat, bool wfn) { shared_ptr> ts = make_shared>(); - StateInfo m = *mps->info->basis[ii]; + StateInfo m = *info->basis[ii]; ts->data.resize(m.n); - mps->info->load_right_dims(ii + 1); - StateInfo r = *mps->info->right_dims[ii + 1]; + info->load_right_dims(ii + 1); + StateInfo r = *info->right_dims[ii + 1]; StateInfo mr = - StateInfo::tensor_product(m, r, *mps->info->right_dims_fci[ii]); + StateInfo::tensor_product(m, r, *info->right_dims_fci[ii]); shared_ptr::ConnectionInfo> cmr = StateInfo::get_connection_info(m, r, mr); - shared_ptr> mat = mps->tensors[ii]; assert(wfn == mat->info->is_wavefunction); for (int i = 0; i < mat->info->n; i++) { S bra = mat->info->quanta[i].get_bra(mat->info->delta_quantum); @@ -445,8 +450,8 @@ template struct UnfusedMPS { ikkb = cmr->ij_indices[kk].second; uint32_t lp = (uint32_t)m.n_states[ikka] * r.n_states[ikkb]; ts->data[ikka].push_back(make_pair( - make_pair(wfn ? bra : mps->info->target - bra, - mps->info->target - r.quanta[ikkb]), + make_pair(wfn ? bra : info->target - bra, + info->target - r.quanta[ikkb]), make_shared>(mat->info->n_states_bra[i], m.n_states[ikka], r.n_states[ikkb]))); @@ -463,20 +468,46 @@ template struct UnfusedMPS { r.deallocate(); return ts; } + static vector>>> + forward_multi_mps_tensor(int i, const shared_ptr> &mmps) { + mmps->load_wavefunction(i); + vector>>> ts(mmps->wfns.size()); + for (int iw = 0; iw < ts.size(); iw++) { + ts[iw] = vector>>(mmps->wfns[iw]->n); + for (int k = 0; k < mmps->wfns[iw]->n; k++) { + mmps->info->target = + dynamic_pointer_cast>(mmps->info) + ->targets[k]; + if (mmps->canonical_form[i] == 'J' || + (i == 0 && mmps->canonical_form[i] == 'M')) + ts[iw][k] = forward_left_fused(i, mmps->info, + (*mmps->wfns[iw])[k], true); + else if (mmps->canonical_form[i] == 'T' || + (i == mmps->n_sites - 1 && + mmps->canonical_form[i] == 'M')) + ts[iw][k] = forward_right_fused(i, mmps->info, + (*mmps->wfns[iw])[k], true); + else + assert(false); + } + } + mmps->unload_wavefunction(i); + return ts; + } static shared_ptr> forward_mps_tensor(int i, const shared_ptr> &mps) { assert(mps->tensors[i] != nullptr); mps->load_tensor(i); shared_ptr> ts; if (mps->canonical_form[i] == 'L' || mps->canonical_form[i] == 'K' || - (i == 0 && mps->canonical_form[i] == 'C')) { - ts = forward_left_fused(i, mps, + (i == 0 && mps->canonical_form[i] == 'C')) + ts = forward_left_fused(i, mps->info, mps->tensors[i], mps->canonical_form[i] == 'C' || mps->canonical_form[i] == 'K'); - } else if (mps->canonical_form[i] == 'R' || - mps->canonical_form[i] == 'S' || - (i == mps->n_sites - 1 && mps->canonical_form[i] == 'C')) - ts = forward_right_fused(i, mps, + else if (mps->canonical_form[i] == 'R' || + mps->canonical_form[i] == 'S' || + (i == mps->n_sites - 1 && mps->canonical_form[i] == 'C')) + ts = forward_right_fused(i, mps->info, mps->tensors[i], mps->canonical_form[i] == 'C' || mps->canonical_form[i] == 'S'); else @@ -485,26 +516,26 @@ template struct UnfusedMPS { return ts; } static shared_ptr> - backward_left_fused(int ii, const shared_ptr> &mps, + backward_left_fused(int ii, shared_ptr> info, const shared_ptr> &spt, bool wfn) { shared_ptr> i_alloc = make_shared>(); shared_ptr::FP>> d_alloc = make_shared::FP>>(); - StateInfo m = *mps->info->basis[ii]; - StateInfo l = *mps->info->left_dims[ii]; - StateInfo lm = StateInfo::tensor_product( - l, m, *mps->info->left_dims_fci[ii + 1]); + StateInfo m = *info->basis[ii]; + StateInfo l = *info->left_dims[ii]; + StateInfo lm = + StateInfo::tensor_product(l, m, *info->left_dims_fci[ii + 1]); shared_ptr::ConnectionInfo> clm = StateInfo::get_connection_info(l, m, lm); shared_ptr> minfo = make_shared>(i_alloc); if (wfn) - minfo->initialize(lm, *mps->info->right_dims[ii + 1], - mps->info->target, false, true); + minfo->initialize(lm, *info->right_dims[ii + 1], info->target, + false, true); else - minfo->initialize(lm, *mps->info->left_dims[ii + 1], - mps->info->vacuum, false); + minfo->initialize(lm, *info->left_dims[ii + 1], info->vacuum, + false); shared_ptr> mat = make_shared>(d_alloc); mat->allocate(minfo); @@ -526,8 +557,8 @@ template struct UnfusedMPS { ibbb = clm->ij_indices[bb].second; uint32_t lp = (uint32_t)l.n_states[ibba] * m.n_states[ibbb] * mat->info->n_states_ket[i]; - pair qq = make_pair(l.quanta[ibba], - wfn ? mps->info->target - ket : ket); + pair qq = + make_pair(l.quanta[ibba], wfn ? info->target - ket : ket); if (mp[ibbb].count(qq)) { shared_ptr> ts = mp[ibbb].at(qq); assert(ts->shape[0] == l.n_states[ibba]); @@ -545,26 +576,25 @@ template struct UnfusedMPS { return mat; } static shared_ptr> - backward_right_fused(int ii, const shared_ptr> &mps, + backward_right_fused(int ii, shared_ptr> info, const shared_ptr> &spt, bool wfn) { shared_ptr> i_alloc = make_shared>(); shared_ptr::FP>> d_alloc = make_shared::FP>>(); - StateInfo m = *mps->info->basis[ii]; - StateInfo r = *mps->info->right_dims[ii + 1]; + StateInfo m = *info->basis[ii]; + StateInfo r = *info->right_dims[ii + 1]; StateInfo mr = - StateInfo::tensor_product(m, r, *mps->info->right_dims_fci[ii]); + StateInfo::tensor_product(m, r, *info->right_dims_fci[ii]); shared_ptr::ConnectionInfo> cmr = StateInfo::get_connection_info(m, r, mr); shared_ptr> minfo = make_shared>(i_alloc); if (wfn) - minfo->initialize(*mps->info->left_dims[ii], mr, mps->info->target, - false, true); + minfo->initialize(*info->left_dims[ii], mr, info->target, false, + true); else - minfo->initialize(*mps->info->right_dims[ii], mr, mps->info->vacuum, - false); + minfo->initialize(*info->right_dims[ii], mr, info->vacuum, false); shared_ptr> mat = make_shared>(d_alloc); mat->allocate(minfo); @@ -585,8 +615,8 @@ template struct UnfusedMPS { uint32_t ikka = cmr->ij_indices[kk].first, ikkb = cmr->ij_indices[kk].second; uint32_t lp = (uint32_t)m.n_states[ikka] * r.n_states[ikkb]; - pair qq = make_pair(wfn ? bra : mps->info->target - bra, - mps->info->target - r.quanta[ikkb]); + pair qq = make_pair(wfn ? bra : info->target - bra, + info->target - r.quanta[ikkb]); if (mp[ikka].count(qq)) { shared_ptr> ts = mp[ikka].at(qq); assert(ts->shape[0] == mat->info->n_states_bra[i]); @@ -609,20 +639,55 @@ template struct UnfusedMPS { const shared_ptr> &spt) { shared_ptr> mat; if (mps->canonical_form[i] == 'L' || mps->canonical_form[i] == 'K' || - (i == 0 && mps->canonical_form[i] == 'C')) { - mat = backward_left_fused(i, mps, spt, + (i == 0 && mps->canonical_form[i] == 'C')) + mat = backward_left_fused(i, mps->info, spt, mps->canonical_form[i] == 'C' || mps->canonical_form[i] == 'K'); - } else if (mps->canonical_form[i] == 'R' || - mps->canonical_form[i] == 'S' || - (i == mps->n_sites - 1 && mps->canonical_form[i] == 'C')) - mat = backward_right_fused(i, mps, spt, + else if (mps->canonical_form[i] == 'R' || + mps->canonical_form[i] == 'S' || + (i == mps->n_sites - 1 && mps->canonical_form[i] == 'C')) + mat = backward_right_fused(i, mps->info, spt, mps->canonical_form[i] == 'C' || mps->canonical_form[i] == 'S'); else assert(false); return mat; } + static vector>> + backward_multi_mps_tensor( + int i, const shared_ptr> &mmps, + const vector>>> &spt) { + vector>> wfns(spt.size()); + for (int iw = 0; iw < (int)spt.size(); iw++) { + vector>> mats(spt[iw].size()); + for (int k = 0; k < (int)spt[iw].size(); k++) { + mmps->info->target = + dynamic_pointer_cast>(mmps->info) + ->targets[k]; + if (mmps->canonical_form[i] == 'J' || + (i == 0 && mmps->canonical_form[i] == 'M')) + mats[k] = + backward_left_fused(i, mmps->info, spt[iw][k], true); + else if (mmps->canonical_form[i] == 'T' || + (i == mmps->n_sites - 1 && + mmps->canonical_form[i] == 'M')) + mats[k] = + backward_right_fused(i, mmps->info, spt[iw][k], true); + else + assert(false); + } + vector>> infos(spt[iw].size()); + for (int k = 0; k < (int)spt[iw].size(); k++) + infos[k] = mats[k]->info; + shared_ptr::FP>> d_alloc = + make_shared::FP>>(); + wfns[iw] = make_shared>(d_alloc); + wfns[iw]->allocate(infos); + for (int k = 0; k < (int)spt[iw].size(); k++) + (*wfns[iw])[k]->copy_data_from(mats[k]); + } + return wfns; + } void initialize(const shared_ptr> &mps) { this->info = mps->info; canonical_form = mps->canonical_form; @@ -630,21 +695,55 @@ template struct UnfusedMPS { n_sites = mps->n_sites; dot = mps->dot; tensors.resize(mps->n_sites); - for (int i = 0; i < mps->n_sites; i++) - tensors[i] = forward_mps_tensor(i, mps); + is_multi = mps->get_type() & MPSTypes::MultiWfn; + if (!is_multi) + for (int i = 0; i < mps->n_sites; i++) + tensors[i] = forward_mps_tensor(i, mps); + else { + shared_ptr> mmps = + dynamic_pointer_cast>(mps); + for (int i = 0; i < mps->n_sites; i++) + if (i != mps->center) { + mps->info->target = + dynamic_pointer_cast>(mmps->info) + ->targets[0]; + tensors[i] = forward_mps_tensor(i, mps); + } + nroots = mmps->nroots; + weights = mmps->weights; + wfns = forward_multi_mps_tensor(mps->center, mmps); + } } // Transform from Unfused MPS to normal MPS shared_ptr> finalize(const shared_ptr> ¶_rule = nullptr) const { info->load_mutable(); - shared_ptr> xmps = make_shared>(info); + shared_ptr> xmps; + xmps = is_multi ? make_shared>( + dynamic_pointer_cast>(info)) + : make_shared>(info); xmps->canonical_form = canonical_form; xmps->center = center; xmps->n_sites = n_sites; xmps->dot = dot; xmps->tensors.resize(n_sites); - for (int i = 0; i < xmps->n_sites; i++) - xmps->tensors[i] = backward_mps_tensor(i, xmps, tensors[i]); + if (!is_multi) + for (int i = 0; i < xmps->n_sites; i++) + xmps->tensors[i] = backward_mps_tensor(i, xmps, tensors[i]); + else { + shared_ptr> xmmps = + dynamic_pointer_cast>(xmps); + for (int i = 0; i < xmps->n_sites; i++) + if (i != xmps->center) { + xmps->info->target = + dynamic_pointer_cast>(xmmps->info) + ->targets[0]; + xmps->tensors[i] = backward_mps_tensor(i, xmps, tensors[i]); + } + xmmps->nroots = nroots; + xmmps->weights = weights; + xmmps->wfns = backward_multi_mps_tensor(xmps->center, xmmps, wfns); + } if (para_rule != nullptr) para_rule->comm->barrier(); if (para_rule == nullptr || para_rule->is_root()) { @@ -718,7 +817,12 @@ struct TransUnfusedMPS { const shared_ptr> &cg, S2 target) { shared_ptr> fmps = make_shared>(); umps->info->load_mutable(); - fmps->info = TransMPSInfo::forward(umps->info, target); + fmps->info = + umps->is_multi + ? TransMultiMPSInfo::forward( + dynamic_pointer_cast>(umps->info), + vector{target}) + : TransMPSInfo::forward(umps->info, target); fmps->info->tag = xtag; fmps->info->save_mutable(); fmps->tensors.resize(umps->tensors.size()); @@ -726,7 +830,13 @@ struct TransUnfusedMPS { fmps->center = umps->center; fmps->n_sites = umps->n_sites; fmps->dot = umps->dot; + fmps->nroots = umps->nroots; + fmps->weights = umps->weights; + fmps->is_multi = umps->is_multi; umps->info->load_mutable(); + if (umps->is_multi) + umps->info->target = + dynamic_pointer_cast>(umps->info)->targets[0]; for (int i = 0; i < umps->n_sites; i++) if (umps->canonical_form[i] == 'L') fmps->tensors[i] = TransSparseTensor::forward( @@ -743,13 +853,31 @@ struct TransUnfusedMPS { fmps->tensors[i] = TransSparseTensor::forward( umps->tensors[i], umps->info->basis[i], ri, rj, cg, true, target); - } else { + } else if (!umps->is_multi) { shared_ptr> ri = make_shared>(StateInfo::complementary( *umps->info->right_dims[i + 1], umps->info->target)); fmps->tensors[i] = TransSparseTensor::forward( umps->tensors[i], umps->info->basis[i], umps->info->left_dims[i], ri, cg, true, target); + } else { + fmps->wfns.resize(umps->wfns.size()); + for (int iw = 0; iw < (int)umps->wfns.size(); iw++) { + fmps->wfns[iw].resize(umps->wfns[iw].size()); + for (int k = 0; k < (int)umps->wfns[iw].size(); k++) { + shared_ptr> ri = + make_shared>( + StateInfo::complementary( + *umps->info->right_dims[i + 1], + dynamic_pointer_cast>( + umps->info) + ->targets[k])); + fmps->wfns[iw][k] = + TransSparseTensor::forward( + umps->wfns[iw][k], umps->info->basis[i], + umps->info->left_dims[i], ri, cg, true, target); + } + } } umps->info->deallocate_mutable(); return fmps; diff --git a/src/dmrg/state_averaged.hpp b/src/dmrg/state_averaged.hpp index 3da1663d..73af6055 100644 --- a/src/dmrg/state_averaged.hpp +++ b/src/dmrg/state_averaged.hpp @@ -587,4 +587,79 @@ template struct MultiMPS : MPS { } }; +template +struct TransMultiMPSInfo { + static shared_ptr> + forward(const shared_ptr> &si, const vector &targets) { + return TransMultiMPSInfo::backward(si, targets); + } + static shared_ptr> + backward(const shared_ptr> &si, + const vector &targets) { + return TransMultiMPSInfo::forward(si, targets); + } +}; + +template struct TransMultiMPSInfoAnyBase { + static shared_ptr> + transform(const shared_ptr> &si, const vector &targets) { + int n_sites = si->n_sites; + S vacuum = TransStateInfo::forward( + make_shared>(si->vacuum), targets[0]) + ->quanta[0]; + vector>> basis(n_sites); + for (int i = 0; i < n_sites; i++) + basis[i] = TransStateInfo::forward(si->basis[i], vacuum); + shared_ptr> so = + make_shared>(n_sites, vacuum, targets, basis); + // handle the singlet embedding case + so->left_dims_fci[0] = + TransStateInfo::forward(si->left_dims_fci[0], vacuum); + for (int i = 0; i < n_sites; i++) + so->left_dims_fci[i + 1] = + make_shared>(StateInfo::tensor_product( + *so->left_dims_fci[i], *basis[i], S(S::invalid))); + so->right_dims_fci[n_sites] = + TransStateInfo::forward(si->right_dims_fci[n_sites], vacuum); + for (int i = n_sites - 1; i >= 0; i--) + so->right_dims_fci[i] = + make_shared>(StateInfo::tensor_product( + *basis[i], *so->right_dims_fci[i + 1], S(S::invalid))); + for (int i = 0; i <= n_sites; i++) { + StateInfo::multi_target_filter(*so->left_dims_fci[i], + *so->right_dims_fci[i], targets); + StateInfo::multi_target_filter(*so->right_dims_fci[i], + *so->left_dims_fci[i], targets); + } + for (int i = 0; i <= n_sites; i++) + so->left_dims_fci[i]->collect(); + for (int i = n_sites; i >= 0; i--) + so->right_dims_fci[i]->collect(); + for (int i = 0; i <= n_sites; i++) + so->left_dims[i] = + TransStateInfo::forward(si->left_dims[i], vacuum); + for (int i = n_sites; i >= 0; i--) + so->right_dims[i] = + TransStateInfo::forward(si->right_dims[i], vacuum); + so->check_bond_dimensions(); + so->bond_dim = so->get_max_bond_dimension(); + so->tag = si->tag; + return so; + } +}; + +// Translation between SAny MultiMPSInfo +template +struct TransMultiMPSInfo + : TransMultiMPSInfoAnyBase { + static shared_ptr> + forward(const shared_ptr> &si, const vector &targets) { + return TransMultiMPSInfoAnyBase::transform(si, targets); + } + static shared_ptr> + backward(const shared_ptr> &si, const vector &targets) { + return TransMultiMPSInfoAnyBase::transform(si, targets); + } +}; + } // namespace block2 diff --git a/src/dmrg/sweep_algorithm.hpp b/src/dmrg/sweep_algorithm.hpp index 331adce8..0b6f4223 100644 --- a/src/dmrg/sweep_algorithm.hpp +++ b/src/dmrg/sweep_algorithm.hpp @@ -319,8 +319,8 @@ template struct DMRG { false); xket = context_ket; if (pket != nullptr) { - context_pket = - MovingEnvironment::symm_context_convert_group( + context_pket = MovingEnvironment:: + symm_context_convert_perturbative( i, me->ket, context_ket, 1, !skip_decomp ? forward : fuse_left, false, true, true, true, pket); @@ -861,10 +861,13 @@ template struct DMRG { xold_ket = context_old_ket; xket = context_ket; if (pket != nullptr) { - context_pket = - MovingEnvironment::symm_context_convert_group( - i, me->ket, context_ket, 2, true, false, true, true, - true, pket); + context_pket = MovingEnvironment< + S, FL, FLS>::symm_context_convert_perturbative(i, me->ket, + context_ket, + 2, true, + false, true, + true, true, + pket); xpket = context_pket; } } diff --git a/src/instantiation/block2_dmrg.hpp b/src/instantiation/block2_dmrg.hpp index 47b5cce7..670b0974 100644 --- a/src/instantiation/block2_dmrg.hpp +++ b/src/instantiation/block2_dmrg.hpp @@ -785,6 +785,8 @@ extern template struct block2::AntiHermitianRuleQC; extern template struct block2::MultiMPSInfo; extern template struct block2::MultiMPS; +extern template struct block2::TransMultiMPSInfo; + // sweep_algorithm.hpp extern template struct block2::DMRG; extern template struct block2::Linear; diff --git a/src/instantiation/dmrg_a/state_averaged.cpp b/src/instantiation/dmrg_a/state_averaged.cpp index 86fff1db..b78d7937 100644 --- a/src/instantiation/dmrg_a/state_averaged.cpp +++ b/src/instantiation/dmrg_a/state_averaged.cpp @@ -22,3 +22,5 @@ template struct block2::MultiMPSInfo; template struct block2::MultiMPS; + +template struct block2::TransMultiMPSInfo; \ No newline at end of file diff --git a/src/pybind.cpp b/src/pybind.cpp index 1d8a7459..5e8d8d45 100644 --- a/src/pybind.cpp +++ b/src/pybind.cpp @@ -304,6 +304,7 @@ PYBIND11_MODULE(block2, m) { #ifdef _USE_SANY bind_dmrg(m_sany, "SAny"); bind_trans_mps(m_sany, "sany"); + bind_trans_multi_mps(m_sany, "sany"); bind_fl_trans_mps_spin_specific(m_sany, "sany"); #ifdef _USE_COMPLEX bind_dmrg>(m_sany_cpx, "SAny"); diff --git a/src/pybind/dmrg_a/trans_mps.cpp b/src/pybind/dmrg_a/trans_mps.cpp index 47304f31..c7ff3ca4 100644 --- a/src/pybind/dmrg_a/trans_mps.cpp +++ b/src/pybind/dmrg_a/trans_mps.cpp @@ -21,6 +21,8 @@ #include "../pybind_dmrg.hpp" template void bind_trans_mps(py::module &m, const string &aux_name); +template void bind_trans_multi_mps(py::module &m, + const string &aux_name); template auto bind_fl_trans_mps_spin_specific(py::module &m, const string &aux_name) diff --git a/src/pybind/pybind_dmrg.hpp b/src/pybind/pybind_dmrg.hpp index 42960af3..5603cde8 100644 --- a/src/pybind/pybind_dmrg.hpp +++ b/src/pybind/pybind_dmrg.hpp @@ -484,22 +484,28 @@ template void bind_fl_mps(py::module &m) { .def_readwrite("canonical_form", &UnfusedMPS::canonical_form) .def_static("forward_left_fused", &UnfusedMPS::forward_left_fused, py::arg("i"), - py::arg("mps"), py::arg("wfn")) + py::arg("info"), py::arg("mat"), py::arg("wfn")) .def_static("forward_right_fused", &UnfusedMPS::forward_right_fused, py::arg("i"), - py::arg("mps"), py::arg("wfn")) + py::arg("info"), py::arg("mat"), py::arg("wfn")) .def_static("forward_mps_tensor", &UnfusedMPS::forward_mps_tensor, py::arg("i"), py::arg("mps")) + .def_static("forward_multi_mps_tensor", + &UnfusedMPS::forward_multi_mps_tensor, py::arg("i"), + py::arg("mmps")) .def_static("backward_left_fused", &UnfusedMPS::backward_left_fused, py::arg("i"), - py::arg("mps"), py::arg("spt"), py::arg("wfn")) + py::arg("info"), py::arg("spt"), py::arg("wfn")) .def_static("backward_right_fused", &UnfusedMPS::backward_right_fused, py::arg("i"), - py::arg("mps"), py::arg("spt"), py::arg("wfn")) + py::arg("info"), py::arg("spt"), py::arg("wfn")) .def_static("backward_mps_tensor", &UnfusedMPS::backward_mps_tensor, py::arg("i"), py::arg("mps"), py::arg("spt")) + .def_static("backward_multi_mps_tensor", + &UnfusedMPS::backward_multi_mps_tensor, py::arg("i"), + py::arg("mmps"), py::arg("spt")) .def("initialize", &UnfusedMPS::initialize) .def("finalize", &UnfusedMPS::finalize, py::arg("para_rule") = nullptr) @@ -1039,12 +1045,12 @@ void bind_fl_moving_environment(py::module &m, const string &name) { py::arg("forward"), py::arg("is_wfn"), py::arg("infer_info"), py::arg("ket") = nullptr, py::arg("cket") = nullptr) - .def_static("symm_context_convert_group", - &MovingEnvironment::symm_context_convert_group, - py::arg("i"), py::arg("mps"), py::arg("cmps"), - py::arg("dot"), py::arg("fuse_left"), py::arg("mask"), - py::arg("forward"), py::arg("is_wfn"), - py::arg("infer_info"), py::arg("pket")); + .def_static( + "symm_context_convert_perturbative", + &MovingEnvironment::symm_context_convert_perturbative, + py::arg("i"), py::arg("mps"), py::arg("cmps"), py::arg("dot"), + py::arg("fuse_left"), py::arg("mask"), py::arg("forward"), + py::arg("is_wfn"), py::arg("infer_info"), py::arg("pket")); py::bind_vector>>>( m, ("Vector" + name).c_str()); @@ -2225,6 +2231,13 @@ void bind_trans_mps(py::module &m, const string &aux_name) { &TransMPSInfo::forward); } +template +void bind_trans_multi_mps(py::module &m, const string &aux_name) { + + m.def(("trans_multi_mps_info_to_" + aux_name).c_str(), + &TransMultiMPSInfo::forward); +} + template void bind_fl_trans_mps(py::module &m, const string &aux_name) { @@ -2898,6 +2911,8 @@ extern template auto bind_fl_spin_specific(py::module &m) extern template void bind_trans_mps(py::module &m, const string &aux_name); +extern template void bind_trans_multi_mps(py::module &m, + const string &aux_name); extern template auto bind_fl_trans_mps_spin_specific(py::module &m, const string &aux_name)